├── .gitignore ├── LICENSE ├── MANIFEST.in ├── NOTICE.txt ├── README.md ├── batch_inference.py ├── datasets ├── hico │ ├── hico_600_annots.txt │ └── hico_600_taglist.txt ├── imagenet_multi │ ├── imagenet_multi_1000_annots.txt │ └── imagenet_multi_1000_taglist.txt ├── openimages_common_214 │ ├── imgs │ │ └── .gitkeep │ ├── openimages_common_214_ram_annots.txt │ ├── openimages_common_214_ram_taglist.txt │ ├── openimages_common_214_tag2text_idannots.txt │ └── openimages_common_214_tag2text_tagidlist.txt └── openimages_rare_200 │ ├── imgs │ └── .gitkeep │ ├── openimages_rare_200_llm_tag_descriptions.json │ ├── openimages_rare_200_ram_annots.txt │ └── openimages_rare_200_ram_taglist.txt ├── finetune.py ├── generate_tag_des_llm.py ├── gui_demo.ipynb ├── images ├── 1641173_2291260800.jpg ├── demo │ ├── demo1.jpg │ ├── demo2.jpg │ ├── demo3.jpg │ └── demo4.jpg ├── experiment_comparison.png ├── localization_and_recognition.jpg ├── openset_example.jpg ├── ram_grounded_sam.jpg ├── ram_plus_compare.jpg ├── ram_plus_experiment.png ├── ram_plus_framework.jpg ├── ram_plus_visualization.jpg ├── tag2text_framework.png ├── tag2text_grounded_sam.jpg ├── tag2text_retrieval_visualization.png ├── tag2text_visualization.png └── tagging_results.jpg ├── inference_ram.py ├── inference_ram_openset.py ├── inference_ram_plus.py ├── inference_ram_plus_openset.py ├── inference_tag2text.py ├── pretrain.py ├── ram ├── __init__.py ├── configs │ ├── finetune.yaml │ ├── finetune_tag2text.yaml │ ├── med_config.json │ ├── pretrain.yaml │ ├── pretrain_tag2text.yaml │ ├── q2l_config.json │ └── swin │ │ ├── config_swinB_224.json │ │ ├── config_swinB_384.json │ │ ├── config_swinL_224.json │ │ └── config_swinL_384.json ├── data │ ├── __init__.py │ ├── dataset.py │ ├── ram_tag_list.txt │ ├── ram_tag_list_chinese.txt │ ├── ram_tag_list_threshold.txt │ ├── randaugment.py │ ├── tag2text_ori_tag_list.txt │ ├── tag_list.txt │ └── utils.py ├── inference.py ├── models │ ├── __init__.py │ ├── bert.py │ ├── ram.py │ ├── ram_plus.py │ ├── swin_transformer.py │ ├── tag2text.py │ ├── utils.py │ └── vit.py ├── transform.py └── utils │ ├── __init__.py │ ├── metrics.py │ └── openset_utils.py ├── recognize_anything_demo.ipynb ├── requirements.txt ├── setup.cfg ├── setup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # checkpoint 132 | *.pth 133 | outputs/ 134 | 135 | # Editor 136 | .idea/ 137 | .vscode/ 138 | 139 | # gradio cache 140 | gradio_cached_examples/ 141 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | 190 | Copyright (c) 2022 OPPO 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | https://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include ram/configs/*.json 2 | include ram/configs/swin/*.json 3 | include ram/data/*.txt 4 | -------------------------------------------------------------------------------- /datasets/hico/hico_600_taglist.txt: -------------------------------------------------------------------------------- 1 | person board airplane 2 | person direct airplane 3 | person exit airplane 4 | person fly airplane 5 | person inspect airplane 6 | person load airplane 7 | person ride airplane 8 | person sit_on airplane 9 | person wash airplane 10 | person no_interaction airplane 11 | person carry bicycle 12 | person hold bicycle 13 | person inspect bicycle 14 | person jump bicycle 15 | person hop_on bicycle 16 | person park bicycle 17 | person push bicycle 18 | person repair bicycle 19 | person ride bicycle 20 | person sit_on bicycle 21 | person straddle bicycle 22 | person walk bicycle 23 | person wash bicycle 24 | person no_interaction bicycle 25 | person chase bird 26 | person feed bird 27 | person hold bird 28 | person pet bird 29 | person release bird 30 | person watch bird 31 | person no_interaction bird 32 | person board boat 33 | person drive boat 34 | person exit boat 35 | person inspect boat 36 | person jump boat 37 | person launch boat 38 | person repair boat 39 | person ride boat 40 | person row boat 41 | person sail boat 42 | person sit_on boat 43 | person stand_on boat 44 | person tie boat 45 | person wash boat 46 | person no_interaction boat 47 | person carry bottle 48 | person drink_with bottle 49 | person hold bottle 50 | person inspect bottle 51 | person lick bottle 52 | person open bottle 53 | person pour bottle 54 | person no_interaction bottle 55 | person board bus 56 | person direct bus 57 | person drive bus 58 | person exit bus 59 | person inspect bus 60 | person load bus 61 | person ride bus 62 | person sit_on bus 63 | person wash bus 64 | person wave bus 65 | person no_interaction bus 66 | person board car 67 | person direct car 68 | person drive car 69 | person hose car 70 | person inspect car 71 | person jump car 72 | person load car 73 | person park car 74 | person ride car 75 | person wash car 76 | person no_interaction car 77 | person dry cat 78 | person feed cat 79 | person hold cat 80 | person hug cat 81 | person kiss cat 82 | person pet cat 83 | person scratch cat 84 | person wash cat 85 | person chase cat 86 | person no_interaction cat 87 | person carry chair 88 | person hold chair 89 | person lie_on chair 90 | person sit_on chair 91 | person stand_on chair 92 | person no_interaction chair 93 | person carry couch 94 | person lie_on couch 95 | person sit_on couch 96 | person no_interaction couch 97 | person feed cow 98 | person herd cow 99 | person hold cow 100 | person hug cow 101 | person kiss cow 102 | person lasso cow 103 | person milk cow 104 | person pet cow 105 | person ride cow 106 | person walk cow 107 | person no_interaction cow 108 | person clean dining_table 109 | person eat_at dining_table 110 | person sit_at dining_table 111 | person no_interaction dining_table 112 | person carry dog 113 | person dry dog 114 | person feed dog 115 | person groom dog 116 | person hold dog 117 | person hose dog 118 | person hug dog 119 | person inspect dog 120 | person kiss dog 121 | person pet dog 122 | person run dog 123 | person scratch dog 124 | person straddle dog 125 | person train dog 126 | person walk dog 127 | person wash dog 128 | person chase dog 129 | person no_interaction dog 130 | person feed horse 131 | person groom horse 132 | person hold horse 133 | person hug horse 134 | person jump horse 135 | person kiss horse 136 | person load horse 137 | person hop_on horse 138 | person pet horse 139 | person race horse 140 | person ride horse 141 | person run horse 142 | person straddle horse 143 | person train horse 144 | person walk horse 145 | person wash horse 146 | person no_interaction horse 147 | person hold motorcycle 148 | person inspect motorcycle 149 | person jump motorcycle 150 | person hop_on motorcycle 151 | person park motorcycle 152 | person push motorcycle 153 | person race motorcycle 154 | person ride motorcycle 155 | person sit_on motorcycle 156 | person straddle motorcycle 157 | person turn motorcycle 158 | person walk motorcycle 159 | person wash motorcycle 160 | person no_interaction motorcycle 161 | person carry person 162 | person greet person 163 | person hold person 164 | person hug person 165 | person kiss person 166 | person stab person 167 | person tag person 168 | person teach person 169 | person lick person 170 | person no_interaction person 171 | person carry potted_plant 172 | person hold potted_plant 173 | person hose potted_plant 174 | person no_interaction potted_plant 175 | person carry sheep 176 | person feed sheep 177 | person herd sheep 178 | person hold sheep 179 | person hug sheep 180 | person kiss sheep 181 | person pet sheep 182 | person ride sheep 183 | person shear sheep 184 | person walk sheep 185 | person wash sheep 186 | person no_interaction sheep 187 | person board train 188 | person drive train 189 | person exit train 190 | person load train 191 | person ride train 192 | person sit_on train 193 | person wash train 194 | person no_interaction train 195 | person control tv 196 | person repair tv 197 | person watch tv 198 | person no_interaction tv 199 | person buy apple 200 | person cut apple 201 | person eat apple 202 | person hold apple 203 | person inspect apple 204 | person peel apple 205 | person pick apple 206 | person smell apple 207 | person wash apple 208 | person no_interaction apple 209 | person carry backpack 210 | person hold backpack 211 | person inspect backpack 212 | person open backpack 213 | person wear backpack 214 | person no_interaction backpack 215 | person buy banana 216 | person carry banana 217 | person cut banana 218 | person eat banana 219 | person hold banana 220 | person inspect banana 221 | person peel banana 222 | person pick banana 223 | person smell banana 224 | person no_interaction banana 225 | person break baseball_bat 226 | person carry baseball_bat 227 | person hold baseball_bat 228 | person sign baseball_bat 229 | person swing baseball_bat 230 | person throw baseball_bat 231 | person wield baseball_bat 232 | person no_interaction baseball_bat 233 | person hold baseball_glove 234 | person wear baseball_glove 235 | person no_interaction baseball_glove 236 | person feed bear 237 | person hunt bear 238 | person watch bear 239 | person no_interaction bear 240 | person clean bed 241 | person lie_on bed 242 | person sit_on bed 243 | person no_interaction bed 244 | person inspect bench 245 | person lie_on bench 246 | person sit_on bench 247 | person no_interaction bench 248 | person carry book 249 | person hold book 250 | person open book 251 | person read book 252 | person no_interaction book 253 | person hold bowl 254 | person stir bowl 255 | person wash bowl 256 | person lick bowl 257 | person no_interaction bowl 258 | person cut broccoli 259 | person eat broccoli 260 | person hold broccoli 261 | person smell broccoli 262 | person stir broccoli 263 | person wash broccoli 264 | person no_interaction broccoli 265 | person blow cake 266 | person carry cake 267 | person cut cake 268 | person eat cake 269 | person hold cake 270 | person light cake 271 | person make cake 272 | person pick_up cake 273 | person no_interaction cake 274 | person carry carrot 275 | person cook carrot 276 | person cut carrot 277 | person eat carrot 278 | person hold carrot 279 | person peel carrot 280 | person smell carrot 281 | person stir carrot 282 | person wash carrot 283 | person no_interaction carrot 284 | person carry cell_phone 285 | person hold cell_phone 286 | person read cell_phone 287 | person repair cell_phone 288 | person talk_on cell_phone 289 | person text_on cell_phone 290 | person no_interaction cell_phone 291 | person check clock 292 | person hold clock 293 | person repair clock 294 | person set clock 295 | person no_interaction clock 296 | person carry cup 297 | person drink_with cup 298 | person hold cup 299 | person inspect cup 300 | person pour cup 301 | person sip cup 302 | person smell cup 303 | person fill cup 304 | person wash cup 305 | person no_interaction cup 306 | person buy donut 307 | person carry donut 308 | person eat donut 309 | person hold donut 310 | person make donut 311 | person pick_up donut 312 | person smell donut 313 | person no_interaction donut 314 | person feed elephant 315 | person hold elephant 316 | person hose elephant 317 | person hug elephant 318 | person kiss elephant 319 | person hop_on elephant 320 | person pet elephant 321 | person ride elephant 322 | person walk elephant 323 | person wash elephant 324 | person watch elephant 325 | person no_interaction elephant 326 | person hug fire_hydrant 327 | person inspect fire_hydrant 328 | person open fire_hydrant 329 | person paint fire_hydrant 330 | person no_interaction fire_hydrant 331 | person hold fork 332 | person lift fork 333 | person stick fork 334 | person lick fork 335 | person wash fork 336 | person no_interaction fork 337 | person block frisbee 338 | person catch frisbee 339 | person hold frisbee 340 | person spin frisbee 341 | person throw frisbee 342 | person no_interaction frisbee 343 | person feed giraffe 344 | person kiss giraffe 345 | person pet giraffe 346 | person ride giraffe 347 | person watch giraffe 348 | person no_interaction giraffe 349 | person hold hair_drier 350 | person operate hair_drier 351 | person repair hair_drier 352 | person no_interaction hair_drier 353 | person carry handbag 354 | person hold handbag 355 | person inspect handbag 356 | person no_interaction handbag 357 | person carry hot_dog 358 | person cook hot_dog 359 | person cut hot_dog 360 | person eat hot_dog 361 | person hold hot_dog 362 | person make hot_dog 363 | person no_interaction hot_dog 364 | person carry keyboard 365 | person clean keyboard 366 | person hold keyboard 367 | person type_on keyboard 368 | person no_interaction keyboard 369 | person assemble kite 370 | person carry kite 371 | person fly kite 372 | person hold kite 373 | person inspect kite 374 | person launch kite 375 | person pull kite 376 | person no_interaction kite 377 | person cut_with knife 378 | person hold knife 379 | person stick knife 380 | person wash knife 381 | person wield knife 382 | person lick knife 383 | person no_interaction knife 384 | person hold laptop 385 | person open laptop 386 | person read laptop 387 | person repair laptop 388 | person type_on laptop 389 | person no_interaction laptop 390 | person clean microwave 391 | person open microwave 392 | person operate microwave 393 | person no_interaction microwave 394 | person control mouse 395 | person hold mouse 396 | person repair mouse 397 | person no_interaction mouse 398 | person buy orange 399 | person cut orange 400 | person eat orange 401 | person hold orange 402 | person inspect orange 403 | person peel orange 404 | person pick orange 405 | person squeeze orange 406 | person wash orange 407 | person no_interaction orange 408 | person clean oven 409 | person hold oven 410 | person inspect oven 411 | person open oven 412 | person repair oven 413 | person operate oven 414 | person no_interaction oven 415 | person check parking_meter 416 | person pay parking_meter 417 | person repair parking_meter 418 | person no_interaction parking_meter 419 | person buy pizza 420 | person carry pizza 421 | person cook pizza 422 | person cut pizza 423 | person eat pizza 424 | person hold pizza 425 | person make pizza 426 | person pick_up pizza 427 | person slide pizza 428 | person smell pizza 429 | person no_interaction pizza 430 | person clean refrigerator 431 | person hold refrigerator 432 | person move refrigerator 433 | person open refrigerator 434 | person no_interaction refrigerator 435 | person hold remote 436 | person point remote 437 | person swing remote 438 | person no_interaction remote 439 | person carry sandwich 440 | person cook sandwich 441 | person cut sandwich 442 | person eat sandwich 443 | person hold sandwich 444 | person make sandwich 445 | person no_interaction sandwich 446 | person cut_with scissors 447 | person hold scissors 448 | person open scissors 449 | person no_interaction scissors 450 | person clean sink 451 | person repair sink 452 | person wash sink 453 | person no_interaction sink 454 | person carry skateboard 455 | person flip skateboard 456 | person grind skateboard 457 | person hold skateboard 458 | person jump skateboard 459 | person pick_up skateboard 460 | person ride skateboard 461 | person sit_on skateboard 462 | person stand_on skateboard 463 | person no_interaction skateboard 464 | person adjust skis 465 | person carry skis 466 | person hold skis 467 | person inspect skis 468 | person jump skis 469 | person pick_up skis 470 | person repair skis 471 | person ride skis 472 | person stand_on skis 473 | person wear skis 474 | person no_interaction skis 475 | person adjust snowboard 476 | person carry snowboard 477 | person grind snowboard 478 | person hold snowboard 479 | person jump snowboard 480 | person ride snowboard 481 | person stand_on snowboard 482 | person wear snowboard 483 | person no_interaction snowboard 484 | person hold spoon 485 | person lick spoon 486 | person wash spoon 487 | person sip spoon 488 | person no_interaction spoon 489 | person block sports_ball 490 | person carry sports_ball 491 | person catch sports_ball 492 | person dribble sports_ball 493 | person hit sports_ball 494 | person hold sports_ball 495 | person inspect sports_ball 496 | person kick sports_ball 497 | person pick_up sports_ball 498 | person serve sports_ball 499 | person sign sports_ball 500 | person spin sports_ball 501 | person throw sports_ball 502 | person no_interaction sports_ball 503 | person hold stop_sign 504 | person stand_under stop_sign 505 | person stop_at stop_sign 506 | person no_interaction stop_sign 507 | person carry suitcase 508 | person drag suitcase 509 | person hold suitcase 510 | person hug suitcase 511 | person load suitcase 512 | person open suitcase 513 | person pack suitcase 514 | person pick_up suitcase 515 | person zip suitcase 516 | person no_interaction suitcase 517 | person carry surfboard 518 | person drag surfboard 519 | person hold surfboard 520 | person inspect surfboard 521 | person jump surfboard 522 | person lie_on surfboard 523 | person load surfboard 524 | person ride surfboard 525 | person stand_on surfboard 526 | person sit_on surfboard 527 | person wash surfboard 528 | person no_interaction surfboard 529 | person carry teddy_bear 530 | person hold teddy_bear 531 | person hug teddy_bear 532 | person kiss teddy_bear 533 | person no_interaction teddy_bear 534 | person carry tennis_racket 535 | person hold tennis_racket 536 | person inspect tennis_racket 537 | person swing tennis_racket 538 | person no_interaction tennis_racket 539 | person adjust tie 540 | person cut tie 541 | person hold tie 542 | person inspect tie 543 | person pull tie 544 | person tie tie 545 | person wear tie 546 | person no_interaction tie 547 | person hold toaster 548 | person operate toaster 549 | person repair toaster 550 | person no_interaction toaster 551 | person clean toilet 552 | person flush toilet 553 | person open toilet 554 | person repair toilet 555 | person sit_on toilet 556 | person stand_on toilet 557 | person wash toilet 558 | person no_interaction toilet 559 | person brush_with toothbrush 560 | person hold toothbrush 561 | person wash toothbrush 562 | person no_interaction toothbrush 563 | person install traffic_light 564 | person repair traffic_light 565 | person stand_under traffic_light 566 | person stop_at traffic_light 567 | person no_interaction traffic_light 568 | person direct truck 569 | person drive truck 570 | person inspect truck 571 | person load truck 572 | person repair truck 573 | person ride truck 574 | person sit_on truck 575 | person wash truck 576 | person no_interaction truck 577 | person carry umbrella 578 | person hold umbrella 579 | person lose umbrella 580 | person open umbrella 581 | person repair umbrella 582 | person set umbrella 583 | person stand_under umbrella 584 | person no_interaction umbrella 585 | person hold vase 586 | person make vase 587 | person paint vase 588 | person no_interaction vase 589 | person fill wine_glass 590 | person hold wine_glass 591 | person sip wine_glass 592 | person toast wine_glass 593 | person lick wine_glass 594 | person wash wine_glass 595 | person no_interaction wine_glass 596 | person feed zebra 597 | person hold zebra 598 | person pet zebra 599 | person watch zebra 600 | person no_interaction zebra -------------------------------------------------------------------------------- /datasets/imagenet_multi/imagenet_multi_1000_taglist.txt: -------------------------------------------------------------------------------- 1 | tench 2 | goldfish 3 | great white shark 4 | tiger shark 5 | hammerhead shark 6 | electric ray 7 | stingray 8 | rooster 9 | hen 10 | ostrich 11 | brambling 12 | goldfinch 13 | house finch 14 | junco 15 | indigo bunting 16 | American robin 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | American dipper 22 | kite (bird of prey) 23 | bald eagle 24 | vulture 25 | great grey owl 26 | fire salamander 27 | smooth newt 28 | newt 29 | spotted salamander 30 | axolotl 31 | American bullfrog 32 | tree frog 33 | tailed frog 34 | loggerhead sea turtle 35 | leatherback sea turtle 36 | mud turtle 37 | terrapin 38 | box turtle 39 | banded gecko 40 | green iguana 41 | Carolina anole 42 | desert grassland whiptail lizard 43 | agama 44 | frilled-necked lizard 45 | alligator lizard 46 | Gila monster 47 | European green lizard 48 | chameleon 49 | Komodo dragon 50 | Nile crocodile 51 | American alligator 52 | triceratops 53 | worm snake 54 | ring-necked snake 55 | eastern hog-nosed snake 56 | smooth green snake 57 | kingsnake 58 | garter snake 59 | water snake 60 | vine snake 61 | night snake 62 | boa constrictor 63 | African rock python 64 | Indian cobra 65 | green mamba 66 | sea snake 67 | Saharan horned viper 68 | eastern diamondback rattlesnake 69 | sidewinder rattlesnake 70 | trilobite 71 | harvestman 72 | scorpion 73 | yellow garden spider 74 | barn spider 75 | European garden spider 76 | southern black widow 77 | tarantula 78 | wolf spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse 84 | prairie grouse 85 | peafowl 86 | quail 87 | partridge 88 | african grey parrot 89 | macaw 90 | sulphur-crested cockatoo 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | duck 99 | red-breasted merganser 100 | goose 101 | black swan 102 | tusker 103 | echidna 104 | platypus 105 | wallaby 106 | koala 107 | wombat 108 | jellyfish 109 | sea anemone 110 | brain coral 111 | flatworm 112 | nematode 113 | conch 114 | snail 115 | slug 116 | sea slug 117 | chiton 118 | chambered nautilus 119 | Dungeness crab 120 | rock crab 121 | fiddler crab 122 | red king crab 123 | American lobster 124 | spiny lobster 125 | crayfish 126 | hermit crab 127 | isopod 128 | white stork 129 | black stork 130 | spoonbill 131 | flamingo 132 | little blue heron 133 | great egret 134 | bittern bird 135 | crane bird 136 | limpkin 137 | common gallinule 138 | American coot 139 | bustard 140 | ruddy turnstone 141 | dunlin 142 | common redshank 143 | dowitcher 144 | oystercatcher 145 | pelican 146 | king penguin 147 | albatross 148 | grey whale 149 | killer whale 150 | dugong 151 | sea lion 152 | Chihuahua 153 | Japanese Chin 154 | Maltese 155 | Pekingese 156 | Shih Tzu 157 | King Charles Spaniel 158 | Papillon 159 | toy terrier 160 | Rhodesian Ridgeback 161 | Afghan Hound 162 | Basset Hound 163 | Beagle 164 | Bloodhound 165 | Bluetick Coonhound 166 | Black and Tan Coonhound 167 | Treeing Walker Coonhound 168 | English foxhound 169 | Redbone Coonhound 170 | borzoi 171 | Irish Wolfhound 172 | Italian Greyhound 173 | Whippet 174 | Ibizan Hound 175 | Norwegian Elkhound 176 | Otterhound 177 | Saluki 178 | Scottish Deerhound 179 | Weimaraner 180 | Staffordshire Bull Terrier 181 | American Staffordshire Terrier 182 | Bedlington Terrier 183 | Border Terrier 184 | Kerry Blue Terrier 185 | Irish Terrier 186 | Norfolk Terrier 187 | Norwich Terrier 188 | Yorkshire Terrier 189 | Wire Fox Terrier 190 | Lakeland Terrier 191 | Sealyham Terrier 192 | Airedale Terrier 193 | Cairn Terrier 194 | Australian Terrier 195 | Dandie Dinmont Terrier 196 | Boston Terrier 197 | Miniature Schnauzer 198 | Giant Schnauzer 199 | Standard Schnauzer 200 | Scottish Terrier 201 | Tibetan Terrier 202 | Australian Silky Terrier 203 | Soft-coated Wheaten Terrier 204 | West Highland White Terrier 205 | Lhasa Apso 206 | Flat-Coated Retriever 207 | Curly-coated Retriever 208 | Golden Retriever 209 | Labrador Retriever 210 | Chesapeake Bay Retriever 211 | German Shorthaired Pointer 212 | Vizsla 213 | English Setter 214 | Irish Setter 215 | Gordon Setter 216 | Brittany dog 217 | Clumber Spaniel 218 | English Springer Spaniel 219 | Welsh Springer Spaniel 220 | Cocker Spaniel 221 | Sussex Spaniel 222 | Irish Water Spaniel 223 | Kuvasz 224 | Schipperke 225 | Groenendael dog 226 | Malinois 227 | Briard 228 | Australian Kelpie 229 | Komondor 230 | Old English Sheepdog 231 | Shetland Sheepdog 232 | collie 233 | Border Collie 234 | Bouvier des Flandres dog 235 | Rottweiler 236 | German Shepherd Dog 237 | Dobermann 238 | Miniature Pinscher 239 | Greater Swiss Mountain Dog 240 | Bernese Mountain Dog 241 | Appenzeller Sennenhund 242 | Entlebucher Sennenhund 243 | Boxer 244 | Bullmastiff 245 | Tibetan Mastiff 246 | French Bulldog 247 | Great Dane 248 | St. Bernard 249 | husky 250 | Alaskan Malamute 251 | Siberian Husky 252 | Dalmatian 253 | Affenpinscher 254 | Basenji 255 | pug 256 | Leonberger 257 | Newfoundland dog 258 | Great Pyrenees dog 259 | Samoyed 260 | Pomeranian 261 | Chow Chow 262 | Keeshond 263 | brussels griffon 264 | Pembroke Welsh Corgi 265 | Cardigan Welsh Corgi 266 | Toy Poodle 267 | Miniature Poodle 268 | Standard Poodle 269 | Mexican hairless dog (xoloitzcuintli) 270 | grey wolf 271 | Alaskan tundra wolf 272 | red wolf or maned wolf 273 | coyote 274 | dingo 275 | dhole 276 | African wild dog 277 | hyena 278 | red fox 279 | kit fox 280 | Arctic fox 281 | grey fox 282 | tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat 286 | Egyptian Mau 287 | cougar 288 | lynx 289 | leopard 290 | snow leopard 291 | jaguar 292 | lion 293 | tiger 294 | cheetah 295 | brown bear 296 | American black bear 297 | polar bear 298 | sloth bear 299 | mongoose 300 | meerkat 301 | tiger beetle 302 | ladybug 303 | ground beetle 304 | longhorn beetle 305 | leaf beetle 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant 312 | grasshopper 313 | cricket insect 314 | stick insect 315 | cockroach 316 | praying mantis 317 | cicada 318 | leafhopper 319 | lacewing 320 | dragonfly 321 | damselfly 322 | red admiral butterfly 323 | ringlet butterfly 324 | monarch butterfly 325 | small white butterfly 326 | sulphur butterfly 327 | gossamer-winged butterfly 328 | starfish 329 | sea urchin 330 | sea cucumber 331 | cottontail rabbit 332 | hare 333 | Angora rabbit 334 | hamster 335 | porcupine 336 | fox squirrel 337 | marmot 338 | beaver 339 | guinea pig 340 | common sorrel horse 341 | zebra 342 | pig 343 | wild boar 344 | warthog 345 | hippopotamus 346 | ox 347 | water buffalo 348 | bison 349 | ram (adult male sheep) 350 | bighorn sheep 351 | Alpine ibex 352 | hartebeest 353 | impala (antelope) 354 | gazelle 355 | arabian camel 356 | llama 357 | weasel 358 | mink 359 | European polecat 360 | black-footed ferret 361 | otter 362 | skunk 363 | badger 364 | armadillo 365 | three-toed sloth 366 | orangutan 367 | gorilla 368 | chimpanzee 369 | gibbon 370 | siamang 371 | guenon 372 | patas monkey 373 | baboon 374 | macaque 375 | langur 376 | black-and-white colobus 377 | proboscis monkey 378 | marmoset 379 | white-headed capuchin 380 | howler monkey 381 | titi monkey 382 | Geoffroy's spider monkey 383 | common squirrel monkey 384 | ring-tailed lemur 385 | indri 386 | Asian elephant 387 | African bush elephant 388 | red panda 389 | giant panda 390 | snoek fish 391 | eel 392 | silver salmon 393 | rock beauty fish 394 | clownfish 395 | sturgeon 396 | gar fish 397 | lionfish 398 | pufferfish 399 | abacus 400 | abaya 401 | academic gown 402 | accordion 403 | acoustic guitar 404 | aircraft carrier 405 | airliner 406 | airship 407 | altar 408 | ambulance 409 | amphibious vehicle 410 | analog clock 411 | apiary 412 | apron 413 | trash can 414 | assault rifle 415 | backpack 416 | bakery 417 | balance beam 418 | balloon 419 | ballpoint pen 420 | Band-Aid 421 | banjo 422 | baluster / handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel 429 | wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | swimming cap 435 | bath towel 436 | bathtub 437 | station wagon 438 | lighthouse 439 | beaker 440 | military hat (bearskin or shako) 441 | beer bottle 442 | beer glass 443 | bell tower 444 | baby bib 445 | tandem bicycle 446 | bikini 447 | ring binder 448 | binoculars 449 | birdhouse 450 | boathouse 451 | bobsleigh 452 | bolo tie 453 | poke bonnet 454 | bookcase 455 | bookstore 456 | bottle cap 457 | hunting bow 458 | bow tie 459 | brass memorial plaque 460 | bra 461 | breakwater 462 | breastplate 463 | broom 464 | bucket 465 | buckle 466 | bulletproof vest 467 | high-speed train 468 | butcher shop 469 | taxicab 470 | cauldron 471 | candle 472 | cannon 473 | canoe 474 | can opener 475 | cardigan 476 | car mirror 477 | carousel 478 | tool kit 479 | cardboard box / carton 480 | car wheel 481 | automated teller machine 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello 488 | mobile phone 489 | chain 490 | chain-link fence 491 | chain mail 492 | chainsaw 493 | storage chest 494 | chiffonier 495 | bell or wind chime 496 | china cabinet 497 | Christmas stocking 498 | church 499 | movie theater 500 | cleaver 501 | cliff dwelling 502 | cloak 503 | clogs 504 | cocktail shaker 505 | coffee mug 506 | coffeemaker 507 | spiral or coil 508 | combination lock 509 | computer keyboard 510 | candy store 511 | container ship 512 | convertible 513 | corkscrew 514 | cornet 515 | cowboy boot 516 | cowboy hat 517 | cradle 518 | construction crane 519 | crash helmet 520 | crate 521 | infant bed 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam 527 | desk 528 | desktop computer 529 | rotary dial telephone 530 | diaper 531 | digital clock 532 | digital watch 533 | dining table 534 | dishcloth 535 | dishwasher 536 | disc brake 537 | dock 538 | dog sled 539 | dome 540 | doormat 541 | drilling rig 542 | drum 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso machine 552 | face powder 553 | feather boa 554 | filing cabinet 555 | fireboat 556 | fire truck 557 | fire screen 558 | flagpole 559 | flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster bed 566 | freight car 567 | French horn 568 | frying pan 569 | fur coat 570 | garbage truck 571 | gas mask or respirator 572 | gas pump 573 | goblet 574 | go-kart 575 | golf ball 576 | golf cart 577 | gondola 578 | gong 579 | gown 580 | grand piano 581 | greenhouse 582 | radiator grille 583 | grocery store 584 | guillotine 585 | hair clip 586 | hair spray 587 | half-track 588 | hammer 589 | hamper 590 | hair dryer 591 | hand-held computer 592 | handkerchief 593 | hard disk drive 594 | harmonica 595 | harp 596 | combine harvester 597 | hatchet 598 | holster 599 | home theater 600 | honeycomb 601 | hook 602 | hoop skirt 603 | gymnastic horizontal bar 604 | horse-drawn vehicle 605 | hourglass 606 | iPod 607 | clothes iron 608 | carved pumpkin 609 | jeans 610 | jeep 611 | T-shirt 612 | jigsaw puzzle 613 | rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat 619 | ladle 620 | lampshade 621 | laptop computer 622 | lawn mower 623 | lens cap 624 | letter opener 625 | library 626 | lifeboat 627 | lighter 628 | limousine 629 | ocean liner 630 | lipstick 631 | slip-on shoe 632 | lotion 633 | music speaker 634 | loupe magnifying glass 635 | sawmill 636 | magnetic compass 637 | messenger bag 638 | mailbox 639 | tights 640 | one-piece bathing suit 641 | manhole cover 642 | maraca 643 | marimba 644 | mask 645 | matchstick 646 | maypole 647 | maze 648 | measuring cup 649 | medicine cabinet 650 | megalith 651 | microphone 652 | microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home 662 | ford model t 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar and pestle 668 | graduation cap 669 | mosque 670 | mosquito net 671 | vespa 672 | mountain bike 673 | tent 674 | computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | metal nail 679 | neck brace 680 | necklace 681 | baby pacifier 682 | notebook computer 683 | obelisk 684 | oboe 685 | ocarina 686 | odometer 687 | oil filter 688 | pipe organ 689 | oscilloscope 690 | overskirt 691 | bullock cart 692 | oxygen mask 693 | product packet / packaging 694 | paddle 695 | paddle wheel 696 | padlock 697 | paintbrush 698 | pajamas 699 | palace 700 | pan flute 701 | paper towel 702 | parachute 703 | parallel bars 704 | park bench 705 | parking meter 706 | railroad car 707 | patio 708 | payphone 709 | pedestal 710 | pencil case 711 | pencil sharpener 712 | perfume 713 | Petri dish 714 | photocopier 715 | plectrum 716 | Pickelhaube 717 | picket fence 718 | pickup truck 719 | pier 720 | piggy bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate ship 726 | drink pitcher 727 | block plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | farm plow 732 | plunger 733 | Polaroid camera 734 | pole 735 | police van 736 | poncho 737 | pool table 738 | soda bottle 739 | plant pot 740 | potter's wheel 741 | power drill 742 | prayer rug 743 | printer 744 | prison 745 | missile 746 | projector 747 | hockey puck 748 | punching bag 749 | purse 750 | quill 751 | quilt 752 | race car 753 | racket 754 | radiator 755 | radio 756 | radio telescope 757 | rain barrel 758 | recreational vehicle 759 | fishing casting reel 760 | reflex camera 761 | refrigerator 762 | remote control 763 | restaurant 764 | revolver 765 | rifle 766 | rocking chair 767 | rotisserie 768 | eraser 769 | rugby ball 770 | ruler measuring stick 771 | sneaker 772 | safe 773 | safety pin 774 | salt shaker 775 | sandal 776 | sarong 777 | saxophone 778 | scabbard 779 | weighing scale 780 | school bus 781 | schooner 782 | scoreboard 783 | CRT monitor 784 | screw 785 | screwdriver 786 | seat belt 787 | sewing machine 788 | shield 789 | shoe store 790 | shoji screen / room divider 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | balaclava ski mask 798 | sleeping bag 799 | slide rule 800 | sliding door 801 | slot machine 802 | snorkel 803 | snowmobile 804 | snowplow 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar thermal collector 809 | sombrero 810 | soup bowl 811 | keyboard space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | motorboat 816 | spider web 817 | spindle 818 | sports car 819 | spotlight 820 | stage 821 | steam locomotive 822 | through arch bridge 823 | steel drum 824 | stethoscope 825 | scarf 826 | stone wall 827 | stopwatch 828 | stove 829 | strainer 830 | tram 831 | stretcher 832 | couch 833 | stupa 834 | submarine 835 | suit 836 | sundial 837 | sunglasses 838 | sunglasses 839 | sunscreen 840 | suspension bridge 841 | mop 842 | sweatshirt 843 | swim trunks / shorts 844 | swing 845 | electrical switch 846 | syringe 847 | table lamp 848 | tank 849 | tape player 850 | teapot 851 | teddy bear 852 | television 853 | tennis ball 854 | thatched roof 855 | front curtain 856 | thimble 857 | threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck 866 | toy store 867 | tractor 868 | semi-trailer truck 869 | tray 870 | trench coat 871 | tricycle 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus 876 | trombone 877 | hot tub 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle 882 | upright piano 883 | vacuum cleaner 884 | vase 885 | vaulted or arched ceiling 886 | velvet fabric 887 | vending machine 888 | vestment 889 | viaduct 890 | violin 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet 895 | wardrobe 896 | military aircraft 897 | sink 898 | washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | hair wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | airplane wing 910 | wok 911 | wooden spoon 912 | wool 913 | split-rail fence 914 | shipwreck 915 | sailboat 916 | yurt 917 | website 918 | comic book 919 | crossword 920 | traffic or street sign 921 | traffic light 922 | dust jacket 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot 928 | trifle 929 | ice cream 930 | popsicle 931 | baguette 932 | bagel 933 | pretzel 934 | cheeseburger 935 | hot dog 936 | mashed potatoes 937 | cabbage 938 | broccoli 939 | cauliflower 940 | zucchini 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber 945 | artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith apple 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple 955 | banana 956 | jackfruit 957 | cherimoya (custard apple) 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate syrup 962 | dough 963 | meatloaf 964 | pizza 965 | pot pie 966 | burrito 967 | red wine 968 | espresso 969 | tea cup 970 | eggnog 971 | mountain 972 | bubble 973 | cliff 974 | coral reef 975 | geyser 976 | lakeshore 977 | promontory 978 | sandbar 979 | beach 980 | valley 981 | volcano 982 | baseball player 983 | bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper 988 | corn 989 | acorn 990 | rose hip 991 | horse chestnut seed 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn mushroom 996 | earth star fungus 997 | hen of the woods mushroom 998 | bolete 999 | corn cob 1000 | toilet paper -------------------------------------------------------------------------------- /datasets/openimages_common_214/imgs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/datasets/openimages_common_214/imgs/.gitkeep -------------------------------------------------------------------------------- /datasets/openimages_common_214/openimages_common_214_ram_taglist.txt: -------------------------------------------------------------------------------- 1 | accident 2 | accordion 3 | plane 4 | airport 5 | antelope 6 | apple 7 | art gallery 8 | eggplant 9 | auditorium 10 | autumn 11 | baboon 12 | backpack 13 | bakery 14 | bamboo 15 | banana 16 | barbecue 17 | bed 18 | bedroom 19 | clock 20 | bicycle 21 | bikini 22 | birthday cake 23 | blackberry 24 | blueberry 25 | pig 26 | bookcase 27 | bridge 28 | broccoli 29 | bus 30 | butterfly 31 | calculator 32 | calendar 33 | camping 34 | candle 35 | candy 36 | cannon 37 | canyon 38 | car 39 | carousel 40 | cat 41 | cave 42 | ceiling 43 | cheese 44 | cheetah 45 | chef 46 | chicken 47 | christmas 48 | christmas tree 49 | clover 50 | coral 51 | corn 52 | courtyard 53 | crab 54 | lobster 55 | crocodile 56 | crosswalk 57 | crow 58 | cucumber 59 | cup 60 | currency 61 | dachshund 62 | deer 63 | desert 64 | die 65 | dinosaur 66 | dog 67 | dolphin 68 | doodle 69 | dragonfly 70 | drum 71 | duck 72 | dumbbell 73 | easter egg 74 | egg 75 | elephant 76 | faucet 77 | ferris wheel 78 | fire 79 | fireman 80 | firework 81 | flamingo 82 | flower 83 | football 84 | fountain 85 | fox 86 | fridge 87 | frog 88 | ham 89 | gas stove 90 | giraffe 91 | glacier 92 | glove 93 | goat 94 | goose 95 | gorilla 96 | grape 97 | guitar 98 | gull 99 | gym 100 | halloween 101 | hamburger 102 | hamster 103 | handbag 104 | hedgehog 105 | helicopter 106 | horse 107 | hummingbird 108 | jellyfish 109 | kangaroo 110 | kimono 111 | kite 112 | ladybird 113 | laptop 114 | leg 115 | mailbox 116 | library 117 | lightning 118 | lily 119 | lion 120 | lizard 121 | luggage 122 | mannequin 123 | map 124 | mask 125 | mattress 126 | microphone 127 | microwave 128 | monkey 129 | moon 130 | mosque 131 | mouse 132 | mushroom 133 | nebula 134 | sea 135 | ostrich 136 | palm tree 137 | paper 138 | pasta 139 | patient 140 | pavilion 141 | pear 142 | pebble 143 | penguin 144 | pet 145 | piano 146 | picture frame 147 | pine 148 | pineapple 149 | pizza 150 | police car 151 | pomegranate 152 | poodle 153 | popcorn 154 | stamp 155 | power station 156 | printer 157 | pumpkin 158 | raccoon 159 | rainbow 160 | rat 161 | restroom 162 | ring 163 | run 164 | salad 165 | sandwich 166 | sausage 167 | shark 168 | sheet music 169 | shrine 170 | snowboard 171 | snake 172 | sparrow 173 | squirrel 174 | stage 175 | starfish 176 | statue 177 | steering wheel 178 | stream 179 | street art 180 | street light 181 | submarine 182 | suite 183 | surfboard 184 | sushi 185 | swan 186 | tattoo 187 | teddy 188 | tennis court 189 | tennis racket 190 | tiger 191 | toast 192 | toilet bowl 193 | toy 194 | tractor 195 | train 196 | trampoline 197 | treadmill 198 | truck 199 | tunnel 200 | turkey 201 | vending machine 202 | waffle 203 | walnut 204 | washing machine 205 | water buffalo 206 | waterfall 207 | watermelon 208 | wheat 209 | wheelchair 210 | windmill 211 | winter 212 | wolf 213 | woodpecker 214 | zebra 215 | -------------------------------------------------------------------------------- /datasets/openimages_common_214/openimages_common_214_tag2text_tagidlist.txt: -------------------------------------------------------------------------------- 1 | 3 2 | 8 3 | 16 4 | 19 5 | 21 6 | 33 7 | 44 8 | 50 9 | 58 10 | 61 11 | 71 12 | 77 13 | 84 14 | 96 15 | 117 16 | 139 17 | 142 18 | 147 19 | 180 20 | 200 21 | 202 22 | 206 23 | 244 24 | 267 25 | 317 26 | 321 27 | 347 28 | 361 29 | 380 30 | 387 31 | 398 32 | 407 33 | 471 34 | 486 35 | 489 36 | 509 37 | 514 38 | 530 39 | 568 40 | 590 41 | 595 42 | 612 43 | 622 44 | 626 45 | 654 46 | 658 47 | 664 48 | 684 49 | 699 50 | 704 51 | 717 52 | 720 53 | 727 54 | 760 55 | 773 56 | 786 57 | 787 58 | 812 59 | 814 60 | 817 61 | 843 62 | 855 63 | 856 64 | 907 65 | 950 66 | 955 67 | 957 68 | 1023 69 | 1042 70 | 1056 71 | 1066 72 | 1091 73 | 1094 74 | 1108 75 | 1141 76 | 1148 77 | 1152 78 | 1168 79 | 1174 80 | 1187 81 | 1231 82 | 1235 83 | 1246 84 | 1276 85 | 1277 86 | 1305 87 | 1308 88 | 1344 89 | 1359 90 | 1362 91 | 1393 92 | 1394 93 | 1410 94 | 1411 95 | 1468 96 | 1504 97 | 1524 98 | 1536 99 | 1540 100 | 1542 101 | 1546 102 | 1553 103 | 1572 104 | 1574 105 | 1606 106 | 1610 107 | 1615 108 | 1655 109 | 1672 110 | 1680 111 | 1682 112 | 1687 113 | 1691 114 | 1692 115 | 1711 116 | 1712 117 | 1713 118 | 1719 119 | 1727 120 | 1733 121 | 1761 122 | 1770 123 | 1782 124 | 1784 125 | 1786 126 | 1803 127 | 1812 128 | 1816 129 | 1820 130 | 1829 131 | 1831 132 | 1841 133 | 1845 134 | 1878 135 | 1882 136 | 1931 137 | 1940 138 | 1944 139 | 1947 140 | 1974 141 | 1975 142 | 1977 143 | 2009 144 | 2031 145 | 2035 146 | 2052 147 | 2065 148 | 2110 149 | 2113 150 | 2138 151 | 2149 152 | 2154 153 | 2157 154 | 2174 155 | 2178 156 | 2184 157 | 2185 158 | 2202 159 | 2222 160 | 2233 161 | 2291 162 | 2301 163 | 2302 164 | 2317 165 | 2320 166 | 2351 167 | 2354 168 | 2373 169 | 2383 170 | 2393 171 | 2403 172 | 2413 173 | 2415 174 | 2417 175 | 2423 176 | 2449 177 | 2454 178 | 2455 179 | 2472 180 | 2494 181 | 2495 182 | 2528 183 | 2541 184 | 2543 185 | 2553 186 | 2563 187 | 2589 188 | 2603 189 | 2654 190 | 2656 191 | 2658 192 | 2676 193 | 2690 194 | 2693 195 | 2700 196 | 2708 197 | 2720 198 | 2721 199 | 2729 200 | 2732 201 | 2734 202 | 2756 203 | 2786 204 | 2792 205 | 2801 206 | 2821 207 | 2851 208 | 2887 209 | 2906 210 | 2909 211 | 2924 212 | 2929 213 | 2966 214 | 2980 215 | -------------------------------------------------------------------------------- /datasets/openimages_rare_200/imgs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/datasets/openimages_rare_200/imgs/.gitkeep -------------------------------------------------------------------------------- /datasets/openimages_rare_200/openimages_rare_200_ram_taglist.txt: -------------------------------------------------------------------------------- 1 | Aerial photography 2 | Aircraft engine 3 | Ale 4 | Aloe 5 | Amphibian 6 | Angling 7 | Anole 8 | Antique car 9 | Arcade game 10 | Arthropod 11 | Assault rifle 12 | Athletic shoe 13 | Auto racing 14 | Backlighting 15 | Bagpipes 16 | Ball game 17 | Barbecue chicken 18 | Barechested 19 | Barquentine 20 | Beef tenderloin 21 | Billiard room 22 | Billiards 23 | Bird of prey 24 | Black swan 25 | Black-and-white 26 | Blond 27 | Boating 28 | Bonbon 29 | Bottled water 30 | Bouldering 31 | Bovine 32 | Bratwurst 33 | Breadboard 34 | Briefs 35 | Brisket 36 | Brochette 37 | Calabaza 38 | Camera operator 39 | Canola 40 | Childbirth 41 | Chordophone 42 | Church bell 43 | Classical sculpture 44 | Close-up 45 | Cobblestone 46 | Coca-cola 47 | Combat sport 48 | Comics 49 | Compact car 50 | Computer speaker 51 | Cookies and crackers 52 | Coral reef fish 53 | Corn on the cob 54 | Cosmetics 55 | Crocodilia 56 | Digital camera 57 | Dishware 58 | Divemaster 59 | Dobermann 60 | Dog walking 61 | Domestic rabbit 62 | Domestic short-haired cat 63 | Double-decker bus 64 | Drums 65 | Electric guitar 66 | Electric piano 67 | Electronic instrument 68 | Equestrianism 69 | Equitation 70 | Erinaceidae 71 | Extreme sport 72 | Falafel 73 | Figure skating 74 | Filling station 75 | Fire apparatus 76 | Firearm 77 | Flatbread 78 | Floristry 79 | Forklift truck 80 | Freight transport 81 | Fried food 82 | Fried noodles 83 | Frigate 84 | Frozen yogurt 85 | Frying 86 | Full moon 87 | Galleon 88 | Glacial landform 89 | Gliding 90 | Go-kart 91 | Goats 92 | Grappling 93 | Great white shark 94 | Gumbo 95 | Gun turret 96 | Hair coloring 97 | Halter 98 | Headphones 99 | Heavy cruiser 100 | Herding 101 | High-speed rail 102 | Holding hands 103 | Horse and buggy 104 | Horse racing 105 | Hound 106 | Hunting knife 107 | Hurdling 108 | Inflatable 109 | Jackfruit 110 | Jeans 111 | Jiaozi 112 | Junk food 113 | Khinkali 114 | Kitesurfing 115 | Lawn game 116 | Leaf vegetable 117 | Lechon 118 | Lifebuoy 119 | Locust 120 | Lumpia 121 | Luxury vehicle 122 | Machine tool 123 | Medical imaging 124 | Melee weapon 125 | Microcontroller 126 | Middle ages 127 | Military person 128 | Military vehicle 129 | Milky way 130 | Miniature Poodle 131 | Modern dance 132 | Molluscs 133 | Monoplane 134 | Motorcycling 135 | Musical theatre 136 | Narcissus 137 | Nest box 138 | Newsagent's shop 139 | Nile crocodile 140 | Nordic skiing 141 | Nuclear power plant 142 | Orator 143 | Outdoor shoe 144 | Parachuting 145 | Pasta salad 146 | Peafowl 147 | Pelmeni 148 | Perching bird 149 | Performance car 150 | Personal water craft 151 | Pit bull 152 | Plant stem 153 | Pork chop 154 | Portrait photography 155 | Primate 156 | Procyonidae 157 | Prosciutto 158 | Public speaking 159 | Racewalking 160 | Ramen 161 | Rear-view mirror 162 | Residential area 163 | Ribs 164 | Rice ball 165 | Road cycling 166 | Roller skating 167 | Roman temple 168 | Rowing 169 | Rural area 170 | Sailboat racing 171 | Scaled reptile 172 | Scuba diving 173 | Senior citizen 174 | Shallot 175 | Shinto shrine 176 | Shooting range 177 | Siberian husky 178 | Sledding 179 | Soba 180 | Solar energy 181 | Sport climbing 182 | Sport utility vehicle 183 | Steamed rice 184 | Stemware 185 | Sumo 186 | Surfing Equipment 187 | Team sport 188 | Touring car 189 | Toy block 190 | Trampolining 191 | Underwater diving 192 | Vegetarian food 193 | Wallaby 194 | Water polo 195 | Watercolor paint 196 | Whiskers 197 | Wind wave 198 | Woodwind instrument 199 | Yakitori 200 | Zeppelin 201 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * RAM++ & RAM & Tag2Text finetune 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import os 7 | import ruamel.yaml as yaml 8 | import numpy as np 9 | import random 10 | import time 11 | import datetime 12 | import json 13 | from pathlib import Path 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | from torch.utils.data import DataLoader 21 | 22 | from ram.models import ram_plus, ram, tag2text 23 | import utils 24 | from utils import cosine_lr_schedule 25 | from ram.data import create_dataset, create_sampler, create_loader 26 | 27 | import clip 28 | 29 | def build_text_embed(model_clip, caption): 30 | run_on_gpu = torch.cuda.is_available() 31 | with torch.no_grad(): 32 | 33 | texts = clip.tokenize(caption,truncate = True) # tokenize 34 | if run_on_gpu: 35 | texts = texts.cuda() 36 | model_clip = model_clip.cuda() 37 | text_embeddings = model_clip.encode_text(texts) 38 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 39 | return text_embeddings 40 | 41 | 42 | 43 | def train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip): 44 | # train 45 | model.train() 46 | 47 | metric_logger = utils.MetricLogger(delimiter=" ") 48 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 49 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 50 | metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 51 | metric_logger.add_meter('loss_alignment', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 52 | 53 | header = 'Train Epoch: [{}]'.format(epoch) 54 | print_freq = 50 55 | 56 | data_loader.sampler.set_epoch(epoch) 57 | 58 | for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 59 | 60 | optimizer.zero_grad() 61 | 62 | batch_text_embed = build_text_embed(model_clip,caption) 63 | 64 | image = image.to(device,non_blocking=True) 65 | image_224 = image_224.to(device,non_blocking=True) 66 | 67 | with torch.no_grad(): 68 | clip_image_feature = model_clip.encode_image(image_224) 69 | 70 | loss_tag, loss_dis, loss_alignment = model(image, caption, image_tag, clip_image_feature, batch_text_embed) 71 | loss = loss_tag + loss_dis + loss_alignment 72 | 73 | loss.backward() 74 | optimizer.step() 75 | 76 | metric_logger.update(loss_tag=loss_tag.item()) 77 | metric_logger.update(loss_dis=loss_dis.item()) 78 | metric_logger.update(loss_alignment=loss_alignment.item()) 79 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 80 | 81 | 82 | # gather the stats from all processes 83 | metric_logger.synchronize_between_processes() 84 | print("Averaged stats:", metric_logger.global_avg()) 85 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 86 | 87 | 88 | 89 | def train_ram(model, data_loader, optimizer, epoch, device, config, model_clip): 90 | # train 91 | model.train() 92 | 93 | metric_logger = utils.MetricLogger(delimiter=" ") 94 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 95 | metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 96 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 97 | metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 98 | 99 | header = 'Train Epoch: [{}]'.format(epoch) 100 | print_freq = 50 101 | 102 | data_loader.sampler.set_epoch(epoch) 103 | 104 | for i, (image, image_224, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 105 | 106 | optimizer.zero_grad() 107 | 108 | image = image.to(device,non_blocking=True) 109 | image_224 = image_224.to(device,non_blocking=True) 110 | 111 | with torch.no_grad(): 112 | clip_image_feature = model_clip.encode_image(image_224) 113 | 114 | loss_t2t, loss_tag, loss_dis = model(image, caption, image_tag, parse_tag, clip_image_feature) 115 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() + loss_dis 116 | 117 | loss.backward() 118 | optimizer.step() 119 | 120 | metric_logger.update(loss_t2t=loss_t2t.item()) 121 | metric_logger.update(loss_tag=loss_tag.item()) 122 | metric_logger.update(loss_dis=loss_dis.item()) 123 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 124 | 125 | 126 | # gather the stats from all processes 127 | metric_logger.synchronize_between_processes() 128 | print("Averaged stats:", metric_logger.global_avg()) 129 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 130 | 131 | 132 | def train_tag2text(model, data_loader, optimizer, epoch, device, config): 133 | # train 134 | model.train() 135 | 136 | metric_logger = utils.MetricLogger(delimiter=" ") 137 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 138 | metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 139 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 140 | 141 | header = 'Train Epoch: [{}]'.format(epoch) 142 | print_freq = 50 143 | 144 | data_loader.sampler.set_epoch(epoch) 145 | 146 | for i, (image, _, caption, _, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 147 | 148 | 149 | optimizer.zero_grad() 150 | 151 | image = image.to(device,non_blocking=True) 152 | 153 | loss_t2t, loss_tag = model(image, caption, parse_tag) 154 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() 155 | 156 | loss.backward() 157 | optimizer.step() 158 | 159 | metric_logger.update(loss_t2t=loss_t2t.item()) 160 | metric_logger.update(loss_tag=loss_tag.item()) 161 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 162 | 163 | 164 | # gather the stats from all processes 165 | metric_logger.synchronize_between_processes() 166 | print("Averaged stats:", metric_logger.global_avg()) 167 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 168 | 169 | 170 | def main(args, config): 171 | utils.init_distributed_mode(args) 172 | 173 | device = torch.device(args.device) 174 | 175 | # fix the seed for reproducibility 176 | seed = args.seed + utils.get_rank() 177 | torch.manual_seed(seed) 178 | np.random.seed(seed) 179 | random.seed(seed) 180 | cudnn.benchmark = True 181 | 182 | #### Dataset #### 183 | print("Creating dataset") 184 | datasets = [create_dataset('finetune', config, min_scale=0.2)] 185 | print('number of training samples: %d'%len(datasets[0])) 186 | 187 | num_tasks = utils.get_world_size() 188 | global_rank = utils.get_rank() 189 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 190 | 191 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 192 | 193 | print("Creating model") 194 | if args.checkpoint: 195 | print("load from:", args.checkpoint) 196 | 197 | #### Model #### 198 | if args.model_type == 'ram_plus': 199 | print("Creating pretrained CLIP model") 200 | model_clip, _ = clip.load("ViT-B/16", device=device) 201 | 202 | print("Creating RAM model") 203 | model = ram_plus(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 204 | vit_ckpt_layer=config['vit_ckpt_layer']) 205 | 206 | elif args.model_type == 'ram': 207 | print("Creating pretrained CLIP model") 208 | model_clip, _ = clip.load("ViT-B/16", device=device) 209 | 210 | print("Creating RAM model") 211 | model = ram(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 212 | vit_ckpt_layer=config['vit_ckpt_layer']) 213 | 214 | elif args.model_type == 'tag2text': 215 | print("Creating Tag2Text model") 216 | model = tag2text(pretrained = args.checkpoint,image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 217 | vit_ckpt_layer=config['vit_ckpt_layer'], tag_list='ram/data/ram_tag_list.txt') 218 | model = model.to(device) 219 | 220 | ### Frozen CLIP model ### 221 | model_clip = model_clip.to(device) 222 | for _, param in model_clip.named_parameters(): 223 | param.requires_grad = False 224 | 225 | ### Frozen label embedding for open-set recogniztion ### 226 | model.label_embed.requires_grad = False 227 | optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), lr=config['init_lr'], weight_decay=config['weight_decay']) 228 | 229 | start_epoch = 0 230 | 231 | model_without_ddp = model 232 | if args.distributed: 233 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 234 | model_without_ddp = model.module 235 | 236 | print("Start training") 237 | start_time = time.time() 238 | for epoch in range(start_epoch, config['max_epoch']): 239 | 240 | cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr']) 241 | 242 | if args.model_type == 'ram_plus': 243 | train_stats = train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip) 244 | elif args.model_type == 'ram': 245 | train_stats = train_ram(model, data_loader, optimizer, epoch, device, config, model_clip) 246 | elif args.model_type == 'tag2text': 247 | train_stats = train_tag2text(model, data_loader, optimizer, epoch, device, config) 248 | 249 | if utils.is_main_process(): 250 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 251 | 'epoch': epoch, 252 | } 253 | save_obj = { 254 | 'model': model_without_ddp.state_dict(), 255 | 'optimizer': optimizer.state_dict(), 256 | 'config': config, 257 | 'epoch': epoch, 258 | } 259 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 260 | 261 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 262 | f.write(json.dumps(log_stats) + "\n") 263 | 264 | dist.barrier() 265 | 266 | total_time = time.time() - start_time 267 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 268 | print('Training time {}'.format(total_time_str)) 269 | 270 | 271 | if __name__ == '__main__': 272 | parser = argparse.ArgumentParser() 273 | parser.add_argument('--config', default='./configs/pretrain.yaml') 274 | parser.add_argument("--model-type",type=str,choices=("ram_plus", "ram", "tag2text"),required=True) 275 | parser.add_argument('--output-dir', default='output/Pretrain') 276 | parser.add_argument('--checkpoint', default='') 277 | parser.add_argument('--evaluate', action='store_true') 278 | parser.add_argument('--device', default='cuda') 279 | parser.add_argument('--seed', default=42, type=int) 280 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 281 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 282 | parser.add_argument('--distributed', default=True, type=bool) 283 | args = parser.parse_args() 284 | 285 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 286 | 287 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 288 | 289 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 290 | 291 | main(args, config) -------------------------------------------------------------------------------- /generate_tag_des_llm.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import json 3 | from tqdm import tqdm 4 | import argparse 5 | from ram.utils.openset_utils import openimages_rare_unseen 6 | 7 | parser = argparse.ArgumentParser( 8 | description='Generate LLM tag descriptions for RAM++ open-set recognition') 9 | parser.add_argument('--openai_api_key', 10 | default='sk-xxxxx') 11 | parser.add_argument('--output_file_path', 12 | help='save path of llm tag descriptions', 13 | default='datasets/openimages_rare_200/openimages_rare_200_llm_tag_descriptions.json') 14 | 15 | 16 | def analyze_tags(tag): 17 | # Generate LLM tag descriptions 18 | 19 | llm_prompts = [ f"Describe concisely what a(n) {tag} looks like:", \ 20 | f"How can you identify a(n) {tag} concisely?", \ 21 | f"What does a(n) {tag} look like concisely?",\ 22 | f"What are the identifying characteristics of a(n) {tag}:", \ 23 | f"Please provide a concise description of the visual characteristics of {tag}:"] 24 | 25 | results = {} 26 | result_lines = [] 27 | 28 | result_lines.append(f"a photo of a {tag}.") 29 | 30 | for llm_prompt in tqdm(llm_prompts): 31 | 32 | # send message 33 | response = openai.ChatCompletion.create( 34 | model="gpt-3.5-turbo", 35 | messages=[{"role": "assistant", "content": llm_prompt}], 36 | max_tokens=77, 37 | temperature=0.99, 38 | n=10, 39 | stop=None 40 | ) 41 | 42 | # parse the response 43 | for item in response.choices: 44 | result_lines.append(item.message['content'].strip()) 45 | results[tag] = result_lines 46 | return results 47 | 48 | 49 | if __name__ == "__main__": 50 | 51 | args = parser.parse_args() 52 | 53 | # set OpenAI API key 54 | openai.api_key = args.openai_api_key 55 | 56 | categories = openimages_rare_unseen 57 | 58 | tag_descriptions = [] 59 | 60 | for tag in categories: 61 | result = analyze_tags(tag) 62 | tag_descriptions.append(result) 63 | 64 | output_file_path = args.output_file_path 65 | 66 | with open(output_file_path, 'w') as w: 67 | json.dump(tag_descriptions, w, indent=3) 68 | 69 | -------------------------------------------------------------------------------- /images/1641173_2291260800.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/1641173_2291260800.jpg -------------------------------------------------------------------------------- /images/demo/demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/demo/demo1.jpg -------------------------------------------------------------------------------- /images/demo/demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/demo/demo2.jpg -------------------------------------------------------------------------------- /images/demo/demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/demo/demo3.jpg -------------------------------------------------------------------------------- /images/demo/demo4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/demo/demo4.jpg -------------------------------------------------------------------------------- /images/experiment_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/experiment_comparison.png -------------------------------------------------------------------------------- /images/localization_and_recognition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/localization_and_recognition.jpg -------------------------------------------------------------------------------- /images/openset_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/openset_example.jpg -------------------------------------------------------------------------------- /images/ram_grounded_sam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/ram_grounded_sam.jpg -------------------------------------------------------------------------------- /images/ram_plus_compare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/ram_plus_compare.jpg -------------------------------------------------------------------------------- /images/ram_plus_experiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/ram_plus_experiment.png -------------------------------------------------------------------------------- /images/ram_plus_framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/ram_plus_framework.jpg -------------------------------------------------------------------------------- /images/ram_plus_visualization.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/ram_plus_visualization.jpg -------------------------------------------------------------------------------- /images/tag2text_framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/tag2text_framework.png -------------------------------------------------------------------------------- /images/tag2text_grounded_sam.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/tag2text_grounded_sam.jpg -------------------------------------------------------------------------------- /images/tag2text_retrieval_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/tag2text_retrieval_visualization.png -------------------------------------------------------------------------------- /images/tag2text_visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/tag2text_visualization.png -------------------------------------------------------------------------------- /images/tagging_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xinyu1205/recognize-anything/7cb804a8609e9f4b1a50b7f31436d2df40bb9481/images/tagging_results.jpg -------------------------------------------------------------------------------- /inference_ram.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Model (RAM) 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | 11 | from PIL import Image 12 | from ram.models import ram 13 | from ram import inference_ram as inference 14 | from ram import get_transform 15 | 16 | 17 | parser = argparse.ArgumentParser( 18 | description='Tag2Text inferece for tagging and captioning') 19 | parser.add_argument('--image', 20 | metavar='DIR', 21 | help='path to dataset', 22 | default='images/demo/demo1.jpg') 23 | parser.add_argument('--pretrained', 24 | metavar='DIR', 25 | help='path to pretrained model', 26 | default='pretrained/ram_swin_large_14m.pth') 27 | parser.add_argument('--image-size', 28 | default=384, 29 | type=int, 30 | metavar='N', 31 | help='input image size (default: 384)') 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | args = parser.parse_args() 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | transform = get_transform(image_size=args.image_size) 41 | 42 | #######load model 43 | model = ram(pretrained=args.pretrained, 44 | image_size=args.image_size, 45 | vit='swin_l') 46 | model.eval() 47 | 48 | model = model.to(device) 49 | 50 | image = transform(Image.open(args.image)).unsqueeze(0).to(device) 51 | 52 | res = inference(image, model) 53 | print("Image Tags: ", res[0]) 54 | print("图像标签: ", res[1]) 55 | -------------------------------------------------------------------------------- /inference_ram_openset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Model (RAM) inference on unseen classes 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | 11 | from PIL import Image 12 | from ram.models import ram 13 | from ram import inference_ram_openset as inference 14 | from ram import get_transform 15 | 16 | from ram.utils import build_openset_label_embedding 17 | from torch import nn 18 | 19 | parser = argparse.ArgumentParser( 20 | description='Tag2Text inferece for tagging and captioning') 21 | parser.add_argument('--image', 22 | metavar='DIR', 23 | help='path to dataset', 24 | default='images/openset_example.jpg') 25 | parser.add_argument('--pretrained', 26 | metavar='DIR', 27 | help='path to pretrained model', 28 | default='pretrained/ram_swin_large_14m.pth') 29 | parser.add_argument('--image-size', 30 | default=384, 31 | type=int, 32 | metavar='N', 33 | help='input image size (default: 448)') 34 | 35 | 36 | if __name__ == "__main__": 37 | 38 | args = parser.parse_args() 39 | 40 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 41 | 42 | transform = get_transform(image_size=args.image_size) 43 | 44 | #######load model 45 | model = ram(pretrained=args.pretrained, 46 | image_size=args.image_size, 47 | vit='swin_l') 48 | 49 | #######set openset interference 50 | openset_label_embedding, openset_categories = build_openset_label_embedding() 51 | 52 | model.tag_list = np.array(openset_categories) 53 | 54 | model.label_embed = nn.Parameter(openset_label_embedding.float()) 55 | 56 | model.num_class = len(openset_categories) 57 | # the threshold for unseen categories is often lower 58 | model.class_threshold = torch.ones(model.num_class) * 0.5 59 | ####### 60 | 61 | model.eval() 62 | 63 | model = model.to(device) 64 | 65 | image = transform(Image.open(args.image)).unsqueeze(0).to(device) 66 | 67 | res = inference(image, model) 68 | print("Image Tags: ", res) 69 | -------------------------------------------------------------------------------- /inference_ram_plus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Plus Model (RAM++) 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | 11 | from PIL import Image 12 | from ram.models import ram_plus 13 | from ram import inference_ram as inference 14 | from ram import get_transform 15 | 16 | 17 | parser = argparse.ArgumentParser( 18 | description='Tag2Text inferece for tagging and captioning') 19 | parser.add_argument('--image', 20 | metavar='DIR', 21 | help='path to dataset', 22 | default='images/demo/demo1.jpg') 23 | parser.add_argument('--pretrained', 24 | metavar='DIR', 25 | help='path to pretrained model', 26 | default='pretrained/ram_plus_swin_large_14m.pth') 27 | parser.add_argument('--image-size', 28 | default=384, 29 | type=int, 30 | metavar='N', 31 | help='input image size (default: 448)') 32 | 33 | 34 | if __name__ == "__main__": 35 | 36 | args = parser.parse_args() 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | transform = get_transform(image_size=args.image_size) 41 | 42 | #######load model 43 | model = ram_plus(pretrained=args.pretrained, 44 | image_size=args.image_size, 45 | vit='swin_l') 46 | model.eval() 47 | 48 | model = model.to(device) 49 | 50 | image = transform(Image.open(args.image)).unsqueeze(0).to(device) 51 | 52 | res = inference(image, model) 53 | print("Image Tags: ", res[0]) 54 | print("图像标签: ", res[1]) 55 | -------------------------------------------------------------------------------- /inference_ram_plus_openset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Plus Model (RAM++) inference on unseen classes 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | 11 | from PIL import Image 12 | from ram.models import ram_plus 13 | from ram import inference_ram_openset as inference 14 | from ram import get_transform 15 | 16 | from ram.utils import build_openset_llm_label_embedding 17 | from torch import nn 18 | import json 19 | 20 | parser = argparse.ArgumentParser( 21 | description='Tag2Text inferece for tagging and captioning') 22 | parser.add_argument('--image', 23 | metavar='DIR', 24 | help='path to dataset', 25 | default='images/openset_example.jpg') 26 | parser.add_argument('--pretrained', 27 | metavar='DIR', 28 | help='path to pretrained model', 29 | default='pretrained/ram_plus_swin_large_14m.pth') 30 | parser.add_argument('--image-size', 31 | default=384, 32 | type=int, 33 | metavar='N', 34 | help='input image size (default: 448)') 35 | parser.add_argument('--llm_tag_des', 36 | metavar='DIR', 37 | help='path to LLM tag descriptions', 38 | default='datasets/openimages_rare_200/openimages_rare_200_llm_tag_descriptions.json') 39 | 40 | if __name__ == "__main__": 41 | 42 | args = parser.parse_args() 43 | 44 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 45 | 46 | transform = get_transform(image_size=args.image_size) 47 | 48 | #######load model 49 | model = ram_plus(pretrained=args.pretrained, 50 | image_size=args.image_size, 51 | vit='swin_l') 52 | 53 | #######set openset interference 54 | 55 | print('Building tag embedding:') 56 | with open(args.llm_tag_des, 'rb') as fo: 57 | llm_tag_des = json.load(fo) 58 | openset_label_embedding, openset_categories = build_openset_llm_label_embedding(llm_tag_des) 59 | 60 | model.tag_list = np.array(openset_categories) 61 | 62 | model.label_embed = nn.Parameter(openset_label_embedding.float()) 63 | 64 | model.num_class = len(openset_categories) 65 | # the threshold for unseen categories is often lower 66 | model.class_threshold = torch.ones(model.num_class) * 0.5 67 | ####### 68 | 69 | model.eval() 70 | 71 | model = model.to(device) 72 | 73 | image = transform(Image.open(args.image)).unsqueeze(0).to(device) 74 | 75 | res = inference(image, model) 76 | print("Image Tags: ", res) 77 | -------------------------------------------------------------------------------- /inference_tag2text.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Tag2Text Model 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import numpy as np 7 | import random 8 | 9 | import torch 10 | 11 | from PIL import Image 12 | from ram.models import tag2text 13 | from ram import inference_tag2text as inference 14 | from ram import get_transform 15 | 16 | 17 | parser = argparse.ArgumentParser( 18 | description='Tag2Text inferece for tagging and captioning') 19 | parser.add_argument('--image', 20 | metavar='DIR', 21 | help='path to dataset', 22 | default='images/1641173_2291260800.jpg') 23 | parser.add_argument('--pretrained', 24 | metavar='DIR', 25 | help='path to pretrained model', 26 | default='pretrained/tag2text_swin_14m.pth') 27 | parser.add_argument('--image-size', 28 | default=384, 29 | type=int, 30 | metavar='N', 31 | help='input image size (default: 384)') 32 | parser.add_argument('--thre', 33 | default=0.68, 34 | type=float, 35 | metavar='N', 36 | help='threshold value') 37 | parser.add_argument('--specified-tags', 38 | default='None', 39 | help='User input specified tags') 40 | 41 | 42 | if __name__ == "__main__": 43 | 44 | args = parser.parse_args() 45 | 46 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 47 | 48 | transform = get_transform(image_size=args.image_size) 49 | 50 | # delete some tags that may disturb captioning 51 | # 127: "quarter"; 2961: "back", 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" 52 | delete_tag_index = [127,2961, 3351, 3265, 3338, 3355, 3359] 53 | 54 | #######load model 55 | model = tag2text(pretrained=args.pretrained, 56 | image_size=args.image_size, 57 | vit='swin_b', 58 | delete_tag_index=delete_tag_index) 59 | model.threshold = args.thre # threshold for tagging 60 | model.eval() 61 | 62 | model = model.to(device) 63 | 64 | image = transform(Image.open(args.image)).unsqueeze(0).to(device) 65 | 66 | res = inference(image, model, args.specified_tags) 67 | print("Model Identified Tags: ", res[0]) 68 | print("User Specified Tags: ", res[1]) 69 | print("Image Caption: ", res[2]) 70 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * RAM++ & RAM & Tag2Text pretrain 3 | * Written by Xinyu Huang 4 | ''' 5 | import argparse 6 | import os 7 | import ruamel.yaml as yaml 8 | import numpy as np 9 | import random 10 | import time 11 | import datetime 12 | import json 13 | from pathlib import Path 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | from torch.utils.data import DataLoader 21 | 22 | from ram.models import ram_plus, ram, tag2text 23 | import utils 24 | from utils import warmup_lr_schedule, step_lr_schedule 25 | from ram.data import create_dataset, create_sampler, create_loader 26 | 27 | import clip 28 | 29 | def build_text_embed(model_clip, caption): 30 | run_on_gpu = torch.cuda.is_available() 31 | with torch.no_grad(): 32 | 33 | texts = clip.tokenize(caption,truncate = True) # tokenize 34 | if run_on_gpu: 35 | texts = texts.cuda() 36 | model_clip = model_clip.cuda() 37 | text_embeddings = model_clip.encode_text(texts) 38 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 39 | # text_embedding = text_embeddings.mean(dim=0) 40 | # text_embedding /= text_embedding.norm() 41 | return text_embeddings 42 | 43 | 44 | 45 | def train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip): 46 | # train 47 | model.train() 48 | 49 | metric_logger = utils.MetricLogger(delimiter=" ") 50 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 51 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 52 | metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 53 | metric_logger.add_meter('loss_alignment', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 54 | 55 | header = 'Train Epoch: [{}]'.format(epoch) 56 | print_freq = 50 57 | 58 | data_loader.sampler.set_epoch(epoch) 59 | 60 | for i, (image, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 61 | 62 | if epoch==0: 63 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 64 | 65 | optimizer.zero_grad() 66 | 67 | batch_text_embed = build_text_embed(model_clip,caption) 68 | 69 | image = image.to(device,non_blocking=True) 70 | 71 | with torch.no_grad(): 72 | clip_image_feature = model_clip.encode_image(image) 73 | 74 | loss_tag, loss_dis, loss_alignment = model(image, caption, image_tag, clip_image_feature, batch_text_embed) 75 | loss = loss_tag + loss_dis + loss_alignment 76 | 77 | loss.backward() 78 | optimizer.step() 79 | 80 | metric_logger.update(loss_tag=loss_tag.item()) 81 | metric_logger.update(loss_dis=loss_dis.item()) 82 | metric_logger.update(loss_alignment=loss_alignment.item()) 83 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 84 | 85 | 86 | # gather the stats from all processes 87 | metric_logger.synchronize_between_processes() 88 | print("Averaged stats:", metric_logger.global_avg()) 89 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 90 | 91 | 92 | 93 | def train_ram(model, data_loader, optimizer, epoch, device, config, model_clip): 94 | # train 95 | model.train() 96 | 97 | metric_logger = utils.MetricLogger(delimiter=" ") 98 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 99 | metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 100 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 101 | metric_logger.add_meter('loss_dis', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 102 | 103 | header = 'Train Epoch: [{}]'.format(epoch) 104 | print_freq = 50 105 | 106 | data_loader.sampler.set_epoch(epoch) 107 | 108 | for i, (image, caption, image_tag, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 109 | 110 | if epoch==0: 111 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 112 | 113 | optimizer.zero_grad() 114 | 115 | image = image.to(device,non_blocking=True) 116 | 117 | with torch.no_grad(): 118 | clip_image_feature = model_clip.encode_image(image) 119 | 120 | loss_t2t, loss_tag, loss_dis = model(image, caption, image_tag, parse_tag, clip_image_feature) 121 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() + loss_dis 122 | 123 | loss.backward() 124 | optimizer.step() 125 | 126 | metric_logger.update(loss_t2t=loss_t2t.item()) 127 | metric_logger.update(loss_tag=loss_tag.item()) 128 | metric_logger.update(loss_dis=loss_dis.item()) 129 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 130 | 131 | 132 | # gather the stats from all processes 133 | metric_logger.synchronize_between_processes() 134 | print("Averaged stats:", metric_logger.global_avg()) 135 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 136 | 137 | 138 | def train_tag2text(model, data_loader, optimizer, epoch, device, config): 139 | # train 140 | model.train() 141 | 142 | metric_logger = utils.MetricLogger(delimiter=" ") 143 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}')) 144 | metric_logger.add_meter('loss_t2t', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 145 | metric_logger.add_meter('loss_tag', utils.SmoothedValue(window_size=50, fmt='{value:.4f}')) 146 | 147 | header = 'Train Epoch: [{}]'.format(epoch) 148 | print_freq = 50 149 | 150 | data_loader.sampler.set_epoch(epoch) 151 | 152 | for i, (image, caption, _, parse_tag) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 153 | 154 | if epoch==0: 155 | warmup_lr_schedule(optimizer, i, config['warmup_steps'], config['warmup_lr'], config['init_lr']) 156 | 157 | optimizer.zero_grad() 158 | 159 | image = image.to(device,non_blocking=True) 160 | 161 | loss_t2t, loss_tag = model(image, caption, parse_tag) 162 | loss = loss_t2t + loss_tag/(loss_tag/loss_t2t).detach() 163 | 164 | loss.backward() 165 | optimizer.step() 166 | 167 | metric_logger.update(loss_t2t=loss_t2t.item()) 168 | metric_logger.update(loss_tag=loss_tag.item()) 169 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 170 | 171 | 172 | # gather the stats from all processes 173 | metric_logger.synchronize_between_processes() 174 | print("Averaged stats:", metric_logger.global_avg()) 175 | return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()} 176 | 177 | 178 | def main(args, config): 179 | utils.init_distributed_mode(args) 180 | 181 | device = torch.device(args.device) 182 | 183 | # fix the seed for reproducibility 184 | seed = args.seed + utils.get_rank() 185 | torch.manual_seed(seed) 186 | np.random.seed(seed) 187 | random.seed(seed) 188 | cudnn.benchmark = True 189 | 190 | #### Dataset #### 191 | print("Creating dataset") 192 | datasets = [create_dataset('pretrain', config, min_scale=0.2)] 193 | print('number of training samples: %d'%len(datasets[0])) 194 | 195 | num_tasks = utils.get_world_size() 196 | global_rank = utils.get_rank() 197 | samplers = create_sampler(datasets, [True], num_tasks, global_rank) 198 | 199 | data_loader = create_loader(datasets,samplers,batch_size=[config['batch_size']], num_workers=[4], is_trains=[True], collate_fns=[None])[0] 200 | 201 | #### Model #### 202 | if args.model_type == 'ram_plus': 203 | print("Creating pretrained CLIP model") 204 | model_clip, _ = clip.load("ViT-B/16", device=device) 205 | 206 | print("Creating RAM model") 207 | model = ram_plus(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 208 | vit_ckpt_layer=config['vit_ckpt_layer'], stage = 'train_from_scratch') 209 | 210 | elif args.model_type == 'ram': 211 | print("Creating pretrained CLIP model") 212 | model_clip, _ = clip.load("ViT-B/16", device=device) 213 | 214 | print("Creating RAM model") 215 | model = ram(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 216 | vit_ckpt_layer=config['vit_ckpt_layer'], stage = 'train_from_scratch') 217 | 218 | elif args.model_type == 'tag2text': 219 | print("Creating Tag2Text model") 220 | model = tag2text(image_size=config['image_size'], vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], 221 | vit_ckpt_layer=config['vit_ckpt_layer'], stage = 'train_from_scratch', tag_list='ram/data/ram_tag_list.txt') 222 | model = model.to(device) 223 | 224 | ### Frozen CLIP model ### 225 | model_clip = model_clip.to(device) 226 | for _, param in model_clip.named_parameters(): 227 | param.requires_grad = False 228 | 229 | ### Frozen label embedding for open-set recogniztion ### 230 | model.label_embed.requires_grad = False 231 | optimizer = torch.optim.AdamW(filter(lambda x: x.requires_grad, model.parameters()), lr=config['init_lr'], weight_decay=config['weight_decay']) 232 | 233 | start_epoch = 0 234 | if args.checkpoint: 235 | checkpoint = torch.load(args.checkpoint, map_location='cpu') 236 | state_dict = checkpoint['model'] 237 | model.load_state_dict(state_dict) 238 | 239 | optimizer.load_state_dict(checkpoint['optimizer']) 240 | start_epoch = checkpoint['epoch']+1 241 | print('resume checkpoint from %s'%args.checkpoint) 242 | 243 | model_without_ddp = model 244 | if args.distributed: 245 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 246 | model_without_ddp = model.module 247 | 248 | print("Start training") 249 | start_time = time.time() 250 | for epoch in range(start_epoch, config['max_epoch']): 251 | 252 | step_lr_schedule(optimizer, epoch, config['init_lr'], config['min_lr'], config['lr_decay_rate']) 253 | 254 | if args.model_type == 'ram_plus': 255 | train_stats = train_ram_plus(model, data_loader, optimizer, epoch, device, config, model_clip) 256 | elif args.model_type == 'ram': 257 | train_stats = train_ram(model, data_loader, optimizer, epoch, device, config, model_clip) 258 | elif args.model_type == 'tag2text': 259 | train_stats = train_tag2text(model, data_loader, optimizer, epoch, device, config) 260 | 261 | if utils.is_main_process(): 262 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 263 | 'epoch': epoch, 264 | } 265 | save_obj = { 266 | 'model': model_without_ddp.state_dict(), 267 | 'optimizer': optimizer.state_dict(), 268 | 'config': config, 269 | 'epoch': epoch, 270 | } 271 | torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch)) 272 | 273 | with open(os.path.join(args.output_dir, "log.txt"),"a") as f: 274 | f.write(json.dumps(log_stats) + "\n") 275 | 276 | dist.barrier() 277 | 278 | total_time = time.time() - start_time 279 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 280 | print('Training time {}'.format(total_time_str)) 281 | 282 | 283 | if __name__ == '__main__': 284 | parser = argparse.ArgumentParser() 285 | parser.add_argument('--config', default='./configs/pretrain.yaml') 286 | parser.add_argument("--model-type",type=str,choices=("ram_plus", "ram", "tag2text"),required=True) 287 | parser.add_argument('--output-dir', default='output/Pretrain') 288 | parser.add_argument('--checkpoint', default='') 289 | parser.add_argument('--evaluate', action='store_true') 290 | parser.add_argument('--device', default='cuda') 291 | parser.add_argument('--seed', default=42, type=int) 292 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 293 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 294 | parser.add_argument('--distributed', default=True, type=bool) 295 | args = parser.parse_args() 296 | 297 | config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader) 298 | 299 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 300 | 301 | yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w')) 302 | 303 | main(args, config) -------------------------------------------------------------------------------- /ram/__init__.py: -------------------------------------------------------------------------------- 1 | from .inference import inference_tag2text, inference_ram, inference_ram_openset 2 | from .transform import get_transform 3 | -------------------------------------------------------------------------------- /ram/configs/finetune.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_l' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 26 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/finetune_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | ] 4 | image_path_root: "" 5 | 6 | # size of vit model; base or large 7 | vit: 'swin_b' 8 | vit_grad_ckpt: False 9 | vit_ckpt_layer: 0 10 | 11 | image_size: 384 12 | batch_size: 36 13 | 14 | # optimizer 15 | weight_decay: 0.05 16 | init_lr: 5e-06 17 | min_lr: 0 18 | max_epoch: 2 19 | warmup_steps: 3000 20 | 21 | class_num: 4585 22 | 23 | -------------------------------------------------------------------------------- /ram/configs/med_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 12, 15 | "num_hidden_layers": 12, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30524, 19 | "encoder_width": 768, 20 | "add_cross_attention": true 21 | } -------------------------------------------------------------------------------- /ram/configs/pretrain.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_l' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 52 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/pretrain_tag2text.yaml: -------------------------------------------------------------------------------- 1 | train_file: [ 2 | 'datasets/train/coco_train_rmcocodev_ram.json', 3 | 'datasets/train/vg_ram.json', 4 | 'datasets/train/sbu_ram.json', 5 | 'datasets/train/cc3m_train_ram.json', 6 | 'datasets/train/cc3m_val_ram.json', 7 | 'datasets/train/cc12m_ram.json', 8 | ] 9 | image_path_root: "" 10 | 11 | # size of vit model; base or large 12 | vit: 'swin_b' 13 | vit_grad_ckpt: False 14 | vit_ckpt_layer: 0 15 | 16 | image_size: 224 17 | batch_size: 80 18 | 19 | # optimizer 20 | weight_decay: 0.05 21 | init_lr: 1e-4 22 | min_lr: 5e-7 23 | warmup_lr: 5e-7 24 | lr_decay_rate: 0.9 25 | max_epoch: 5 26 | warmup_steps: 3000 27 | 28 | class_num: 4585 29 | 30 | -------------------------------------------------------------------------------- /ram/configs/q2l_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "BertModel" 4 | ], 5 | "attention_probs_dropout_prob": 0.1, 6 | "hidden_act": "gelu", 7 | "hidden_dropout_prob": 0.1, 8 | "hidden_size": 768, 9 | "initializer_range": 0.02, 10 | "intermediate_size": 3072, 11 | "layer_norm_eps": 1e-12, 12 | "max_position_embeddings": 512, 13 | "model_type": "bert", 14 | "num_attention_heads": 4, 15 | "num_hidden_layers": 2, 16 | "pad_token_id": 0, 17 | "type_vocab_size": 2, 18 | "vocab_size": 30522, 19 | "encoder_width": 768, 20 | "add_cross_attention": true, 21 | "add_tag_cross_attention": false 22 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinB_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_base_patch4_window7_224_22k.pth", 3 | "vision_width": 1024, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 128, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 4, 8, 16, 32 ] 9 | } -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_224.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window7_224_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 224, 5 | "window_size": 7, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } 10 | -------------------------------------------------------------------------------- /ram/configs/swin/config_swinL_384.json: -------------------------------------------------------------------------------- 1 | { 2 | "ckpt": "pretrain_model/swin_large_patch4_window12_384_22k.pth", 3 | "vision_width": 1536, 4 | "image_res": 384, 5 | "window_size": 12, 6 | "embed_dim": 192, 7 | "depths": [ 2, 2, 18, 2 ], 8 | "num_heads": [ 6, 12, 24, 48 ] 9 | } -------------------------------------------------------------------------------- /ram/data/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.transforms.functional import InterpolationMode 5 | 6 | from .dataset import pretrain_dataset, finetune_dataset 7 | from .randaugment import RandomAugment 8 | 9 | def create_dataset(dataset, config, min_scale=0.5): 10 | 11 | normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 12 | 13 | transform_train = transforms.Compose([ 14 | transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 15 | transforms.RandomHorizontalFlip(), 16 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 17 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 18 | transforms.ToTensor(), 19 | normalize, 20 | ]) 21 | 22 | transform_inputsize_224 = transforms.Compose([ 23 | transforms.RandomResizedCrop(224,scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC), 24 | transforms.RandomHorizontalFlip(), 25 | RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize', 26 | 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), 27 | transforms.ToTensor(), 28 | normalize, 29 | ]) 30 | 31 | if dataset=='pretrain': 32 | dataset = pretrain_dataset(config['train_file'], transform_train, class_num=config['class_num'], root=config['image_path_root']) 33 | return dataset 34 | 35 | elif dataset=='finetune': 36 | dataset = finetune_dataset(config['train_file'], transform_train, transform_inputsize_224, class_num=config['class_num'], root=config['image_path_root']) 37 | return dataset 38 | 39 | 40 | 41 | 42 | def create_sampler(datasets, shuffles, num_tasks, global_rank): 43 | samplers = [] 44 | for dataset,shuffle in zip(datasets,shuffles): 45 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle) 46 | samplers.append(sampler) 47 | return samplers 48 | 49 | 50 | def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): 51 | loaders = [] 52 | for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns): 53 | if is_train: 54 | shuffle = (sampler is None) 55 | drop_last = True 56 | else: 57 | shuffle = False 58 | drop_last = False 59 | loader = DataLoader( 60 | dataset, 61 | batch_size=bs, 62 | num_workers=n_worker, 63 | pin_memory=True, 64 | sampler=sampler, 65 | shuffle=shuffle, 66 | collate_fn=collate_fn, 67 | drop_last=drop_last, 68 | ) 69 | loaders.append(loader) 70 | return loaders 71 | 72 | -------------------------------------------------------------------------------- /ram/data/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from torch.utils.data import Dataset 6 | 7 | from PIL import Image 8 | from PIL import ImageFile 9 | ImageFile.LOAD_TRUNCATED_IMAGES = True 10 | Image.MAX_IMAGE_PIXELS = None 11 | 12 | from .utils import pre_caption 13 | import os,glob 14 | 15 | import torch 16 | import numpy as np 17 | 18 | class pretrain_dataset(Dataset): 19 | def __init__(self, ann_file, transform, class_num = 4585, root = ''): 20 | 21 | self.ann = [] 22 | for f in ann_file: 23 | print('loading '+f) 24 | ann = json.load(open(f,'r')) 25 | self.ann += ann 26 | 27 | self.transform = transform 28 | self.class_num = class_num 29 | self.root = root 30 | 31 | 32 | def __len__(self): 33 | return len(self.ann) 34 | 35 | def __getitem__(self, index): 36 | 37 | ann = self.ann[index] 38 | 39 | image_path_use = os.path.join(self.root, ann['image_path']) 40 | image = Image.open(image_path_use).convert('RGB') 41 | image = self.transform(image) 42 | 43 | # required for tag2text support 44 | if ann.get('union_label_id') is not None: 45 | num = ann['union_label_id'] 46 | image_tag = np.zeros([self.class_num]) 47 | image_tag[num] = 1 48 | image_tag = torch.tensor(image_tag, dtype = torch.long) 49 | else: 50 | image_tag = None 51 | 52 | caption_index = np.random.randint(0, len(ann['caption'])) 53 | 54 | caption = pre_caption(ann['caption'][caption_index],30) 55 | 56 | num = ann['parse_label_id'][caption_index] 57 | parse_tag = np.zeros([self.class_num]) 58 | parse_tag[num] = 1 59 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 60 | 61 | return image, caption, image_tag, parse_tag 62 | 63 | 64 | class finetune_dataset(Dataset): 65 | def __init__(self, ann_file, transform, transform_224, class_num = 4585, root = ''): 66 | 67 | self.ann = [] 68 | for f in ann_file: 69 | print('loading '+f) 70 | ann = json.load(open(f,'r')) 71 | self.ann += ann 72 | 73 | self.transform = transform 74 | self.transform_224 = transform_224 75 | self.class_num = class_num 76 | self.root = root 77 | 78 | 79 | def __len__(self): 80 | return len(self.ann) 81 | 82 | def __getitem__(self, index): 83 | 84 | ann = self.ann[index] 85 | 86 | image_path_use = os.path.join(self.root, ann['image_path']) 87 | image = Image.open(image_path_use).convert('RGB') 88 | image = self.transform(image) 89 | 90 | image_224 = Image.open(image_path_use).convert('RGB') 91 | image_224 = self.transform_224(image_224) 92 | 93 | # required for tag2text support 94 | if ann.get('union_label_id') is not None: 95 | num = ann['union_label_id'] 96 | image_tag = np.zeros([self.class_num]) 97 | image_tag[num] = 1 98 | image_tag = torch.tensor(image_tag, dtype = torch.long) 99 | else: 100 | image_tag = None 101 | 102 | caption_index = np.random.randint(0, len(ann['caption'])) 103 | 104 | caption = pre_caption(ann['caption'][caption_index],30) 105 | 106 | num = ann['parse_label_id'][caption_index] 107 | parse_tag = np.zeros([self.class_num]) 108 | parse_tag[num] = 1 109 | parse_tag = torch.tensor(parse_tag, dtype = torch.long) 110 | 111 | return image, image_224, caption, image_tag, parse_tag 112 | 113 | -------------------------------------------------------------------------------- /ram/data/randaugment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | ## aug functions 6 | def identity_func(img): 7 | return img 8 | 9 | 10 | def autocontrast_func(img, cutoff=0): 11 | ''' 12 | same output as PIL.ImageOps.autocontrast 13 | ''' 14 | n_bins = 256 15 | 16 | def tune_channel(ch): 17 | n = ch.size 18 | cut = cutoff * n // 100 19 | if cut == 0: 20 | high, low = ch.max(), ch.min() 21 | else: 22 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 23 | low = np.argwhere(np.cumsum(hist) > cut) 24 | low = 0 if low.shape[0] == 0 else low[0] 25 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 26 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 27 | if high <= low: 28 | table = np.arange(n_bins) 29 | else: 30 | scale = (n_bins - 1) / (high - low) 31 | offset = -low * scale 32 | table = np.arange(n_bins) * scale + offset 33 | table[table < 0] = 0 34 | table[table > n_bins - 1] = n_bins - 1 35 | table = table.clip(0, 255).astype(np.uint8) 36 | return table[ch] 37 | 38 | channels = [tune_channel(ch) for ch in cv2.split(img)] 39 | out = cv2.merge(channels) 40 | return out 41 | 42 | 43 | def equalize_func(img): 44 | ''' 45 | same output as PIL.ImageOps.equalize 46 | PIL's implementation is different from cv2.equalize 47 | ''' 48 | n_bins = 256 49 | 50 | def tune_channel(ch): 51 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 52 | non_zero_hist = hist[hist != 0].reshape(-1) 53 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 54 | if step == 0: return ch 55 | n = np.empty_like(hist) 56 | n[0] = step // 2 57 | n[1:] = hist[:-1] 58 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 59 | return table[ch] 60 | 61 | channels = [tune_channel(ch) for ch in cv2.split(img)] 62 | out = cv2.merge(channels) 63 | return out 64 | 65 | 66 | def rotate_func(img, degree, fill=(0, 0, 0)): 67 | ''' 68 | like PIL, rotate by degree, not radians 69 | ''' 70 | H, W = img.shape[0], img.shape[1] 71 | center = W / 2, H / 2 72 | M = cv2.getRotationMatrix2D(center, degree, 1) 73 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 74 | return out 75 | 76 | 77 | def solarize_func(img, thresh=128): 78 | ''' 79 | same output as PIL.ImageOps.posterize 80 | ''' 81 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 82 | table = table.clip(0, 255).astype(np.uint8) 83 | out = table[img] 84 | return out 85 | 86 | 87 | def color_func(img, factor): 88 | ''' 89 | same output as PIL.ImageEnhance.Color 90 | ''' 91 | ## implementation according to PIL definition, quite slow 92 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 93 | # out = blend(degenerate, img, factor) 94 | # M = ( 95 | # np.eye(3) * factor 96 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 97 | # )[np.newaxis, np.newaxis, :] 98 | M = ( 99 | np.float32([ 100 | [0.886, -0.114, -0.114], 101 | [-0.587, 0.413, -0.587], 102 | [-0.299, -0.299, 0.701]]) * factor 103 | + np.float32([[0.114], [0.587], [0.299]]) 104 | ) 105 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 106 | return out 107 | 108 | 109 | def contrast_func(img, factor): 110 | """ 111 | same output as PIL.ImageEnhance.Contrast 112 | """ 113 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 114 | table = np.array([( 115 | el - mean) * factor + mean 116 | for el in range(256) 117 | ]).clip(0, 255).astype(np.uint8) 118 | out = table[img] 119 | return out 120 | 121 | 122 | def brightness_func(img, factor): 123 | ''' 124 | same output as PIL.ImageEnhance.Contrast 125 | ''' 126 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 127 | out = table[img] 128 | return out 129 | 130 | 131 | def sharpness_func(img, factor): 132 | ''' 133 | The differences the this result and PIL are all on the 4 boundaries, the center 134 | areas are same 135 | ''' 136 | kernel = np.ones((3, 3), dtype=np.float32) 137 | kernel[1][1] = 5 138 | kernel /= 13 139 | degenerate = cv2.filter2D(img, -1, kernel) 140 | if factor == 0.0: 141 | out = degenerate 142 | elif factor == 1.0: 143 | out = img 144 | else: 145 | out = img.astype(np.float32) 146 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 147 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 148 | out = out.astype(np.uint8) 149 | return out 150 | 151 | 152 | def shear_x_func(img, factor, fill=(0, 0, 0)): 153 | H, W = img.shape[0], img.shape[1] 154 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 155 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 156 | return out 157 | 158 | 159 | def translate_x_func(img, offset, fill=(0, 0, 0)): 160 | ''' 161 | same output as PIL.Image.transform 162 | ''' 163 | H, W = img.shape[0], img.shape[1] 164 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 165 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 166 | return out 167 | 168 | 169 | def translate_y_func(img, offset, fill=(0, 0, 0)): 170 | ''' 171 | same output as PIL.Image.transform 172 | ''' 173 | H, W = img.shape[0], img.shape[1] 174 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 175 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 176 | return out 177 | 178 | 179 | def posterize_func(img, bits): 180 | ''' 181 | same output as PIL.ImageOps.posterize 182 | ''' 183 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 184 | return out 185 | 186 | 187 | def shear_y_func(img, factor, fill=(0, 0, 0)): 188 | H, W = img.shape[0], img.shape[1] 189 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 190 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR).astype(np.uint8) 191 | return out 192 | 193 | 194 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 195 | replace = np.array(replace, dtype=np.uint8) 196 | H, W = img.shape[0], img.shape[1] 197 | rh, rw = np.random.random(2) 198 | pad_size = pad_size // 2 199 | ch, cw = int(rh * H), int(rw * W) 200 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 201 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 202 | out = img.copy() 203 | out[x1:x2, y1:y2, :] = replace 204 | return out 205 | 206 | 207 | ### level to args 208 | def enhance_level_to_args(MAX_LEVEL): 209 | def level_to_args(level): 210 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 211 | return level_to_args 212 | 213 | 214 | def shear_level_to_args(MAX_LEVEL, replace_value): 215 | def level_to_args(level): 216 | level = (level / MAX_LEVEL) * 0.3 217 | if np.random.random() > 0.5: level = -level 218 | return (level, replace_value) 219 | 220 | return level_to_args 221 | 222 | 223 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 224 | def level_to_args(level): 225 | level = (level / MAX_LEVEL) * float(translate_const) 226 | if np.random.random() > 0.5: level = -level 227 | return (level, replace_value) 228 | 229 | return level_to_args 230 | 231 | 232 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 233 | def level_to_args(level): 234 | level = int((level / MAX_LEVEL) * cutout_const) 235 | return (level, replace_value) 236 | 237 | return level_to_args 238 | 239 | 240 | def solarize_level_to_args(MAX_LEVEL): 241 | def level_to_args(level): 242 | level = int((level / MAX_LEVEL) * 256) 243 | return (level, ) 244 | return level_to_args 245 | 246 | 247 | def none_level_to_args(level): 248 | return () 249 | 250 | 251 | def posterize_level_to_args(MAX_LEVEL): 252 | def level_to_args(level): 253 | level = int((level / MAX_LEVEL) * 4) 254 | return (level, ) 255 | return level_to_args 256 | 257 | 258 | def rotate_level_to_args(MAX_LEVEL, replace_value): 259 | def level_to_args(level): 260 | level = (level / MAX_LEVEL) * 30 261 | if np.random.random() < 0.5: 262 | level = -level 263 | return (level, replace_value) 264 | 265 | return level_to_args 266 | 267 | 268 | func_dict = { 269 | 'Identity': identity_func, 270 | 'AutoContrast': autocontrast_func, 271 | 'Equalize': equalize_func, 272 | 'Rotate': rotate_func, 273 | 'Solarize': solarize_func, 274 | 'Color': color_func, 275 | 'Contrast': contrast_func, 276 | 'Brightness': brightness_func, 277 | 'Sharpness': sharpness_func, 278 | 'ShearX': shear_x_func, 279 | 'TranslateX': translate_x_func, 280 | 'TranslateY': translate_y_func, 281 | 'Posterize': posterize_func, 282 | 'ShearY': shear_y_func, 283 | } 284 | 285 | translate_const = 10 286 | MAX_LEVEL = 10 287 | replace_value = (128, 128, 128) 288 | arg_dict = { 289 | 'Identity': none_level_to_args, 290 | 'AutoContrast': none_level_to_args, 291 | 'Equalize': none_level_to_args, 292 | 'Rotate': rotate_level_to_args(MAX_LEVEL, replace_value), 293 | 'Solarize': solarize_level_to_args(MAX_LEVEL), 294 | 'Color': enhance_level_to_args(MAX_LEVEL), 295 | 'Contrast': enhance_level_to_args(MAX_LEVEL), 296 | 'Brightness': enhance_level_to_args(MAX_LEVEL), 297 | 'Sharpness': enhance_level_to_args(MAX_LEVEL), 298 | 'ShearX': shear_level_to_args(MAX_LEVEL, replace_value), 299 | 'TranslateX': translate_level_to_args( 300 | translate_const, MAX_LEVEL, replace_value 301 | ), 302 | 'TranslateY': translate_level_to_args( 303 | translate_const, MAX_LEVEL, replace_value 304 | ), 305 | 'Posterize': posterize_level_to_args(MAX_LEVEL), 306 | 'ShearY': shear_level_to_args(MAX_LEVEL, replace_value), 307 | } 308 | 309 | 310 | class RandomAugment(object): 311 | 312 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 313 | self.N = N 314 | self.M = M 315 | self.isPIL = isPIL 316 | if augs: 317 | self.augs = augs 318 | else: 319 | self.augs = list(arg_dict.keys()) 320 | 321 | def get_random_ops(self): 322 | sampled_ops = np.random.choice(self.augs, self.N) 323 | return [(op, 0.5, self.M) for op in sampled_ops] 324 | 325 | def __call__(self, img): 326 | if self.isPIL: 327 | img = np.array(img) 328 | ops = self.get_random_ops() 329 | for name, prob, level in ops: 330 | if np.random.random() > prob: 331 | continue 332 | args = arg_dict[name](level) 333 | img = func_dict[name](img, *args) 334 | return img 335 | 336 | 337 | if __name__ == '__main__': 338 | a = RandomAugment() 339 | img = np.random.randn(32, 32, 3) 340 | a(img) -------------------------------------------------------------------------------- /ram/data/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import utils 9 | 10 | def pre_caption(caption,max_words=50): 11 | caption = re.sub( 12 | r"([.!\"()*#:;~])", 13 | ' ', 14 | caption.lower(), 15 | ) 16 | caption = re.sub( 17 | r"\s{2,}", 18 | ' ', 19 | caption, 20 | ) 21 | caption = caption.rstrip('\n') 22 | caption = caption.strip(' ') 23 | 24 | #truncate caption 25 | caption_words = caption.split(' ') 26 | if len(caption_words)>max_words: 27 | caption = ' '.join(caption_words[:max_words]) 28 | 29 | return caption 30 | 31 | def pre_question(question,max_ques_words=50): 32 | question = re.sub( 33 | r"([.!\"()*#:;~])", 34 | '', 35 | question.lower(), 36 | ) 37 | question = question.rstrip(' ') 38 | 39 | #truncate question 40 | question_words = question.split(' ') 41 | if len(question_words)>max_ques_words: 42 | question = ' '.join(question_words[:max_ques_words]) 43 | 44 | return question 45 | 46 | 47 | def save_result(result, result_dir, filename, remove_duplicate=''): 48 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank())) 49 | final_result_file = os.path.join(result_dir, '%s.json'%filename) 50 | 51 | json.dump(result,open(result_file,'w')) 52 | 53 | dist.barrier() 54 | 55 | if utils.is_main_process(): 56 | # combine results from all processes 57 | result = [] 58 | 59 | for rank in range(utils.get_world_size()): 60 | result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank)) 61 | res = json.load(open(result_file,'r')) 62 | result += res 63 | 64 | if remove_duplicate: 65 | result_new = [] 66 | id_list = [] 67 | for res in result: 68 | if res[remove_duplicate] not in id_list: 69 | id_list.append(res[remove_duplicate]) 70 | result_new.append(res) 71 | result = result_new 72 | 73 | json.dump(result,open(final_result_file,'w')) 74 | print('result file saved to %s'%final_result_file) 75 | 76 | return final_result_file 77 | 78 | 79 | 80 | from pycocotools.coco import COCO 81 | from pycocoevalcap.eval import COCOEvalCap 82 | from torchvision.datasets.utils import download_url 83 | 84 | def coco_caption_eval(coco_gt_root, results_file, split): 85 | urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json', 86 | 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'} 87 | filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'} 88 | 89 | download_url(urls[split],coco_gt_root) 90 | annotation_file = os.path.join(coco_gt_root,filenames[split]) 91 | 92 | # create coco object and coco_result object 93 | coco = COCO(annotation_file) 94 | coco_result = coco.loadRes(results_file) 95 | 96 | # create coco_eval object by taking coco and coco_result 97 | coco_eval = COCOEvalCap(coco, coco_result) 98 | 99 | # evaluate on a subset of images by setting 100 | # coco_eval.params['image_id'] = coco_result.getImgIds() 101 | # please remove this line when evaluating the full validation set 102 | # coco_eval.params['image_id'] = coco_result.getImgIds() 103 | 104 | # evaluate results 105 | # SPICE will take a few minutes the first time, but speeds up due to caching 106 | coco_eval.evaluate() 107 | 108 | # print output evaluation scores 109 | for metric, score in coco_eval.eval.items(): 110 | print(f'{metric}: {score:.3f}') 111 | 112 | return coco_eval -------------------------------------------------------------------------------- /ram/inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Inference of RAM and Tag2Text Models 3 | * Written by Xinyu Huang 4 | ''' 5 | import torch 6 | 7 | 8 | def inference_tag2text(image, model, input_tag="None"): 9 | 10 | with torch.no_grad(): 11 | caption, tag_predict = model.generate(image, 12 | tag_input=None, 13 | max_length=50, 14 | return_tag_predict=True) 15 | 16 | if input_tag == '' or input_tag == 'none' or input_tag == 'None': 17 | return tag_predict[0], None, caption[0] 18 | 19 | # If user input specified tags: 20 | else: 21 | input_tag_list = [] 22 | input_tag_list.append(input_tag.replace(',', ' | ')) 23 | 24 | with torch.no_grad(): 25 | caption, input_tag = model.generate(image, 26 | tag_input=input_tag_list, 27 | max_length=50, 28 | return_tag_predict=True) 29 | 30 | return tag_predict[0], input_tag[0], caption[0] 31 | 32 | 33 | def inference_ram(image, model): 34 | 35 | with torch.no_grad(): 36 | tags, tags_chinese = model.generate_tag(image) 37 | 38 | return tags[0],tags_chinese[0] 39 | 40 | 41 | def inference_ram_openset(image, model): 42 | 43 | with torch.no_grad(): 44 | tags = model.generate_tag_openset(image) 45 | 46 | return tags[0] 47 | -------------------------------------------------------------------------------- /ram/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ram_plus import ram_plus 2 | from .ram import ram 3 | from .tag2text import tag2text 4 | -------------------------------------------------------------------------------- /ram/models/ram_plus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Recognize Anything Plus Model (RAM++) 3 | * Written by Xinyu Huang 4 | ''' 5 | import json 6 | import warnings 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | import torch.nn.functional as F 13 | from .bert import BertConfig, BertLMHeadModel, BertModel 14 | from .swin_transformer import SwinTransformer 15 | from .utils import * 16 | 17 | warnings.filterwarnings("ignore") 18 | 19 | 20 | 21 | class RAM_plus(nn.Module): 22 | def __init__(self, 23 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 24 | image_size=384, 25 | text_encoder_type='bert-base-uncased', 26 | vit='base', 27 | vit_grad_ckpt=False, 28 | vit_ckpt_layer=0, 29 | threshold=0.68, 30 | delete_tag_index=[], 31 | tag_list=f'{CONFIG_PATH}/data/ram_tag_list.txt', 32 | tag_list_chinese=f'{CONFIG_PATH}/data/ram_tag_list_chinese.txt', 33 | stage='eval'): 34 | r""" The Recognize Anything Plus Model (RAM++) inference module. 35 | RAM++ is a strong image tagging model, which can recognize any category with high accuracy using tag categories. 36 | Described in the paper "Open-Set Image Tagging with Multi-Grained Text Supervision" https://arxiv.org/abs/2310.15200 37 | 38 | Args: 39 | med_config (str): path for the mixture of encoder-decoder model's configuration file 40 | image_size (int): input image size 41 | vit (str): model size of vision transformer 42 | threshold (int): tagging threshold 43 | delete_tag_index (list): delete some tags that may disturb captioning 44 | """ 45 | super().__init__() 46 | 47 | # create image encoder 48 | if vit == 'swin_b': 49 | if image_size == 224: 50 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 51 | elif image_size == 384: 52 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 53 | vision_config = read_json(vision_config_path) 54 | assert image_size == vision_config['image_res'] 55 | # assert config['patch_size'] == 32 56 | vision_width = vision_config['vision_width'] 57 | 58 | self.visual_encoder = SwinTransformer( 59 | img_size=vision_config['image_res'], 60 | patch_size=4, 61 | in_chans=3, 62 | embed_dim=vision_config['embed_dim'], 63 | depths=vision_config['depths'], 64 | num_heads=vision_config['num_heads'], 65 | window_size=vision_config['window_size'], 66 | mlp_ratio=4., 67 | qkv_bias=True, 68 | drop_rate=0.0, 69 | drop_path_rate=0.1, 70 | ape=False, 71 | patch_norm=True, 72 | use_checkpoint=False) 73 | 74 | if stage == 'train_from_scratch': 75 | # download from https://github.com/microsoft/Swin-Transformer 76 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 77 | 78 | for k in list(state_dict.keys()): 79 | if 'relative_position_bias_table' in k: 80 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 81 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 82 | elif ('relative_position_index' in k) or ('attn_mask' in k): 83 | del state_dict[k] 84 | 85 | print("### Load Vision Backbone", vit) 86 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 87 | print("missing_keys: ", msg.missing_keys) 88 | print("unexpected_keys: ", msg.unexpected_keys) 89 | 90 | elif vit == 'swin_l': 91 | if image_size == 224: 92 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 93 | elif image_size == 384: 94 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 95 | vision_config = read_json(vision_config_path) 96 | assert image_size == vision_config['image_res'] 97 | # assert config['patch_size'] == 32 98 | vision_width = vision_config['vision_width'] 99 | 100 | self.visual_encoder = SwinTransformer( 101 | img_size=vision_config['image_res'], 102 | patch_size=4, 103 | in_chans=3, 104 | embed_dim=vision_config['embed_dim'], 105 | depths=vision_config['depths'], 106 | num_heads=vision_config['num_heads'], 107 | window_size=vision_config['window_size'], 108 | mlp_ratio=4., 109 | qkv_bias=True, 110 | drop_rate=0.0, 111 | drop_path_rate=0.1, 112 | ape=False, 113 | patch_norm=True, 114 | use_checkpoint=False) 115 | 116 | if stage == 'train_from_scratch': 117 | # download from https://github.com/microsoft/Swin-Transformer 118 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 119 | 120 | for k in list(state_dict.keys()): 121 | if 'relative_position_bias_table' in k: 122 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 123 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 124 | elif ('relative_position_index' in k) or ('attn_mask' in k): 125 | del state_dict[k] 126 | 127 | print("### Load Vision Backbone", vit) 128 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 129 | print("missing_keys: ", msg.missing_keys) 130 | print("unexpected_keys: ", msg.unexpected_keys) 131 | 132 | else: 133 | self.visual_encoder, vision_width = create_vit( 134 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 135 | 136 | # create tokenzier 137 | self.tokenizer = init_tokenizer(text_encoder_type) 138 | 139 | self.delete_tag_index = delete_tag_index 140 | 141 | # load tag list 142 | self.tag_list = self.load_tag_list(tag_list) 143 | self.tag_list_chinese = self.load_tag_list(tag_list_chinese) 144 | 145 | # create image-tag recognition decoder 146 | self.threshold = threshold 147 | self.num_class = len(self.tag_list) 148 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 149 | q2l_config.encoder_width = 512 150 | self.tagging_head = BertModel(config=q2l_config, 151 | add_pooling_layer=False) 152 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 153 | 154 | if stage == 'train_from_scratch': 155 | self.label_embed = nn.Parameter(torch.load(f'{CONFIG_PATH}/data/frozen_tag_embedding/ram_plus_tag_embedding_class_4585_des_51.pth',map_location='cpu').float()) 156 | else: 157 | # when eval with pretrained RAM++ model, directly load from ram_plus_swin_large_14m.pth 158 | self.label_embed = nn.Parameter(torch.zeros(self.num_class * 51, q2l_config.encoder_width)) 159 | 160 | if q2l_config.hidden_size != 512: 161 | self.wordvec_proj = nn.Linear(512, q2l_config.hidden_size) 162 | else: 163 | self.wordvec_proj = nn.Identity() 164 | 165 | self.fc = nn.Linear(q2l_config.hidden_size, 1) 166 | 167 | self.del_selfattention() 168 | 169 | self.image_proj = nn.Linear(vision_width, 512) 170 | 171 | # adjust thresholds for some tags 172 | self.class_threshold = torch.ones(self.num_class) * self.threshold 173 | ram_class_threshold_path = f'{CONFIG_PATH}/data/ram_tag_list_threshold.txt' 174 | with open(ram_class_threshold_path, 'r', encoding='utf-8') as f: 175 | ram_class_threshold = [float(s.strip()) for s in f] 176 | for key,value in enumerate(ram_class_threshold): 177 | self.class_threshold[key] = value 178 | 179 | self.reweight_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 180 | 181 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 182 | gamma_pos=0, 183 | clip=0.05) 184 | 185 | self.text_alignment_loss_function = AsymmetricLoss(gamma_neg=4, 186 | gamma_pos=0, 187 | clip=0.05) 188 | 189 | def load_tag_list(self, tag_list_file): 190 | with open(tag_list_file, 'r', encoding="utf-8") as f: 191 | tag_list = f.read().splitlines() 192 | tag_list = np.array(tag_list) 193 | return tag_list 194 | 195 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 196 | def del_selfattention(self): 197 | del self.tagging_head.embeddings 198 | for layer in self.tagging_head.encoder.layer: 199 | del layer.attention 200 | 201 | def forward(self, image, caption, image_tag, clip_feature, batch_text_embed): 202 | """ 203 | call function as forward 204 | 205 | Args: 206 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 207 | caption: type: list[string] len: batch_size 208 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 209 | 210 | Returns: 211 | loss: type: torch.Tensor 212 | """ 213 | 214 | image_embeds = self.image_proj(self.visual_encoder(image)) 215 | image_atts = torch.ones(image_embeds.size()[:-1], 216 | dtype=torch.long).to(image.device) 217 | 218 | ##================= Distillation from CLIP ================## 219 | image_cls_embeds = image_embeds[:, 0, :] 220 | image_spatial_embeds = image_embeds[:, 1:, :] 221 | 222 | loss_dis = F.l1_loss(image_cls_embeds, clip_feature) 223 | 224 | ###===========multi tag des reweight==============### 225 | bs = image_embeds.shape[0] 226 | 227 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 228 | 229 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 230 | reweight_scale = self.reweight_scale.exp() 231 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 232 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 233 | 234 | weight_normalized = F.softmax(logits_per_image, dim=2) 235 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 236 | 237 | for i in range(bs): 238 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 239 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 240 | label_embed_reweight[i] = product.sum(dim=1) 241 | 242 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 243 | 244 | ##================= Image Tagging ================## 245 | 246 | tagging_embed = self.tagging_head( 247 | encoder_embeds=label_embed, 248 | encoder_hidden_states=image_embeds, 249 | encoder_attention_mask=image_atts, 250 | return_dict=False, 251 | mode='tagging', 252 | ) 253 | 254 | logits = self.fc(tagging_embed[0]).squeeze(-1) 255 | 256 | loss_tag = self.tagging_loss_function(logits, image_tag) 257 | 258 | ##================= Image-text Alignment ================## 259 | 260 | batch_text_embed = torch.nn.functional.relu(self.wordvec_proj(batch_text_embed.to(self.label_embed.dtype))) 261 | batch_text_embed = batch_text_embed.unsqueeze(0).repeat(bs, 1, 1) 262 | alignment_embedding = self.tagging_head( 263 | encoder_embeds=batch_text_embed, 264 | encoder_hidden_states=image_embeds, 265 | encoder_attention_mask=image_atts, 266 | return_dict=False, 267 | mode='tagging', 268 | ) 269 | alignment_logits = self.fc(alignment_embedding[0]).squeeze(-1) 270 | 271 | with torch.no_grad(): 272 | alignment_targets = torch.zeros(alignment_logits.size()).to(image.device) 273 | alignment_targets.fill_diagonal_(1) 274 | 275 | loss_alignment = self.text_alignment_loss_function(alignment_logits,alignment_targets) 276 | 277 | return loss_tag, loss_dis, loss_alignment 278 | 279 | 280 | def generate_tag(self, 281 | image 282 | ): 283 | 284 | image_embeds = self.image_proj(self.visual_encoder(image)) 285 | image_atts = torch.ones(image_embeds.size()[:-1], 286 | dtype=torch.long).to(image.device) 287 | 288 | image_cls_embeds = image_embeds[:, 0, :] 289 | image_spatial_embeds = image_embeds[:, 1:, :] 290 | 291 | bs = image_spatial_embeds.shape[0] 292 | 293 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 294 | 295 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 296 | reweight_scale = self.reweight_scale.exp() 297 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 298 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 299 | 300 | weight_normalized = F.softmax(logits_per_image, dim=2) 301 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 302 | 303 | for i in range(bs): 304 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting 305 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 306 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 307 | label_embed_reweight[i] = product.sum(dim=1) 308 | 309 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 310 | 311 | # recognized image tags using alignment decoder 312 | tagging_embed = self.tagging_head( 313 | encoder_embeds=label_embed, 314 | encoder_hidden_states=image_embeds, 315 | encoder_attention_mask=image_atts, 316 | return_dict=False, 317 | mode='tagging', 318 | ) 319 | 320 | logits = self.fc(tagging_embed[0]).squeeze(-1) 321 | 322 | targets = torch.where( 323 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 324 | torch.tensor(1.0).to(image.device), 325 | torch.zeros(self.num_class).to(image.device)) 326 | 327 | tag = targets.cpu().numpy() 328 | tag[:,self.delete_tag_index] = 0 329 | tag_output = [] 330 | tag_output_chinese = [] 331 | for b in range(bs): 332 | index = np.argwhere(tag[b] == 1) 333 | token = self.tag_list[index].squeeze(axis=1) 334 | tag_output.append(' | '.join(token)) 335 | token_chinese = self.tag_list_chinese[index].squeeze(axis=1) 336 | tag_output_chinese.append(' | '.join(token_chinese)) 337 | 338 | 339 | return tag_output, tag_output_chinese 340 | 341 | def generate_tag_openset(self, 342 | image, 343 | threshold=0.68, 344 | tag_input=None, 345 | ): 346 | 347 | image_embeds = self.image_proj(self.visual_encoder(image)) 348 | image_atts = torch.ones(image_embeds.size()[:-1], 349 | dtype=torch.long).to(image.device) 350 | 351 | image_cls_embeds = image_embeds[:, 0, :] 352 | image_spatial_embeds = image_embeds[:, 1:, :] 353 | 354 | bs = image_spatial_embeds.shape[0] 355 | 356 | des_per_class = int(self.label_embed.shape[0] / self.num_class) 357 | 358 | image_cls_embeds = image_cls_embeds / image_cls_embeds.norm(dim=-1, keepdim=True) 359 | reweight_scale = self.reweight_scale.exp() 360 | logits_per_image = (reweight_scale * image_cls_embeds @ self.label_embed.t()) 361 | logits_per_image = logits_per_image.view(bs, -1,des_per_class) 362 | 363 | weight_normalized = F.softmax(logits_per_image, dim=2) 364 | label_embed_reweight = torch.empty(bs, self.num_class, 512).to(image.device).to(image.dtype) 365 | 366 | for i in range(bs): 367 | # 这里对 value_ori 进行 reshape,然后使用 broadcasting 368 | reshaped_value = self.label_embed.view(-1, des_per_class, 512) 369 | product = weight_normalized[i].unsqueeze(-1) * reshaped_value 370 | label_embed_reweight[i] = product.sum(dim=1) 371 | 372 | label_embed = torch.nn.functional.relu(self.wordvec_proj(label_embed_reweight)) 373 | 374 | # recognized image tags using alignment decoder 375 | tagging_embed = self.tagging_head( 376 | encoder_embeds=label_embed, 377 | encoder_hidden_states=image_embeds, 378 | encoder_attention_mask=image_atts, 379 | return_dict=False, 380 | mode='tagging', 381 | ) 382 | 383 | logits = self.fc(tagging_embed[0]).squeeze(-1) 384 | 385 | targets = torch.where( 386 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 387 | torch.tensor(1.0).to(image.device), 388 | torch.zeros(self.num_class).to(image.device)) 389 | 390 | tag = targets.cpu().numpy() 391 | tag[:,self.delete_tag_index] = 0 392 | tag_output = [] 393 | for b in range(bs): 394 | index = np.argwhere(tag[b] == 1) 395 | token = self.tag_list[index].squeeze(axis=1) 396 | tag_output.append(' | '.join(token)) 397 | 398 | return tag_output 399 | 400 | 401 | # load RAM++ pretrained model parameters 402 | def ram_plus(pretrained='', **kwargs): 403 | model = RAM_plus(**kwargs) 404 | if pretrained: 405 | if kwargs['vit'] == 'swin_b': 406 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 407 | elif kwargs['vit'] == 'swin_l': 408 | model, msg = load_checkpoint_swinlarge(model, pretrained, kwargs) 409 | else: 410 | model, msg = load_checkpoint(model, pretrained) 411 | print('vit:', kwargs['vit']) 412 | # print('msg', msg) 413 | return model 414 | -------------------------------------------------------------------------------- /ram/models/tag2text.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * The Tag2Text Model 3 | * Written by Xinyu Huang 4 | ''' 5 | import numpy as np 6 | import json 7 | import torch 8 | import warnings 9 | 10 | from torch import nn 11 | from .bert import BertConfig, BertModel, BertLMHeadModel 12 | from .swin_transformer import SwinTransformer 13 | 14 | from .utils import * 15 | 16 | warnings.filterwarnings("ignore") 17 | 18 | 19 | class Tag2Text(nn.Module): 20 | 21 | def __init__(self, 22 | med_config=f'{CONFIG_PATH}/configs/med_config.json', 23 | image_size=384, 24 | text_encoder_type='bert-base-uncased', 25 | vit='base', 26 | vit_grad_ckpt=False, 27 | vit_ckpt_layer=0, 28 | prompt='a picture of ', 29 | threshold=0.68, 30 | delete_tag_index=[127,2961, 3351, 3265, 3338, 3355, 3359], 31 | tag_list=f'{CONFIG_PATH}/data/tag2text_ori_tag_list.txt', 32 | stage='eval'): 33 | r""" Tag2Text inference module, both captioning and tagging are included. 34 | Tag2Text is an efficient and controllable vision-language pre-training framework. 35 | Described in the paper "Tag2Text: Guiding Vision-Language Model via Image Tagging" https://arxiv.org/abs/2303.05657 36 | 37 | Args: 38 | med_config (str): path for the mixture of encoder-decoder model's configuration file 39 | image_size (int): input image size 40 | vit (str): model size of vision transformer 41 | threshold (int): tagging threshold 42 | delete_tag_index (list): delete some tags that may disturb captioning 43 | """ 44 | super().__init__() 45 | 46 | # create image encoder 47 | if vit == 'swin_b': 48 | if image_size == 224: 49 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 50 | elif image_size == 384: 51 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 52 | vision_config = read_json(vision_config_path) 53 | assert image_size == vision_config['image_res'] 54 | # assert config['patch_size'] == 32 55 | vision_width = vision_config['vision_width'] 56 | 57 | self.visual_encoder = SwinTransformer( 58 | img_size=vision_config['image_res'], 59 | patch_size=4, 60 | in_chans=3, 61 | embed_dim=vision_config['embed_dim'], 62 | depths=vision_config['depths'], 63 | num_heads=vision_config['num_heads'], 64 | window_size=vision_config['window_size'], 65 | mlp_ratio=4., 66 | qkv_bias=True, 67 | drop_rate=0.0, 68 | drop_path_rate=0.1, 69 | ape=False, 70 | patch_norm=True, 71 | use_checkpoint=False) 72 | 73 | if stage == 'train_from_scratch': 74 | # download from https://github.com/microsoft/Swin-Transformer 75 | state_dict = torch.load(vision_config['ckpt'], map_location="cpu")['model'] 76 | 77 | for k in list(state_dict.keys()): 78 | if 'relative_position_bias_table' in k: 79 | dst_num_pos = (2 * vision_config['window_size'] - 1) ** 2 80 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], dst_num_pos, param_name=k) 81 | elif ('relative_position_index' in k) or ('attn_mask' in k): 82 | del state_dict[k] 83 | 84 | print("### Load Vision Backbone", vit) 85 | msg = self.visual_encoder.load_state_dict(state_dict, strict = False) 86 | print("missing_keys: ", msg.missing_keys) 87 | print("unexpected_keys: ", msg.unexpected_keys) 88 | 89 | else: 90 | self.visual_encoder, vision_width = create_vit( 91 | vit, image_size, vit_grad_ckpt, vit_ckpt_layer) 92 | 93 | # create tokenzier 94 | self.tokenizer = init_tokenizer(text_encoder_type) 95 | 96 | # Tag2Text employ encoder-decoder architecture for image-tag-text generation: image-tag interaction encoder and image-tag-text decoder 97 | # create image-tag interaction encoder 98 | encoder_config = BertConfig.from_json_file(med_config) 99 | encoder_config.encoder_width = vision_width 100 | self.tag_encoder = BertModel(config=encoder_config, 101 | add_pooling_layer=False) 102 | 103 | # create image-tag-text decoder 104 | decoder_config = BertConfig.from_json_file(med_config) 105 | self.text_decoder = BertLMHeadModel(config=decoder_config) 106 | 107 | # delete some tags that may disturb captioning 108 | # 127: "quarter"; 2961: "back"; 3351: "two"; 3265: "three"; 3338: "four"; 3355: "five"; 3359: "one" 109 | self.delete_tag_index = delete_tag_index 110 | self.prompt = prompt 111 | self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1 112 | 113 | # load tag list 114 | self.tag_list = self.load_tag_list(tag_list) 115 | 116 | # create image-tag recognition decoder 117 | self.threshold = threshold 118 | self.num_class = len(self.tag_list) 119 | q2l_config = BertConfig.from_json_file(f'{CONFIG_PATH}/configs/q2l_config.json') 120 | q2l_config.encoder_width = vision_width 121 | self.tagging_head = BertModel(config=q2l_config, 122 | add_pooling_layer=False) 123 | self.tagging_head.resize_token_embeddings(len(self.tokenizer)) 124 | self.label_embed = nn.Embedding(self.num_class, q2l_config.hidden_size) 125 | self.fc = GroupWiseLinear(self.num_class, 126 | q2l_config.hidden_size, 127 | bias=True) 128 | self.del_selfattention() 129 | 130 | self.tagging_loss_function = AsymmetricLoss(gamma_neg=7, 131 | gamma_pos=0, 132 | clip=0.05) 133 | 134 | # share weights of the lowest 2-layer of "image-tag interaction encoder" with the "image-tag recogntion decoder" 135 | tie_encoder_decoder_weights(self.tag_encoder, self.tagging_head, '', 136 | ' ') 137 | 138 | # adjust thresholds for some tags 139 | # default threshold: 0.68 140 | # 2701: "person"; 2828: "man"; 1167: "woman"; 141 | tag_thrshold = {2701:0.7, 2828: 0.7, 1167: 0.7} 142 | self.class_threshold = torch.ones(self.num_class) * self.threshold 143 | for key,value in tag_thrshold.items(): 144 | self.class_threshold[key] = value 145 | 146 | def load_tag_list(self, tag_list_file): 147 | with open(tag_list_file, 'r') as f: 148 | tag_list = f.read().splitlines() 149 | tag_list = np.array(tag_list) 150 | return tag_list 151 | 152 | # delete self-attention layer of image-tag recognition decoder to reduce computation, follower Query2Label 153 | def del_selfattention(self): 154 | del self.tagging_head.embeddings 155 | for layer in self.tagging_head.encoder.layer: 156 | del layer.attention 157 | 158 | 159 | def forward(self, image, caption, tag): 160 | """ 161 | call function as forward 162 | 163 | Args: 164 | image: type: torch.Tensor shape: batch_size * 3 * 384 * 384 165 | caption: type: list[string] len: batch_size 166 | tag: type: torch.Tensor shape: batch * class_num (e.g. 3429) value: positive sample is 1.0, negative sample is 0.0 167 | 168 | Returns: 169 | loss: type: torch.Tensor 170 | """ 171 | 172 | image_embeds = self.visual_encoder(image) 173 | image_atts = torch.ones(image_embeds.size()[:-1], 174 | dtype=torch.long).to(image.device) 175 | 176 | ##================= Image Tagging ================## 177 | bs = image_embeds.shape[0] 178 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 179 | 180 | tagging_embed = self.tagging_head( 181 | encoder_embeds=label_embed, 182 | encoder_hidden_states=image_embeds, 183 | encoder_attention_mask=image_atts, 184 | return_dict=False, 185 | mode='tagging', 186 | ) 187 | 188 | logits = self.fc(tagging_embed[0]) 189 | 190 | loss_tag = self.tagging_loss_function(logits, tag) 191 | 192 | ##================= Image-Tag-Text Generation ================## 193 | tag = tag.cpu().numpy() 194 | tag_input = [] 195 | for b in range(bs): 196 | index = np.argwhere(tag[b] == 1) 197 | token = self.tag_list[index].squeeze(axis=1) 198 | tag_input.append(' | '.join(token)) 199 | 200 | # tokenizer input tags 201 | tag_input_tokenzier = self.tokenizer(tag_input, 202 | padding='max_length', 203 | truncation=True, 204 | max_length=40, 205 | return_tensors="pt").to( 206 | image.device) 207 | encoder_input_ids = tag_input_tokenzier.input_ids 208 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 209 | 210 | # put input tag into image-tag interaction encoder to interact with image embeddings 211 | output_tagembedding = self.tag_encoder( 212 | encoder_input_ids, 213 | attention_mask=tag_input_tokenzier.attention_mask, 214 | encoder_hidden_states=image_embeds, 215 | encoder_attention_mask=image_atts, 216 | return_dict=True, 217 | ) 218 | 219 | text = self.tokenizer(caption, 220 | padding='longest', 221 | truncation=True, 222 | max_length=40, 223 | return_tensors="pt").to( 224 | image.device) 225 | 226 | decoder_input_ids = text.input_ids 227 | decoder_input_ids[:,0] = self.tokenizer.bos_token_id 228 | 229 | decoder_targets = decoder_input_ids.masked_fill( 230 | decoder_input_ids == self.tokenizer.pad_token_id, -100) 231 | decoder_targets[:,:self.prompt_length] = -100 232 | 233 | decoder_output = self.text_decoder(decoder_input_ids, 234 | attention_mask = text.attention_mask, 235 | encoder_hidden_states = output_tagembedding.last_hidden_state, 236 | encoder_attention_mask = None, 237 | labels = decoder_targets, 238 | return_dict = True, 239 | ) 240 | 241 | loss_t2t = decoder_output.loss 242 | 243 | return loss_t2t, loss_tag 244 | 245 | 246 | def generate(self, 247 | image, 248 | sample=False, 249 | num_beams=3, 250 | max_length=30, 251 | min_length=10, 252 | top_p=0.9, 253 | repetition_penalty=1.0, 254 | tag_input=None, 255 | return_tag_predict=False): 256 | 257 | image_embeds = self.visual_encoder(image) 258 | image_atts = torch.ones(image_embeds.size()[:-1], 259 | dtype=torch.long).to(image.device) 260 | 261 | # if not user specified tags, recognized image tags using image-tag recogntiion decoder 262 | if tag_input == None: 263 | 264 | bs = image_embeds.shape[0] 265 | label_embed = self.label_embed.weight.unsqueeze(0).repeat(bs, 1, 1) 266 | tagging_embed = self.tagging_head( 267 | encoder_embeds=label_embed, 268 | encoder_hidden_states=image_embeds, 269 | encoder_attention_mask=image_atts, 270 | return_dict=False, 271 | mode='tagging', 272 | ) 273 | 274 | logits = self.fc(tagging_embed[0]) 275 | 276 | targets = torch.where( 277 | torch.sigmoid(logits) > self.class_threshold.to(image.device), 278 | torch.tensor(1.0).to(image.device), 279 | torch.zeros(self.num_class).to(image.device)) 280 | 281 | tag = targets.cpu().numpy() 282 | 283 | # delete some tags that may disturb captioning 284 | tag[:, self.delete_tag_index] = 0 285 | 286 | tag_input = [] 287 | for b in range(bs): 288 | index = np.argwhere(tag[b] == 1) 289 | token = self.tag_list[index].squeeze(axis=1) 290 | tag_input.append(' | '.join(token)) 291 | 292 | tag_output = tag_input 293 | 294 | # beam search for text generation(default) 295 | if not sample: 296 | image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) 297 | tag_input_temp = [] 298 | for tag in tag_input: 299 | for i in range(num_beams): 300 | tag_input_temp.append(tag) 301 | tag_input = tag_input_temp 302 | 303 | image_atts = torch.ones(image_embeds.size()[:-1], 304 | dtype=torch.long).to(image.device) 305 | 306 | # tokenizer input tags 307 | tag_input_tokenzier = self.tokenizer(tag_input, 308 | padding='max_length', 309 | truncation=True, 310 | max_length=40, 311 | return_tensors="pt").to( 312 | image.device) 313 | encoder_input_ids = tag_input_tokenzier.input_ids 314 | encoder_input_ids[:, 0] = self.tokenizer.enc_token_id 315 | 316 | # put input tag into image-tag interaction encoder to interact with image embeddings 317 | output_tagembedding = self.tag_encoder( 318 | encoder_input_ids, 319 | attention_mask=tag_input_tokenzier.attention_mask, 320 | encoder_hidden_states=image_embeds, 321 | encoder_attention_mask=image_atts, 322 | return_dict=True, 323 | ) 324 | 325 | # prompt trick for better captioning, followed BLIP 326 | prompt = [self.prompt] * image.size(0) 327 | input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( 328 | image.device) 329 | input_ids[:, 0] = self.tokenizer.bos_token_id 330 | input_ids = input_ids[:, :-1] 331 | 332 | if sample: 333 | # nucleus sampling 334 | model_kwargs = { 335 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 336 | "encoder_attention_mask": None 337 | } 338 | outputs = self.text_decoder.generate( 339 | input_ids=input_ids, 340 | max_length=max_length, 341 | min_length=min_length, 342 | do_sample=True, 343 | top_p=top_p, 344 | num_return_sequences=1, 345 | eos_token_id=self.tokenizer.sep_token_id, 346 | pad_token_id=self.tokenizer.pad_token_id, 347 | repetition_penalty=1.1, 348 | **model_kwargs) 349 | else: 350 | # beam search (default) 351 | model_kwargs = { 352 | "encoder_hidden_states": output_tagembedding.last_hidden_state, 353 | "encoder_attention_mask": None 354 | } 355 | outputs = self.text_decoder.generate( 356 | input_ids=input_ids, 357 | max_length=max_length, 358 | min_length=min_length, 359 | num_beams=num_beams, 360 | eos_token_id=self.tokenizer.sep_token_id, 361 | pad_token_id=self.tokenizer.pad_token_id, 362 | repetition_penalty=repetition_penalty, 363 | **model_kwargs) 364 | 365 | captions = [] 366 | for output in outputs: 367 | caption = self.tokenizer.decode(output, skip_special_tokens=True) 368 | captions.append(caption[len(self.prompt):]) 369 | if return_tag_predict == True: 370 | return captions, tag_output 371 | return captions 372 | 373 | 374 | # load Tag2Text pretrained model parameters 375 | def tag2text(pretrained='', **kwargs): 376 | model = Tag2Text(**kwargs) 377 | if pretrained: 378 | if kwargs['vit'] == 'swin_b': 379 | model, msg = load_checkpoint_swinbase(model, pretrained, kwargs) 380 | else: 381 | model, msg = load_checkpoint(model, pretrained) 382 | print('vit:', kwargs['vit']) 383 | # print('msg', msg) 384 | return model 385 | 386 | -------------------------------------------------------------------------------- /ram/models/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import math 5 | 6 | from torch import nn 7 | from typing import List 8 | from transformers import BertTokenizer 9 | from urllib.parse import urlparse 10 | from timm.models.hub import download_cached_file 11 | from .vit import interpolate_pos_embed, VisionTransformer 12 | from .swin_transformer import interpolate_relative_pos_embed 13 | from pathlib import Path 14 | CONFIG_PATH=(Path(__file__).resolve().parents[1]) 15 | 16 | def read_json(rpath): 17 | with open(rpath, 'r') as f: 18 | return json.load(f) 19 | 20 | 21 | def tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, 22 | base_model_prefix: str, skip_key: str): 23 | uninitialized_encoder_weights: List[str] = [] 24 | if decoder.__class__ != encoder.__class__: 25 | logger.info( 26 | f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder weights are correctly initialized." 27 | ) 28 | 29 | def tie_encoder_to_decoder_recursively( 30 | decoder_pointer: nn.Module, 31 | encoder_pointer: nn.Module, 32 | module_name: str, 33 | uninitialized_encoder_weights: List[str], 34 | skip_key: str, 35 | depth=0, 36 | ): 37 | assert isinstance(decoder_pointer, nn.Module) and isinstance( 38 | encoder_pointer, nn.Module 39 | ), f"{decoder_pointer} and {encoder_pointer} have to be of type torch.nn.Module" 40 | if hasattr(decoder_pointer, "weight") and skip_key not in module_name: 41 | assert hasattr(encoder_pointer, "weight") 42 | encoder_pointer.weight = decoder_pointer.weight 43 | if hasattr(decoder_pointer, "bias"): 44 | assert hasattr(encoder_pointer, "bias") 45 | encoder_pointer.bias = decoder_pointer.bias 46 | print(module_name + ' is tied') 47 | return 48 | 49 | encoder_modules = encoder_pointer._modules 50 | decoder_modules = decoder_pointer._modules 51 | if len(decoder_modules) > 0: 52 | assert ( 53 | len(encoder_modules) > 0 54 | ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" 55 | 56 | all_encoder_weights = set([ 57 | module_name + "/" + sub_name 58 | for sub_name in encoder_modules.keys() 59 | ]) 60 | encoder_layer_pos = 0 61 | for name, module in decoder_modules.items(): 62 | if name.isdigit(): 63 | encoder_name = str(int(name) + encoder_layer_pos) 64 | decoder_name = name 65 | if not isinstance( 66 | decoder_modules[decoder_name], 67 | type(encoder_modules[encoder_name])) and len( 68 | encoder_modules) != len(decoder_modules): 69 | # this can happen if the name corresponds to the position in a list module list of layers 70 | # in this case the decoder has added a cross-attention that the encoder does not have 71 | # thus skip this step and subtract one layer pos from encoder 72 | encoder_layer_pos -= 1 73 | continue 74 | elif name not in encoder_modules: 75 | continue 76 | elif depth > 500: 77 | raise ValueError( 78 | "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is a circular dependency between two or more `nn.Modules` of your model." 79 | ) 80 | else: 81 | decoder_name = encoder_name = name 82 | tie_encoder_to_decoder_recursively( 83 | decoder_modules[decoder_name], 84 | encoder_modules[encoder_name], 85 | module_name + "/" + name, 86 | uninitialized_encoder_weights, 87 | skip_key, 88 | depth=depth + 1, 89 | ) 90 | all_encoder_weights.remove(module_name + "/" + encoder_name) 91 | 92 | uninitialized_encoder_weights += list(all_encoder_weights) 93 | 94 | # tie weights recursively 95 | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, 96 | uninitialized_encoder_weights, skip_key) 97 | 98 | 99 | class GroupWiseLinear(nn.Module): 100 | # could be changed to: 101 | # output = torch.einsum('ijk,zjk->ij', x, self.W) 102 | # or output = torch.einsum('ijk,jk->ij', x, self.W[0]) 103 | def __init__(self, num_class, hidden_dim, bias=True): 104 | super().__init__() 105 | self.num_class = num_class 106 | self.hidden_dim = hidden_dim 107 | self.bias = bias 108 | 109 | self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim)) 110 | if bias: 111 | self.b = nn.Parameter(torch.Tensor(1, num_class)) 112 | self.reset_parameters() 113 | 114 | def reset_parameters(self): 115 | stdv = 1. / math.sqrt(self.W.size(2)) 116 | for i in range(self.num_class): 117 | self.W[0][i].data.uniform_(-stdv, stdv) 118 | if self.bias: 119 | for i in range(self.num_class): 120 | self.b[0][i].data.uniform_(-stdv, stdv) 121 | 122 | def forward(self, x): 123 | # x: B,K,d 124 | x = (self.W * x).sum(-1) 125 | if self.bias: 126 | x = x + self.b 127 | return x 128 | 129 | 130 | def init_tokenizer(text_encoder_type='bert-base-uncased'): 131 | tokenizer = BertTokenizer.from_pretrained(text_encoder_type) 132 | tokenizer.add_special_tokens({'bos_token': '[DEC]'}) 133 | tokenizer.add_special_tokens({'additional_special_tokens': ['[ENC]']}) 134 | tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0] 135 | return tokenizer 136 | 137 | 138 | def create_vit(vit, 139 | image_size, 140 | use_grad_checkpointing=False, 141 | ckpt_layer=0, 142 | drop_path_rate=0): 143 | 144 | assert vit in ['base', 'large'], "vit parameter must be base or large" 145 | if vit == 'base': 146 | vision_width = 768 147 | visual_encoder = VisionTransformer( 148 | img_size=image_size, 149 | patch_size=16, 150 | embed_dim=vision_width, 151 | depth=12, 152 | num_heads=12, 153 | use_grad_checkpointing=use_grad_checkpointing, 154 | ckpt_layer=ckpt_layer, 155 | drop_path_rate=0 or drop_path_rate) 156 | elif vit == 'large': 157 | vision_width = 1024 158 | visual_encoder = VisionTransformer( 159 | img_size=image_size, 160 | patch_size=16, 161 | embed_dim=vision_width, 162 | depth=24, 163 | num_heads=16, 164 | use_grad_checkpointing=use_grad_checkpointing, 165 | ckpt_layer=ckpt_layer, 166 | drop_path_rate=0.1 or drop_path_rate) 167 | return visual_encoder, vision_width 168 | 169 | 170 | def is_url(url_or_filename): 171 | parsed = urlparse(url_or_filename) 172 | return parsed.scheme in ("http", "https") 173 | 174 | 175 | def load_checkpoint(model, url_or_filename): 176 | if is_url(url_or_filename): 177 | cached_file = download_cached_file(url_or_filename, 178 | check_hash=False, 179 | progress=True) 180 | checkpoint = torch.load(cached_file, map_location='cpu') 181 | elif os.path.isfile(url_or_filename): 182 | checkpoint = torch.load(url_or_filename, map_location='cpu') 183 | else: 184 | raise RuntimeError('checkpoint url or path is invalid') 185 | 186 | state_dict = checkpoint['model'] 187 | 188 | state_dict['visual_encoder.pos_embed'] = interpolate_pos_embed( 189 | state_dict['visual_encoder.pos_embed'], model.visual_encoder) 190 | if 'visual_encoder_m.pos_embed' in model.state_dict().keys(): 191 | state_dict['visual_encoder_m.pos_embed'] = interpolate_pos_embed( 192 | state_dict['visual_encoder_m.pos_embed'], model.visual_encoder_m) 193 | for key in model.state_dict().keys(): 194 | if key in state_dict.keys(): 195 | if state_dict[key].shape != model.state_dict()[key].shape: 196 | del state_dict[key] 197 | 198 | msg = model.load_state_dict(state_dict, strict=False) 199 | print('load checkpoint from %s' % url_or_filename) 200 | return model, msg 201 | 202 | 203 | def load_checkpoint_swinbase(model, url_or_filename, kwargs): 204 | if kwargs['image_size'] == 224: 205 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_224.json' 206 | elif kwargs['image_size'] == 384: 207 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinB_384.json' 208 | window_size = read_json(vision_config_path)['window_size'] 209 | print('--------------') 210 | print(url_or_filename) 211 | print('--------------') 212 | if is_url(url_or_filename): 213 | cached_file = download_cached_file(url_or_filename, 214 | check_hash=False, 215 | progress=True) 216 | checkpoint = torch.load(cached_file, map_location='cpu') 217 | elif os.path.isfile(url_or_filename): 218 | checkpoint = torch.load(url_or_filename, map_location='cpu') 219 | else: 220 | raise RuntimeError('checkpoint url or path is invalid') 221 | 222 | state_dict = checkpoint['model'] 223 | 224 | for k in list(state_dict.keys()): 225 | if 'relative_position_bias_table' in k: 226 | dst_num_pos = (2 * window_size - 1)**2 227 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 228 | dst_num_pos, 229 | param_name=k) 230 | elif ('relative_position_index' in k) or ('attn_mask' in k): 231 | del state_dict[k] 232 | elif "vision_multi" in k: 233 | state_dict[k.replace("vision_multi", 234 | "tagging_head")] = state_dict.pop(k) 235 | 236 | msg = model.load_state_dict(state_dict, strict=False) 237 | print('load checkpoint from %s' % url_or_filename) 238 | return model, msg 239 | 240 | 241 | def load_checkpoint_swinlarge(model, url_or_filename, kwargs): 242 | if kwargs['image_size'] == 224: 243 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_224.json' 244 | elif kwargs['image_size'] == 384: 245 | vision_config_path = f'{CONFIG_PATH}/configs/swin/config_swinL_384.json' 246 | window_size = read_json(vision_config_path)['window_size'] 247 | print('--------------') 248 | print(url_or_filename) 249 | print('--------------') 250 | if is_url(url_or_filename): 251 | cached_file = download_cached_file(url_or_filename, 252 | check_hash=False, 253 | progress=True) 254 | checkpoint = torch.load(cached_file, map_location='cpu') 255 | elif os.path.isfile(url_or_filename): 256 | checkpoint = torch.load(url_or_filename, map_location='cpu') 257 | else: 258 | raise RuntimeError('checkpoint url or path is invalid') 259 | 260 | state_dict = checkpoint['model'] 261 | 262 | for k in list(state_dict.keys()): 263 | if 'relative_position_bias_table' in k: 264 | dst_num_pos = (2 * window_size - 1)**2 265 | state_dict[k] = interpolate_relative_pos_embed(state_dict[k], 266 | dst_num_pos, 267 | param_name=k) 268 | elif ('relative_position_index' in k) or ('attn_mask' in k): 269 | del state_dict[k] 270 | elif "vision_multi" in k: 271 | state_dict[k.replace("vision_multi", 272 | "tagging_head")] = state_dict.pop(k) 273 | 274 | msg = model.load_state_dict(state_dict, strict=False) 275 | print('load checkpoint from %s' % url_or_filename) 276 | return model, msg 277 | 278 | 279 | # Tagging loss function 280 | # copy from https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py 281 | class AsymmetricLoss(nn.Module): 282 | def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): 283 | super(AsymmetricLoss, self).__init__() 284 | 285 | self.gamma_neg = gamma_neg 286 | self.gamma_pos = gamma_pos 287 | self.clip = clip 288 | self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss 289 | self.eps = eps 290 | 291 | def forward(self, x, y): 292 | """" 293 | Parameters 294 | ---------- 295 | x: input logits 296 | y: targets (multi-label binarized vector) 297 | """ 298 | 299 | # Calculating Probabilities 300 | x_sigmoid = torch.sigmoid(x) 301 | xs_pos = x_sigmoid 302 | xs_neg = 1 - x_sigmoid 303 | 304 | # Asymmetric Clipping 305 | if self.clip is not None and self.clip > 0: 306 | xs_neg = (xs_neg + self.clip).clamp(max=1) 307 | 308 | # Basic CE calculation 309 | los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) 310 | los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) 311 | loss = los_pos + los_neg 312 | 313 | # Asymmetric Focusing 314 | if self.gamma_neg > 0 or self.gamma_pos > 0: 315 | if self.disable_torch_grad_focal_loss: 316 | torch.set_grad_enabled(False) 317 | pt0 = xs_pos * y 318 | pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p 319 | pt = pt0 + pt1 320 | one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) 321 | one_sided_w = torch.pow(1 - pt, one_sided_gamma) 322 | if self.disable_torch_grad_focal_loss: 323 | torch.set_grad_enabled(True) 324 | loss *= one_sided_w 325 | 326 | return -loss.sum() 327 | -------------------------------------------------------------------------------- /ram/models/vit.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2022, salesforce.com, inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: BSD-3-Clause 5 | * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | * By Junnan Li 7 | * Based on timm code base 8 | * https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | ''' 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from functools import partial 15 | 16 | from timm.models.vision_transformer import _cfg, PatchEmbed 17 | from timm.models.registry import register_model 18 | from timm.models.layers import trunc_normal_, DropPath 19 | from timm.models.helpers import named_apply, adapt_input_conv 20 | 21 | from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper 22 | 23 | class Mlp(nn.Module): 24 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 25 | """ 26 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 27 | super().__init__() 28 | out_features = out_features or in_features 29 | hidden_features = hidden_features or in_features 30 | self.fc1 = nn.Linear(in_features, hidden_features) 31 | self.act = act_layer() 32 | self.fc2 = nn.Linear(hidden_features, out_features) 33 | self.drop = nn.Dropout(drop) 34 | 35 | def forward(self, x): 36 | x = self.fc1(x) 37 | x = self.act(x) 38 | x = self.drop(x) 39 | x = self.fc2(x) 40 | x = self.drop(x) 41 | return x 42 | 43 | 44 | class Attention(nn.Module): 45 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 46 | super().__init__() 47 | self.num_heads = num_heads 48 | head_dim = dim // num_heads 49 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 50 | self.scale = qk_scale or head_dim ** -0.5 51 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 52 | self.attn_drop = nn.Dropout(attn_drop) 53 | self.proj = nn.Linear(dim, dim) 54 | self.proj_drop = nn.Dropout(proj_drop) 55 | self.attn_gradients = None 56 | self.attention_map = None 57 | 58 | def save_attn_gradients(self, attn_gradients): 59 | self.attn_gradients = attn_gradients 60 | 61 | def get_attn_gradients(self): 62 | return self.attn_gradients 63 | 64 | def save_attention_map(self, attention_map): 65 | self.attention_map = attention_map 66 | 67 | def get_attention_map(self): 68 | return self.attention_map 69 | 70 | def forward(self, x, register_hook=False): 71 | B, N, C = x.shape 72 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 73 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 74 | 75 | attn = (q @ k.transpose(-2, -1)) * self.scale 76 | attn = attn.softmax(dim=-1) 77 | attn = self.attn_drop(attn) 78 | 79 | if register_hook: 80 | self.save_attention_map(attn) 81 | attn.register_hook(self.save_attn_gradients) 82 | 83 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 84 | x = self.proj(x) 85 | x = self.proj_drop(x) 86 | return x 87 | 88 | 89 | class Block(nn.Module): 90 | 91 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 92 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_grad_checkpointing=False): 93 | super().__init__() 94 | self.norm1 = norm_layer(dim) 95 | self.attn = Attention( 96 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 97 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 98 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 99 | self.norm2 = norm_layer(dim) 100 | mlp_hidden_dim = int(dim * mlp_ratio) 101 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 102 | 103 | if use_grad_checkpointing: 104 | self.attn = checkpoint_wrapper(self.attn) 105 | self.mlp = checkpoint_wrapper(self.mlp) 106 | 107 | def forward(self, x, register_hook=False): 108 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) 109 | x = x + self.drop_path(self.mlp(self.norm2(x))) 110 | return x 111 | 112 | 113 | class VisionTransformer(nn.Module): 114 | """ Vision Transformer 115 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 116 | https://arxiv.org/abs/2010.11929 117 | """ 118 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 119 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 120 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, 121 | use_grad_checkpointing=False, ckpt_layer=0): 122 | """ 123 | Args: 124 | img_size (int, tuple): input image size 125 | patch_size (int, tuple): patch size 126 | in_chans (int): number of input channels 127 | num_classes (int): number of classes for classification head 128 | embed_dim (int): embedding dimension 129 | depth (int): depth of transformer 130 | num_heads (int): number of attention heads 131 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 132 | qkv_bias (bool): enable bias for qkv if True 133 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 134 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 135 | drop_rate (float): dropout rate 136 | attn_drop_rate (float): attention dropout rate 137 | drop_path_rate (float): stochastic depth rate 138 | norm_layer: (nn.Module): normalization layer 139 | """ 140 | super().__init__() 141 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 142 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 143 | 144 | self.patch_embed = PatchEmbed( 145 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 146 | 147 | num_patches = self.patch_embed.num_patches 148 | 149 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 150 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 151 | self.pos_drop = nn.Dropout(p=drop_rate) 152 | 153 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 154 | self.blocks = nn.ModuleList([ 155 | Block( 156 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 157 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 158 | use_grad_checkpointing=(use_grad_checkpointing and i>=depth-ckpt_layer) 159 | ) 160 | for i in range(depth)]) 161 | self.norm = norm_layer(embed_dim) 162 | 163 | trunc_normal_(self.pos_embed, std=.02) 164 | trunc_normal_(self.cls_token, std=.02) 165 | self.apply(self._init_weights) 166 | 167 | def _init_weights(self, m): 168 | if isinstance(m, nn.Linear): 169 | trunc_normal_(m.weight, std=.02) 170 | if isinstance(m, nn.Linear) and m.bias is not None: 171 | nn.init.constant_(m.bias, 0) 172 | elif isinstance(m, nn.LayerNorm): 173 | nn.init.constant_(m.bias, 0) 174 | nn.init.constant_(m.weight, 1.0) 175 | 176 | @torch.jit.ignore 177 | def no_weight_decay(self): 178 | return {'pos_embed', 'cls_token'} 179 | 180 | def forward(self, x, register_blk=-1): 181 | B = x.shape[0] 182 | x = self.patch_embed(x) 183 | 184 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 185 | x = torch.cat((cls_tokens, x), dim=1) 186 | 187 | x = x + self.pos_embed[:,:x.size(1),:] 188 | x = self.pos_drop(x) 189 | 190 | for i,blk in enumerate(self.blocks): 191 | x = blk(x, register_blk==i) 192 | x = self.norm(x) 193 | 194 | return x 195 | 196 | @torch.jit.ignore() 197 | def load_pretrained(self, checkpoint_path, prefix=''): 198 | _load_weights(self, checkpoint_path, prefix) 199 | 200 | 201 | @torch.no_grad() 202 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 203 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 204 | """ 205 | import numpy as np 206 | 207 | def _n2p(w, t=True): 208 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 209 | w = w.flatten() 210 | if t: 211 | if w.ndim == 4: 212 | w = w.transpose([3, 2, 0, 1]) 213 | elif w.ndim == 3: 214 | w = w.transpose([2, 0, 1]) 215 | elif w.ndim == 2: 216 | w = w.transpose([1, 0]) 217 | return torch.from_numpy(w) 218 | 219 | w = np.load(checkpoint_path) 220 | if not prefix and 'opt/target/embedding/kernel' in w: 221 | prefix = 'opt/target/' 222 | 223 | if hasattr(model.patch_embed, 'backbone'): 224 | # hybrid 225 | backbone = model.patch_embed.backbone 226 | stem_only = not hasattr(backbone, 'stem') 227 | stem = backbone if stem_only else backbone.stem 228 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 229 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 230 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 231 | if not stem_only: 232 | for i, stage in enumerate(backbone.stages): 233 | for j, block in enumerate(stage.blocks): 234 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 235 | for r in range(3): 236 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 237 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 238 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 239 | if block.downsample is not None: 240 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 241 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 242 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 243 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 244 | else: 245 | embed_conv_w = adapt_input_conv( 246 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 247 | model.patch_embed.proj.weight.copy_(embed_conv_w) 248 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 249 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 250 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 251 | if pos_embed_w.shape != model.pos_embed.shape: 252 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 253 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 254 | model.pos_embed.copy_(pos_embed_w) 255 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 256 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 257 | # if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 258 | # model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 259 | # model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 260 | # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 261 | # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 262 | # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 263 | for i, block in enumerate(model.blocks.children()): 264 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 265 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 266 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 267 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 268 | block.attn.qkv.weight.copy_(torch.cat([ 269 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 270 | block.attn.qkv.bias.copy_(torch.cat([ 271 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 272 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 273 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 274 | for r in range(2): 275 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 276 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 277 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 278 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 279 | 280 | 281 | def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): 282 | # interpolate position embedding 283 | embedding_size = pos_embed_checkpoint.shape[-1] 284 | num_patches = visual_encoder.patch_embed.num_patches 285 | num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches 286 | # height (== width) for the checkpoint position embedding 287 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 288 | # height (== width) for the new position embedding 289 | new_size = int(num_patches ** 0.5) 290 | 291 | if orig_size!=new_size: 292 | # class_token and dist_token are kept unchanged 293 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 294 | # only the position tokens are interpolated 295 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 296 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 297 | pos_tokens = torch.nn.functional.interpolate( 298 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 299 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 300 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 301 | print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) 302 | 303 | return new_pos_embed 304 | else: 305 | return pos_embed_checkpoint -------------------------------------------------------------------------------- /ram/transform.py: -------------------------------------------------------------------------------- 1 | from torchvision.transforms import Normalize, Compose, Resize, ToTensor 2 | 3 | 4 | def convert_to_rgb(image): 5 | return image.convert("RGB") 6 | 7 | def get_transform(image_size=384): 8 | return Compose([ 9 | convert_to_rgb, 10 | Resize((image_size, image_size)), 11 | ToTensor(), 12 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 13 | ]) 14 | -------------------------------------------------------------------------------- /ram/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import get_mAP, get_PR 2 | from .openset_utils import build_openset_label_embedding, build_openset_llm_label_embedding 3 | -------------------------------------------------------------------------------- /ram/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | from numpy import ndarray 5 | 6 | 7 | def get_mAP( 8 | preds: ndarray, 9 | gt_file: str, 10 | taglist: List[str] 11 | ) -> Tuple[float, ndarray]: 12 | assert preds.shape[1] == len(taglist) 13 | 14 | # When mapping categories from test datasets to our system, there might be 15 | # multiple vs one situation due to different semantic definitions of tags. 16 | # So there can be duplicate tags in `taglist`. This special case is taken 17 | # into account. 18 | tag2idxs = {} 19 | for idx, tag in enumerate(taglist): 20 | if tag not in tag2idxs: 21 | tag2idxs[tag] = [] 22 | tag2idxs[tag].append(idx) 23 | 24 | # build targets 25 | targets = np.zeros_like(preds) 26 | with open(gt_file, "r") as f: 27 | lines = [line.strip("\n").split(",") for line in f.readlines()] 28 | assert len(lines) == targets.shape[0] 29 | for i, line in enumerate(lines): 30 | for tag in line[1:]: 31 | targets[i, tag2idxs[tag]] = 1.0 32 | 33 | # compute average precision for each class 34 | APs = np.zeros(preds.shape[1]) 35 | for k in range(preds.shape[1]): 36 | APs[k] = _average_precision(preds[:, k], targets[:, k]) 37 | 38 | return APs.mean(), APs 39 | 40 | 41 | def _average_precision(output: ndarray, target: ndarray) -> float: 42 | epsilon = 1e-8 43 | 44 | # sort examples 45 | indices = output.argsort()[::-1] 46 | # Computes prec@i 47 | total_count_ = np.cumsum(np.ones((len(output), 1))) 48 | 49 | target_ = target[indices] 50 | ind = target_ == 1 51 | pos_count_ = np.cumsum(ind) 52 | total = pos_count_[-1] 53 | pos_count_[np.logical_not(ind)] = 0 54 | pp = pos_count_ / total_count_ 55 | precision_at_i_ = np.sum(pp) 56 | precision_at_i = precision_at_i_ / (total + epsilon) 57 | 58 | return precision_at_i 59 | 60 | 61 | def get_PR( 62 | pred_file: str, 63 | gt_file: str, 64 | taglist: List[str] 65 | ) -> Tuple[float, float, ndarray, ndarray]: 66 | # When mapping categories from test datasets to our system, there might be 67 | # multiple vs one situation due to different semantic definitions of tags. 68 | # So there can be duplicate tags in `taglist`. This special case is taken 69 | # into account. 70 | tag2idxs = {} 71 | for idx, tag in enumerate(taglist): 72 | if tag not in tag2idxs: 73 | tag2idxs[tag] = [] 74 | tag2idxs[tag].append(idx) 75 | 76 | # build preds 77 | with open(pred_file, "r", encoding="utf-8") as f: 78 | lines = [line.strip().split(",") for line in f.readlines()] 79 | preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 80 | for i, line in enumerate(lines): 81 | for tag in line[1:]: 82 | preds[i, tag2idxs[tag]] = True 83 | 84 | # build targets 85 | with open(gt_file, "r", encoding="utf-8") as f: 86 | lines = [line.strip().split(",") for line in f.readlines()] 87 | targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) 88 | for i, line in enumerate(lines): 89 | for tag in line[1:]: 90 | targets[i, tag2idxs[tag]] = True 91 | 92 | assert preds.shape == targets.shape 93 | 94 | # calculate P and R 95 | TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222 96 | FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222 97 | FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222 98 | eps = 1.e-9 99 | Ps = TPs / (TPs + FPs + eps) 100 | Rs = TPs / (TPs + FNs + eps) 101 | 102 | return Ps.mean(), Rs.mean(), Ps, Rs 103 | -------------------------------------------------------------------------------- /ram/utils/openset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | from clip import clip 7 | 8 | 9 | def article(name): 10 | return "an" if name[0] in "aeiou" else "a" 11 | 12 | 13 | def processed_name(name, rm_dot=False): 14 | # _ for lvis 15 | # / for obj365 16 | res = name.replace("_", " ").replace("/", " or ").lower() 17 | if rm_dot: 18 | res = res.rstrip(".") 19 | return res 20 | 21 | 22 | single_template = ["a photo of a {}."] 23 | 24 | multiple_templates = [ 25 | "There is {article} {} in the scene.", 26 | "There is the {} in the scene.", 27 | "a photo of {article} {} in the scene.", 28 | "a photo of the {} in the scene.", 29 | "a photo of one {} in the scene.", 30 | "itap of {article} {}.", 31 | "itap of my {}.", # itap: I took a picture of 32 | "itap of the {}.", 33 | "a photo of {article} {}.", 34 | "a photo of my {}.", 35 | "a photo of the {}.", 36 | "a photo of one {}.", 37 | "a photo of many {}.", 38 | "a good photo of {article} {}.", 39 | "a good photo of the {}.", 40 | "a bad photo of {article} {}.", 41 | "a bad photo of the {}.", 42 | "a photo of a nice {}.", 43 | "a photo of the nice {}.", 44 | "a photo of a cool {}.", 45 | "a photo of the cool {}.", 46 | "a photo of a weird {}.", 47 | "a photo of the weird {}.", 48 | "a photo of a small {}.", 49 | "a photo of the small {}.", 50 | "a photo of a large {}.", 51 | "a photo of the large {}.", 52 | "a photo of a clean {}.", 53 | "a photo of the clean {}.", 54 | "a photo of a dirty {}.", 55 | "a photo of the dirty {}.", 56 | "a bright photo of {article} {}.", 57 | "a bright photo of the {}.", 58 | "a dark photo of {article} {}.", 59 | "a dark photo of the {}.", 60 | "a photo of a hard to see {}.", 61 | "a photo of the hard to see {}.", 62 | "a low resolution photo of {article} {}.", 63 | "a low resolution photo of the {}.", 64 | "a cropped photo of {article} {}.", 65 | "a cropped photo of the {}.", 66 | "a close-up photo of {article} {}.", 67 | "a close-up photo of the {}.", 68 | "a jpeg corrupted photo of {article} {}.", 69 | "a jpeg corrupted photo of the {}.", 70 | "a blurry photo of {article} {}.", 71 | "a blurry photo of the {}.", 72 | "a pixelated photo of {article} {}.", 73 | "a pixelated photo of the {}.", 74 | "a black and white photo of the {}.", 75 | "a black and white photo of {article} {}.", 76 | "a plastic {}.", 77 | "the plastic {}.", 78 | "a toy {}.", 79 | "the toy {}.", 80 | "a plushie {}.", 81 | "the plushie {}.", 82 | "a cartoon {}.", 83 | "the cartoon {}.", 84 | "an embroidered {}.", 85 | "the embroidered {}.", 86 | "a painting of the {}.", 87 | "a painting of a {}.", 88 | ] 89 | 90 | 91 | openimages_rare_unseen = ['Aerial photography', 92 | 'Aircraft engine', 93 | 'Ale', 94 | 'Aloe', 95 | 'Amphibian', 96 | 'Angling', 97 | 'Anole', 98 | 'Antique car', 99 | 'Arcade game', 100 | 'Arthropod', 101 | 'Assault rifle', 102 | 'Athletic shoe', 103 | 'Auto racing', 104 | 'Backlighting', 105 | 'Bagpipes', 106 | 'Ball game', 107 | 'Barbecue chicken', 108 | 'Barechested', 109 | 'Barquentine', 110 | 'Beef tenderloin', 111 | 'Billiard room', 112 | 'Billiards', 113 | 'Bird of prey', 114 | 'Black swan', 115 | 'Black-and-white', 116 | 'Blond', 117 | 'Boating', 118 | 'Bonbon', 119 | 'Bottled water', 120 | 'Bouldering', 121 | 'Bovine', 122 | 'Bratwurst', 123 | 'Breadboard', 124 | 'Briefs', 125 | 'Brisket', 126 | 'Brochette', 127 | 'Calabaza', 128 | 'Camera operator', 129 | 'Canola', 130 | 'Childbirth', 131 | 'Chordophone', 132 | 'Church bell', 133 | 'Classical sculpture', 134 | 'Close-up', 135 | 'Cobblestone', 136 | 'Coca-cola', 137 | 'Combat sport', 138 | 'Comics', 139 | 'Compact car', 140 | 'Computer speaker', 141 | 'Cookies and crackers', 142 | 'Coral reef fish', 143 | 'Corn on the cob', 144 | 'Cosmetics', 145 | 'Crocodilia', 146 | 'Digital camera', 147 | 'Dishware', 148 | 'Divemaster', 149 | 'Dobermann', 150 | 'Dog walking', 151 | 'Domestic rabbit', 152 | 'Domestic short-haired cat', 153 | 'Double-decker bus', 154 | 'Drums', 155 | 'Electric guitar', 156 | 'Electric piano', 157 | 'Electronic instrument', 158 | 'Equestrianism', 159 | 'Equitation', 160 | 'Erinaceidae', 161 | 'Extreme sport', 162 | 'Falafel', 163 | 'Figure skating', 164 | 'Filling station', 165 | 'Fire apparatus', 166 | 'Firearm', 167 | 'Flatbread', 168 | 'Floristry', 169 | 'Forklift truck', 170 | 'Freight transport', 171 | 'Fried food', 172 | 'Fried noodles', 173 | 'Frigate', 174 | 'Frozen yogurt', 175 | 'Frying', 176 | 'Full moon', 177 | 'Galleon', 178 | 'Glacial landform', 179 | 'Gliding', 180 | 'Go-kart', 181 | 'Goats', 182 | 'Grappling', 183 | 'Great white shark', 184 | 'Gumbo', 185 | 'Gun turret', 186 | 'Hair coloring', 187 | 'Halter', 188 | 'Headphones', 189 | 'Heavy cruiser', 190 | 'Herding', 191 | 'High-speed rail', 192 | 'Holding hands', 193 | 'Horse and buggy', 194 | 'Horse racing', 195 | 'Hound', 196 | 'Hunting knife', 197 | 'Hurdling', 198 | 'Inflatable', 199 | 'Jackfruit', 200 | 'Jeans', 201 | 'Jiaozi', 202 | 'Junk food', 203 | 'Khinkali', 204 | 'Kitesurfing', 205 | 'Lawn game', 206 | 'Leaf vegetable', 207 | 'Lechon', 208 | 'Lifebuoy', 209 | 'Locust', 210 | 'Lumpia', 211 | 'Luxury vehicle', 212 | 'Machine tool', 213 | 'Medical imaging', 214 | 'Melee weapon', 215 | 'Microcontroller', 216 | 'Middle ages', 217 | 'Military person', 218 | 'Military vehicle', 219 | 'Milky way', 220 | 'Miniature Poodle', 221 | 'Modern dance', 222 | 'Molluscs', 223 | 'Monoplane', 224 | 'Motorcycling', 225 | 'Musical theatre', 226 | 'Narcissus', 227 | 'Nest box', 228 | 'Newsagent\'s shop', 229 | 'Nile crocodile', 230 | 'Nordic skiing', 231 | 'Nuclear power plant', 232 | 'Orator', 233 | 'Outdoor shoe', 234 | 'Parachuting', 235 | 'Pasta salad', 236 | 'Peafowl', 237 | 'Pelmeni', 238 | 'Perching bird', 239 | 'Performance car', 240 | 'Personal water craft', 241 | 'Pit bull', 242 | 'Plant stem', 243 | 'Pork chop', 244 | 'Portrait photography', 245 | 'Primate', 246 | 'Procyonidae', 247 | 'Prosciutto', 248 | 'Public speaking', 249 | 'Racewalking', 250 | 'Ramen', 251 | 'Rear-view mirror', 252 | 'Residential area', 253 | 'Ribs', 254 | 'Rice ball', 255 | 'Road cycling', 256 | 'Roller skating', 257 | 'Roman temple', 258 | 'Rowing', 259 | 'Rural area', 260 | 'Sailboat racing', 261 | 'Scaled reptile', 262 | 'Scuba diving', 263 | 'Senior citizen', 264 | 'Shallot', 265 | 'Shinto shrine', 266 | 'Shooting range', 267 | 'Siberian husky', 268 | 'Sledding', 269 | 'Soba', 270 | 'Solar energy', 271 | 'Sport climbing', 272 | 'Sport utility vehicle', 273 | 'Steamed rice', 274 | 'Stemware', 275 | 'Sumo', 276 | 'Surfing Equipment', 277 | 'Team sport', 278 | 'Touring car', 279 | 'Toy block', 280 | 'Trampolining', 281 | 'Underwater diving', 282 | 'Vegetarian food', 283 | 'Wallaby', 284 | 'Water polo', 285 | 'Watercolor paint', 286 | 'Whiskers', 287 | 'Wind wave', 288 | 'Woodwind instrument', 289 | 'Yakitori', 290 | 'Zeppelin'] 291 | 292 | 293 | def build_openset_label_embedding(categories=None): 294 | if categories is None: 295 | categories = openimages_rare_unseen 296 | print("Creating pretrained CLIP model") 297 | model, _ = clip.load("ViT-B/16") 298 | templates = multiple_templates 299 | 300 | run_on_gpu = torch.cuda.is_available() 301 | 302 | with torch.no_grad(): 303 | openset_label_embedding = [] 304 | for category in categories: 305 | texts = [ 306 | template.format( 307 | processed_name(category, rm_dot=True), article=article(category) 308 | ) 309 | for template in templates 310 | ] 311 | texts = [ 312 | "This is " + text if text.startswith("a") or text.startswith("the") else text 313 | for text in texts 314 | ] 315 | texts = clip.tokenize(texts) # tokenize 316 | if run_on_gpu: 317 | texts = texts.cuda() 318 | model = model.cuda() 319 | text_embeddings = model.encode_text(texts) 320 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 321 | text_embedding = text_embeddings.mean(dim=0) 322 | text_embedding /= text_embedding.norm() 323 | openset_label_embedding.append(text_embedding) 324 | openset_label_embedding = torch.stack(openset_label_embedding, dim=1) 325 | if run_on_gpu: 326 | openset_label_embedding = openset_label_embedding.cuda() 327 | 328 | openset_label_embedding = openset_label_embedding.t() 329 | return openset_label_embedding, categories 330 | 331 | 332 | 333 | import json 334 | from tqdm import tqdm 335 | 336 | def build_openset_llm_label_embedding(llm_tag_des): 337 | print("Creating pretrained CLIP model") 338 | model, _ = clip.load("ViT-B/16") 339 | llm_tag_des = llm_tag_des 340 | categories = [] 341 | 342 | run_on_gpu = torch.cuda.is_available() 343 | 344 | with torch.no_grad(): 345 | openset_label_embedding = [] 346 | for item in tqdm(llm_tag_des): 347 | category = list(item.keys())[0] 348 | des = list(item.values())[0] 349 | 350 | categories.append(category) 351 | 352 | texts = clip.tokenize(des, truncate=True) # tokenize 353 | if run_on_gpu: 354 | texts = texts.cuda() 355 | model = model.cuda() 356 | text_embeddings = model.encode_text(texts) 357 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 358 | # text_embedding = text_embeddings.mean(dim=0) 359 | # text_embedding /= text_embedding.norm() 360 | # openset_label_embedding.append(text_embedding) 361 | openset_label_embedding.append(text_embeddings) 362 | # openset_label_embedding = torch.stack(openset_label_embedding, dim=1) 363 | openset_label_embedding = torch.cat(openset_label_embedding, dim=0) 364 | if run_on_gpu: 365 | openset_label_embedding = openset_label_embedding.cuda() 366 | 367 | # openset_label_embedding = openset_label_embedding.t() 368 | return openset_label_embedding, categories 369 | 370 | 371 | 372 | 373 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.4.12 2 | transformers>=4.25.1 3 | fairscale==0.4.4 4 | pycocoevalcap 5 | torch 6 | torchvision 7 | Pillow 8 | scipy 9 | clip @ git+https://github.com/openai/CLIP.git 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = ram 3 | version = 0.0.1 4 | description = Recognize Anything Plus Model, Recognize Anything Model and Tag2Text Model 5 | 6 | [options] 7 | packages = find: 8 | include_package_data = True 9 | 10 | [options.packages.find] 11 | exclude = 12 | datasets 13 | images 14 | outputs 15 | pretrained 16 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | setuptools.setup() 3 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr): 3 | """Decay the learning rate""" 4 | lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr 5 | for param_group in optimizer.param_groups: 6 | param_group['lr'] = lr 7 | 8 | def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr): 9 | """Warmup the learning rate""" 10 | lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step) 11 | for param_group in optimizer.param_groups: 12 | 13 | param_group['lr'] = lr 14 | 15 | def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate): 16 | """Decay the learning rate""" 17 | lr = max(min_lr, init_lr * (decay_rate**epoch)) 18 | for param_group in optimizer.param_groups: 19 | param_group['lr'] = lr 20 | 21 | import numpy as np 22 | import io 23 | import os 24 | import time 25 | from collections import defaultdict, deque 26 | import datetime 27 | 28 | import torch 29 | import torch.distributed as dist 30 | 31 | class SmoothedValue(object): 32 | """Track a series of values and provide access to smoothed values over a 33 | window or the global series average. 34 | """ 35 | 36 | def __init__(self, window_size=20, fmt=None): 37 | if fmt is None: 38 | fmt = "{median:.4f} ({global_avg:.4f})" 39 | self.deque = deque(maxlen=window_size) 40 | self.total = 0.0 41 | self.count = 0 42 | self.fmt = fmt 43 | 44 | def update(self, value, n=1): 45 | self.deque.append(value) 46 | self.count += n 47 | self.total += value * n 48 | 49 | def synchronize_between_processes(self): 50 | """ 51 | Warning: does not synchronize the deque! 52 | """ 53 | if not is_dist_avail_and_initialized(): 54 | return 55 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 56 | dist.barrier() 57 | dist.all_reduce(t) 58 | t = t.tolist() 59 | self.count = int(t[0]) 60 | self.total = t[1] 61 | 62 | @property 63 | def median(self): 64 | d = torch.tensor(list(self.deque)) 65 | return d.median().item() 66 | 67 | @property 68 | def avg(self): 69 | d = torch.tensor(list(self.deque), dtype=torch.float32) 70 | return d.mean().item() 71 | 72 | @property 73 | def global_avg(self): 74 | return self.total / self.count 75 | 76 | @property 77 | def max(self): 78 | return max(self.deque) 79 | 80 | @property 81 | def value(self): 82 | return self.deque[-1] 83 | 84 | def __str__(self): 85 | return self.fmt.format( 86 | median=self.median, 87 | avg=self.avg, 88 | global_avg=self.global_avg, 89 | max=self.max, 90 | value=self.value) 91 | 92 | 93 | class MetricLogger(object): 94 | def __init__(self, delimiter="\t"): 95 | self.meters = defaultdict(SmoothedValue) 96 | self.delimiter = delimiter 97 | 98 | def update(self, **kwargs): 99 | for k, v in kwargs.items(): 100 | if isinstance(v, torch.Tensor): 101 | v = v.item() 102 | assert isinstance(v, (float, int)) 103 | self.meters[k].update(v) 104 | 105 | def __getattr__(self, attr): 106 | if attr in self.meters: 107 | return self.meters[attr] 108 | if attr in self.__dict__: 109 | return self.__dict__[attr] 110 | raise AttributeError("'{}' object has no attribute '{}'".format( 111 | type(self).__name__, attr)) 112 | 113 | def __str__(self): 114 | loss_str = [] 115 | for name, meter in self.meters.items(): 116 | loss_str.append( 117 | "{}: {}".format(name, str(meter)) 118 | ) 119 | return self.delimiter.join(loss_str) 120 | 121 | def global_avg(self): 122 | loss_str = [] 123 | for name, meter in self.meters.items(): 124 | loss_str.append( 125 | "{}: {:.4f}".format(name, meter.global_avg) 126 | ) 127 | return self.delimiter.join(loss_str) 128 | 129 | def synchronize_between_processes(self): 130 | for meter in self.meters.values(): 131 | meter.synchronize_between_processes() 132 | 133 | def add_meter(self, name, meter): 134 | self.meters[name] = meter 135 | 136 | def log_every(self, iterable, print_freq, header=None): 137 | i = 0 138 | if not header: 139 | header = '' 140 | start_time = time.time() 141 | end = time.time() 142 | iter_time = SmoothedValue(fmt='{avg:.4f}') 143 | data_time = SmoothedValue(fmt='{avg:.4f}') 144 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 145 | log_msg = [ 146 | header, 147 | '[{0' + space_fmt + '}/{1}]', 148 | 'eta: {eta}', 149 | '{meters}', 150 | 'time: {time}', 151 | 'data: {data}' 152 | ] 153 | if torch.cuda.is_available(): 154 | log_msg.append('max mem: {memory:.0f}') 155 | log_msg = self.delimiter.join(log_msg) 156 | MB = 1024.0 * 1024.0 157 | for obj in iterable: 158 | data_time.update(time.time() - end) 159 | yield obj 160 | iter_time.update(time.time() - end) 161 | if i % print_freq == 0 or i == len(iterable) - 1: 162 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 163 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 164 | if torch.cuda.is_available(): 165 | print(log_msg.format( 166 | i, len(iterable), eta=eta_string, 167 | meters=str(self), 168 | time=str(iter_time), data=str(data_time), 169 | memory=torch.cuda.max_memory_allocated() / MB)) 170 | else: 171 | print(log_msg.format( 172 | i, len(iterable), eta=eta_string, 173 | meters=str(self), 174 | time=str(iter_time), data=str(data_time))) 175 | i += 1 176 | end = time.time() 177 | total_time = time.time() - start_time 178 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 179 | print('{} Total time: {} ({:.4f} s / it)'.format( 180 | header, total_time_str, total_time / len(iterable))) 181 | 182 | 183 | class AttrDict(dict): 184 | def __init__(self, *args, **kwargs): 185 | super(AttrDict, self).__init__(*args, **kwargs) 186 | self.__dict__ = self 187 | 188 | 189 | def compute_acc(logits, label, reduction='mean'): 190 | ret = (torch.argmax(logits, dim=1) == label).float() 191 | if reduction == 'none': 192 | return ret.detach() 193 | elif reduction == 'mean': 194 | return ret.mean().item() 195 | 196 | def compute_n_params(model, return_str=True): 197 | tot = 0 198 | for p in model.parameters(): 199 | w = 1 200 | for x in p.shape: 201 | w *= x 202 | tot += w 203 | if return_str: 204 | if tot >= 1e6: 205 | return '{:.1f}M'.format(tot / 1e6) 206 | else: 207 | return '{:.1f}K'.format(tot / 1e3) 208 | else: 209 | return tot 210 | 211 | def setup_for_distributed(is_master): 212 | """ 213 | This function disables printing when not in master process 214 | """ 215 | import builtins as __builtin__ 216 | builtin_print = __builtin__.print 217 | 218 | def print(*args, **kwargs): 219 | force = kwargs.pop('force', False) 220 | if is_master or force: 221 | builtin_print(*args, **kwargs) 222 | 223 | __builtin__.print = print 224 | 225 | 226 | def is_dist_avail_and_initialized(): 227 | if not dist.is_available(): 228 | return False 229 | if not dist.is_initialized(): 230 | return False 231 | return True 232 | 233 | 234 | def get_world_size(): 235 | if not is_dist_avail_and_initialized(): 236 | return 1 237 | return dist.get_world_size() 238 | 239 | 240 | def get_rank(): 241 | if not is_dist_avail_and_initialized(): 242 | return 0 243 | return dist.get_rank() 244 | 245 | 246 | def is_main_process(): 247 | return get_rank() == 0 248 | 249 | 250 | def save_on_master(*args, **kwargs): 251 | if is_main_process(): 252 | torch.save(*args, **kwargs) 253 | 254 | 255 | def init_distributed_mode(args): 256 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 257 | args.rank = int(os.environ["RANK"]) 258 | args.world_size = int(os.environ['WORLD_SIZE']) 259 | args.gpu = int(os.environ['LOCAL_RANK']) 260 | elif 'SLURM_PROCID' in os.environ: 261 | args.rank = int(os.environ['SLURM_PROCID']) 262 | args.gpu = args.rank % torch.cuda.device_count() 263 | else: 264 | print('Not using distributed mode') 265 | args.distributed = False 266 | return 267 | 268 | args.distributed = True 269 | 270 | torch.cuda.set_device(args.gpu) 271 | args.dist_backend = 'nccl' 272 | print('| distributed init (rank {}, word {}): {}'.format( 273 | args.rank, args.world_size, args.dist_url), flush=True) 274 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 275 | world_size=args.world_size, rank=args.rank) 276 | torch.distributed.barrier() 277 | setup_for_distributed(args.rank == 0) 278 | 279 | --------------------------------------------------------------------------------