├── LICENSE ├── README.md ├── Security.md ├── configs ├── action-localization │ └── ava_actions │ │ └── SPELL_default.yaml ├── action-segmentation │ └── 50salads │ │ └── SPELL_default.yaml ├── active-speaker-detection │ └── ava_active-speaker │ │ ├── SPELL_default.yaml │ │ └── SPELL_plus_default.yaml └── video-summarization │ ├── SumMe │ └── SPELL_default.yaml │ └── TVSum │ └── SPELL_default.yaml ├── data ├── annotations │ └── merge_ava_activespeaker.py ├── generate_spatial-temporal_graphs.py └── generate_temporal_graphs.py ├── docs ├── GETTING_STARTED_AL.md ├── GETTING_STARTED_AS.md ├── GETTING_STARTED_VS.md └── images │ └── gravit_teaser.jpg ├── gravit ├── __init__.py ├── datasets │ ├── __init__.py │ └── dataset_context_reasoning.py ├── models │ ├── __init__.py │ ├── build.py │ ├── context_reasoning │ │ ├── __init__.py │ │ └── spell.py │ └── losses.py └── utils │ ├── __init__.py │ ├── ava │ ├── README.md │ ├── __init__.py │ ├── label_map_util.py │ ├── metrics.py │ ├── np_box_list.py │ ├── np_box_list_ops.py │ ├── np_box_mask_list.py │ ├── np_box_mask_list_ops.py │ ├── np_box_ops.py │ ├── np_mask_ops.py │ ├── object_detection_evaluation.py │ ├── per_image_evaluation.py │ └── standard_fields.py │ ├── eval_tool.py │ ├── formatter.py │ ├── logger.py │ ├── parser.py │ └── vs │ ├── avg_splits.py │ ├── knapsack.py │ ├── run_vs_exp_summe.sh │ └── run_vs_exp_tvsum.sh ├── requirements.txt ├── setup.py └── tools ├── evaluate.py └── train_context_reasoning.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023, Intel Corporation 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GraVi-T 2 | This repository contains an open-source codebase for Graph-based long-term Video undersTanding (GraVi-T). It is designed to serve as a spatial-temporal graph learning framework for multiple video understanding tasks. The current version supports training and evaluating two state-of-the-art models: [SPELL](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136950367.pdf) for the tasks of active speaker detection, action localization, and action segmentation, and [VideoSAGE](https://arxiv.org/abs/2404.10539) for video summarization. 3 | 4 | In the near future, we will release more advanced graph-based approaches (e.g. [STHG](https://arxiv.org/abs/2306.10608)) for other tasks, including audio-visual diarization. 5 | 6 | ![](docs/images/gravit_teaser.jpg?raw=true) 7 | 8 | ## Ego4D Challenges and ActivityNet 9 | We want to note that our method has recently won many challenges, including the Ego4D challenges [@ECCV22](https://ego4d-data.org/workshops/eccv22/), [@CVPR23](https://ego4d-data.org/workshops/cvpr23/) and ActivityNet [@CVPR22](https://research.google.com/ava/challenge.html). We summarize ASD (active speaker detection) and AVD (audio-visual diarization) performance comparisons on the validation set of the Ego4D dataset: 10 | | ASD Model | ASD mAP(%)↑ | ASD mAP@0.5(%)↑ | AVD DER(%)↓ | 11 | |:------------|:-------------------:|:-----------------------:|:------------------:| 12 | | RegionCls | - | 24.6 | 80.0 | 13 | | TalkNet | - | 50.6 | 79.3 | 14 | | [SPELL](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136950367.pdf) (Ours) | 71.3 | 60.7 | 66.6 | 15 | | [STHG](https://arxiv.org/abs/2306.10608) (Ours) | **75.7** | **63.7** | **59.4** | 16 | 17 | :bulb:In this table, We report two metrics to evaluate ASD performance: mAP quantifies the ASD results by assuming that the face bound-box detections are the ground truth (i.e. assuming the perfect face detector), whereas mAP@0.5 quantifies the ASD results on the detected face bounding boxes (i.e. a face detection is considered positive only if the IoU between a detected face bounding box and the ground-truth exceeds 0.5). For AVD, we report DER (diarization error rate): a lower DER value indicates a better AVD performance. For more information, please refer to our technical reports for the challenge. 18 | 19 | :bulb:We computed mAP@0.5 by using [Ego4D's official evaluation tool](https://github.com/EGO4D/audio-visual/tree/main/active-speaker-detection/active_speaker/active_speaker_evaluation) 20 | 21 | ## Use Cases and Performance 22 | ### Active Speaker Detection (Dataset: AVA-ActiveSpeaker v1.0) 23 | | Model | Feature | validation mAP (%) | 24 | |:---------------|:------------------:|:--------------------------:| 25 | | SPELL (Ours) | RESNET18-TSM-AUG | **94.2** (up from 88.0) | 26 | | SPELL (Ours) | RESNET50-TSM-AUG | **94.9** (up from 89.3) | 27 | > Numbers in parentheses indicate the mAP scores without using the suggested graph learning method. 28 | 29 | ### Action Localization (Dataset: AVA-Actions v2.2) 30 | | Model | Feature | validation mAP (%) | 31 | |:---------------|:----------------------:|:--------------------------:| 32 | | SPELL (Ours) | SLOWFAST-64x2-R101 | **36.8** (up from 29.4) | 33 | > Number in parentheses indicates the mAP score without using the suggested graph learning method. 34 | 35 | ### Action Segmentation (Dataset: 50Salads - split2) 36 | | Model | Feature | F1@0.1 (%) | Acc (%) | 37 | |:---------------|:------------:|:-------------------------:|:-------------------------:| 38 | | SPELL (Ours) | MSTCN++ | **84.7** (up from 83.4) | **85.0** (up from 84.6) | 39 | | SPELL (Ours) | ASFORMER | **89.8** (up from 86.1) | **88.2** (up from 87.8) | 40 | > Numbers in parentheses indicate the scores without using the suggested graph learning method. 41 | 42 | ### Video Summarization (Datasets: SumMe & TVSum) 43 | | Model | Feature | [Kendall's Tau](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.kendalltau.html#scipy.stats.kendalltau)* | [Spearman's Rho](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.spearmanr.html#scipy.stats.spearmanr)* | 44 | |:-------------------|:---------------------------------:|:-------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| 45 | | VideoSAGE (Ours) | eccv16_dataset_summe_google_pool5 | **0.12** (up from 0.09) | **0.16** (up from 0.12) | 46 | | VideoSAGE (Ours) | eccv16_dataset_tvsum_google_pool5 | **0.30** (up from 0.27) | **0.42** (up from 0.39) | 47 | > Numbers in parentheses indicate the scores without using the suggested graph learning method.\ 48 | > *Correlation metric between predicted frame importance and ground truth. 49 | 50 | ## Requirements 51 | Preliminary requirements: 52 | - Python>=3.9 53 | - CUDA 12.1 54 | 55 | Run the following command if you have CUDA 12.1: 56 | ``` 57 | pip3 install -r requirements.txt 58 | ``` 59 | 60 | Alternatively, you can manually install PyYAML, pandas, and [PyG](https://www.pyg.org)>=2.0.3 with CUDA>=11.1 61 | 62 | ## Installation 63 | After confirming the above requirements, run the following commands: 64 | ``` 65 | git clone https://github.com/IntelLabs/GraVi-T.git 66 | cd GraVi-T 67 | pip3 install -e . 68 | ``` 69 | 70 | ## Getting Started (Active Speaker Detection) 71 | ### Annotations 72 | 1) Download the annotations of AVA-ActiveSpeaker from the official site: 73 | ``` 74 | DATA_DIR="data/annotations" 75 | 76 | wget https://research.google.com/ava/download/ava_activespeaker_val_v1.0.tar.bz2 -P ${DATA_DIR} 77 | tar -xf ${DATA_DIR}/ava_activespeaker_val_v1.0.tar.bz2 -C ${DATA_DIR} 78 | ``` 79 | 80 | 2) Preprocess the annotations: 81 | ``` 82 | python data/annotations/merge_ava_activespeaker.py 83 | ``` 84 | 85 | ### Features 86 | Download `RESNET18-TSM-AUG.zip` from the Google Drive link from [SPELL](https://github.com/SRA2/SPELL#code-usage) and unzip under `data/features`. 87 | > We use the features from the thirdparty repositories. 88 | 89 | ### Directory Structure 90 | The data directories should look as follows: 91 | ``` 92 | |-- data 93 | |-- annotations 94 | |-- ava_activespeaker_val_v1.0.csv 95 | |-- features 96 | |-- RESNET18-TSM-AUG 97 | |-- train 98 | |-- val 99 | ``` 100 | 101 | ### Experiments 102 | We can perform the experiments on active speaker detection with the default configuration by following the three steps below. 103 | 104 | #### Step 1: Graph Generation 105 | Run the following command to generate spatial-temporal graphs from the features: 106 | ``` 107 | python data/generate_spatial-temporal_graphs.py --features RESNET18-TSM-AUG --ec_mode csi --time_span 90 --tau 0.9 108 | ``` 109 | The generated graphs will be saved under `data/graphs`. Each graph captures long temporal context information in a video, which spans about 90 seconds (specified by `--time_span`). 110 | 111 | #### Step 2: Training 112 | Next, run the training script by passing the default configuration file: 113 | ``` 114 | python tools/train_context_reasoning.py --cfg configs/active-speaker-detection/ava_active-speaker/SPELL_default.yaml 115 | ``` 116 | The results and logs will be saved under `results`. 117 | 118 | #### Step 3: Evaluation 119 | Now, we can evaluate the trained model's performance: 120 | ``` 121 | python tools/evaluate.py --exp_name SPELL_ASD_default --eval_type AVA_ASD 122 | ``` 123 | This will print the evaluation score. 124 | 125 | ## Getting Started (Action Localization) 126 | Please refer to the instructions in [GETTING_STARTED_AL.md](docs/GETTING_STARTED_AL.md). 127 | 128 | ## Getting Started (Action Segmentation) 129 | Please refer to the instructions in [GETTING_STARTED_AS.md](docs/GETTING_STARTED_AS.md). 130 | 131 | ## Getting Started (Video Summarization) 132 | Please refer to the instructions in [GETTING_STARTED_VS.md](docs/GETTING_STARTED_VS.md). 133 | 134 | ## Contributor 135 | GraVi-T is written and maintained by [Kyle Min](https://github.com/kylemin) (from version 1.0.0 to 1.1.0) and [Jose Rojas Chaves](https://github.com/joserochh) (version 1.2.0). Please refer to the release notes for details about each version's supported features and applications. 136 | 137 | ## Citation 138 | ECCV 2022 paper about [SPELL](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136950367.pdf): 139 | ```bibtex 140 | @inproceedings{min2022learning, 141 | title={Learning Long-Term Spatial-Temporal Graphs for Active Speaker Detection}, 142 | author={Min, Kyle and Roy, Sourya and Tripathi, Subarna and Guha, Tanaya and Majumdar, Somdeb}, 143 | booktitle={European Conference on Computer Vision}, 144 | pages={371--387}, 145 | year={2022}, 146 | organization={Springer} 147 | } 148 | ``` 149 | 150 | Ego4D workshop paper [@ECCV22](https://ego4d-data.org/workshops/eccv22/): 151 | ```bibtex 152 | @article{min2022intel, 153 | title={Intel Labs at Ego4D Challenge 2022: A Better Baseline for Audio-Visual Diarization}, 154 | author={Min, Kyle}, 155 | journal={arXiv preprint arXiv:2210.07764}, 156 | year={2022} 157 | } 158 | ``` 159 | 160 | Ego4D workshop paper [@CVPR23](https://ego4d-data.org/workshops/cvpr23/) about [STHG](https://arxiv.org/abs/2306.10608): 161 | ```bibtex 162 | @article{min2023sthg, 163 | title={STHG: Spatial-Temporal Heterogeneous Graph Learning for Advanced Audio-Visual Diarization}, 164 | author={Min, Kyle}, 165 | journal={arXiv preprint arXiv:2306.10608}, 166 | year={2023} 167 | } 168 | ``` 169 | 170 | SG2RL workshop paper [@CVPR24](https://sites.google.com/view/sg2rl) about [VideoSAGE](https://arxiv.org/abs/2404.10539): 171 | ```bibtex 172 | @article{chaves2024videosage, 173 | title={VideoSAGE: Video Summarization with Graph Representation Learning}, 174 | author={Jose M. Rojas Chaves and Subarna Tripathi}, 175 | journal={arXiv preprint arXiv:2404.10539}, 176 | year={2024} 177 | } 178 | ``` 179 | 180 | ## Disclaimer 181 | 182 | > This “research quality code” is for Non-Commercial purposes and provided by Intel “As Is” without any express or implied warranty of any kind. Please see the dataset's applicable license for terms and conditions. Intel does not own the rights to this data set and does not confer any rights to it. Intel does not warrant or assume responsibility for the accuracy or completeness of any information, text, graphics, links or other items within the code. A thorough security review has not been performed on this code. Additionally, this repository may contain components that are out of date or contain known security vulnerabilities. 183 | 184 | > AVA-ActiveSpeaker, AVA-Actions, 50Salads, TVSum, SumMe: Please see the dataset's applicable license for terms and conditions. Intel does not own the rights to this data set and does not confer any rights to it. 185 | 186 | ## Datasets & Models Disclaimer 187 | 188 | > To the extent that any public datasets are referenced by Intel or accessed using tools or code on this site those datasets are provided by the third party indicated as the data source. Intel does not create the data, or datasets, and does not warrant their accuracy or quality. By accessing the public dataset(s), or using a model trained on those datasets, you agree to the terms associated with those datasets and that your use complies with the applicable license. 189 | 190 | > Intel expressly disclaims the accuracy, adequacy, or completeness of any public datasets, and is not liable for any errors, omissions, or defects in the data, or for any reliance on the data. Intel is not liable for any liability or damages relating to your use of public datasets. 191 | -------------------------------------------------------------------------------- /Security.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | Intel is committed to rapidly addressing security vulnerabilities affecting our customers and providing clear guidance on the solution, impact, severity and mitigation. 3 | 4 | ## Reporting a Vulnerability 5 | Please report any security vulnerabilities in this project [utilizing the guidelines here](https://www.intel.com/content/www/us/en/security-center/vulnerability-handling-guidelines.html). 6 | 7 | -------------------------------------------------------------------------------- /configs/action-localization/ava_actions/SPELL_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_AL_default 2 | model_name: SPELL 3 | graph_name: SLOWFAST-64x2-R101_cdi_90.0_3.0 4 | loss_name: bce_logit 5 | use_spf: True 6 | use_ref: False 7 | num_modality: 1 8 | channel1: 1024 9 | channel2: 512 10 | proj_dim: 64 11 | final_dim: 80 12 | num_att_heads: 0 13 | dropout: 0.2 14 | lr: 0.0005 15 | wd: 0.0001 16 | batch_size: 16 17 | sch_param: 5 18 | num_epoch: 20 19 | -------------------------------------------------------------------------------- /configs/action-segmentation/50salads/SPELL_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_AS_default 2 | model_name: SPELL 3 | graph_name: ASFORMER_10_10 4 | loss_name: ce_ref 5 | use_spf: False 6 | use_ref: True 7 | w_ref: 5 8 | num_modality: 1 9 | channel1: 64 10 | channel2: 64 11 | final_dim: 19 12 | num_att_heads: 4 13 | dropout: 0.2 14 | lr: 0.0005 15 | wd: 0 16 | batch_size: 1 17 | sch_param: 5 18 | num_epoch: 50 19 | sample_rate: 2 20 | -------------------------------------------------------------------------------- /configs/active-speaker-detection/ava_active-speaker/SPELL_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_ASD_default 2 | model_name: SPELL 3 | graph_name: RESNET18-TSM-AUG_csi_90.0_0.9 4 | loss_name: bce_logit 5 | use_spf: True 6 | use_ref: False 7 | num_modality: 2 8 | channel1: 64 9 | channel2: 16 10 | proj_dim: 64 11 | final_dim: 1 12 | num_att_heads: 0 13 | dropout: 0.2 14 | lr: 0.0005 15 | wd: 0 16 | batch_size: 16 17 | sch_param: 10 18 | num_epoch: 70 19 | -------------------------------------------------------------------------------- /configs/active-speaker-detection/ava_active-speaker/SPELL_plus_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_plus_ASD_default 2 | model_name: SPELL 3 | graph_name: RESNET50-TSM-AUG_csi_90.0_0.9 4 | loss_name: bce_logit 5 | use_spf: True 6 | use_ref: False 7 | num_modality: 2 8 | channel1: 64 9 | channel2: 16 10 | proj_dim: 64 11 | final_dim: 1 12 | num_att_heads: 0 13 | dropout: 0.2 14 | lr: 0.0003 15 | wd: 0 16 | batch_size: 12 17 | sch_param: 10 18 | num_epoch: 50 19 | -------------------------------------------------------------------------------- /configs/video-summarization/SumMe/SPELL_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_VS_SumMe_default 2 | model_name: SPELL 3 | graph_name: SumMe_10_0 4 | dataset: SumMe 5 | loss_name: bce_logit 6 | use_spf: False 7 | use_ref: False 8 | num_modality: 1 9 | channel1: 128 10 | channel2: 256 11 | final_dim: 1 12 | num_att_heads: 0 13 | dropout: 0.5 14 | lr: 0.001 15 | wd: 0.003 16 | batch_size: 1 17 | sch_param: 5 18 | num_epoch: 40 19 | sample_rate: 1 20 | -------------------------------------------------------------------------------- /configs/video-summarization/TVSum/SPELL_default.yaml: -------------------------------------------------------------------------------- 1 | exp_name: SPELL_VS_TVSum_default 2 | model_name: SPELL 3 | graph_name: TVSum_5_0 4 | dataset: TVSum 5 | loss_name: bce_logit 6 | use_spf: False 7 | use_ref: False 8 | num_modality: 1 9 | channel1: 256 10 | channel2: 128 11 | final_dim: 1 12 | num_att_heads: 0 13 | dropout: 0.5 14 | lr: 0.02 15 | wd: 0.0001 16 | batch_size: 1 17 | sch_param: 5 18 | num_epoch: 50 19 | sample_rate: 1 20 | -------------------------------------------------------------------------------- /data/annotations/merge_ava_activespeaker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import glob 4 | 5 | 6 | def merge_csv_files(path_annts, sp): 7 | """ 8 | Merge multiple csv files into a single file 9 | """ 10 | 11 | csv_files = sorted(glob.glob(os.path.join(path_annts, '*.csv'))) 12 | data = [] 13 | for csv_file in csv_files: 14 | with open(csv_file) as f: 15 | reader = csv.reader(f) 16 | data.extend(list(reader)) 17 | 18 | with open(f'data/annotations/ava_activespeaker_{sp}_v1.0.csv', 'w') as f: 19 | writer = csv.writer(f, delimiter =',') 20 | writer.writerows(data) 21 | 22 | 23 | if __name__ == "__main__": 24 | path_annts = 'data/annotations/ava_activespeaker_test_v1.0' 25 | sp = 'val' 26 | 27 | merge_csv_files(path_annts, sp) 28 | -------------------------------------------------------------------------------- /data/generate_spatial-temporal_graphs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import pickle #nosec 5 | import argparse 6 | import numpy as np 7 | from functools import partial 8 | from multiprocessing import Pool 9 | from torch_geometric.data import Data 10 | 11 | 12 | def _get_time_windows(list_fts, time_span): 13 | """ 14 | Get the time windows from the list of frame_timestamps 15 | Each window is a subset of frame_timestamps where its time span is not greater than "time_span" 16 | 17 | e.g. 18 | input: 19 | list_fts: [902, 903, 904, 905, 910, 911, 912, 913, 914, 917] 20 | time_span: 3 21 | output: 22 | twd_all: [[902, 903, 904], [905], [910, 911, 912], [913, 914], [917]] 23 | """ 24 | 25 | twd_all = [] 26 | 27 | start = end = 0 28 | while end < len(list_fts): 29 | while end < len(list_fts) and list_fts[end] < list_fts[start] + time_span: 30 | end += 1 31 | 32 | twd_all.append(list_fts[start:end]) 33 | start = end 34 | 35 | return twd_all 36 | 37 | 38 | def generate_graph(data_file, args, path_graphs, sp): 39 | """ 40 | Generate graphs of a single video 41 | Time span of each graph is not greater than "time_span" 42 | """ 43 | 44 | video_id = os.path.splitext(os.path.basename(data_file))[0] 45 | with open(data_file, 'rb') as f: 46 | data = pickle.load(f) #nosec 47 | 48 | # Get a list of frame_timestamps 49 | list_fts = sorted([float(frame_timestamp) for frame_timestamp in data.keys()]) 50 | 51 | # Get the time windows where the time span of each window is not greater than "time_span" 52 | twd_all = _get_time_windows(list_fts, args.time_span) 53 | 54 | # Iterate over every time window 55 | num_graph = 0 56 | for twd in twd_all: 57 | # Skip the training graphs without any temporal edges 58 | if sp == 'train' and len(twd) == 1: 59 | continue 60 | 61 | # Get lists of the timestamps, features, coordinates, labels, person_ids, and global_ids for a given time window 62 | timestamp, feature, coord, label, person_id, global_id = [], [], [], [], [], [] 63 | for fts in twd: 64 | for entity in data[f'{fts:g}']: 65 | timestamp.append(fts) 66 | feature.append(entity['feature']) 67 | x1, y1, x2, y2 = [float(c) for c in entity['person_box'].split(',')] 68 | coord.append(np.array([(x1+x2)/2, (y1+y2)/2, x2-x1, y2-y1], dtype=np.float32)) 69 | label.append(entity['label']) 70 | person_id.append(entity['person_id']) 71 | global_id.append(entity['global_id']) 72 | 73 | # Get a list of the edge information: these are for edge_index and edge_attr 74 | node_source = [] 75 | node_target = [] 76 | edge_attr = [] 77 | for i in range(len(timestamp)): 78 | for j in range(len(timestamp)): 79 | # Time difference between the i-th and j-th nodes 80 | time_diff = timestamp[i] - timestamp[j] 81 | 82 | # If the edge connection mode is csi, nodes having the same identity are connected across the frames 83 | # If the edge connection mode is cdi, temporally-distant nodes with different identities are also connected 84 | if args.ec_mode == 'csi': 85 | id_condition = person_id[i] == person_id[j] 86 | elif args.ec_mode == 'cdi': 87 | id_condition = True 88 | 89 | # The edge ij connects the i-th node and j-th node 90 | # Positive edge_attr indicates that the edge ij is backward (negative: forward) 91 | if time_diff == 0 or (abs(time_diff) <= args.tau and id_condition): 92 | node_source.append(i) 93 | node_target.append(j) 94 | edge_attr.append(np.sign(time_diff)) 95 | 96 | # x: features 97 | # c: coordinates of person_box 98 | # g: global_ids 99 | # edge_index: information on how the graph nodes are connected 100 | # edge_attr: information about whether the edge is spatial (0) or temporal (positive: backward, negative: forward) 101 | # y: labels 102 | graphs = Data(x = torch.tensor(np.array(feature, dtype=np.float32), dtype=torch.float32), 103 | c = torch.tensor(np.array(coord, dtype=np.float32), dtype=torch.float32), 104 | g = torch.tensor(global_id, dtype=torch.long), 105 | edge_index = torch.tensor(np.array([node_source, node_target], dtype=np.int64), dtype=torch.long), 106 | edge_attr = torch.tensor(edge_attr, dtype=torch.float32), 107 | y = torch.tensor(np.array(label, dtype=np.float32), dtype=torch.float32)) 108 | 109 | num_graph += 1 110 | torch.save(graphs, os.path.join(path_graphs, f'{video_id}_{num_graph:04d}.pt')) 111 | 112 | return num_graph 113 | 114 | 115 | if __name__ == "__main__": 116 | """ 117 | Generate spatial-temporal graphs from the extracted features 118 | """ 119 | 120 | parser = argparse.ArgumentParser() 121 | # Default paths for the training process 122 | parser.add_argument('--root_data', type=str, help='Root directory to the data', default='./data') 123 | parser.add_argument('--features', type=str, help='Name of the features', required=True) 124 | 125 | # Two options for the edge connection mode: 126 | # csi: Connect the nodes only with the same identities across the frames 127 | # cdi: Connect different identities across the frames 128 | parser.add_argument('--ec_mode', type=str, help='Edge connection mode (csi | cdi)', required=True) 129 | parser.add_argument('--time_span', type=float, help='Maximum time span for each graph in seconds', required=True) 130 | parser.add_argument('--tau', type=float, help='Maximum time difference between neighboring nodes in seconds', required=True) 131 | 132 | args = parser.parse_args() 133 | 134 | # Iterate over train/val splits 135 | print ('This process might take a few minutes') 136 | for sp in ['train', 'val']: 137 | path_graphs = os.path.join(args.root_data, f'graphs/{args.features}_{args.ec_mode}_{args.time_span}_{args.tau}/{sp}') 138 | os.makedirs(path_graphs, exist_ok=True) 139 | 140 | list_data_files = sorted(glob.glob(os.path.join(args.root_data, f'features/{args.features}/{sp}/*.pkl'))) 141 | 142 | with Pool(processes=20) as pool: 143 | num_graph = pool.map(partial(generate_graph, args=args, path_graphs=path_graphs, sp=sp), list_data_files) 144 | 145 | print (f'Graph generation for {sp} is finished (number of graphs: {sum(num_graph)})') 146 | -------------------------------------------------------------------------------- /data/generate_temporal_graphs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import h5py 5 | import argparse 6 | import numpy as np 7 | from functools import partial 8 | from multiprocessing import Pool 9 | from torch_geometric.data import Data 10 | from random import randint, seed 11 | 12 | 13 | def get_edge_info(num_frame, args): 14 | skip = args.skip_factor 15 | 16 | # Get a list of the edge information: these are for edge_index and edge_attr 17 | node_source = [] 18 | node_target = [] 19 | edge_attr = [] 20 | for i in range(num_frame): 21 | for j in range(num_frame): 22 | # Frame difference between the i-th and j-th nodes 23 | frame_diff = i - j 24 | 25 | # The edge ij connects the i-th node and j-th node 26 | # Positive edge_attr indicates that the edge ij is backward (negative: forward) 27 | if abs(frame_diff) <= args.tauf: 28 | node_source.append(i) 29 | node_target.append(j) 30 | edge_attr.append(np.sign(frame_diff)) 31 | 32 | # Make additional connections between non-adjacent nodes 33 | # This can help reduce over-segmentation of predictions in some cases 34 | elif skip: 35 | if (frame_diff % skip == 0) and (abs(frame_diff) <= skip * args.tauf): 36 | node_source.append(i) 37 | node_target.append(j) 38 | edge_attr.append(np.sign(frame_diff)) 39 | 40 | return node_source, node_target, edge_attr 41 | 42 | 43 | def generate_sum_temporal_graph(video_data, args, path_graphs): 44 | """ 45 | Generate temporal graphs of a single video from video summarization data set 46 | """ 47 | video_id = video_data["global_id"] 48 | out_folder = video_data["purpose"] 49 | features = video_data["features"] 50 | gtscore = np.array([video_data["gtscore"]], dtype=np.float32)[::args.sample_rate] 51 | gtscore = gtscore.transpose() 52 | num_samples = features.shape[0] 53 | 54 | # Get a list of the edge information: these are for edge_index and edge_attr 55 | node_source, node_target, edge_attr = get_edge_info(num_samples, args) 56 | 57 | # x: features 58 | # g: global_id 59 | # edge_index: information on how the graph nodes are connected 60 | # edge_attr: information about whether the edge is spatial (0) or temporal (positive: backward, negative: forward) 61 | # y: gtscore 62 | graphs = Data(x=torch.tensor(np.array(features, dtype=np.float32), dtype=torch.float32), 63 | g=video_id, 64 | edge_index=torch.tensor(np.array([node_source, node_target], dtype=np.int64), dtype=torch.long), 65 | edge_attr=torch.tensor(edge_attr, dtype=torch.float32), 66 | y=torch.tensor(gtscore, dtype=torch.float32)) 67 | 68 | torch.save(graphs, os.path.join(path_graphs, f'{out_folder}/{video_id}.pt')) 69 | 70 | 71 | def generate_temporal_graph(data_file, args, path_graphs, actions, train_ids, all_ids): 72 | """ 73 | Generate temporal graphs of a single video 74 | """ 75 | 76 | video_id = os.path.splitext(os.path.basename(data_file))[0] 77 | feature = np.transpose(np.load(data_file)) 78 | num_frame = feature.shape[0] 79 | 80 | # Get a list of ground-truth action labels 81 | with open(os.path.join(args.root_data, f'annotations/{args.dataset}/groundTruth/{video_id}.txt')) as f: 82 | label = [actions[line.strip()] for line in f] 83 | 84 | # Get a list of the edge information: these are for edge_index and edge_attr 85 | node_source, node_target, edge_attr = get_edge_info(num_frame, args) 86 | 87 | # x: features 88 | # g: global_id 89 | # edge_index: information on how the graph nodes are connected 90 | # edge_attr: information about whether the edge is spatial (0) or temporal (positive: backward, negative: forward) 91 | # y: labels 92 | graphs = Data(x = torch.tensor(np.array(feature, dtype=np.float32), dtype=torch.float32), 93 | g = all_ids.index(video_id), 94 | edge_index = torch.tensor(np.array([node_source, node_target], dtype=np.int64), dtype=torch.long), 95 | edge_attr = torch.tensor(edge_attr, dtype=torch.float32), 96 | y = torch.tensor(np.array(label, dtype=np.int64)[::args.sample_rate], dtype=torch.long)) 97 | 98 | if video_id in train_ids: 99 | torch.save(graphs, os.path.join(path_graphs, 'train', f'{video_id}.pt')) 100 | else: 101 | torch.save(graphs, os.path.join(path_graphs, 'val', f'{video_id}.pt')) 102 | 103 | 104 | if __name__ == "__main__": 105 | """ 106 | Generate temporal graphs from the extracted features 107 | """ 108 | 109 | parser = argparse.ArgumentParser() 110 | # Default paths for the training process 111 | parser.add_argument('--root_data', type=str, help='Root directory to the data', default='./data') 112 | parser.add_argument('--dataset', type=str, help='Name of the dataset', default='50salads') 113 | parser.add_argument('--features', type=str, help='Name of the features', required=True) 114 | 115 | # Hyperparameters for the graph generation 116 | parser.add_argument('--tauf', type=int, help='Maximum frame difference between neighboring nodes', required=True) 117 | parser.add_argument('--skip_factor', type=int, help='Make additional connections between non-adjacent nodes', default=10) 118 | parser.add_argument('--sample_rate', type=int, help='Downsampling rate for the input', default=2) 119 | 120 | args = parser.parse_args() 121 | 122 | print ('This process might take a few minutes') 123 | 124 | actions = {} 125 | all_ids = [] 126 | if args.dataset == "50salads": 127 | 128 | # Build a mapping from action classes to action ids 129 | with open(os.path.join(args.root_data, f'annotations/{args.dataset}/mapping.txt')) as f: 130 | for line in f: 131 | aid, cls = line.strip().split(' ') 132 | actions[cls] = int(aid) 133 | 134 | # Get a list of all video ids 135 | all_ids = sorted([os.path.splitext(v)[0] for v in 136 | os.listdir(os.path.join(args.root_data, f'annotations/{args.dataset}/groundTruth'))]) 137 | 138 | # Iterate over different splits 139 | list_splits = sorted(os.listdir(os.path.join(args.root_data, f'features/{args.features}'))) 140 | for split in list_splits: 141 | # Get a list of training video ids 142 | with open(os.path.join(args.root_data, f'annotations/{args.dataset}/splits/train.{split}.bundle')) as f: 143 | train_ids = [os.path.splitext(line.strip())[0] for line in f] 144 | 145 | path_graphs = os.path.join(args.root_data, f'graphs/{args.features}_{args.tauf}_{args.skip_factor}/{split}') 146 | os.makedirs(os.path.join(path_graphs, 'train'), exist_ok=True) 147 | os.makedirs(os.path.join(path_graphs, 'val'), exist_ok=True) 148 | 149 | list_data_files = sorted(glob.glob(os.path.join(args.root_data, f'features/{args.features}/{split}/*.npy'))) 150 | 151 | with Pool(processes=20) as pool: 152 | pool.map(partial(generate_temporal_graph, args=args, path_graphs=path_graphs, actions=actions, train_ids=train_ids, all_ids=all_ids), list_data_files) 153 | 154 | print (f'Graph generation for {split} is finished') 155 | 156 | elif args.dataset == "SumMe" or args.dataset == "TVSum": 157 | 158 | path_dataset = os.path.join(args.root_data, 159 | f'annotations/{args.dataset}/{args.features}.h5') 160 | 161 | with h5py.File(path_dataset, 'r') as hdf: 162 | all_videos = list(hdf.keys()) 163 | 164 | all_ids = [] 165 | dataset = [None] * len(all_videos) 166 | for video in all_videos: 167 | id = int(video.split("_")[1]) 168 | all_ids.append(id - 1) 169 | 170 | data = {} 171 | data["global_id"] = id 172 | data["purpose"] = "train" 173 | data["features"] = np.array(hdf.get(video + '/features')) 174 | data["gtscore"] = np.array(hdf.get(video + '/gtscore')) 175 | dataset[id - 1] = data 176 | 177 | # Set the seed value 178 | seed(42) 179 | 180 | amount_to_select = len(all_ids) // 5 181 | for split_i in range(1, 6): 182 | # Init 183 | for i in all_ids: 184 | dataset[i]['purpose'] = "train" 185 | 186 | # Randomly select 20% of videos for validation 187 | train_ids = all_ids.copy() 188 | for _ in range(amount_to_select): 189 | train_idx = randint(0, len(train_ids) - 1) 190 | video_idx = train_ids.pop(train_idx) 191 | dataset[video_idx]["purpose"] = "val" 192 | 193 | path_graphs = os.path.join(args.root_data, 194 | f'graphs/{args.dataset}_{args.tauf}_{args.skip_factor}/split{split_i}') 195 | os.makedirs(os.path.join(path_graphs, 'train'), exist_ok=True) 196 | os.makedirs(os.path.join(path_graphs, 'val'), exist_ok=True) 197 | 198 | with Pool(processes=20) as pool: 199 | pool.map(partial(generate_sum_temporal_graph, args=args, path_graphs=path_graphs), dataset) 200 | 201 | print(f'Graph generation for split{split_i} is finished') -------------------------------------------------------------------------------- /docs/GETTING_STARTED_AL.md: -------------------------------------------------------------------------------- 1 | ## Getting Started (Action Localization) 2 | ### Annotations 3 | Download the annotations of AVA-Actions from the official site: 4 | ``` 5 | DATA_DIR="data/annotations" 6 | 7 | wget https://research.google.com/ava/download/ava_val_v2.2.csv.zip -P ${DATA_DIR} 8 | wget https://research.google.com/ava/download/ava_action_list_v2.2_for_activitynet_2019.pbtxt -P ${DATA_DIR} 9 | 10 | unzip ${DATA_DIR}/ava_val_v2.2.csv.zip -d ${DATA_DIR} 11 | mv ${DATA_DIR}/research/action_recognition/ava/website/www/download/ava_val_v2.2.csv ${DATA_DIR} 12 | ``` 13 | 14 | ### Features 15 | Download `SLOWFAST-64x2-R101.zip` from the Google Drive link from [SPELL](https://github.com/SRA2/SPELL#code-usage) and unzip under `data/features`. 16 | > We use the features from the thirdparty repositories. SLOWFAST-64x2-R101 is obtained by using the official code of [SlowFast](https://github.com/facebookresearch/SlowFast) with the pretrained checkpoint ([SLOWFAST_64x2_R101_50_50.pkl](https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/ava/SLOWFAST_64x2_R101_50_50.pkl)) in [SlowFast Model Zoo](https://github.com/facebookresearch/SlowFast/blob/main/MODEL_ZOO.md). 17 | 18 | ### Directory Structure 19 | The data directories should look as follows: 20 | ``` 21 | |-- data 22 | |-- annotations 23 | |-- ava_val_v2.2.csv 24 | |-- ava_action_list_v2.2_for_activitynet_2019.pbtxt 25 | |-- features 26 | |-- SLOWFAST-64x2-R101 27 | |-- train 28 | |-- val 29 | ``` 30 | 31 | ### Experiments 32 | We can perform the experiments on action localization with the default configuration by following the three steps below. 33 | 34 | #### Step 1: Graph Generation 35 | Run the following command to generate spatial-temporal graphs from the features: 36 | ``` 37 | python data/generate_spatial-temporal_graphs.py --features SLOWFAST-64x2-R101 --ec_mode cdi --time_span 90 --tau 3 38 | ``` 39 | The generated graphs will be saved under `data/graphs`. Each graph captures long temporal context information in a video, which spans about 90 seconds (specified by `--time_span`). 40 | 41 | #### Step 2: Training 42 | Next, run the training script by passing the default configuration file: 43 | ``` 44 | python tools/train_context_reasoning.py --cfg configs/action-localization/ava_actions/SPELL_default.yaml 45 | ``` 46 | The results and logs will be saved under `results`. 47 | 48 | #### Step 3: Evaluation 49 | Now, we can evaluate the trained model's performance: 50 | ``` 51 | python tools/evaluate.py --exp_name SPELL_AL_default --eval_type AVA_AL 52 | ``` 53 | This will print the evaluation score. 54 | -------------------------------------------------------------------------------- /docs/GETTING_STARTED_AS.md: -------------------------------------------------------------------------------- 1 | ## Getting Started (Action Segmentation) 2 | ### Annotations 3 | We suggest using the same set of annotations used by [MS-TCN++](https://github.com/sj-li/MS-TCN2) and [ASFormer](https://github.com/ChinaYi/ASFormer). Download the 50Salads dataset from the links provided by either of the two repositories. 4 | 5 | ### Features 6 | We suggest extracting the features using [ASFormer](https://github.com/ChinaYi/ASFormer). Please use their repository and the pre-trained model checkpoints ([link](https://github.com/ChinaYi/ASFormer/tree/main#reproduce-our-results)) to extract the frame-wise features for each split of the dataset. Please extract the features from each of the four refinement layers and concatenate them. To be more specific, you can concatenate the 64-dimensional features from this [line](https://github.com/ChinaYi/ASFormer/blob/main/model.py#L315), which will give you 256-dimensional (frame-wise) features. Similarly, you can also extract MS-TCN++ features from this [line](https://github.com/sj-li/MS-TCN2/blob/master/model.py#L23). 7 | > We use the features from the thirdparty repositories. 8 | 9 | ### Directory Structure 10 | The data directories should look as follows: 11 | ``` 12 | |-- data 13 | |-- annotations 14 | |-- 50salads 15 | |-- groundTruth 16 | |-- splits 17 | |-- mapping.txt 18 | |-- features 19 | |-- ASFORMER 20 | |-- split1 21 | |-- split2 22 | |-- split3 23 | |-- split4 24 | |-- split5 25 | ``` 26 | 27 | ### Experiments 28 | We can perform the experiments on action segmentation with the default configuration by following the three steps below. 29 | 30 | #### Step 1: Graph Generation 31 | Run the following command to generate temporal graphs from the features: 32 | ``` 33 | python data/generate_temporal_graphs.py --features ASFORMER --tauf 10 34 | ``` 35 | The generated graphs will be saved under `data/graphs`. Each graph captures long temporal context information in a video. 36 | 37 | #### Step 2: Training 38 | Next, run the training script by passing the default configuration file. You also need to specify which split to perform the experiments on: 39 | ``` 40 | python tools/train_context_reasoning.py --cfg configs/action-segmentation/50salads/SPELL_default.yaml --split 2 41 | ``` 42 | The results and logs will be saved under `results`. 43 | 44 | #### Step 3: Evaluation 45 | Now, we can evaluate the trained model's performance. You also need to specify which split to evaluate the experiments on: 46 | ``` 47 | python tools/evaluate.py --dataset 50salads --exp_name SPELL_AS_default --eval_type AS --split 2 48 | ``` 49 | This will print the evaluation scores. 50 | -------------------------------------------------------------------------------- /docs/GETTING_STARTED_VS.md: -------------------------------------------------------------------------------- 1 | ## Getting Started (Video Summarization) 2 | ### Datasets with annotations and features 3 | We suggest using the same set of datasets used by [PGL-SUM](https://github.com/e-apostolidis/PGL-SUM) or [A2Summ](https://github.com/boheumd/A2Summ). Download the TVSum & SumMe datasets from the links provided by either of the two repositories. 4 | 5 | ### Directory Structure 6 | The data directories should look as follows: 7 | ``` 8 | |-- data 9 | |-- annotations 10 | |-- SumMe 11 | |-- eccv16_dataset_summe_google_pool5.h5 12 | |-- TVSum 13 | |-- eccv16_dataset_tvsum_google_pool5.h5 14 | ``` 15 | 16 | ### Experiments 17 | We can perform the experiments on video summarization with the default configuration by following the three steps below. 18 | 19 | #### Step 1: Graph Generation 20 | Run the following command to generate temporal graphs from the features: 21 | 22 | On SumMe: 23 | ``` 24 | python data/generate_temporal_graphs.py --dataset SumMe --features eccv16_dataset_summe_google_pool5 --tauf 10 --skip_factor 0 25 | ``` 26 | On TVSum: 27 | ``` 28 | python data/generate_temporal_graphs.py --dataset TVSum --features eccv16_dataset_tvsum_google_pool5 --tauf 5 --skip_factor 0 29 | ``` 30 | The generated graphs will be saved under `data/graphs`. Each graph captures long temporal context information in a video. 31 | 32 | #### Step 2: Training 33 | Next, run the training script by passing the default configuration file. You also need to specify which split to perform the experiments on: 34 | 35 | On SumMe: 36 | ``` 37 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 4 38 | ``` 39 | On TVSum: 40 | ``` 41 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 4 42 | ``` 43 | The results and logs will be saved under `results`. 44 | 45 | #### Step 3: Evaluation 46 | Now, we can evaluate the trained model's performance, You also need to specify which split to perform the evaluation on: 47 | 48 | On SumMe: 49 | ``` 50 | python tools/evaluate.py --exp_name SPELL_VS_SumMe_default --eval_type VS_max --split 4 51 | ``` 52 | On TVSum: 53 | ``` 54 | python tools/evaluate.py --exp_name SPELL_VS_TVSum_default --eval_type VS_avg --split 4 55 | ``` 56 | 57 | This will print the evaluation scores. 58 | 59 | #### Step 3: Evaluation Alternative 60 | You can also get average results from all splits: 61 | 62 | On SumMe: 63 | ``` 64 | python tools/evaluate.py --exp_name SPELL_VS_SumMe_default --eval_type VS_max --all_splits 65 | ``` 66 | On TVSum: 67 | ``` 68 | python tools/evaluate.py --exp_name SPELL_VS_TVSum_default --eval_type VS_avg --all_splits 69 | ``` 70 | #### Note: 71 | 72 | You can use bash scripts from `gravit/utils/vs/` to train models on all the splits and get evaluation metrics for TVSum and SumMe. -------------------------------------------------------------------------------- /docs/images/gravit_teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/GraVi-T/fb343d43d575cc91f7deaa47f2321d99bac1aad5/docs/images/gravit_teaser.jpg -------------------------------------------------------------------------------- /gravit/__init__.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import get_distribution 2 | 3 | try: 4 | __version__ = get_distribution('gravit').version 5 | except: 6 | __version__ = '1.1.0' 7 | -------------------------------------------------------------------------------- /gravit/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_context_reasoning import GraphDataset 2 | -------------------------------------------------------------------------------- /gravit/datasets/dataset_context_reasoning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | from torch_geometric.data import Data, Dataset 5 | 6 | 7 | class GraphDataset(Dataset): 8 | """ 9 | General class for graph dataset 10 | """ 11 | 12 | def __init__(self, path_graphs): 13 | super(GraphDataset, self).__init__() 14 | self.all_graphs = sorted(glob.glob(os.path.join(path_graphs, '*.pt'))) 15 | 16 | def len(self): 17 | return len(self.all_graphs) 18 | 19 | def get(self, idx): 20 | data = torch.load(self.all_graphs[idx]) 21 | return data 22 | -------------------------------------------------------------------------------- /gravit/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model 2 | from .losses import get_loss_func 3 | -------------------------------------------------------------------------------- /gravit/models/build.py: -------------------------------------------------------------------------------- 1 | from .context_reasoning import * 2 | 3 | 4 | def build_model(cfg, device): 5 | """ 6 | Build the model corresponding to "model_name" 7 | """ 8 | 9 | model_name = cfg['model_name'] 10 | model = globals()[model_name](cfg) 11 | 12 | return model.to(device) 13 | -------------------------------------------------------------------------------- /gravit/models/context_reasoning/__init__.py: -------------------------------------------------------------------------------- 1 | from .spell import SPELL 2 | -------------------------------------------------------------------------------- /gravit/models/context_reasoning/spell.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, ModuleList, Conv1d, Sequential, ReLU, Dropout 3 | from torch_geometric.nn import Linear, EdgeConv, GATv2Conv, SAGEConv, BatchNorm 4 | 5 | 6 | class DilatedResidualLayer(Module): 7 | def __init__(self, dilation, in_channels, out_channels): 8 | super(DilatedResidualLayer, self).__init__() 9 | self.conv_dilated = Conv1d(in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation) 10 | self.conv_1x1 = Conv1d(out_channels, out_channels, kernel_size=1) 11 | self.relu = ReLU() 12 | self.dropout = Dropout() 13 | 14 | def forward(self, x): 15 | out = self.relu(self.conv_dilated(x)) 16 | out = self.conv_1x1(out) 17 | out = self.dropout(out) 18 | return x + out 19 | 20 | 21 | # This is for the iterative refinement (we refer to MSTCN++: https://github.com/sj-li/MS-TCN2) 22 | class Refinement(Module): 23 | def __init__(self, final_dim, num_layers=10, interm_dim=64): 24 | super(Refinement, self).__init__() 25 | self.conv_1x1 = Conv1d(final_dim, interm_dim, kernel_size=1) 26 | self.layers = ModuleList([DilatedResidualLayer(2**i, interm_dim, interm_dim) for i in range(num_layers)]) 27 | self.conv_out = Conv1d(interm_dim, final_dim, kernel_size=1) 28 | 29 | def forward(self, x): 30 | f = self.conv_1x1(x) 31 | for layer in self.layers: 32 | f = layer(f) 33 | out = self.conv_out(f) 34 | return out 35 | 36 | 37 | class SPELL(Module): 38 | def __init__(self, cfg): 39 | super(SPELL, self).__init__() 40 | self.use_spf = cfg['use_spf'] # whether to use the spatial features 41 | self.use_ref = cfg['use_ref'] 42 | self.num_modality = cfg['num_modality'] 43 | channels = [cfg['channel1'], cfg['channel2']] 44 | final_dim = cfg['final_dim'] 45 | num_att_heads = cfg['num_att_heads'] 46 | dropout = cfg['dropout'] 47 | 48 | if self.use_spf: 49 | self.layer_spf = Linear(-1, cfg['proj_dim']) # projection layer for spatial features 50 | 51 | self.layer011 = Linear(-1, channels[0]) 52 | if self.num_modality == 2: 53 | self.layer012 = Linear(-1, channels[0]) 54 | 55 | self.batch01 = BatchNorm(channels[0]) 56 | self.relu = ReLU() 57 | self.dropout = Dropout(dropout) 58 | 59 | self.layer11 = EdgeConv(Sequential(Linear(2*channels[0], channels[0]), ReLU(), Linear(channels[0], channels[0]))) 60 | self.batch11 = BatchNorm(channels[0]) 61 | self.layer12 = EdgeConv(Sequential(Linear(2*channels[0], channels[0]), ReLU(), Linear(channels[0], channels[0]))) 62 | self.batch12 = BatchNorm(channels[0]) 63 | self.layer13 = EdgeConv(Sequential(Linear(2*channels[0], channels[0]), ReLU(), Linear(channels[0], channels[0]))) 64 | self.batch13 = BatchNorm(channels[0]) 65 | 66 | if num_att_heads > 0: 67 | self.layer21 = GATv2Conv(channels[0], channels[1], heads=num_att_heads) 68 | else: 69 | self.layer21 = SAGEConv(channels[0], channels[1]) 70 | num_att_heads = 1 71 | self.batch21 = BatchNorm(channels[1]*num_att_heads) 72 | 73 | self.layer31 = SAGEConv(channels[1]*num_att_heads, final_dim) 74 | self.layer32 = SAGEConv(channels[1]*num_att_heads, final_dim) 75 | self.layer33 = SAGEConv(channels[1]*num_att_heads, final_dim) 76 | 77 | if self.use_ref: 78 | self.layer_ref1 = Refinement(final_dim) 79 | self.layer_ref2 = Refinement(final_dim) 80 | self.layer_ref3 = Refinement(final_dim) 81 | 82 | 83 | def forward(self, x, edge_index, edge_attr, c=None): 84 | feature_dim = x.shape[1] 85 | 86 | if self.use_spf: 87 | x_visual = self.layer011(torch.cat((x[:, :feature_dim//self.num_modality], self.layer_spf(c)), dim=1)) 88 | else: 89 | x_visual = self.layer011(x[:, :feature_dim//self.num_modality]) 90 | 91 | if self.num_modality == 1: 92 | x = x_visual 93 | elif self.num_modality == 2: 94 | x_audio = self.layer012(x[:, feature_dim//self.num_modality:]) 95 | x = x_visual + x_audio 96 | 97 | x = self.batch01(x) 98 | x = self.relu(x) 99 | 100 | edge_index_f = edge_index[:, edge_attr<=0] 101 | edge_index_b = edge_index[:, edge_attr>=0] 102 | 103 | # Forward-graph stream 104 | x1 = self.layer11(x, edge_index_f) 105 | x1 = self.batch11(x1) 106 | x1 = self.relu(x1) 107 | x1 = self.dropout(x1) 108 | x1 = self.layer21(x1, edge_index_f) 109 | x1 = self.batch21(x1) 110 | x1 = self.relu(x1) 111 | x1 = self.dropout(x1) 112 | 113 | # Backward-graph stream 114 | x2 = self.layer12(x, edge_index_b) 115 | x2 = self.batch12(x2) 116 | x2 = self.relu(x2) 117 | x2 = self.dropout(x2) 118 | x2 = self.layer21(x2, edge_index_b) 119 | x2 = self.batch21(x2) 120 | x2 = self.relu(x2) 121 | x2 = self.dropout(x2) 122 | 123 | # Undirected-graph stream 124 | x3 = self.layer13(x, edge_index) 125 | x3 = self.batch13(x3) 126 | x3 = self.relu(x3) 127 | x3 = self.dropout(x3) 128 | x3 = self.layer21(x3, edge_index) 129 | x3 = self.batch21(x3) 130 | x3 = self.relu(x3) 131 | x3 = self.dropout(x3) 132 | 133 | x1 = self.layer31(x1, edge_index_f) 134 | x2 = self.layer32(x2, edge_index_b) 135 | x3 = self.layer33(x3, edge_index) 136 | 137 | out = x1+x2+x3 138 | 139 | if self.use_ref: 140 | xr0 = torch.permute(out, (1, 0)).unsqueeze(0) 141 | xr1 = self.layer_ref1(torch.softmax(xr0, dim=1)) 142 | xr2 = self.layer_ref2(torch.softmax(xr1, dim=1)) 143 | xr3 = self.layer_ref3(torch.softmax(xr2, dim=1)) 144 | out = torch.stack((xr0, xr1, xr2, xr2), dim=0).squeeze(1).transpose(2, 1).contiguous() 145 | 146 | return out 147 | -------------------------------------------------------------------------------- /gravit/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Module, CrossEntropyLoss, BCEWithLogitsLoss, MSELoss 4 | 5 | 6 | class CEWithREF(Module): 7 | def __init__(self, w_ref, mode='train'): 8 | super(CEWithREF, self).__init__() 9 | self.ce = CrossEntropyLoss() 10 | self.mse = MSELoss(reduction='none') 11 | self.w_ref = w_ref 12 | self.mode = mode 13 | 14 | def forward(self, input: Tensor, target: Tensor) -> Tensor: 15 | if self.mode == 'train': 16 | loss = 0 17 | for pred in input: 18 | loss += self.ce(pred, target) 19 | loss += self.w_ref * self.mse(torch.log_softmax(pred[1:, :], dim=1), torch.log_softmax(pred.detach()[:-1, :], dim=1)).clamp(0, 16).mean() 20 | else: 21 | pred = input[-1] 22 | loss = self.ce(pred, target) + self.w_ref * self.mse(torch.log_softmax(pred[1:, :], dim=1), torch.log_softmax(pred.detach()[:-1, :], dim=1)).clamp(0, 16).mean() 23 | 24 | return loss 25 | 26 | 27 | _LOSSES = { 28 | 'ce': CrossEntropyLoss, 29 | 'bce_logit': BCEWithLogitsLoss, 30 | 'mse': MSELoss, 31 | 'ce_ref': CEWithREF 32 | } 33 | 34 | 35 | def get_loss_func(cfg, mode='train'): 36 | """ 37 | Get the loss function corresponding to "loss_name" 38 | """ 39 | 40 | loss_name = cfg['loss_name'] 41 | if loss_name not in _LOSSES: 42 | raise ValueError(f'Loss {loss_name} is not implemented in models/losses.py') 43 | 44 | if cfg['use_ref']: 45 | return _LOSSES[loss_name](cfg['w_ref'], mode) 46 | 47 | return _LOSSES[loss_name]() 48 | -------------------------------------------------------------------------------- /gravit/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/GraVi-T/fb343d43d575cc91f7deaa47f2321d99bac1aad5/gravit/utils/__init__.py -------------------------------------------------------------------------------- /gravit/utils/ava/README.md: -------------------------------------------------------------------------------- 1 | The code under this folder is taken from the official [ActivityNet repository](https://github.com/activitynet/ActivityNet/tree/master/Evaluation/ava) with minimal modification. We do not own the rights to this code. 2 | 3 | All rights belong to the original copyright owner: ActivityNet 4 | 5 | Copyright (c) 2015 ActivityNet 6 | 7 | Licensed under The MIT License 8 | 9 | Please refer to https://github.com/activitynet/ActivityNet/blob/master/LICENSE 10 | -------------------------------------------------------------------------------- /gravit/utils/ava/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IntelLabs/GraVi-T/fb343d43d575cc91f7deaa47f2321d99bac1aad5/gravit/utils/ava/__init__.py -------------------------------------------------------------------------------- /gravit/utils/ava/label_map_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Label map utility functions.""" 16 | 17 | import logging 18 | 19 | # from google.protobuf import text_format 20 | # from google3.third_party.tensorflow_models.object_detection.protos import string_int_label_map_pb2 21 | 22 | 23 | def _validate_label_map(label_map): 24 | """Checks if a label map is valid. 25 | 26 | Args: 27 | label_map: StringIntLabelMap to validate. 28 | 29 | Raises: 30 | ValueError: if label map is invalid. 31 | """ 32 | for item in label_map.item: 33 | if item.id < 1: 34 | raise ValueError('Label map ids should be >= 1.') 35 | 36 | 37 | def create_category_index(categories): 38 | """Creates dictionary of COCO compatible categories keyed by category id. 39 | 40 | Args: 41 | categories: a list of dicts, each of which has the following keys: 42 | 'id': (required) an integer id uniquely identifying this category. 43 | 'name': (required) string representing category name 44 | e.g., 'cat', 'dog', 'pizza'. 45 | 46 | Returns: 47 | category_index: a dict containing the same entries as categories, but keyed 48 | by the 'id' field of each category. 49 | """ 50 | category_index = {} 51 | for cat in categories: 52 | category_index[cat['id']] = cat 53 | return category_index 54 | 55 | 56 | def get_max_label_map_index(label_map): 57 | """Get maximum index in label map. 58 | 59 | Args: 60 | label_map: a StringIntLabelMapProto 61 | 62 | Returns: 63 | an integer 64 | """ 65 | return max([item.id for item in label_map.item]) 66 | 67 | 68 | def convert_label_map_to_categories(label_map, 69 | max_num_classes, 70 | use_display_name=True): 71 | """Loads label map proto and returns categories list compatible with eval. 72 | 73 | This function loads a label map and returns a list of dicts, each of which 74 | has the following keys: 75 | 'id': (required) an integer id uniquely identifying this category. 76 | 'name': (required) string representing category name 77 | e.g., 'cat', 'dog', 'pizza'. 78 | We only allow class into the list if its id-label_id_offset is 79 | between 0 (inclusive) and max_num_classes (exclusive). 80 | If there are several items mapping to the same id in the label map, 81 | we will only keep the first one in the categories list. 82 | 83 | Args: 84 | label_map: a StringIntLabelMapProto or None. If None, a default categories 85 | list is created with max_num_classes categories. 86 | max_num_classes: maximum number of (consecutive) label indices to include. 87 | use_display_name: (boolean) choose whether to load 'display_name' field 88 | as category name. If False or if the display_name field does not exist, 89 | uses 'name' field as category names instead. 90 | Returns: 91 | categories: a list of dictionaries representing all possible categories. 92 | """ 93 | categories = [] 94 | list_of_ids_already_added = [] 95 | if not label_map: 96 | label_id_offset = 1 97 | for class_id in range(max_num_classes): 98 | categories.append({ 99 | 'id': class_id + label_id_offset, 100 | 'name': 'category_{}'.format(class_id + label_id_offset) 101 | }) 102 | return categories 103 | for item in label_map.item: 104 | if not 0 < item.id <= max_num_classes: 105 | logging.info('Ignore item %d since it falls outside of requested ' 106 | 'label range.', item.id) 107 | continue 108 | if use_display_name and item.HasField('display_name'): 109 | name = item.display_name 110 | else: 111 | name = item.name 112 | if item.id not in list_of_ids_already_added: 113 | list_of_ids_already_added.append(item.id) 114 | categories.append({'id': item.id, 'name': name}) 115 | return categories 116 | 117 | 118 | def load_labelmap(path): 119 | """Loads label map proto. 120 | 121 | Args: 122 | path: path to StringIntLabelMap proto text file. 123 | Returns: 124 | a StringIntLabelMapProto 125 | """ 126 | with open(path, 'r') as fid: 127 | label_map_string = fid.read() 128 | label_map = string_int_label_map_pb2.StringIntLabelMap() 129 | try: 130 | text_format.Merge(label_map_string, label_map) 131 | except text_format.ParseError: 132 | label_map.ParseFromString(label_map_string) 133 | _validate_label_map(label_map) 134 | return label_map 135 | 136 | 137 | def get_label_map_dict(label_map_path, use_display_name=False): 138 | """Reads a label map and returns a dictionary of label names to id. 139 | 140 | Args: 141 | label_map_path: path to label_map. 142 | use_display_name: whether to use the label map items' display names as keys. 143 | 144 | Returns: 145 | A dictionary mapping label names to id. 146 | """ 147 | label_map = load_labelmap(label_map_path) 148 | label_map_dict = {} 149 | for item in label_map.item: 150 | if use_display_name: 151 | label_map_dict[item.display_name] = item.id 152 | else: 153 | label_map_dict[item.name] = item.id 154 | return label_map_dict 155 | 156 | 157 | def create_category_index_from_labelmap(label_map_path): 158 | """Reads a label map and returns a category index. 159 | 160 | Args: 161 | label_map_path: Path to `StringIntLabelMap` proto text file. 162 | 163 | Returns: 164 | A category index, which is a dictionary that maps integer ids to dicts 165 | containing categories, e.g. 166 | {1: {'id': 1, 'name': 'dog'}, 2: {'id': 2, 'name': 'cat'}, ...} 167 | """ 168 | label_map = load_labelmap(label_map_path) 169 | max_num_classes = max(item.id for item in label_map.item) 170 | categories = convert_label_map_to_categories(label_map, max_num_classes) 171 | return create_category_index(categories) 172 | 173 | 174 | def create_class_agnostic_category_index(): 175 | """Creates a category index with a single `object` class.""" 176 | return {1: {'id': 1, 'name': 'object'}} 177 | -------------------------------------------------------------------------------- /gravit/utils/ava/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Functions for computing metrics like precision, recall, CorLoc and etc.""" 17 | from __future__ import division 18 | 19 | import numpy as np 20 | 21 | 22 | def compute_precision_recall(scores, labels, num_gt): 23 | """Compute precision and recall. 24 | 25 | Args: 26 | scores: A float numpy array representing detection score 27 | labels: A boolean numpy array representing true/false positive labels 28 | num_gt: Number of ground truth instances 29 | 30 | Raises: 31 | ValueError: if the input is not of the correct format 32 | 33 | Returns: 34 | precision: Fraction of positive instances over detected ones. This value is 35 | None if no ground truth labels are present. 36 | recall: Fraction of detected positive instance over all positive instances. 37 | This value is None if no ground truth labels are present. 38 | 39 | """ 40 | if not isinstance( 41 | labels, np.ndarray) or labels.dtype != bool or len(labels.shape) != 1: 42 | raise ValueError("labels must be single dimension bool numpy array") 43 | 44 | if not isinstance( 45 | scores, np.ndarray) or len(scores.shape) != 1: 46 | raise ValueError("scores must be single dimension numpy array") 47 | 48 | if num_gt < np.sum(labels): 49 | raise ValueError("Number of true positives must be smaller than num_gt.") 50 | 51 | if len(scores) != len(labels): 52 | raise ValueError("scores and labels must be of the same size.") 53 | 54 | if num_gt == 0: 55 | return None, None 56 | 57 | sorted_indices = np.argsort(scores) 58 | sorted_indices = sorted_indices[::-1] 59 | labels = labels.astype(int) 60 | true_positive_labels = labels[sorted_indices] 61 | false_positive_labels = 1 - true_positive_labels 62 | cum_true_positives = np.cumsum(true_positive_labels) 63 | cum_false_positives = np.cumsum(false_positive_labels) 64 | precision = cum_true_positives.astype(float) / ( 65 | cum_true_positives + cum_false_positives) 66 | recall = cum_true_positives.astype(float) / num_gt 67 | return precision, recall 68 | 69 | 70 | def compute_average_precision(precision, recall): 71 | """Compute Average Precision according to the definition in VOCdevkit. 72 | 73 | Precision is modified to ensure that it does not decrease as recall 74 | decrease. 75 | 76 | Args: 77 | precision: A float [N, 1] numpy array of precisions 78 | recall: A float [N, 1] numpy array of recalls 79 | 80 | Raises: 81 | ValueError: if the input is not of the correct format 82 | 83 | Returns: 84 | average_precison: The area under the precision recall curve. NaN if 85 | precision and recall are None. 86 | 87 | """ 88 | if precision is None: 89 | if recall is not None: 90 | raise ValueError("If precision is None, recall must also be None") 91 | return np.NAN 92 | 93 | if not isinstance(precision, np.ndarray) or not isinstance(recall, 94 | np.ndarray): 95 | raise ValueError("precision and recall must be numpy array") 96 | if precision.dtype != float or recall.dtype != float: 97 | raise ValueError("input must be float numpy array.") 98 | if len(precision) != len(recall): 99 | raise ValueError("precision and recall must be of the same size.") 100 | if not precision.size: 101 | return 0.0 102 | if np.amin(precision) < 0 or np.amax(precision) > 1: 103 | raise ValueError("Precision must be in the range of [0, 1].") 104 | if np.amin(recall) < 0 or np.amax(recall) > 1: 105 | raise ValueError("recall must be in the range of [0, 1].") 106 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 107 | raise ValueError("recall must be a non-decreasing array") 108 | 109 | recall = np.concatenate([[0], recall, [1]]) 110 | precision = np.concatenate([[0], precision, [0]]) 111 | 112 | # Preprocess precision to be a non-decreasing array 113 | for i in range(len(precision) - 2, -1, -1): 114 | precision[i] = np.maximum(precision[i], precision[i + 1]) 115 | 116 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 117 | average_precision = np.sum( 118 | (recall[indices] - recall[indices - 1]) * precision[indices]) 119 | return average_precision 120 | 121 | 122 | def compute_cor_loc(num_gt_imgs_per_class, 123 | num_images_correctly_detected_per_class): 124 | """Compute CorLoc according to the definition in the following paper. 125 | 126 | https://www.robots.ox.ac.uk/~vgg/rg/papers/deselaers-eccv10.pdf 127 | 128 | Returns nans if there are no ground truth images for a class. 129 | 130 | Args: 131 | num_gt_imgs_per_class: 1D array, representing number of images containing 132 | at least one object instance of a particular class 133 | num_images_correctly_detected_per_class: 1D array, representing number of 134 | images that are correctly detected at least one object instance of a 135 | particular class 136 | 137 | Returns: 138 | corloc_per_class: A float numpy array represents the corloc score of each 139 | class 140 | """ 141 | # Divide by zero expected for classes with no gt examples. 142 | with np.errstate(divide="ignore", invalid="ignore"): 143 | return np.where( 144 | num_gt_imgs_per_class == 0, np.nan, 145 | num_images_correctly_detected_per_class / num_gt_imgs_per_class) 146 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_box_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxList classes and functions.""" 17 | 18 | import numpy as np 19 | 20 | 21 | class BoxList(object): 22 | """Box collection. 23 | 24 | BoxList represents a list of bounding boxes as numpy array, where each 25 | bounding box is represented as a row of 4 numbers, 26 | [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a 27 | given list correspond to a single image. 28 | 29 | Optionally, users can add additional related fields (such as 30 | objectness/classification scores). 31 | """ 32 | 33 | def __init__(self, data): 34 | """Constructs box collection. 35 | 36 | Args: 37 | data: a numpy array of shape [N, 4] representing box coordinates 38 | 39 | Raises: 40 | ValueError: if bbox data is not a numpy array 41 | ValueError: if invalid dimensions for bbox data 42 | """ 43 | if not isinstance(data, np.ndarray): 44 | raise ValueError('data must be a numpy array.') 45 | if len(data.shape) != 2 or data.shape[1] != 4: 46 | raise ValueError('Invalid dimensions for box data.') 47 | if data.dtype != np.float32 and data.dtype != np.float64: 48 | raise ValueError('Invalid data type for box data: float is required.') 49 | if not self._is_valid_boxes(data): 50 | raise ValueError('Invalid box data. data must be a numpy array of ' 51 | 'N*[y_min, x_min, y_max, x_max]') 52 | self.data = {'boxes': data} 53 | 54 | def num_boxes(self): 55 | """Return number of boxes held in collections.""" 56 | return self.data['boxes'].shape[0] 57 | 58 | def get_extra_fields(self): 59 | """Return all non-box fields.""" 60 | return [k for k in self.data.keys() if k != 'boxes'] 61 | 62 | def has_field(self, field): 63 | return field in self.data 64 | 65 | def add_field(self, field, field_data): 66 | """Add data to a specified field. 67 | 68 | Args: 69 | field: a string parameter used to speficy a related field to be accessed. 70 | field_data: a numpy array of [N, ...] representing the data associated 71 | with the field. 72 | Raises: 73 | ValueError: if the field is already exist or the dimension of the field 74 | data does not matches the number of boxes. 75 | """ 76 | if self.has_field(field): 77 | raise ValueError('Field ' + field + 'already exists') 78 | if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes(): 79 | raise ValueError('Invalid dimensions for field data') 80 | self.data[field] = field_data 81 | 82 | def get(self): 83 | """Convenience function for accesssing box coordinates. 84 | 85 | Returns: 86 | a numpy array of shape [N, 4] representing box corners 87 | """ 88 | return self.get_field('boxes') 89 | 90 | def get_field(self, field): 91 | """Accesses data associated with the specified field in the box collection. 92 | 93 | Args: 94 | field: a string parameter used to speficy a related field to be accessed. 95 | 96 | Returns: 97 | a numpy 1-d array representing data of an associated field 98 | 99 | Raises: 100 | ValueError: if invalid field 101 | """ 102 | if not self.has_field(field): 103 | raise ValueError('field {} does not exist'.format(field)) 104 | return self.data[field] 105 | 106 | def get_coordinates(self): 107 | """Get corner coordinates of boxes. 108 | 109 | Returns: 110 | a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max] 111 | """ 112 | box_coordinates = self.get() 113 | y_min = box_coordinates[:, 0] 114 | x_min = box_coordinates[:, 1] 115 | y_max = box_coordinates[:, 2] 116 | x_max = box_coordinates[:, 3] 117 | return [y_min, x_min, y_max, x_max] 118 | 119 | def _is_valid_boxes(self, data): 120 | """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin]. 121 | 122 | Args: 123 | data: a numpy array of shape [N, 4] representing box coordinates 124 | 125 | Returns: 126 | a boolean indicating whether all ymax of boxes are equal or greater than 127 | ymin, and all xmax of boxes are equal or greater than xmin. 128 | """ 129 | if data.shape[0] > 0: 130 | for i in range(data.shape[0]): 131 | if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]: 132 | return False 133 | return True 134 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_box_list_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Bounding Box List operations for Numpy BoxLists. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | from . import np_box_list 25 | from . import np_box_ops 26 | 27 | 28 | class SortOrder(object): 29 | """Enum class for sort order. 30 | 31 | Attributes: 32 | ascend: ascend order. 33 | descend: descend order. 34 | """ 35 | ASCEND = 1 36 | DESCEND = 2 37 | 38 | 39 | def area(boxlist): 40 | """Computes area of boxes. 41 | 42 | Args: 43 | boxlist: BoxList holding N boxes 44 | 45 | Returns: 46 | a numpy array with shape [N*1] representing box areas 47 | """ 48 | y_min, x_min, y_max, x_max = boxlist.get_coordinates() 49 | return (y_max - y_min) * (x_max - x_min) 50 | 51 | 52 | def intersection(boxlist1, boxlist2): 53 | """Compute pairwise intersection areas between boxes. 54 | 55 | Args: 56 | boxlist1: BoxList holding N boxes 57 | boxlist2: BoxList holding M boxes 58 | 59 | Returns: 60 | a numpy array with shape [N*M] representing pairwise intersection area 61 | """ 62 | return np_box_ops.intersection(boxlist1.get(), boxlist2.get()) 63 | 64 | 65 | def iou(boxlist1, boxlist2): 66 | """Computes pairwise intersection-over-union between box collections. 67 | 68 | Args: 69 | boxlist1: BoxList holding N boxes 70 | boxlist2: BoxList holding M boxes 71 | 72 | Returns: 73 | a numpy array with shape [N, M] representing pairwise iou scores. 74 | """ 75 | return np_box_ops.iou(boxlist1.get(), boxlist2.get()) 76 | 77 | 78 | def ioa(boxlist1, boxlist2): 79 | """Computes pairwise intersection-over-area between box collections. 80 | 81 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 82 | their intersection area over box2's area. Note that ioa is not symmetric, 83 | that is, IOA(box1, box2) != IOA(box2, box1). 84 | 85 | Args: 86 | boxlist1: BoxList holding N boxes 87 | boxlist2: BoxList holding M boxes 88 | 89 | Returns: 90 | a numpy array with shape [N, M] representing pairwise ioa scores. 91 | """ 92 | return np_box_ops.ioa(boxlist1.get(), boxlist2.get()) 93 | 94 | 95 | def gather(boxlist, indices, fields=None): 96 | """Gather boxes from BoxList according to indices and return new BoxList. 97 | 98 | By default, gather returns boxes corresponding to the input index list, as 99 | well as all additional fields stored in the boxlist (indexing into the 100 | first dimension). However one can optionally only gather from a 101 | subset of fields. 102 | 103 | Args: 104 | boxlist: BoxList holding N boxes 105 | indices: a 1-d numpy array of type int_ 106 | fields: (optional) list of fields to also gather from. If None (default), 107 | all fields are gathered from. Pass an empty fields list to only gather 108 | the box coordinates. 109 | 110 | Returns: 111 | subboxlist: a BoxList corresponding to the subset of the input BoxList 112 | specified by indices 113 | 114 | Raises: 115 | ValueError: if specified field is not contained in boxlist or if the 116 | indices are not of type int_ 117 | """ 118 | if indices.size: 119 | if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0: 120 | raise ValueError('indices are out of valid range.') 121 | subboxlist = np_box_list.BoxList(boxlist.get()[indices, :]) 122 | if fields is None: 123 | fields = boxlist.get_extra_fields() 124 | for field in fields: 125 | extra_field_data = boxlist.get_field(field) 126 | subboxlist.add_field(field, extra_field_data[indices, ...]) 127 | return subboxlist 128 | 129 | 130 | def sort_by_field(boxlist, field, order=SortOrder.DESCEND): 131 | """Sort boxes and associated fields according to a scalar field. 132 | 133 | A common use case is reordering the boxes according to descending scores. 134 | 135 | Args: 136 | boxlist: BoxList holding N boxes. 137 | field: A BoxList field for sorting and reordering the BoxList. 138 | order: (Optional) 'descend' or 'ascend'. Default is descend. 139 | 140 | Returns: 141 | sorted_boxlist: A sorted BoxList with the field in the specified order. 142 | 143 | Raises: 144 | ValueError: if specified field does not exist or is not of single dimension. 145 | ValueError: if the order is not either descend or ascend. 146 | """ 147 | if not boxlist.has_field(field): 148 | raise ValueError('Field ' + field + ' does not exist') 149 | if len(boxlist.get_field(field).shape) != 1: 150 | raise ValueError('Field ' + field + 'should be single dimension.') 151 | if order != SortOrder.DESCEND and order != SortOrder.ASCEND: 152 | raise ValueError('Invalid sort order') 153 | 154 | field_to_sort = boxlist.get_field(field) 155 | sorted_indices = np.argsort(field_to_sort) 156 | if order == SortOrder.DESCEND: 157 | sorted_indices = sorted_indices[::-1] 158 | return gather(boxlist, sorted_indices) 159 | 160 | 161 | def non_max_suppression(boxlist, 162 | max_output_size=10000, 163 | iou_threshold=1.0, 164 | score_threshold=-10.0): 165 | """Non maximum suppression. 166 | 167 | This op greedily selects a subset of detection bounding boxes, pruning 168 | away boxes that have high IOU (intersection over union) overlap (> thresh) 169 | with already selected boxes. In each iteration, the detected bounding box with 170 | highest score in the available pool is selected. 171 | 172 | Args: 173 | boxlist: BoxList holding N boxes. Must contain a 'scores' field 174 | representing detection scores. All scores belong to the same class. 175 | max_output_size: maximum number of retained boxes 176 | iou_threshold: intersection over union threshold. 177 | score_threshold: minimum score threshold. Remove the boxes with scores 178 | less than this value. Default value is set to -10. A very 179 | low threshold to pass pretty much all the boxes, unless 180 | the user sets a different score threshold. 181 | 182 | Returns: 183 | a BoxList holding M boxes where M <= max_output_size 184 | Raises: 185 | ValueError: if 'scores' field does not exist 186 | ValueError: if threshold is not in [0, 1] 187 | ValueError: if max_output_size < 0 188 | """ 189 | if not boxlist.has_field('scores'): 190 | raise ValueError('Field scores does not exist') 191 | if iou_threshold < 0. or iou_threshold > 1.0: 192 | raise ValueError('IOU threshold must be in [0, 1]') 193 | if max_output_size < 0: 194 | raise ValueError('max_output_size must be bigger than 0.') 195 | 196 | boxlist = filter_scores_greater_than(boxlist, score_threshold) 197 | if boxlist.num_boxes() == 0: 198 | return boxlist 199 | 200 | boxlist = sort_by_field(boxlist, 'scores') 201 | 202 | # Prevent further computation if NMS is disabled. 203 | if iou_threshold == 1.0: 204 | if boxlist.num_boxes() > max_output_size: 205 | selected_indices = np.arange(max_output_size) 206 | return gather(boxlist, selected_indices) 207 | else: 208 | return boxlist 209 | 210 | boxes = boxlist.get() 211 | num_boxes = boxlist.num_boxes() 212 | # is_index_valid is True only for all remaining valid boxes, 213 | is_index_valid = np.full(num_boxes, 1, dtype=bool) 214 | selected_indices = [] 215 | num_output = 0 216 | for i in range(num_boxes): 217 | if num_output < max_output_size: 218 | if is_index_valid[i]: 219 | num_output += 1 220 | selected_indices.append(i) 221 | is_index_valid[i] = False 222 | valid_indices = np.where(is_index_valid)[0] 223 | if valid_indices.size == 0: 224 | break 225 | 226 | intersect_over_union = np_box_ops.iou( 227 | np.expand_dims(boxes[i, :], axis=0), boxes[valid_indices, :]) 228 | intersect_over_union = np.squeeze(intersect_over_union, axis=0) 229 | is_index_valid[valid_indices] = np.logical_and( 230 | is_index_valid[valid_indices], 231 | intersect_over_union <= iou_threshold) 232 | return gather(boxlist, np.array(selected_indices)) 233 | 234 | 235 | def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh, 236 | max_output_size): 237 | """Multi-class version of non maximum suppression. 238 | 239 | This op greedily selects a subset of detection bounding boxes, pruning 240 | away boxes that have high IOU (intersection over union) overlap (> thresh) 241 | with already selected boxes. It operates independently for each class for 242 | which scores are provided (via the scores field of the input box_list), 243 | pruning boxes with score less than a provided threshold prior to 244 | applying NMS. 245 | 246 | Args: 247 | boxlist: BoxList holding N boxes. Must contain a 'scores' field 248 | representing detection scores. This scores field is a tensor that can 249 | be 1 dimensional (in the case of a single class) or 2-dimensional, which 250 | which case we assume that it takes the shape [num_boxes, num_classes]. 251 | We further assume that this rank is known statically and that 252 | scores.shape[1] is also known (i.e., the number of classes is fixed 253 | and known at graph construction time). 254 | score_thresh: scalar threshold for score (low scoring boxes are removed). 255 | iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap 256 | with previously selected boxes are removed). 257 | max_output_size: maximum number of retained boxes per class. 258 | 259 | Returns: 260 | a BoxList holding M boxes with a rank-1 scores field representing 261 | corresponding scores for each box with scores sorted in decreasing order 262 | and a rank-1 classes field representing a class label for each box. 263 | Raises: 264 | ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have 265 | a valid scores field. 266 | """ 267 | if not 0 <= iou_thresh <= 1.0: 268 | raise ValueError('thresh must be between 0 and 1') 269 | if not isinstance(boxlist, np_box_list.BoxList): 270 | raise ValueError('boxlist must be a BoxList') 271 | if not boxlist.has_field('scores'): 272 | raise ValueError('input boxlist must have \'scores\' field') 273 | scores = boxlist.get_field('scores') 274 | if len(scores.shape) == 1: 275 | scores = np.reshape(scores, [-1, 1]) 276 | elif len(scores.shape) == 2: 277 | if scores.shape[1] is None: 278 | raise ValueError('scores field must have statically defined second ' 279 | 'dimension') 280 | else: 281 | raise ValueError('scores field must be of rank 1 or 2') 282 | num_boxes = boxlist.num_boxes() 283 | num_scores = scores.shape[0] 284 | num_classes = scores.shape[1] 285 | 286 | if num_boxes != num_scores: 287 | raise ValueError('Incorrect scores field length: actual vs expected.') 288 | 289 | selected_boxes_list = [] 290 | for class_idx in range(num_classes): 291 | boxlist_and_class_scores = np_box_list.BoxList(boxlist.get()) 292 | class_scores = np.reshape(scores[0:num_scores, class_idx], [-1]) 293 | boxlist_and_class_scores.add_field('scores', class_scores) 294 | boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores, 295 | score_thresh) 296 | nms_result = non_max_suppression(boxlist_filt, 297 | max_output_size=max_output_size, 298 | iou_threshold=iou_thresh, 299 | score_threshold=score_thresh) 300 | nms_result.add_field( 301 | 'classes', np.zeros_like(nms_result.get_field('scores')) + class_idx) 302 | selected_boxes_list.append(nms_result) 303 | selected_boxes = concatenate(selected_boxes_list) 304 | sorted_boxes = sort_by_field(selected_boxes, 'scores') 305 | return sorted_boxes 306 | 307 | 308 | def scale(boxlist, y_scale, x_scale): 309 | """Scale box coordinates in x and y dimensions. 310 | 311 | Args: 312 | boxlist: BoxList holding N boxes 313 | y_scale: float 314 | x_scale: float 315 | 316 | Returns: 317 | boxlist: BoxList holding N boxes 318 | """ 319 | y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1) 320 | y_min = y_scale * y_min 321 | y_max = y_scale * y_max 322 | x_min = x_scale * x_min 323 | x_max = x_scale * x_max 324 | scaled_boxlist = np_box_list.BoxList(np.hstack([y_min, x_min, y_max, x_max])) 325 | 326 | fields = boxlist.get_extra_fields() 327 | for field in fields: 328 | extra_field_data = boxlist.get_field(field) 329 | scaled_boxlist.add_field(field, extra_field_data) 330 | 331 | return scaled_boxlist 332 | 333 | 334 | def clip_to_window(boxlist, window): 335 | """Clip bounding boxes to a window. 336 | 337 | This op clips input bounding boxes (represented by bounding box 338 | corners) to a window, optionally filtering out boxes that do not 339 | overlap at all with the window. 340 | 341 | Args: 342 | boxlist: BoxList holding M_in boxes 343 | window: a numpy array of shape [4] representing the 344 | [y_min, x_min, y_max, x_max] window to which the op 345 | should clip boxes. 346 | 347 | Returns: 348 | a BoxList holding M_out boxes where M_out <= M_in 349 | """ 350 | y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1) 351 | win_y_min = window[0] 352 | win_x_min = window[1] 353 | win_y_max = window[2] 354 | win_x_max = window[3] 355 | y_min_clipped = np.fmax(np.fmin(y_min, win_y_max), win_y_min) 356 | y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min) 357 | x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min) 358 | x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min) 359 | clipped = np_box_list.BoxList( 360 | np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped])) 361 | clipped = _copy_extra_fields(clipped, boxlist) 362 | areas = area(clipped) 363 | nonzero_area_indices = np.reshape(np.nonzero(np.greater(areas, 0.0)), 364 | [-1]).astype(np.int32) 365 | return gather(clipped, nonzero_area_indices) 366 | 367 | 368 | def prune_non_overlapping_boxes(boxlist1, boxlist2, minoverlap=0.0): 369 | """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2. 370 | 371 | For each box in boxlist1, we want its IOA to be more than minoverlap with 372 | at least one of the boxes in boxlist2. If it does not, we remove it. 373 | 374 | Args: 375 | boxlist1: BoxList holding N boxes. 376 | boxlist2: BoxList holding M boxes. 377 | minoverlap: Minimum required overlap between boxes, to count them as 378 | overlapping. 379 | 380 | Returns: 381 | A pruned boxlist with size [N', 4]. 382 | """ 383 | intersection_over_area = ioa(boxlist2, boxlist1) # [M, N] tensor 384 | intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor 385 | keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap)) 386 | keep_inds = np.nonzero(keep_bool)[0] 387 | new_boxlist1 = gather(boxlist1, keep_inds) 388 | return new_boxlist1 389 | 390 | 391 | def prune_outside_window(boxlist, window): 392 | """Prunes bounding boxes that fall outside a given window. 393 | 394 | This function prunes bounding boxes that even partially fall outside the given 395 | window. See also ClipToWindow which only prunes bounding boxes that fall 396 | completely outside the window, and clips any bounding boxes that partially 397 | overflow. 398 | 399 | Args: 400 | boxlist: a BoxList holding M_in boxes. 401 | window: a numpy array of size 4, representing [ymin, xmin, ymax, xmax] 402 | of the window. 403 | 404 | Returns: 405 | pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in. 406 | valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes 407 | in the input tensor. 408 | """ 409 | 410 | y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1) 411 | win_y_min = window[0] 412 | win_x_min = window[1] 413 | win_y_max = window[2] 414 | win_x_max = window[3] 415 | coordinate_violations = np.hstack([np.less(y_min, win_y_min), 416 | np.less(x_min, win_x_min), 417 | np.greater(y_max, win_y_max), 418 | np.greater(x_max, win_x_max)]) 419 | valid_indices = np.reshape( 420 | np.where(np.logical_not(np.max(coordinate_violations, axis=1))), [-1]) 421 | return gather(boxlist, valid_indices), valid_indices 422 | 423 | 424 | def concatenate(boxlists, fields=None): 425 | """Concatenate list of BoxLists. 426 | 427 | This op concatenates a list of input BoxLists into a larger BoxList. It also 428 | handles concatenation of BoxList fields as long as the field tensor shapes 429 | are equal except for the first dimension. 430 | 431 | Args: 432 | boxlists: list of BoxList objects 433 | fields: optional list of fields to also concatenate. By default, all 434 | fields from the first BoxList in the list are included in the 435 | concatenation. 436 | 437 | Returns: 438 | a BoxList with number of boxes equal to 439 | sum([boxlist.num_boxes() for boxlist in BoxList]) 440 | Raises: 441 | ValueError: if boxlists is invalid (i.e., is not a list, is empty, or 442 | contains non BoxList objects), or if requested fields are not contained in 443 | all boxlists 444 | """ 445 | if not isinstance(boxlists, list): 446 | raise ValueError('boxlists should be a list') 447 | if not boxlists: 448 | raise ValueError('boxlists should have nonzero length') 449 | for boxlist in boxlists: 450 | if not isinstance(boxlist, np_box_list.BoxList): 451 | raise ValueError('all elements of boxlists should be BoxList objects') 452 | concatenated = np_box_list.BoxList( 453 | np.vstack([boxlist.get() for boxlist in boxlists])) 454 | if fields is None: 455 | fields = boxlists[0].get_extra_fields() 456 | for field in fields: 457 | first_field_shape = boxlists[0].get_field(field).shape 458 | first_field_shape = first_field_shape[1:] 459 | for boxlist in boxlists: 460 | if not boxlist.has_field(field): 461 | raise ValueError('boxlist must contain all requested fields') 462 | field_shape = boxlist.get_field(field).shape 463 | field_shape = field_shape[1:] 464 | if field_shape != first_field_shape: 465 | raise ValueError('field %s must have same shape for all boxlists ' 466 | 'except for the 0th dimension.' % field) 467 | concatenated_field = np.concatenate( 468 | [boxlist.get_field(field) for boxlist in boxlists], axis=0) 469 | concatenated.add_field(field, concatenated_field) 470 | return concatenated 471 | 472 | 473 | def filter_scores_greater_than(boxlist, thresh): 474 | """Filter to keep only boxes with score exceeding a given threshold. 475 | 476 | This op keeps the collection of boxes whose corresponding scores are 477 | greater than the input threshold. 478 | 479 | Args: 480 | boxlist: BoxList holding N boxes. Must contain a 'scores' field 481 | representing detection scores. 482 | thresh: scalar threshold 483 | 484 | Returns: 485 | a BoxList holding M boxes where M <= N 486 | 487 | Raises: 488 | ValueError: if boxlist not a BoxList object or if it does not 489 | have a scores field 490 | """ 491 | if not isinstance(boxlist, np_box_list.BoxList): 492 | raise ValueError('boxlist must be a BoxList') 493 | if not boxlist.has_field('scores'): 494 | raise ValueError('input boxlist must have \'scores\' field') 495 | scores = boxlist.get_field('scores') 496 | if len(scores.shape) > 2: 497 | raise ValueError('Scores should have rank 1 or 2') 498 | if len(scores.shape) == 2 and scores.shape[1] != 1: 499 | raise ValueError('Scores should have rank 1 or have shape ' 500 | 'consistent with [None, 1]') 501 | high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), 502 | [-1]).astype(np.int32) 503 | return gather(boxlist, high_score_indices) 504 | 505 | 506 | def change_coordinate_frame(boxlist, window): 507 | """Change coordinate frame of the boxlist to be relative to window's frame. 508 | 509 | Given a window of the form [ymin, xmin, ymax, xmax], 510 | changes bounding box coordinates from boxlist to be relative to this window 511 | (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)). 512 | 513 | An example use case is data augmentation: where we are given groundtruth 514 | boxes (boxlist) and would like to randomly crop the image to some 515 | window (window). In this case we need to change the coordinate frame of 516 | each groundtruth box to be relative to this new window. 517 | 518 | Args: 519 | boxlist: A BoxList object holding N boxes. 520 | window: a size 4 1-D numpy array. 521 | 522 | Returns: 523 | Returns a BoxList object with N boxes. 524 | """ 525 | win_height = window[2] - window[0] 526 | win_width = window[3] - window[1] 527 | boxlist_new = scale( 528 | np_box_list.BoxList(boxlist.get() - 529 | [window[0], window[1], window[0], window[1]]), 530 | 1.0 / win_height, 1.0 / win_width) 531 | _copy_extra_fields(boxlist_new, boxlist) 532 | 533 | return boxlist_new 534 | 535 | 536 | def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from): 537 | """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to. 538 | 539 | Args: 540 | boxlist_to_copy_to: BoxList to which extra fields are copied. 541 | boxlist_to_copy_from: BoxList from which fields are copied. 542 | 543 | Returns: 544 | boxlist_to_copy_to with extra fields. 545 | """ 546 | for field in boxlist_to_copy_from.get_extra_fields(): 547 | boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field)) 548 | return boxlist_to_copy_to 549 | 550 | 551 | def _update_valid_indices_by_removing_high_iou_boxes( 552 | selected_indices, is_index_valid, intersect_over_union, threshold): 553 | max_iou = np.max(intersect_over_union[:, selected_indices], axis=1) 554 | return np.logical_and(is_index_valid, max_iou <= threshold) 555 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_box_mask_list.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Numpy BoxMaskList classes and functions.""" 17 | 18 | import numpy as np 19 | from . import np_box_list 20 | 21 | 22 | class BoxMaskList(np_box_list.BoxList): 23 | """Convenience wrapper for BoxList with masks. 24 | 25 | BoxMaskList extends the np_box_list.BoxList to contain masks as well. 26 | In particular, its constructor receives both boxes and masks. Note that the 27 | masks correspond to the full image. 28 | """ 29 | 30 | def __init__(self, box_data, mask_data): 31 | """Constructs box collection. 32 | 33 | Args: 34 | box_data: a numpy array of shape [N, 4] representing box coordinates 35 | mask_data: a numpy array of shape [N, height, width] representing masks 36 | with values are in {0,1}. The masks correspond to the full 37 | image. The height and the width will be equal to image height and width. 38 | 39 | Raises: 40 | ValueError: if bbox data is not a numpy array 41 | ValueError: if invalid dimensions for bbox data 42 | ValueError: if mask data is not a numpy array 43 | ValueError: if invalid dimension for mask data 44 | """ 45 | super(BoxMaskList, self).__init__(box_data) 46 | if not isinstance(mask_data, np.ndarray): 47 | raise ValueError('Mask data must be a numpy array.') 48 | if len(mask_data.shape) != 3: 49 | raise ValueError('Invalid dimensions for mask data.') 50 | if mask_data.dtype != np.uint8: 51 | raise ValueError('Invalid data type for mask data: uint8 is required.') 52 | if mask_data.shape[0] != box_data.shape[0]: 53 | raise ValueError('There should be the same number of boxes and masks.') 54 | self.data['masks'] = mask_data 55 | 56 | def get_masks(self): 57 | """Convenience function for accessing masks. 58 | 59 | Returns: 60 | a numpy array of shape [N, height, width] representing masks 61 | """ 62 | return self.get_field('masks') 63 | 64 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_box_mask_list_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for np_box_mask_list.BoxMaskList. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | from . import np_box_list_ops 25 | from . import np_box_mask_list 26 | from . import np_mask_ops 27 | 28 | 29 | def box_list_to_box_mask_list(boxlist): 30 | """Converts a BoxList containing 'masks' into a BoxMaskList. 31 | 32 | Args: 33 | boxlist: An np_box_list.BoxList object. 34 | 35 | Returns: 36 | An np_box_mask_list.BoxMaskList object. 37 | 38 | Raises: 39 | ValueError: If boxlist does not contain `masks` as a field. 40 | """ 41 | if not boxlist.has_field('masks'): 42 | raise ValueError('boxlist does not contain mask field.') 43 | box_mask_list = np_box_mask_list.BoxMaskList( 44 | box_data=boxlist.get(), 45 | mask_data=boxlist.get_field('masks')) 46 | extra_fields = boxlist.get_extra_fields() 47 | for key in extra_fields: 48 | if key != 'masks': 49 | box_mask_list.data[key] = boxlist.get_field(key) 50 | return box_mask_list 51 | 52 | 53 | def area(box_mask_list): 54 | """Computes area of masks. 55 | 56 | Args: 57 | box_mask_list: np_box_mask_list.BoxMaskList holding N boxes and masks 58 | 59 | Returns: 60 | a numpy array with shape [N*1] representing mask areas 61 | """ 62 | return np_mask_ops.area(box_mask_list.get_masks()) 63 | 64 | 65 | def intersection(box_mask_list1, box_mask_list2): 66 | """Compute pairwise intersection areas between masks. 67 | 68 | Args: 69 | box_mask_list1: BoxMaskList holding N boxes and masks 70 | box_mask_list2: BoxMaskList holding M boxes and masks 71 | 72 | Returns: 73 | a numpy array with shape [N*M] representing pairwise intersection area 74 | """ 75 | return np_mask_ops.intersection(box_mask_list1.get_masks(), 76 | box_mask_list2.get_masks()) 77 | 78 | 79 | def iou(box_mask_list1, box_mask_list2): 80 | """Computes pairwise intersection-over-union between box and mask collections. 81 | 82 | Args: 83 | box_mask_list1: BoxMaskList holding N boxes and masks 84 | box_mask_list2: BoxMaskList holding M boxes and masks 85 | 86 | Returns: 87 | a numpy array with shape [N, M] representing pairwise iou scores. 88 | """ 89 | return np_mask_ops.iou(box_mask_list1.get_masks(), 90 | box_mask_list2.get_masks()) 91 | 92 | 93 | def ioa(box_mask_list1, box_mask_list2): 94 | """Computes pairwise intersection-over-area between box and mask collections. 95 | 96 | Intersection-over-area (ioa) between two masks mask1 and mask2 is defined as 97 | their intersection area over mask2's area. Note that ioa is not symmetric, 98 | that is, IOA(mask1, mask2) != IOA(mask2, mask1). 99 | 100 | Args: 101 | box_mask_list1: np_box_mask_list.BoxMaskList holding N boxes and masks 102 | box_mask_list2: np_box_mask_list.BoxMaskList holding M boxes and masks 103 | 104 | Returns: 105 | a numpy array with shape [N, M] representing pairwise ioa scores. 106 | """ 107 | return np_mask_ops.ioa(box_mask_list1.get_masks(), box_mask_list2.get_masks()) 108 | 109 | 110 | def gather(box_mask_list, indices, fields=None): 111 | """Gather boxes from np_box_mask_list.BoxMaskList according to indices. 112 | 113 | By default, gather returns boxes corresponding to the input index list, as 114 | well as all additional fields stored in the box_mask_list (indexing into the 115 | first dimension). However one can optionally only gather from a 116 | subset of fields. 117 | 118 | Args: 119 | box_mask_list: np_box_mask_list.BoxMaskList holding N boxes 120 | indices: a 1-d numpy array of type int_ 121 | fields: (optional) list of fields to also gather from. If None (default), 122 | all fields are gathered from. Pass an empty fields list to only gather 123 | the box coordinates. 124 | 125 | Returns: 126 | subbox_mask_list: a np_box_mask_list.BoxMaskList corresponding to the subset 127 | of the input box_mask_list specified by indices 128 | 129 | Raises: 130 | ValueError: if specified field is not contained in box_mask_list or if the 131 | indices are not of type int_ 132 | """ 133 | if fields is not None: 134 | if 'masks' not in fields: 135 | fields.append('masks') 136 | return box_list_to_box_mask_list( 137 | np_box_list_ops.gather( 138 | boxlist=box_mask_list, indices=indices, fields=fields)) 139 | 140 | 141 | def sort_by_field(box_mask_list, field, 142 | order=np_box_list_ops.SortOrder.DESCEND): 143 | """Sort boxes and associated fields according to a scalar field. 144 | 145 | A common use case is reordering the boxes according to descending scores. 146 | 147 | Args: 148 | box_mask_list: BoxMaskList holding N boxes. 149 | field: A BoxMaskList field for sorting and reordering the BoxMaskList. 150 | order: (Optional) 'descend' or 'ascend'. Default is descend. 151 | 152 | Returns: 153 | sorted_box_mask_list: A sorted BoxMaskList with the field in the specified 154 | order. 155 | """ 156 | return box_list_to_box_mask_list( 157 | np_box_list_ops.sort_by_field( 158 | boxlist=box_mask_list, field=field, order=order)) 159 | 160 | 161 | def non_max_suppression(box_mask_list, 162 | max_output_size=10000, 163 | iou_threshold=1.0, 164 | score_threshold=-10.0): 165 | """Non maximum suppression. 166 | 167 | This op greedily selects a subset of detection bounding boxes, pruning 168 | away boxes that have high IOU (intersection over union) overlap (> thresh) 169 | with already selected boxes. In each iteration, the detected bounding box with 170 | highest score in the available pool is selected. 171 | 172 | Args: 173 | box_mask_list: np_box_mask_list.BoxMaskList holding N boxes. Must contain 174 | a 'scores' field representing detection scores. All scores belong to the 175 | same class. 176 | max_output_size: maximum number of retained boxes 177 | iou_threshold: intersection over union threshold. 178 | score_threshold: minimum score threshold. Remove the boxes with scores 179 | less than this value. Default value is set to -10. A very 180 | low threshold to pass pretty much all the boxes, unless 181 | the user sets a different score threshold. 182 | 183 | Returns: 184 | an np_box_mask_list.BoxMaskList holding M boxes where M <= max_output_size 185 | 186 | Raises: 187 | ValueError: if 'scores' field does not exist 188 | ValueError: if threshold is not in [0, 1] 189 | ValueError: if max_output_size < 0 190 | """ 191 | if not box_mask_list.has_field('scores'): 192 | raise ValueError('Field scores does not exist') 193 | if iou_threshold < 0. or iou_threshold > 1.0: 194 | raise ValueError('IOU threshold must be in [0, 1]') 195 | if max_output_size < 0: 196 | raise ValueError('max_output_size must be bigger than 0.') 197 | 198 | box_mask_list = filter_scores_greater_than(box_mask_list, score_threshold) 199 | if box_mask_list.num_boxes() == 0: 200 | return box_mask_list 201 | 202 | box_mask_list = sort_by_field(box_mask_list, 'scores') 203 | 204 | # Prevent further computation if NMS is disabled. 205 | if iou_threshold == 1.0: 206 | if box_mask_list.num_boxes() > max_output_size: 207 | selected_indices = np.arange(max_output_size) 208 | return gather(box_mask_list, selected_indices) 209 | else: 210 | return box_mask_list 211 | 212 | masks = box_mask_list.get_masks() 213 | num_masks = box_mask_list.num_boxes() 214 | 215 | # is_index_valid is True only for all remaining valid boxes, 216 | is_index_valid = np.full(num_masks, 1, dtype=bool) 217 | selected_indices = [] 218 | num_output = 0 219 | for i in range(num_masks): 220 | if num_output < max_output_size: 221 | if is_index_valid[i]: 222 | num_output += 1 223 | selected_indices.append(i) 224 | is_index_valid[i] = False 225 | valid_indices = np.where(is_index_valid)[0] 226 | if valid_indices.size == 0: 227 | break 228 | 229 | intersect_over_union = np_mask_ops.iou( 230 | np.expand_dims(masks[i], axis=0), masks[valid_indices]) 231 | intersect_over_union = np.squeeze(intersect_over_union, axis=0) 232 | is_index_valid[valid_indices] = np.logical_and( 233 | is_index_valid[valid_indices], 234 | intersect_over_union <= iou_threshold) 235 | return gather(box_mask_list, np.array(selected_indices)) 236 | 237 | 238 | def multi_class_non_max_suppression(box_mask_list, score_thresh, iou_thresh, 239 | max_output_size): 240 | """Multi-class version of non maximum suppression. 241 | 242 | This op greedily selects a subset of detection bounding boxes, pruning 243 | away boxes that have high IOU (intersection over union) overlap (> thresh) 244 | with already selected boxes. It operates independently for each class for 245 | which scores are provided (via the scores field of the input box_list), 246 | pruning boxes with score less than a provided threshold prior to 247 | applying NMS. 248 | 249 | Args: 250 | box_mask_list: np_box_mask_list.BoxMaskList holding N boxes. Must contain a 251 | 'scores' field representing detection scores. This scores field is a 252 | tensor that can be 1 dimensional (in the case of a single class) or 253 | 2-dimensional, in which case we assume that it takes the 254 | shape [num_boxes, num_classes]. We further assume that this rank is known 255 | statically and that scores.shape[1] is also known (i.e., the number of 256 | classes is fixed and known at graph construction time). 257 | score_thresh: scalar threshold for score (low scoring boxes are removed). 258 | iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap 259 | with previously selected boxes are removed). 260 | max_output_size: maximum number of retained boxes per class. 261 | 262 | Returns: 263 | a box_mask_list holding M boxes with a rank-1 scores field representing 264 | corresponding scores for each box with scores sorted in decreasing order 265 | and a rank-1 classes field representing a class label for each box. 266 | Raises: 267 | ValueError: if iou_thresh is not in [0, 1] or if input box_mask_list does 268 | not have a valid scores field. 269 | """ 270 | if not 0 <= iou_thresh <= 1.0: 271 | raise ValueError('thresh must be between 0 and 1') 272 | if not isinstance(box_mask_list, np_box_mask_list.BoxMaskList): 273 | raise ValueError('box_mask_list must be a box_mask_list') 274 | if not box_mask_list.has_field('scores'): 275 | raise ValueError('input box_mask_list must have \'scores\' field') 276 | scores = box_mask_list.get_field('scores') 277 | if len(scores.shape) == 1: 278 | scores = np.reshape(scores, [-1, 1]) 279 | elif len(scores.shape) == 2: 280 | if scores.shape[1] is None: 281 | raise ValueError('scores field must have statically defined second ' 282 | 'dimension') 283 | else: 284 | raise ValueError('scores field must be of rank 1 or 2') 285 | 286 | num_boxes = box_mask_list.num_boxes() 287 | num_scores = scores.shape[0] 288 | num_classes = scores.shape[1] 289 | 290 | if num_boxes != num_scores: 291 | raise ValueError('Incorrect scores field length: actual vs expected.') 292 | 293 | selected_boxes_list = [] 294 | for class_idx in range(num_classes): 295 | box_mask_list_and_class_scores = np_box_mask_list.BoxMaskList( 296 | box_data=box_mask_list.get(), 297 | mask_data=box_mask_list.get_masks()) 298 | class_scores = np.reshape(scores[0:num_scores, class_idx], [-1]) 299 | box_mask_list_and_class_scores.add_field('scores', class_scores) 300 | box_mask_list_filt = filter_scores_greater_than( 301 | box_mask_list_and_class_scores, score_thresh) 302 | nms_result = non_max_suppression( 303 | box_mask_list_filt, 304 | max_output_size=max_output_size, 305 | iou_threshold=iou_thresh, 306 | score_threshold=score_thresh) 307 | nms_result.add_field( 308 | 'classes', 309 | np.zeros_like(nms_result.get_field('scores')) + class_idx) 310 | selected_boxes_list.append(nms_result) 311 | selected_boxes = np_box_list_ops.concatenate(selected_boxes_list) 312 | sorted_boxes = np_box_list_ops.sort_by_field(selected_boxes, 'scores') 313 | return box_list_to_box_mask_list(boxlist=sorted_boxes) 314 | 315 | 316 | def prune_non_overlapping_masks(box_mask_list1, box_mask_list2, minoverlap=0.0): 317 | """Prunes the boxes in list1 that overlap less than thresh with list2. 318 | 319 | For each mask in box_mask_list1, we want its IOA to be more than minoverlap 320 | with at least one of the masks in box_mask_list2. If it does not, we remove 321 | it. If the masks are not full size image, we do the pruning based on boxes. 322 | 323 | Args: 324 | box_mask_list1: np_box_mask_list.BoxMaskList holding N boxes and masks. 325 | box_mask_list2: np_box_mask_list.BoxMaskList holding M boxes and masks. 326 | minoverlap: Minimum required overlap between boxes, to count them as 327 | overlapping. 328 | 329 | Returns: 330 | A pruned box_mask_list with size [N', 4]. 331 | """ 332 | intersection_over_area = ioa(box_mask_list2, box_mask_list1) # [M, N] tensor 333 | intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor 334 | keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap)) 335 | keep_inds = np.nonzero(keep_bool)[0] 336 | new_box_mask_list1 = gather(box_mask_list1, keep_inds) 337 | return new_box_mask_list1 338 | 339 | 340 | def concatenate(box_mask_lists, fields=None): 341 | """Concatenate list of box_mask_lists. 342 | 343 | This op concatenates a list of input box_mask_lists into a larger 344 | box_mask_list. It also 345 | handles concatenation of box_mask_list fields as long as the field tensor 346 | shapes are equal except for the first dimension. 347 | 348 | Args: 349 | box_mask_lists: list of np_box_mask_list.BoxMaskList objects 350 | fields: optional list of fields to also concatenate. By default, all 351 | fields from the first BoxMaskList in the list are included in the 352 | concatenation. 353 | 354 | Returns: 355 | a box_mask_list with number of boxes equal to 356 | sum([box_mask_list.num_boxes() for box_mask_list in box_mask_list]) 357 | Raises: 358 | ValueError: if box_mask_lists is invalid (i.e., is not a list, is empty, or 359 | contains non box_mask_list objects), or if requested fields are not 360 | contained in all box_mask_lists 361 | """ 362 | if fields is not None: 363 | if 'masks' not in fields: 364 | fields.append('masks') 365 | return box_list_to_box_mask_list( 366 | np_box_list_ops.concatenate(boxlists=box_mask_lists, fields=fields)) 367 | 368 | 369 | def filter_scores_greater_than(box_mask_list, thresh): 370 | """Filter to keep only boxes and masks with score exceeding a given threshold. 371 | 372 | This op keeps the collection of boxes and masks whose corresponding scores are 373 | greater than the input threshold. 374 | 375 | Args: 376 | box_mask_list: BoxMaskList holding N boxes and masks. Must contain a 377 | 'scores' field representing detection scores. 378 | thresh: scalar threshold 379 | 380 | Returns: 381 | a BoxMaskList holding M boxes and masks where M <= N 382 | 383 | Raises: 384 | ValueError: if box_mask_list not a np_box_mask_list.BoxMaskList object or 385 | if it does not have a scores field 386 | """ 387 | if not isinstance(box_mask_list, np_box_mask_list.BoxMaskList): 388 | raise ValueError('box_mask_list must be a BoxMaskList') 389 | if not box_mask_list.has_field('scores'): 390 | raise ValueError('input box_mask_list must have \'scores\' field') 391 | scores = box_mask_list.get_field('scores') 392 | if len(scores.shape) > 2: 393 | raise ValueError('Scores should have rank 1 or 2') 394 | if len(scores.shape) == 2 and scores.shape[1] != 1: 395 | raise ValueError('Scores should have rank 1 or have shape ' 396 | 'consistent with [None, 1]') 397 | high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), 398 | [-1]).astype(np.int32) 399 | return gather(box_mask_list, high_score_indices) 400 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, 4] numpy arrays representing bounding boxes. 17 | 18 | Example box operations that are supported: 19 | * Areas: compute bounding box areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | 25 | def area(boxes): 26 | """Computes area of boxes. 27 | 28 | Args: 29 | boxes: Numpy array with shape [N, 4] holding N boxes 30 | 31 | Returns: 32 | a numpy array with shape [N*1] representing box areas 33 | """ 34 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 35 | 36 | 37 | def intersection(boxes1, boxes2): 38 | """Compute pairwise intersection areas between boxes. 39 | 40 | Args: 41 | boxes1: a numpy array with shape [N, 4] holding N boxes 42 | boxes2: a numpy array with shape [M, 4] holding M boxes 43 | 44 | Returns: 45 | a numpy array with shape [N*M] representing pairwise intersection area 46 | """ 47 | [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1) 48 | [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1) 49 | 50 | all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2)) 51 | all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2)) 52 | intersect_heights = np.maximum( 53 | np.zeros(all_pairs_max_ymin.shape), 54 | all_pairs_min_ymax - all_pairs_max_ymin) 55 | all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2)) 56 | all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2)) 57 | intersect_widths = np.maximum( 58 | np.zeros(all_pairs_max_xmin.shape), 59 | all_pairs_min_xmax - all_pairs_max_xmin) 60 | return intersect_heights * intersect_widths 61 | 62 | 63 | def iou(boxes1, boxes2): 64 | """Computes pairwise intersection-over-union between box collections. 65 | 66 | Args: 67 | boxes1: a numpy array with shape [N, 4] holding N boxes. 68 | boxes2: a numpy array with shape [M, 4] holding N boxes. 69 | 70 | Returns: 71 | a numpy array with shape [N, M] representing pairwise iou scores. 72 | """ 73 | intersect = intersection(boxes1, boxes2) 74 | area1 = area(boxes1) 75 | area2 = area(boxes2) 76 | union = np.expand_dims(area1, axis=1) + np.expand_dims( 77 | area2, axis=0) - intersect 78 | return intersect / union 79 | 80 | 81 | def ioa(boxes1, boxes2): 82 | """Computes pairwise intersection-over-area between box collections. 83 | 84 | Intersection-over-area (ioa) between two boxes box1 and box2 is defined as 85 | their intersection area over box2's area. Note that ioa is not symmetric, 86 | that is, IOA(box1, box2) != IOA(box2, box1). 87 | 88 | Args: 89 | boxes1: a numpy array with shape [N, 4] holding N boxes. 90 | boxes2: a numpy array with shape [M, 4] holding N boxes. 91 | 92 | Returns: 93 | a numpy array with shape [N, M] representing pairwise ioa scores. 94 | """ 95 | intersect = intersection(boxes1, boxes2) 96 | areas = np.expand_dims(area(boxes2), axis=0) 97 | return intersect / areas 98 | -------------------------------------------------------------------------------- /gravit/utils/ava/np_mask_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Operations for [N, height, width] numpy arrays representing masks. 17 | 18 | Example mask operations that are supported: 19 | * Areas: compute mask areas 20 | * IOU: pairwise intersection-over-union scores 21 | """ 22 | import numpy as np 23 | 24 | EPSILON = 1e-7 25 | 26 | 27 | def area(masks): 28 | """Computes area of masks. 29 | 30 | Args: 31 | masks: Numpy array with shape [N, height, width] holding N masks. Masks 32 | values are of type np.uint8 and values are in {0,1}. 33 | 34 | Returns: 35 | a numpy array with shape [N*1] representing mask areas. 36 | 37 | Raises: 38 | ValueError: If masks.dtype is not np.uint8 39 | """ 40 | if masks.dtype != np.uint8: 41 | raise ValueError('Masks type should be np.uint8') 42 | return np.sum(masks, axis=(1, 2), dtype=np.float32) 43 | 44 | 45 | def intersection(masks1, masks2): 46 | """Compute pairwise intersection areas between masks. 47 | 48 | Args: 49 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 50 | values are of type np.uint8 and values are in {0,1}. 51 | masks2: a numpy array with shape [M, height, width] holding M masks. Masks 52 | values are of type np.uint8 and values are in {0,1}. 53 | 54 | Returns: 55 | a numpy array with shape [N*M] representing pairwise intersection area. 56 | 57 | Raises: 58 | ValueError: If masks1 and masks2 are not of type np.uint8. 59 | """ 60 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 61 | raise ValueError('masks1 and masks2 should be of type np.uint8') 62 | n = masks1.shape[0] 63 | m = masks2.shape[0] 64 | answer = np.zeros([n, m], dtype=np.float32) 65 | for i in np.arange(n): 66 | for j in np.arange(m): 67 | answer[i, j] = np.sum(np.minimum(masks1[i], masks2[j]), dtype=np.float32) 68 | return answer 69 | 70 | 71 | def iou(masks1, masks2): 72 | """Computes pairwise intersection-over-union between mask collections. 73 | 74 | Args: 75 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 76 | values are of type np.uint8 and values are in {0,1}. 77 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 78 | values are of type np.uint8 and values are in {0,1}. 79 | 80 | Returns: 81 | a numpy array with shape [N, M] representing pairwise iou scores. 82 | 83 | Raises: 84 | ValueError: If masks1 and masks2 are not of type np.uint8. 85 | """ 86 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 87 | raise ValueError('masks1 and masks2 should be of type np.uint8') 88 | intersect = intersection(masks1, masks2) 89 | area1 = area(masks1) 90 | area2 = area(masks2) 91 | union = np.expand_dims(area1, axis=1) + np.expand_dims( 92 | area2, axis=0) - intersect 93 | return intersect / np.maximum(union, EPSILON) 94 | 95 | 96 | def ioa(masks1, masks2): 97 | """Computes pairwise intersection-over-area between box collections. 98 | 99 | Intersection-over-area (ioa) between two masks, mask1 and mask2 is defined as 100 | their intersection area over mask2's area. Note that ioa is not symmetric, 101 | that is, IOA(mask1, mask2) != IOA(mask2, mask1). 102 | 103 | Args: 104 | masks1: a numpy array with shape [N, height, width] holding N masks. Masks 105 | values are of type np.uint8 and values are in {0,1}. 106 | masks2: a numpy array with shape [M, height, width] holding N masks. Masks 107 | values are of type np.uint8 and values are in {0,1}. 108 | 109 | Returns: 110 | a numpy array with shape [N, M] representing pairwise ioa scores. 111 | 112 | Raises: 113 | ValueError: If masks1 and masks2 are not of type np.uint8. 114 | """ 115 | if masks1.dtype != np.uint8 or masks2.dtype != np.uint8: 116 | raise ValueError('masks1 and masks2 should be of type np.uint8') 117 | intersect = intersection(masks1, masks2) 118 | areas = np.expand_dims(area(masks2), axis=0) 119 | return intersect / (areas + EPSILON) 120 | -------------------------------------------------------------------------------- /gravit/utils/ava/per_image_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Evaluate Object Detection result on a single image. 16 | 17 | Annotate each detected result as true positives or false positive according to 18 | a predefined IOU ratio. Non Maximum Supression is used by default. Multi class 19 | detection is supported by default. 20 | Based on the settings, per image evaluation is either performed on boxes or 21 | on object masks. 22 | """ 23 | import numpy as np 24 | 25 | from . import np_box_list 26 | from . import np_box_list_ops 27 | from . import np_box_mask_list 28 | from . import np_box_mask_list_ops 29 | 30 | 31 | class PerImageEvaluation(object): 32 | """Evaluate detection result of a single image.""" 33 | 34 | def __init__(self, 35 | num_groundtruth_classes, 36 | matching_iou_threshold=0.5): 37 | """Initialized PerImageEvaluation by evaluation parameters. 38 | 39 | Args: 40 | num_groundtruth_classes: Number of ground truth object classes 41 | matching_iou_threshold: A ratio of area intersection to union, which is 42 | the threshold to consider whether a detection is true positive or not 43 | """ 44 | self.matching_iou_threshold = matching_iou_threshold 45 | self.num_groundtruth_classes = num_groundtruth_classes 46 | 47 | def compute_object_detection_metrics( 48 | self, detected_boxes, detected_scores, detected_class_labels, 49 | groundtruth_boxes, groundtruth_class_labels, 50 | groundtruth_is_difficult_list, groundtruth_is_group_of_list, 51 | detected_masks=None, groundtruth_masks=None): 52 | """Evaluates detections as being tp, fp or ignored from a single image. 53 | 54 | The evaluation is done in two stages: 55 | 1. All detections are matched to non group-of boxes; true positives are 56 | determined and detections matched to difficult boxes are ignored. 57 | 2. Detections that are determined as false positives are matched against 58 | group-of boxes and ignored if matched. 59 | 60 | Args: 61 | detected_boxes: A float numpy array of shape [N, 4], representing N 62 | regions of detected object regions. 63 | Each row is of the format [y_min, x_min, y_max, x_max] 64 | detected_scores: A float numpy array of shape [N, 1], representing 65 | the confidence scores of the detected N object instances. 66 | detected_class_labels: A integer numpy array of shape [N, 1], repreneting 67 | the class labels of the detected N object instances. 68 | groundtruth_boxes: A float numpy array of shape [M, 4], representing M 69 | regions of object instances in ground truth 70 | groundtruth_class_labels: An integer numpy array of shape [M, 1], 71 | representing M class labels of object instances in ground truth 72 | groundtruth_is_difficult_list: A boolean numpy array of length M denoting 73 | whether a ground truth box is a difficult instance or not 74 | groundtruth_is_group_of_list: A boolean numpy array of length M denoting 75 | whether a ground truth box has group-of tag 76 | detected_masks: (optional) A uint8 numpy array of shape 77 | [N, height, width]. If not None, the metrics will be computed based 78 | on masks. 79 | groundtruth_masks: (optional) A uint8 numpy array of shape 80 | [M, height, width]. 81 | 82 | Returns: 83 | scores: A list of C float numpy arrays. Each numpy array is of 84 | shape [K, 1], representing K scores detected with object class 85 | label c 86 | tp_fp_labels: A list of C boolean numpy arrays. Each numpy array 87 | is of shape [K, 1], representing K True/False positive label of 88 | object instances detected with class label c 89 | """ 90 | detected_boxes, detected_scores, detected_class_labels, detected_masks = ( 91 | self._remove_invalid_boxes(detected_boxes, detected_scores, 92 | detected_class_labels, detected_masks)) 93 | scores, tp_fp_labels = self._compute_tp_fp( 94 | detected_boxes=detected_boxes, 95 | detected_scores=detected_scores, 96 | detected_class_labels=detected_class_labels, 97 | groundtruth_boxes=groundtruth_boxes, 98 | groundtruth_class_labels=groundtruth_class_labels, 99 | groundtruth_is_difficult_list=groundtruth_is_difficult_list, 100 | groundtruth_is_group_of_list=groundtruth_is_group_of_list, 101 | detected_masks=detected_masks, 102 | groundtruth_masks=groundtruth_masks) 103 | 104 | return scores, tp_fp_labels 105 | 106 | def _compute_tp_fp(self, detected_boxes, detected_scores, 107 | detected_class_labels, groundtruth_boxes, 108 | groundtruth_class_labels, groundtruth_is_difficult_list, 109 | groundtruth_is_group_of_list, 110 | detected_masks=None, groundtruth_masks=None): 111 | """Labels true/false positives of detections of an image across all classes. 112 | 113 | Args: 114 | detected_boxes: A float numpy array of shape [N, 4], representing N 115 | regions of detected object regions. 116 | Each row is of the format [y_min, x_min, y_max, x_max] 117 | detected_scores: A float numpy array of shape [N, 1], representing 118 | the confidence scores of the detected N object instances. 119 | detected_class_labels: A integer numpy array of shape [N, 1], repreneting 120 | the class labels of the detected N object instances. 121 | groundtruth_boxes: A float numpy array of shape [M, 4], representing M 122 | regions of object instances in ground truth 123 | groundtruth_class_labels: An integer numpy array of shape [M, 1], 124 | representing M class labels of object instances in ground truth 125 | groundtruth_is_difficult_list: A boolean numpy array of length M denoting 126 | whether a ground truth box is a difficult instance or not 127 | groundtruth_is_group_of_list: A boolean numpy array of length M denoting 128 | whether a ground truth box has group-of tag 129 | detected_masks: (optional) A np.uint8 numpy array of shape 130 | [N, height, width]. If not None, the scores will be computed based 131 | on masks. 132 | groundtruth_masks: (optional) A np.uint8 numpy array of shape 133 | [M, height, width]. 134 | 135 | Returns: 136 | result_scores: A list of float numpy arrays. Each numpy array is of 137 | shape [K, 1], representing K scores detected with object class 138 | label c 139 | result_tp_fp_labels: A list of boolean numpy array. Each numpy array is of 140 | shape [K, 1], representing K True/False positive label of object 141 | instances detected with class label c 142 | 143 | Raises: 144 | ValueError: If detected masks is not None but groundtruth masks are None, 145 | or the other way around. 146 | """ 147 | if detected_masks is not None and groundtruth_masks is None: 148 | raise ValueError( 149 | 'Detected masks is available but groundtruth masks is not.') 150 | if detected_masks is None and groundtruth_masks is not None: 151 | raise ValueError( 152 | 'Groundtruth masks is available but detected masks is not.') 153 | 154 | result_scores = [] 155 | result_tp_fp_labels = [] 156 | for i in range(self.num_groundtruth_classes): 157 | groundtruth_is_difficult_list_at_ith_class = ( 158 | groundtruth_is_difficult_list[groundtruth_class_labels == i]) 159 | groundtruth_is_group_of_list_at_ith_class = ( 160 | groundtruth_is_group_of_list[groundtruth_class_labels == i]) 161 | (gt_boxes_at_ith_class, gt_masks_at_ith_class, 162 | detected_boxes_at_ith_class, detected_scores_at_ith_class, 163 | detected_masks_at_ith_class) = self._get_ith_class_arrays( 164 | detected_boxes, detected_scores, detected_masks, 165 | detected_class_labels, groundtruth_boxes, groundtruth_masks, 166 | groundtruth_class_labels, i) 167 | scores, tp_fp_labels = self._compute_tp_fp_for_single_class( 168 | detected_boxes=detected_boxes_at_ith_class, 169 | detected_scores=detected_scores_at_ith_class, 170 | groundtruth_boxes=gt_boxes_at_ith_class, 171 | groundtruth_is_difficult_list= 172 | groundtruth_is_difficult_list_at_ith_class, 173 | groundtruth_is_group_of_list= 174 | groundtruth_is_group_of_list_at_ith_class, 175 | detected_masks=detected_masks_at_ith_class, 176 | groundtruth_masks=gt_masks_at_ith_class) 177 | result_scores.append(scores) 178 | result_tp_fp_labels.append(tp_fp_labels) 179 | return result_scores, result_tp_fp_labels 180 | 181 | def _get_overlaps_and_scores_box_mode( 182 | self, 183 | detected_boxes, 184 | detected_scores, 185 | groundtruth_boxes, 186 | groundtruth_is_group_of_list): 187 | """Computes overlaps and scores between detected and groudntruth boxes. 188 | 189 | Args: 190 | detected_boxes: A numpy array of shape [N, 4] representing detected box 191 | coordinates 192 | detected_scores: A 1-d numpy array of length N representing classification 193 | score 194 | groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth 195 | box coordinates 196 | groundtruth_is_group_of_list: A boolean numpy array of length M denoting 197 | whether a ground truth box has group-of tag. If a groundtruth box 198 | is group-of box, every detection matching this box is ignored. 199 | 200 | Returns: 201 | iou: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If 202 | gt_non_group_of_boxlist.num_boxes() == 0 it will be None. 203 | ioa: A float numpy array of size [num_detected_boxes, num_gt_boxes]. If 204 | gt_group_of_boxlist.num_boxes() == 0 it will be None. 205 | scores: The score of the detected boxlist. 206 | num_boxes: Number of non-maximum suppressed detected boxes. 207 | """ 208 | detected_boxlist = np_box_list.BoxList(detected_boxes) 209 | detected_boxlist.add_field('scores', detected_scores) 210 | gt_non_group_of_boxlist = np_box_list.BoxList( 211 | groundtruth_boxes[~groundtruth_is_group_of_list]) 212 | iou = np_box_list_ops.iou(detected_boxlist, gt_non_group_of_boxlist) 213 | scores = detected_boxlist.get_field('scores') 214 | num_boxes = detected_boxlist.num_boxes() 215 | return iou, None, scores, num_boxes 216 | 217 | def _compute_tp_fp_for_single_class( 218 | self, detected_boxes, detected_scores, groundtruth_boxes, 219 | groundtruth_is_difficult_list, groundtruth_is_group_of_list, 220 | detected_masks=None, groundtruth_masks=None): 221 | """Labels boxes detected with the same class from the same image as tp/fp. 222 | 223 | Args: 224 | detected_boxes: A numpy array of shape [N, 4] representing detected box 225 | coordinates 226 | detected_scores: A 1-d numpy array of length N representing classification 227 | score 228 | groundtruth_boxes: A numpy array of shape [M, 4] representing ground truth 229 | box coordinates 230 | groundtruth_is_difficult_list: A boolean numpy array of length M denoting 231 | whether a ground truth box is a difficult instance or not. If a 232 | groundtruth box is difficult, every detection matching this box 233 | is ignored. 234 | groundtruth_is_group_of_list: A boolean numpy array of length M denoting 235 | whether a ground truth box has group-of tag. If a groundtruth box 236 | is group-of box, every detection matching this box is ignored. 237 | detected_masks: (optional) A uint8 numpy array of shape 238 | [N, height, width]. If not None, the scores will be computed based 239 | on masks. 240 | groundtruth_masks: (optional) A uint8 numpy array of shape 241 | [M, height, width]. 242 | 243 | Returns: 244 | Two arrays of the same size, containing all boxes that were evaluated as 245 | being true positives or false positives; if a box matched to a difficult 246 | box or to a group-of box, it is ignored. 247 | 248 | scores: A numpy array representing the detection scores. 249 | tp_fp_labels: a boolean numpy array indicating whether a detection is a 250 | true positive. 251 | """ 252 | if detected_boxes.size == 0: 253 | return np.array([], dtype=float), np.array([], dtype=bool) 254 | 255 | (iou, _, scores, 256 | num_detected_boxes) = self._get_overlaps_and_scores_box_mode( 257 | detected_boxes=detected_boxes, 258 | detected_scores=detected_scores, 259 | groundtruth_boxes=groundtruth_boxes, 260 | groundtruth_is_group_of_list=groundtruth_is_group_of_list) 261 | 262 | if groundtruth_boxes.size == 0: 263 | return scores, np.zeros(num_detected_boxes, dtype=bool) 264 | 265 | tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool) 266 | is_matched_to_difficult_box = np.zeros(num_detected_boxes, dtype=bool) 267 | is_matched_to_group_of_box = np.zeros(num_detected_boxes, dtype=bool) 268 | 269 | # The evaluation is done in two stages: 270 | # 1. All detections are matched to non group-of boxes; true positives are 271 | # determined and detections matched to difficult boxes are ignored. 272 | # 2. Detections that are determined as false positives are matched against 273 | # group-of boxes and ignored if matched. 274 | 275 | # Tp-fp evaluation for non-group of boxes (if any). 276 | if iou.shape[1] > 0: 277 | groundtruth_nongroup_of_is_difficult_list = groundtruth_is_difficult_list[ 278 | ~groundtruth_is_group_of_list] 279 | max_overlap_gt_ids = np.argmax(iou, axis=1) 280 | is_gt_box_detected = np.zeros(iou.shape[1], dtype=bool) 281 | for i in range(num_detected_boxes): 282 | gt_id = max_overlap_gt_ids[i] 283 | if iou[i, gt_id] >= self.matching_iou_threshold: 284 | if not groundtruth_nongroup_of_is_difficult_list[gt_id]: 285 | if not is_gt_box_detected[gt_id]: 286 | tp_fp_labels[i] = True 287 | is_gt_box_detected[gt_id] = True 288 | else: 289 | is_matched_to_difficult_box[i] = True 290 | 291 | return scores[~is_matched_to_difficult_box 292 | & ~is_matched_to_group_of_box], tp_fp_labels[ 293 | ~is_matched_to_difficult_box 294 | & ~is_matched_to_group_of_box] 295 | 296 | def _get_ith_class_arrays(self, detected_boxes, detected_scores, 297 | detected_masks, detected_class_labels, 298 | groundtruth_boxes, groundtruth_masks, 299 | groundtruth_class_labels, class_index): 300 | """Returns numpy arrays belonging to class with index `class_index`. 301 | 302 | Args: 303 | detected_boxes: A numpy array containing detected boxes. 304 | detected_scores: A numpy array containing detected scores. 305 | detected_masks: A numpy array containing detected masks. 306 | detected_class_labels: A numpy array containing detected class labels. 307 | groundtruth_boxes: A numpy array containing groundtruth boxes. 308 | groundtruth_masks: A numpy array containing groundtruth masks. 309 | groundtruth_class_labels: A numpy array containing groundtruth class 310 | labels. 311 | class_index: An integer index. 312 | 313 | Returns: 314 | gt_boxes_at_ith_class: A numpy array containing groundtruth boxes labeled 315 | as ith class. 316 | gt_masks_at_ith_class: A numpy array containing groundtruth masks labeled 317 | as ith class. 318 | detected_boxes_at_ith_class: A numpy array containing detected boxes 319 | corresponding to the ith class. 320 | detected_scores_at_ith_class: A numpy array containing detected scores 321 | corresponding to the ith class. 322 | detected_masks_at_ith_class: A numpy array containing detected masks 323 | corresponding to the ith class. 324 | """ 325 | selected_groundtruth = (groundtruth_class_labels == class_index) 326 | gt_boxes_at_ith_class = groundtruth_boxes[selected_groundtruth] 327 | if groundtruth_masks is not None: 328 | gt_masks_at_ith_class = groundtruth_masks[selected_groundtruth] 329 | else: 330 | gt_masks_at_ith_class = None 331 | selected_detections = (detected_class_labels == class_index) 332 | detected_boxes_at_ith_class = detected_boxes[selected_detections] 333 | detected_scores_at_ith_class = detected_scores[selected_detections] 334 | if detected_masks is not None: 335 | detected_masks_at_ith_class = detected_masks[selected_detections] 336 | else: 337 | detected_masks_at_ith_class = None 338 | return (gt_boxes_at_ith_class, gt_masks_at_ith_class, 339 | detected_boxes_at_ith_class, detected_scores_at_ith_class, 340 | detected_masks_at_ith_class) 341 | 342 | def _remove_invalid_boxes(self, detected_boxes, detected_scores, 343 | detected_class_labels, detected_masks=None): 344 | """Removes entries with invalid boxes. 345 | 346 | A box is invalid if either its xmax is smaller than its xmin, or its ymax 347 | is smaller than its ymin. 348 | 349 | Args: 350 | detected_boxes: A float numpy array of size [num_boxes, 4] containing box 351 | coordinates in [ymin, xmin, ymax, xmax] format. 352 | detected_scores: A float numpy array of size [num_boxes]. 353 | detected_class_labels: A int32 numpy array of size [num_boxes]. 354 | detected_masks: A uint8 numpy array of size [num_boxes, height, width]. 355 | 356 | Returns: 357 | valid_detected_boxes: A float numpy array of size [num_valid_boxes, 4] 358 | containing box coordinates in [ymin, xmin, ymax, xmax] format. 359 | valid_detected_scores: A float numpy array of size [num_valid_boxes]. 360 | valid_detected_class_labels: A int32 numpy array of size 361 | [num_valid_boxes]. 362 | valid_detected_masks: A uint8 numpy array of size 363 | [num_valid_boxes, height, width]. 364 | """ 365 | valid_indices = np.logical_and(detected_boxes[:, 0] < detected_boxes[:, 2], 366 | detected_boxes[:, 1] < detected_boxes[:, 3]) 367 | detected_boxes = detected_boxes[valid_indices] 368 | detected_scores = detected_scores[valid_indices] 369 | detected_class_labels = detected_class_labels[valid_indices] 370 | if detected_masks is not None: 371 | detected_masks = detected_masks[valid_indices] 372 | return [ 373 | detected_boxes, detected_scores, detected_class_labels, detected_masks 374 | ] 375 | -------------------------------------------------------------------------------- /gravit/utils/ava/standard_fields.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Contains classes specifying naming conventions used for object detection. 17 | 18 | 19 | Specifies: 20 | InputDataFields: standard fields used by reader/preprocessor/batcher. 21 | DetectionResultFields: standard fields returned by object detector. 22 | BoxListFields: standard field used by BoxList 23 | TfExampleFields: standard fields for tf-example data format (go/tf-example). 24 | """ 25 | 26 | 27 | class InputDataFields(object): 28 | """Names for the input tensors. 29 | 30 | Holds the standard data field names to use for identifying input tensors. This 31 | should be used by the decoder to identify keys for the returned tensor_dict 32 | containing input tensors. And it should be used by the model to identify the 33 | tensors it needs. 34 | 35 | Attributes: 36 | image: image. 37 | original_image: image in the original input size. 38 | key: unique key corresponding to image. 39 | source_id: source of the original image. 40 | filename: original filename of the dataset (without common path). 41 | groundtruth_image_classes: image-level class labels. 42 | groundtruth_boxes: coordinates of the ground truth boxes in the image. 43 | groundtruth_classes: box-level class labels. 44 | groundtruth_label_types: box-level label types (e.g. explicit negative). 45 | groundtruth_is_crowd: [DEPRECATED, use groundtruth_group_of instead] 46 | is the groundtruth a single object or a crowd. 47 | groundtruth_area: area of a groundtruth segment. 48 | groundtruth_difficult: is a `difficult` object 49 | groundtruth_group_of: is a `group_of` objects, e.g. multiple objects of the 50 | same class, forming a connected group, where instances are heavily 51 | occluding each other. 52 | proposal_boxes: coordinates of object proposal boxes. 53 | proposal_objectness: objectness score of each proposal. 54 | groundtruth_instance_masks: ground truth instance masks. 55 | groundtruth_instance_boundaries: ground truth instance boundaries. 56 | groundtruth_instance_classes: instance mask-level class labels. 57 | groundtruth_keypoints: ground truth keypoints. 58 | groundtruth_keypoint_visibilities: ground truth keypoint visibilities. 59 | groundtruth_label_scores: groundtruth label scores. 60 | groundtruth_weights: groundtruth weight factor for bounding boxes. 61 | num_groundtruth_boxes: number of groundtruth boxes. 62 | true_image_shapes: true shapes of images in the resized images, as resized 63 | images can be padded with zeros. 64 | """ 65 | image = 'image' 66 | original_image = 'original_image' 67 | key = 'key' 68 | source_id = 'source_id' 69 | filename = 'filename' 70 | groundtruth_image_classes = 'groundtruth_image_classes' 71 | groundtruth_boxes = 'groundtruth_boxes' 72 | groundtruth_classes = 'groundtruth_classes' 73 | groundtruth_label_types = 'groundtruth_label_types' 74 | groundtruth_is_crowd = 'groundtruth_is_crowd' 75 | groundtruth_area = 'groundtruth_area' 76 | groundtruth_difficult = 'groundtruth_difficult' 77 | groundtruth_group_of = 'groundtruth_group_of' 78 | proposal_boxes = 'proposal_boxes' 79 | proposal_objectness = 'proposal_objectness' 80 | groundtruth_instance_masks = 'groundtruth_instance_masks' 81 | groundtruth_instance_boundaries = 'groundtruth_instance_boundaries' 82 | groundtruth_instance_classes = 'groundtruth_instance_classes' 83 | groundtruth_keypoints = 'groundtruth_keypoints' 84 | groundtruth_keypoint_visibilities = 'groundtruth_keypoint_visibilities' 85 | groundtruth_label_scores = 'groundtruth_label_scores' 86 | groundtruth_weights = 'groundtruth_weights' 87 | num_groundtruth_boxes = 'num_groundtruth_boxes' 88 | true_image_shape = 'true_image_shape' 89 | 90 | 91 | class DetectionResultFields(object): 92 | """Naming conventions for storing the output of the detector. 93 | 94 | Attributes: 95 | source_id: source of the original image. 96 | key: unique key corresponding to image. 97 | detection_boxes: coordinates of the detection boxes in the image. 98 | detection_scores: detection scores for the detection boxes in the image. 99 | detection_classes: detection-level class labels. 100 | detection_masks: contains a segmentation mask for each detection box. 101 | detection_boundaries: contains an object boundary for each detection box. 102 | detection_keypoints: contains detection keypoints for each detection box. 103 | num_detections: number of detections in the batch. 104 | """ 105 | 106 | source_id = 'source_id' 107 | key = 'key' 108 | detection_boxes = 'detection_boxes' 109 | detection_scores = 'detection_scores' 110 | detection_classes = 'detection_classes' 111 | detection_masks = 'detection_masks' 112 | detection_boundaries = 'detection_boundaries' 113 | detection_keypoints = 'detection_keypoints' 114 | num_detections = 'num_detections' 115 | 116 | 117 | class BoxListFields(object): 118 | """Naming conventions for BoxLists. 119 | 120 | Attributes: 121 | boxes: bounding box coordinates. 122 | classes: classes per bounding box. 123 | scores: scores per bounding box. 124 | weights: sample weights per bounding box. 125 | objectness: objectness score per bounding box. 126 | masks: masks per bounding box. 127 | boundaries: boundaries per bounding box. 128 | keypoints: keypoints per bounding box. 129 | keypoint_heatmaps: keypoint heatmaps per bounding box. 130 | """ 131 | boxes = 'boxes' 132 | classes = 'classes' 133 | scores = 'scores' 134 | weights = 'weights' 135 | objectness = 'objectness' 136 | masks = 'masks' 137 | boundaries = 'boundaries' 138 | keypoints = 'keypoints' 139 | keypoint_heatmaps = 'keypoint_heatmaps' 140 | 141 | 142 | class TfExampleFields(object): 143 | """TF-example proto feature names for object detection. 144 | 145 | Holds the standard feature names to load from an Example proto for object 146 | detection. 147 | 148 | Attributes: 149 | image_encoded: JPEG encoded string 150 | image_format: image format, e.g. "JPEG" 151 | filename: filename 152 | channels: number of channels of image 153 | colorspace: colorspace, e.g. "RGB" 154 | height: height of image in pixels, e.g. 462 155 | width: width of image in pixels, e.g. 581 156 | source_id: original source of the image 157 | object_class_text: labels in text format, e.g. ["person", "cat"] 158 | object_class_label: labels in numbers, e.g. [16, 8] 159 | object_bbox_xmin: xmin coordinates of groundtruth box, e.g. 10, 30 160 | object_bbox_xmax: xmax coordinates of groundtruth box, e.g. 50, 40 161 | object_bbox_ymin: ymin coordinates of groundtruth box, e.g. 40, 50 162 | object_bbox_ymax: ymax coordinates of groundtruth box, e.g. 80, 70 163 | object_view: viewpoint of object, e.g. ["frontal", "left"] 164 | object_truncated: is object truncated, e.g. [true, false] 165 | object_occluded: is object occluded, e.g. [true, false] 166 | object_difficult: is object difficult, e.g. [true, false] 167 | object_group_of: is object a single object or a group of objects 168 | object_depiction: is object a depiction 169 | object_is_crowd: [DEPRECATED, use object_group_of instead] 170 | is the object a single object or a crowd 171 | object_segment_area: the area of the segment. 172 | object_weight: a weight factor for the object's bounding box. 173 | instance_masks: instance segmentation masks. 174 | instance_boundaries: instance boundaries. 175 | instance_classes: Classes for each instance segmentation mask. 176 | detection_class_label: class label in numbers. 177 | detection_bbox_ymin: ymin coordinates of a detection box. 178 | detection_bbox_xmin: xmin coordinates of a detection box. 179 | detection_bbox_ymax: ymax coordinates of a detection box. 180 | detection_bbox_xmax: xmax coordinates of a detection box. 181 | detection_score: detection score for the class label and box. 182 | """ 183 | image_encoded = 'image/encoded' 184 | image_format = 'image/format' # format is reserved keyword 185 | filename = 'image/filename' 186 | channels = 'image/channels' 187 | colorspace = 'image/colorspace' 188 | height = 'image/height' 189 | width = 'image/width' 190 | source_id = 'image/source_id' 191 | object_class_text = 'image/object/class/text' 192 | object_class_label = 'image/object/class/label' 193 | object_bbox_ymin = 'image/object/bbox/ymin' 194 | object_bbox_xmin = 'image/object/bbox/xmin' 195 | object_bbox_ymax = 'image/object/bbox/ymax' 196 | object_bbox_xmax = 'image/object/bbox/xmax' 197 | object_view = 'image/object/view' 198 | object_truncated = 'image/object/truncated' 199 | object_occluded = 'image/object/occluded' 200 | object_difficult = 'image/object/difficult' 201 | object_group_of = 'image/object/group_of' 202 | object_depiction = 'image/object/depiction' 203 | object_is_crowd = 'image/object/is_crowd' 204 | object_segment_area = 'image/object/segment/area' 205 | object_weight = 'image/object/weight' 206 | instance_masks = 'image/segmentation/object' 207 | instance_boundaries = 'image/boundaries/object' 208 | instance_classes = 'image/segmentation/object/class' 209 | detection_class_label = 'image/detection/label' 210 | detection_bbox_ymin = 'image/detection/bbox/ymin' 211 | detection_bbox_xmin = 'image/detection/bbox/xmin' 212 | detection_bbox_ymax = 'image/detection/bbox/ymax' 213 | detection_bbox_xmax = 'image/detection/bbox/xmax' 214 | detection_score = 'image/detection/score' 215 | -------------------------------------------------------------------------------- /gravit/utils/eval_tool.py: -------------------------------------------------------------------------------- 1 | # This code is based the official ActivityNet repository: https://github.com/activitynet/ActivityNet 2 | # The owner of the official ActivityNet repository: ActivityNet 3 | # Copyright (c) 2015 ActivityNet 4 | # Licensed under The MIT License 5 | # Please refer to https://github.com/activitynet/ActivityNet/blob/master/LICENSE 6 | 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from collections import defaultdict 12 | import csv 13 | import decimal 14 | import heapq 15 | import h5py 16 | from .ava import object_detection_evaluation 17 | from .ava import standard_fields 18 | from .vs import knapsack 19 | from scipy import stats 20 | from scipy.stats import rankdata 21 | 22 | def compute_average_precision(precision, recall): 23 | """Compute Average Precision according to the definition in VOCdevkit. 24 | Precision is modified to ensure that it does not decrease as recall 25 | decrease. 26 | Args: 27 | precision: A float [N, 1] numpy array of precisions 28 | recall: A float [N, 1] numpy array of recalls 29 | Raises: 30 | ValueError: if the input is not of the correct format 31 | Returns: 32 | average_precison: The area under the precision recall curve. NaN if 33 | precision and recall are None. 34 | """ 35 | if precision is None: 36 | if recall is not None: 37 | raise ValueError("If precision is None, recall must also be None") 38 | return np.NAN 39 | 40 | if not isinstance(precision, np.ndarray) or not isinstance( 41 | recall, np.ndarray): 42 | raise ValueError("precision and recall must be numpy array") 43 | if precision.dtype != float or recall.dtype != float: 44 | raise ValueError("input must be float numpy array.") 45 | if len(precision) != len(recall): 46 | raise ValueError("precision and recall must be of the same size.") 47 | if not precision.size: 48 | return 0.0 49 | if np.amin(precision) < 0 or np.amax(precision) > 1: 50 | raise ValueError("Precision must be in the range of [0, 1].") 51 | if np.amin(recall) < 0 or np.amax(recall) > 1: 52 | raise ValueError("recall must be in the range of [0, 1].") 53 | if not all(recall[i] <= recall[i + 1] for i in range(len(recall) - 1)): 54 | raise ValueError("recall must be a non-decreasing array") 55 | 56 | recall = np.concatenate([[0], recall, [1]]) 57 | precision = np.concatenate([[0], precision, [0]]) 58 | 59 | # Smooth precision to be monotonically decreasing. 60 | for i in range(len(precision) - 2, -1, -1): 61 | precision[i] = np.maximum(precision[i], precision[i + 1]) 62 | 63 | indices = np.where(recall[1:] != recall[:-1])[0] + 1 64 | average_precision = np.sum( 65 | (recall[indices] - recall[indices - 1]) * precision[indices]) 66 | return average_precision 67 | 68 | 69 | def load_csv(filename, column_names): 70 | """Loads CSV from the filename using given column names. 71 | Adds uid column. 72 | Args: 73 | filename: Path to the CSV file to load. 74 | column_names: A list of column names for the data. 75 | Returns: 76 | df: A Pandas DataFrame containing the data. 77 | """ 78 | # Here and elsewhere, df indicates a DataFrame variable. 79 | df = pd.read_csv(filename, header=None, names=column_names) 80 | # Creates a unique id from frame timestamp and entity id. 81 | df["uid"] = (df["frame_timestamp"].map(str) + ":" + df["entity_id"]) 82 | return df 83 | 84 | 85 | def eq(a, b, tolerance=1e-09): 86 | """Returns true if values are approximately equal.""" 87 | return abs(a - b) <= tolerance 88 | 89 | 90 | def merge_groundtruth_and_predictions(df_groundtruth, df_predictions): 91 | """Merges groundtruth and prediction DataFrames. 92 | The returned DataFrame is merged on uid field and sorted in descending order 93 | by score field. Bounding boxes are checked to make sure they match between 94 | groundtruth and predictions. 95 | Args: 96 | df_groundtruth: A DataFrame with groundtruth data. 97 | df_predictions: A DataFrame with predictions data. 98 | Returns: 99 | df_merged: A merged DataFrame, with rows matched on uid column. 100 | """ 101 | if df_groundtruth["uid"].count() != df_predictions["uid"].count(): 102 | raise ValueError( 103 | "Groundtruth and predictions CSV must have the same number of " 104 | "unique rows.") 105 | 106 | if df_predictions["label"].unique() != ["SPEAKING_AUDIBLE"]: 107 | raise ValueError( 108 | "Predictions CSV must contain only SPEAKING_AUDIBLE label.") 109 | 110 | if df_predictions["score"].count() < df_predictions["uid"].count(): 111 | raise ValueError("Predictions CSV must contain score value for every row.") 112 | 113 | # Merges groundtruth and predictions on uid, validates that uid is unique 114 | # in both frames, and sorts the resulting frame by the predictions score. 115 | df_merged = df_groundtruth.merge( 116 | df_predictions, 117 | on="uid", 118 | suffixes=("_groundtruth", "_prediction"), 119 | validate="1:1").sort_values( 120 | by=["score"], ascending=False).reset_index() 121 | # Validates that bounding boxes in ground truth and predictions match for the 122 | # same uids. 123 | df_merged["bounding_box_correct"] = np.where( 124 | eq(df_merged["entity_box_x1_groundtruth"], 125 | df_merged["entity_box_x1_prediction"]) 126 | & eq(df_merged["entity_box_x2_groundtruth"], 127 | df_merged["entity_box_x2_prediction"]) 128 | & eq(df_merged["entity_box_y1_groundtruth"], 129 | df_merged["entity_box_y1_prediction"]) 130 | & eq(df_merged["entity_box_y2_groundtruth"], 131 | df_merged["entity_box_y2_prediction"]), True, False) 132 | 133 | if (~df_merged["bounding_box_correct"]).sum() > 0: 134 | raise ValueError( 135 | "Mismatch between groundtruth and predictions bounding boxes found at " 136 | + str(list(df_merged[~df_merged["bounding_box_correct"]]["uid"]))) 137 | 138 | return df_merged 139 | 140 | 141 | def get_all_positives(df_merged): 142 | """Counts all positive examples in the groundtruth dataset.""" 143 | return df_merged[df_merged["label_groundtruth"] == 144 | "SPEAKING_AUDIBLE"]["uid"].count() 145 | 146 | 147 | def calculate_precision_recall(df_merged): 148 | """Calculates precision and recall arrays going through df_merged row-wise.""" 149 | all_positives = get_all_positives(df_merged) 150 | 151 | # Populates each row with 1 if this row is a true positive 152 | # (at its score level). 153 | df_merged["is_tp"] = np.where( 154 | (df_merged["label_groundtruth"] == "SPEAKING_AUDIBLE") & 155 | (df_merged["label_prediction"] == "SPEAKING_AUDIBLE"), 1, 0) 156 | 157 | # Counts true positives up to and including that row. 158 | df_merged["tp"] = df_merged["is_tp"].cumsum() 159 | 160 | # Calculates precision for every row counting true positives up to 161 | # and including that row over the index (1-based) of that row. 162 | df_merged["precision"] = df_merged["tp"] / (df_merged.index + 1) 163 | 164 | # Calculates recall for every row counting true positives up to 165 | # and including that row over all positives in the groundtruth dataset. 166 | df_merged["recall"] = df_merged["tp"] / all_positives 167 | 168 | return np.array(df_merged["precision"]), np.array(df_merged["recall"]) 169 | 170 | 171 | def run_evaluation_asd(predictions, groundtruth): 172 | """Runs AVA Active Speaker evaluation, returns average precision result.""" 173 | column_names=[ 174 | "video_id", "frame_timestamp", "entity_box_x1", "entity_box_y1", 175 | "entity_box_x2", "entity_box_y2", "label", "entity_id" 176 | ] 177 | df_groundtruth = load_csv(groundtruth, column_names=column_names) 178 | df_predictions = pd.DataFrame(predictions, columns=column_names+["score"]) 179 | # Creates a unique id from frame timestamp and entity id. 180 | df_predictions["uid"] = (df_predictions["frame_timestamp"].map(str) + ":" + df_predictions["entity_id"]) 181 | 182 | df_merged = merge_groundtruth_and_predictions(df_groundtruth, df_predictions) 183 | precision, recall = calculate_precision_recall(df_merged) 184 | 185 | return compute_average_precision(precision, recall) 186 | 187 | 188 | def make_image_key(video_id, timestamp): 189 | """Returns a unique identifier for a video id & timestamp.""" 190 | return "%s,%.6f" % (video_id, decimal.Decimal(timestamp)) 191 | 192 | 193 | def read_csv(csv_file, class_whitelist=None, capacity=0): 194 | """Loads boxes and class labels from a CSV file in the AVA format. 195 | CSV file format described at https://research.google.com/ava/download.html. 196 | Args: 197 | csv_file: A file object. 198 | class_whitelist: If provided, boxes corresponding to (integer) class labels 199 | not in this set are skipped. 200 | capacity: Maximum number of labeled boxes allowed for each example. Default 201 | is 0 where there is no limit. 202 | Returns: 203 | boxes: A dictionary mapping each unique image key (string) to a list of 204 | boxes, given as coordinates [y1, x1, y2, x2]. 205 | labels: A dictionary mapping each unique image key (string) to a list of 206 | integer class lables, matching the corresponding box in `boxes`. 207 | scores: A dictionary mapping each unique image key (string) to a list of 208 | score values lables, matching the corresponding label in `labels`. If 209 | scores are not provided in the csv, then they will default to 1.0. 210 | all_keys: A set of all image keys found in the csv file. 211 | """ 212 | entries = defaultdict(list) 213 | boxes = defaultdict(list) 214 | labels = defaultdict(list) 215 | scores = defaultdict(list) 216 | all_keys = set() 217 | reader = csv.reader(csv_file) 218 | for row in reader: 219 | assert len(row) in [2, 7, 8], "Wrong number of columns: " + row 220 | image_key = make_image_key(row[0], row[1]) 221 | all_keys.add(image_key) 222 | # Rows with 2 tokens (videoid,timestatmp) indicates images with no detected 223 | # / ground truth actions boxes. Add them to all_keys, so we can score 224 | # appropriately, but otherwise skip the box creation steps. 225 | if len(row) == 2: 226 | continue 227 | x1, y1, x2, y2 = [float(n) for n in row[2:6]] 228 | action_id = int(row[6]) 229 | if class_whitelist and action_id not in class_whitelist: 230 | continue 231 | score = 1.0 232 | if len(row) == 8: 233 | score = float(row[7]) 234 | if capacity < 1 or len(entries[image_key]) < capacity: 235 | heapq.heappush(entries[image_key], (score, action_id, y1, x1, y2, x2)) 236 | elif score > entries[image_key][0][0]: 237 | heapq.heapreplace(entries[image_key], (score, action_id, y1, x1, y2, x2)) 238 | for image_key in entries: 239 | # Evaluation API assumes boxes with descending scores 240 | entry = sorted(entries[image_key], key=lambda tup: -tup[0]) 241 | for item in entry: 242 | score, action_id, y1, x1, y2, x2 = item 243 | boxes[image_key].append([y1, x1, y2, x2]) 244 | labels[image_key].append(action_id) 245 | scores[image_key].append(score) 246 | return boxes, labels, scores, all_keys 247 | 248 | 249 | def read_detections(detections, class_whitelist, capacity=50): 250 | """ 251 | Loads boxes and class labels from a list of detections in the AVA format. 252 | """ 253 | entries = defaultdict(list) 254 | boxes = defaultdict(list) 255 | labels = defaultdict(list) 256 | scores = defaultdict(list) 257 | for row in detections: 258 | image_key = make_image_key(row[0], row[1]) 259 | x1, y1, x2, y2 = row[2:6] 260 | action_id = int(row[6]) 261 | if class_whitelist and action_id not in class_whitelist: 262 | continue 263 | score = float(row[7]) 264 | if capacity < 1 or len(entries[image_key]) < capacity: 265 | heapq.heappush(entries[image_key], (score, action_id, y1, x1, y2, x2)) 266 | elif score > entries[image_key][0][0]: 267 | heapq.heapreplace(entries[image_key], (score, action_id, y1, x1, y2, x2)) 268 | for image_key in entries: 269 | # Evaluation API assumes boxes with descending scores 270 | entry = sorted(entries[image_key], key=lambda tup: -tup[0]) 271 | for item in entry: 272 | score, action_id, y1, x1, y2, x2 = item 273 | boxes[image_key].append([y1, x1, y2, x2]) 274 | labels[image_key].append(action_id) 275 | scores[image_key].append(score) 276 | return boxes, labels, scores 277 | 278 | 279 | def read_labelmap(labelmap_file): 280 | """Reads a labelmap without the dependency on protocol buffers. 281 | Args: 282 | labelmap_file: A file object containing a label map protocol buffer. 283 | Returns: 284 | labelmap: The label map in the form used by the object_detection_evaluation 285 | module - a list of {"id": integer, "name": classname } dicts. 286 | class_ids: A set containing all of the valid class id integers. 287 | """ 288 | labelmap = [] 289 | class_ids = set() 290 | name = "" 291 | class_id = "" 292 | for line in labelmap_file: 293 | if line.startswith(" name:"): 294 | name = line.split('"')[1] 295 | elif line.startswith(" id:") or line.startswith(" label_id:"): 296 | class_id = int(line.strip().split(" ")[-1]) 297 | labelmap.append({"id": class_id, "name": name}) 298 | class_ids.add(class_id) 299 | return labelmap, class_ids 300 | 301 | 302 | def run_evaluation_al(detections, groundtruth, labelmap): 303 | """ 304 | Runs AVA Actions evaluation, returns mean average precision result 305 | """ 306 | with open(labelmap, 'r') as f: 307 | categories, class_whitelist = read_labelmap(f) 308 | 309 | pascal_evaluator = object_detection_evaluation.PascalDetectionEvaluator(categories) 310 | 311 | # Reads the ground truth data. 312 | with open(groundtruth, 'r') as f: 313 | boxes, labels, _, included_keys = read_csv(f, class_whitelist) 314 | for image_key in boxes: 315 | pascal_evaluator.add_single_ground_truth_image_info( 316 | image_key, { 317 | standard_fields.InputDataFields.groundtruth_boxes: 318 | np.array(boxes[image_key], dtype=float), 319 | standard_fields.InputDataFields.groundtruth_classes: 320 | np.array(labels[image_key], dtype=int), 321 | standard_fields.InputDataFields.groundtruth_difficult: 322 | np.zeros(len(boxes[image_key]), dtype=bool) 323 | }) 324 | 325 | # Reads detections data. 326 | boxes, labels, scores = read_detections(detections, class_whitelist) 327 | for image_key in boxes: 328 | if image_key not in included_keys: 329 | continue 330 | pascal_evaluator.add_single_detected_image_info( 331 | image_key, { 332 | standard_fields.DetectionResultFields.detection_boxes: 333 | np.array(boxes[image_key], dtype=float), 334 | standard_fields.DetectionResultFields.detection_classes: 335 | np.array(labels[image_key], dtype=int), 336 | standard_fields.DetectionResultFields.detection_scores: 337 | np.array(scores[image_key], dtype=float) 338 | }) 339 | 340 | metrics = pascal_evaluator.evaluate() 341 | return metrics['PascalBoxes_Precision/mAP@0.5IOU'] 342 | 343 | 344 | def get_class_start_end_times(result): 345 | """ 346 | Return the classes and their corresponding start and end times 347 | """ 348 | last_class = result[0] 349 | classes = [last_class] 350 | starts = [0] 351 | ends = [] 352 | 353 | for i, c in enumerate(result): 354 | if c != last_class: 355 | classes.append(c) 356 | starts.append(i) 357 | ends.append(i) 358 | last_class = c 359 | 360 | ends.append(len(result)-1) 361 | 362 | return classes, starts, ends 363 | 364 | 365 | def compare_segmentation(pred, label, th): 366 | """ 367 | Temporally compare the predicted and ground-truth segmentations 368 | """ 369 | 370 | pc, ps, pe = get_class_start_end_times(pred) 371 | lc, ls, le = get_class_start_end_times(label) 372 | 373 | tp = 0 374 | fp = 0 375 | matched = [0]*len(lc) 376 | for i in range(len(pc)): 377 | inter = np.minimum(pe[i], le) - np.maximum(ps[i], ls) 378 | union = np.maximum(pe[i], le) - np.minimum(ps[i], ls) 379 | IoU = (inter/union) * [pc[i] == lc[j] for j in range(len(lc))] 380 | 381 | best_idx = np.array(IoU).argmax() 382 | if IoU[best_idx] >= th and not matched[best_idx]: 383 | tp += 1 384 | matched[best_idx] = 1 385 | else: 386 | fp += 1 387 | 388 | fn = len(lc) - sum(matched) 389 | 390 | return tp, fp, fn 391 | 392 | 393 | def get_eval_score(cfg, preds): 394 | """ 395 | Compute the evaluation score 396 | """ 397 | 398 | # Path to the annotations 399 | path_annts = os.path.join(cfg['root_data'], 'annotations') 400 | 401 | eval_type = cfg['eval_type'] 402 | str_score = "" 403 | if eval_type == 'AVA_ASD': 404 | groundtruth = os.path.join(path_annts, 'ava_activespeaker_val_v1.0.csv') 405 | score = run_evaluation_asd(preds, groundtruth) 406 | str_score = f'{score*100:.2f}%' 407 | elif eval_type == 'AVA_AL': 408 | groundtruth = os.path.join(path_annts, 'ava_val_v2.2.csv') 409 | labelmap = os.path.join(path_annts, 'ava_action_list_v2.2_for_activitynet_2019.pbtxt') 410 | score = run_evaluation_al(preds, groundtruth, labelmap) 411 | str_score = f'{score*100:.2f}%' 412 | elif eval_type == 'AS': 413 | total = 0 414 | correct = 0 415 | threshold = [0.1, 0.25, 0.5] 416 | tp, fp, fn = [0]*len(threshold), [0]*len(threshold), [0]*len(threshold) 417 | 418 | for video_id, pred in preds: 419 | # Get a list of ground-truth action labels 420 | with open(os.path.join(path_annts, f'{cfg["dataset"]}/groundTruth/{video_id}.txt')) as f: 421 | label = [line.strip() for line in f] 422 | 423 | total += len(label) 424 | for i, lb in enumerate(label): 425 | if pred[i] == lb: 426 | correct += 1 427 | 428 | for i, th in enumerate(threshold): 429 | tp_, fp_, fn_ = compare_segmentation(pred, label, th) 430 | tp[i] += tp_ 431 | fp[i] += fp_ 432 | fn[i] += fn_ 433 | 434 | acc = correct/total 435 | str_score = f'(Acc) {acc*100:.2f}%' 436 | for i, th in enumerate(threshold): 437 | pre = tp[i] / (tp[i]+fp[i]) 438 | rec = tp[i] / (tp[i]+fn[i]) 439 | f1 = np.nan_to_num(2*pre*rec / (pre+rec)) 440 | str_score += f', (F1@{th}) {f1*100:.2f}%' 441 | elif eval_type == "VS_max" or eval_type == "VS_avg": 442 | 443 | path_dataset = os.path.join(cfg['root_data'], f'annotations/{cfg["dataset"]}/eccv16_dataset_{cfg["dataset"].lower()}_google_pool5.h5') 444 | with h5py.File(path_dataset, 'r') as hdf: 445 | 446 | all_f1_scores = [] 447 | all_taus = [] 448 | all_rhos = [] 449 | for video, scores in preds: 450 | 451 | n_samples = hdf.get(video + '/n_steps')[()] 452 | n_frames = hdf.get(video + '/n_frames')[()] 453 | gt_segments = np.array(hdf.get(video + '/change_points')) 454 | gt_samples = np.array(hdf.get(video + '/picks')) 455 | gt_scores = np.array(hdf.get(video + '/gtscore')) 456 | user_summaries = np.array(hdf.get(video + '/user_summary')) 457 | 458 | # Take scores from sampled frames to all frames 459 | gt_samples = np.append(gt_samples, [n_frames - 1]) # To account for last frames within loop 460 | frame_scores = np.zeros(n_frames, dtype=np.float32) 461 | for idx in range(n_samples): 462 | frame_scores[gt_samples[idx]:gt_samples[idx + 1]] = scores[idx] 463 | 464 | # Calculate segments' avg score and length 465 | # (Segment_X = video[frame_A:frame_B]) 466 | n_segments = len(gt_segments) 467 | s_scores = np.empty(n_segments) 468 | s_lengths = np.empty(n_segments, dtype=np.int32) 469 | for idx in range(n_segments): 470 | s_lengths[idx] = gt_segments[idx][1] - gt_segments[idx][0] + 1 471 | s_scores[idx] = (frame_scores[gt_segments[idx][0]:gt_segments[idx][1]].mean()) 472 | 473 | # Select for max importance 474 | final_len = int(n_frames * 0.15) # 15% of total length 475 | segments = knapsack.fill_knapsack(final_len, s_scores, s_lengths) 476 | 477 | # Mark frames from selected segments 478 | sum_segs = np.zeros((len(segments), 2), dtype=int) 479 | pred_summary = np.zeros(n_frames, dtype=np.int8) 480 | for i, seg in enumerate(segments): 481 | pred_summary[gt_segments[seg][0]:gt_segments[seg][1]] = 1 482 | sum_segs[i][0] = gt_segments[seg][0] 483 | sum_segs[i][1] = gt_segments[seg][1] 484 | 485 | # Calculate F1-Score per user summary 486 | user_summary = np.zeros(n_frames, dtype=np.int8) 487 | n_user_sums = user_summaries.shape[0] 488 | f1_scores = np.empty(n_user_sums) 489 | 490 | for u_sum_idx in range(n_user_sums): 491 | user_summary[:n_frames] = user_summaries[u_sum_idx] 492 | 493 | # F-1 494 | tp = pred_summary & user_summary 495 | precision = sum(tp)/sum(pred_summary) 496 | recall = sum(tp)/sum(user_summary) 497 | 498 | if (precision + recall) == 0: 499 | f1_scores[u_sum_idx] = 0 500 | else: 501 | f1_scores[u_sum_idx] = (2 * precision * recall * 100 / (precision + recall)) 502 | 503 | # Correlation Metrics 504 | pred_imp_score = np.array(scores) 505 | ref_imp_scores = gt_scores 506 | rho_coeff, _ = stats.spearmanr(pred_imp_score, ref_imp_scores) 507 | tau_coeff, _ = stats.kendalltau(rankdata(-pred_imp_score), rankdata(-ref_imp_scores)) 508 | 509 | all_taus.append(tau_coeff) 510 | all_rhos.append(rho_coeff) 511 | 512 | # Calculate one F1-Score from all user summaries 513 | if eval_type == "VS_max": 514 | f1 = max(f1_scores) 515 | else: 516 | f1 = np.mean(f1_scores) 517 | 518 | all_f1_scores.append(f1) 519 | 520 | f1_score = sum(all_f1_scores) / len(all_f1_scores) 521 | tau = sum(all_taus) / len(all_taus) 522 | rho = sum(all_rhos) / len(all_rhos) 523 | 524 | str_score = f"F1-Score = {f1_score}, Tau = {tau}, Rho = {rho}" 525 | return str_score 526 | -------------------------------------------------------------------------------- /gravit/utils/formatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import pickle #nosec 5 | 6 | 7 | def get_formatting_data_dict(cfg): 8 | """ 9 | Get a dictionary that is used to format the results following the formatting rules of the evaluation tool 10 | """ 11 | 12 | root_data = cfg['root_data'] 13 | dataset = cfg['dataset'] 14 | data_dict = {} 15 | 16 | if 'AVA' in cfg['eval_type']: 17 | # Get a list of the feature files 18 | features = '_'.join(cfg['graph_name'].split('_')[:-3]) 19 | list_data_files = sorted(glob.glob(os.path.join(root_data, f'features/{features}/val/*.pkl'))) 20 | 21 | for data_file in list_data_files: 22 | video_id = os.path.splitext(os.path.basename(data_file))[0] 23 | 24 | with open(data_file, 'rb') as f: 25 | data = pickle.load(f) #nosec 26 | 27 | # Get a list of frame_timestamps 28 | list_fts = sorted([float(frame_timestamp) for frame_timestamp in data.keys()]) 29 | 30 | # Iterate over all the frame_timestamps and retrieve the required data for evaluation 31 | for fts in list_fts: 32 | frame_timestamp = f'{fts:g}' 33 | for entity in data[frame_timestamp]: 34 | data_dict[entity['global_id']] = {'video_id': video_id, 35 | 'frame_timestamp': frame_timestamp, 36 | 'person_box': entity['person_box'], 37 | 'person_id': entity['person_id']} 38 | elif 'AS' in cfg['eval_type']: 39 | # Build a mapping from action ids to action classes 40 | data_dict['actions'] = {} 41 | with open(os.path.join(root_data, 'annotations', dataset, 'mapping.txt')) as f: 42 | for line in f: 43 | aid, cls = line.strip().split(' ') 44 | data_dict['actions'][int(aid)] = cls 45 | 46 | # Get a list of all video ids 47 | data_dict['all_ids'] = sorted([os.path.splitext(v)[0] for v in os.listdir(os.path.join(root_data, f'annotations/{dataset}/groundTruth'))]) 48 | 49 | return data_dict 50 | 51 | 52 | def get_formatted_preds(cfg, logits, g, data_dict): 53 | """ 54 | Get a list of formatted predictions from the model output, which is used to compute the evaluation score 55 | """ 56 | 57 | eval_type = cfg['eval_type'] 58 | preds = [] 59 | if 'AVA' in eval_type: 60 | # Compute scores from the logits 61 | scores_all = torch.sigmoid(logits.detach().cpu()).numpy() 62 | 63 | # Iterate over all the nodes and get the formatted predictions for evaluation 64 | for scores, global_id in zip(scores_all, g): 65 | data = data_dict[global_id] 66 | video_id = data['video_id'] 67 | frame_timestamp = float(data['frame_timestamp']) 68 | x1, y1, x2, y2 = [float(c) for c in data['person_box'].split(',')] 69 | 70 | if eval_type == 'AVA_ASD': 71 | # Line formatted following Challenge #2: http://activity-net.org/challenges/2019/tasks/guest_ava.html 72 | person_id = data['person_id'] 73 | score = scores.item() 74 | pred = [video_id, frame_timestamp, x1, y1, x2, y2, 'SPEAKING_AUDIBLE', person_id, score] 75 | preds.append(pred) 76 | 77 | elif eval_type == 'AVA_AL': 78 | # Line formatted following Challenge #1: http://activity-net.org/challenges/2019/tasks/guest_ava.html 79 | for action_id, score in enumerate(scores, 1): 80 | pred = [video_id, frame_timestamp, x1, y1, x2, y2, action_id, score] 81 | preds.append(pred) 82 | elif 'AS' in eval_type: 83 | tmp = logits 84 | if cfg['use_ref']: 85 | tmp = logits[-1] 86 | 87 | tmp = torch.softmax(tmp.detach().cpu(), dim=1).max(dim=1)[1].tolist() 88 | 89 | # Upsample the predictions to fairly compare with the ground-truth labels 90 | preds = [] 91 | for pred in tmp: 92 | preds.extend([data_dict['actions'][pred]] * cfg['sample_rate']) 93 | 94 | # Pair the final predictions with the video_id 95 | (g,) = g 96 | video_id = data_dict['all_ids'][g] 97 | preds = [(video_id, preds)] 98 | 99 | elif 'VS' in eval_type: 100 | tmp = logits 101 | tmp = torch.sigmoid(tmp.squeeze().cpu()).numpy().tolist() 102 | (g,) = g 103 | preds.append([f"video_{g}", tmp]) 104 | 105 | return preds 106 | -------------------------------------------------------------------------------- /gravit/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | 4 | def get_logger(path_result, file_name, file_mode='w'): 5 | """ 6 | Get the logger that logs runtime messages under "path_result" 7 | """ 8 | 9 | logging.basicConfig(format='%(asctime)s.%(msecs)03d %(message)s', 10 | datefmt='%m/%d/%Y %H:%M:%S', 11 | level=logging.DEBUG, 12 | handlers=[logging.StreamHandler(), 13 | logging.FileHandler(filename=os.path.join(path_result, f'{file_name}.log'), mode=file_mode)]) 14 | 15 | return logging.getLogger() 16 | -------------------------------------------------------------------------------- /gravit/utils/parser.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | 4 | 5 | def get_args(): 6 | """ 7 | Get the command-line arguments for the configuration 8 | """ 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--cfg', type=str, help='Path to the configuration file', required=True) 13 | 14 | # Additional arguments from the command line override the configuration from cfg file 15 | 16 | # Root directories for the training process 17 | parser.add_argument('--root_data', type=str, help='Root directory to the data', default='./data') 18 | parser.add_argument('--root_result', type=str, help='Root directory to output', default='./results') 19 | 20 | # Names required for the training process 21 | parser.add_argument('--exp_name', type=str, help='Name of the experiment') 22 | parser.add_argument('--model_name', type=str, help='Name of the model') 23 | parser.add_argument('--graph_name', type=str, help='Name of the graphs') 24 | parser.add_argument('--loss_name', type=str, help='Name of the loss function') 25 | parser.add_argument('--eval_type', type=str, help='Type of the evaluation') 26 | 27 | # Other hyper-parameters 28 | parser.add_argument('--use_spf', type=bool, help='Whether to use the spatial features') 29 | parser.add_argument('--use_ref', type=bool, help='Whether to use the iterative refinement') 30 | parser.add_argument('--w_ref', type=float, help='Weight for the iterative refinement') 31 | parser.add_argument('--num_modality', type=int, help='Number of input modalities') 32 | parser.add_argument('--channel1', type=int, help='Filter dimension of the first GCN layers') 33 | parser.add_argument('--channel2', type=int, help='Filter dimension of the rest GCN layers') 34 | parser.add_argument('--proj_dim', type=int, help='Dimension of the projected spatial feature') 35 | parser.add_argument('--final_dim', type=int, help='Dimension of the final output') 36 | parser.add_argument('--num_att_heads', type=int, help='Number of attention heads of GATv2') 37 | parser.add_argument('--dropout', type=float, help='Dropout for the last GCN layers') 38 | parser.add_argument('--lr', type=float, help='Initial learning rate') 39 | parser.add_argument('--wd', type=float, help='Weight decay value for regularization') 40 | parser.add_argument('--batch_size', type=int, help='Batch size during the training process') 41 | parser.add_argument('--sch_param', type=int, help='Parameter for lr_scheduler') 42 | parser.add_argument('--num_epoch', type=int, help='Total number of epochs') 43 | parser.add_argument('--sample_rate', type=int, help='Downsampling rate for the input') 44 | parser.add_argument('--split', type=int, help='Which fold to use for cross-validation') 45 | 46 | return parser.parse_args() 47 | 48 | 49 | def get_cfg(args): 50 | """ 51 | Initialize the configuration given the optional command-line arguments 52 | """ 53 | 54 | with open(args.cfg, 'r') as f: 55 | cfg = yaml.safe_load(f) 56 | delattr(args, 'cfg') 57 | 58 | for k, v in vars(args).items(): 59 | if v is None and k in cfg: 60 | continue 61 | 62 | cfg[k] = v 63 | 64 | return cfg 65 | -------------------------------------------------------------------------------- /gravit/utils/vs/avg_splits.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def print_results(all_eval_results): 5 | """ 6 | Given a set of results from all splits, calculates avg value for each metric. 7 | """ 8 | all_f1_scores = [] 9 | all_taus = [] 10 | all_rhos = [] 11 | for eval in all_eval_results: 12 | _, f1, tau, rho = re.findall("\d+\.?\d*", eval) 13 | all_f1_scores.append(float(f1)) 14 | all_taus.append(float(tau)) 15 | all_rhos.append(float(rho)) 16 | 17 | final_f1_score = sum(all_f1_scores) / len(all_f1_scores) 18 | final_tau = sum(all_taus) / len(all_taus) 19 | final_rho = sum(all_rhos) / len(all_rhos) 20 | 21 | print(f"Final average results: F1-Score = {final_f1_score:.4}, Kendall's Tau = {final_tau:.3}, Spearman's Rho = {final_rho:.3}") 22 | -------------------------------------------------------------------------------- /gravit/utils/vs/knapsack.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fill_knapsack(final_len, scores, lens): 5 | """ 6 | Given a set of segments, each with a length and an importance value, 7 | this method determines which segments to include in a final summary 8 | so that total length is less than or equal to final_len and total 9 | added importance is maximized. 10 | """ 11 | 12 | n_segments = len(scores) 13 | 14 | k_table = np.zeros((n_segments + 1, final_len + 1)) 15 | 16 | for seg_idx in range(1, n_segments + 1): 17 | for len_step in range(1, final_len + 1): 18 | if lens[seg_idx - 1] <= len_step: 19 | k_table[seg_idx, len_step] = max( 20 | scores[seg_idx - 1] + 21 | k_table[seg_idx - 1, len_step - lens[seg_idx - 1]], 22 | k_table[seg_idx - 1, len_step]) 23 | else: 24 | k_table[seg_idx, len_step] = k_table[seg_idx - 1, len_step] 25 | 26 | segments = [] 27 | len_left = final_len 28 | for seg_idx in range(n_segments, 0, -1): 29 | # print(f"seg {seg_idx} len {len_left}") 30 | if k_table[seg_idx, len_left] != k_table[seg_idx - 1, len_left]: 31 | segments.insert(0, seg_idx - 1) 32 | len_left -= lens[seg_idx - 1] 33 | 34 | return segments -------------------------------------------------------------------------------- /gravit/utils/vs/run_vs_exp_summe.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | # Training 3 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 1 4 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 2 5 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 3 6 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 4 7 | python tools/train_context_reasoning.py --cfg configs/video-summarization/SumMe/SPELL_default.yaml --split 5 8 | # Evaluation 9 | python tools/evaluate.py --exp_name SPELL_VS_SumMe_default --eval_type VS_max --all_splits -------------------------------------------------------------------------------- /gravit/utils/vs/run_vs_exp_tvsum.sh: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/bash 3 | # Training 4 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 1 5 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 2 6 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 3 7 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 4 8 | python tools/train_context_reasoning.py --cfg configs/video-summarization/TVSum/SPELL_default.yaml --split 5 9 | # Evaluation 10 | python tools/evaluate.py --exp_name SPELL_VS_TVSum_default --eval_type VS_avg --all_splits 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py 2 | pyyaml 3 | pandas 4 | scikit-learn 5 | -f https://download.pytorch.org/whl/cu121/torch_stable.html 6 | -f https://data.pyg.org/whl/torch-2.5.1+cu121.html 7 | torch==2.5.1 8 | torchvision==0.20.1 9 | pyg_lib 10 | torch_scatter 11 | torch_sparse 12 | torch_cluster 13 | torch_spline_conv -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md', 'r') as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name='gravit', 8 | version='1.1.0', 9 | description='Graph learning framework for long-term Video undersTanding', 10 | long_description=long_description, 11 | license='Apache License 2.0', 12 | author='Kyle Min', 13 | author_email='kyle.min@intel.com', 14 | packages=find_packages(), 15 | python_requires='>=3.7', 16 | install_requires=['pyyaml', 'pandas', 'torch', 'torch-geometric>=2.0.3'], 17 | scripts=['data/generate_spatial-temporal_graphs.py', 18 | 'data/generate_temporal_graphs.py', 19 | 'tools/train_context_reasoning.py', 20 | 'tools/evaluate.py'] 21 | ) 22 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from torch_geometric.loader import DataLoader 7 | from gravit.utils.parser import get_cfg 8 | from gravit.utils.logger import get_logger 9 | from gravit.models import build_model 10 | from gravit.datasets import GraphDataset 11 | from gravit.utils.formatter import get_formatting_data_dict, get_formatted_preds 12 | from gravit.utils.eval_tool import get_eval_score 13 | from gravit.utils.vs import avg_splits 14 | 15 | 16 | def evaluate(cfg): 17 | """ 18 | Run the evaluation process given the configuration 19 | """ 20 | 21 | # Input and output paths 22 | path_graphs = os.path.join(cfg['root_data'], f'graphs/{cfg["graph_name"]}') 23 | path_result = os.path.join(cfg['root_result'], f'{cfg["exp_name"]}') 24 | if cfg['split'] is not None: 25 | path_graphs = os.path.join(path_graphs, f'split{cfg["split"]}') 26 | path_result = os.path.join(path_result, f'split{cfg["split"]}') 27 | 28 | # Prepare the logger 29 | logger = get_logger(path_result, file_name='eval') 30 | logger.info(cfg['exp_name']) 31 | logger.info(path_result) 32 | # Build a model and prepare the data loaders 33 | logger.info('Preparing a model and data loaders') 34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | model = build_model(cfg, device) 36 | val_loader = DataLoader(GraphDataset(os.path.join(path_graphs, 'val'))) 37 | num_val_graphs = len(val_loader) 38 | 39 | # Init 40 | #x_dummy = torch.tensor(np.array(np.random.rand(10, 1024), dtype=np.float32), dtype=torch.float32).to(device) 41 | #node_source_dummy = np.random.randint(10, size=5) 42 | #node_target_dummy = np.random.randint(10, size=5) 43 | #edge_index_dummy = torch.tensor(np.array([node_source_dummy, node_target_dummy], dtype=np.int64), dtype=torch.long).to(device) 44 | #signs = np.sign(node_source_dummy - node_target_dummy) 45 | #edge_attr_dummy = torch.tensor(signs, dtype=torch.float32).to(device) 46 | #model(x_dummy, edge_index_dummy, edge_attr_dummy, None) 47 | 48 | # Load the trained model 49 | logger.info('Loading the trained model') 50 | state_dict = torch.load(os.path.join(path_result, 'ckpt_best.pt'), map_location=torch.device('cpu')) 51 | model.load_state_dict(state_dict) 52 | model.eval() 53 | 54 | # Load the feature files to properly format the evaluation results 55 | logger.info('Retrieving the formatting dictionary') 56 | data_dict = get_formatting_data_dict(cfg) 57 | 58 | # Run the evaluation process 59 | logger.info('Evaluation process started') 60 | 61 | preds_all = [] 62 | with torch.no_grad(): 63 | for i, data in enumerate(val_loader, 1): 64 | g = data.g.tolist() 65 | x = data.x.to(device) 66 | edge_index = data.edge_index.to(device) 67 | edge_attr = data.edge_attr.to(device) 68 | c = None 69 | if cfg['use_spf']: 70 | c = data.c.to(device) 71 | 72 | logits = model(x, edge_index, edge_attr, c) 73 | 74 | # Change the format of the model output 75 | preds = get_formatted_preds(cfg, logits, g, data_dict) 76 | preds_all.extend(preds) 77 | 78 | logger.info(f'[{i:04d}|{num_val_graphs:04d}] processed') 79 | 80 | # Compute the evaluation score 81 | logger.info(f'Computing the evaluation score') 82 | eval_score = get_eval_score(cfg, preds_all) 83 | logger.info(f'{cfg["eval_type"]} evaluation finished: {eval_score}\n') 84 | return eval_score 85 | 86 | if __name__ == "__main__": 87 | """ 88 | Evaluate the trained model from the experiment "exp_name" 89 | """ 90 | 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('--root_data', type=str, help='Root directory to the data', default='./data') 93 | parser.add_argument('--root_result', type=str, help='Root directory to output', default='./results') 94 | parser.add_argument('--dataset', type=str, help='Name of the dataset') 95 | parser.add_argument('--exp_name', type=str, help='Name of the experiment', required=True) 96 | parser.add_argument('--eval_type', type=str, help='Type of the evaluation', required=True) 97 | parser.add_argument('--split', type=int, help='Split to evaluate') 98 | parser.add_argument('--all_splits', action='store_true', help='Evaluate all splits') 99 | 100 | args = parser.parse_args() 101 | 102 | path_result = os.path.join(args.root_result, args.exp_name) 103 | if not os.path.isdir(path_result): 104 | raise ValueError(f'Please run the training experiment "{args.exp_name}" first') 105 | 106 | results = [] 107 | if args.all_splits: 108 | results = glob.glob(os.path.join(path_result, "*", "cfg.yaml")) 109 | else: 110 | if args.split: 111 | path_result = os.path.join(path_result, f'split{args.split}') 112 | if not os.path.isdir(path_result): 113 | raise ValueError(f'Please run the training experiment "{args.exp_name}" first') 114 | 115 | results.append(os.path.join(path_result, 'cfg.yaml')) 116 | 117 | all_eval_results = [] 118 | for result in results: 119 | args.cfg = result 120 | cfg = get_cfg(args) 121 | all_eval_results.append(evaluate(cfg)) 122 | 123 | if "VS" in args.eval_type and args.all_splits: 124 | avg_splits.print_results(all_eval_results) 125 | -------------------------------------------------------------------------------- /tools/train_context_reasoning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import torch 4 | import torch.optim as optim 5 | from torch_geometric.loader import DataLoader 6 | from gravit.utils.parser import get_args, get_cfg 7 | from gravit.utils.logger import get_logger 8 | from gravit.models import build_model, get_loss_func 9 | from gravit.datasets import GraphDataset 10 | 11 | 12 | def train(cfg): 13 | """ 14 | Run the training process given the configuration 15 | """ 16 | 17 | # Input and output paths 18 | path_graphs = os.path.join(cfg['root_data'], f'graphs/{cfg["graph_name"]}') 19 | path_result = os.path.join(cfg['root_result'], f'{cfg["exp_name"]}') 20 | if cfg['split'] is not None: 21 | path_graphs = os.path.join(path_graphs, f'split{cfg["split"]}') 22 | path_result = os.path.join(path_result, f'split{cfg["split"]}') 23 | os.makedirs(path_result, exist_ok=True) 24 | 25 | # Prepare the logger and save the current configuration for future reference 26 | logger = get_logger(path_result, file_name='train') 27 | logger.info(cfg['exp_name']) 28 | logger.info('Saving the configuration file') 29 | with open(os.path.join(path_result, 'cfg.yaml'), 'w') as f: 30 | yaml.dump({k: v for k, v in cfg.items() if v is not None}, f, default_flow_style=False, sort_keys=False) 31 | 32 | # Build a model and prepare the data loaders 33 | logger.info('Preparing a model and data loaders') 34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 35 | model = build_model(cfg, device) 36 | train_loader = DataLoader(GraphDataset(os.path.join(path_graphs, 'train')), batch_size=cfg['batch_size'], shuffle=True) 37 | val_loader = DataLoader(GraphDataset(os.path.join(path_graphs, 'val'))) 38 | 39 | # Prepare the experiment 40 | loss_func = get_loss_func(cfg) 41 | loss_func_val = get_loss_func(cfg, 'val') 42 | optimizer = optim.Adam(model.parameters(), lr=cfg['lr'], weight_decay=cfg['wd']) 43 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['sch_param']) 44 | 45 | # Run the training process 46 | logger.info('Training process started') 47 | 48 | min_loss_val = float('inf') 49 | for epoch in range(1, cfg['num_epoch']+1): 50 | model.train() 51 | 52 | # Train for a single epoch 53 | loss_sum = 0. 54 | for data in train_loader: 55 | optimizer.zero_grad() 56 | 57 | x, y = data.x.to(device), data.y.to(device) 58 | edge_index = data.edge_index.to(device) 59 | edge_attr = data.edge_attr.to(device) 60 | c = None 61 | if cfg['use_spf']: 62 | c = data.c.to(device) 63 | 64 | logits = model(x, edge_index, edge_attr, c) 65 | 66 | loss = loss_func(logits, y) 67 | loss.backward() 68 | loss_sum += loss.item() 69 | optimizer.step() 70 | 71 | # Adjust the learning rate 72 | scheduler.step() 73 | 74 | loss_train = loss_sum / len(train_loader) 75 | 76 | # Get the validation loss 77 | loss_val = val(val_loader, cfg['use_spf'], model, device, loss_func_val) 78 | 79 | # Save the best-performing checkpoint 80 | if loss_val < min_loss_val: 81 | min_loss_val = loss_val 82 | epoch_best = epoch 83 | torch.save(model.state_dict(), os.path.join(path_result, 'ckpt_best.pt')) 84 | 85 | # Log the losses for every epoch 86 | logger.info(f'Epoch [{epoch:03d}|{cfg["num_epoch"]:03d}] loss_train: {loss_train:.4f}, loss_val: {loss_val:.4f}, best: epoch {epoch_best:03d}') 87 | 88 | logger.info('Training finished') 89 | 90 | 91 | def val(val_loader, use_spf, model, device, loss_func): 92 | """ 93 | Run a single validation process 94 | """ 95 | 96 | model.eval() 97 | loss_sum = 0 98 | with torch.no_grad(): 99 | for data in val_loader: 100 | x, y = data.x.to(device), data.y.to(device) 101 | edge_index = data.edge_index.to(device) 102 | edge_attr = data.edge_attr.to(device) 103 | c = None 104 | if use_spf: 105 | c = data.c.to(device) 106 | 107 | logits = model(x, edge_index, edge_attr, c) 108 | loss = loss_func(logits, y) 109 | loss_sum += loss.item() 110 | 111 | return loss_sum / len(val_loader) 112 | 113 | 114 | if __name__ == "__main__": 115 | args = get_args() 116 | cfg = get_cfg(args) 117 | 118 | train(cfg) 119 | --------------------------------------------------------------------------------