├── .DS_Store ├── .gitignore ├── LICENSE ├── NOTICE.txt ├── README.md ├── capture ├── capture_metric │ ├── __init__.py │ ├── capture.py │ └── stop_words.py ├── ckpt │ └── README.md └── setup.py ├── detail_caption_construction ├── ckpt │ └── README.md ├── config_llava15_7b_detailcaps_4870 │ ├── caption_anything_caption_reorganization.yaml │ ├── stage1_overall_caption.yaml │ ├── stage2_bbox.yaml │ ├── stage3_local_caption.yaml │ ├── stage4_filter.yaml │ └── stage5_caption_merge.yaml ├── generate_all.sh ├── generate_stage1_overall_caption.py ├── generate_stage2_bbox.py ├── generate_stage3_local_caption.py ├── generate_stage4_filter.py ├── generate_stage5_caption_merge.py ├── merge_results.py └── utils │ ├── bbox_cluster.py │ ├── bbox_statistics.py │ ├── image_processing_owlv2.py │ └── utils.py ├── images └── intro.png └── prepare.sh /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-multimodal-models/CAPTURE/52eeb2781e8b4b1854c07a57b573c75edfd45688/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *all-mpnet-base-v2* 3 | *flan-t5-base-VG-factual-sg* 4 | detail_caption_construction/LLaVA 5 | capture/FactualSceneGraph 6 | detail_caption_construction/*data 7 | detail_caption_construction/scripts_output/* 8 | check.py 9 | *__pycache__* 10 | opensource_git_commit.log 11 | sensitive_info_result.txt 12 | 13 | *FactualSceneGraph* 14 | capture/README.md 15 | capture/dist 16 | capture/*.egg-info 17 | capture/build 18 | 19 | *.DS_Store* 20 | *.out 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /NOTICE.txt: -------------------------------------------------------------------------------- 1 | Copyright (2024) Bytedance Ltd. and/or its affiliates -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Benchmarking and Improving Detail Image Caption 2 | [![License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/foundation-multimodal-models/CAPTURE) 3 | 4 | [![Dataset](https://img.shields.io/badge/Dataset-Huggingface%204.0-yellow)](https://huggingface.co/datasets/foundation-multimodal-models/DetailCaps-4870) 5 | 6 | 7 | Code and data for paper: 8 | 9 | *Benchmarking and Improving Detail Image Caption*. 10 | Hongyuan Dong\*, Jiawen Li\*, Bohong Wu, Jiacong Wang, Yuan Zhang, Haoyuan Guo (* Equal Contribution) 11 | 12 | 13 | Our paper is now available on [arXiv](https://arxiv.org/abs/2405.19092). 14 | 15 | 16 | ## Overview 17 | Image captioning has long been regarded as a fundamental task in visual understanding. 18 | Recently, however, few large vision-language model (LVLM) research discusses model's image captioning performance because of the outdated short-caption benchmarks and unreliable evaluation metrics. 19 | In this work, we propose to benchmark detail image caption task by curating high-quality evaluation datasets annotated by human experts, GPT-4V and Gemini-1.5-Pro. 20 | We also design a more reliable caption evaluation metric called **CAPTURE** (CAPtion evaluation by exTracting and coUpling coRE information). 21 | CAPTURE extracts visual elements, e.g., objects, attributes and relations from captions, and then matches these elements through three stages, achieving the highest consistency with expert judgements over other rule-based or model-based caption metrics. 22 | The proposed benchmark and metric provide reliable evaluation for LVLM's detailed image captioning ability. 23 | Guided by this evaluation, we further explore to unleash LVLM's detail caption capabilities by synthesizing high-quality data through a five-stage data construction pipeline. 24 | Our pipeline only uses a given LVLM itself and other open-source tools, without any human or GPT-4V annotation in the loop. 25 | Experiments show that the proposed data construction strategy significantly improves model-generated detail caption data quality for LVLMs with leading performance, and the data quality can be further improved in a self-looping paradigm. 26 |

27 | 28 |

29 | 30 | 31 | ## Detail Image Caption Benchmark 32 | We release the DetailCaps-4870 benchmark, which contains 4870 images with high-quality reference captions annotated by GPT-4V&Gemini-1.5-Pro. 33 | The statistics of DetailCaps-4870 compared with other image caption benchmarks of comparables sizes is shown below: 34 | 35 | | Benchmark | Data source | Annt. expert | Img num | ref num | Avg len | Uni. 2-gram | 36 | | --- | --- | --- | --- | --- | --- | --- | 37 | | **COCOtest** | COCO | Human | $5000$ | $25,010$ | $10.59$ | $61,448$ | 38 | | **Nocapsval** | Openimages | Human | $4500$ | $45,000$ | $11.49$ | $116,969$ | 39 | | **DetailCaps-100** | COCO, SAM, LAION, CC, SBU | GPT-4V, Human | $100$ | $100$ | $175.96$ | $10,858$ | 40 | | **DetailCaps-4870** | COCO, SAM, LAION, CC, SBU, Coyo, Flickr | GPT-4V, GPT4O, Gemini-1.5-Pro | $4870$ | $14610$ | $122.06$ | $533,201$ | 41 | 42 | The evaluation dataset will soon be available on [Huggingface](https://huggingface.co/). 43 | Please download the dataset and put it under the `datasets` folder. 44 | 45 | ## Detail Image Caption Evaluation Metric: CAPTURE 46 | The proposed metric **CAPTURE** (CAPtion evaluation by exTracting and coUpling coRE information) achieves the highest consistency with expert judgements on DetailCaps benchmarks. 47 | We show the average consistency scores on DetailCaps-100 and DetailCaps-4870 benchmarks in the table below. 48 | 49 | | Caption metric | PCC $\rho$ $\uparrow$ | $1-R^2$ $\downarrow$ | Kendall's $\tau$ $\uparrow$ | Sample $\tau$ $\uparrow$ | 50 | | --- | --- | --- | --- | --- | 51 | | **BLEU** | $0.2608$ | $54.75$ | $0.1866$ | $0.2462$ | 52 | | **ROUGE-L** | $0.2951$ | $134.12$ | $0.2149$ | $0.3383$ | 53 | | **CIDEr** | $0.1148$ | $2.6e^7$ | $0.1165$ | $0.0991$ | 54 | | **METEOR** | $0.4022$ | $290.38$ | $0.2927$ | $0.4062$ | 55 | | **SPICE** | $0.4386$ | $155.95$ | $0.3244$ | $0.4718$ | 56 | | **CLIPScore** | $0.3558$ | $21.46$ | $0.2479$ | $0.3841$ | 57 | | **CAPTURE** | $0.5091$ | $8.29$ | $0.3861$ | $0.6018$ | 58 | 59 | We evaluate SOTA open-source LVLMs' detail captioning abilities with our benchmark and metric. 60 | The results are listed below. 61 | 62 | | Model | Language Model | Caption Data | Resolution | CAPTURE | 63 | | --- | --- | --- | --- | --- | 64 | | **CogVLM** | Vicuna-7B | Human Annt. | $490^2$ | $60.06$ | 65 | | **ShareCaptioner-7B** | Vicuna-7B | GPT-4V Annt. | $448^2$ | $59.80$ | 66 | | **LLaVA-1.5-7B** | Vicuna-7B | Synthesized | $336^2$ | $51.05$ | 67 | | **LLaVA-1.5-13B** | Vicuna-13B | Synthesized | $336^2$ | $51.20$ | 68 | | **LLaVA-NEXT-7B** | Vicuna-7B | GPT-4V Annt. | $336^2$*{1-5} | $58.61$ | 69 | | **LLaVA-NEXT-13B** | Vicuna-13B | GPT-4V Annt. | $336^2$*{1-5} | $59.01$ | 70 | | **LLaVA-NEXT-34B** | Hermes-2-Yi-34B | GPT-4V Annt. | $336^2$*{1-5} | $59.20$ | 71 | | **Mini-Gemini-HD-7B** | Vicuna-7B | GPT-4V Annt. | $336^2$*5 | $57.95$ | 72 | | **Mini-Gemini-HD-13B** | Vicuna-13B | GPT-4V Annt. | $336^2$*5 | $58.66$ | 73 | | **Intern-XComposerV2** | Vicuna-7B | GPT-4V Annt. | $490^2$ | $59.86$ | 74 | | **InternVL-V1.2-PLUS-40B** | Hermes-2-Yi-34B | GPT-4V Annt. | $448^2$ | $60.69$ | 75 | | **InternVL-V1.5-26B** | InternLM-20B | GPT-4V Annt. | $448^2$*{1-41} | $63.42$ | 76 | 77 | 78 | ## Detail Image Caption Construction 79 | We construct a data construction pipeline to unleash LVLM's detail image captioning ability with open-source vision and language tools. 80 | We show the performance of the performance of the proposed data construction pipeline with different LVLM bachbones below. 81 | 82 | | Caption | DetailCaps-100 | DetailCaps-4870 | Average | 83 | | --- | --- | --- | --- | 84 | | **LLaVA-1.5-7B self** | $51.23$ | $51.05$ | $51.14$ | 85 | | **LLaVA-1.5-7B syn** | $57.11$ | $56.25$ | $56.68$ | 86 | | **LLaVA-1.5-13B self** | $51.76$ | $51.20$ | $51.48$ | 87 | | **LLaVA-1.5-13B syn** | $57.36$ | $57.05$ | $57.20$ | 88 | | **LLaVA-NEXT-7B self** | $61.48$ | $58.61$ | $60.73$ | 89 | | **LLaVA-NEXT-7B syn** | $62.24$ | $60.39$ | $61.31$ | 90 | | **Mini-Gemini-7B-HD self** | $59.51$ | $57.95$ | $58.73$ | 91 | | **Mini-Gemini-7B-HD syn** | $60.44$ | $59.07$ | $59.75$ | 92 | 93 | 94 | 95 | ## Quick Start 96 | 97 | ### Environment 98 | Run the following scripts to prepare the environment for CAPTURE and the data construction pipeline. 99 | ```bash 100 | conda create -n detailcaption python=3.9 101 | conda activate detailcaption 102 | bash prepare.sh 103 | ``` 104 | 105 | ### Detail Image Caption Evaluation 106 | We have wrapped the proposed CAPTURE evaluation metric into pip package, and you can install it as follows: 107 | ```bash 108 | pip3 install capture_metric 109 | ``` 110 | After installation, CAPTURE metric can be used in the same way as other caption evaluation metrics implemented in [pycocoevalcap](https://github.com/sks3i/pycocoevalcap), such as BLEU, CIDEr, METEOR, ROUGE, etc. 111 | Here is an example: 112 | ```python 113 | from capture_metric.capture import CAPTURE 114 | refs = { 115 | : [ref_0, ref_1, ...], 116 | ... 117 | } 118 | preds = { 119 | : [pred_caption], 120 | ... 121 | } 122 | 123 | evaluator = CAPTURE() 124 | score = evaluator.compute_score(refs, preds) 125 | print(f"CAPTURE score: {score}") 126 | ``` 127 | 128 | You can now use [lmms_eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) to evaluate you LVLM's detail image caption performance on the DetailCaps-4870 benchmark with CAPTURE metric. 129 | We refer to [lmms detailcaps](https://github.com/EvolvingLMMs-Lab/lmms-eval/tasks/detailcaps) for more details. 130 | 131 | 132 | ### Detail Image Caption Construction 133 | For detail image caption construction, first download SAM, Owlv2, LLaVA-v1.5 (or other LVLM), LLaMA-2 and place them under `ckpt` folder: 134 | ``` 135 | ckpt 136 | ├─sam 137 | | ├─sam_vit_h_4b8939.pth 138 | | └─sam_vit_l_0b3195.pth 139 | ├─owlv2-large-patch14-ensemble 140 | ├─llava-v1.5-13b 141 | ├─llava-v1.5-7b 142 | ├─llava-v1.5-13b 143 | ├─Llama-2-7b-chat-hf 144 | └─Llama-2-13b-chat-hf 145 | ``` 146 | Then organize your image data in `.parquet` format with binary image stored in the `frame` field. 147 | Run the followig script to generate annotations for your parquet data files stored in ``. 148 | `` should be set as either `7b` or `13b`, corresponding to pipelines for different model size. 149 | ```bash 150 | bash generate_all_annotations.sh 151 | ``` 152 | 153 | 154 | ## Citation 155 | ```bibtex 156 | @article{dong2024benchmarking, 157 | title={Benchmarking and Improving Detail Image Caption}, 158 | author={Dong, Hongyuan and Li, Jiawen and Wu, Bohong and Wang, Jiacong and Zhang, Yuan and Guo, Haoyuan}, 159 | journal={arXiv preprint arXiv:2405.19092}, 160 | year={2024} 161 | } 162 | ``` 163 | 164 | 165 | -------------------------------------------------------------------------------- /capture/capture_metric/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import nltk 3 | import os 4 | 5 | 6 | def nltk_find_and_download(package_name, path): 7 | downloaded = False 8 | try: 9 | if (nltk.data.find(f'{path}/{package_name}')): 10 | downloaded = True 11 | except: 12 | pass 13 | try: 14 | if nltk.data.find(f'{path}/{package_name}.zip'): 15 | downloaded = True 16 | except: 17 | pass 18 | 19 | if not downloaded: 20 | nltk.download(package_name) 21 | 22 | 23 | def download_nltk_data(): 24 | nltk_find_and_download('wordnet', 'corpora') 25 | nltk_find_and_download('punkt', 'tokenizers') 26 | nltk_find_and_download('averaged_perceptron_tagger', 'taggers') 27 | 28 | 29 | if int(os.environ.get("RANK", 0)) == 0: 30 | download_nltk_data() 31 | -------------------------------------------------------------------------------- /capture/capture_metric/capture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import functools 19 | import tabulate 20 | import nltk 21 | from nltk.corpus import wordnet 22 | from nltk.stem import WordNetLemmatizer 23 | from nltk.tokenize import sent_tokenize 24 | import collections 25 | import torch 26 | import tqdm 27 | import contextlib 28 | import io 29 | from sentence_transformers import SentenceTransformer 30 | import numpy as np 31 | import multiprocessing 32 | from statistics import mean 33 | 34 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 35 | from factual_scene_graph.evaluation.soft_spice_evaluation import encode_phrases 36 | 37 | 38 | _tabulate_format = tabulate.TableFormat( 39 | lineabove=tabulate.Line("+", "-", "+", "+"), 40 | linebelowheader=tabulate.Line("|", "-", "+", "|"), 41 | linebetweenrows=None, 42 | linebelow=tabulate.Line("+", "-", "+", "+"), 43 | headerrow=tabulate.DataRow("|", "|", "|"), 44 | datarow=tabulate.DataRow("|", "|", "|"), 45 | padding=1, with_header_hide=None 46 | ) 47 | 48 | def tprint(graph, file=None): 49 | """ 50 | Print a scene graph as a table. 51 | The printed strings contain essential information about the parsed scene graph. 52 | """ 53 | assert isinstance(graph, dict), 'Input must be a dictionary' 54 | _print = functools.partial(print, file=file) 55 | 56 | _print('Entities:') 57 | entities_data = [ 58 | [e['head'].lower(), e.get('quantity', ''), ','.join(e.get('attributes', set()))] 59 | for e in graph['entities'] 60 | ] 61 | _print(tabulate.tabulate(entities_data, headers=['Entity', 'Quantity', 'Attributes'], tablefmt=_tabulate_format)) 62 | 63 | _print('Relations:') 64 | relations_data = [ 65 | [ 66 | graph['entities'][rel['subject']]['head'].lower(), 67 | rel['relation'].lower(), 68 | graph['entities'][rel['object']]['head'].lower() 69 | ] 70 | for rel in graph['relations'] 71 | ] 72 | _print(tabulate.tabulate(relations_data, headers=['Subject', 'Relation', 'Object'], tablefmt=_tabulate_format)) 73 | 74 | 75 | def merge_sentence_results(results, text_processor): 76 | # from IPython import embed; embed() 77 | objects, attributes, relations = set(), collections.defaultdict(set), set() 78 | for result in results: 79 | for entity in result['entities']: 80 | lemmatized_obj = text_processor.normalize_word(entity['head'], wordnet.NOUN) 81 | objects.add(lemmatized_obj) 82 | for attribute in entity['attributes']: 83 | attribute = text_processor.normalize_word(attribute, wordnet.ADJ) 84 | if ' of' in attribute: 85 | continue 86 | attributes[lemmatized_obj].add(attribute) 87 | for relation in result['relations']: 88 | relations.add(( 89 | text_processor.normalize_word(result['entities'][relation['subject']]['head'], wordnet.NOUN), 90 | relation['relation'], 91 | text_processor.normalize_word(result['entities'][relation['object']]['head'], wordnet.NOUN) 92 | )) 93 | 94 | return objects, attributes, relations 95 | 96 | 97 | def are_tuples_match(synsets1, synsets2): 98 | """ 99 | Determine if two lists of synsets have non-empty intersections for corresponding elements. 100 | 101 | :param synsets1: First list of synsets. 102 | :param synsets2: Second list of synsets. 103 | :return: True if all corresponding synsets have a non-empty intersection, False otherwise. 104 | """ 105 | 106 | return len(synsets1) == len(synsets2) and all(s1.intersection(s2) for s1, s2 in zip(synsets1, synsets2)) 107 | 108 | 109 | def get_synonyms(word): 110 | synsets = wordnet.synsets(word) 111 | synonyms = set() 112 | for synset in synsets: 113 | for lemma in synset.lemmas(): 114 | synonyms.add(lemma.name()) 115 | return synonyms 116 | 117 | 118 | def set_mp_context(expected_context='spawn'): 119 | default_context_name = torch.multiprocessing.get_context().get_start_method() 120 | if default_context_name != expected_context: 121 | torch.multiprocessing.set_start_method('spawn', force=True) 122 | return 123 | 124 | 125 | class TextProcessor: 126 | def __init__(self) -> None: 127 | self.wnl = WordNetLemmatizer() 128 | 129 | def normalize_word(self, word, pos): 130 | return self.wnl.lemmatize(word, pos=pos) 131 | 132 | 133 | class CAPTURE: 134 | def __init__( 135 | self, 136 | alpha: float = 0.5, 137 | beta: float = 0.5, 138 | gamma: float = 0.2, 139 | synonym_matching: bool = True, 140 | soft_matching: bool = True, 141 | stop_words: bool = True, 142 | eps: float = 1e-6, 143 | ): 144 | """ 145 | Args: 146 | alpha (`float`, *optional*, defaults to be 0.5): 147 | The ratio of object F1 score considered in CAPTURE score computation. 148 | beta (`float`, *optional*, defaults to be 0.5): 149 | The ratio of attribute F1 score considered in CAPTURE score computation. 150 | The summation of alpha and beta must equals to 1. 151 | gamma (`float`, *optional*, defaults to be 0.2): 152 | The ratio of relation F1 score considered in CAPTURE score computation. 153 | synonym_matching (`bool`, *optional*, defaults to be True): 154 | Controls whether to use synonym_matching for visual elements mathcing. 155 | soft_matching (`bool`, *optional*, defaults to be True): 156 | Controls whether to use soft_matching for visual elements mathcing. 157 | stop_words (`bool`, *optional*, defaults to be True): 158 | Controls whether to use stop words object elements filtering. 159 | eps (`float`, *optional*, defaults to be 1e-6): 160 | A small number to avoid division by zero when computing precision, recall and F1. 161 | """ 162 | self.alpha = alpha 163 | self.beta = beta 164 | assert self.alpha + self.beta == 1. 165 | self.gamma = gamma 166 | self.parser = None 167 | self.text_processor=TextProcessor() 168 | self.synonym_matching = synonym_matching 169 | 170 | if stop_words: 171 | from capture_metric.stop_words import stop_words_list 172 | self.stop_words_list = set(stop_words_list) 173 | else: 174 | self.stop_words_list = set([]) 175 | 176 | self.eps = eps 177 | 178 | self.soft_matching = soft_matching 179 | if self.soft_matching: 180 | self.text_encoder = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to('cuda:0').eval() 181 | 182 | 183 | def compute_synonyms_score(self, word1, word2): 184 | # in case word1 or word2 consists of multiple words 185 | if word1 in word2 or word2 in word1: 186 | return 1 187 | elif len(word1.split()) > 0 or len(word2.split() > 0): 188 | word1 = '_'.join(word1.split()) 189 | word2 = '_'.join(word2.split()) 190 | 191 | synonyms1 = get_synonyms(word1) 192 | synonyms2 = get_synonyms(word2) 193 | iou = len(synonyms1.intersection(synonyms2)) / (len(synonyms1.union(synonyms2)) + self.eps) 194 | return iou 195 | 196 | 197 | def compute_match(self, all_cand, all_gt): 198 | total_match = 0 199 | matched_cand_indices, matched_ref_indices = set(), set() 200 | for ii, cand in enumerate(all_cand): 201 | for jj, ref in enumerate(all_gt): 202 | if cand == ref and jj not in matched_ref_indices: 203 | matched_cand_indices.add(ii) 204 | matched_ref_indices.add(jj) 205 | # print(cand, ref) 206 | total_match += 1 207 | break 208 | 209 | if self.synonym_matching: 210 | for ii, cand in enumerate(all_cand): 211 | if ii not in matched_cand_indices: 212 | for jj, ref in enumerate(all_gt): 213 | if jj not in matched_ref_indices and self.compute_synonyms_score(cand, ref) > 0.: 214 | matched_cand_indices.add(ii) 215 | matched_ref_indices.add(jj) 216 | # print(cand, ref) 217 | total_match += 1 218 | break 219 | 220 | remained_cands = [cand for i, cand in enumerate(all_cand) if i not in matched_cand_indices] 221 | remained_refs = [gt for j, gt in enumerate(all_gt) if j not in matched_ref_indices] 222 | cand_match = total_match 223 | ref_match = total_match 224 | if self.soft_matching and len(remained_cands) > 0 and len(remained_refs) > 0: 225 | with io.StringIO() as f: 226 | with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): 227 | remained_cands_features, remained_refs_features = encode_phrases(self.text_encoder, remained_cands, remained_refs, batch_size=4) 228 | sim_mat = remained_cands_features.dot(remained_refs_features.T) 229 | remained_cands_match = np.sum(np.max(sim_mat, axis=1)) 230 | remained_refs_match = np.sum(np.max(sim_mat, axis=0)) 231 | cand_match = total_match + remained_cands_match 232 | ref_match = total_match + remained_refs_match 233 | 234 | return total_match, cand_match, ref_match 235 | 236 | 237 | def get_all_lemmatized_nouns(self, text): 238 | tokens = nltk.word_tokenize(text) 239 | tagged = nltk.pos_tag(tokens) 240 | nouns = [self.text_processor.normalize_word(token, pos=wordnet.NOUN) for token, tag in tagged if tag.startswith('NN')] 241 | return nouns 242 | 243 | 244 | def compute_f_score(self, gt_parsed, cand_parsed): 245 | gt_objects, gt_attributes, gt_relations = gt_parsed 246 | cand_objects, cand_attributes, cand_relations = cand_parsed 247 | 248 | # Objects 249 | object_match, object_cand_match, object_ref_match = self.compute_match(cand_objects, gt_objects) 250 | object_precision, object_recall = object_cand_match / (len(cand_objects) + self.eps), object_ref_match / (len(gt_objects) + self.eps) 251 | object_f1 = 2 * object_precision * object_recall / (object_precision + object_recall + self.eps) 252 | 253 | # Attributes 254 | gt_attributes_words, cand_attributes_words = [], [] 255 | for k, v in gt_attributes.items(): 256 | gt_attributes_words.extend(v) 257 | for k, v in cand_attributes.items(): 258 | cand_attributes_words.extend(v) 259 | attribute_match, attribute_cand_match, attribute_ref_match = self.compute_match(cand_attributes_words, gt_attributes_words) 260 | attribute_precision, attribute_recall = attribute_cand_match / (len(cand_attributes_words) + self.eps), attribute_ref_match / (len(gt_attributes_words) + self.eps) 261 | attribute_f1 = 2 * attribute_precision * attribute_recall / (attribute_precision + attribute_recall + self.eps) 262 | 263 | # Relations 264 | relation_match = 0 265 | matched_cand_indices, matched_ref_indices = set(), set() 266 | for i, cand in enumerate(cand_relations): 267 | for j, ref in enumerate(gt_relations): 268 | if cand == ref and j not in matched_ref_indices: 269 | matched_cand_indices.add(i) 270 | matched_ref_indices.add(j) 271 | relation_match += 1 272 | break 273 | 274 | if self.synonym_matching: 275 | for i, cand in enumerate(cand_relations): 276 | if i not in matched_cand_indices: 277 | for j, ref in enumerate(gt_relations): 278 | if j not in matched_ref_indices and all([self.compute_synonyms_score(cand_ele, ref_ele) > 0. for cand_ele, ref_ele in zip(cand, ref)]): 279 | matched_cand_indices.add(i) 280 | matched_ref_indices.add(j) 281 | relation_match += 1 282 | break 283 | 284 | remained_cands = [' '.join(cand) for i, cand in enumerate(cand_relations) if i not in matched_cand_indices] 285 | remained_refs = [' '.join(gt) for j, gt in enumerate(gt_relations) if j not in matched_ref_indices] 286 | cands_match = relation_match 287 | refs_match = relation_match 288 | if self.soft_matching and len(remained_cands) > 0 and len(remained_refs) > 0: 289 | with io.StringIO() as f: 290 | with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): 291 | remained_cands_features, remained_refs_features = encode_phrases(self.text_encoder, remained_cands, remained_refs, batch_size=4) 292 | sim_mat = remained_cands_features.dot(remained_refs_features.T) 293 | remained_cands_match = np.sum(np.max(sim_mat, axis=1)) 294 | remained_refs_match = np.sum(np.max(sim_mat, axis=0)) 295 | cands_match += remained_cands_match 296 | refs_match += remained_refs_match 297 | 298 | relation_precision, relation_recall = cands_match / (len(cand_relations) + self.eps), refs_match / (len(gt_relations) + self.eps) 299 | relation_f1 = 2 * relation_precision * relation_recall / (relation_precision + relation_recall + self.eps) 300 | 301 | capture_score = self.alpha*object_f1 + self.beta*attribute_f1 + self.gamma * relation_f1 302 | capture_score /= (self.alpha + self.beta + self.gamma) 303 | # print(f"obj_f1: {object_f1}, attr_f1: {attribute_f1}, rel_f1: {relation_f1} capture: {capture_score}") 304 | 305 | return capture_score, object_precision, object_recall, object_f1, \ 306 | attribute_precision, attribute_recall, attribute_f1, \ 307 | relation_precision, relation_recall, relation_f1 308 | 309 | 310 | def sample_to_parse_results(self, sample): 311 | sample_index, text = sample[0], sample[1] 312 | try: 313 | sentences = sent_tokenize(text) 314 | except Exception as e: 315 | print(e) 316 | print(f"text: {text}") 317 | import pdb; pdb.set_trace() 318 | with torch.no_grad(): 319 | with io.StringIO() as f: 320 | with contextlib.redirect_stdout(f), contextlib.redirect_stderr(f): 321 | graph_obj = self.parser.parse(sentences, beam_size=5, return_text=False,max_output_len=128) 322 | 323 | objects, attributes, relations = merge_sentence_results(graph_obj, self.text_processor) 324 | text_all_nouns = set(self.get_all_lemmatized_nouns(text)) 325 | objects = [object for object in objects if object not in self.stop_words_list and (object in text_all_nouns or all([piece in text_all_nouns for piece in object.split(' ')]))] 326 | attributes = {k: v for k,v in attributes.items() if (k in text_all_nouns or all([piece in text_all_nouns for piece in k.split(' ')]))} # k in text_all_nouns and k not in self.stop_words_list} 327 | relations = set([relation for relation in relations if (relation[0] in text_all_nouns or all([piece in text_all_nouns for piece in relation[0].split(' ')])) and (relation[2] in text_all_nouns or all([piece in text_all_nouns for piece in relation[2].split(' ')])) ]) 328 | return sample_index, objects, attributes, relations 329 | 330 | 331 | def parse_samples(self, samples, device, desc=""): 332 | torch.cuda.set_device(int(str(device)[-1])) 333 | if self.parser is not None and hasattr(self.parser, 'device') and self.parser.device == device: 334 | pass 335 | else: 336 | if self.parser is not None: 337 | print(f"self.parser.device {self.parser.device} device {device}") 338 | if torch.cuda.is_available(): 339 | self.parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device) 340 | else: 341 | self.parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device='cpu') 342 | self.parser.model.eval() 343 | parsed_samples = [] 344 | for sample in tqdm.tqdm(samples, desc=desc + ' ' + str(device)): 345 | parsed_sample = self.sample_to_parse_results(sample) 346 | parsed_samples.append(parsed_sample) 347 | return parsed_samples 348 | 349 | 350 | def process_samples_multiprocessing(self, partitioned_data, desc="parsing"): 351 | set_mp_context() 352 | with multiprocessing.Pool(processes=torch.cuda.device_count()) as pool: 353 | futures = [] 354 | for idx, this_partitioned_data in enumerate(partitioned_data): 355 | future = pool.apply_async(self.parse_samples, args=(this_partitioned_data, torch.device(f'cuda:{idx}'), desc)) 356 | futures.append(future) 357 | all_parsed = [] 358 | for future in futures: 359 | results = future.get() 360 | all_parsed.extend(results) 361 | # all_parsed.sort(key=lambda x: x[0]) 362 | # all_parsed = [(res[1], res[2], res[3]) for res in all_parsed] 363 | # return all_parsed 364 | 365 | all_parsed_dict = collections.defaultdict(list) 366 | for parsed_sample in all_parsed: 367 | all_parsed_dict[parsed_sample[0]].append(parsed_sample[1:]) 368 | return all_parsed_dict 369 | 370 | 371 | def compute_score(self, gts, res, prev_gt_parsed=None, prev_cand_parsed=None, return_parse_results=False): 372 | gts = [(sample_key, gt) for sample_key, sample_gts in gts.items() for gt in sample_gts] 373 | cands = [(sample_key, sample_res[0]) for sample_key, sample_res in res.items()] 374 | 375 | def partition_data(data): 376 | num_chunk = torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1 377 | chunk_size = len(data) // num_chunk 378 | partitioned_data = [] 379 | start = 0 380 | for i in range(num_chunk): 381 | end = start + chunk_size 382 | if i < len(data) % num_chunk: 383 | end += 1 384 | partitioned_data.append(data[start:end]) 385 | start = end 386 | return partitioned_data 387 | 388 | if prev_cand_parsed is None: 389 | partitioned_data = partition_data(cands) 390 | cand_parsed = self.process_samples_multiprocessing(partitioned_data, desc='parsing cand') 391 | else: 392 | print("parsing cand skip") 393 | cand_parsed = prev_cand_parsed 394 | 395 | if prev_gt_parsed is None: 396 | partitioned_data = partition_data(gts) 397 | gt_parsed = self.process_samples_multiprocessing(partitioned_data, desc='parsing gt') 398 | else: 399 | print("parsing gt skip") 400 | gt_parsed = prev_gt_parsed 401 | 402 | scores = [] 403 | parse_results = [] 404 | for sample_key in tqdm.tqdm(gt_parsed.keys(), desc="computing score"): 405 | sample_gt_parsed, sample_cand_parsed = gt_parsed[sample_key], cand_parsed[sample_key][0] 406 | results = [ 407 | self.compute_f_score(this_gt_parsed, sample_cand_parsed) for this_gt_parsed in sample_gt_parsed 408 | ] 409 | sample_scores = [result[0] for result in results] 410 | sample_score = sum(sample_scores) / len(sample_scores) 411 | scores.append(sample_score) 412 | parse_results.append({ 413 | "sample_key": sample_key, 414 | "gt_parsed": sample_gt_parsed, 415 | "cand_parsed": sample_cand_parsed, 416 | "object_precision": round(mean([result[1]*100 for result in results]), 2), 417 | 'object_recall': round(mean([result[2]*100 for result in results]), 2), 418 | 'object_f1': round(mean([result[3]*100 for result in results]), 2), 419 | 'attribute_precision': round(mean([result[4]*100 for result in results]), 2), 420 | 'attribute_recall': round(mean([result[5]*100 for result in results]), 2), 421 | 'attribute_f1': round(mean([result[6]*100 for result in results]), 2), 422 | 'relation_precision': round(mean([result[7]*100 for result in results]), 2), 423 | 'relation_recall': round(mean([result[8]*100 for result in results]), 2), 424 | 'relation_f1': round(mean([result[9]*100 for result in results]), 2), 425 | }) 426 | 427 | score = sum(scores) / len(scores) 428 | 429 | if return_parse_results: 430 | return score, scores, parse_results 431 | else: 432 | return score, scores 433 | 434 | 435 | 436 | if __name__ == '__main__': 437 | torch.multiprocessing.set_start_method("spawn") 438 | 439 | refs = { 440 | 'example_0': [ 441 | "The image depicts a busy city street with cars running in the foreground, including a red car and a white truck. The street is surrounded by green trees. In the backgound of the image, modern edifices and a clock tower stand under a clear blue sky. ", 442 | "The image depicts a busy city street with cars running in the foreground, including a red car and a white truck. The street is surrounded by green trees. In the backgound of the image, modern edifices and a clock tower stand under a clear blue sky. " 443 | ], 444 | } 445 | preds = { 446 | 'example_0': [ 447 | "The image shows a red car, a white truck and other automobiles running on a city road. Pedestrians are walking on the side. Tall buildings can be seen under a clear blue sky." 448 | ] 449 | } 450 | assert refs.keys() == preds.keys() 451 | 452 | evaluator = CAPTURE() 453 | score = evaluator.compute_score(refs, preds) 454 | print(f"CAPTURE score: {score}") 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | -------------------------------------------------------------------------------- /capture/capture_metric/stop_words.py: -------------------------------------------------------------------------------- 1 | 2 | stop_words_list = [ 3 | 'background', 4 | 'scene', 5 | 'object', 6 | 'image', 7 | 'element', 8 | 'text', 9 | 'landscape', 10 | 'setting', 11 | 'space', 12 | 'foreground', 13 | 'life', 14 | 'backdrop', 15 | 'nature', 16 | 'design', 17 | 'atmosphere', 18 | 'surface', 19 | 'position', 20 | 'structure', 21 | 'hue', 22 | 'picture', 23 | 'color', 24 | 'frame', 25 | 'charm', 26 | 'area', 27 | 'attention', 28 | 'presence', 29 | 'attire', 30 | 'surroundings', 31 | 'world', 32 | 'distance', 33 | 'location', 34 | 'ambiance', 35 | 'gaze', 36 | 'tone', 37 | 'figure', 38 | 'field', 39 | 'center', 40 | 'day', 41 | 'air', 42 | 'pattern', 43 | 'tranquility', 44 | 'view', 45 | 'photo', 46 | 'story', 47 | 'activity', 48 | 'expression', 49 | 'item', 50 | 'time', 51 | 'composition', 52 | 'viewer', 53 | 'environment', 54 | 'horizon', 55 | 'moment', 56 | 'narrative', 57 | 'place', 58 | 'spectator', 59 | 'architecture', 60 | 'texture', 61 | 'anticipation', 62 | 'content', 63 | 'detail', 64 | 'back', 65 | 'cityscape', 66 | 'motion', 67 | 'top', 68 | 'comfort', 69 | 'display', 70 | 'journey', 71 | 'description', 72 | 'action', 73 | 'Atmosphäre', 74 | 'expanse', 75 | 'glow', 76 | 'adventure', 77 | 'individual', 78 | 'joy', 79 | 'city life', 80 | 'interior', 81 | 'character', 82 | 'night', 83 | 'angle', 84 | 'perspective', 85 | 'stance', 86 | 'history', 87 | 'shape', 88 | 'energy', 89 | 'spirit', 90 | 'match', 91 | 'movement', 92 | 'game', 93 | 'exterior', 94 | 'beauty', 95 | 'event', 96 | 'accent', 97 | 'focus', 98 | 'bustle', 99 | 'touch', 100 | 'camaraderie', 101 | 'aesthetic', 102 | 'vantage point', 103 | 'excitement', 104 | 'direction', 105 | 'warmth', 106 | 'harmony', 107 | 'border', 108 | 'right', 109 | 'terrain', 110 | 'mid-air', 111 | 'palette', 112 | 'message', 113 | 'interaction', 114 | 'workspace', 115 | 'spectacle', 116 | 'base', 117 | 'name', 118 | 'theme', 119 | 'mood', 120 | 'solitude', 121 | 'left', 122 | 'culture', 123 | 'grandeur', 124 | 'facade', 125 | 'boundary', 126 | 'white', 127 | 'information', 128 | 'illustration', 129 | 'row', 130 | 'balance', 131 | 'creativity', 132 | 'layout', 133 | 'style', 134 | 'simplicity', 135 | 'body language', 136 | 'photograph', 137 | 'celebration', 138 | 'hustle', 139 | 'work', 140 | 'form', 141 | 'court', 142 | 'edge', 143 | 'chaos', 144 | 'posture', 145 | 'elegance', 146 | 'handle', 147 | 'tableau', 148 | 'short', 149 | 'clothing', 150 | 'context', 151 | 'depth', 152 | 'uniform', 153 | 'rhythm', 154 | 'everyday life', 155 | 'experience', 156 | 'reflection', 157 | 'identity', 158 | 'pleasure', 159 | 'marvel', 160 | 'weather', 161 | 'neighborhood', 162 | 'track', 163 | 'urban life', 164 | 'functionality', 165 | 'silhouette', 166 | 'aura', 167 | 'pose', 168 | 'outdoors', 169 | 'conversation', 170 | 'imagination', 171 | 'surrounding', 172 | 'countertop', 173 | 'serenity', 174 | 'side', 175 | 'scenery', 176 | 'atmosfere', 177 | 'peak', 178 | 'traffic', 179 | 'arrangement', 180 | 'mystery', 181 | 'thrill', 182 | 'essence', 183 | 'formation', 184 | 'skill', 185 | 'outline', 186 | 'presentation', 187 | 'community', 188 | 'darkness', 189 | 'demeanor', 190 | 'section', 191 | 'group', 192 | 'positioning', 193 | 'scale', 194 | 'memory', 195 | 'living space', 196 | 'front', 197 | 'emotion', 198 | 'task', 199 | 'travel', 200 | 'freedom', 201 | 'feature', 202 | 'curiosity', 203 | 'decor', 204 | 'performance', 205 | 'dining experience', 206 | 'purpose', 207 | 'resilience', 208 | 'passion', 209 | 'Moment', 210 | 'art', 211 | 'wilderness', 212 | 'effect', 213 | 'urban setting', 214 | 'power', 215 | 'him', 216 | 'vibe', 217 | 'placement', 218 | 'destination', 219 | 'wild', 220 | 'vastness', 221 | 'determination', 222 | 'layer', 223 | 'heritage', 224 | 'contrast', 225 | 'innocence', 226 | 'beauté', 227 | 'decoration', 228 | 'infrastructure', 229 | 'urban', 230 | 'ensemble', 231 | 'page', 232 | 'countryside', 233 | 'winter', 234 | 'material', 235 | 'intersection', 236 | 'subject', 237 | 'stillness', 238 | 'appeal', 239 | 'height', 240 | 'play', 241 | 'tall', 242 | 'modernity', 243 | 'connection', 244 | 'natural', 245 | 'exploration', 246 | 'coexistence', 247 | 'bottom', 248 | 'tension', 249 | 'outdoor', 250 | 'Nature', 251 | 'natural beauté', 252 | 'technology', 253 | 'imagery', 254 | 'reality', 255 | 'relaxation', 256 | 'spot', 257 | 'point', 258 | 'show', 259 | 'significance', 260 | 'personality', 261 | 'her', 262 | 'snapshot', 263 | 'viewpoint', 264 | 'entrance', 265 | 'rim', 266 | 'operation', 267 | 'intensity', 268 | 'nostalgia', 269 | 'afternoon', 270 | 'it', 271 | 'stride', 272 | 'case', 273 | 'key', 274 | 'component', 275 | 'pant', 276 | 'collage', 277 | 'speed', 278 | 'promise', 279 | 'atmosphère', 280 | 'vibrancy', 281 | 'remote', 282 | 'teamwork', 283 | 'site', 284 | 'turn', 285 | 'end', 286 | 'network', 287 | 'love', 288 | 'flavor', 289 | 'strap', 290 | 'era', 291 | 'marking', 292 | 'music', 293 | 'confidence', 294 | 'knowledge', 295 | 'appearance', 296 | 'condition', 297 | 'routine', 298 | 'team', 299 | 'role', 300 | 'unity', 301 | 'bloom', 302 | 'ecosystem', 303 | 'topping', 304 | 'enclosure', 305 | 'order', 306 | 'step', 307 | 'flight', 308 | 'ingredient', 309 | 'diversity', 310 | 'tip', 311 | 'universe', 312 | 'breakfast', 313 | 'middle', 314 | 'treat', 315 | 'companionship', 316 | 'essential', 317 | 'red', 318 | 'symbolism', 319 | 'interpretation', 320 | ] -------------------------------------------------------------------------------- /capture/ckpt/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-multimodal-models/CAPTURE/52eeb2781e8b4b1854c07a57b573c75edfd45688/capture/ckpt/README.md -------------------------------------------------------------------------------- /capture/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='capture_metric', 5 | version='0.1.12', 6 | author='Hongyuan Dong', 7 | author_email='d_ousia@icloud.com', 8 | description='A package for detail image caption evaluation.', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/foundation-multimodal-models/CAPTURE', 12 | packages=find_packages(), 13 | include_package_data=True, 14 | install_requires=[ 15 | 'torch', 16 | 'transformers', 17 | 'tqdm', 18 | 'nltk', 19 | 'spacy', 20 | 'scipy', 21 | 'sentence-transformers', 22 | 'pandas', 23 | 'numpy', 24 | 'tabulate', 25 | 'FactualSceneGraph' 26 | # Add other dependencies needed for your package 27 | ], 28 | license="Apache-2.0", 29 | ) 30 | -------------------------------------------------------------------------------- /detail_caption_construction/ckpt/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-multimodal-models/CAPTURE/52eeb2781e8b4b1854c07a57b573c75edfd45688/detail_caption_construction/ckpt/README.md -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/caption_anything_caption_reorganization.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | vision_encoder_type: vit-eva_h/14 # eva_clip_g # vit-eva_h/14 3 | lm_decoder_type: 'llama' 4 | lm_pretrained_path: ./reservoir/vicuna-13b #hdfs://haruna/home/byte_data_aml_research/user/cruise/Altman/ckpt_ziye/vicuna_cn_90w_new_version 5 | stage: 'stage1' 6 | freeze_lm: true 7 | freeze_vision: true 8 | freeze_qformer: false 9 | freeze_linear: false 10 | gradient_checkpointing: ['vision'] 11 | max_new_tokens: 64 # max tokens in generation at validation stage 12 | tokenizer_path: ./reservoir/vicuna_tokenizer 13 | 14 | # qformer's config = bert config 15 | qformer_config: 16 | add_cross_attention: true 17 | attention_probs_dropout_prob: 0.1 18 | classifier_dropout: null 19 | cross_attention_freq: 2 20 | gradient_checkpointing: false 21 | hidden_act: gelu 22 | hidden_dropout_prob: 0.1 23 | hidden_size: 768 24 | initializer_range: 0.02 25 | intermediate_size: 3072 26 | layer_norm_eps: 1e-12 27 | max_position_embeddings: 512 28 | model_type: bert 29 | num_attention_heads: 12 30 | num_hidden_layers: 12 31 | pad_token_id: 0 32 | position_embedding_type: absolute 33 | query_length: 32 34 | type_vocab_size: 2 35 | use_cache: true 36 | vocab_size: 30522 37 | load_legacy_ckpt: # TODO: Qformer.pth doesn't include "query_tokens" 38 | - ./reservoir/mp_rank_00_model_states.pt 39 | ckpt_rename_parameters: 40 | - {'module.': ''} 41 | 42 | 43 | origin: false 44 | 45 | hdfs_data_paths: 46 | # - hdfs://haruna/home/byte_data_aml/user/donghongyuan.dousia/caption_anything/case_study/llava_local_caption.parquet 47 | # - hdfs://haruna/home/byte_data_aml/user/donghongyuan.dousia/caption_anything/dataset/zc_1m/local_caption/* 48 | - reservoir/local_caption_cropped_merged_processed_data 49 | # - hdfs://haruna/home/byte_data_aml/user/donghongyuan.dousia/caption_anything/case_study/filter_area_merge_bbox_kmeans_local_caption.parquet 50 | 51 | 52 | # hdfs_data_processed_path: hdfs://haruna/home/byte_data_aml/user/donghongyuan.dousia/caption_anything/dataset/detailcaps_100_llava15_recaption/caption_reorganization_cropped_filtered 53 | hdfs_data_processed_path: hdfs://haruna/home/byte_data_aml/user/donghongyuan.dousia/caption_anything/dataset/detailcaps_5000_llava15_7b_pipeline/caption_reorganization_cropped -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/stage1_overall_caption.yaml: -------------------------------------------------------------------------------- 1 | 2 | model_path: llava-v1.5-7b 3 | img_key: 'frame' 4 | batch_size: 8 5 | 6 | ckpt_path: detail_caption_construction/ckpt/ 7 | source_path: detail_caption_construction/data/source_data/ 8 | target_path: detail_caption_construction/data/stage1_overall_caption/ 9 | -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/stage2_bbox.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | # 7b pipeline 3 | visual_encoder: 'vit_l' 4 | checkpoint: sam/sam_vit_l_0b3195.pth 5 | # # 13b pipeline 6 | # visual_encoder: 'vit_h' 7 | # checkpoint: sam/sam_vit_h_4b8939.pth 8 | mask: 9 | points_per_side: 8 # point num per size, 10 | pred_iou_thresh: 0.66 # iou thresh 11 | stability_score_thresh: 0.86 # 12 | crop_n_layers: 1 13 | crop_n_points_downscale_factor: 2 14 | min_mask_region_area: 100 # Requires open-cv to run post-processing on the mask. 15 | 16 | cluster: 17 | merge_threshold: 0. 18 | kmeans_center_num: 10 19 | compress_scale: 4 20 | expected_cropped_bbox_num_per_bbox: 2 21 | bbox_to_be_cropped_num: 3 22 | Draw: 23 | draw_bbox: False 24 | 25 | do_cluster: True 26 | do_crop: True 27 | do_eval: True 28 | img_key: 'frame' 29 | batch_size: 1 30 | 31 | ckpt_path: detail_caption_construction/ckpt/ 32 | source_path: detail_caption_construction/data/stage1_overall_caption/ 33 | target_path: detail_caption_construction/data/stage2_bbox/ 34 | -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/stage3_local_caption.yaml: -------------------------------------------------------------------------------- 1 | 2 | model_path: llava-v1.5-7b 3 | batch_size: 8 4 | 5 | ckpt_path: detail_caption_construction/ckpt/ 6 | source_path: detail_caption_construction/data/stage2_bbox/ 7 | target_path: detail_caption_construction/data/stage3_local_caption/ 8 | -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/stage4_filter.yaml: -------------------------------------------------------------------------------- 1 | 2 | model_path: owlv2-large-patch14-ensemble 3 | threshold: 0.01 4 | nms_threshold: 0.1 5 | batch_size: 8 6 | 7 | ckpt_path: detail_caption_construction/ckpt/ 8 | source_path: detail_caption_construction/data/stage3_local_caption/ 9 | target_path: detail_caption_construction/data/stage4_filter/ 10 | -------------------------------------------------------------------------------- /detail_caption_construction/config_llava15_7b_detailcaps_4870/stage5_caption_merge.yaml: -------------------------------------------------------------------------------- 1 | 2 | model_path: Llama-2-7b-chat-hf # Llama-2-13b-chat-hf 3 | batch_size: 4 4 | 5 | ckpt_path: detail_caption_construction/ckpt/ 6 | source_path: detail_caption_construction/data/stage4_filter/ 7 | target_path: detail_caption_construction/data/stage5_caption_merge/ 8 | 9 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_all.sh: -------------------------------------------------------------------------------- 1 | 2 | set -ex 3 | 4 | node_index=$1 5 | node_num=$2 6 | chunk_num=$(nvidia-smi | grep MiB | wc -l) 7 | 8 | bash prepare.sh 9 | 10 | 11 | for (( chunk_index=0; chunk_index<=$[$chunk_num-1]; chunk_index++ )) 12 | do 13 | CUDA_VISIBLE_DEVICES=$chunk_index nohup python3 detail_caption_construction/generate_stage1_overall_caption.py \ 14 | --config_path detail_caption_construction/config_llava15_7b_detailcaps_4870/stage1_overall_caption.yaml \ 15 | --chunk_index $chunk_index \ 16 | --chunk_num $chunk_num \ 17 | --node_index $node_index \ 18 | --node_num $node_num > detail_caption_construction/scripts_output/stage1_overall_caption_$chunk_index.log 2>&1 & 19 | done 20 | 21 | python3 detail_caption_construction/merge_results.py --config detail_caption_construction/config_llava15_7b_detailcaps_4870/stage1_overall_caption.yaml --node_index $node_index --node_num $node_num > detail_caption_construction/scripts_output/watch_and_upload_stage1_overall_caption.log 2>&1 22 | wait 23 | 24 | 25 | for (( chunk_index=0; chunk_index<=$[$chunk_num-1]; chunk_index++ )) 26 | do 27 | CUDA_VISIBLE_DEVICES=$chunk_index nohup python3 detail_caption_construction/generate_stage2_bbox.py \ 28 | --config_path detail_caption_construction/config_llava15_7b_detailcaps_4870/stage2_bbox.yaml \ 29 | --chunk_index $chunk_index \ 30 | --chunk_num $chunk_num \ 31 | --node_index $node_index \ 32 | --node_num $node_num > detail_caption_construction/scripts_output/stage2_bbox_$chunk_index.log 2>&1 & 33 | done 34 | 35 | python3 detail_caption_construction/merge_results.py --config detail_caption_construction/config_llava15_7b_detailcaps_4870/stage2_bbox.yaml --node_index $node_index --node_num $node_num > detail_caption_construction/scripts_output/watch_and_upload_stage2_bbox.log 2>&1 36 | wait 37 | 38 | 39 | for (( chunk_index=0; chunk_index<=$[$chunk_num-1]; chunk_index++ )) 40 | do 41 | CUDA_VISIBLE_DEVICES=$chunk_index nohup python3 detail_caption_construction/generate_stage3_local_caption.py \ 42 | --config_path detail_caption_construction/config_llava15_7b_detailcaps_4870/stage3_local_caption.yaml \ 43 | --chunk_index $chunk_index \ 44 | --chunk_num $chunk_num \ 45 | --node_index $node_index \ 46 | --node_num $node_num > detail_caption_construction/scripts_output/stage3_local_caption_$chunk_index.log 2>&1 & 47 | done 48 | 49 | python3 detail_caption_construction/merge_results.py --config detail_caption_construction/config_llava15_7b_detailcaps_4870/stage3_local_caption.yaml --node_index $node_index --node_num $node_num > detail_caption_construction/scripts_output/watch_and_upload_stage3_local_caption.log 2>&1 50 | wait 51 | 52 | 53 | for (( chunk_index=0; chunk_index<=$[$chunk_num-1]; chunk_index++ )) 54 | do 55 | CUDA_VISIBLE_DEVICES=$chunk_index nohup python3 detail_caption_construction/generate_stage4_filter.py \ 56 | --config_path detail_caption_construction/config_llava15_7b_detailcaps_4870/stage4_filter.yaml \ 57 | --chunk_index $chunk_index \ 58 | --chunk_num $chunk_num \ 59 | --node_index $node_index \ 60 | --node_num $node_num > detail_caption_construction/scripts_output/stage4_filter_$chunk_index.log 2>&1 & 61 | done 62 | 63 | python3 detail_caption_construction/merge_results.py --config detail_caption_construction/config_llava15_7b_detailcaps_4870/stage4_filter.yaml --node_index $node_index --node_num $node_num > detail_caption_construction/scripts_output/watch_and_upload_stage4_filter.log 2>&1 64 | wait 65 | 66 | 67 | for (( chunk_index=0; chunk_index<=$[$chunk_num-1]; chunk_index++ )) 68 | do 69 | CUDA_VISIBLE_DEVICES=$chunk_index nohup python3 detail_caption_construction/generate_stage5_caption_merge.py \ 70 | --config_path detail_caption_construction/config_llava15_7b_detailcaps_4870/stage5_caption_merge.yaml \ 71 | --chunk_index $chunk_index \ 72 | --chunk_num $chunk_num \ 73 | --node_index $node_index \ 74 | --node_num $node_num > detail_caption_construction/scripts_output/stage5_caption_merge_$chunk_index.log 2>&1 & 75 | done 76 | 77 | python3 detail_caption_construction/merge_results.py --config detail_caption_construction/config_llava15_7b_detailcaps_4870/stage5_caption_merge.yaml --node_index $node_index --node_num $node_num > detail_caption_construction/scripts_output/watch_and_upload_stage5_caption_merge.log 2>&1 78 | wait 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_stage1_overall_caption.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import os 19 | import sys 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..')) 21 | 22 | import io 23 | import torch 24 | import pandas as pd 25 | import yaml 26 | import tqdm 27 | import argparse 28 | from PIL import Image 29 | 30 | from LLaVA.llava.constants import IMAGE_TOKEN_INDEX 31 | from LLaVA.llava.model.builder import load_pretrained_model 32 | from LLaVA.llava.mm_utils import process_images, tokenizer_image_token 33 | from LLaVA.llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 34 | from LLaVA.llava.constants import IMAGE_TOKEN_INDEX 35 | from detail_caption_construction.utils.utils import get_data_files 36 | 37 | 38 | def spotter_llava(model, batch, image_processor, tokenizer, img_key): 39 | images = [Image.open(io.BytesIO(batch.loc[i, img_key])).convert("RGB") for i in range(len(batch))] 40 | # from IPython import embed; embed() 41 | try: 42 | image_tensor = process_images(images, image_processor, None).to(model.dtype) 43 | except Exception as e: 44 | all_image_tensor = [] 45 | for image in images: 46 | try: 47 | this_image_tensor = process_images([image], image_processor, None).to(model.dtype) 48 | all_image_tensor.append(this_image_tensor) 49 | except Exception as e: 50 | print("an image is corrupted, skipping") 51 | image_tensor = torch.cat(all_image_tensor) 52 | 53 | llava_ori_prompt = 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \n' 54 | spotter_prompt = ' Describe this image in detail. ASSISTANT:' 55 | spotter_prompts = [llava_ori_prompt + spotter_prompt for _ in range(image_tensor.shape[0])] 56 | input_ids = tokenizer_image_token(spotter_prompts[0], tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').to(model.device) 57 | all_input_ids = torch.tile(input_ids, [image_tensor.shape[0], 1]) 58 | with torch.inference_mode(): 59 | with torch.amp.autocast('cuda', dtype=torch.bfloat16): 60 | output_ids = model.generate( 61 | all_input_ids, 62 | images=image_tensor, 63 | do_sample=True, 64 | temperature=0.2, 65 | max_new_tokens=192, 66 | use_cache=True) 67 | 68 | res = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 69 | batch_processed_data = [] 70 | for i in range(len(res)): 71 | processed_sample = batch.loc[i].to_dict() 72 | res[i] = res[i].replace(spotter_prompt, '') 73 | processed_sample['overall_caption'] = res[i] 74 | batch_processed_data.append(processed_sample) 75 | 76 | return batch_processed_data 77 | 78 | 79 | def main(config_path, chunk_index, chunk_num, node_index, node_num): 80 | with open(config_path) as f: 81 | config = yaml.load(f,Loader=yaml.FullLoader) 82 | 83 | lvlm = config['model_path'] 84 | model_path = f"{config['ckpt_path']}/{lvlm}" 85 | print(f"loading {model_path}") 86 | if "llava" in lvlm and '1.6' in lvlm: 87 | model_name = get_model_name_from_path(model_path) 88 | elif "llava" in lvlm: 89 | model_name = lvlm 90 | else: 91 | raise ValueError(f"lvlm {lvlm} not supported") 92 | tokenizer, llava_model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) 93 | llava_model.eval().cuda() 94 | 95 | img_key = config['img_key'] 96 | batch_size = config['batch_size'] 97 | source_data_files, target_data_files = get_data_files(config, node_index=node_index, node_num=node_num) 98 | 99 | for source_data_file in source_data_files: 100 | if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" in target_data_files: 101 | print(f"file {source_data_file} processed, skipping") 102 | continue 103 | print(f"processing {source_data_file}") 104 | processed_data = [] 105 | df = pd.read_parquet(source_data_file) 106 | start, end = chunk_index * (len(df) // chunk_num), (chunk_index + 1) * (len(df) // chunk_num) - 1 107 | if len(df) - end < len(df) // chunk_num: 108 | end = len(df) - 1 109 | df = df.loc[start: end] 110 | for offset in tqdm.trange(0, len(df), batch_size): 111 | offset += start 112 | batch = df.loc[offset: offset + batch_size - 1].reset_index(drop=True) 113 | batch_processed_data = [] 114 | 115 | with torch.no_grad(): 116 | batch_processed_data = spotter_llava(llava_model, batch, image_processor, tokenizer, img_key) 117 | processed_data.extend(batch_processed_data) 118 | 119 | processed_df = pd.DataFrame(processed_data).reset_index(drop=True) 120 | base_path = os.path.basename(source_data_file) 121 | output_path = f"detail_caption_construction/data/processed_data/{base_path.split('.')[0]}_chunk{chunk_index}.parquet" 122 | processed_df.to_parquet(output_path) 123 | 124 | 125 | if __name__ == '__main__': 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument('--chunk_index', type=int) 128 | parser.add_argument('--chunk_num', type=int, default=8) 129 | parser.add_argument('--node_index', type=int, default=0) 130 | parser.add_argument('--node_num', type=int, default=1) 131 | parser.add_argument('--config_path', type=str) 132 | args = parser.parse_args() 133 | 134 | main( 135 | config_path=args.config_path, 136 | chunk_index=args.chunk_index, 137 | chunk_num=args.chunk_num, 138 | node_index=args.node_index, 139 | node_num=args.node_num 140 | ) 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_stage2_bbox.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | from collections import defaultdict 19 | from email.policy import default 20 | import os 21 | import sys 22 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..')) 23 | # sys.path.append('tagore/module/detection/CoDETR') 24 | sys.path.append('tagore/module/detection/SAM/sam_train_eval_example') 25 | 26 | import cv2 27 | import yaml 28 | import torch 29 | import argparse 30 | import pandas as pd 31 | import numpy as np 32 | from cruise.data_module.utils import parse_data_source 33 | 34 | from PIL import Image 35 | from abc import ABC, abstractmethod 36 | from tqdm import tqdm 37 | 38 | 39 | from segment_anything import SamAutomaticMaskGenerator, sam_model_registry 40 | from pycocotools import mask as mask_utils 41 | from matplotlib import pyplot as plt 42 | from collections import defaultdict 43 | 44 | from detail_caption_construction.utils.bbox_cluster import cluster, convert_bbox 45 | from detail_caption_construction.utils.bbox_statistics import compute_metrics 46 | from detail_caption_construction.utils.utils import get_data_files 47 | 48 | 49 | class BBoxPredictor(ABC): 50 | def __init__(self, config): 51 | self.config = config 52 | self.init_predictor() 53 | 54 | @abstractmethod 55 | def init_predictor(self, **kwargs): 56 | pass 57 | 58 | @abstractmethod 59 | def process_input(self, inputs, **kwargs): 60 | pass 61 | 62 | @abstractmethod 63 | def predict(self, batch, **kwargs): 64 | pass 65 | 66 | @abstractmethod 67 | def process_output(self, outputs, **kwargs): 68 | pass 69 | 70 | 71 | class SAMPredictor(BBoxPredictor): 72 | 73 | def init_predictor(self): 74 | self.device = f"cuda:{os.environ.get('RANK', 0)}" 75 | local_checkpoint = 'detail_caption_construction/ckpt/sam/' + os.path.basename(self.config['checkpoint']) 76 | print(f"loading {local_checkpoint}") 77 | sam = sam_model_registry[self.config['visual_encoder']]( 78 | checkpoint=local_checkpoint) 79 | sam.to(self.device) 80 | self.mask_generator = SamAutomaticMaskGenerator( 81 | model=sam, 82 | points_per_side=self.config['mask']['points_per_side'], # point num per size, 83 | pred_iou_thresh=self.config['mask']['pred_iou_thresh'], # iou thresh 84 | stability_score_thresh=self.config['mask']['stability_score_thresh'], # 85 | crop_n_layers=self.config['mask']['crop_n_layers'], 86 | crop_n_points_downscale_factor=self.config['mask']['crop_n_points_downscale_factor'], 87 | min_mask_region_area=self.config['mask']['min_mask_region_area'], # Requires open-cv to run post-processing 88 | ) 89 | 90 | def visualize(self, image, bboxes, out_file, linewidth=2): 91 | 92 | image_h, image_w = image.shape[:2] 93 | fig, ax = plt.subplots(figsize=(image_w/100, image_h/100), dpi=100) 94 | ax.axis('off') 95 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) 96 | ax.imshow(image) 97 | 98 | for i in range(len(bboxes)): 99 | color = np.random.rand(3) 100 | bbox = bboxes[i] 101 | x0, y0, w, h = bbox 102 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=linewidth)) 103 | 104 | plt.savefig(out_file, format='jpeg') 105 | plt.close() 106 | 107 | def process_input(self, inputs, img_key, order="RGB"): 108 | batch_images = [] 109 | batch_imagehw = [] 110 | for i, sample in inputs.iterrows(): 111 | img_np = np.frombuffer(sample[img_key], np.uint8) 112 | image = cv2.imdecode(img_np, cv2.IMREAD_COLOR) 113 | if image is None: 114 | batch_images.append(None) 115 | batch_imagehw.append((0,0)) 116 | continue 117 | if order == "RGB": 118 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 119 | elif order == "BGR": 120 | pass 121 | batch_images.append(image) 122 | batch_imagehw.append(image.shape[:2]) 123 | return {'image': batch_images, 'hw': batch_imagehw} 124 | 125 | def predict(self, batch): 126 | outputs = [] 127 | for image in batch["image"]: 128 | try: 129 | mask = self.mask_generator.generate(image) 130 | except Exception as e: 131 | mask = None 132 | outputs.append(mask) 133 | batch["outputs"] = outputs 134 | torch.cuda.empty_cache() 135 | return batch 136 | 137 | def process_output(self, outputs): 138 | result = defaultdict(list) 139 | output = outputs["outputs"] 140 | result["hw"] = outputs["hw"] 141 | bboxes = [] 142 | for i in range(len(output)): 143 | bbox = [] 144 | if output[i] is None: 145 | result["sam"].append(None) 146 | bboxes.append(bbox) 147 | continue 148 | for j in range(len(output[i])): 149 | output[i][j]["segmentation"] = mask_utils.encode(output[i][j]["segmentation"]) 150 | output[i][j]["segmentation"]["counts"] = output[i][j]["segmentation"]["counts"].decode() 151 | x, y, w, h = output[i][j]["bbox"] 152 | bbox.append([x, y, x+w, y+h]) 153 | bboxes.append(bbox) 154 | result["sam"].append(output[i]) 155 | result["SAM_bboxes"] = bboxes 156 | return result 157 | 158 | 159 | def main(config_path, chunk_index, chunk_num, node_index, node_num): 160 | with open(config_path) as f: 161 | config = yaml.load(f,Loader=yaml.FullLoader) 162 | do_cluster, do_crop, do_eval = config['do_cluster'], config['do_crop'], config['do_eval'] 163 | img_key = config['img_key'] 164 | 165 | model_hparams = config['model'] 166 | model = SAMPredictor(config=model_hparams) 167 | 168 | batch_size = config['batch_size'] 169 | source_data_files, target_data_files = get_data_files(config, node_index=node_index, node_num=node_num) 170 | 171 | for source_data_file in source_data_files: 172 | if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" in target_data_files: 173 | print(f"file {source_data_file} processed, skipping") 174 | continue 175 | print(f"processing {source_data_file}") 176 | 177 | df = pd.read_parquet(source_data_file) 178 | start, end = chunk_index * (len(df) // chunk_num), (chunk_index + 1) * (len(df) // chunk_num) - 1 179 | if len(df) - end < len(df) // chunk_num: 180 | end = len(df) - 1 181 | df = df.loc[start: end] 182 | 183 | processed_data = defaultdict(list) 184 | for offset in tqdm(range(0, len(df), batch_size)): 185 | offset += start 186 | inputs = df.loc[offset: offset + batch_size - 1] 187 | batch = model.process_input(inputs, img_key=img_key) 188 | outputs = model.predict(batch) 189 | results = model.process_output(outputs) 190 | 191 | for key, val in results.items(): 192 | processed_data[key].extend(val) 193 | 194 | if do_cluster: 195 | print("### Doing clustering ###") 196 | item_id = df["item_id"].tolist() if "item_id" in df.columns else [i+start for i in range(len(df))] 197 | cluster_info = {"item_id": item_id, img_key: df[img_key].tolist(), "bboxes": processed_data[f"{'SAM'}_bboxes"], "hw": processed_data["hw"]} 198 | df_cluster = pd.DataFrame(cluster_info) 199 | df_cluster = cluster(df_cluster, 'SAM', config['cluster']) 200 | print("### Clustering completed ###") 201 | 202 | if do_eval: 203 | print("### Doing evaluation ###") 204 | compute_metrics(df_cluster, config['cluster']['compress_scale'], keys=["cluster_centers", "merged_cluster_centers", "cropped_boxes"]) 205 | 206 | processed_data[f"{'SAM'}_cluster_centers"] = df_cluster["cluster_centers"] 207 | processed_data[f"{'SAM'}_merged_cluster_centers"] = df_cluster["merged_cluster_centers"] 208 | if do_crop: 209 | processed_data[f"{'SAM'}_cropped_boxes"] = df_cluster["cropped_boxes"] 210 | 211 | df = df.reset_index(drop=True) 212 | for key, val in processed_data.items(): 213 | df[key] = val 214 | 215 | if not do_cluster and do_eval: 216 | print("### Doing evaluation ###") 217 | compute_metrics(df, config['cluster']['compress_scale'], baseline_key="bboxes") 218 | 219 | # from IPython import embed; embed() 220 | base_path = os.path.basename(source_data_file) 221 | output_path = f"detail_caption_construction/data/processed_data/{base_path.split('.')[0]}_chunk{chunk_index}.parquet" 222 | print(df) 223 | print(output_path) 224 | print(df.columns) 225 | df.to_parquet(output_path) 226 | 227 | 228 | if __name__ == '__main__': 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument('--config_path', type=str) 231 | parser.add_argument('--chunk_index', type=int, default=0) 232 | parser.add_argument('--chunk_num', type=int, default=1) 233 | parser.add_argument('--node_index', type=int, default=0) 234 | parser.add_argument('--node_num', type=int, default=10) 235 | 236 | args = parser.parse_args() 237 | main( 238 | config_path=args.config_path, 239 | chunk_index=args.chunk_index, 240 | chunk_num=args.chunk_num, 241 | node_index=args.node_index, 242 | node_num=args.node_num, 243 | ) 244 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_stage3_local_caption.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import math 19 | import os 20 | import sys 21 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..')) 22 | 23 | import io 24 | import torch 25 | import pandas as pd 26 | import numpy as np 27 | import yaml 28 | import tqdm 29 | import argparse 30 | from PIL import Image 31 | 32 | import sys 33 | sys.path.append('./reservoir/llava_code_base') 34 | from LLaVA.llava.constants import IMAGE_TOKEN_INDEX 35 | from LLaVA.llava.model.builder import load_pretrained_model 36 | from LLaVA.llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 37 | from LLaVA.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 38 | from LLaVA.llava.conversation import conv_templates 39 | from detail_caption_construction.utils.utils import get_data_files 40 | 41 | 42 | def process_batch(model, batch, image_processor, tokenizer): 43 | def get_area(coordinates): 44 | return (coordinates[3] - coordinates[1]) * (coordinates[2] - coordinates[0]) 45 | 46 | def get_cropped_images_and_indices(batch): 47 | all_image = [] 48 | all_processed_boxes = [] 49 | all_cropped_images = [] 50 | all_cropped_images_indices = [] 51 | for index, sample in batch.iterrows(): 52 | image = np.array(Image.open(io.BytesIO(sample['frame'])).convert("RGB")) 53 | boxes = [[round(coordinate) for coordinate in coordinates] for coordinates in sample['SAM_cropped_boxes']] 54 | 55 | processed_boxes = [] 56 | cropped_images = [] 57 | for box in boxes: 58 | if True: # get_area(box) < 7000: 59 | dilate_x, dilate_y = sample['hw'][1] // 50, sample['hw'][0] // 50, 60 | box[0], box[1], box[2], box[3] = max(0, box[0]-dilate_x), max(0, box[1]-dilate_y), min(image.shape[1] - 1, box[2]+dilate_x), min(image.shape[0] - 1, box[3]+dilate_y) 61 | if box[2] - box[0] <= 2 or box[3] - box[1] <= 2: 62 | continue 63 | processed_boxes.append(box) 64 | cropped_images.append(image[box[1]: box[3], box[0]: box[2], :]) 65 | all_processed_boxes.append(processed_boxes) 66 | cropped_images = [Image.fromarray(img.astype('uint8')).convert('RGB') for img in cropped_images] 67 | all_image.append(image) 68 | all_cropped_images.extend(cropped_images) 69 | all_cropped_images_indices.extend([index]*len(cropped_images)) 70 | 71 | return all_image, all_processed_boxes, all_cropped_images, all_cropped_images_indices 72 | 73 | 74 | all_image, all_processed_boxes, all_cropped_images, all_cropped_images_indices = get_cropped_images_and_indices(batch) 75 | llava_ori_prompt = 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions. USER: \n' 76 | spotter_prompt = ' describe this picture in detail with no more than twenty words. ASSISTANT:' 77 | final_prompt = llava_ori_prompt + spotter_prompt 78 | this_batch_size = 32 79 | all_res = [] 80 | for offset in range(0, len(all_cropped_images), this_batch_size): 81 | batch_cropped_images = all_cropped_images[offset: offset + this_batch_size] 82 | image_tensor = process_images(batch_cropped_images, image_processor, None).to(model.dtype) 83 | input_ids = tokenizer_image_token(final_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').to(model.device) 84 | all_input_ids = torch.tile(input_ids, [len(batch_cropped_images), 1]) 85 | with torch.inference_mode(): 86 | with torch.amp.autocast('cuda', dtype=torch.bfloat16): 87 | output_ids = model.generate( 88 | all_input_ids, 89 | images=image_tensor, 90 | do_sample=True, 91 | temperature=0.2, 92 | max_new_tokens=64, 93 | use_cache=True 94 | ) 95 | res = tokenizer.batch_decode(output_ids, skip_special_tokens=True) 96 | all_res.extend(res) 97 | 98 | regrouped_all_res = [[] for _ in range(len(batch))] 99 | for index, res in enumerate(all_res): 100 | regrouped_all_res[all_cropped_images_indices[index]].append(res) 101 | 102 | return all_image, all_processed_boxes, regrouped_all_res 103 | 104 | 105 | def prepare_model_for_inference(model, dtype): 106 | model.cuda() 107 | model.eval() 108 | if dtype is not None: 109 | model.to(dtype) 110 | 111 | 112 | def visualize(image, boxes, semantic_tags, out_file='reservoir/temp.jpg', linewidth=6): 113 | import matplotlib.pyplot as plt 114 | 115 | image_h, image_w = image.shape[:2] 116 | fig, ax = plt.subplots(figsize=(image_w/100, image_h/100), dpi=100) 117 | ax.axis('off') 118 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) 119 | ax.imshow(image) 120 | 121 | draw_label_setting = {'facecolor': 'black', 'alpha': 0.8, 'pad': 0.7, 'edgecolor': 'none'} 122 | color = np.random.rand(len(semantic_tags), 3) 123 | n_terms = 10 124 | delta_y = math.floor(image_h/25) 125 | for i, (box, label) in enumerate(zip(boxes, semantic_tags)): 126 | box = [round(i, 2) for i in box] 127 | ax.add_patch(plt.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], edgecolor=color[i], facecolor=(0,0,0,0), lw=linewidth)) 128 | label = label.strip() 129 | label_ = label.split() 130 | if len(label_) < n_terms: 131 | ax.text(box[0], box[1], label, fontsize=6, bbox=draw_label_setting, verticalalignment='top', color="gray") 132 | else: 133 | n_labels = (len(label_)-1)//n_terms+1 134 | for label_idx in range(n_labels): 135 | start, end = label_idx * n_terms, min((n_terms * (label_idx+1), len(label_))) 136 | this_label = ' '.join(label_[start:end]) 137 | this_y = box[1] + delta_y * label_idx 138 | ax.text(box[0], this_y, this_label, fontsize=6, bbox=draw_label_setting, verticalalignment='top', color="gray") 139 | 140 | plt.savefig(out_file, format='jpeg') 141 | plt.close() 142 | 143 | 144 | 145 | def main(config_path, chunk_index, chunk_num, node_index, node_num): 146 | with open(config_path) as f: 147 | config = yaml.load(f,Loader=yaml.FullLoader) 148 | 149 | lvlm = config['model_path'] 150 | model_path = f"{config['ckpt_path']}/{lvlm}" 151 | print(f"loading {model_path}") 152 | if "llava" in lvlm and '1.6' in lvlm: 153 | model_name = get_model_name_from_path(model_path) 154 | elif "llava" in lvlm: 155 | model_name = lvlm 156 | else: 157 | raise ValueError(f"lvlm {lvlm} not supported") 158 | tokenizer, llava_model, image_processor, context_len = load_pretrained_model(model_path, None, model_name) 159 | llava_model.eval().cuda() 160 | 161 | batch_size = config['batch_size'] 162 | source_data_files, target_data_files = get_data_files(config, node_index=node_index, node_num=node_num) 163 | 164 | for source_data_file in source_data_files: 165 | if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" in target_data_files: 166 | print(f"file {source_data_file} processed, skipping") 167 | continue 168 | print(f"processing {source_data_file}") 169 | processed_data = [] 170 | df = pd.read_parquet(source_data_file) 171 | start, end = chunk_index * (len(df) // chunk_num), (chunk_index + 1) * (len(df) // chunk_num) - 1 172 | if len(df) - end < len(df) // chunk_num: 173 | end = len(df) - 1 174 | df = df.loc[start: end] 175 | 176 | for offset in tqdm.trange(0, len(df), batch_size): 177 | offset += start 178 | batch = df.loc[offset: offset + batch_size - 1].reset_index(drop=True) 179 | all_image, all_boxes, all_tags = process_batch(llava_model, batch, image_processor, tokenizer) 180 | for index, (image, boxes, tags) in enumerate(zip(all_image, all_boxes, all_tags)): 181 | sample = batch.loc[index].to_dict() 182 | local_caption = repr([(boxes[i], tags[i]) for i in range(len(boxes))]) 183 | sample['local_caption'] = local_caption 184 | processed_data.append(sample) 185 | 186 | processed_df = pd.DataFrame(processed_data).reset_index(drop=True) 187 | base_path = os.path.basename(source_data_file) 188 | output_path = f"detail_caption_construction/data/processed_data/{base_path.split('.')[0]}_chunk{chunk_index}.parquet" 189 | processed_df.to_parquet(output_path) 190 | 191 | 192 | if __name__ == '__main__': 193 | parser = argparse.ArgumentParser() 194 | parser.add_argument('--chunk_index', type=int, default=0) 195 | parser.add_argument('--chunk_num', type=int, default=4) 196 | parser.add_argument('--node_index', type=int, default=0) 197 | parser.add_argument('--node_num', type=int, default=10) 198 | parser.add_argument('--config_path', type=str) 199 | args = parser.parse_args() 200 | 201 | main( 202 | config_path=args.config_path, 203 | chunk_index=args.chunk_index, 204 | chunk_num=args.chunk_num, 205 | node_index=args.node_index, 206 | node_num=args.node_num, 207 | ) 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | # image = Image.open(io.BytesIO(image)).convert("RGB") 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_stage4_filter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import os 19 | import sys 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..')) 21 | 22 | import io 23 | import torch 24 | import pandas as pd 25 | import yaml 26 | import tqdm 27 | import argparse 28 | from PIL import Image 29 | from transformers import AutoProcessor, Owlv2ForObjectDetection 30 | from nltk.tokenize import sent_tokenize 31 | import re 32 | 33 | from capture.FactualSceneGraph.src.factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 34 | from detail_caption_construction.utils.utils import get_data_files 35 | 36 | 37 | stop_words = [] 38 | with open('capture/diffed_objects.txt', 'r') as f: 39 | for i, line in enumerate(f.readlines()): 40 | if i > 513: 41 | break 42 | if not line.startswith('#'): 43 | stop_words.append(line.strip().split(':')[0]) 44 | stop_words = set(stop_words) 45 | 46 | 47 | def get_phrases(parse_res, caption): 48 | objects = [entity['head'] for entity in parse_res['entities']] 49 | attributes = [(adj, entity['head']) for entity in parse_res['entities'] for adj in entity['attributes']] 50 | 51 | all_phrases = [] 52 | 53 | # first run: find {adj} {noun} phrases 54 | for attr_index, attribute in enumerate(attributes): 55 | phrase = f"{attribute[0]} {attribute[1]}" 56 | # all_phrases.append(phrase) 57 | if attribute[1] in objects: 58 | objects.remove(attribute[1]) 59 | 60 | # second run: find remaining noun phrases 61 | for obj_index, object in enumerate(objects): 62 | phrase = f"{object}" 63 | all_phrases.append(phrase) 64 | 65 | all_phrases = [phrase for phrase in all_phrases if phrase in caption] 66 | 67 | return all_phrases 68 | 69 | 70 | def get_elements(parse_res, caption): 71 | objects = [(entity['head'], entity['head']) for entity in parse_res['entities']] 72 | attributes = [(adj, f"{adj} {entity['head']}") for entity in parse_res['entities'] for adj in entity['attributes']] 73 | 74 | all_elements = [element for element in objects + attributes if element[0] in caption] 75 | all_phrases = [element[1] for element in all_elements] 76 | ori_words = [element[0] for element in all_elements] 77 | 78 | return all_phrases, ori_words 79 | 80 | 81 | def ground_attributes_to_sentence(parse_res, sentence): 82 | objects = [entity['head'] for entity in parse_res['entities']] 83 | attributes = [(adj, entity['head']) for entity in parse_res['entities'] for adj in entity['attributes']] 84 | match_info = {} 85 | attribute_shot_indices = set() 86 | object_shot_indices = set() 87 | temp_sentence = sentence 88 | 89 | # first run: find {adj} {noun} phrases 90 | for attr_index, attribute in enumerate(attributes): 91 | phrase = f"{attribute[0]} {attribute[1]}" 92 | start = sentence.find(phrase) 93 | if start != -1: 94 | match_info[phrase] = [start, start + len(phrase)] 95 | attribute_shot_indices.add(attr_index) 96 | if attribute[1] in objects: 97 | objects.remove(attribute[1]) 98 | temp_sentence = temp_sentence.replace(phrase, '#'*len(phrase)) 99 | 100 | # second run: find {noun} {adj} phrases 101 | for attr_index, attribute in enumerate(attributes): 102 | phrase = f"{attribute[1]} {attribute[0]}" 103 | start = sentence.find(phrase) 104 | if start != -1: 105 | match_info[phrase] = [start, start + len(phrase)] 106 | attribute_shot_indices.add(attr_index) 107 | if attribute[1] in objects: 108 | objects.remove(attribute[1]) 109 | temp_sentence = temp_sentence.replace(phrase, '#'*len(phrase)) 110 | 111 | # third run: find remaining noun phrases 112 | for obj_index, object in enumerate(objects): 113 | phrase = f"{object}" 114 | start = sentence.find(phrase) 115 | if start != -1: 116 | match_info[phrase] = [start, start + len(phrase)] 117 | object_shot_indices.add(obj_index) 118 | temp_sentence = temp_sentence.replace(phrase, '#'*len(phrase)) 119 | 120 | return match_info 121 | 122 | 123 | def process_batch(model, processor, parser, batch, threshold, nms_threshold): 124 | images = [Image.open(io.BytesIO(batch.loc[i, 'frame'])).convert("RGB") for i in range(len(batch))] 125 | 126 | all_local_caption = [eval(sample['local_caption']) for _, sample in batch.iterrows()] 127 | filtered_all_local_caption = [] 128 | for sample_idx, local_caption in enumerate(all_local_caption): 129 | graph_obj = parser.parse([caption for bbox, caption in local_caption], beam_size=5, return_text=False, max_output_len=128) 130 | all_phrases, all_ori_words = [], [] 131 | for bbox_idx, res in enumerate(graph_obj): 132 | phrases = get_phrases(res, local_caption[bbox_idx][1]) 133 | all_phrases.append(phrases) 134 | 135 | for phrases_idx, phrases in enumerate(all_phrases): 136 | phrases.append(local_caption[phrases_idx][1]) 137 | all_phrases[phrases_idx] = [' '.join(phrase.split(' ')) if ' ' in phrase else phrase for phrase in phrases ] 138 | 139 | sample_images = [images[sample_idx] for _ in range(len(all_phrases))] 140 | results = [] 141 | for mini_batch_start in range(0, len(all_phrases), 8): 142 | mini_batch_all_phrases = all_phrases[mini_batch_start: mini_batch_start+8] 143 | mini_batch_sample_images = sample_images[mini_batch_start: mini_batch_start+8] 144 | 145 | if len(mini_batch_all_phrases) == 0 or sum([len(phrases) for phrases in mini_batch_all_phrases]) == 0: 146 | mini_batch_results = [{ 147 | 'scores': torch.tensor([]), 148 | 'labels': torch.tensor([]), 149 | 'boxes': torch.tensor([]), 150 | } for _ in range(len(mini_batch_all_phrases))] 151 | results.extend(mini_batch_results) 152 | continue 153 | 154 | input_tensor = processor(text=mini_batch_all_phrases, images=mini_batch_sample_images, truncation=True, return_tensors="pt") 155 | input_tensor = input_tensor.to("cuda") 156 | 157 | with torch.no_grad(): 158 | outputs = model(**input_tensor) 159 | 160 | padded_image_size = [max(images[sample_idx].size), max(images[sample_idx].size)] 161 | target_sizes = [padded_image_size for _ in range(len(mini_batch_all_phrases))] 162 | 163 | if nms_threshold < 1.0: 164 | mini_batch_results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, nms_threshold=nms_threshold, threshold=threshold) 165 | else: 166 | mini_batch_results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) 167 | 168 | results.extend(mini_batch_results) 169 | 170 | filtered_local_caption = [] 171 | for bbox_idx, (phrases, result) in enumerate(zip(all_phrases, results)): 172 | label_scores = [0 for _ in range(len(phrases))] 173 | try: 174 | for score_idx in range(result['scores'].shape[0]): 175 | label_idx = result['labels'][score_idx] 176 | if label_idx < len(label_scores): 177 | label_scores[label_idx] = max(label_scores[label_idx], result['scores'][score_idx]) 178 | except: 179 | from IPython import embed; embed() 180 | 181 | this_local_caption = list(all_local_caption[sample_idx][bbox_idx]) 182 | 183 | for phrase_idx, phrase in enumerate(phrases[:-1]): 184 | if label_scores[phrase_idx] < threshold: 185 | this_local_caption[1] = re.sub(pattern=rf"([a-z]*){phrase}([a-z]*)", repl='', string=this_local_caption[1]) 186 | filtered_local_caption.append(this_local_caption) 187 | 188 | final_filtered_local_caption = filtered_local_caption 189 | filtered_all_local_caption.append(final_filtered_local_caption) 190 | 191 | all_overall_caption = [sample['overall_caption'] for _, sample in batch.iterrows()] 192 | filtered_all_overall_caption = [] 193 | for sample_idx, caption in enumerate(all_overall_caption): 194 | sentences = sent_tokenize(caption) 195 | graph_obj = parser.parse(sentences, beam_size=5, return_text=False, max_output_len=128) 196 | all_phrases, all_ori_words = [], [] 197 | for sent, res in zip(sentences, graph_obj): 198 | phrases_se = ground_attributes_to_sentence(res, sent) 199 | all_phrases.append(list(phrases_se.keys())) 200 | 201 | sample_images = [images[sample_idx] for _ in range(len(all_phrases))] 202 | results = [] 203 | for mini_batch_start in range(0, len(all_phrases), 8): 204 | mini_batch_all_phrases = all_phrases[mini_batch_start: mini_batch_start+8] 205 | mini_batch_sample_images = sample_images[mini_batch_start: mini_batch_start+8] 206 | 207 | if len(mini_batch_all_phrases) == 0 or sum([len(phrases) for phrases in mini_batch_all_phrases]) == 0: 208 | mini_batch_results = [{ 209 | 'scores': torch.tensor([]), 210 | 'labels': torch.tensor([]), 211 | 'boxes': torch.tensor([]), 212 | } for _ in range(len(mini_batch_all_phrases))] 213 | results.extend(mini_batch_results) 214 | continue 215 | 216 | input_tensor = processor(text=mini_batch_all_phrases, images=mini_batch_sample_images, truncation=True, return_tensors="pt") 217 | input_tensor = input_tensor.to("cuda") 218 | 219 | with torch.no_grad(): 220 | outputs = model(**input_tensor) 221 | 222 | padded_image_size = [max(images[sample_idx].size), max(images[sample_idx].size)] 223 | target_sizes = [padded_image_size for _ in range(len(mini_batch_all_phrases))] 224 | 225 | if nms_threshold < 1.0: 226 | mini_batch_results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, nms_threshold=nms_threshold, threshold=threshold) 227 | else: 228 | mini_batch_results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_sizes, threshold=threshold) 229 | 230 | results.extend(mini_batch_results) 231 | 232 | filtered_caption = [] 233 | for sent_idx, (sent, phrases, result) in enumerate(zip(sentences, all_phrases, results)): 234 | label_scores = [0 for _ in range(len(phrases))] 235 | try: 236 | for score_idx in range(result['scores'].shape[0]): 237 | label_idx = result['labels'][score_idx] 238 | if label_idx < len(label_scores): 239 | label_scores[label_idx] = max(label_scores[label_idx], result['scores'][score_idx]) 240 | except: 241 | from IPython import embed; embed() 242 | 243 | for phrase, score in zip(phrases, label_scores): 244 | if score < threshold: 245 | sent = sent.replace(phrase, "") 246 | 247 | filtered_caption.append(sent) 248 | 249 | filtered_caption = ' '.join(filtered_caption) 250 | filtered_all_overall_caption.append(filtered_caption) 251 | 252 | batch_processed_data = [] 253 | for sample_idx in range(len(batch)): 254 | processed_sample = batch.loc[sample_idx].to_dict() 255 | processed_sample['filtered_local_caption'] = repr(filtered_all_local_caption[sample_idx]) 256 | processed_sample['filtered_overall_caption'] = filtered_all_overall_caption[sample_idx] 257 | batch_processed_data.append(processed_sample) 258 | 259 | return batch_processed_data 260 | 261 | 262 | def prepare_model_for_inference(model, dtype): 263 | model.cuda() 264 | model.eval() 265 | if dtype is not None: 266 | model.to(dtype) 267 | 268 | 269 | def main(config_path, chunk_index, chunk_num, node_index, node_num): 270 | with open(config_path) as f: 271 | config = yaml.load(f,Loader=yaml.FullLoader) 272 | 273 | processor = AutoProcessor.from_pretrained(f"{config['ckpt_path']}/{config['model_path']}") 274 | model = Owlv2ForObjectDetection.from_pretrained(f"{config['ckpt_path']}/{config['model_path']}").to("cuda") 275 | parser = SceneGraphParser('capture/ckpt/flan-t5-base-VG-factual-sg', device='cuda') 276 | 277 | source_data_files, target_data_files = get_data_files(config, node_index=node_index, node_num=node_num) 278 | 279 | for source_data_file in source_data_files: 280 | if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" in target_data_files: 281 | print(f"file {source_data_file} processed, skipping") 282 | continue 283 | print(f"processing {source_data_file}") 284 | batch_size = config['batch_size'] 285 | processed_data = [] 286 | df = pd.read_parquet(source_data_file) 287 | 288 | start, end = chunk_index * (len(df) // chunk_num), (chunk_index + 1) * (len(df) // chunk_num) - 1 289 | if len(df) - end < len(df) // chunk_num: 290 | end = len(df) - 1 291 | df = df.loc[start: end] 292 | for offset in tqdm.trange(0, len(df), batch_size): 293 | offset += start 294 | batch = df.loc[offset: offset + batch_size - 1].reset_index(drop=True) 295 | 296 | with torch.inference_mode(): 297 | batch_processed_data = process_batch(model, processor, parser, batch, config['threshold'], config['nms_threshold']) 298 | processed_data.extend(batch_processed_data) 299 | 300 | processed_df = pd.DataFrame(processed_data).reset_index(drop=True) 301 | base_path = os.path.basename(source_data_file) 302 | output_path = f"detail_caption_construction/data/processed_data/{base_path.split('.')[0]}_chunk{chunk_index}.parquet" 303 | processed_df.to_parquet(output_path) 304 | 305 | 306 | if __name__ == '__main__': 307 | parser = argparse.ArgumentParser() 308 | parser.add_argument('--config_path', type=str) 309 | parser.add_argument('--chunk_index', type=int, default=0) 310 | parser.add_argument('--chunk_num', type=int, default=4) 311 | parser.add_argument('--node_index', type=int, default=0) 312 | parser.add_argument('--node_num', type=int, default=10) 313 | args = parser.parse_args() 314 | 315 | main( 316 | config_path=args.config_path, 317 | chunk_index=args.chunk_index, 318 | chunk_num=args.chunk_num, 319 | node_index=args.node_index, 320 | node_num=args.node_num 321 | ) 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | # image = Image.open(io.BytesIO(image)).convert("RGB") 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | -------------------------------------------------------------------------------- /detail_caption_construction/generate_stage5_caption_merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | 18 | import os 19 | import sys 20 | sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', '..')) 21 | 22 | import torch 23 | import pandas as pd 24 | import yaml 25 | import tqdm 26 | import argparse 27 | 28 | from detail_caption_construction.utils.utils import get_data_files 29 | 30 | 31 | def caption_merge(llm, left_padding_tokenizer, batch): 32 | def get_prompt(overall_caption, local_captions, mode="augment", icl=False): 33 | local_captions = [(h*w, caption) for [x1, y1, h, w], caption in local_captions] 34 | local_captions = [item[1].strip('.').strip(',') for item in local_captions] 35 | prompt = "" 36 | prompt = prompt + "Overall description: " + overall_caption + "\n\n" 37 | prompt = prompt + " Elements appearing in the image: " + '. '.join(local_captions) + '. \n\n' 38 | prompt = prompt + "Please combine them into a description, refining the overall description with detailed annotations. " + \ 39 | "Avoid simply concatenating the sentences. Ignore elements that do not fit in the scene. Reply with no more than three hundred words.\n\n" + \ 40 | "Refined description: " 41 | return prompt 42 | 43 | batch_processed_data = [sample.to_dict() for i, sample in batch.iterrows()] 44 | if 'filtered_overall_caption' in batch_processed_data[0].keys() and 'filtered_local_caption' in batch_processed_data[0].keys(): 45 | overall_caption_key, local_caption_key = 'filtered_overall_caption', 'filtered_local_caption' 46 | else: 47 | overall_caption_key, local_caption_key = 'overall_caption', 'local_caption' 48 | 49 | if left_padding_tokenizer.pad_token == left_padding_tokenizer.unk_token: 50 | overall_captions = [f"{get_prompt(sample[overall_caption_key], eval(sample[local_caption_key]))} " for sample in batch_processed_data] 51 | tokenized_input = left_padding_tokenizer(overall_captions, add_special_tokens=False, padding='longest', return_tensors='pt') 52 | input_ids, attention_mask = tokenized_input['input_ids'].to(llm.device), tokenized_input['attention_mask'].to(llm.device) 53 | 54 | with torch.inference_mode(): 55 | res = left_padding_tokenizer.batch_decode(llm.generate( 56 | input_ids=input_ids, 57 | attention_mask=attention_mask, 58 | eos_token_id=left_padding_tokenizer.eos_token_id, 59 | pad_token_id=left_padding_tokenizer.pad_token_id, 60 | use_cache=True, 61 | max_new_tokens=500, 62 | do_sample=False, 63 | temperature=0.2, 64 | num_beams=3, 65 | top_p=0.95, 66 | num_return_sequences=1), skip_special_tokens=True) 67 | 68 | res = [this_res.replace(overall_captions[i].replace('', ''), '').strip('\n ') for i, this_res in enumerate(res)] 69 | 70 | for i, sample in enumerate(batch_processed_data): 71 | sample['synthesized_caption'] = res[i] 72 | 73 | return batch_processed_data 74 | 75 | 76 | def main(config_path, chunk_index, chunk_num, node_index, node_num): 77 | with open(config_path) as f: 78 | config = yaml.load(f,Loader=yaml.FullLoader) 79 | 80 | from transformers import AutoTokenizer, AutoModelForCausalLM 81 | path = f"{config['ckpt_path']}/{config['model_path']}" 82 | print(f"loading {path}") 83 | left_padding_tokenizer = AutoTokenizer.from_pretrained(path, padding_side='left') 84 | left_padding_tokenizer.pad_token = left_padding_tokenizer.unk_token 85 | llm = AutoModelForCausalLM.from_pretrained(path, device_map='auto') 86 | llm.eval().cuda() 87 | 88 | source_data_files, target_data_files = get_data_files(config, node_index=node_index, node_num=node_num) 89 | batch_size = config['batch_size'] 90 | 91 | for source_data_file in source_data_files: 92 | if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" in target_data_files: 93 | print(f"file {source_data_file} processed, skipping") 94 | continue 95 | print(f"processing {source_data_file}") 96 | processed_data = [] 97 | df = pd.read_parquet(source_data_file) 98 | 99 | start, end = chunk_index * (len(df) // chunk_num), (chunk_index + 1) * (len(df) // chunk_num) - 1 100 | if len(df) - end < len(df) // chunk_num: 101 | end = len(df) - 1 102 | df = df.loc[start: end] 103 | for offset in tqdm.trange(0, len(df), batch_size): 104 | offset += start 105 | batch = df.loc[offset: offset + batch_size - 1].reset_index(drop=True) 106 | with torch.no_grad(): 107 | batch_processed_data = caption_merge(llm, left_padding_tokenizer, batch) 108 | processed_data.extend(batch_processed_data) 109 | 110 | processed_df = pd.DataFrame(processed_data).reset_index(drop=True) 111 | base_path = os.path.basename(source_data_file) 112 | output_path = f"detail_caption_construction/data/processed_data/{base_path.split('.')[0]}_chunk{chunk_index}.parquet" 113 | processed_df.to_parquet(output_path) 114 | 115 | 116 | if __name__ == '__main__': 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument('--chunk_index', type=int) 119 | parser.add_argument('--chunk_num', type=int, default=4) 120 | parser.add_argument('--node_index', type=int, default=0) 121 | parser.add_argument('--node_num', type=int, default=10) 122 | parser.add_argument('--config_path', type=str) 123 | args = parser.parse_args() 124 | 125 | main( 126 | config_path=args.config_path, 127 | chunk_index=args.chunk_index, 128 | chunk_num=args.chunk_num, 129 | node_index=args.node_index, 130 | node_num=args.node_num, 131 | ) 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /detail_caption_construction/merge_results.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (2024) CAPTURE project Authors 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import os 18 | import pandas as pd 19 | import collections 20 | import torch 21 | import time 22 | import argparse 23 | import yaml 24 | from cruise.data_module.utils import parse_data_source 25 | 26 | 27 | def watch_and_upload(config_path, node_index, node_num): 28 | with open(config_path) as f: 29 | config = yaml.load(f,Loader=yaml.FullLoader) 30 | 31 | source_data_files = os.listdir(config['source_path']) 32 | source_data_files = [f"{config['source_path']}/{path}" for path in source_data_files] 33 | source_data_files.sort() 34 | start, end = node_index * (len(source_data_files) // node_num), (node_index + 1) * (len(source_data_files) // node_num) 35 | if len(source_data_files) - end < len(source_data_files) // node_num: 36 | end = len(source_data_files) 37 | source_data_files = source_data_files[start: end] 38 | 39 | target_data_files = os.listdir(f"{config['target_path']}/") 40 | target_data_files = [f"{config['target_path']}/{path}" for path in target_data_files] 41 | target_data_files.sort() 42 | target_data_files = [file.split('/')[-1].split('.')[0] for file in target_data_files] 43 | 44 | source_data_files = [source_data_file for source_data_file in source_data_files if f"{source_data_file.split('/')[-1].split('.')[0]}_processed" not in target_data_files] 45 | print(f"source_data_files: {source_data_files}") 46 | print(f"target_data_files: {target_data_files}") 47 | source_data_file_prefixes = [source_data_file.split('/')[-1].replace('.parquet', '').replace('.snappy', '') for source_data_file in source_data_files] 48 | source_data_file_prefixes = set(source_data_file_prefixes) 49 | remain_file_num = len(source_data_file_prefixes) 50 | print(f'remain_file_num: {remain_file_num}') 51 | 52 | while remain_file_num > 0: 53 | print(time.time()) 54 | print(f"remain_file_num: {remain_file_num}") 55 | files = ["detail_caption_construction/data/processed_data/" + file for file in os.listdir("detail_caption_construction/data/processed_data")] 56 | file_prefix_mapping = collections.defaultdict(list) 57 | 58 | for file in files: 59 | prefix = file.split("/")[-1].split("_chunk")[0] 60 | if prefix in source_data_file_prefixes: 61 | file_prefix_mapping[prefix].append(file) 62 | 63 | for prefix, file_group in file_prefix_mapping.items(): 64 | if len(file_group) == torch.cuda.device_count(): 65 | print(f"start processing {prefix} results") 66 | file_group.sort() 67 | print("relating chunk files: ") 68 | for file in file_group: 69 | print(file) 70 | all_df = [pd.read_parquet(file) for file in file_group] 71 | 72 | # discard useless columns 73 | df = pd.concat(all_df).reset_index() 74 | if 'sort_index' in df.keys(): 75 | df = df.drop('sort_index', axis=1) 76 | if 'index' in df.keys(): 77 | df = df.drop('index', axis=1) 78 | if 'level_0' in df.keys(): 79 | df = df.drop('level_0', axis=1) 80 | 81 | output_file = f"{config['target_path']}/{prefix.split('/')[-1]}_processed.parquet" 82 | df.to_parquet(output_file) 83 | 84 | os.system(f"rm detail_caption_construction/data/processed_data/{prefix}*") 85 | remain_file_num -= 1 86 | print(f"finished processing {prefix} results") 87 | print("========================================") 88 | 89 | if remain_file_num > 0: 90 | time.sleep(10) 91 | 92 | # for path in source_data_files: 93 | # os.system(f"rm {path}") 94 | # print(f"prev processed data {path} removed") 95 | 96 | 97 | if __name__ == '__main__': 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument('--config', type=str) 100 | parser.add_argument('--node_index', type=int, default=0) 101 | parser.add_argument('--node_num', type=int, default=10) 102 | args = parser.parse_args() 103 | watch_and_upload(args.config, args.node_index, args.node_num) 104 | -------------------------------------------------------------------------------- /detail_caption_construction/utils/bbox_cluster.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import pandas as pd 4 | import cv2 5 | import os 6 | from matplotlib import pyplot as plt 7 | from tqdm import tqdm 8 | import json 9 | import argparse 10 | import time 11 | import yaml 12 | from detail_caption_construction.utils.bbox_statistics import compute_metrics 13 | 14 | def convert_bbox(bboxes, mode): 15 | bbox = [] 16 | for b in bboxes: 17 | if mode == "xywh": 18 | x, y, w, h = b 19 | x1, y1, x2, y2 = x, y, x+w, y+h 20 | elif mode == "xyxy": 21 | x1, y1, x2, y2 = b 22 | bbox.append([x1, y1, x2, y2]) 23 | return bbox 24 | 25 | def convert_bbox_area(bboxes, mode): 26 | area = [] 27 | for b in bboxes: 28 | if mode == "xywh": 29 | x, y, w, h = b 30 | elif mode == "xyxy": 31 | x1, y1, x2, y2 = b 32 | w, h = x2-x1, y2-y1 33 | area.append(w*h) 34 | return area 35 | 36 | def show_box(box, ax, lw): 37 | color = np.random.rand(3) 38 | if isinstance(box[0], (int, float)): 39 | box = [box] 40 | for b in box: 41 | x0, y0 = b[0], b[1] 42 | w, h = b[2] - b[0], b[3] - b[1] 43 | ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor=color, facecolor=(0,0,0,0), lw=lw)) 44 | 45 | def draw_bbox(sample, img_key, bboxes, save_path, linewidth=4): 46 | image_bytes = sample[img_key] 47 | image = np.frombuffer(image_bytes, dtype=np.uint8) 48 | image = cv2.imdecode(image, cv2.IMREAD_COLOR) 49 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 50 | image_h, image_w = image.shape[:2] 51 | # print(image_h, image_w) 52 | 53 | fig, ax = plt.subplots(figsize=(image_w/100, image_h/100), dpi=100) 54 | 55 | ax.axis('off') 56 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) 57 | ax.imshow(image) 58 | 59 | for i in range(len(bboxes)): 60 | show_box(bboxes[i], ax, linewidth) 61 | 62 | plt.savefig(os.path.join(save_path, f'{int(sample["item_id"])}.jpg'), format='jpeg') 63 | plt.close() 64 | 65 | 66 | def compute_iou(bboxes1, bboxes2, type="iou"): 67 | 68 | bbox1_array = np.array(bboxes1) 69 | bbox2_array = np.array(bboxes2) 70 | 71 | bbox1_area = (bbox1_array[:, 2] - bbox1_array[:, 0]) * (bbox1_array[:, 3] - bbox1_array[:, 1]) 72 | bbox2_area = (bbox2_array[:, 2] - bbox2_array[:, 0]) * (bbox2_array[:, 3] - bbox2_array[:, 1]) 73 | 74 | intersection_tl = np.maximum(bbox1_array[:, None, :2], bbox2_array[:, :2]) 75 | intersection_br = np.minimum(bbox1_array[:, None, 2:], bbox2_array[:, 2:]) 76 | intersection_wh = np.maximum(0, intersection_br - intersection_tl) 77 | intersection_area = intersection_wh[:, :, 0] * intersection_wh[:, :, 1] 78 | union_area = bbox1_area[:, None] + bbox2_area - intersection_area 79 | iou = intersection_area / union_area 80 | if type == "iou": 81 | return iou 82 | 83 | cir_tl = np.minimum(bbox1_array[:, None, :2], bbox2_array[:, :2]) 84 | cir_br = np.maximum(bbox1_array[:, None, 2:], bbox2_array[:, 2:]) 85 | cir_wh = np.maximum(0, cir_br - cir_tl) 86 | cir_area = cir_wh[:, :, 0] * cir_wh[:, :, 1] 87 | 88 | giou = (cir_area - union_area) / cir_area 89 | iou = iou - giou 90 | if type == "giou": 91 | return iou, giou 92 | 93 | def filter_bbox(bboxes, areas, hw): 94 | res_bbox = [] 95 | res_area = [] 96 | if len(bboxes) == 0: 97 | return res_bbox, res_area 98 | h, w = hw 99 | for bbox, area in zip(bboxes, areas): 100 | bbox_h, bbox_w = bbox[3]-bbox[1], bbox[2]-bbox[0] 101 | if area/(h*w) < 0.001: 102 | continue 103 | if area/(h*w) > 0.9: 104 | continue 105 | # if bbox_h/h < 0.01 or bbox_h/h > 0.9: 106 | # continue 107 | # if bbox_w/w < 0.01 or bbox_w/w > 0.9: 108 | # continue 109 | res_bbox.append(bbox) 110 | res_area.append(area) 111 | return res_bbox, res_area 112 | 113 | def merge_bbox(bboxes, merge_threshold): 114 | # from IPython import embed; embed() 115 | sorted_indices = np.argsort((bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])) 116 | bboxes = np.array(bboxes[sorted_indices]) 117 | 118 | merged_bboxes = [] 119 | while len(bboxes) > 0: 120 | bbox = bboxes[0] 121 | merge_indices = [0] 122 | 123 | for i in range(1, len(bboxes)): 124 | iou, giou = compute_iou([bbox], [bboxes[i]], "giou") 125 | if iou > merge_threshold or giou < 0.1: 126 | merge_indices.append(i) 127 | 128 | merged_bbox = [int(min(bboxes[merge_indices][:, 0])), 129 | int(min(bboxes[merge_indices][:, 1])), 130 | int(max(bboxes[merge_indices][:, 2])), 131 | int(max(bboxes[merge_indices][:, 3]))] 132 | merged_bboxes.append(merged_bbox) 133 | 134 | bboxes = np.array([bboxes[i] for i in range(len(bboxes)) if i not in merge_indices]) 135 | 136 | return merged_bboxes 137 | 138 | def init_cluster_centers(k, points): 139 | points = np.array(points) 140 | centers = [points[np.random.choice(len(points))]] 141 | 142 | while len(centers) < k: 143 | distances = np.linalg.norm(points[:, np.newaxis, :] - centers, axis=2) 144 | min_distances = np.min(distances, axis=1) 145 | probabilities = min_distances / min_distances.sum() 146 | next_centers_index = np.random.choice(len(points), p=probabilities) 147 | centers.append(points[next_centers_index]) 148 | 149 | return centers 150 | 151 | def iou_distance(bbox1, bbox2): 152 | x1, y1, x2, y2 = bbox1 153 | x3, y3, x4, y4 = bbox2 154 | intersection_area = max(0, min(x2, x4) - max(x1, x3)) * max(0, min(y2, y4) - max(y1, y3)) 155 | union_area = (x2 - x1) * (y2 - y1) + (x4 - x3) * (y4 - y3) - intersection_area 156 | iou = intersection_area / union_area 157 | 158 | cir_area = max(0, max(x2, x4) - min(x1, x3)) * max(0, max(y2, y4) - min(y1, y3)) 159 | 160 | giou = (cir_area - union_area) / cir_area 161 | iou = iou - giou 162 | return 1 - iou 163 | 164 | 165 | def point_is_in_boxes(point, bboxes): 166 | times = 0 167 | for i, bbox in enumerate(bboxes): 168 | if bbox[0] <= point[0] < bbox[2] and bbox[1] <= point[1] < bbox[3]: 169 | return True 170 | return False 171 | 172 | 173 | def get_area(bbox): 174 | return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) 175 | 176 | 177 | def maximalRectangle(matrix): 178 | continuous = [[0 for _ in range(len(matrix[0]))] for __ in range(len(matrix))] 179 | for i in range(len(matrix)): 180 | cur = 0 181 | for j in range(len(matrix[0])): 182 | if matrix[i][j] == 1: 183 | cur += 1 184 | else: 185 | cur = 0 186 | continuous[i][j] = cur 187 | 188 | def maxRec(nums): 189 | stack = [] 190 | right_bound = [0 for _ in range(len(nums))] 191 | for i in range(len(nums)): 192 | if len(stack) == 0 or nums[i] >= nums[stack[-1]]: 193 | stack.append(i) 194 | else: 195 | while len(stack) != 0 and nums[i] < nums[stack[-1]]: 196 | top = stack.pop() 197 | right_bound[top] = i - 1 198 | stack.append(i) 199 | 200 | while len(stack) != 0: 201 | top = stack.pop() 202 | right_bound[top] = len(nums) - 1 203 | 204 | left_bound = [0 for _ in range(len(nums))] 205 | for i in range(len(nums) - 1, -1, -1): 206 | if len(stack) == 0 or nums[i] >= nums[stack[-1]]: 207 | stack.append(i) 208 | else: 209 | while len(stack) != 0 and nums[i] < nums[stack[-1]]: 210 | top = stack.pop() 211 | left_bound[top] = i + 1 212 | stack.append(i) 213 | 214 | while len(stack) != 0: 215 | top = stack.pop() 216 | left_bound[top] = 0 217 | 218 | best, best_left_bound, best_right_bound = 0, 0, 0 219 | # print(left_bound) 220 | # print(right_bound) 221 | for i in range(len(left_bound)): 222 | if nums[i] * (right_bound[i] - left_bound[i] + 1) > best: 223 | best = nums[i] * (right_bound[i] - left_bound[i] + 1) 224 | best_left_bound, best_right_bound = left_bound[i], right_bound[i] 225 | return best, best_left_bound, best_right_bound 226 | 227 | best = 0 228 | bbox = [0, 0, 0, 0] 229 | for j in range(len(matrix[0])): 230 | nums = [continuous[i][j] for i in range(len(matrix))] 231 | # from IPython import embed; embed() 232 | this_best, this_best_left_bound, this_best_right_bound = maxRec(nums) 233 | if best < this_best: 234 | width = this_best // (this_best_right_bound - this_best_left_bound + 1) 235 | best, bbox[0], bbox[1], bbox[2], bbox[3] = this_best, j - width + 1, this_best_left_bound, j, this_best_right_bound 236 | 237 | return best, bbox 238 | 239 | 240 | def compress_pixels(bbox_map, scale=2): 241 | compressed_bbox_map = np.zeros([bbox_map.shape[0] // scale, bbox_map.shape[1] // scale], dtype=int) 242 | for x in range(compressed_bbox_map.shape[1]): 243 | for y in range(compressed_bbox_map.shape[0]): 244 | if bbox_map[y * scale: y * scale + scale, x * scale: x * scale + scale].sum() >= scale ** 2 / 2: 245 | compressed_bbox_map[y, x] = 1 246 | return compressed_bbox_map 247 | 248 | 249 | def recover_compressed_bbox(bbox, scale=2): 250 | for i in range(len(bbox)): 251 | bbox[i] = bbox[i] * scale 252 | 253 | return bbox 254 | 255 | 256 | def cluster(df, name, config): 257 | new_path = {} 258 | if config['Draw']['draw_bbox']: 259 | for key, val in config['Draw.path'].__dict__.items(): 260 | path = val.split("/") 261 | new_path[key] = "/".join(path[:-1] + [name] + [path[-1]]) 262 | os.makedirs(new_path[key], exist_ok=True) 263 | 264 | df["bbox"] = df["bboxes"].apply(convert_bbox, mode="xyxy") 265 | df["area"] = df["bboxes"].apply(convert_bbox_area, mode="xyxy") 266 | 267 | cluster_centers, merged_cluster_centers, cropped_boxes = [], [], [] 268 | k_means_times, merge_times, crop_times = [], [], [] 269 | for i in tqdm(range(len(df))): 270 | start_time_stamp = time.time() 271 | bbox_list, area_list = filter_bbox(df["bbox"][i], df["area"][i], df["hw"][i]) 272 | if len(bbox_list) == 0: 273 | cluster_centers.append([]) 274 | merged_cluster_centers.append([]) 275 | cropped_boxes.append([]) 276 | continue 277 | points = [bbox+[(bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2] for bbox in bbox_list] 278 | # points = [[(bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2] for bbox in bbox_list] 279 | 280 | sorted_bbox_indices = np.argsort(area_list)[::-1] 281 | initial_centers = [points[j] for j in sorted_bbox_indices[:min(config['kmeans_center_num'], len(sorted_bbox_indices))]] 282 | km = KMeans(n_clusters=len(initial_centers), n_init=10, init="k-means++", random_state=0) 283 | 284 | km.fit(points) 285 | k_means_time_stamp = time.time() 286 | k_means_times.append(time.time() - start_time_stamp) 287 | 288 | labels = km.labels_ 289 | cluster_center = km.cluster_centers_ 290 | clustered_bboxes = [] 291 | new_cluster_center = [] 292 | for j in range(len(initial_centers)): 293 | cur_bbox = np.array(bbox_list)[labels==j] 294 | if len(cur_bbox) == 0: 295 | continue 296 | clustered_bboxes.append(cur_bbox.tolist()) 297 | new_cluster_center.append([int(min(cur_bbox[:, 0])), int(min(cur_bbox[:, 1])), int(max(cur_bbox[:, 2])), int(max(cur_bbox[:, 3]))]) 298 | 299 | cluster_centers.append(new_cluster_center) 300 | bboxes_for_each_center = np.array(new_cluster_center) 301 | 302 | merged_center = merge_bbox(bboxes_for_each_center, config['merge_threshold']) 303 | merged_cluster_centers.append(merged_center) 304 | 305 | merge_time_stamp = time.time() 306 | merge_times.append(merge_time_stamp - k_means_time_stamp) 307 | 308 | area_merged_center = [[get_area(bbox), bbox] for bbox in merged_center] 309 | area_merged_center.sort(key=lambda x: x[0], reverse=True) 310 | 311 | bboxes_to_be_cropped = [item[1] for item in area_merged_center[:config['bbox_to_be_cropped_num']]] 312 | remain_bboxes = [item[1] for item in area_merged_center[config['bbox_to_be_cropped_num']:]] 313 | for kk, bbox in enumerate(bboxes_to_be_cropped): 314 | bbox_map = np.zeros([int(bbox[3] - bbox[1]), int(bbox[2] - bbox[0])], dtype=int) 315 | for x in range(int(bbox[2]) - int(bbox[0])): 316 | for y in range(int(bbox[3]) - int(bbox[1])): 317 | point = [bbox[0] + x, bbox[1] + y] 318 | if not point_is_in_boxes(point, remain_bboxes): 319 | bbox_map[y][x] = 1 320 | compressed_bbox_map = compress_pixels(bbox_map, scale=config['compress_scale']) 321 | for _ in range(config['expected_cropped_bbox_num_per_bbox']): 322 | if compressed_bbox_map.shape[0] == 0 or compressed_bbox_map.shape[1] == 0: 323 | break 324 | if compressed_bbox_map.sum() / (compressed_bbox_map.shape[0] * compressed_bbox_map.shape[1]) < 0.3: 325 | break 326 | __, best_bbox = maximalRectangle(compressed_bbox_map) 327 | compressed_bbox_map[best_bbox[1]: best_bbox[3], best_bbox[0]: best_bbox[2]] = 0 328 | if get_area(best_bbox) / (compressed_bbox_map.shape[0] * compressed_bbox_map.shape[1]) < 0.3: 329 | break 330 | best_bbox = recover_compressed_bbox(best_bbox, scale=config['compress_scale']) 331 | best_bbox[0], best_bbox[1], best_bbox[2], best_bbox[3] = \ 332 | best_bbox[0] + bbox[0], best_bbox[1] + bbox[1], best_bbox[2] + bbox[0], best_bbox[3] + bbox[1] 333 | remain_bboxes.append(best_bbox) 334 | cropped_boxes.append(remain_bboxes) 335 | crop_time_stamp = time.time() 336 | crop_times.append(crop_time_stamp - merge_time_stamp) 337 | 338 | if config['Draw']['draw_bbox']: 339 | draw_bbox(df.loc[i], config['img_key'], clustered_bboxes, save_path=new_path["clustered_bboxes"], linewidth=6) 340 | draw_bbox(df.loc[i], config['img_key'], new_cluster_center, save_path=new_path["cluster_center"], linewidth=6) 341 | draw_bbox(df.loc[i], config['img_key'], merged_center, save_path=new_path["merge"], linewidth=6) 342 | draw_bbox(df.loc[i], config['img_key'], remain_bboxes, save_path=new_path["cropped"], linewidth=6) 343 | 344 | df["cluster_centers"] = cluster_centers 345 | df["merged_cluster_centers"] = merged_cluster_centers 346 | df["cropped_boxes"] = cropped_boxes 347 | 348 | print(f"kmeans time {sum(k_means_times) / len(k_means_times)}") 349 | print(f"merge time {sum(merge_times) / len(merge_times)}") 350 | print(f"crop time {sum(crop_times) / len(crop_times)}") 351 | 352 | return df 353 | 354 | if __name__ == "__main__": 355 | parser = argparse.ArgumentParser() 356 | parser.add_argument('--config_path', type=str) 357 | args = parser.parse_args() 358 | 359 | name = "SAM" 360 | with open(args.config_path) as f: 361 | config = yaml.load(f,Loader=yaml.FullLoader) 362 | config = config['cluster'] 363 | 364 | data_source = config['data']['sourcesam'] 365 | img_key = config['img_key'] 366 | 367 | data_paths = ["detail_caption_construction/data/source_data/detailcaps_100_frame.parquet"] 368 | 369 | for data_path in data_paths: 370 | df = pd.read_parquet(data_path) 371 | # df = df[:20] 372 | bboxes = [] 373 | hw = [] 374 | for i in tqdm(range(len(df))): 375 | bbox = [] 376 | for j in range(len(df["annotations"][i])): 377 | x, y, w, h = df["annotations"][i][j]["bbox"].tolist() 378 | bbox.append([x, y, x+w, y+h]) 379 | bboxes.append(bbox) 380 | hw.append(df["annotations"][i][0]["segmentation"]["size"].tolist()) 381 | 382 | print("### Doing clustering ###") 383 | config['img_key'] = img_key 384 | item_id = df["item_id"].tolist() if "item_id" in df.columns else [i for i in range(len(df))] 385 | cluster_info = {"item_id": df["item_id"].tolist(), "frame": df["frame"].tolist(), "bboxes": bboxes, "hw": hw} 386 | df_cluster = pd.DataFrame(cluster_info) 387 | df_cluster = cluster(df_cluster, name, config) 388 | 389 | print("### Doing evaluation ###") 390 | compute_metrics(df_cluster, config['compress_scale'], keys=["cluster_centers", "merged_cluster_centers", "cropped_boxes"]) 391 | df["hw"] = hw 392 | df[f"{name}_bboxes"] = bboxes 393 | df[f"{name}_cluster_centers"] = df_cluster["cluster_centers"] 394 | df[f"{name}_merged_cluster_centers"] = df_cluster["merged_cluster_centers"] 395 | df[f"{name}_cropped_boxes"] = df_cluster["cropped_boxes"] 396 | 397 | from IPython import embed; embed() 398 | -------------------------------------------------------------------------------- /detail_caption_construction/utils/bbox_statistics.py: -------------------------------------------------------------------------------- 1 | from sklearn.cluster import KMeans 2 | import numpy as np 3 | import pandas as pd 4 | import cv2 5 | import os 6 | from matplotlib import pyplot as plt 7 | import tqdm 8 | import json 9 | import argparse 10 | from collections import defaultdict 11 | 12 | 13 | def point_is_in_boxes(point, bboxes): 14 | for i, bbox in enumerate(bboxes): 15 | if bbox[0] <= point[0] < bbox[2] and bbox[1] <= point[1] < bbox[3]: 16 | return True 17 | return False 18 | 19 | 20 | def get_area(bbox): 21 | return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) 22 | 23 | 24 | def compute_coverage_and_overlap(image_size, bboxes): 25 | covered, overlap = 0, 0 26 | for x in range(image_size[1]): 27 | for y in range(image_size[0]): 28 | if point_is_in_boxes([x, y], bboxes): 29 | covered += 1 30 | if covered == 0 or image_size[0]==0 or image_size[1]==0: 31 | return 0, 0, 0 32 | else: 33 | coverage = covered / (image_size[0] * image_size[1]) 34 | areas = [get_area(bbox) for bbox in bboxes] 35 | overlap = sum(areas) / covered 36 | return coverage, overlap, areas 37 | 38 | 39 | def compress_bbox(bbox, scale=2): 40 | if type(bbox) == list: 41 | for i in range(len(bbox)): 42 | bbox[i] = bbox[i] // scale 43 | elif type(bbox) == np.ndarray: 44 | bbox = bbox // scale 45 | else: 46 | raise TypeError(f"bbox type {type(bbox)} is unexpected") 47 | 48 | return bbox 49 | 50 | 51 | def compute_metrics(df, compress_scale, keys=None): 52 | all_coverage, all_overlap, all_areas, all_bbox_num = defaultdict(list), defaultdict(list), defaultdict(list), defaultdict(list) 53 | for i in tqdm.tqdm(range(len(df))): 54 | sample = df.loc[i] 55 | # image_size = sample['annotations'][0]['segmentation']['size'] // compress_scale 56 | image_size = sample["hw"][0] // compress_scale, sample["hw"][1] // compress_scale 57 | 58 | for key in keys: 59 | bboxes = np.array(sample[key]).tolist() 60 | if len(bboxes) == 0: 61 | continue 62 | bboxes = [compress_bbox(bbox, scale=compress_scale) for bbox in bboxes] 63 | 64 | coverage, overlap, areas = compute_coverage_and_overlap(image_size, bboxes) 65 | if coverage == 0: 66 | continue 67 | all_coverage[key].append(coverage) 68 | all_overlap[key].append(overlap) 69 | all_areas[key].extend(areas) 70 | all_bbox_num[key].append(len(bboxes)) 71 | 72 | # from IPython import embed; embed() 73 | quantiles = [0.1, 0.3, 0.5, 0.7, 0.9] 74 | for key in keys: 75 | print(f"######## {key} ########") 76 | print(f"bbox_num: {sum(all_bbox_num[key]) / len(all_bbox_num[key])}") 77 | print(f"bbox_num_quantile: {quantiles}: {np.quantile(np.array(all_bbox_num[key]), quantiles)}") 78 | print(f"coverage: {sum(all_coverage[key]) / len(all_coverage[key])}") 79 | print(f"coverage_quantile: {quantiles}: {np.quantile(np.array(all_coverage[key]), quantiles)}") 80 | print(f"overlap: {sum(all_overlap[key]) / len(all_overlap[key])}") 81 | print(f"overlap_quantile: {quantiles}: {np.quantile(np.array(all_overlap[key]), quantiles)}") 82 | print(f"bbox_area_quantile: {quantiles}: {np.quantile(np.array(all_areas[key]), quantiles)}") 83 | return 84 | 85 | 86 | 87 | if __name__ == "__main__": 88 | parser = argparse.ArgumentParser() 89 | parser.add_argument('--file_path', type=str, default="reservoir/processed_data/cropped_bboxes.parquet") 90 | parser.add_argument('--baseline_key', type=str, default="cluster_centers") 91 | parser.add_argument('--exp_key', type=str, default="cropped_boxes") 92 | parser.add_argument('--compress_scale', type=int, default=2) 93 | args = parser.parse_args() 94 | df = pd.read_parquet(args.file_path) 95 | compute_metrics(df, args.compress_scale, args.baseline_key, args.exp_key) 96 | -------------------------------------------------------------------------------- /detail_caption_construction/utils/image_processing_owlv2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Image processor class for OWLv2.""" 16 | 17 | import warnings 18 | from typing import Dict, List, Optional, Tuple, Union 19 | 20 | import numpy as np 21 | 22 | from ...image_processing_utils import BaseImageProcessor, BatchFeature 23 | from ...image_transforms import ( 24 | center_to_corners_format, 25 | pad, 26 | to_channel_dimension_format, 27 | ) 28 | from ...image_utils import ( 29 | OPENAI_CLIP_MEAN, 30 | OPENAI_CLIP_STD, 31 | ChannelDimension, 32 | ImageInput, 33 | PILImageResampling, 34 | get_image_size, 35 | infer_channel_dimension_format, 36 | is_scaled_image, 37 | make_list_of_images, 38 | to_numpy_array, 39 | valid_images, 40 | ) 41 | from ...utils import ( 42 | TensorType, 43 | is_scipy_available, 44 | is_torch_available, 45 | is_vision_available, 46 | logging, 47 | requires_backends, 48 | ) 49 | 50 | 51 | if is_torch_available(): 52 | import torch 53 | 54 | 55 | if is_vision_available(): 56 | import PIL 57 | 58 | if is_scipy_available(): 59 | from scipy import ndimage as ndi 60 | 61 | 62 | logger = logging.get_logger(__name__) 63 | 64 | 65 | # Copied from transformers.models.owlvit.image_processing_owlvit._upcast 66 | def _upcast(t): 67 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type 68 | if t.is_floating_point(): 69 | return t if t.dtype in (torch.float32, torch.float64) else t.float() 70 | else: 71 | return t if t.dtype in (torch.int32, torch.int64) else t.int() 72 | 73 | 74 | # Copied from transformers.models.owlvit.image_processing_owlvit.box_area 75 | def box_area(boxes): 76 | """ 77 | Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. 78 | 79 | Args: 80 | boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): 81 | Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 82 | < x2` and `0 <= y1 < y2`. 83 | Returns: 84 | `torch.FloatTensor`: a tensor containing the area for each box. 85 | """ 86 | boxes = _upcast(boxes) 87 | return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 88 | 89 | 90 | # Copied from transformers.models.owlvit.image_processing_owlvit.box_iou 91 | def box_iou(boxes1, boxes2): 92 | area1 = box_area(boxes1) 93 | area2 = box_area(boxes2) 94 | 95 | left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 96 | right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 97 | 98 | width_height = (right_bottom - left_top).clamp(min=0) # [N,M,2] 99 | inter = width_height[:, :, 0] * width_height[:, :, 1] # [N,M] 100 | 101 | union = area1[:, None] + area2 - inter 102 | 103 | iou = inter / union 104 | return iou, union 105 | 106 | 107 | def _preprocess_resize_output_shape(image, output_shape): 108 | """Validate resize output shape according to input image. 109 | 110 | Args: 111 | image (`np.ndarray`): 112 | Image to be resized. 113 | output_shape (`iterable`): 114 | Size of the generated output image `(rows, cols[, ...][, dim])`. If `dim` is not provided, the number of 115 | channels is preserved. 116 | 117 | Returns 118 | image (`np.ndarray): 119 | The input image, but with additional singleton dimensions appended in the case where `len(output_shape) > 120 | input.ndim`. 121 | output_shape (`Tuple`): 122 | The output shape converted to tuple. 123 | 124 | Raises ------ ValueError: 125 | If output_shape length is smaller than the image number of dimensions. 126 | 127 | Notes ----- The input image is reshaped if its number of dimensions is not equal to output_shape_length. 128 | 129 | """ 130 | output_shape = tuple(output_shape) 131 | output_ndim = len(output_shape) 132 | input_shape = image.shape 133 | if output_ndim > image.ndim: 134 | # append dimensions to input_shape 135 | input_shape += (1,) * (output_ndim - image.ndim) 136 | image = np.reshape(image, input_shape) 137 | elif output_ndim == image.ndim - 1: 138 | # multichannel case: append shape of last axis 139 | output_shape = output_shape + (image.shape[-1],) 140 | elif output_ndim < image.ndim: 141 | raise ValueError("output_shape length cannot be smaller than the " "image number of dimensions") 142 | 143 | return image, output_shape 144 | 145 | 146 | def _clip_warp_output(input_image, output_image): 147 | """Clip output image to range of values of input image. 148 | 149 | Note that this function modifies the values of *output_image* in-place. 150 | 151 | Taken from: 152 | https://github.com/scikit-image/scikit-image/blob/b4b521d6f0a105aabeaa31699949f78453ca3511/skimage/transform/_warps.py#L640. 153 | 154 | Args: 155 | input_image : ndarray 156 | Input image. 157 | output_image : ndarray 158 | Output image, which is modified in-place. 159 | """ 160 | min_val = np.min(input_image) 161 | if np.isnan(min_val): 162 | # NaNs detected, use NaN-safe min/max 163 | min_func = np.nanmin 164 | max_func = np.nanmax 165 | min_val = min_func(input_image) 166 | else: 167 | min_func = np.min 168 | max_func = np.max 169 | max_val = max_func(input_image) 170 | 171 | output_image = np.clip(output_image, min_val, max_val) 172 | 173 | return output_image 174 | 175 | 176 | class Owlv2ImageProcessor(BaseImageProcessor): 177 | r""" 178 | Constructs an OWLv2 image processor. 179 | 180 | Args: 181 | do_rescale (`bool`, *optional*, defaults to `True`): 182 | Whether to rescale the image by the specified scale `rescale_factor`. Can be overriden by `do_rescale` in 183 | the `preprocess` method. 184 | rescale_factor (`int` or `float`, *optional*, defaults to `1/255`): 185 | Scale factor to use if rescaling the image. Can be overriden by `rescale_factor` in the `preprocess` 186 | method. 187 | do_pad (`bool`, *optional*, defaults to `True`): 188 | Whether to pad the image to a square with gray pixels on the bottom and the right. Can be overriden by 189 | `do_pad` in the `preprocess` method. 190 | do_resize (`bool`, *optional*, defaults to `True`): 191 | Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overriden 192 | by `do_resize` in the `preprocess` method. 193 | size (`Dict[str, int]` *optional*, defaults to `{"height": 960, "width": 960}`): 194 | Size to resize the image to. Can be overriden by `size` in the `preprocess` method. 195 | resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`): 196 | Resampling method to use if resizing the image. Can be overriden by `resample` in the `preprocess` method. 197 | do_normalize (`bool`, *optional*, defaults to `True`): 198 | Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess` 199 | method. 200 | image_mean (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_MEAN`): 201 | Mean to use if normalizing the image. This is a float or list of floats the length of the number of 202 | channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. 203 | image_std (`float` or `List[float]`, *optional*, defaults to `OPENAI_CLIP_STD`): 204 | Standard deviation to use if normalizing the image. This is a float or list of floats the length of the 205 | number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method. 206 | """ 207 | 208 | model_input_names = ["pixel_values"] 209 | 210 | def __init__( 211 | self, 212 | do_rescale: bool = True, 213 | rescale_factor: Union[int, float] = 1 / 255, 214 | do_pad: bool = True, 215 | do_resize: bool = True, 216 | size: Dict[str, int] = None, 217 | resample: PILImageResampling = PILImageResampling.BILINEAR, 218 | do_normalize: bool = True, 219 | image_mean: Optional[Union[float, List[float]]] = None, 220 | image_std: Optional[Union[float, List[float]]] = None, 221 | **kwargs, 222 | ) -> None: 223 | super().__init__(**kwargs) 224 | 225 | self.do_rescale = do_rescale 226 | self.rescale_factor = rescale_factor 227 | self.do_pad = do_pad 228 | self.do_resize = do_resize 229 | self.size = size if size is not None else {"height": 960, "width": 960} 230 | self.resample = resample 231 | self.do_normalize = do_normalize 232 | self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN 233 | self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD 234 | 235 | def pad( 236 | self, 237 | image: np.array, 238 | data_format: Optional[Union[str, ChannelDimension]] = None, 239 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 240 | ): 241 | """ 242 | Pad an image to a square with gray pixels on the bottom and the right, as per the original OWLv2 243 | implementation. 244 | 245 | Args: 246 | image (`np.ndarray`): 247 | Image to pad. 248 | data_format (`str` or `ChannelDimension`, *optional*): 249 | The channel dimension format of the image. If not provided, it will be the same as the input image. 250 | input_data_format (`ChannelDimension` or `str`, *optional*): 251 | The channel dimension format of the input image. If not provided, it will be inferred from the input 252 | image. 253 | """ 254 | height, width = get_image_size(image) 255 | size = max(height, width) 256 | image = pad( 257 | image=image, 258 | padding=((0, size - height), (0, size - width)), 259 | constant_values=0.5, 260 | data_format=data_format, 261 | input_data_format=input_data_format, 262 | ) 263 | 264 | return image 265 | 266 | def resize( 267 | self, 268 | image: np.ndarray, 269 | size: Dict[str, int], 270 | anti_aliasing: bool = True, 271 | anti_aliasing_sigma=None, 272 | data_format: Optional[Union[str, ChannelDimension]] = None, 273 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 274 | **kwargs, 275 | ) -> np.ndarray: 276 | """ 277 | Resize an image as per the original implementation. 278 | 279 | Args: 280 | image (`np.ndarray`): 281 | Image to resize. 282 | size (`Dict[str, int]`): 283 | Dictionary containing the height and width to resize the image to. 284 | anti_aliasing (`bool`, *optional*, defaults to `True`): 285 | Whether to apply anti-aliasing when downsampling the image. 286 | anti_aliasing_sigma (`float`, *optional*, defaults to `None`): 287 | Standard deviation for Gaussian kernel when downsampling the image. If `None`, it will be calculated 288 | automatically. 289 | data_format (`str` or `ChannelDimension`, *optional*): 290 | The channel dimension format of the image. If not provided, it will be the same as the input image. 291 | input_data_format (`ChannelDimension` or `str`, *optional*): 292 | The channel dimension format of the input image. If not provided, it will be inferred from the input 293 | image. 294 | """ 295 | requires_backends(self, "scipy") 296 | 297 | output_shape = (size["height"], size["width"]) 298 | image = to_channel_dimension_format(image, ChannelDimension.LAST) 299 | image, output_shape = _preprocess_resize_output_shape(image, output_shape) 300 | input_shape = image.shape 301 | factors = np.divide(input_shape, output_shape) 302 | 303 | # Translate modes used by np.pad to those used by scipy.ndimage 304 | ndi_mode = "mirror" 305 | cval = 0 306 | order = 1 307 | if anti_aliasing: 308 | if anti_aliasing_sigma is None: 309 | anti_aliasing_sigma = np.maximum(0, (factors - 1) / 2) 310 | else: 311 | anti_aliasing_sigma = np.atleast_1d(anti_aliasing_sigma) * np.ones_like(factors) 312 | if np.any(anti_aliasing_sigma < 0): 313 | raise ValueError("Anti-aliasing standard deviation must be " "greater than or equal to zero") 314 | elif np.any((anti_aliasing_sigma > 0) & (factors <= 1)): 315 | warnings.warn( 316 | "Anti-aliasing standard deviation greater than zero but " "not down-sampling along all axes" 317 | ) 318 | filtered = ndi.gaussian_filter(image, anti_aliasing_sigma, cval=cval, mode=ndi_mode) 319 | else: 320 | filtered = image 321 | 322 | zoom_factors = [1 / f for f in factors] 323 | out = ndi.zoom(filtered, zoom_factors, order=order, mode=ndi_mode, cval=cval, grid_mode=True) 324 | 325 | image = _clip_warp_output(image, out) 326 | 327 | image = to_channel_dimension_format(image, input_data_format, ChannelDimension.LAST) 328 | image = ( 329 | to_channel_dimension_format(image, data_format, input_data_format) if data_format is not None else image 330 | ) 331 | return image 332 | 333 | def preprocess( 334 | self, 335 | images: ImageInput, 336 | do_pad: bool = None, 337 | do_resize: bool = None, 338 | size: Dict[str, int] = None, 339 | do_rescale: bool = None, 340 | rescale_factor: float = None, 341 | do_normalize: bool = None, 342 | image_mean: Optional[Union[float, List[float]]] = None, 343 | image_std: Optional[Union[float, List[float]]] = None, 344 | return_tensors: Optional[Union[str, TensorType]] = None, 345 | data_format: ChannelDimension = ChannelDimension.FIRST, 346 | input_data_format: Optional[Union[str, ChannelDimension]] = None, 347 | **kwargs, 348 | ) -> PIL.Image.Image: 349 | """ 350 | Preprocess an image or batch of images. 351 | 352 | Args: 353 | images (`ImageInput`): 354 | Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If 355 | passing in images with pixel values between 0 and 1, set `do_rescale=False`. 356 | do_pad (`bool`, *optional*, defaults to `self.do_pad`): 357 | Whether to pad the image to a square with gray pixels on the bottom and the right. 358 | do_resize (`bool`, *optional*, defaults to `self.do_resize`): 359 | Whether to resize the image. 360 | size (`Dict[str, int]`, *optional*, defaults to `self.size`): 361 | Size to resize the image to. 362 | do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): 363 | Whether to rescale the image values between [0 - 1]. 364 | rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): 365 | Rescale factor to rescale the image by if `do_rescale` is set to `True`. 366 | do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): 367 | Whether to normalize the image. 368 | image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`): 369 | Image mean. 370 | image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): 371 | Image standard deviation. 372 | return_tensors (`str` or `TensorType`, *optional*): 373 | The type of tensors to return. Can be one of: 374 | - Unset: Return a list of `np.ndarray`. 375 | - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. 376 | - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. 377 | - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. 378 | - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. 379 | data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): 380 | The channel dimension format for the output image. Can be one of: 381 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 382 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 383 | - Unset: Use the channel dimension format of the input image. 384 | input_data_format (`ChannelDimension` or `str`, *optional*): 385 | The channel dimension format for the input image. If unset, the channel dimension format is inferred 386 | from the input image. Can be one of: 387 | - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format. 388 | - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format. 389 | - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. 390 | """ 391 | do_rescale = do_rescale if do_rescale is not None else self.do_rescale 392 | rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor 393 | do_pad = do_pad if do_pad is not None else self.do_pad 394 | do_resize = do_resize if do_resize is not None else self.do_resize 395 | do_normalize = do_normalize if do_normalize is not None else self.do_normalize 396 | image_mean = image_mean if image_mean is not None else self.image_mean 397 | image_std = image_std if image_std is not None else self.image_std 398 | 399 | size = size if size is not None else self.size 400 | 401 | images = make_list_of_images(images) 402 | 403 | if not valid_images(images): 404 | raise ValueError( 405 | "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " 406 | "torch.Tensor, tf.Tensor or jax.ndarray." 407 | ) 408 | 409 | if do_resize and size is None: 410 | raise ValueError("Size must be specified if do_resize is True.") 411 | 412 | if do_rescale and rescale_factor is None: 413 | raise ValueError("Rescale factor must be specified if do_rescale is True.") 414 | 415 | if do_normalize and (image_mean is None or image_std is None): 416 | raise ValueError("Image mean and std must be specified if do_normalize is True.") 417 | 418 | # All transformations expect numpy arrays. 419 | images = [to_numpy_array(image) for image in images] 420 | 421 | if is_scaled_image(images[0]) and do_rescale: 422 | logger.warning_once( 423 | "It looks like you are trying to rescale already rescaled images. If the input" 424 | " images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." 425 | ) 426 | 427 | if input_data_format is None: 428 | # We assume that all images have the same channel dimension format. 429 | input_data_format = infer_channel_dimension_format(images[0]) 430 | 431 | if do_rescale: 432 | images = [ 433 | self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) 434 | for image in images 435 | ] 436 | 437 | if do_pad: 438 | images = [self.pad(image=image, input_data_format=input_data_format) for image in images] 439 | 440 | if do_resize: 441 | images = [ 442 | self.resize( 443 | image=image, 444 | size=size, 445 | input_data_format=input_data_format, 446 | ) 447 | for image in images 448 | ] 449 | 450 | if do_normalize: 451 | images = [ 452 | self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) 453 | for image in images 454 | ] 455 | 456 | images = [ 457 | to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images 458 | ] 459 | 460 | data = {"pixel_values": images} 461 | return BatchFeature(data=data, tensor_type=return_tensors) 462 | 463 | # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_object_detection 464 | def post_process_object_detection( 465 | self, outputs, threshold: float = 0.1, target_sizes: Union[TensorType, List[Tuple]] = None, nms_threshold: float = 1.0 466 | ): 467 | """ 468 | Converts the raw output of [`OwlViTForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y, 469 | bottom_right_x, bottom_right_y) format. 470 | Args: 471 | outputs ([`OwlViTObjectDetectionOutput`]): 472 | Raw outputs of the model. 473 | threshold (`float`, *optional*): 474 | Score threshold to keep object detection predictions. 475 | target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*): 476 | Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size 477 | `(height, width)` of each image in the batch. If unset, predictions will not be resized. 478 | Returns: 479 | `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image 480 | in the batch as predicted by the model. 481 | """ 482 | # TODO: (amy) add support for other frameworks 483 | logits, boxes = outputs.logits, outputs.pred_boxes 484 | if target_sizes is not None: 485 | if len(logits) != len(target_sizes): 486 | raise ValueError( 487 | "Make sure that you pass in as many target sizes as the batch dimension of the logits" 488 | ) 489 | probs = torch.max(logits, dim=-1) 490 | scores = torch.sigmoid(probs.values) 491 | labels = probs.indices 492 | # Convert to [x0, y0, x1, y1] format 493 | boxes = center_to_corners_format(boxes) 494 | # Apply non-maximum suppression (NMS) 495 | if nms_threshold < 1.0: 496 | for idx in range(boxes.shape[0]): 497 | for i in torch.argsort(-scores[idx]): 498 | if not scores[idx][i]: 499 | continue 500 | ious = box_iou(boxes[idx][i, :].unsqueeze(0), boxes[idx])[0][0] 501 | ious[i] = -1.0 # Mask self-IoU. 502 | scores[idx][ious > nms_threshold] = 0.0 503 | # Convert from relative [0, 1] to absolute [0, height] coordinates 504 | if target_sizes is not None: 505 | if isinstance(target_sizes, List): 506 | img_h = torch.Tensor([i[0] for i in target_sizes]) 507 | img_w = torch.Tensor([i[1] for i in target_sizes]) 508 | else: 509 | img_h, img_w = target_sizes.unbind(1) 510 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(boxes.device) 511 | boxes = boxes * scale_fct[:, None, :] 512 | results = [] 513 | for s, l, b in zip(scores, labels, boxes): 514 | score = s[s > threshold] 515 | label = l[s > threshold] 516 | box = b[s > threshold] 517 | results.append({"scores": score, "labels": label, "boxes": box}) 518 | return results 519 | 520 | # Copied from transformers.models.owlvit.image_processing_owlvit.OwlViTImageProcessor.post_process_image_guided_detection 521 | def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_threshold=0.3, target_sizes=None): 522 | """ 523 | Converts the output of [`OwlViTForObjectDetection.image_guided_detection`] into the format expected by the COCO 524 | api. 525 | 526 | Args: 527 | outputs ([`OwlViTImageGuidedObjectDetectionOutput`]): 528 | Raw outputs of the model. 529 | threshold (`float`, *optional*, defaults to 0.0): 530 | Minimum confidence threshold to use to filter out predicted boxes. 531 | nms_threshold (`float`, *optional*, defaults to 0.3): 532 | IoU threshold for non-maximum suppression of overlapping boxes. 533 | target_sizes (`torch.Tensor`, *optional*): 534 | Tensor of shape (batch_size, 2) where each entry is the (height, width) of the corresponding image in 535 | the batch. If set, predicted normalized bounding boxes are rescaled to the target sizes. If left to 536 | None, predictions will not be unnormalized. 537 | 538 | Returns: 539 | `List[Dict]`: A list of dictionaries, each dictionary containing the scores, labels and boxes for an image 540 | in the batch as predicted by the model. All labels are set to None as 541 | `OwlViTForObjectDetection.image_guided_detection` perform one-shot object detection. 542 | """ 543 | logits, target_boxes = outputs.logits, outputs.target_pred_boxes 544 | 545 | if len(logits) != len(target_sizes): 546 | raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") 547 | if target_sizes.shape[1] != 2: 548 | raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") 549 | 550 | probs = torch.max(logits, dim=-1) 551 | scores = torch.sigmoid(probs.values) 552 | 553 | # Convert to [x0, y0, x1, y1] format 554 | target_boxes = center_to_corners_format(target_boxes) 555 | 556 | # Apply non-maximum suppression (NMS) 557 | if nms_threshold < 1.0: 558 | for idx in range(target_boxes.shape[0]): 559 | for i in torch.argsort(-scores[idx]): 560 | if not scores[idx][i]: 561 | continue 562 | 563 | ious = box_iou(target_boxes[idx][i, :].unsqueeze(0), target_boxes[idx])[0][0] 564 | ious[i] = -1.0 # Mask self-IoU. 565 | scores[idx][ious > nms_threshold] = 0.0 566 | 567 | # Convert from relative [0, 1] to absolute [0, height] coordinates 568 | img_h, img_w = target_sizes.unbind(1) 569 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device) 570 | target_boxes = target_boxes * scale_fct[:, None, :] 571 | 572 | # Compute box display alphas based on prediction scores 573 | results = [] 574 | alphas = torch.zeros_like(scores) 575 | 576 | for idx in range(target_boxes.shape[0]): 577 | # Select scores for boxes matching the current query: 578 | query_scores = scores[idx] 579 | if not query_scores.nonzero().numel(): 580 | continue 581 | 582 | # Apply threshold on scores before scaling 583 | query_scores[query_scores < threshold] = 0.0 584 | 585 | # Scale box alpha such that the best box for each query has alpha 1.0 and the worst box has alpha 0.1. 586 | # All other boxes will either belong to a different query, or will not be shown. 587 | max_score = torch.max(query_scores) + 1e-6 588 | query_alphas = (query_scores - (max_score * 0.1)) / (max_score * 0.9) 589 | query_alphas = torch.clip(query_alphas, 0.0, 1.0) 590 | alphas[idx] = query_alphas 591 | 592 | mask = alphas[idx] > 0 593 | box_scores = alphas[idx][mask] 594 | boxes = target_boxes[idx][mask] 595 | results.append({"scores": box_scores, "labels": None, "boxes": boxes}) 596 | 597 | return results 598 | -------------------------------------------------------------------------------- /detail_caption_construction/utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | 5 | def get_data_files(config, node_index, node_num): 6 | source_path = config['source_path'] 7 | source_data_files = os.listdir(source_path) 8 | source_data_files = [f"{source_path}/{path}" for path in source_data_files] 9 | source_data_files.sort() 10 | start, end = node_index * (len(source_data_files) // node_num), (node_index + 1) * (len(source_data_files) // node_num) 11 | if len(source_data_files) - end < len(source_data_files) // node_num: 12 | end = len(source_data_files) 13 | source_data_files = source_data_files[start: end] 14 | 15 | # os.makedirs(f"{config['target_path']}/node_{node_index}/", exist_ok=True) 16 | target_data_files = os.listdir(f"{config['target_path']}/") 17 | target_data_files = [f"{config['target_path']}/{path}" for path in target_data_files] 18 | target_data_files.sort() 19 | target_data_files = [file.split('/')[-1].split('.')[0] for file in target_data_files] 20 | print(f"processed_files: {target_data_files}") 21 | 22 | return source_data_files, target_data_files 23 | 24 | 25 | -------------------------------------------------------------------------------- /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foundation-multimodal-models/CAPTURE/52eeb2781e8b4b1854c07a57b573c75edfd45688/images/intro.png -------------------------------------------------------------------------------- /prepare.sh: -------------------------------------------------------------------------------- 1 | 2 | if [ -d "./detail_caption_construction/LLaVA" ]; then 3 | echo "detail_caption_construction/LLaVA already configured, skipping......" 4 | else 5 | cd detail_caption_construction 6 | git clone https://github.com/haotian-liu/LLaVA 7 | cd LLaVA 8 | pip3 install -e . 9 | cd .. 10 | cd .. 11 | location=$(pip3 show transformers | grep "Location") 12 | location=${location/Location: /} 13 | sudo chmod -R 777 ${location}/transformers/models/owlv2/ 14 | rm ${location}/transformers/models/owlv2/image_processing_owlv2.py 15 | cp detail_caption_construction/utils/image_processing_owlv2.py ${location}/transformers/models/owlv2/ 16 | fi 17 | 18 | sam_installed=$(pip3 list | grep segment) 19 | if [ -n "$sam_installed" ]; then 20 | echo "sam already configured, skipping......" 21 | else 22 | pip3 install git+https://github.com/facebookresearch/segment-anything.git 23 | pip3 install opencv-python pycocotools matplotlib onnxruntime onnx 24 | fi 25 | 26 | if [ -d "./detail_caption_construction/data" ]; then 27 | echo "detail_caption_construction/data already exists, skipping......" 28 | else 29 | mkdir ./detail_caption_construction/data 30 | fi 31 | 32 | if [ -d "./detail_caption_construction/data/source_data" ]; then 33 | echo "detail_caption_construction data folders already configured, skipping......" 34 | else 35 | cd detail_caption_construction 36 | cd data 37 | mkdir source_data 38 | mkdir stage1_overall_caption 39 | mkdir stage2_bbox 40 | mkdir stage3_local_caption 41 | mkdir stage4_filter 42 | mkdir stage5_caption_merge 43 | mkdir processed_data 44 | cd .. 45 | cd .. 46 | fi 47 | 48 | if [ -d "./detail_caption_construction/scripts_output" ]; then 49 | rm -r ./detail_caption_construction/scripts_output 50 | mkdir ./detail_caption_construction/scripts_output 51 | else 52 | mkdir ./detail_caption_construction/scripts_output 53 | fi 54 | 55 | --------------------------------------------------------------------------------