├── .gitignore ├── README.md ├── framework.png ├── input_text └── scene_graph_text │ ├── scene_graph_coco17 │ └── 000000000057.json │ ├── scene_graph_coco17_attr │ └── 000000000057.json │ └── scene_graph_coco17_caption │ └── 000000000057.json ├── main_aokvqa.py ├── main_okvqa.py ├── object_similarity ├── object_engineer_aokvqa.sh ├── object_engineer_okvqa.sh ├── object_engineering_aokvqa.py └── object_engineering_okvqa.py ├── preprocess ├── make_clip_features.py ├── make_line2sample.py ├── preprocess_aokvqa.sh ├── preprocess_okvqa.py └── reorganize_captions.py ├── run_aokvqa.sh ├── run_okvqa.sh └── utils_api.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code for [paper](https://arxiv.org/abs/2301.05226) *Visual Chain-of-Thought Prompting for Knowledge-based Visual Reasoning* 2 | ## Overall framework 3 | ![framework](framework.png) 4 | 5 | ## Preprocess datasets 6 | * Coco dataset 2014 and 2017 7 | * Download OK-VQA and AOK-VQA dataset, following the [PICa](https://github.com/microsoft/PICa) format 8 | * For A-OKVQA, run preprocess script (`preprocess/preprocess_aokvqa.sh` for AOK-VQA). For OK-VQA, you need to modify the script a little to fit its format (A-OKVQA and OK-VQA have similar formats). 9 | * Make training object similarity file (`object_similarity/object_similarity_aokvqa.sh` for AOK-VQA and `object_similarity/object_similarity_okvqa.sh` for OK-VQA) 10 | ## Prepare Scene graph and captions 11 | * Before running experiments, VisualCoT also need scene graph and captions, including three files for each input image (under `input_text/scene_graph_text/scene_graph_coco17`, `input_text/scene_graph_text/scene_graph_coco17_attr`, and `input_text/scene_graph_text/scene_graph_coco17_caption`). We have provided an example of image No.57 under each dir. Please follow the format of the examples and get scene graphs for all other images. 12 | * If you do not want to inference a scene graph model to get the scene graphs, here we provide the scene graphs and captions we generated (need additional process to match the format of above three examples): 13 | * [Dense Captions](https://umass-my.sharepoint.com/:u:/g/personal/qinhongzhou_umass_edu/ETTHSIaFZt1AnxyZjGDfAhEBIxn1CKM8JIle6rRjFHlLaQ?e=05Bt7N). 14 | * [Attributes](https://umass-my.sharepoint.com/:u:/g/personal/qinhongzhou_umass_edu/EZh2wLg5CrNIku4nWew40QgB6hwbJiD6jBy5oAxXVYA0zQ?e=TISTYq) for COCO17 test dataset. 15 | * [Relations and Objects](https://umass-my.sharepoint.com/:u:/g/personal/qinhongzhou_umass_edu/ETf9rj1yrFFJmEkJFGvvCBEBQ1uDz3b6LTSafigANyZcBg?e=15rBIw) for COCO17 test dataset. 16 | * [Attributes](https://umass-my.sharepoint.com/:u:/g/personal/qinhongzhou_umass_edu/EYNsQp_JD5ZGqmImfz3AWTgBQ3NCPIP9GOGASzLpEXIATQ?e=TD2ToA) for COCO17 17 | * [Relations and Objects](https://umass-my.sharepoint.com/:u:/g/personal/qinhongzhou_umass_edu/ESSSDHfvXEFMrggmJeuExlEBgCmJar5Ibt17z8yY9ZFgdw?e=cfEisT) for COCO17 18 | ## Run experiments 19 | * `run_aokvqa.sh` for AOK-VQA 20 | * `run_okvqa.sh` for OK-VQA 21 | ## Main Results 22 | | Backbone | OK-VQA test (DA) | AOK-VQA val (DA) | AOK-VQA test (DA) | 23 | |-------------|------------------|------------------|-------------------| 24 | | OPT-66B | 44.6 | 46.4 | 46.0 | 25 | | Llama-2-70B | 54.9 | 50.5 | 54.4 | 26 | ## Cite 27 | arXiv version 28 | ``` 29 | @article{chen2023see, 30 | title={Visual Chain-of-Thought Prompting for Knowledge-based Visual Reasoning}, 31 | author={Chen, Zhenfang and Zhou, Qinhong and Shen, Yikang and Hong, Yining and Sun, Zhiqing and Gutfreund, Dan and Gan, Chuang}, 32 | journal={arXiv preprint arXiv:2301.05226}, 33 | year={2023} 34 | } 35 | ``` 36 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UMass-Embodied-AGI/VisualCoT/e8b1fbc595107dd294756bf859ef228dd82e4507/framework.png -------------------------------------------------------------------------------- /input_text/scene_graph_text/scene_graph_coco17/000000000057.json: -------------------------------------------------------------------------------- 1 | [[{"rect": [415.10400390625, 93.2335433959961, 484.9341125488281, 166.52694702148438], "class": "arm", "conf": 0.6582691669464111}, {"rect": [422.6954650878906, 123.94754791259766, 498.64739990234375, 171.53367614746094], "class": "arm", "conf": 0.3263871371746063}, {"rect": [390.8945007324219, 103.7408218383789, 498.7996520996094, 184.929931640625], "class": "arm", "conf": 0.08014032989740372}, {"rect": [429.9819030761719, 126.46842956542969, 481.1421203613281, 161.34910583496094], "class": "arm", "conf": 0.07948336005210876}, {"rect": [406.1563415527344, 95.29844665527344, 454.5167541503906, 171.31448364257812], "class": "arm", "conf": 0.0779888927936554}, {"rect": [382.6018981933594, 60.616153717041016, 408.8274841308594, 97.87584686279297], "class": "face", "conf": 0.5411292314529419}, {"rect": [377.2185974121094, 36.973026275634766, 422.32861328125, 77.21955871582031], "class": "hair", "conf": 0.900209367275238}, {"rect": [380.9769287109375, 51.788658142089844, 414.83685302734375, 75.4355239868164], "class": "hair", "conf": 0.07423733174800873}, {"rect": [478.95684814453125, 162.27984619140625, 507.67852783203125, 186.30447387695312], "class": "hand", "conf": 0.5605973601341248}, {"rect": [489.6508483886719, 164.33233642578125, 514.0590209960938, 186.6338348388672], "class": "hand", "conf": 0.3430511951446533}, {"rect": [480.520263671875, 171.02577209472656, 517.7135620117188, 188.58230590820312], "class": "handle", "conf": 0.09117522090673447}, {"rect": [380.29132080078125, 39.24116134643555, 418.8224182128906, 97.37042999267578], "class": "head", "conf": 0.48689112067222595}, {"rect": [381.456787109375, 210.2130889892578, 431.331298828125, 393.79925537109375], "class": "leg", "conf": 0.6896959543228149}, {"rect": [424.5220947265625, 206.35604858398438, 477.9910583496094, 319.5708923339844], "class": "leg", "conf": 0.6285281777381897}, {"rect": [390.049072265625, 235.66412353515625, 425.26397705078125, 362.8528137207031], "class": "leg", "conf": 0.21010053157806396}, {"rect": [378.7637023925781, 202.85675048828125, 477.3090515136719, 388.9565734863281], "class": "leg", "conf": 0.18848560750484467}, {"rect": [429.72235107421875, 239.2583770751953, 480.3968200683594, 302.7010803222656], "class": "leg", "conf": 0.09525711834430695}, {"rect": [362.8437805175781, 37.38288116455078, 515.8743286132812, 390.00128173828125], "class": "man", "conf": 0.8562116026878357}, {"rect": [351.92327880859375, 29.784034729003906, 490.094970703125, 213.55828857421875], "class": "man", "conf": 0.23160699009895325}, {"rect": [362.80120849609375, 32.96712875366211, 448.4388122558594, 162.9371795654297], "class": "man", "conf": 0.1079145073890686}, {"rect": [358.2061462402344, 33.88023376464844, 495.73516845703125, 370.9494934082031], "class": "person", "conf": 0.06393539905548096}, {"rect": [338.5758056640625, 17.330217361450195, 531.437744140625, 412.83319091796875], "class": "player", "conf": 0.24352310597896576}, {"rect": [481.2076110839844, 159.32720947265625, 574.5369873046875, 213.64181518554688], "class": "racket", "conf": 0.9561034440994263}, {"rect": [366.45166015625, 84.2710189819336, 447.8257751464844, 206.29042053222656], "class": "shirt", "conf": 0.9746922850608826}, {"rect": [400.484375, 363.1129455566406, 437.3819580078125, 404.67803955078125], "class": "shoe", "conf": 0.8509595990180969}, {"rect": [419.76239013671875, 293.77239990234375, 462.8852844238281, 345.2738037109375], "class": "shoe", "conf": 0.791804313659668}, {"rect": [400.3055114746094, 291.7102966308594, 458.1903381347656, 400.8177795410156], "class": "shoe", "conf": 0.451229065656662}, {"rect": [389.53765869140625, 289.766357421875, 473.3502197265625, 361.330078125], "class": "shoe", "conf": 0.13684089481830597}, {"rect": [392.7488708496094, 346.1171875, 447.36798095703125, 411.2497253417969], "class": "shoe", "conf": 0.0561157688498497}, {"rect": [370.8249206542969, 194.576416015625, 474.23089599609375, 270.8345642089844], "class": "short", "conf": 0.9614090919494629}, {"rect": [356.0423278808594, 176.1267852783203, 498.0373229980469, 296.9983215332031], "class": "short", "conf": 0.08240000903606415}, {"rect": [419.591796875, 296.43988037109375, 462.65545654296875, 346.5200500488281], "class": "sneaker", "conf": 0.18005944788455963}, {"rect": [399.8039245605469, 361.9225769042969, 437.9319763183594, 404.35595703125], "class": "sneaker", "conf": 0.09698790311813354}, {"rect": [400.3652648925781, 293.8009948730469, 459.1552429199219, 402.6984558105469], "class": "sneaker", "conf": 0.09357697516679764}, {"rect": [427.9765319824219, 288.8907775878906, 453.26446533203125, 312.6236572265625], "class": "sock", "conf": 0.743930995464325}, {"rect": [409.9351501464844, 350.21051025390625, 428.9573974609375, 372.9019470214844], "class": "sock", "conf": 0.6154168248176575}, {"rect": [437.1496276855469, 277.0347900390625, 463.5595397949219, 304.26019287109375], "class": "sock", "conf": 0.27959662675857544}, {"rect": [394.2384338378906, 273.92572021484375, 424.9969177246094, 367.5042724609375], "class": "sock", "conf": 0.15972097218036652}, {"rect": [420.8367919921875, 291.3815612792969, 444.9108581542969, 316.9632873535156], "class": "sock", "conf": 0.07458183169364929}, {"rect": [402.19635009765625, 343.0699157714844, 429.946533203125, 376.3409423828125], "class": "sock", "conf": 0.06334833055734634}], [{"subj_id": 17, "obj_id": 29, "class": "wearing", "conf": 0.20901232957839966, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [422.5279083251953, 232.7054901123047]}, {"subj_id": 17, "obj_id": 23, "class": "wearing", "conf": 0.1914515346288681, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [407.1387176513672, 145.28071975708008]}, {"subj_id": 17, "obj_id": 6, "class": "has", "conf": 0.1440306454896927, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [399.7736053466797, 57.09629249572754]}, {"subj_id": 17, "obj_id": 22, "class": "holding", "conf": 0.13203947246074677, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [527.8722991943359, 186.48451232910156]}, {"subj_id": 23, "obj_id": 17, "class": "on", "conf": 0.1140482947230339, "subj_center": [407.1387176513672, 145.28071975708008], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 24, "class": "wearing", "conf": 0.1134713664650917, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [418.93316650390625, 383.89549255371094]}, {"subj_id": 22, "obj_id": 8, "class": "in", "conf": 0.11252132058143616, "subj_center": [527.8722991943359, 186.48451232910156], "obj_center": [493.31768798828125, 174.2921600341797]}, {"subj_id": 17, "obj_id": 25, "class": "wearing", "conf": 0.11077643185853958, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [441.32383728027344, 319.5231018066406]}, {"subj_id": 29, "obj_id": 17, "class": "on", "conf": 0.10555530339479446, "subj_center": [422.5279083251953, 232.7054901123047], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 0, "class": "has", "conf": 0.09894213825464249, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [450.01905822753906, 129.88024520874023]}, {"subj_id": 17, "obj_id": 12, "class": "has", "conf": 0.08594156056642532, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [406.39404296875, 302.0061721801758]}, {"subj_id": 17, "obj_id": 34, "class": "wearing", "conf": 0.08456337451934814, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [440.62049865722656, 300.75721740722656]}, {"subj_id": 17, "obj_id": 5, "class": "has", "conf": 0.08441606163978577, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [395.7146911621094, 79.24600028991699]}, {"subj_id": 22, "obj_id": 9, "class": "in", "conf": 0.07520977407693863, "subj_center": [527.8722991943359, 186.48451232910156], "obj_center": [501.8549346923828, 175.48308563232422]}, {"subj_id": 17, "obj_id": 11, "class": "has", "conf": 0.07134296745061874, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [399.55686950683594, 68.30579566955566]}, {"subj_id": 25, "obj_id": 17, "class": "on", "conf": 0.06812875717878342, "subj_center": [441.32383728027344, 319.5231018066406], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 13, "class": "has", "conf": 0.06773655116558075, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [451.25657653808594, 262.9634704589844]}, {"subj_id": 6, "obj_id": 17, "class": "on", "conf": 0.06453395634889603, "subj_center": [399.7736053466797, 57.09629249572754], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 24, "obj_id": 17, "class": "on", "conf": 0.06328039616346359, "subj_center": [418.93316650390625, 383.89549255371094], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 8, "class": "has", "conf": 0.05845368653535843, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [493.31768798828125, 174.2921600341797]}, {"subj_id": 21, "obj_id": 29, "class": "wearing", "conf": 0.057542458176612854, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [422.5279083251953, 232.7054901123047]}, {"subj_id": 17, "obj_id": 26, "class": "wearing", "conf": 0.05445714294910431, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [429.2479248046875, 346.2640380859375]}, {"subj_id": 21, "obj_id": 23, "class": "wearing", "conf": 0.051322754472494125, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [407.1387176513672, 145.28071975708008]}, {"subj_id": 17, "obj_id": 1, "class": "has", "conf": 0.04824553430080414, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [460.6714324951172, 147.7406120300293]}, {"subj_id": 12, "obj_id": 17, "class": "of", "conf": 0.04671589657664299, "subj_center": [406.39404296875, 302.0061721801758], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 35, "class": "wearing", "conf": 0.04174565523862839, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [419.44627380371094, 361.5562286376953]}, {"subj_id": 21, "obj_id": 6, "class": "has", "conf": 0.04004696011543274, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [399.7736053466797, 57.09629249572754]}, {"subj_id": 21, "obj_id": 22, "class": "holding", "conf": 0.03731017932295799, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [527.8722991943359, 186.48451232910156]}, {"subj_id": 11, "obj_id": 17, "class": "of", "conf": 0.034122686833143234, "subj_center": [399.55686950683594, 68.30579566955566], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 36, "class": "wearing", "conf": 0.03160851076245308, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [450.3545837402344, 290.6474914550781]}, {"subj_id": 5, "obj_id": 17, "class": "of", "conf": 0.0309883002191782, "subj_center": [395.7146911621094, 79.24600028991699], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 14, "class": "has", "conf": 0.03066200762987137, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [407.6565246582031, 299.2584686279297]}, {"subj_id": 23, "obj_id": 21, "class": "on", "conf": 0.029222799465060234, "subj_center": [407.1387176513672, 145.28071975708008], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 21, "obj_id": 0, "class": "has", "conf": 0.02841968461871147, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [450.01905822753906, 129.88024520874023]}, {"subj_id": 21, "obj_id": 12, "class": "has", "conf": 0.028232702985405922, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [406.39404296875, 302.0061721801758]}, {"subj_id": 17, "obj_id": 31, "class": "wearing", "conf": 0.02761729434132576, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [441.1236267089844, 321.47996520996094]}, {"subj_id": 29, "obj_id": 21, "class": "on", "conf": 0.027539042755961418, "subj_center": [422.5279083251953, 232.7054901123047], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 21, "obj_id": 24, "class": "wearing", "conf": 0.026661576703190804, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [418.93316650390625, 383.89549255371094]}, {"subj_id": 0, "obj_id": 17, "class": "of", "conf": 0.026350095868110657, "subj_center": [450.01905822753906, 129.88024520874023], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 13, "obj_id": 17, "class": "of", "conf": 0.025596775114536285, "subj_center": [451.25657653808594, 262.9634704589844], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 21, "obj_id": 5, "class": "has", "conf": 0.025592491030693054, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [395.7146911621094, 79.24600028991699]}, {"subj_id": 26, "obj_id": 17, "class": "on", "conf": 0.02557535097002983, "subj_center": [429.2479248046875, 346.2640380859375], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 21, "obj_id": 25, "class": "wearing", "conf": 0.025043601170182228, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [441.32383728027344, 319.5231018066406]}, {"subj_id": 17, "obj_id": 15, "class": "has", "conf": 0.024812625721096992, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [428.036376953125, 295.9066619873047]}, {"subj_id": 17, "obj_id": 37, "class": "wearing", "conf": 0.023029854521155357, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [409.61767578125, 320.7149963378906]}, {"subj_id": 21, "obj_id": 34, "class": "wearing", "conf": 0.022914117202162743, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [440.62049865722656, 300.75721740722656]}, {"subj_id": 17, "obj_id": 9, "class": "has", "conf": 0.02169770561158657, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [501.8549346923828, 175.48308563232422]}, {"subj_id": 21, "obj_id": 11, "class": "has", "conf": 0.02036985568702221, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [399.55686950683594, 68.30579566955566]}, {"subj_id": 17, "obj_id": 27, "class": "wearing", "conf": 0.020031966269016266, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [431.4439392089844, 325.5482177734375]}, {"subj_id": 21, "obj_id": 13, "class": "has", "conf": 0.01996336504817009, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [451.25657653808594, 262.9634704589844]}, {"subj_id": 6, "obj_id": 11, "class": "on", "conf": 0.018970759585499763, "subj_center": [399.7736053466797, 57.09629249572754], "obj_center": [399.55686950683594, 68.30579566955566]}, {"subj_id": 24, "obj_id": 21, "class": "on", "conf": 0.018663272261619568, "subj_center": [418.93316650390625, 383.89549255371094], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 31, "obj_id": 17, "class": "on", "conf": 0.018297147005796432, "subj_center": [441.1236267089844, 321.47996520996094], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 6, "obj_id": 21, "class": "on", "conf": 0.01806744560599327, "subj_center": [399.7736053466797, 57.09629249572754], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 18, "obj_id": 29, "class": "wearing", "conf": 0.018016386777162552, "subj_center": [421.0091247558594, 121.67116165161133], "obj_center": [422.5279083251953, 232.7054901123047]}, {"subj_id": 25, "obj_id": 21, "class": "on", "conf": 0.017825797200202942, "subj_center": [441.32383728027344, 319.5231018066406], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 17, "obj_id": 30, "class": "wearing", "conf": 0.017482098191976547, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [427.0398254394531, 236.56255340576172]}, {"subj_id": 20, "obj_id": 29, "class": "wearing", "conf": 0.016256628558039665, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [422.5279083251953, 232.7054901123047]}, {"subj_id": 21, "obj_id": 8, "class": "has", "conf": 0.0158917848020792, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [493.31768798828125, 174.2921600341797]}, {"subj_id": 20, "obj_id": 23, "class": "wearing", "conf": 0.015681341290473938, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [407.1387176513672, 145.28071975708008]}, {"subj_id": 18, "obj_id": 23, "class": "wearing", "conf": 0.01538869272917509, "subj_center": [421.0091247558594, 121.67116165161133], "obj_center": [407.1387176513672, 145.28071975708008]}, {"subj_id": 1, "obj_id": 17, "class": "of", "conf": 0.014746330678462982, "subj_center": [460.6714324951172, 147.7406120300293], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 12, "obj_id": 21, "class": "of", "conf": 0.014702294021844864, "subj_center": [406.39404296875, 302.0061721801758], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 17, "obj_id": 32, "class": "wearing", "conf": 0.01445777527987957, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [418.8679504394531, 383.13926696777344]}, {"subj_id": 34, "obj_id": 17, "class": "on", "conf": 0.014147467911243439, "subj_center": [440.62049865722656, 300.75721740722656], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 21, "obj_id": 1, "class": "has", "conf": 0.013778970576822758, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [460.6714324951172, 147.7406120300293]}, {"subj_id": 17, "obj_id": 33, "class": "wearing", "conf": 0.012615961022675037, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [429.76025390625, 348.2497253417969]}, {"subj_id": 21, "obj_id": 26, "class": "wearing", "conf": 0.012534061446785927, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [429.2479248046875, 346.2640380859375]}, {"subj_id": 8, "obj_id": 17, "class": "of", "conf": 0.012496856041252613, "subj_center": [493.31768798828125, 174.2921600341797], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 16, "class": "has", "conf": 0.012466047890484333, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [455.05958557128906, 270.97972869873047]}, {"subj_id": 17, "obj_id": 18, "class": "wearing", "conf": 0.01217191107571125, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [421.0091247558594, 121.67116165161133]}, {"subj_id": 14, "obj_id": 17, "class": "of", "conf": 0.011931163258850574, "subj_center": [407.6565246582031, 299.2584686279297], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 27, "obj_id": 17, "class": "on", "conf": 0.011686550453305244, "subj_center": [431.4439392089844, 325.5482177734375], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 2, "class": "has", "conf": 0.011522197164595127, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [444.8470764160156, 144.33537673950195]}, {"subj_id": 21, "obj_id": 35, "class": "wearing", "conf": 0.011430228129029274, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [419.44627380371094, 361.5562286376953]}, {"subj_id": 22, "obj_id": 17, "class": "on", "conf": 0.011316443793475628, "subj_center": [527.8722991943359, 186.48451232910156], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 17, "obj_id": 7, "class": "has", "conf": 0.0112907774746418, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [397.9068908691406, 63.612091064453125]}, {"subj_id": 17, "obj_id": 3, "class": "has", "conf": 0.011079147458076477, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [455.56201171875, 143.9087677001953]}, {"subj_id": 20, "obj_id": 6, "class": "has", "conf": 0.010619232431054115, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [399.7736053466797, 57.09629249572754]}, {"subj_id": 17, "obj_id": 4, "class": "has", "conf": 0.010512005537748337, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [430.3365478515625, 133.30646514892578]}, {"subj_id": 5, "obj_id": 21, "class": "of", "conf": 0.010504872538149357, "subj_center": [395.7146911621094, 79.24600028991699], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 15, "obj_id": 17, "class": "of", "conf": 0.010460702702403069, "subj_center": [428.036376953125, 295.9066619873047], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 20, "obj_id": 22, "class": "holding", "conf": 0.010447757318615913, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [527.8722991943359, 186.48451232910156]}, {"subj_id": 35, "obj_id": 17, "class": "on", "conf": 0.010147219523787498, "subj_center": [419.44627380371094, 361.5562286376953], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 21, "obj_id": 14, "class": "has", "conf": 0.00966279860585928, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [407.6565246582031, 299.2584686279297]}, {"subj_id": 11, "obj_id": 21, "class": "of", "conf": 0.009651199914515018, "subj_center": [399.55686950683594, 68.30579566955566], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 6, "obj_id": 5, "class": "on", "conf": 0.008495799265801907, "subj_center": [399.7736053466797, 57.09629249572754], "obj_center": [395.7146911621094, 79.24600028991699]}, {"subj_id": 20, "obj_id": 24, "class": "wearing", "conf": 0.00840070378035307, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [418.93316650390625, 383.89549255371094]}, {"subj_id": 32, "obj_id": 17, "class": "on", "conf": 0.008362277410924435, "subj_center": [418.8679504394531, 383.13926696777344], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 0, "obj_id": 21, "class": "of", "conf": 0.008350124582648277, "subj_center": [450.01905822753906, 129.88024520874023], "obj_center": [435.00677490234375, 215.08170413970947]}, {"subj_id": 37, "obj_id": 17, "class": "on", "conf": 0.008343961089849472, "subj_center": [409.61767578125, 320.7149963378906], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 21, "obj_id": 36, "class": "wearing", "conf": 0.008253660053014755, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [450.3545837402344, 290.6474914550781]}, {"subj_id": 21, "obj_id": 15, "class": "has", "conf": 0.008176947943866253, "subj_center": [435.00677490234375, 215.08170413970947], "obj_center": [428.036376953125, 295.9066619873047]}, {"subj_id": 20, "obj_id": 25, "class": "wearing", "conf": 0.008143127895891666, "subj_center": [426.9706573486328, 202.41486358642578], "obj_center": [441.32383728027344, 319.5231018066406]}, {"subj_id": 17, "obj_id": 28, "class": "wearing", "conf": 0.00801161490380764, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [420.0584259033203, 378.68345642089844]}, {"subj_id": 17, "obj_id": 39, "class": "wearing", "conf": 0.00791203323751688, "subj_center": [439.3590545654297, 213.69208145141602], "obj_center": [416.0714416503906, 359.70542907714844]}, {"subj_id": 23, "obj_id": 20, "class": "on", "conf": 0.00781384389847517, "subj_center": [407.1387176513672, 145.28071975708008], "obj_center": [426.9706573486328, 202.41486358642578]}, {"subj_id": 30, "obj_id": 17, "class": "on", "conf": 0.007469784002751112, "subj_center": [427.0398254394531, 236.56255340576172], "obj_center": [439.3590545654297, 213.69208145141602]}, {"subj_id": 29, "obj_id": 20, "class": "on", "conf": 0.007469267584383488, "subj_center": [422.5279083251953, 232.7054901123047], "obj_center": [426.9706573486328, 202.41486358642578]}, {"subj_id": 26, "obj_id": 21, "class": "on", "conf": 0.007462529465556145, "subj_center": [429.2479248046875, 346.2640380859375], "obj_center": [435.00677490234375, 215.08170413970947]}]] -------------------------------------------------------------------------------- /input_text/scene_graph_text/scene_graph_coco17_attr/000000000057.json: -------------------------------------------------------------------------------- 1 | [[{"rect": [375.7676086425781, 198.69309997558594, 477.95904541015625, 275.0142517089844], "class": "short", "conf": 0.9386906623840332, "attr": ["checkered", "blue", "plaid"], "attr_conf": [0.12672920525074005, 0.1929844468832016, 0.5183870196342468]}, {"rect": [369.84185791015625, 79.27885437011719, 455.9405517578125, 220.9800567626953], "class": "shirt", "conf": 0.8784235715866089, "attr": ["short sleeved", "orange", "red"], "attr_conf": [0.05093493312597275, 0.29195699095726013, 0.5454922318458557]}, {"rect": [56.157039642333984, 223.796875, 422.24627685546875, 401.2650146484375], "class": "shadow", "conf": 0.8617755770683289, "attr": ["white"], "attr_conf": [0.6941165924072266]}, {"rect": [369.4289855957031, 30.715730667114258, 519.0794067382812, 408.05767822265625], "class": "man", "conf": 0.8587678074836731, "attr": ["standing", "playing tennis", "playing"], "attr_conf": [0.06929700821638107, 0.11102530360221863, 0.22551782429218292]}, {"rect": [490.3601379394531, 161.1704559326172, 577.4737548828125, 213.88795471191406], "class": "racket", "conf": 0.8229435086250305, "attr": ["tennis", "white", "black", "green"], "attr_conf": [0.0901181548833847, 0.1125616729259491, 0.1457483023405075, 0.5095292329788208]}, {"rect": [135.04090881347656, 188.26712036132812, 154.80393981933594, 206.55105590820312], "class": "ball", "conf": 0.7927573323249817, "attr": ["flying", "green", "yellow"], "attr_conf": [0.0934738740324974, 0.1355827897787094, 0.5002978444099426]}, {"rect": [380.695068359375, 38.574012756347656, 425.5843505859375, 78.5858383178711], "class": "hair", "conf": 0.7818819880485535, "attr": ["short", "curly", "dark", "black"], "attr_conf": [0.05739080160856247, 0.07352045923471451, 0.1835852563381195, 0.5327075123786926]}, {"rect": [498.53619384765625, 166.53948974609375, 517.3152465820312, 188.61746215820312], "class": "hand", "conf": 0.7553640604019165, "attr": ["white"], "attr_conf": [0.3187938630580902]}, {"rect": [484.0550537109375, 164.32546997070312, 506.7552795410156, 184.32862854003906], "class": "hand", "conf": 0.7433161735534668, "attr": ["white"], "attr_conf": [0.22359587252140045]}, {"rect": [4.80386209487915, 221.55044555664062, 403.4365539550781, 426.288330078125], "class": "line", "conf": 0.7036184668540955, "attr": ["white"], "attr_conf": [0.8511332869529724]}, {"rect": [421.8260498046875, 296.9776916503906, 466.86767578125, 348.4229431152344], "class": "shoe", "conf": 0.6774736642837524, "attr": ["tennis", "white"], "attr_conf": [0.06794583797454834, 0.680751383304596]}, {"rect": [401.41473388671875, 370.21429443359375, 440.33245849609375, 408.3084411621094], "class": "shoe", "conf": 0.6501246690750122, "attr": ["white"], "attr_conf": [0.7100988030433655]}, {"rect": [380.360107421875, 40.33637619018555, 423.9418640136719, 101.17613220214844], "class": "head", "conf": 0.6428474187850952, "attr": ["short", "curly", "dark", "black"], "attr_conf": [0.07091108709573746, 0.0765766054391861, 0.1787957102060318, 0.46587973833084106]}, {"rect": [400.6651611328125, 5.3185014724731445, 633.91015625, 101.23121643066406], "class": "shadow", "conf": 0.6389392018318176, "attr": ["cast", "black", "brown", "dark"], "attr_conf": [0.08082034438848495, 0.15436798334121704, 0.15467478334903717, 0.3400002121925354]}, {"rect": [409.9427185058594, 350.501953125, 431.1708068847656, 374.5637512207031], "class": "sock", "conf": 0.5981297492980957, "attr": ["white"], "attr_conf": [0.9655088186264038]}, {"rect": [430.75341796875, 276.2204284667969, 469.3234558105469, 315.6880187988281], "class": "sock", "conf": 0.5629867911338806, "attr": ["white"], "attr_conf": [0.8572573065757751]}, {"rect": [0.0, 19.88435935974121, 639.2881469726562, 422.4766540527344], "class": "court", "conf": 0.4811999797821045, "attr": ["dirt", "red", "brown", "orange", "clay"], "attr_conf": [0.05257749557495117, 0.12571871280670166, 0.14589205384254456, 0.23927059769630432, 0.2732866108417511]}, {"rect": [456.0313720703125, 267.8884582519531, 483.96728515625, 291.3211975097656], "class": "knee", "conf": 0.44874992966651917, "attr": ["bent"], "attr_conf": [0.9092406630516052]}, {"rect": [5.822593688964844, 374.1316223144531, 321.70556640625, 426.288330078125], "class": "line", "conf": 0.4046552777290344, "attr": ["white"], "attr_conf": [0.8952876329421997]}, {"rect": [378.93402099609375, 212.7904052734375, 436.8543701171875, 380.7745666503906], "class": "leg", "conf": 0.3669198155403137, "attr": ["bent", "crossed", "bare"], "attr_conf": [0.06543419510126114, 0.07148923724889755, 0.07176770269870758]}, {"rect": [472.05950927734375, 152.61900329589844, 489.6600646972656, 171.71510314941406], "class": "watch", "conf": 0.3144718408584595, "attr": ["black"], "attr_conf": [0.88736891746521]}, {"rect": [0.0, 14.395047187805176, 639.2881469726562, 424.1204528808594], "class": "ground", "conf": 0.311477392911911, "attr": ["dirt", "red", "brown", "orange", "clay"], "attr_conf": [0.05359957367181778, 0.1195363849401474, 0.149090975522995, 0.24172335863113403, 0.27286815643310547]}, {"rect": [396.7987976074219, 300.0684509277344, 469.4784240722656, 403.863525390625], "class": "shoe", "conf": 0.3050905168056488, "attr": ["tennis", "white"], "attr_conf": [0.0603000707924366, 0.7632794976234436]}, {"rect": [387.6640930175781, 87.32012939453125, 399.74468994140625, 95.76163482666016], "class": "mouth", "conf": 0.30070754885673523, "attr": ["smiling", "closed", "open"], "attr_conf": [0.06646332144737244, 0.12914811074733734, 0.6207467913627625]}, {"rect": [472.3974304199219, 151.9155731201172, 491.1281433105469, 171.11224365234375], "class": "wrist", "conf": 0.2999599277973175, "attr": ["black"], "attr_conf": [0.8781831860542297]}, {"rect": [409.7373046875, 100.91047668457031, 441.9548034667969, 134.66603088378906], "class": "writing", "conf": 0.29682910442352295, "attr": ["white"], "attr_conf": [0.9781084060668945]}, {"rect": [135.15330505371094, 187.64132690429688, 155.61489868164062, 206.47390747070312], "class": "tennis ball", "conf": 0.22090338170528412, "attr": ["flying", "green", "yellow"], "attr_conf": [0.09183763712644577, 0.13557341694831848, 0.49887967109680176]}, {"rect": [484.87200927734375, 161.73353576660156, 575.6378173828125, 213.27391052246094], "class": "tennis racket", "conf": 0.21902309358119965, "attr": ["tennis", "white", "black", "green"], "attr_conf": [0.1015172079205513, 0.11729150265455246, 0.16931116580963135, 0.4515334963798523]}, {"rect": [382.9950866699219, 58.9854850769043, 414.7087097167969, 100.23716735839844], "class": "face", "conf": 0.20735099911689758, "attr": ["smiling"], "attr_conf": [0.3399803340435028]}, {"rect": [408.52679443359375, 91.9161605834961, 500.1498718261719, 179.14306640625], "class": "arm", "conf": 0.20213431119918823, "attr": ["raised", "outstretched", "extended"], "attr_conf": [0.12956707179546356, 0.13083970546722412, 0.291843056678772]}]] -------------------------------------------------------------------------------- /input_text/scene_graph_text/scene_graph_coco17_caption/000000000057.json: -------------------------------------------------------------------------------- 1 | {"[408.52679443359375, 91.9161605834961, 500.1498718261719, 179.14306640625]": "Tennis ball is hitting in the air with a racquet in his right arm.", "[135.04090881347656, 188.26712036132812, 154.80393981933594, 206.55105590820312]": "An image with an orange background of a ball.", "[0.0, 19.88435935974121, 639.2881469726562, 422.4766540527344]": "Tennis playing on the court with a tennis racket.", "[382.9950866699219, 58.9854850769043, 414.7087097167969, 100.23716735839844]": "Dude with a serious expression on the face.", "[0.0, 14.395047187805176, 639.2881469726562, 424.1204528808594]": "Playing tennis on the tennis clay tennis field, with a tennis ball in the foreground.", "[380.695068359375, 38.574012756347656, 425.5843505859375, 78.5858383178711]": "Player with black hair in a red jacket.", "[484.0550537109375, 164.32546997070312, 506.7552795410156, 184.32862854003906]": "Someone holding the hand and the other hand.", "[498.53619384765625, 166.53948974609375, 517.3152465820312, 188.61746215820312]": "Someone holding their hand out to touch something.", "[380.360107421875, 40.33637619018555, 423.9418640136719, 101.17613220214844]": "Tennis baller with red uniform and black headband.", "[456.0313720703125, 267.8884582519531, 483.96728515625, 291.3211975097656]": "Animal with collar on a brown and orange ground.", "[378.93402099609375, 212.7904052734375, 436.8543701171875, 380.7745666503906]": "Tennis players' legs on court with a tennis racket.", "[4.80386209487915, 221.55044555664062, 403.4365539550781, 426.288330078125]": "Tennis ball coming toward the tennis ball that is coming towards him.", "[5.822593688964844, 374.1316223144531, 321.70556640625, 426.288330078125]": "Someone standing at the net with their shadow in a court.", "[369.4289855957031, 30.715730667114258, 519.0794067382812, 408.05767822265625]": "Tennis player in action with a racke in his hands and a tennis ball.", "[387.6640930175781, 87.32012939453125, 399.74468994140625, 95.76163482666016]": "A man with a beard and a red tie.", "[490.3601379394531, 161.1704559326172, 577.4737548828125, 213.88795471191406]": "Someone is swinging the rackets to hit it.", "[56.157039642333984, 223.796875, 422.24627685546875, 401.2650146484375]": "Tennis ball coming to the tennis ball on the tennis court.", "[400.6651611328125, 5.3185014724731445, 633.91015625, 101.23121643066406]": "Playing tennis in the sun on the court with a tennis racket.", "[369.84185791015625, 79.27885437011719, 455.9405517578125, 220.9800567626953]": "Tennis player in red t shirt swinging at ball.", "[396.7987976074219, 300.0684509277344, 469.4784240722656, 403.863525390625]": "Athletic shoes on a person playing tennis.", "[401.41473388671875, 370.21429443359375, 440.33245849609375, 408.3084411621094]": "A tennis player with rackets and a tennis shoe.", "[421.8260498046875, 296.9776916503906, 466.86767578125, 348.4229431152344]": "Shoes on a tennis court with a tennis ball.", "[375.7676086425781, 198.69309997558594, 477.95904541015625, 275.0142517089844]": "Tennis playing in an orange top and black shorts.", "[409.9427185058594, 350.501953125, 431.1708068847656, 374.5637512207031]": "Someone has made an interesting paper cone out of a cup and some kind.", "[430.75341796875, 276.2204284667969, 469.3234558105469, 315.6880187988281]": "Tennis ball and tennis player with white shoes and socks.", "[135.15330505371094, 187.64132690429688, 155.61489868164062, 206.47390747070312]": "A tennis ball is seen in this picture.", "[484.87200927734375, 161.73353576660156, 575.6378173828125, 213.27391052246094]": "Someone holding a racket on a tennis court.", "[472.05950927734375, 152.61900329589844, 489.6600646972656, 171.71510314941406]": "Animal that looks to have a watch on his wrist.", "[472.3974304199219, 151.9155731201172, 491.1281433105469, 171.11224365234375]": "Someone has a wrist on their left wrist.", "[409.7373046875, 100.91047668457031, 441.9548034667969, 134.66603088378906]": "Someone wearing an orange shirt with the word w on it."} -------------------------------------------------------------------------------- /main_aokvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import argparse 4 | import numpy as np 5 | import multiprocessing 6 | import time 7 | import json 8 | import time 9 | import torch 10 | import random 11 | import openai 12 | from tqdm import tqdm 13 | from transformers import GPT2Tokenizer 14 | import pdb 15 | import pickle 16 | import glob 17 | from transformers import CLIPProcessor, CLIPModel 18 | from transformers import CLIPTokenizer, CLIPTextModel 19 | from PIL import Image 20 | import utils_api 21 | 22 | def set_random_seed(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | set_random_seed(seed=0) 28 | 29 | def parse_sentence(raw_result_list): 30 | output_list = [] 31 | for raw_result in raw_result_list: 32 | if isinstance(raw_result, str): 33 | raw_result = raw_result.strip(" ") 34 | tmp_result_list = raw_result.split(",") 35 | tmp_output_list = [] 36 | for tmp_result in tmp_result_list: 37 | tmp_output_list +=tmp_result.split(" and ") 38 | output_list += tmp_output_list 39 | elif isinstance(raw_result, list): 40 | raw_ouput = parse_sentence(raw_result) 41 | output_list +=raw_ouput 42 | output_list = [ele.strip() for ele in output_list] 43 | output_list = [ele for ele in output_list if len(ele)>0] 44 | return output_list 45 | 46 | def parge_obj_name(raw_result_list): 47 | output_list = [] 48 | for raw_result in raw_result_list: 49 | if isinstance(raw_result, str): 50 | raw_result = raw_result.strip(" ") 51 | tmp_result_list = raw_result.split(",") 52 | tmp_output_list = [] 53 | for tmp_result in tmp_result_list: 54 | tmp_output_list +=tmp_result.split(" and ") 55 | output_list += tmp_output_list 56 | elif isinstance(raw_result, list): 57 | raw_ouput = parse_sentence(raw_result) 58 | output_list +=raw_ouput 59 | output_list = [ele[2:] if ele.lower().startswith("a ") else ele for ele in output_list] 60 | output_list = [ele[3:] if ele.lower().startswith("an ") else ele for ele in output_list] 61 | output_list = [ele[4:] if ele.lower().startswith("the ") else ele for ele in output_list] 62 | output_list = [ele.strip() for ele in output_list] 63 | output_list = [ele for ele in output_list if len(ele)>0] 64 | return output_list 65 | 66 | 67 | ## initial cleaning for reference QA results; Please use vqav2 eval script for the final number 68 | def process_answer(answer): 69 | answer = answer.replace('.', '').replace(',', '').lower() 70 | to_be_removed = {'a', 'an', 'the', 'to', ''} 71 | answer_list = answer.split(' ') 72 | answer_list = [item for item in answer_list if item not in to_be_removed] 73 | return ' '.join(answer_list) 74 | 75 | def bounding_box_matching(box1, box2, thres=0.5): 76 | ax1, ay1, ax2, ay2 = box1 77 | bx1, by1, bx2, by2 = box2 78 | if ax1 >= bx2 or ax2 <= bx1 or ay1 >= by2 or ay2 <= by1: 79 | return 0 80 | intersection = (min(ax2, bx2) - max(ax1, bx1)) * (min(ay2, by2) - max(ay1, by1)) 81 | and_set = (ax2-ax1) * (ay2-ay1) + (bx2-bx1) * (by2-by1) - intersection 82 | return (intersection / and_set) 83 | 84 | class VisualCOT_AOKVQA: 85 | def __init__(self, args, apikey_list): 86 | self.args = args 87 | self.chain_of_thoughts = args.chain_of_thoughts 88 | ## loading input questions (and answer for reference accuracy computing) 89 | self.load_dataset(args) 90 | self.load_similarity() 91 | self.apikey_list = apikey_list 92 | self.apikey_idx = 0 93 | openai.api_key = self.apikey_list[self.apikey_idx] 94 | if args.engine == "opt": 95 | if args.with_six_gpus: 96 | self.initialize_opt_6gpu() 97 | elif args.with_one_gpu: 98 | self.initialize_opt_small() 99 | else: 100 | self.initialize_opt() 101 | elif args.engine == "llama": 102 | self.initialize_llama() 103 | elif args.engine == "bloom": 104 | from plm.bloom import get_model_and_tokenizer 105 | self.model, self.tokenizer = get_model_and_tokenizer(name="microsoft/bloom-deepspeed-inference-int8", 106 | dtype="int8") 107 | elif args.engine == "chat": 108 | import tiktoken 109 | self.tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") 110 | # elif args.engine == "codex": 111 | # import tiktoken 112 | # self.tokenizer = tiktoken.encoding_for_model(self.args.engine_name) 113 | # elif args.engine == "instruct": 114 | # import tiktoken 115 | # self.tokenizer = tiktoken.encoding_for_model("davinci-instruct-beta") 116 | elif args.engine == "chat-test": 117 | self.initialize_opt_small() 118 | elif args.engine in ['ada', 'babbage', 'curie', 'davinci', 'gpt3']: 119 | self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 120 | else: 121 | import tiktoken 122 | self.tokenizer = tiktoken.encoding_for_model(self.args.engine_name) 123 | # self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 124 | 125 | if self.args.use_blip2: 126 | if not self.args.with_blip2_api: 127 | blip2_model_name = "pretrain_flant5xl" if self.args.use_v100 else "pretrain_flant5xxl" 128 | from lavis.models import load_model_and_preprocess 129 | if args.engine == "chat-test": 130 | self.blip2_device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 131 | self.blip2_model, self.blip2_vis_processors, _ = load_model_and_preprocess(name="blip2_t5", 132 | model_type=blip2_model_name, 133 | is_eval=True, 134 | device=self.blip2_device) 135 | import pdb 136 | pdb.set_trace() 137 | else: 138 | self.blip2_device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 139 | self.blip2_model, self.blip2_vis_processors, _ = load_model_and_preprocess(name="blip2_t5", 140 | model_type=blip2_model_name, 141 | is_eval=True, 142 | device=self.blip2_device) 143 | print("Finish loading BLIP2 model") 144 | else: 145 | self.blip2_api = API_URLS = [ "http://localhost:5000/api/generate", ] 146 | 147 | if args.with_clip_verify or args.choice_only: 148 | model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16") 149 | model = model.cuda() 150 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") 151 | self.clip_model, self.clip_processor = model, processor 152 | 153 | if args.oracle_attend: 154 | with open(self.args.val_sim_file, "rb") as fh: 155 | self.val_oracle_attend = pickle.load(fh) 156 | 157 | self.temp_question = "What is the person doing?" 158 | 159 | def find_image(self, img_key): 160 | img_full_path = os.path.join(self.raw_image_dir, "%012d.jpg" % img_key) 161 | #print(img_full_path) 162 | return Image.open(img_full_path).convert("RGB") 163 | 164 | def load_dataset(self, args): 165 | test_name = "test" if args.test_only else "val" 166 | self.raw_image_dir = os.path.join(self.args.raw_image_dir, "%s2017" % test_name) 167 | _, self.answer_dict, self.question_dict, self.rationale_dict, self.choices_dict = \ 168 | self.load_anno(None, f'{args.coco_path}/aokvqa_v1p0_{test_name}.json', 169 | f'{args.coco_path}/aokvqa_v1p0_{test_name}.json', choice_only=args.choice_only) 170 | self.val_keys = list(self.question_dict.keys()) 171 | self.val_keys = self.val_keys[int(args.start * len(self.val_keys)):int(args.end * len(self.val_keys))] # OK 172 | 173 | ## load cached image representation (Coco caption & Tags) 174 | self.inputtext_dict = self.load_cachetext() 175 | 176 | self.traincontext_caption_dict, self.traincontext_answer_dict, \ 177 | self.traincontext_question_dict, self.traincontext_rationale_dict, \ 178 | self.traincontext_choices_dict = \ 179 | self.load_anno('%s/captions_train2017.json' % args.coco_path, \ 180 | '%s/aokvqa_v1p0_train.json' % args.coco_path, \ 181 | '%s/aokvqa_v1p0_train.json' % args.coco_path, choice_only=args.choice_only) 182 | 183 | self.traincontext_interactive_answer_dict = self.traincontext_answer_dict 184 | self.traincontext_interactive_question_dict = self.traincontext_question_dict 185 | self.train_keys = list(self.traincontext_answer_dict.keys()) # OK 186 | self.train_interactive_keys = self.train_keys # OK 187 | 188 | if args.caption_type == 'vinvl_ocr': 189 | self.load_ocr(os.path.join(self.args.sg_path, "coco17_ocr_train.json"), 190 | os.path.join(self.args.sg_path, f"coco17_ocr_{test_name}.json"), 191 | os.path.join(self.args.sg_path, "scene_graph_coco17_attr")) 192 | self.sg_dir = os.path.join(self.args.sg_path, "scene_graph_coco17") 193 | self.sg_attr_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_attr") 194 | self.sg_cap_dir = os.path.join(self.args.sg_path, self.args.concept_caption_path) 195 | 196 | def load_anno(self, coco_caption_file, answer_anno_file, question_anno_file, choice_only=False): 197 | if coco_caption_file is not None: 198 | coco_caption = json.load(open(coco_caption_file, 'r')) 199 | if type(coco_caption) == type({}): coco_caption = coco_caption['annotations'] 200 | answer_anno = json.load(open(answer_anno_file, 'r')) 201 | question_anno = json.load(open(question_anno_file, 'r')) 202 | 203 | caption_dict = {} 204 | if coco_caption_file is not None: 205 | for sample in coco_caption: 206 | if sample['image_id'] not in caption_dict: 207 | caption_dict[sample['image_id']] = [sample['caption']] 208 | else: 209 | caption_dict[sample['image_id']].append(sample['caption']) 210 | 211 | answer_dict = {} 212 | for sample in answer_anno: 213 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in answer_dict: 214 | if choice_only: 215 | if 'correct_choice_idx' in sample: 216 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample[ 217 | "correct_choice_idx"] 218 | else: 219 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = 0 220 | else: 221 | if 'direct_answers' in sample: 222 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample[ 223 | "direct_answers"] 224 | else: 225 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = [""] 226 | 227 | question_dict = {} 228 | for sample in question_anno: 229 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in question_dict: 230 | question_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['question'] 231 | 232 | rationales_dict = {} 233 | for sample in answer_anno: 234 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in rationales_dict: 235 | if 'rationales' in sample: 236 | rationales_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['rationales'] 237 | else: 238 | rationales_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = "" 239 | 240 | choices_dict = {} 241 | for sample in answer_anno: 242 | choices_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['choices'] 243 | 244 | return caption_dict, answer_dict, question_dict, rationales_dict, choices_dict 245 | 246 | def sleep(self, sleep_time=1.5, switch_key=False): 247 | if self.args.engine == "codex": 248 | sleep_time = 0.1 249 | if switch_key: 250 | self.apikey_idx += 1 251 | if self.apikey_idx >= len(self.apikey_list): 252 | self.apikey_idx = 0 253 | openai.api_key = self.apikey_list[self.apikey_idx] 254 | time.sleep(sleep_time) 255 | 256 | def load_ocr(self, train_ocr, val_ocr, sg_path, thres=0.2): 257 | train_ocr_dict = json.load(open(train_ocr)) 258 | val_ocr_dict = json.load(open(val_ocr)) 259 | self.train_ocr_text = {} 260 | self.val_ocr_text = {} 261 | for key in train_ocr_dict: 262 | tmp_ocr_list = train_ocr_dict[key] 263 | ocr_text = {} 264 | if len(tmp_ocr_list) > 0: 265 | obj_list = json.load(open(os.path.join(sg_path, f"{key.split('_')[-1]}.json"))) 266 | for tmp_ocr in tmp_ocr_list: 267 | box = tmp_ocr["box"] 268 | text = tmp_ocr["text"] 269 | conf = tmp_ocr["conf"] 270 | if conf > thres: 271 | max_match_val = -1 272 | max_match_obj = "" 273 | box = [box[0][0], box[0][1], box[1][0], box[2][1]] 274 | for obj in obj_list[0]: 275 | if bounding_box_matching(box, obj['rect']) > max_match_val: 276 | max_match_obj = obj['class'] 277 | max_match_val = bounding_box_matching(box, obj['rect']) 278 | ocr_text[max_match_obj] = f"Text {text} is on the {max_match_obj}." 279 | self.train_ocr_text[int(key.split("_")[-1])] = ocr_text 280 | 281 | for key in val_ocr_dict: 282 | tmp_ocr_list = val_ocr_dict[key] 283 | ocr_text = {} 284 | if len(tmp_ocr_list) > 0: 285 | obj_list = json.load(open(os.path.join(sg_path, f"{key.split('_')[-1]}.json"))) 286 | for tmp_ocr in tmp_ocr_list: 287 | box = tmp_ocr["box"] 288 | text = tmp_ocr["text"] 289 | conf = tmp_ocr["conf"] 290 | if conf > thres: 291 | max_match_val = -1 292 | max_match_obj = "" 293 | box = [box[0][0], box[0][1], box[1][0], box[2][1]] 294 | for obj in obj_list[0]: 295 | if bounding_box_matching(box, obj['rect']) > max_match_val: 296 | max_match_obj = obj['class'] 297 | max_match_val = bounding_box_matching(box, obj['rect']) 298 | ocr_text[max_match_obj] = f"Text {text} is on the {max_match_obj}." 299 | self.val_ocr_text[int(key.split("_")[-1])] = ocr_text 300 | 301 | def initialize_llama(self): 302 | from transformers import LlamaForCausalLM, LlamaTokenizer 303 | self.model = LlamaForCausalLM.from_pretrained(self.args.llama_path, 304 | device_map="auto", torch_dtype=torch.float16) 305 | self.tokenizer = LlamaTokenizer.from_pretrained(self.args.llama_path) 306 | 307 | def initialize_opt_small(self): 308 | from transformers import OPTForCausalLM 309 | self.model = OPTForCausalLM.from_pretrained("facebook/opt-1.3b", 310 | device_map="auto", torch_dtype=torch.float16) 311 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-1.3b") 312 | 313 | def initialize_opt(self): 314 | num_device = 8 315 | import math 316 | num_layers = math.ceil(64 / 8.0) 317 | assert torch.cuda.device_count() >= num_device 318 | opt_device_map = {'model.decoder.embed_tokens': 0, 319 | 'lm_head': 0, 320 | 'model.decoder.embed_positions': 0, 321 | 'model.decoder.final_layer_norm': 0, 322 | 'model.decoder.layers.0': 0} 323 | for layer in range(64): 324 | layer_name = "model.decoder.layers.%s" % (str(layer)) 325 | device = layer // num_layers 326 | opt_device_map[layer_name] = device 327 | from transformers import OPTForCausalLM 328 | self.model = OPTForCausalLM.from_pretrained("facebook/opt-66b", 329 | device_map=opt_device_map, torch_dtype=torch.float16) 330 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-66b") 331 | 332 | def initialize_opt_6gpu(self): 333 | num_device = 6 334 | import math 335 | num_layers = math.ceil(64 / 6.0) 336 | assert torch.cuda.device_count() >= num_device 337 | opt_device_map = {'model.decoder.embed_tokens': 0, 338 | 'lm_head': 0, 339 | 'model.decoder.embed_positions': 0, 340 | 'model.decoder.final_layer_norm': 0, 341 | 'model.decoder.layers.0': 0} 342 | for layer in range(64): 343 | layer_name = "model.decoder.layers.%s" % (str(layer)) 344 | device = layer // num_layers 345 | opt_device_map[layer_name] = device 346 | from transformers import OPTForCausalLM 347 | self.model = OPTForCausalLM.from_pretrained("facebook/opt-66b", 348 | device_map=opt_device_map, torch_dtype=torch.float16) 349 | self.tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-66b") 350 | 351 | def decode_scene_graph(self, sg_attr): 352 | attr_list = [] 353 | for attr in sg_attr: 354 | if self.args.iterative_strategy == "sg": 355 | attr_list.append(f"{attr[1]} is {', '.join(attr[2])}") 356 | elif self.args.iterative_strategy == "caption": 357 | attr_list.append(attr[3]) 358 | if self.args.caption_type == "vinvl_ocr" and attr[4] != "": 359 | attr_list.append(attr[4]) 360 | else: 361 | assert False 362 | 363 | text = "" 364 | text += "\n".join(attr_list) 365 | return text 366 | 367 | def inference(self, save_every_step): 368 | answers = [] 369 | full_answers = [] 370 | # i=0 371 | if save_every_step: 372 | os.system("mkdir -p %s" % self.args.output_path) 373 | os.system("mkdir -p %s/prompt_samples" % self.args.output_path) 374 | os.system("mkdir -p %s/format_samples" % self.args.output_path) 375 | 376 | if self.args.pick_example_with_question_mode: 377 | while True: 378 | image_id = input("Input one image id please") 379 | question = input("Input one question please") 380 | image_id = str(image_id) 381 | self.given_question = question 382 | for idx, key in enumerate(tqdm(self.val_keys)): 383 | if image_id not in key: 384 | continue 385 | final_answer, answer_list = self.sample_inference_interactive(key) 386 | print(final_answer) 387 | print(answer_list) 388 | pdb.set_trace() 389 | 390 | for idx, key in enumerate(tqdm(self.val_keys)): 391 | if save_every_step: 392 | out_file_name = "%s/prompt_samples/sample_%s_*.json" % (self.args.output_path, str(idx)) 393 | print(out_file_name) 394 | out_file_list = glob.glob(out_file_name) 395 | if len(out_file_list) > 0: 396 | continue 397 | if self.args.pick_example_mode: 398 | if not self.pick_example(key): 399 | continue 400 | final_answer, answer_list = self.sample_inference_interactive(key) 401 | answers.append(final_answer) 402 | full_answers.append(answer_list) 403 | acc = 0. 404 | for answer in answers: 405 | acc += float(answer[3]) 406 | print(acc * 100. / len(answers), len(answers)) 407 | if save_every_step: 408 | json.dump(answers[-1], open("%s/prompt_samples/sample_%s_%s.json" % \ 409 | (self.args.output_path, str(idx), str(float(answers[-1][3]))), 'w')) 410 | json.dump(full_answers[-1], open("%s/format_samples/sample_%s_%s.json" % \ 411 | (self.args.output_path, str(idx), str(float(answers[-1][3]))), 'w')) 412 | # i += 1 413 | # if i > 3: 414 | # break 415 | return answers, full_answers 416 | 417 | def sample_inference_interactive(self, key): 418 | img_key = int(key.split('<->')[0]) if self.args.set_name!="fvqa" else self.image_dict[key] # for fvqa 419 | raw_image = self.find_image(img_key) 420 | if self.args.use_blip2: 421 | if not self.args.with_blip2_api: 422 | self.current_blip2_image = self.blip2_vis_processors["eval"](raw_image).unsqueeze(0).to( 423 | self.blip2_device) 424 | else: 425 | self.current_blip2_image = raw_image 426 | if self.args.debug: 427 | t1=time.time() 428 | if self.args.set_name!="fvqa": 429 | scene_graph_path = os.path.join(self.sg_attr_dir, str(img_key).zfill(12) + ".json") 430 | else: 431 | scene_graph_path = os.path.join(self.sg_attr_dir, str(img_key).replace(".jpg", "") + ".json") 432 | scene_graph_attr = json.load(open( scene_graph_path)) 433 | if self.args.iterative_strategy == "caption": 434 | scene_path = os.path.join(self.sg_cap_dir, str(img_key).zfill(12) + ".json") 435 | if not os.path.isfile(scene_path): 436 | # backup caption 437 | scene_path = os.path.join(self.sg_cap_dir + "_v2", str(img_key).zfill(12) + ".json") 438 | if not os.path.isfile(scene_path): 439 | scene_graph_caption = [f"{attr['class']} is {', '.join(attr['attr'])}." \ 440 | for attr in scene_graph_attr[0]] 441 | else: 442 | scene_graph_caption = json.load(open(scene_path)) 443 | attr_list = [] 444 | for attr_id, attr in enumerate(scene_graph_attr[0]): 445 | if self.args.iterative_strategy == "caption": 446 | if isinstance(scene_graph_caption, list): 447 | tmp_cap = scene_graph_caption[attr_id] 448 | else: 449 | rect_str = str(attr['rect']) 450 | try: 451 | tmp_cap = scene_graph_caption[rect_str] 452 | except: 453 | tmp_cap = attr['class'] 454 | print("Fail to parse attr\n") 455 | tmp_attr = [attr['conf'], attr['class'], attr['attr'], tmp_cap] 456 | else: 457 | tmp_attr = [attr['conf'], attr['class'], attr['attr']] 458 | if self.args.caption_type == "vinvl_ocr": 459 | if attr['class'] in self.val_ocr_text[img_key]: 460 | tmp_attr.append(self.val_ocr_text[img_key][attr['class']]) 461 | else: 462 | tmp_attr.append("") 463 | attr_list.append(tmp_attr) 464 | attr_list.sort(key=lambda x: x[0], reverse=True) 465 | 466 | if self.args.use_blip2: 467 | attr_list_blip2 = [] 468 | attr_name_list = [] 469 | for attr in attr_list: 470 | if attr[1] not in attr_name_list: 471 | attr_name_list.append(attr[1]) 472 | attr_list_blip2.append(attr) 473 | attr_list = attr_list_blip2 474 | self.current_global_caption = self.query_blip2_global_caption(self.question_dict[key]) 475 | 476 | answer_list = [] 477 | noticed_attr_list = [] 478 | thoughts = [] 479 | answer_text = "" 480 | if self.args.debug: 481 | t2=time.time() 482 | print(" PREPARE TIME", t2-t1) 483 | self.current_conversation = [] 484 | rounds = 1 if self.args.all_regional_captions else self.args.rounds 485 | for i in range(rounds): 486 | if self.args.debug: 487 | t3=time.time() 488 | if self.args.random_attend: 489 | idx = random.randint(0, len(attr_list)-1) 490 | elif self.args.oracle_attend: 491 | attr_scores = [self.val_oracle_attend[key][attr[1]] for attr in attr_list] 492 | idx = attr_scores.index(max(attr_scores)) 493 | elif self.args.all_regional_captions: 494 | idx = None 495 | else: 496 | idx = self.interactive(key, attr_list) 497 | # HERE 498 | torch.cuda.synchronize() 499 | if self.args.debug: 500 | t4=time.time() 501 | 502 | if self.args.use_blip2: 503 | if idx >= len(attr_list): 504 | idx = idx % len(attr_list) 505 | print("index %d, out of object list."%(idx)) 506 | print(attr_list) 507 | local_caption = self.query_blip2_local_caption(attr_list[idx][1], self.question_dict[key]) 508 | noticed_attr_list.append([attr_list[idx][0], attr_list[idx][1], [], local_caption, ""]) 509 | elif self.args.all_regional_captions: 510 | noticed_attr_list.extend([attr_i for attr_i in attr_list]) 511 | else: 512 | noticed_attr_list.append(attr_list[idx]) 513 | if self.chain_of_thoughts: 514 | answer_list.append(self.sample_inference(key, [] if self.args.ablation_visual else noticed_attr_list 515 | , [] if self.args.ablation_reason else thoughts)) 516 | else: 517 | answer_list.append(self.sample_inference(key, noticed_attr_list)) 518 | if self.args.debug: 519 | t5=time.time() 520 | print(" VISUAL LOOP TIME", t4-t3) 521 | print(" REASON LOOP TIME", t5-t4) 522 | torch.cuda.synchronize() 523 | if idx != None: 524 | attr_list = attr_list[:idx] + attr_list[idx + 1:] 525 | thoughts.append(answer_list[-1][4]) 526 | if answer_text == answer_list[-1][1]: 527 | break 528 | else: 529 | answer_text = answer_list[-1][1] 530 | final_answer = answer_list[-1] 531 | return final_answer, answer_list 532 | 533 | def pick_example(self, key): 534 | img_key = int(key.split('<->')[0]) if self.args.set_name != "fvqa" else self.image_dict[key] # for fvqa 535 | scene_graph_path = os.path.join(self.sg_attr_dir, str(img_key).zfill(12) + ".json") 536 | scene_graph_attr = json.load(open(scene_graph_path)) 537 | for attr_id, attr in enumerate(scene_graph_attr[0]): 538 | if attr['class'] in ['girl', 'boy', 'man', 'woman'] and len(attr['attr']) > 0: 539 | description = attr['attr'][0] 540 | self.temp_question = f"What is the {description} {attr['class']} doing?" 541 | return True 542 | return False 543 | 544 | def query_blip2_basic(self, image, prompt, use_pred_answer=False): 545 | if not self.args.with_blip2_api: 546 | if use_pred_answer: 547 | output = self.blip2_model.predict_answers({"image": image, "text_input": prompt}, max_len=25) 548 | else: 549 | output = self.blip2_model.generate({"image": image, "text_input": prompt}) 550 | if self.args.debug: 551 | import pdb 552 | pdb.set_trace() 553 | else: 554 | # api only support predict_answers CALL 555 | output = utils_api.blip_completev2(images=[image ], texts= [prompt], blip_urls=self.blip2_api, num_beams=5, length_penalty=-1.0, encoding_format="PNG",) 556 | return output 557 | 558 | return output 559 | 560 | def query_blip2_objects(self): 561 | obj_list = [] 562 | max_obj_num = 10 563 | while len(obj_list) < max_obj_num: 564 | if len(obj_list)==0: 565 | tmp_obj_name_list = self.query_blip2_basic(image=self.current_blip2_image, 566 | prompt="Give me the name of one object, creature, or entity in the image.") 567 | else: 568 | tmp_prompt = "Give me the name of one object, creature, or entity in the image besides" 569 | for tmp_idx, tmp_name in enumerate(obj_list): 570 | tmp_prompt = tmp_prompt +" %s"%tmp_name 571 | if tmp_idx < len(obj_list) -1: 572 | tmp_prompt +="," 573 | else: 574 | tmp_prompt +="?" 575 | tmp_obj_name_list = self.query_blip2_basic(image=self.current_blip2_image, prompt=tmp_prompt) 576 | 577 | tmp_obj_name_list_refine = parge_obj_name(tmp_obj_name_list) 578 | print(tmp_obj_name_list_refine) 579 | 580 | all_exist_flag = True 581 | for obj_name in tmp_obj_name_list_refine: 582 | if obj_name not in obj_list: 583 | obj_list.append(obj_name) 584 | all_exist_flag = False 585 | if all_exist_flag: 586 | break 587 | obj_list = list(set(obj_list)) 588 | attr_list = [[1.0, obj_name] for obj_name in obj_list] # [[confidence, name]] 589 | print(attr_list) 590 | if self.args.debug: 591 | pdb.set_trace() 592 | return attr_list 593 | 594 | def query_blip2_global_caption(self, question): 595 | global_caption = self.query_blip2_basic(image=self.current_blip2_image, prompt="An image of ")[0] 596 | global_caption_question = self.query_blip2_basic(image=self.current_blip2_image, prompt=f"Question: Please " 597 | f"look at the picture and answer the following question. " 598 | f"{question} Answer:", use_pred_answer=True)[0] 599 | if self.args.debug: 600 | print(". ".join([global_caption, global_caption_question])) 601 | return ". ".join([global_caption, global_caption_question]) 602 | 603 | def query_blip2_local_caption(self, obj_name, question): 604 | local_caption_raw = self.query_blip2_basic(image=self.current_blip2_image, 605 | prompt=f"Question: Look at the {obj_name} in this image. Please give a detailed " 606 | f"description of the {obj_name} in this image. Answer:", use_pred_answer=True)[0] 607 | if self.args.engine == "chat": 608 | self.current_conversation.append({ 609 | 'role': 'user', 610 | 'content': f'You will to look at the {obj_name} in the picture and find {local_caption_raw}.' 611 | f'To find the answer to {question}, you can ask one question about the {obj_name}. ' 612 | f'Please tell me the question you want to ask directly.' 613 | }) 614 | successful = False 615 | while not successful: 616 | try: 617 | self.sleep() 618 | response = openai.ChatCompletion.create( 619 | model="gpt-3.5-turbo", 620 | messages=self.current_conversation, 621 | max_tokens=40, 622 | temperature=0., 623 | stream=False, 624 | ) 625 | successful = True 626 | except Exception as e: 627 | print(e) 628 | self.sleep(switch_key=True) 629 | question_from_chatgpt = response['choices'][0]['message']['content'] 630 | elif self.args.engine == "chat-test": 631 | self.current_conversation.append({ 632 | 'role': 'user', 633 | 'content': f'You will to look at the {obj_name} in the picture and find {local_caption_raw}.' 634 | f'To find the answer to {question}, you can ask one question about the {obj_name}. ' 635 | f'Please tell me the question you want to ask directly.' 636 | }) 637 | question_from_chatgpt = "Who are you?" 638 | elif self.args.engine in ["ada", "babbage", "curie", "davinci", "codex", "instruct"]: 639 | prompt = f"I look at the {obj_name} in the picture and find {local_caption_raw}. " \ 640 | f"To find the answer to {question}, I ask one question about the {obj_name}. " \ 641 | f"My question is:" 642 | successful = False 643 | while not successful: 644 | try: 645 | self.sleep() 646 | response = openai.Completion.create( 647 | engine=self.args.engine_name, 648 | prompt=prompt, 649 | max_tokens=41, 650 | logprobs=1, 651 | temperature=0., 652 | stream=False, 653 | stop=["<|endoftext|>", "?", " ?"], 654 | ) 655 | successful = True 656 | except Exception as e: 657 | print(e) 658 | self.sleep(switch_key=True) 659 | question_from_chatgpt = response['choices'][0]['text'].strip()+ "?" 660 | else: 661 | question_from_chatgpt = "Empty" 662 | local_caption_question = self.query_blip2_basic(image=self.current_blip2_image, 663 | prompt=f"Question: Please look at the {obj_name} and answer the following question." 664 | f" {question_from_chatgpt} Answer:", use_pred_answer=True)[0] 665 | local_caption = ". ".join([local_caption_raw, question_from_chatgpt + " The answer is " + local_caption_question]) 666 | if self.args.debug: 667 | print(local_caption) 668 | pdb.set_trace() 669 | return local_caption 670 | 671 | def query_blip2_thought_match_img(self, thought): 672 | blip2_answer = self.query_blip2_basic(image=self.current_blip2_image, 673 | prompt=f"Question: Does this sentence match the facts in the picture? Please answer yes or no. " 674 | f"Sentence: In this picture, {thought} Answer:")[0] 675 | if self.args.debug: 676 | print(blip2_answer, thought) 677 | if blip2_answer == "no": 678 | correction = self.query_blip2_basic(image=self.current_blip2_image, 679 | prompt=f"Question: Please correct the following sentence according to " 680 | f"the image. Sentence: {thought}")[0] 681 | return correction 682 | else: 683 | return thought 684 | 685 | def interactive(self, key, attr_list): 686 | context_key_list, rel_obj_list = self.get_interactive_context_keys(key, \ 687 | self.args.similarity_metric, 688 | self.args.n_shot) 689 | question = self.question_dict[key] 690 | if self.args.pick_example_mode: 691 | question = self.temp_question 692 | if self.args.pick_example_with_question_mode: 693 | question = self.given_question 694 | if self.args.engine == "chat" or self.args.engine == "chat-test": 695 | system_prompt = "Let's play a game. I have an image and a complex question about it, and you will give me the " \ 696 | "name of object in the image you want to look at the most. Please follow the format" \ 697 | " of the following examples and give me an object name directly.\n" 698 | prompt = "===\n" 699 | else: 700 | prompt = 'Please select the object most related to the question.\n===\n' 701 | for ni in range(self.args.n_shot): 702 | context_key = context_key_list[ni] 703 | rel_obj = rel_obj_list[ni] 704 | rel_obj = [k for k, v in sorted(rel_obj.items(), key=lambda item: item[1], reverse=True)] 705 | if self.args.set_name=="fvqa": 706 | img_context_key = self.image_dict[context_key] 707 | else: 708 | img_context_key = int(context_key.split('<->')[0]) 709 | while True: ## make sure get context with valid question and answer 710 | if self.args.choice_only or (len(self.traincontext_interactive_question_dict[context_key]) != 0 and len( 711 | self.traincontext_interactive_answer_dict[context_key][0]) != 0): 712 | break 713 | context_key = self.train_interactive_keys[random.randint(0, len(self.train_interactive_keys) - 1)] 714 | if self.args.set_name=="fvqa": 715 | context_scene_graph = json.load(open(os.path.join(self.sg_dir, str(img_context_key).replace(".jpg", "") + ".json"))) 716 | context_scene_graph_attr = json.load(open(os.path.join(self.sg_attr_dir, str(img_context_key).replace(".jpg", "") + ".json"))) 717 | else: 718 | context_scene_graph = json.load(open(os.path.join(self.sg_dir, str(img_context_key).zfill(12) + ".json"))) 719 | context_scene_graph_attr = json.load(open(os.path.join(self.sg_attr_dir, str(img_context_key).zfill(12) + ".json"))) 720 | 721 | obj_list = [] 722 | conf_list = [] 723 | for obj in context_scene_graph[0]: 724 | if obj['class'] not in obj_list: 725 | obj_list.append(obj['class']) 726 | conf_list.append(obj['conf']) 727 | for obj in context_scene_graph_attr[0]: 728 | if obj['class'] not in obj_list: 729 | obj_list.append(obj['class']) 730 | conf_list.append(obj['conf']) 731 | if self.args.engine == "chat" or self.args.engine == "chat-test": 732 | prompt += 'Question: %s\n===\n' % (self.traincontext_interactive_question_dict[context_key]) 733 | elif self.args.use_attributes_to_see or self.args.use_caption_to_see: 734 | prompt += 'Question: %s\n===\n' % (self.traincontext_interactive_question_dict[context_key]) 735 | else: 736 | prompt += 'Question: %s\n===\nOptions:\n' % (self.traincontext_interactive_question_dict[context_key]) 737 | prompt += "The most related option is %s.\n\n===\n" % rel_obj[0] 738 | obj_list = [obj[1] for obj in attr_list] 739 | if self.args.engine == "chat" or self.args.engine == "chat-test": 740 | prompt += "Question: %s\n===\nOptions: %s\n" % (question, ", ".join(obj_list)) 741 | elif self.args.use_attributes_to_see: 742 | obj_list = [f"{obj[1]}: {' '.join(obj[2])} {obj[1]}" for obj in attr_list] 743 | prompt += "Question: %s\n===\nOptions: %s\n" % (question, ",\n".join(obj_list)) 744 | elif self.args.use_caption_to_see: 745 | obj_list = [f"{obj[1]}: {obj[-2]}" for obj in attr_list] 746 | prompt += "Question: %s\n===\nOptions: %s\n" % (question, ", ".join(obj_list)) 747 | else: 748 | prompt += "Question: %s\n===\nOptions:\n" % (question) 749 | prompt += "The most related option is" 750 | if self.args.engine in ["ada", "babbage", "curie", "davinci", "codex", "instruct", "gpt3"]: 751 | obj_idx_list = [] 752 | for obj in obj_list: 753 | obj_idx_list.append(self.tokenizer.encode(f" {obj}")[0]) 754 | logit_bias = {} 755 | current_bias = 100 756 | successful = False 757 | if self.args.engine == "codex": 758 | engine_name = "code-davinci-002" 759 | elif self.args.engine == "instruct": 760 | engine_name = "davinci-instruct-beta" 761 | elif self.args.engine == "gpt3": 762 | engine_name = "text-davinci-001" 763 | else: 764 | engine_name = self.args.engine 765 | while not successful: 766 | for tok_idx in obj_idx_list: 767 | logit_bias[str(tok_idx)] = current_bias 768 | try: 769 | self.sleep() 770 | response = openai.Completion.create( 771 | engine=engine_name, 772 | prompt=prompt, 773 | max_tokens=4, 774 | logprobs=1, 775 | temperature=0., 776 | stream=False, 777 | stop=["\n", "<|endoftext|>"], 778 | logit_bias=logit_bias 779 | ) 780 | successful = True 781 | except Exception as e: 782 | print(e) 783 | self.sleep(switch_key=True) 784 | result = self.tokenizer.encode(response['choices'][0]['text'])[0] 785 | if result in obj_idx_list: 786 | result = obj_idx_list.index(result) 787 | else: 788 | result = 0 789 | elif self.args.engine == "chat": 790 | obj_idx_list = [] 791 | for obj in obj_list: 792 | obj_idx_list.append(self.tokenizer.encode(f" {obj}")[0]) 793 | logit_bias = {} 794 | current_bias = 25 795 | successful = False 796 | while not successful: 797 | for tok_idx in obj_idx_list: 798 | logit_bias[str(tok_idx)] = current_bias 799 | try: 800 | self.sleep() 801 | response = openai.ChatCompletion.create( 802 | model="gpt-3.5-turbo", 803 | messages=[ 804 | {"role": "system", "content": system_prompt}, 805 | {"role": "user", "content": prompt}], 806 | max_tokens=5, 807 | temperature=0., 808 | stream=False, 809 | # stop=["\n", "<|endoftext|>"], 810 | logit_bias=logit_bias 811 | ) 812 | successful = True 813 | except Exception as e: 814 | print(e) 815 | print(prompt) 816 | current_bias = int(0.8 * current_bias) 817 | self.sleep(switch_key=True) 818 | result = self.tokenizer.encode(response['choices'][0]['message']['content'])[0] 819 | if result in obj_idx_list: 820 | result = obj_idx_list.index(result) 821 | else: 822 | result = 0 823 | self.current_conversation = [ 824 | {"role": "system", "content": system_prompt}, 825 | {"role": "user", "content": prompt}, 826 | {"role": "assistant", "content": response['choices'][0]['message']['content']} 827 | ] 828 | elif self.args.engine == "chat-test": 829 | print([{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]) 830 | self.current_conversation = [ 831 | {"role": "system", "content": system_prompt}, 832 | {"role": "user", "content": prompt}, 833 | {"role": "assistant", "content": "Object0"} 834 | ] 835 | pdb.set_trace() 836 | result = 0 837 | elif self.args.engine == "opt" or self.args.engine == "llama": 838 | obj_idx_list = [] 839 | for obj in obj_list: 840 | obj_idx_list.append(self.tokenizer.encode(f" {obj}")[1]) 841 | inputs = self.tokenizer(prompt, return_tensors="pt") 842 | outputs = self.model.generate(inputs.input_ids.to(torch.cuda.current_device()), max_length=len(inputs.input_ids[0]) + 5, 843 | return_dict_in_generate=True, output_scores=True) 844 | scores = outputs['scores'] 845 | scores = scores[0][0][obj_idx_list] 846 | if self.args.use_attributes_to_see or self.args.use_caption_to_see: 847 | result_str = self.tokenizer.decode(outputs['sequences'][0][len(inputs.input_ids[0]):]).split("\n")[0].strip() 848 | result_str = result_str[:-1] if result_str.endswith(".") else result_str 849 | result = -1 850 | for obj_id, obj in enumerate(obj_list): 851 | if result_str in obj: 852 | result = obj_id 853 | if result == -1: 854 | result = scores.argmax().item() 855 | else: 856 | result = scores.argmax().item() 857 | elif self.args.engine == "bloom": 858 | obj_idx_list = [] 859 | for obj in obj_list: 860 | obj_idx_list.append(self.tokenizer.encode(f" {obj}")[0]) 861 | inputs = self.tokenizer(prompt, return_tensors="pt") 862 | outputs = self.model.generate(inputs.input_ids.to(torch.cuda.current_device()), 863 | max_new_tokens=5, return_dict_in_generate=True, output_scores=True) 864 | scores = outputs['scores'] 865 | scores = scores[0][0][obj_idx_list] 866 | result = scores.argmax().item() 867 | else: 868 | assert False 869 | if self.args.debug: 870 | pdb.set_trace() 871 | return result 872 | 873 | def make_choices_text(self, choices, answer): 874 | return f"{', '.join(choices)}.", choices[answer] 875 | 876 | def sample_inference(self, key, scene_graph_attr, thoughts_list=None): 877 | img_key = int(key.split('<->')[0]) if self.args.set_name!="fvqa" else self.image_dict[key] # for fvqa 878 | if self.args.random_caption: 879 | random.seed(img_key) # keep random context in every step of the same sample consistent 880 | question, answer, caption = self.question_dict[key], self.answer_dict[key], self.inputtext_dict[img_key] 881 | if self.args.pick_example_mode: 882 | question = self.temp_question 883 | if self.args.pick_example_with_question_mode: 884 | question = self.given_question 885 | if self.args.random_caption: 886 | caption = random.choice(list(self.traincontext_caption_dict.values())) 887 | if self.args.use_blip2: 888 | caption_i = self.current_global_caption 889 | else: 890 | caption_i = caption[random.randint(0, 891 | len(caption) - 1)] ## select one caption if exists multiple, not true except COCO GT (5) 892 | 893 | default_sg_text = self.decode_scene_graph(scene_graph_attr) 894 | 895 | pred_answer_list, pred_prob_list, thought_list, all_thought_list = [], [], [], [] 896 | context_key_list = self.get_context_keys(key, self.args.similarity_metric, self.args.n_shot * self.args.n_ensemble) 897 | 898 | 899 | 900 | for repeat in range(self.args.n_ensemble): 901 | if self.args.debug: 902 | t1=time.time() 903 | if self.args.engine == "chat" or self.args.engine == "chat-test": 904 | prompt_before_answer = "Based on the given information, I must guess the most possible answer. Answer:\n" 905 | system_prompt = "Let's play a game. I have an image and a complex question about it. I will provide you some information about" \ 906 | " the image in the context, and you will give me the possible answer and reason to the question. You must provide an answer and can not say unclear or unknown. " \ 907 | "Please follow the format and answer style of the following examples and complete the last example.\n" 908 | prompt = "===\n" 909 | else: 910 | prompt_before_answer = "Answer: The answer is" 911 | prompt = 'Please answer the question according to the above context.\n===\n' 912 | ## prompt format following GPT-3 QA API 913 | cur_caption_i = "" if self.args.remove_caption else caption_i 914 | for ni in range(self.args.n_shot): 915 | if context_key_list is None: 916 | context_key = self.train_keys[random.randint(0, len(self.train_keys) - 1)] 917 | else: 918 | context_key = context_key_list[ni + self.args.n_shot * repeat] 919 | while True: ## make sure get context with valid question and answer 920 | if self.args.choice_only or (len(self.traincontext_question_dict[context_key]) != 0 and len( 921 | self.traincontext_answer_dict[context_key][0]) != 0): 922 | break 923 | context_key = self.train_keys[random.randint(0, len(self.train_keys) - 1)] 924 | img_context_key = int(context_key.split('<->')[0]) if self.args.set_name!="fvqa" else self.image_dict[context_key] # for fvqa 925 | if self.args.random_caption: 926 | context_caption = random.choice(list(self.traincontext_caption_dict.values())) 927 | context_caption = random.choice(context_caption) 928 | elif self.args.remove_caption: 929 | context_caption = "" 930 | else: 931 | context_caption = self.traincontext_caption_dict[img_context_key][ 932 | random.randint(0, len(self.traincontext_caption_dict[img_context_key]) - 1)] 933 | prompt += 'Context: %s\n===\n' % (context_caption) 934 | if self.args.choice_only: 935 | choice_text, answer_text = self.make_choices_text(self.traincontext_choices_dict[context_key], 936 | self.traincontext_answer_dict[context_key]) 937 | choice_text = f"\nChoices: {choice_text}" 938 | else: 939 | choice_text = "" 940 | answer_text = self.traincontext_answer_dict[context_key][0] 941 | #if self.args.set_name !="fvqa" else self.traincontext_answer_dict[context_key] 942 | if self.chain_of_thoughts: 943 | rationale_text = self.traincontext_rationale_dict[context_key][0] 944 | #if self.args.set_name !="fvqa" else self.traincontext_rationale_dict[context_key] 945 | prompt += 'Question: %s%s\n%s %s. %s\n\n===\n' % (self.traincontext_question_dict[context_key], 946 | choice_text, prompt_before_answer, answer_text, rationale_text) 947 | else: 948 | prompt += 'Question: %s%s\n%s %s\n\n===\n' % ( 949 | self.traincontext_question_dict[context_key], choice_text, prompt_before_answer, answer_text) 950 | 951 | if thoughts_list is not None and len(thoughts_list) > 0: 952 | cur_thoughts_list = [th for th in thoughts_list if th != ''] 953 | if len(cur_thoughts_list) > 0: 954 | cur_caption_i += "\n" 955 | cur_caption_i += " ".join(cur_thoughts_list) 956 | if self.args.choice_only: 957 | choice_text, _ = self.make_choices_text(self.choices_dict[key], 0) 958 | choice_text = f"\nChoices: {choice_text}" 959 | else: 960 | choice_text = "" 961 | if default_sg_text == "": 962 | prompt += 'Context: %s\n===\n' % cur_caption_i 963 | else: 964 | prompt += 'Context: %s\n%s\n===\n' % (cur_caption_i, default_sg_text) 965 | 966 | if self.chain_of_thoughts: 967 | prompt += 'Question: %s%s\n%s' % (question, choice_text, prompt_before_answer) 968 | else: 969 | prompt += 'Question: %s%s\n%s' % (question, choice_text, prompt_before_answer) 970 | response = None 971 | if self.args.debug: 972 | t2=time.time() 973 | if self.args.engine in ["ada", "babbage", "curie", "davinci", "codex", "instruct", "gpt3"]: 974 | successful = False 975 | if self.args.engine == "codex": 976 | engine_name = "code-davinci-002" 977 | elif self.args.engine == "instruct": 978 | engine_name = "davinci-instruct-beta" 979 | elif self.args.engine == "gpt3": 980 | engine_name = "text-davinci-001" 981 | else: 982 | engine_name = self.args.engine 983 | while not successful: 984 | try: 985 | self.sleep() 986 | response = openai.Completion.create( 987 | engine=engine_name, 988 | prompt=prompt, 989 | max_tokens=41, 990 | logprobs=1, 991 | temperature=0., 992 | stream=False, 993 | stop=["\n", "<|endoftext|>"] 994 | ) 995 | successful = True 996 | except Exception as e: 997 | print(e) 998 | self.sleep(switch_key=True) 999 | plist = [] 1000 | if self.chain_of_thoughts: 1001 | for ii in range(len(response['choices'][0]['logprobs']['tokens'])): 1002 | if response['choices'][0]['logprobs']['tokens'][ii].endswith("."): 1003 | break 1004 | plist.append(response['choices'][0]['logprobs']['token_logprobs'][ii]) 1005 | pred_answer_list.append(process_answer(response['choices'][0]["text"].split(".")[0])) 1006 | thought = ".".join(response['choices'][0]["text"].split(".")[1:]).strip() 1007 | pred_prob_list.append(sum(plist)) 1008 | else: 1009 | for ii in range(len(response['choices'][0]['logprobs']['tokens'])): 1010 | if response['choices'][0]['logprobs']['tokens'][ii] == '\n': 1011 | break 1012 | plist.append(response['choices'][0]['logprobs']['token_logprobs'][ii]) 1013 | pred_answer_list.append(process_answer(response['choices'][0]["text"])) 1014 | pred_prob_list.append(sum(plist)) 1015 | elif self.args.engine == "chat": 1016 | successful = False 1017 | while not successful: 1018 | try: 1019 | self.sleep() 1020 | response = openai.ChatCompletion.create( 1021 | model="gpt-3.5-turbo", 1022 | messages=[ 1023 | {"role": "system", "content": system_prompt}, 1024 | {"role": "user", "content": prompt} 1025 | ], 1026 | max_tokens=40, 1027 | temperature=0., 1028 | stream=False, 1029 | ) 1030 | successful = True 1031 | except Exception as e: 1032 | print(e) 1033 | print(prompt) 1034 | self.sleep(switch_key=True) 1035 | if self.chain_of_thoughts: 1036 | pred_answer_list.append(process_answer(response['choices'][0]['message']['content'].split(".")[0])) 1037 | thought = ".".join(response['choices'][0]['message']['content'].split(".")[1:]).strip() 1038 | pred_prob_list.append(0) 1039 | else: 1040 | pred_answer_list.append(process_answer(response['choices'][0]['message']['content'])) 1041 | pred_prob_list.append(0) 1042 | elif self.args.engine == "chat-test": 1043 | print([{"role": "system", "content": system_prompt},{"role": "user", "content": prompt}]) 1044 | pdb.set_trace() 1045 | pred_answer_list.append("fake answer") 1046 | thought = "This is a fake thought." 1047 | pred_prob_list.append(0) 1048 | elif self.args.engine == "opt" or self.args.engine == "llama" or self.args.engine == "bloom": 1049 | inputs = self.tokenizer(prompt, return_tensors="pt") 1050 | with torch.no_grad(): 1051 | outputs = self.model.generate(inputs.input_ids.to(torch.cuda.current_device()), max_length=len(inputs.input_ids[0]) + 40, 1052 | return_dict_in_generate=True, output_scores=True) 1053 | plist = [] 1054 | result = self.tokenizer.batch_decode(outputs['sequences'][:, len(inputs.input_ids[0]):])[0] 1055 | if self.chain_of_thoughts: 1056 | for ii in range(len(inputs.input_ids[0]), len(outputs['sequences'][0])): 1057 | tok = outputs['sequences'][0][ii] 1058 | if self.tokenizer.decode([tok]) == '.': 1059 | break 1060 | scores = torch.log_softmax(outputs['scores'][ii - len(inputs.input_ids[0])], dim=-1) 1061 | plist.append(scores[0][tok]) 1062 | thought = ".".join(result.split("\n")[0].split("The answer is")[-1].split(".")[1:]).strip() 1063 | pred_answer = process_answer(result.split("\n")[0].split("The answer is")[-1].split(".")[0]) 1064 | pred_answer_list.append(pred_answer) 1065 | pred_prob_list.append(sum(plist)) 1066 | else: 1067 | for ii in range(len(inputs.input_ids[0]), len(outputs['sequences'][0])): 1068 | tok = outputs['sequences'][0][ii] 1069 | if self.tokenizer.decode([tok]) == '\n': 1070 | break 1071 | scores = torch.log_softmax(outputs['scores'][ii - len(inputs.input_ids[0])], dim=-1) 1072 | plist.append(scores[0][tok]) 1073 | pred_answer = process_answer(result.split("\n")[0]) 1074 | pred_answer_list.append(pred_answer) 1075 | pred_prob_list.append(sum(plist)) 1076 | else: 1077 | assert False 1078 | if self.args.debug: 1079 | t3=time.time() 1080 | if self.chain_of_thoughts and self.args.with_clip_verify: 1081 | if self.args.use_blip2: 1082 | tmp_thought_list = thought.split(".") 1083 | new_tmp_thought_list = [] 1084 | new_tmp_thought_list_all = [] 1085 | for thought in tmp_thought_list: 1086 | new_tmp_thought_list.append(self.query_blip2_thought_match_img(thought)) 1087 | new_tmp_thought_list_all.append(thought) 1088 | new_thought = ".".join(new_tmp_thought_list).strip() + "." 1089 | new_thought_all = ".".join(new_tmp_thought_list_all).strip() + "." 1090 | if len(new_tmp_thought_list) > 0: 1091 | thought_list.append(new_thought) 1092 | else: 1093 | thought_list.append('') 1094 | all_thought_list.append(new_thought_all) 1095 | else: 1096 | with torch.no_grad(): 1097 | img_id = self.valkey2idx[key] 1098 | img_emb = torch.from_numpy(self.image_val_feature[img_id]).cuda().float().unsqueeze(dim=0) 1099 | tmp_thought_list = thought.split(".") 1100 | inputs = self.clip_processor(text=tmp_thought_list, return_tensors="pt", padding=True) 1101 | inputs = {k: v.cuda() for k, v in inputs.items()} 1102 | clip_outputs = self.clip_model(**inputs) 1103 | thought_emb = clip_outputs['pooler_output'] 1104 | thought_emb /= thought_emb.norm(dim=-1, keepdim=True) 1105 | img_emb /= img_emb.norm(dim=-1, keepdim=True) 1106 | sim_cands = img_emb @ thought_emb.T 1107 | sim_thre = self.args.verify_threshold 1108 | new_tmp_thought_list = [] 1109 | new_tmp_thought_list_all = [] 1110 | for tid in range(sim_cands.shape[1]): 1111 | sim = sim_cands[0, tid].item() 1112 | if sim > sim_thre and len(tmp_thought_list[tid]) > 0: 1113 | new_tmp_thought_list.append(tmp_thought_list[tid]) 1114 | new_tmp_thought_list_all.append(tmp_thought_list[tid]) 1115 | new_thought = ".".join(new_tmp_thought_list).strip() + "." 1116 | new_thought_all = ".".join(new_tmp_thought_list_all).strip() + "." 1117 | if self.args.random_rationale: 1118 | new_thought = random.choice(list(self.traincontext_rationale_dict.values())) 1119 | new_thought = random.choice(new_thought) 1120 | new_tmp_thought_list = new_thought.split(".") 1121 | new_thought_all = random.choice(list(self.traincontext_rationale_dict.values())) 1122 | new_thought_all = random.choice(new_thought_all) 1123 | elif self.args.oracle_rationale: 1124 | new_thought = self.rationale_dict[key][0] 1125 | new_tmp_thought_list = new_thought.split(".") 1126 | new_thought_all = self.rationale_dict[key][0] 1127 | if len(new_tmp_thought_list) > 0: 1128 | thought_list.append(new_thought) 1129 | else: 1130 | thought_list.append('') 1131 | all_thought_list.append(new_thought_all) 1132 | elif self.chain_of_thoughts: 1133 | if self.args.random_rationale: 1134 | assert False 1135 | thought_list.append(thought) 1136 | all_thought_list.append(new_thought) 1137 | if self.args.debug: 1138 | t4=time.time() 1139 | print(" REASON PREPARE TIME", t2-t1) 1140 | print(" REASON INF TIME", t3-t2) 1141 | print(" REASON POST TIME", t4-t3) 1142 | maxval = -999. 1143 | for ii in range(len(pred_prob_list)): 1144 | if pred_prob_list[ii] > maxval: 1145 | if self.chain_of_thoughts: 1146 | thoughts, all_thoughts = thought_list[ii], all_thought_list[ii] 1147 | maxval, pred_answer = pred_prob_list[ii], pred_answer_list[ii] 1148 | ## a rough accuracy estimator for fast results check 1149 | if self.args.choice_only: 1150 | if pred_answer not in self.choices_dict[key]: 1151 | choices_list = self.choices_dict[key] + [pred_answer] 1152 | inputs = self.clip_processor(text=choices_list, return_tensors="pt", padding=True) 1153 | inputs = {k: v.cuda() for k, v in inputs.items()} 1154 | clip_outputs = self.clip_model(**inputs) 1155 | thought_emb = clip_outputs['pooler_output'] 1156 | thought_emb /= thought_emb.norm(dim=-1, keepdim=True) 1157 | sim = thought_emb[-1].unsqueeze(0) @ thought_emb[:-1].T 1158 | pred_answer = self.choices_dict[key][sim.argmax().item()] 1159 | final_score = 1 if pred_answer == self.choices_dict[key][answer] else 0 1160 | else: 1161 | counter = 0 1162 | for ii in range(len(answer)): 1163 | if pred_answer == answer[ii]: counter += 1 1164 | final_score = min(1., float(counter) * 0.3) 1165 | if self.args.debug: 1166 | print(prompt) 1167 | print(pred_answer) 1168 | print(answer) 1169 | pdb.set_trace() 1170 | if self.chain_of_thoughts: 1171 | return [key, pred_answer, prompt, final_score, thoughts, all_thoughts, float(maxval), 1172 | [attr[1] for attr in scene_graph_attr]] 1173 | return [key, pred_answer, prompt, final_score, float(maxval), [attr[1] for attr in scene_graph_attr]] 1174 | 1175 | def get_context_keys(self, key, metric, n): 1176 | if metric == 'question': 1177 | lineid = self.valkey2idx[key] 1178 | if self.args.pick_example_mode: 1179 | inputs = self.clip_processor(text=[self.temp_question], return_tensors="pt", padding=True) 1180 | inputs = {k: v.cuda() for k, v in inputs.items()} 1181 | clip_outputs = self.clip_model(**inputs) 1182 | val_feature = clip_outputs['pooler_output'].cpu() 1183 | val_feature /= val_feature.norm(dim=-1, keepdim=True) 1184 | similarity = np.matmul(self.train_feature, val_feature.detach()[0].numpy()) 1185 | else: 1186 | similarity = np.matmul(self.train_feature, self.val_feature[lineid, :]) 1187 | index = similarity.argsort()[-n:][::-1] 1188 | return [self.train_idx[str(x)] for x in index] 1189 | elif metric == 'imagequestion': 1190 | ## combined with Q-similairty (image+question) 1191 | lineid = self.valkey2idx[key] 1192 | if self.args.pick_example_mode: 1193 | inputs = self.clip_processor(text=[self.temp_question], return_tensors="pt", padding=True) 1194 | inputs = {k: v.cuda() for k, v in inputs.items()} 1195 | clip_outputs = self.clip_model(**inputs) 1196 | val_feature = clip_outputs['pooler_output'].cpu() 1197 | val_feature /= val_feature.norm(dim=-1, keepdim=True) 1198 | question_similarity = np.matmul(self.train_feature, val_feature.detach()[0].numpy()) 1199 | else: 1200 | question_similarity = np.matmul(self.train_feature, self.val_feature[lineid, :]) 1201 | ## end of Q-similairty 1202 | similarity = question_similarity + np.matmul(self.image_train_feature, self.image_val_feature[lineid, :]) 1203 | index = similarity.argsort()[-n:][::-1] 1204 | return [self.train_idx[str(x)] for x in index] 1205 | else: 1206 | return None 1207 | 1208 | def get_related_obj_dict(self, key): 1209 | if self.args.train_sim_metric == "rationale": 1210 | return self.get_related_obj_dict_rationale(key) 1211 | elif self.args.train_sim_metric == "answer": 1212 | if not hasattr(self, "train_object_select"): 1213 | self.train_object_select = pickle.load(open(self.args.train_sim_file, "rb")) 1214 | return self.train_object_select[key] 1215 | 1216 | def get_related_obj_dict_rationale(self, key): 1217 | img_context_key = int(key.split('<->')[0]) 1218 | context_scene_graph = json.load(open(os.path.join(self.sg_dir, str(img_context_key).zfill(12) + ".json"))) 1219 | context_scene_graph_attr = json.load( 1220 | open(os.path.join(self.sg_attr_dir, str(img_context_key).zfill(12) + ".json"))) 1221 | 1222 | obj_list = [] 1223 | for obj in context_scene_graph[0]: 1224 | if obj['class'] not in obj_list: 1225 | obj_list.append(obj['class']) 1226 | for obj in context_scene_graph_attr[0]: 1227 | if obj['class'] not in obj_list: 1228 | obj_list.append(obj['class']) 1229 | 1230 | related_obj_dict = {} 1231 | rationale = self.traincontext_rationale_dict[key] 1232 | for obj in obj_list: 1233 | for r in rationale: 1234 | if obj in r: 1235 | if obj not in related_obj_dict: 1236 | related_obj_dict[obj] = 1 1237 | else: 1238 | related_obj_dict[obj] += 1 1239 | return related_obj_dict 1240 | 1241 | def get_interactive_context_keys(self, key, metric, n): 1242 | if metric == 'question': 1243 | assert False 1244 | elif metric == 'imagequestion': 1245 | ## combined with Q-similairty (image+question) 1246 | lineid = self.valkey2idx[key] 1247 | if self.args.pick_example_mode: 1248 | inputs = self.clip_processor(text=[self.temp_question], return_tensors="pt", padding=True) 1249 | inputs = {k: v.cuda() for k, v in inputs.items()} 1250 | clip_outputs = self.clip_model(**inputs) 1251 | val_feature = clip_outputs['pooler_output'].cpu() 1252 | val_feature /= val_feature.norm(dim=-1, keepdim=True) 1253 | question_similarity = np.matmul(self.train_feature, val_feature.detach()[0].numpy()) 1254 | else: 1255 | question_similarity = np.matmul(self.train_feature, self.val_feature[lineid, :]) 1256 | ## end of Q-similairty 1257 | similarity = question_similarity + np.matmul(self.image_train_feature, self.image_val_feature[lineid, :]) 1258 | similarity = similarity.argsort() 1259 | idx_list = [] 1260 | rel_obj_list = [] 1261 | for i in range(len(similarity)): 1262 | context_key = self.train_idx[str(similarity[-1 - i])] 1263 | rel_obj_dict = self.get_related_obj_dict(context_key) 1264 | if len(rel_obj_dict) > 0: 1265 | idx_list.append(context_key) 1266 | rel_obj_list.append(rel_obj_dict) 1267 | if len(idx_list) >= n: 1268 | break 1269 | return idx_list, rel_obj_list 1270 | else: 1271 | return None 1272 | 1273 | def load_similarity(self): 1274 | split = "test" if self.args.test_only else "val" 1275 | val_idx = json.load(open('%s/aokvqa_qa_line2sample_idx_%s2017.json' % (self.args.similarity_path, split), 'r')) 1276 | self.valkey2idx = {} 1277 | for ii in val_idx: 1278 | self.valkey2idx[val_idx[ii]] = int(ii) 1279 | if self.args.similarity_metric == 'question': 1280 | self.train_feature = np.load( 1281 | '%s/coco_clip_vitb16_train2017_aokvqa_question.npy' % self.args.similarity_path) 1282 | self.val_feature = np.load('%s/coco_clip_vitb16_%s2017_aokvqa_question.npy' % (self.args.similarity_path, split)) 1283 | self.train_idx = json.load( 1284 | open('%s/aokvqa_qa_line2sample_idx_train2017.json' % self.args.similarity_path, 'r')) 1285 | elif self.args.similarity_metric == 'imagequestion': 1286 | self.train_feature = np.load( 1287 | '%s/coco_clip_vitb16_train2017_aokvqa_question.npy' % self.args.similarity_path) 1288 | self.val_feature = np.load('%s/coco_clip_vitb16_%s2017_aokvqa_question.npy' % (self.args.similarity_path, split)) 1289 | self.train_idx = json.load( 1290 | open('%s/aokvqa_qa_line2sample_idx_train2017.json' % self.args.similarity_path, 'r')) 1291 | self.image_train_feature = np.load( 1292 | '%s/coco_clip_vitb16_train2017_aokvqa_convertedidx_image.npy' % self.args.similarity_path) 1293 | self.image_val_feature = np.load( 1294 | '%s/coco_clip_vitb16_%s2017_aokvqa_convertedidx_image.npy' % (self.args.similarity_path, split)) 1295 | 1296 | def load_tags(self): 1297 | tags_dict = {} 1298 | tagging_pred_file = '%s/test.score.json.tsv' % self.args.tag_path 1299 | read_tsv = csv.reader(open(tagging_pred_file, 'r'), delimiter="\t") 1300 | for row in read_tsv: 1301 | image_id, tags = int(row[0]), json.loads(row[1]) 1302 | tag_str = ', '.join([x['class'] for x in tags]) 1303 | tags_dict[image_id] = tag_str 1304 | tagging_pred_file = '%s/val.score.json.tsv' % self.args.tag_path 1305 | read_tsv = csv.reader(open(tagging_pred_file, 'r'), delimiter="\t") 1306 | for row in read_tsv: 1307 | image_id, tags = int(row[0]), json.loads(row[1]) 1308 | tag_str = ', '.join([x['class'] for x in tags]) 1309 | tags_dict[image_id] = tag_str 1310 | tagging_pred_file = '%s/train.score.json.tsv' % self.args.tag_path 1311 | read_tsv = csv.reader(open(tagging_pred_file, 'r'), delimiter="\t") 1312 | for row in read_tsv: 1313 | image_id, tags = int(row[0]), json.loads(row[1]) 1314 | tag_str = ', '.join([x['class'] for x in tags]) 1315 | tags_dict[image_id] = tag_str 1316 | return tags_dict 1317 | 1318 | def load_cachetext(self): 1319 | read_tsv = csv.reader(open(self.args.valcaption_file, 'r'), delimiter="\t") 1320 | caption_dict = {} 1321 | if 'tag' in self.args.caption_type: 1322 | tags_dict = self.load_tags() 1323 | if self.args.caption_type == 'vinvl_tag': 1324 | for row in read_tsv: 1325 | if int(row[0]) not in caption_dict: 1326 | caption_dict[int(row[0])] = [ 1327 | row[1].split('caption": "')[1].split('", "conf"')[0] + '. ' + tags_dict[int(row[0])]] 1328 | else: 1329 | caption_dict[int(row[0])].append( 1330 | row[1].split('caption": "')[1].split('", "conf"')[0] + '. ' + tags_dict[int(row[0])]) 1331 | else: 1332 | for row in read_tsv: 1333 | if int(row[0]) not in caption_dict: 1334 | caption_dict[int(row[0])] = [row[1].split('caption": "')[1].split('", "conf"')[0]] 1335 | else: 1336 | caption_dict[int(row[0])].append(row[1].split('caption": "')[1].split('", "conf"')[0]) 1337 | return caption_dict 1338 | 1339 | def main(): 1340 | parser = argparse.ArgumentParser() 1341 | parser.add_argument('--apikey_file', type=str, default="", help='api key; https://openai.com/api/') 1342 | parser.add_argument('--apikey', type=str, default="", help='api key; https://openai.com/api/') 1343 | parser.add_argument('--engine', type=str, default='davinci', help='api engine; https://openai.com/api/') 1344 | parser.add_argument('--engine_name', type=str, default='text-davinci-003', help='api engine; https://openai.com/api/') 1345 | parser.add_argument('--caption_type', type=str, default='vinvl_tag', help='vinvl_tag, vinvl, vinvl_sg, vinvl_ocr') 1346 | parser.add_argument('--n_shot', type=int, default=16, help="number of shots") 1347 | parser.add_argument('--n_ensemble', type=int, default=1, help="number of ensemble") 1348 | parser.add_argument('--rounds', type=int, default=3, help="number of interactive rounds") 1349 | parser.add_argument('--image_id', type=int, default=-1, help="selected image id pick example only") 1350 | parser.add_argument('--iterative_strategy', type=str, default="caption", help="caption or sg") 1351 | parser.add_argument('--similarity_metric', type=str, default='imagequestion', help="random/question/imagequestion") 1352 | parser.add_argument('--valcaption_file', type=str, default='input_text/vinvl_caption/VinVL_base_val2014.tsv') 1353 | parser.add_argument('--tag_path', type=str, default='input_text/coco_caption_pred_tags') 1354 | parser.add_argument('--concept_caption_path', type=str, default='scene_graph_coco17_caption') 1355 | parser.add_argument('--sg_path', type=str, default='') 1356 | parser.add_argument('--coco_path', type=str, default='coco_annotations') 1357 | parser.add_argument('--similarity_path', type=str, default='coco_clip_new') 1358 | parser.add_argument('--output_path', type=str, default='output') 1359 | parser.add_argument('--llama_path', type=str, default='/') 1360 | parser.add_argument('--use_blip2', action='store_true') 1361 | parser.add_argument('--choice_only', action='store_true') 1362 | parser.add_argument('--chain_of_thoughts', action='store_true') 1363 | parser.add_argument('--with_six_gpus', action='store_true') 1364 | parser.add_argument('--with_one_gpu', action='store_true') 1365 | parser.add_argument('--test_only', action='store_true') 1366 | parser.add_argument('--random_attend', action='store_true') 1367 | parser.add_argument('--oracle_attend', action='store_true') 1368 | parser.add_argument('--random_caption', action='store_true') 1369 | parser.add_argument('--remove_caption', action='store_true') 1370 | parser.add_argument('--random_rationale', action='store_true') 1371 | parser.add_argument('--oracle_rationale', action='store_true') 1372 | parser.add_argument('--all_regional_captions', action='store_true') 1373 | parser.add_argument('--use_attributes_to_see', action='store_true') 1374 | parser.add_argument('--use_caption_to_see', action='store_true') 1375 | parser.add_argument('--pick_example_mode', action='store_true') 1376 | parser.add_argument('--pick_example_with_question_mode', action='store_true') 1377 | parser.add_argument('--train_sim_metric', type=str, default='rationale') 1378 | parser.add_argument('--train_sim_file', type=str, default='') 1379 | parser.add_argument('--val_sim_file', type=str, default='') 1380 | parser.add_argument('--verify_threshold', type=float, default=0.0) 1381 | parser.add_argument('--start', type=float, default=0.0, help="start point in validation set (0.0-1.0)") 1382 | parser.add_argument('--end', type=float, default=1.0, help="end point in validation set (0.0-1.0)") 1383 | parser.add_argument('--with_clip_verify', action='store_true') 1384 | parser.add_argument('--debug', action='store_true') 1385 | parser.add_argument('--ablation_visual', action='store_true') 1386 | parser.add_argument('--ablation_reason', action='store_true') 1387 | parser.add_argument('--use_v100', action='store_true') 1388 | parser.add_argument('--local_rank', required=False, type=int, help='used by dist launchers') 1389 | parser.add_argument('--raw_image_dir', type=str, default="/path/to/your/coco") 1390 | parser.add_argument('--with_blip2_api', action='store_true') 1391 | parser.add_argument('--set_name', type=str, default='aokvqa') 1392 | args = parser.parse_args() 1393 | 1394 | if args.apikey_file != "": 1395 | apikey_list = open(args.apikey_file).readlines() 1396 | apikey_list = [line.strip() for line in apikey_list] 1397 | else: 1398 | apikey_list = [args.apikey] 1399 | 1400 | aokvqa = VisualCOT_AOKVQA(args, apikey_list) 1401 | 1402 | ## main inference 1403 | #with torch.cuda.amp.autocast(dtype=torch.float): 1404 | answers, full_answers = aokvqa.inference(save_every_step=args.engine in ['ada', 'babbage', 'curie', 'davinci', 'gpt3', 1405 | 'chat', 'codex', 'instruct'] or 1406 | args.pick_example_mode) 1407 | 1408 | # prediction = {} 1409 | acc = 0. 1410 | # for answer in answers: 1411 | # prediction[answer[0]] = [answer[1], answer[2]] 1412 | # acc += float(answer[3]) 1413 | 1414 | format_prediction = [] 1415 | for answer in answers: 1416 | if args.chain_of_thoughts: 1417 | format_prediction.append({"answer": answer[1], "question_id": answer[0].split('<->')[1], 1418 | "thoughts": answer[5]}) 1419 | else: 1420 | format_prediction.append({"answer": answer[1], "question_id": answer[0].split('<->')[1]}) 1421 | 1422 | print(acc * 100. / len(answers), len(answers)) 1423 | acc = acc * 100. / len(answers) 1424 | 1425 | ## if save final predictions 1426 | os.system("mkdir -p %s" % args.output_path) 1427 | os.system("mkdir -p %s/prompt_answer" % args.output_path) 1428 | os.system("mkdir -p %s/format_answer" % args.output_path) 1429 | output_name = 'VisualCOT_%s_n%d_repeat%d_%s_%f.json' % ( 1430 | args.caption_type, args.n_shot, args.n_ensemble, args.similarity_metric, acc) 1431 | json.dump(full_answers, open("%s/prompt_answer/%s" % (args.output_path, output_name), 'w')) 1432 | json.dump(format_prediction, open("%s/format_answer/%s" % (args.output_path, output_name), 'w')) 1433 | 1434 | if __name__ == '__main__': 1435 | main() 1436 | -------------------------------------------------------------------------------- /main_okvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import json 5 | import torch 6 | from PIL import Image 7 | from main_aokvqa import VisualCOT_AOKVQA 8 | 9 | class VisualCOT(VisualCOT_AOKVQA): 10 | def __init__(self, args, apikey_list): 11 | super().__init__(args, apikey_list) 12 | self.train_ok_keys = list(self.traincontext_ok_answer_dict.keys()) 13 | 14 | def find_image(self, img_key): 15 | split = "test" if self.args.test_only else "val" 16 | img_full_path = os.path.join(self.raw_image_dir, "COCO_%s2014_%012d.jpg" % (split, img_key)) 17 | print(img_full_path) 18 | return Image.open(img_full_path).convert("RGB") 19 | 20 | def load_dataset(self, args): 21 | test_name = "test" if args.test_only else "val" 22 | self.raw_image_dir = os.path.join(self.args.raw_image_dir, "%s2014" % test_name) 23 | _, self.answer_dict, self.question_dict = \ 24 | self.load_ok_anno(None, f'%s/mscoco_{test_name}2014_annotations.json' % args.coco_path, \ 25 | f'%s/OpenEnded_mscoco_{test_name}2014_questions.json' % args.coco_path) 26 | self.val_keys = list(self.question_dict.keys()) 27 | self.val_keys = self.val_keys[int(args.start * len(self.val_keys)):int(args.end * len(self.val_keys))] 28 | 29 | ## load cached image representation (Coco caption & Tags) 30 | self.inputtext_dict = self.load_cachetext() 31 | 32 | if self.args.with_ok_context: 33 | self.traincontext_caption_dict, self.traincontext_answer_dict, \ 34 | self.traincontext_question_dict = \ 35 | self.load_ok_anno(f"%s/captions_train2017.json" % args.coco_path, \ 36 | f'%s/mscoco_train2014_annotations.json' % args.coco_path, \ 37 | f'%s/OpenEnded_mscoco_train2014_questions.json' % args.coco_path) 38 | self.traincontext_ok_answer_dict = self.traincontext_answer_dict 39 | self.traincontext_ok_question_dict = self.traincontext_question_dict 40 | # without chain_of_thoughts for ok context 41 | assert not self.args.chain_of_thoughts 42 | else: 43 | self.traincontext_caption_dict, self.traincontext_answer_dict, \ 44 | self.traincontext_question_dict, self.traincontext_rationale_dict, \ 45 | self.traincontext_choices_dict = \ 46 | self.load_aok_anno('%s/captions_train2017.json' % args.coco_path, \ 47 | '%s/aokvqa_v1p0_train.json' % args.coco_path, \ 48 | '%s/aokvqa_v1p0_train.json' % args.coco_path, choice_only=args.choice_only) 49 | 50 | _, self.traincontext_ok_answer_dict, \ 51 | self.traincontext_ok_question_dict = \ 52 | self.load_ok_anno(None, f'%s/mscoco_train2014_annotations.json' % args.coco_path, \ 53 | f'%s/OpenEnded_mscoco_train2014_questions.json' % args.coco_path) 54 | if args.caption_type == 'vinvl_ocr': 55 | self.load_ocr(os.path.join(self.args.sg_path, "coco14_ocr_train.json"), 56 | os.path.join(self.args.sg_path, f"coco14_ocr_{test_name}.json"), 57 | os.path.join(self.args.sg_path, "scene_graph_coco17_attr")) 58 | self.sg_dir = os.path.join(self.args.sg_path, "scene_graph_coco17") 59 | self.sg_attr_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_attr") 60 | self.sg_cap_dir = os.path.join(self.args.sg_path, self.args.concept_caption_path) 61 | 62 | self.train_keys = list(self.traincontext_answer_dict.keys()) 63 | self.train_interactive_keys = list(self.traincontext_ok_answer_dict.keys()) 64 | self.traincontext_interactive_answer_dict = self.traincontext_ok_answer_dict 65 | self.traincontext_interactive_question_dict = self.traincontext_ok_question_dict 66 | 67 | def load_ok_anno(self, coco_caption_file, answer_anno_file, question_anno_file): 68 | if coco_caption_file is not None: 69 | coco_caption = json.load(open(coco_caption_file, 'r')) 70 | if type(coco_caption) == type({}): coco_caption = coco_caption['annotations'] 71 | answer_anno = json.load(open(answer_anno_file, 'r')) 72 | question_anno = json.load(open(question_anno_file, 'r')) 73 | 74 | caption_dict = {} 75 | if coco_caption_file is not None: 76 | for sample in coco_caption: 77 | if sample['image_id'] not in caption_dict: 78 | caption_dict[sample['image_id']] = [sample['caption']] 79 | else: 80 | caption_dict[sample['image_id']].append(sample['caption']) 81 | answer_dict = {} 82 | for sample in answer_anno['annotations']: 83 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in answer_dict: 84 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = [x['answer'] for x in 85 | sample['answers']] 86 | 87 | question_dict = {} 88 | for sample in question_anno['questions']: 89 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in question_dict: 90 | question_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['question'] 91 | return caption_dict, answer_dict, question_dict 92 | 93 | def load_aok_anno(self, coco_caption_file, answer_anno_file, question_anno_file, choice_only=False): 94 | if coco_caption_file is not None: 95 | coco_caption = json.load(open(coco_caption_file, 'r')) 96 | if type(coco_caption) == type({}): coco_caption = coco_caption['annotations'] 97 | answer_anno = json.load(open(answer_anno_file, 'r')) 98 | question_anno = json.load(open(question_anno_file, 'r')) 99 | 100 | caption_dict = {} 101 | if coco_caption_file is not None: 102 | for sample in coco_caption: 103 | if sample['image_id'] not in caption_dict: 104 | caption_dict[sample['image_id']] = [sample['caption']] 105 | else: 106 | caption_dict[sample['image_id']].append(sample['caption']) 107 | 108 | answer_dict = {} 109 | for sample in answer_anno: 110 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in answer_dict: 111 | if choice_only: 112 | if 'correct_choice_idx' in sample: 113 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample[ 114 | "correct_choice_idx"] 115 | else: 116 | assert False 117 | else: 118 | if 'direct_answers' in sample: 119 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample[ 120 | "direct_answers"] 121 | else: 122 | answer_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = [""] 123 | 124 | question_dict = {} 125 | for sample in question_anno: 126 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in question_dict: 127 | question_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['question'] 128 | 129 | rationales_dict = {} 130 | for sample in answer_anno: 131 | if str(sample['image_id']) + '<->' + str(sample['question_id']) not in rationales_dict: 132 | if 'rationales' in sample: 133 | rationales_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['rationales'] 134 | else: 135 | rationales_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = "" 136 | 137 | choices_dict = {} 138 | for sample in answer_anno: 139 | choices_dict[str(sample['image_id']) + '<->' + str(sample['question_id'])] = sample['choices'] 140 | 141 | return caption_dict, answer_dict, question_dict, rationales_dict, choices_dict 142 | 143 | def get_context_keys(self, key, metric, n): 144 | if not self.args.with_ok_context: 145 | if metric == 'question': 146 | lineid = self.valkey2idx[key] 147 | similarity = np.matmul(self.train_feature, self.val_feature[lineid, :]) 148 | index = similarity.argsort()[-n:][::-1] 149 | return [self.train_idx[str(x)] for x in index] 150 | elif metric == 'imagequestion': 151 | ## combined with Q-similairty (image+question) 152 | lineid = self.valkey2idx[key] 153 | question_similarity = np.matmul(self.train_feature, self.val_feature[lineid, :]) 154 | ## end of Q-similairty 155 | similarity = question_similarity + np.matmul(self.image_train_feature, self.image_val_feature[lineid, :]) 156 | index = similarity.argsort()[-n:][::-1] 157 | return [self.train_idx[str(x)] for x in index] 158 | else: 159 | return None 160 | else: 161 | if metric == 'question': 162 | lineid = self.valkey2idx[key] 163 | similarity = np.matmul(self.train_ok_feature, self.val_feature[lineid, :]) 164 | index = similarity.argsort()[-n:][::-1] 165 | return [self.train_ok_idx[str(x)] for x in index] 166 | elif metric == 'imagequestion': 167 | ## combined with Q-similairty (image+question) 168 | lineid = self.valkey2idx[key] 169 | question_similarity = np.matmul(self.train_ok_feature, self.val_feature[lineid, :]) 170 | ## end of Q-similairty 171 | similarity = question_similarity + np.matmul(self.image_train_ok_feature, self.image_val_feature[lineid, :]) 172 | index = similarity.argsort()[-n:][::-1] 173 | return [self.train_ok_idx[str(x)] for x in index] 174 | else: 175 | return None 176 | 177 | def get_interactive_context_keys(self, key, metric, n): 178 | if metric == 'question': 179 | assert False 180 | elif metric == 'imagequestion': 181 | ## combined with Q-similairty (image+question) 182 | lineid = self.valkey2idx[key] 183 | question_similarity = np.matmul(self.train_ok_feature, self.val_feature[lineid, :]) 184 | ## end of Q-similairty 185 | similarity = question_similarity + np.matmul(self.image_train_ok_feature, self.image_val_feature[lineid, :]) 186 | similarity = similarity.argsort() 187 | idx_list = [] 188 | rel_obj_list = [] 189 | for i in range(len(similarity)): 190 | context_key = self.train_ok_idx[str(similarity[-1 - i])] 191 | rel_obj_dict = self.get_related_obj_dict(context_key) 192 | if len(rel_obj_dict) > 0: 193 | idx_list.append(context_key) 194 | rel_obj_list.append(rel_obj_dict) 195 | if len(idx_list) >= n: 196 | break 197 | return idx_list, rel_obj_list 198 | else: 199 | return None 200 | 201 | def load_similarity(self): 202 | split = "test" if self.args.test_only else "val" 203 | val_idx = json.load(open('%s/okvqa_qa_line2sample_idx_%s2014.json' % (self.args.similarity_path, split), 'r')) 204 | self.valkey2idx = {} 205 | for ii in val_idx: 206 | self.valkey2idx[val_idx[ii]] = int(ii) 207 | if self.args.similarity_metric == 'question': 208 | self.train_feature = np.load( 209 | '%s/coco_clip_vitb16_train2017_aokvqa_question.npy' % self.args.similarity_path) 210 | self.train_ok_feature = np.load( 211 | '%s/coco_clip_vitb16_train2014_okvqa_question.npy' % self.args.similarity_path) 212 | self.val_feature = np.load( 213 | '%s/coco_clip_vitb16_%s2014_okvqa_question.npy' % (self.args.similarity_path, split)) 214 | self.train_idx = json.load( 215 | open('%s/aokvqa_qa_line2sample_idx_train2017.json' % self.args.similarity_path, 'r')) 216 | self.train_ok_idx = json.load( 217 | open('%s/okvqa_qa_line2sample_idx_train2014.json' % self.args.similarity_path, 'r')) 218 | elif self.args.similarity_metric == 'imagequestion': 219 | self.train_feature = np.load( 220 | '%s/coco_clip_vitb16_train2017_aokvqa_question.npy' % self.args.similarity_path) 221 | self.train_ok_feature = np.load( 222 | '%s/coco_clip_vitb16_train2014_okvqa_question.npy' % self.args.similarity_path) 223 | self.val_feature = np.load( 224 | '%s/coco_clip_vitb16_%s2014_okvqa_question.npy' % (self.args.similarity_path, split)) 225 | self.train_idx = json.load( 226 | open('%s/aokvqa_qa_line2sample_idx_train2017.json' % self.args.similarity_path, 'r')) 227 | self.train_ok_idx = json.load( 228 | open('%s/okvqa_qa_line2sample_idx_train2014.json' % self.args.similarity_path, 'r')) 229 | self.image_train_feature = np.load( 230 | '%s/coco_clip_vitb16_train2017_aokvqa_convertedidx_image.npy' % self.args.similarity_path) 231 | self.image_train_ok_feature = np.load( 232 | '%s/coco_clip_vitb16_train2014_okvqa_convertedidx_image.npy' % self.args.similarity_path) 233 | self.image_val_feature = np.load( 234 | '%s/coco_clip_vitb16_%s2014_okvqa_convertedidx_image.npy' % (self.args.similarity_path, split)) 235 | 236 | def main(): 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument('--apikey_file', type=str, default="", help='api key; https://openai.com/api/') 239 | parser.add_argument('--apikey', type=str, default="", help='api key; https://openai.com/api/') 240 | parser.add_argument('--engine', type=str, default='davinci', help='api engine; https://openai.com/api/') 241 | parser.add_argument('--engine_name', type=str, default='text-davinci-003', help='api engine name') 242 | parser.add_argument('--caption_type', type=str, default='vinvl_tag', help='vinvl_tag, vinvl, vinvl_sg, vinvl_ocr') 243 | parser.add_argument('--n_shot', type=int, default=16, help="number of shots") 244 | parser.add_argument('--n_ensemble', type=int, default=1, help="number of ensemble") 245 | parser.add_argument('--rounds', type=int, default=3, help="number of interactive rounds") 246 | parser.add_argument('--iterative_strategy', type=str, default="caption", help="caption or sg") 247 | parser.add_argument('--similarity_metric', type=str, default='imagequestion', help="random/question/imagequestion") 248 | parser.add_argument('--valcaption_file', type=str, default='input_text/vinvl_caption/VinVL_base_val2014.tsv') 249 | parser.add_argument('--tag_path', type=str, default='input_text/coco_caption_pred_tags') 250 | parser.add_argument('--concept_caption_path', type=str, default="scene_graph_coco14_caption_ok") 251 | parser.add_argument('--sg_path', type=str, default='') 252 | parser.add_argument('--coco_path', type=str, default='coco_annotations') 253 | parser.add_argument('--similarity_path', type=str, default='coco_clip_new') 254 | parser.add_argument('--output_path', type=str, default='output') 255 | parser.add_argument('--use_blip2', action='store_true') 256 | parser.add_argument('--choice_only', action='store_true') 257 | parser.add_argument('--chain_of_thoughts', action='store_true') 258 | parser.add_argument('--all_regional_captions', action='store_true') 259 | parser.add_argument('--use_attributes_to_see', action='store_true') 260 | parser.add_argument('--with_six_gpus', action='store_true') 261 | parser.add_argument('--with_one_gpu', action='store_true') 262 | parser.add_argument('--test_only', action='store_true') 263 | parser.add_argument('--random_attend', action='store_true') 264 | parser.add_argument('--oracle_attend', action='store_true') 265 | parser.add_argument('--random_caption', action='store_true') 266 | parser.add_argument('--remove_caption', action='store_true') 267 | parser.add_argument('--random_rationale', action='store_true') 268 | parser.add_argument('--oracle_rationale', action='store_true') 269 | parser.add_argument('--llama_path', type=str, default='/') 270 | parser.add_argument('--train_sim_metric', type=str, default='rationale') 271 | parser.add_argument('--train_sim_file', type=str, default='') 272 | parser.add_argument('--val_sim_file', type=str, default='') 273 | parser.add_argument('--verify_threshold', type=float, default=0.0) 274 | parser.add_argument('--start', type=float, default=0.0, help="start point in validation set (0.0-1.0)") 275 | parser.add_argument('--end', type=float, default=1.0, help="end point in validation set (0.0-1.0)") 276 | parser.add_argument('--with_clip_verify', action='store_true') 277 | parser.add_argument('--debug', action='store_true') 278 | parser.add_argument('--with_ok_context', action='store_true') 279 | parser.add_argument('--ablation_visual', action='store_true') 280 | parser.add_argument('--ablation_reason', action='store_true') 281 | parser.add_argument('--use_v100', action='store_true') 282 | parser.add_argument('--local_rank', required=False, type=int, help='used by dist launchers') 283 | parser.add_argument('--raw_image_dir', type=str, default="/path/to/your/coco") 284 | parser.add_argument('--with_blip2_api', action='store_true') 285 | parser.add_argument('--set_name', type=str, default='okvqa') 286 | args = parser.parse_args() 287 | 288 | if args.apikey_file != "": 289 | apikey_list = open(args.apikey_file).readlines() 290 | apikey_list = [line.strip() for line in apikey_list] 291 | else: 292 | apikey_list = [args.apikey] 293 | 294 | okvqa = VisualCOT(args, apikey_list=apikey_list) 295 | 296 | ## main inference 297 | #with torch.cuda.amp.autocast(dtype=torch.float): 298 | answers = okvqa.inference(save_every_step=args.engine in ['ada', 'babbage', 'curie', 'davinci', 'chat', 'codex', 'instruct', 'gpt3']) 299 | 300 | prediction = {} 301 | acc = 0. 302 | for answer in answers: 303 | prediction[answer[0]] = [answer[1], answer[2]] 304 | acc += float(answer[3]) 305 | 306 | format_prediction = [] 307 | for answer in answers: 308 | if args.chain_of_thoughts: 309 | format_prediction.append({"answer": answer[1], "question_id": answer[0].split('<->')[1], 310 | "thoughts": answer[5]}) 311 | else: 312 | format_prediction.append({"answer": answer[1], "question_id": answer[0].split('<->')[1]}) 313 | 314 | print(acc * 100. / len(answers), len(answers)) 315 | acc = acc * 100. / len(answers) 316 | 317 | ## if save final predictions 318 | os.system("mkdir -p %s" % args.output_path) 319 | os.system("mkdir -p %s/prompt_answer" % args.output_path) 320 | os.system("mkdir -p %s/format_answer" % args.output_path) 321 | output_name = 'VisualCOT_%s_n%d_repeat%d_%s_%f.json' % ( 322 | args.caption_type, args.n_shot, args.n_ensemble, args.similarity_metric, acc) 323 | json.dump(prediction, open("%s/prompt_answer/%s" % (args.output_path, output_name), 'w')) 324 | json.dump(format_prediction, open("%s/format_answer/%s" % (args.output_path, output_name), 'w')) 325 | 326 | 327 | if __name__ == '__main__': 328 | main() 329 | -------------------------------------------------------------------------------- /object_similarity/object_engineer_aokvqa.sh: -------------------------------------------------------------------------------- 1 | BASE="/PATH/TO/VisualCOT" 2 | engine=opt 3 | python object_engineering_aokvqa.py \ 4 | --apikey \ 5 | --output_path output \ 6 | --caption_type vinvl_sg \ 7 | --n_shot 8 \ 8 | --iterative_strategy caption \ 9 | --engine ${engine} \ 10 | --sg_path ${BASE}/input_text/scene_graph_text \ 11 | --with_six_gpus 12 | -------------------------------------------------------------------------------- /object_similarity/object_engineer_okvqa.sh: -------------------------------------------------------------------------------- 1 | BASE=/PATH/TO/VisualCOT 2 | engine=opt 3 | python object_engineering_okvqa.py \ 4 | --apikey \ 5 | --output_path output \ 6 | --caption_type vinvl_sg \ 7 | --n_shot 8 \ 8 | --iterative_strategy caption \ 9 | --engine ${engine} \ 10 | --sg_path ${BASE}/input_text/scene_graph_text \ 11 | --with_six_gpus 12 | -------------------------------------------------------------------------------- /object_similarity/object_engineering_aokvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import openai 5 | from transformers import CLIPTokenizer, CLIPTextModel 6 | import pdb 7 | import pickle 8 | 9 | def load_anno(coco_caption_file,answer_anno_file,question_anno_file): 10 | if coco_caption_file is not None: 11 | coco_caption = json.load(open(coco_caption_file,'r')) 12 | if type(coco_caption)==type({}): coco_caption = coco_caption['annotations'] 13 | answer_anno = json.load(open(answer_anno_file,'r')) 14 | question_anno = json.load(open(question_anno_file,'r')) 15 | 16 | caption_dict = {} 17 | if coco_caption_file is not None: 18 | for sample in coco_caption: 19 | if sample['image_id'] not in caption_dict: 20 | caption_dict[sample['image_id']] = [sample['caption']] 21 | else: 22 | caption_dict[sample['image_id']].append(sample['caption']) 23 | answer_dict = {} 24 | for sample in answer_anno: 25 | if str(sample['image_id'])+'<->'+str(sample['question_id']) not in answer_dict: 26 | answer_dict[str(sample['image_id'])+'<->'+str(sample['question_id'])] = sample['direct_answers'] 27 | 28 | question_dict = {} 29 | for sample in question_anno: 30 | if str(sample['image_id'])+'<->'+str(sample['question_id']) not in question_dict: 31 | question_dict[str(sample['image_id'])+'<->'+str(sample['question_id'])] = sample['question'] 32 | 33 | rationales_dict = {} 34 | for sample in answer_anno: 35 | if str(sample['image_id'])+'<->'+str(sample['question_id']) not in rationales_dict: 36 | rationales_dict[str(sample['image_id'])+'<->'+str(sample['question_id'])] = sample['rationales'] 37 | return caption_dict,answer_dict,question_dict,rationales_dict 38 | 39 | class AOKVQA: 40 | def __init__(self, args): 41 | self.args = args 42 | self.chain_of_thoughts = args.chain_of_thoughts 43 | ## loading input questions (and answer for reference accuracy computing) 44 | _,self.answer_dict,self.question_dict,self.rationale_dict = \ 45 | load_anno(None, '%s/aokvqa_v1p0_val.json'%args.coco_path, \ 46 | '%s/aokvqa_v1p0_val.json'%args.coco_path) 47 | self.val_keys = list(self.question_dict.keys()) 48 | self.val_keys = self.val_keys[int(args.start*len(self.val_keys)):int(args.end*len(self.val_keys))] 49 | 50 | self.traincontext_caption_dict,self.traincontext_answer_dict,\ 51 | self.traincontext_question_dict,self.traincontext_rationale_dict = \ 52 | load_anno('%s/captions_train2017.json'%args.coco_path, \ 53 | '%s/aokvqa_v1p0_train.json'%args.coco_path, \ 54 | '%s/aokvqa_v1p0_train.json'%args.coco_path) 55 | self.train_keys = list(self.traincontext_answer_dict.keys()) 56 | 57 | self.sg_dir = os.path.join(self.args.sg_path, "scene_graph_coco17") 58 | self.sg_attr_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_attr") 59 | self.sg_cap_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_caption") 60 | 61 | def get_related_obj_dict(self, key, metric, model=None, processor=None): 62 | img_context_key = int(key.split('<->')[0]) 63 | context_scene_graph = json.load(open(os.path.join(self.sg_dir, str(img_context_key).zfill(12) + ".json"))) 64 | context_scene_graph_attr = json.load( 65 | open(os.path.join(self.sg_attr_dir, str(img_context_key).zfill(12) + ".json"))) 66 | 67 | obj_list = [] 68 | conf_list = [] 69 | for obj in context_scene_graph[0]: 70 | if obj['class'] not in obj_list: 71 | obj_list.append(obj['class']) 72 | conf_list.append(obj['conf']) 73 | for obj in context_scene_graph_attr[0]: 74 | if obj['class'] not in obj_list: 75 | obj_list.append(obj['class']) 76 | conf_list.append(obj['conf']) 77 | 78 | related_obj_dict = {} 79 | if 'rationale' in metric: 80 | rationale = self.traincontext_rationale_dict[key] 81 | for obj in obj_list: 82 | for r in rationale: 83 | if obj in r: 84 | if obj not in related_obj_dict: 85 | related_obj_dict[obj] = 1 86 | else: 87 | related_obj_dict[obj] += 1 88 | elif 'answer' in metric: 89 | answer_list = self.traincontext_answer_dict[key] 90 | inputs = processor(text=answer_list, return_tensors="pt", padding=True) 91 | inputs = {k:v.cuda() for k,v in inputs.items()} 92 | outputs = model(**inputs) 93 | ans_text_emb = outputs['pooler_output'].mean(dim=0).unsqueeze(dim=0) 94 | 95 | inputs = processor(text=obj_list, return_tensors="pt", padding=True) 96 | inputs = {k:v.cuda() for k,v in inputs.items()} 97 | outputs = model(**inputs) 98 | cand_text_emb = outputs['pooler_output'] 99 | 100 | ans_text_emb /= ans_text_emb.norm(dim=-1, keepdim=True) 101 | cand_text_emb /= cand_text_emb.norm(dim=-1, keepdim=True) 102 | 103 | sim_cands = cand_text_emb @ ans_text_emb.T 104 | for idx, obj_name in enumerate(obj_list): 105 | related_obj_dict[obj_name] = sim_cands[idx, 0].cpu().item() 106 | return obj_list, conf_list, related_obj_dict 107 | 108 | def show_object_example(self): 109 | metric_list = ['rationale', 'answer', 'question'] 110 | prompt = 'Please select the object most related to the question.\n===\n' 111 | metric = metric_list[1] 112 | 113 | out_train_fn = f"./input_text/scene_graph_text/train_object_select_{metric}.pk" 114 | 115 | if 'answer' in metric: 116 | model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16") 117 | processor = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") 118 | model = model.cuda() 119 | else: 120 | model, processor = None, None 121 | 122 | out_object_sim_dict = {} 123 | 124 | for pid, img_ques in enumerate(self.train_keys): 125 | obj_list, conf_list, rel_obj_dict = self.get_related_obj_dict(img_ques, metric, model, processor) 126 | rel_obj = [k for k, v in sorted(rel_obj_dict.items(), key=lambda item: item[1], reverse=True)] 127 | 128 | prompt += 'Question: %s\n===\nOptions:\n' % (self.traincontext_question_dict[img_ques]) 129 | candidate_list = [cls for cls, conf in sorted(zip(obj_list, conf_list), key=lambda item: item[1], reverse=True)] 130 | candidate_list = candidate_list[:10] 131 | if rel_obj[0] not in candidate_list: 132 | candidate_list.append(rel_obj[0]) 133 | #random.shuffle(candidate_list) 134 | for oi, obj in enumerate(candidate_list): 135 | prompt += "%s: %s\n" % (chr(ord("A")+oi), obj) 136 | prompt += "The most related option is %s: %s\n\n===\n" % (chr(ord("A")+candidate_list.index(rel_obj[0])), rel_obj[0]) 137 | prompt += "The most related option %s\n\n===\n" % (rel_obj[0]) 138 | print(prompt) 139 | print("Answer: ") 140 | print(self.traincontext_answer_dict[img_ques]) 141 | pdb.set_trace() 142 | if pid % 100 ==0: 143 | print("%d/%d"%(pid, len(self.train_keys))) 144 | out_object_sim_dict[img_ques] = rel_obj_dict 145 | with open(out_train_fn, "wb") as fh: 146 | pickle.dump(out_object_sim_dict, fh) 147 | 148 | def main(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument('--apikey', type=str, required=True, help='api key; https://openai.com/api/') 151 | parser.add_argument('--engine', type=str, default='davinci', help='api engine; https://openai.com/api/') 152 | parser.add_argument('--caption_type', type=str, default='vinvl_tag', help='vinvl_tag, vinvl, vinvl_sg') 153 | parser.add_argument('--n_shot', type=int, default=16, help="number of shots") 154 | parser.add_argument('--n_ensemble', type=int, default=1, help="number of ensemble") 155 | parser.add_argument('--rounds', type=int, default=3, help="number of interactive rounds") 156 | parser.add_argument('--iterative_strategy', type=str, default="caption", help="caption or sg") 157 | parser.add_argument('--similarity_metric', type=str, default='imagequestion', help="random/question/imagequestion") 158 | parser.add_argument('--valcaption_file', type=str, default='input_text/vinvl_caption/VinVL_base_val2014.tsv') 159 | parser.add_argument('--tag_path', type=str, default='input_text/coco_caption_pred_tags') 160 | parser.add_argument('--sg_path', type=str, default='') 161 | parser.add_argument('--coco_path', type=str, default='coco_annotations') 162 | parser.add_argument('--similarity_path', type=str, default='coco_clip_new') 163 | parser.add_argument('--output_path', type=str, default='output') 164 | parser.add_argument('--chain_of_thoughts', action='store_true') 165 | parser.add_argument('--with_six_gpus', action='store_true') 166 | parser.add_argument('--start', type=float, default=0.0, help="start point in validation set (0.0-1.0)") 167 | parser.add_argument('--end', type=float, default=1.0, help="end point in validation set (0.0-1.0)") 168 | args = parser.parse_args() 169 | 170 | openai.api_key = args.apikey 171 | 172 | aokvqa = AOKVQA(args) 173 | 174 | aokvqa.show_object_example() 175 | 176 | if __name__ == '__main__': 177 | main() 178 | -------------------------------------------------------------------------------- /object_similarity/object_engineering_okvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import torch 5 | import openai 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import pickle 8 | 9 | def load_anno(coco_caption_file,answer_anno_file,question_anno_file): 10 | if coco_caption_file is not None: 11 | coco_caption = json.load(open(coco_caption_file,'r')) 12 | if type(coco_caption)==type({}): coco_caption = coco_caption['annotations'] 13 | answer_anno = json.load(open(answer_anno_file,'r')) 14 | question_anno = json.load(open(question_anno_file,'r')) 15 | 16 | caption_dict = {} 17 | if coco_caption_file is not None: 18 | for sample in coco_caption: 19 | if sample['image_id'] not in caption_dict: 20 | caption_dict[sample['image_id']] = [sample['caption']] 21 | else: 22 | caption_dict[sample['image_id']].append(sample['caption']) 23 | answer_dict = {} 24 | for sample in answer_anno["annotations"]: 25 | if str(sample['image_id'])+'<->'+str(sample['question_id']) not in answer_dict: 26 | answers = [ans['raw_answer'] for ans in sample['answers']] 27 | answer_dict[str(sample['image_id'])+'<->'+str(sample['question_id'])] = answers 28 | 29 | question_dict = {} 30 | for sample in question_anno['questions']: 31 | if str(sample['image_id'])+'<->'+str(sample['question_id']) not in question_dict: 32 | question_dict[str(sample['image_id'])+'<->'+str(sample['question_id'])] = sample['question'] 33 | 34 | rationales_dict = {} 35 | return caption_dict,answer_dict,question_dict,rationales_dict 36 | 37 | class OKVQA: 38 | def __init__(self, args): 39 | self.args = args 40 | self.chain_of_thoughts = args.chain_of_thoughts 41 | _,self.answer_dict,self.question_dict,self.rationale_dict = \ 42 | load_anno(None, '%s/mscoco_val2014_annotations.json'%args.coco_path, \ 43 | '%s/OpenEnded_mscoco_val2014_questions.json'%args.coco_path) 44 | self.val_keys = list(self.question_dict.keys()) 45 | self.val_keys = self.val_keys[int(args.start*len(self.val_keys)):int(args.end*len(self.val_keys))] 46 | 47 | self.traincontext_caption_dict,self.traincontext_answer_dict,\ 48 | self.traincontext_question_dict,self.traincontext_rationale_dict = \ 49 | load_anno('%s/captions_train2014.json'%args.coco_path, \ 50 | '%s/mscoco_train2014_annotations.json'%args.coco_path, \ 51 | '%s/OpenEnded_mscoco_train2014_questions.json'%args.coco_path) 52 | self.train_keys = list(self.traincontext_answer_dict.keys()) 53 | 54 | self.sg_dir = os.path.join(self.args.sg_path, "scene_graph_coco17") 55 | self.sg_attr_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_attr") 56 | self.sg_cap_dir = os.path.join(self.args.sg_path, "scene_graph_coco17_caption") 57 | 58 | def get_related_obj_dict(self, key, metric, model=None, processor=None): 59 | img_context_key = int(key.split('<->')[0]) 60 | context_scene_graph = json.load(open(os.path.join(self.sg_dir, str(img_context_key).zfill(12) + ".json"))) 61 | context_scene_graph_attr = json.load( 62 | open(os.path.join(self.sg_attr_dir, str(img_context_key).zfill(12) + ".json"))) 63 | 64 | obj_list = [] 65 | conf_list = [] 66 | for obj in context_scene_graph[0]: 67 | if obj['class'] not in obj_list: 68 | obj_list.append(obj['class']) 69 | conf_list.append(obj['conf']) 70 | for obj in context_scene_graph_attr[0]: 71 | if obj['class'] not in obj_list: 72 | obj_list.append(obj['class']) 73 | conf_list.append(obj['conf']) 74 | 75 | related_obj_dict = {} 76 | if 'rationale' in metric: 77 | rationale = self.traincontext_rationale_dict[key] 78 | for obj in obj_list: 79 | for r in rationale: 80 | if obj in r: 81 | if obj not in related_obj_dict: 82 | related_obj_dict[obj] = 1 83 | else: 84 | related_obj_dict[obj] += 1 85 | elif 'answer' in metric: 86 | with torch.no_grad(): 87 | answer_list = self.traincontext_answer_dict[key] 88 | inputs = processor(text=answer_list, return_tensors="pt", padding=True) 89 | inputs = {k: v.cuda() for k, v in inputs.items()} 90 | outputs = model(**inputs) 91 | ans_text_emb = outputs['pooler_output'].mean(dim=0).unsqueeze(dim=0) 92 | 93 | inputs = processor(text=obj_list, return_tensors="pt", padding=True) 94 | inputs = {k: v.cuda() for k, v in inputs.items()} 95 | outputs = model(**inputs) 96 | cand_text_emb = outputs['pooler_output'] 97 | 98 | ans_text_emb /= ans_text_emb.norm(dim=-1, keepdim=True) 99 | cand_text_emb /= cand_text_emb.norm(dim=-1, keepdim=True) 100 | 101 | sim_cands = cand_text_emb @ ans_text_emb.T 102 | for idx, obj_name in enumerate(obj_list): 103 | related_obj_dict[obj_name] = sim_cands[idx, 0].detach().cpu().item() 104 | return obj_list, conf_list, related_obj_dict 105 | 106 | def show_object_example(self): 107 | metric_list = ['rationale', 'answer', 'question'] 108 | prompt = 'Please select the object most related to the question.\n===\n' 109 | metric = metric_list[1] 110 | 111 | out_train_fn = "./input_text/scene_graph_text/train_object_select_okvqa.pk" 112 | 113 | if 'answer' in metric: 114 | model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch16") 115 | processor = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16") 116 | model = model.cuda() 117 | else: 118 | model, processor = None, None 119 | 120 | out_object_sim_dict = {} 121 | 122 | for pid, img_ques in enumerate(self.train_keys): 123 | obj_list, conf_list, rel_obj_dict = self.get_related_obj_dict(img_ques, metric, model, processor) 124 | rel_obj = [k for k, v in sorted(rel_obj_dict.items(), key=lambda item: item[1], reverse=True)] 125 | 126 | prompt += 'Question: %s\n===\nOptions:\n' % (self.traincontext_question_dict[img_ques]) 127 | candidate_list = [cls for cls, conf in sorted(zip(obj_list, conf_list), key=lambda item: item[1], reverse=True)] 128 | if rel_obj[0] not in candidate_list: 129 | candidate_list.append(rel_obj[0]) 130 | for oi, obj in enumerate(candidate_list): 131 | prompt += "%s: %s\n" % (chr(ord("A")+oi), obj) 132 | prompt += "The most related option is %s: %s\n\n===\n" % (chr(ord("A")+candidate_list.index(rel_obj[0])), rel_obj[0]) 133 | prompt += "The most related option %s\n\n===\n" % (rel_obj[0]) 134 | if pid % 100 ==0: 135 | print("%d/%d"%(pid, len(self.train_keys))) 136 | out_object_sim_dict[img_ques] = rel_obj_dict 137 | with open(out_train_fn, "wb") as fh: 138 | pickle.dump(out_object_sim_dict, fh) 139 | 140 | def main(): 141 | parser = argparse.ArgumentParser() 142 | parser.add_argument('--apikey', type=str, required=True, help='api key; https://openai.com/api/') 143 | parser.add_argument('--engine', type=str, default='davinci', help='api engine; https://openai.com/api/') 144 | parser.add_argument('--caption_type', type=str, default='vinvl_tag', help='vinvl_tag, vinvl, vinvl_sg') 145 | parser.add_argument('--n_shot', type=int, default=16, help="number of shots") 146 | parser.add_argument('--n_ensemble', type=int, default=1, help="number of ensemble") 147 | parser.add_argument('--rounds', type=int, default=3, help="number of interactive rounds") 148 | parser.add_argument('--iterative_strategy', type=str, default="caption", help="caption or sg") 149 | parser.add_argument('--similarity_metric', type=str, default='imagequestion', help="random/question/imagequestion") 150 | parser.add_argument('--valcaption_file', type=str, default='input_text/vinvl_caption/VinVL_base_val2014.tsv') 151 | parser.add_argument('--tag_path', type=str, default='input_text/coco_caption_pred_tags') 152 | parser.add_argument('--sg_path', type=str, default='') 153 | parser.add_argument('--coco_path', type=str, default='coco_annotations') 154 | parser.add_argument('--similarity_path', type=str, default='coco_clip_new') 155 | parser.add_argument('--output_path', type=str, default='output') 156 | parser.add_argument('--chain_of_thoughts', action='store_true') 157 | parser.add_argument('--with_six_gpus', action='store_true') 158 | parser.add_argument('--start', type=float, default=0.0, help="start point in validation set (0.0-1.0)") 159 | parser.add_argument('--end', type=float, default=1.0, help="end point in validation set (0.0-1.0)") 160 | args = parser.parse_args() 161 | 162 | openai.api_key = args.apikey 163 | 164 | okvqa = OKVQA(args) 165 | 166 | okvqa.show_object_example() 167 | 168 | if __name__ == '__main__': 169 | main() 170 | -------------------------------------------------------------------------------- /preprocess/make_clip_features.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import json 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import CLIPProcessor, CLIPModel 7 | import argparse 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--questions', type=str, required=True, help='path to questions') 11 | parser.add_argument('--images', type=str, required=True, help='path to coco images') 12 | parser.add_argument('--qfeatures', type=str, required=True, help='output features path for questions') 13 | parser.add_argument('--ifeatures', type=str, required=True, help='output features path for images') 14 | args = parser.parse_args() 15 | 16 | dataset = json.load(open(args.questions)) 17 | if 'question' in dataset: 18 | # VQAv2 19 | dataset = dataset['questions'] 20 | else: 21 | dataset = [{'question': d['question'], 22 | 'image_id': d['image_id']} for d in dataset] 23 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16") 24 | model = model.cuda() 25 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") 26 | text_embeds_list = [] 27 | image_embeds_list = [] 28 | for q in tqdm(dataset): 29 | imageID = q['image_id'] 30 | file_name = args.images+str(imageID).zfill(12)+".jpg" 31 | image = Image.open(file_name) 32 | inputs = processor(text=[q['question']], images=image, return_tensors="pt", padding=True) 33 | inputs = {k:v.cuda() for k,v in inputs.items()} 34 | outputs = model(**inputs) 35 | text_embeds_list.append(outputs['text_embeds'].detach().cpu()) 36 | image_embeds_list.append(outputs['image_embeds'].detach().cpu()) 37 | text_embeds_list = torch.cat(text_embeds_list, dim=0) 38 | image_embeds_list = torch.cat(image_embeds_list, dim=0) 39 | np.save(args.ifeatures, image_embeds_list.numpy()) 40 | np.save(args.qfeatures, text_embeds_list.numpy()) 41 | -------------------------------------------------------------------------------- /preprocess/make_line2sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument('--input', type=str, required=True, help='path to input') 6 | parser.add_argument('--output', type=str, required=True, help='path to output') 7 | args = parser.parse_args() 8 | 9 | inputs = json.load(open(args.input)) 10 | output_dict = {} 11 | if 'questions' in inputs: 12 | # VQA v2 13 | for idx, q in enumerate(inputs['questions']): 14 | val = "<->".join([str(q['image_id']), str(q['question_id'])]) 15 | output_dict[str(idx)] = val 16 | else: 17 | # A-OKVQA 18 | for idx, q in enumerate(inputs): 19 | val = "<->".join([str(q['image_id']), str(q['question_id'])]) 20 | output_dict[str(idx)] = val 21 | json.dump(output_dict, open(args.output, "w")) 22 | -------------------------------------------------------------------------------- /preprocess/preprocess_aokvqa.sh: -------------------------------------------------------------------------------- 1 | BASE=/PATH/TO/VisualCOT 2 | COCO17PATH=/PATH/TO/coco17 3 | COCO14PATH=/PATH/TO/coco14 4 | 5 | # reorganize train / val split 6 | python reorganize_captions.py \ 7 | --coco14root ${COCO14PATH} \ 8 | --coco17root ${COCO17PATH} \ 9 | --caption14train ${BASE}/coco_annotations/captions_train2014.json \ 10 | --caption14val ${BASE}/coco_annotations/captions_val2014.json \ 11 | --caption17train ${BASE}/coco_annotations/captions_train2017.json \ 12 | --caption17val ${BASE}/coco_annotations/captions_val2017.json 13 | for SPLIT in "train" "val" 14 | do 15 | # make line2sample 16 | python make_line2sample.py \ 17 | --input ${BASE}/coco_annotations/aokvqa_v1p0_${SPLIT}.json \ 18 | --output ${BASE}/coco_clip_new/aokvqa_qa_line2sample_idx_${SPLIT}2017.json 19 | # make clip features 20 | python make_clip_features.py \ 21 | --questions ${BASE}/coco_annotations/aokvqa_v1p0_${SPLIT}.json \ 22 | --images ${COCO17PATH}/${SPLIT}2017/ \ 23 | --ifeatures ${BASE}/coco_clip_new/coco_clip_vitb16_${SPLIT}2017_aokvqa_convertedidx_image.npy \ 24 | --qfeatures ${BASE}/coco_clip_new/coco_clip_vitb16_${SPLIT}2017_aokvqa_question.npy 25 | done 26 | -------------------------------------------------------------------------------- /preprocess/preprocess_okvqa.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import json 3 | import os 4 | 5 | BASE_PATH = "/PATH/TO/VisualCOT" 6 | 7 | def reorganize_captions(): 8 | cap_path = "/pathto/val_refine.json" 9 | sg_attr_dir = f"{BASE_PATH}/input_text/scene_graph_text/scene_graph_coco17_attr" 10 | out_dir = f"{BASE_PATH}/input_text/scene_graph_text/scene_graph_coco14_caption_ok" 11 | 12 | if not os.path.isdir(out_dir): 13 | os.mkdir(out_dir) 14 | 15 | cap_dict_list = json.load(open(cap_path)) 16 | 17 | def cap_dict_to_img(cap_dict_list): 18 | cap_dict_img = {} 19 | for idx, rgn_info in enumerate(cap_dict_list): 20 | img_id = rgn_info["image_id"].split("[")[0] 21 | if img_id not in cap_dict_img: 22 | cap_dict_img[img_id] = [] 23 | cap_dict_img[img_id].append(rgn_info) 24 | return cap_dict_img 25 | 26 | cap_dict_img = cap_dict_to_img(cap_dict_list) 27 | 28 | for img_id_str, rgn_list in cap_dict_img.items(): 29 | 30 | out_img_path = os.path.join(out_dir, img_id_str.zfill(12)+".json") 31 | 32 | def rgn2dict(rgn_list): 33 | rgn_dict = {} 34 | for rgn in rgn_list: 35 | rgn_box_id = rgn["image_id"].split("[")[1] 36 | rgn_box_id = "[" + rgn_box_id 37 | if rgn_box_id not in rgn_dict: 38 | rgn_dict[rgn_box_id] = [rgn] 39 | else: 40 | rgn_dict[rgn_box_id].append(rgn) 41 | return rgn_dict 42 | 43 | rgn_dict = rgn2dict(rgn_list) 44 | 45 | scene_graph_attr = json.load(open(os.path.join(sg_attr_dir, img_id_str.zfill(12) + ".json"))) 46 | cap_list = [] 47 | 48 | for idx, rgn in enumerate(scene_graph_attr[0]): 49 | rgn_id = str(rgn["rect"]) 50 | if rgn_id in rgn_dict: 51 | if len(rgn_dict[rgn_id])==1: 52 | cap_list.append(rgn_dict[rgn_id][0]["caption"]) 53 | else: 54 | find_valid_flag = False 55 | for tmp_idx in range(len(rgn_dict[rgn_id])): 56 | tmp_dict = rgn_dict[rgn_id][tmp_idx] 57 | if rgn["class"] in tmp_dict["concept"]: 58 | cap_list.append(rgn_dict[rgn_id][tmp_idx]["caption"]) 59 | find_valid_flag = True 60 | break 61 | #assert find_valid_flag 62 | if not find_valid_flag: 63 | import pdb; pdb.set_trace() 64 | else: 65 | attr_str = "" 66 | if len(rgn["attr_conf"])>0: 67 | val = max(rgn["attr_conf"]) 68 | idx = rgn["attr_conf"].index(val) 69 | attr_str = "%s "%(rgn["attr"][idx]) 70 | fake_cap = "%s %s"%(attr_str, rgn["class"]) 71 | fake_cap = fake_cap.strip() 72 | cap_list.append(fake_cap) 73 | print(rgn) 74 | print(fake_cap) 75 | with open(out_img_path, "w") as fh: 76 | json.dump(cap_list, fh) 77 | #import pdb; pdb.set_trace() 78 | 79 | if __name__=="__main__": 80 | reorganize_captions() -------------------------------------------------------------------------------- /preprocess/reorganize_captions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | def find_files(dir, is14=False): 6 | files = [f.split(".")[0] for f in os.listdir(dir) if os.path.isfile(os.path.join(dir, f))] 7 | if is14: 8 | files = [f.split("_")[-1] for f in files] 9 | return files 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--coco14root', type=str, required=True, help='path to coco14') 13 | parser.add_argument('--coco17root', type=str, required=True, help='path to coco17') 14 | parser.add_argument('--caption14train', type=str, required=True, help='path to 14 caption train') 15 | parser.add_argument('--caption14val', type=str, required=True, help='path to 14 caption val') 16 | parser.add_argument('--caption17train', type=str, required=True, help='output path to 14 caption train') 17 | parser.add_argument('--caption17val', type=str, required=True, help='output path to 14 caption val') 18 | args = parser.parse_args() 19 | 20 | train17dir = os.path.join(args.coco17root, "train2017") 21 | val17dir = os.path.join(args.coco17root, "val2017") 22 | 23 | train14dir = os.path.join(args.coco14root, "train2014") 24 | val14dir = os.path.join(args.coco14root, "val2014") 25 | 26 | split14 = [find_files(train14dir, is14=True), find_files(val14dir, is14=True)] 27 | split17 = [find_files(train17dir, is14=False), find_files(val17dir, is14=False)] 28 | 29 | caption14train = json.load(open(args.caption14train))['annotations'] 30 | caption14val = json.load(open(args.caption14val))['annotations'] 31 | caption14 = caption14train + caption14val 32 | caption14_dict = {} 33 | for c in caption14: 34 | if c['image_id'] not in caption14_dict: 35 | caption14_dict[c['image_id']] = [c['caption']] 36 | else: 37 | caption14_dict[c['image_id']].append(c['caption']) 38 | 39 | caption17train = { 40 | "annotations": [] 41 | } 42 | caption17val = { 43 | "annotations": [] 44 | } 45 | 46 | for iid in split17[0]: 47 | iid = int(iid) 48 | captions = caption14_dict[iid] 49 | for cp in captions: 50 | caption17train["annotations"].append( 51 | {"image_id": iid, 52 | "caption": cp} 53 | ) 54 | for iid in split17[1]: 55 | iid = int(iid) 56 | captions = caption14_dict[iid] 57 | for cp in captions: 58 | caption17val["annotations"].append( 59 | {"image_id": iid, 60 | "caption": cp} 61 | ) 62 | 63 | json.dump(caption17train, open(args.caption17train, "w")) 64 | json.dump(caption17val, open(args.caption17val, "w")) 65 | -------------------------------------------------------------------------------- /run_aokvqa.sh: -------------------------------------------------------------------------------- 1 | BASE="/PATH/TO/VisualCOT" 2 | engine=opt 3 | # engine choices: 4 | # opt, llama, bloom 5 | # 'ada', 'babbage', 'curie', 'davinci' (gpt-3), chat (gpt-3.5-turbo) 6 | # chat-test (for debug only) 7 | apikey=YOUR_OPENAI_API_KEY 8 | 9 | # parallel run 24 jobs 10 | for i in {0..23} 11 | do 12 | start=$(echo 'scale=4;'$i'/24' | bc) 13 | end=$(echo 'scale=4;('$i'+1)/24' | bc) 14 | echo $start 15 | echo $end 16 | python main_aokvqa.py \ 17 | --apikey ${apikey} \ 18 | --output_path output/${engine}/${i} \ 19 | --caption_type vinvl_ocr \ 20 | --n_shot 8 \ 21 | --n_ensemble 5 \ 22 | --rounds 5 \ 23 | --iterative_strategy caption \ 24 | --engine ${engine} \ 25 | --sg_path ${BASE}/input_text/scene_graph_text \ 26 | --train_sim_metric answer \ 27 | --train_sim_file "./input_text/scene_graph_text/train_object_select_answer_sim.pk" \ 28 | --chain_of_thoughts \ 29 | --start ${start} \ 30 | --with_clip_verify \ 31 | --end ${end} & 32 | # use --llama_path if engine == "llama" 33 | done 34 | -------------------------------------------------------------------------------- /run_okvqa.sh: -------------------------------------------------------------------------------- 1 | BASE="/PATH/TO/VisualCOT" 2 | engine=opt 3 | apikey=YOUR_OPENAI_API_KEY 4 | for i in {0..79} 5 | do 6 | start=$(echo 'scale=4;'$i'/80' | bc) 7 | end=$(echo 'scale=4;('$i'+1)/80' | bc) 8 | echo $start 9 | echo $end 10 | python main_okvqa.py \ 11 | --apikey ${apikey} \ 12 | --output_path output/okvqa_llama2/${i} \ 13 | --caption_type vinvl_ocr \ 14 | --n_shot 8 \ 15 | --n_ensemble 5 \ 16 | --rounds 5 \ 17 | --engine ${engine} \ 18 | --iterative_strategy caption \ 19 | --sg_path ${BASE}/input_text/scene_graph_text \ 20 | --train_sim_metric answer \ 21 | --train_sim_file "./input_text/scene_graph_text/train_object_select_okvqa.pk" \ 22 | --with_clip_verify \ 23 | --start ${start} \ 24 | --with_ok_context \ 25 | --end ${end} & 26 | done 27 | -------------------------------------------------------------------------------- /utils_api.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with OpenAI GPT APIs. 2 | """ 3 | 4 | import random 5 | import functools 6 | import json 7 | import logging 8 | import os 9 | from io import BytesIO 10 | import time 11 | from multiprocessing import shared_memory 12 | 13 | import numpy as np 14 | import requests 15 | 16 | from concurrent.futures import ThreadPoolExecutor 17 | 18 | import openai 19 | from openai import error as openai_error 20 | 21 | 22 | def openai_complete( 23 | prompts, 24 | max_length, 25 | temperature, 26 | num_sampling=1, 27 | best_of=1, 28 | internal_batch_size=None, 29 | internal_num_sampling=None, 30 | sleep_time=3.0, # This is because of the rate limit: 20.000000 / min 31 | stop_token=None, 32 | logit_bias=None, 33 | presence_penalty=0.0, 34 | frequency_penalty=0.0, 35 | logprobs=None, 36 | top_p=1.0, 37 | ): 38 | """OpenAI API call. 39 | Args: 40 | prompts: list of prompts 41 | max_length: max length of the output 42 | temperature: temperature of the output 43 | num_sampling: number of sampling 44 | best_of: number of best of 45 | internal_batch_size: internal batch size 46 | internal_num_sampling: internal number of sampling 47 | sleep_time: sleep time to avoid rate limit 48 | stop_token: stop token 49 | logit_bias: logit bias 50 | presence_penalty: presence penalty 51 | frequency_penalty: frequency penalty 52 | logprobs: logprobs 53 | top_p: top p 54 | Returns: 55 | list of responses 56 | """ 57 | if type(prompts) is str: 58 | prompts = [prompts] 59 | 60 | def openai_api_call(prompts, api_key, organization): 61 | time.sleep(sleep_time + random.random()) 62 | all_response = [] 63 | 64 | 65 | all_logprobs = [] 66 | accumulated_sleep_time = sleep_time 67 | if len(prompts) > 0: 68 | create_fn = openai.Completion.create 69 | 70 | if logit_bias is not None: 71 | create_fn = functools.partial(create_fn, logit_bias=json.loads(logit_bias)) 72 | 73 | if logprobs is not None: 74 | create_fn = functools.partial(create_fn, logprobs=logprobs) 75 | 76 | if internal_batch_size is None: 77 | responses, accumulated_sleep_time = call_openai_internal( 78 | create_fn, prompts, max_length, best_of, num_sampling, stop_token, 79 | temperature, presence_penalty, frequency_penalty, top_p, api_key, organization, 80 | accumulated_sleep_time, sleep_time 81 | ) 82 | all_response = [_["text"] for _ in responses["choices"]] 83 | 84 | if logprobs is not None: 85 | all_logprobs = [_["logprobs"] for _ in responses["choices"]] 86 | else: 87 | for start_idx in range(0, len(prompts), internal_batch_size): 88 | sub_prompts = prompts[start_idx:start_idx + internal_batch_size] 89 | 90 | if internal_num_sampling is None: 91 | responses, accumulated_sleep_time = call_openai_internal( 92 | create_fn, sub_prompts, max_length, best_of, num_sampling, stop_token, 93 | temperature, presence_penalty, frequency_penalty, top_p, api_key, organization, 94 | accumulated_sleep_time, sleep_time 95 | ) 96 | if start_idx < len(prompts) - internal_batch_size: 97 | time.sleep(accumulated_sleep_time + random.random()) 98 | all_response.extend([_["text"] for _ in responses["choices"]]) 99 | if logprobs is not None: 100 | all_logprobs.extend([_["logprobs"] for _ in responses["choices"]]) 101 | 102 | else: 103 | assert num_sampling == best_of 104 | assert num_sampling % internal_num_sampling == 0 105 | 106 | responses = dict() 107 | responses["choices"] = [] 108 | stacked_responses = [] 109 | for i in range(num_sampling // internal_num_sampling): 110 | response_choices, accumulated_sleep_time = call_openai_internal( 111 | create_fn, sub_prompts, max_length, internal_num_sampling, internal_num_sampling, stop_token, 112 | temperature, presence_penalty, frequency_penalty, top_p, api_key, organization, 113 | accumulated_sleep_time, sleep_time 114 | ) 115 | stacked_responses.append(response_choices["choices"]) 116 | if start_idx < len(prompts) - internal_batch_size or i < num_sampling // internal_num_sampling - 1: 117 | time.sleep(accumulated_sleep_time + random.random()) 118 | 119 | for i in range(len(stacked_responses[0])): 120 | for j in range(len(stacked_responses)): 121 | responses["choices"].append(stacked_responses[j][i]) 122 | 123 | all_response.extend([_["text"] for _ in responses["choices"]]) 124 | if logprobs is not None: 125 | all_logprobs.extend([_["logprobs"] for _ in responses["choices"]]) 126 | return all_response, all_logprobs 127 | else: 128 | return None 129 | 130 | api_dicts = [] 131 | multiple_api_key_file = "scripts/openai_keys.json" 132 | if os.path.exists(multiple_api_key_file): 133 | with open(multiple_api_key_file, "r") as f: 134 | lines = f.readlines() 135 | lines = "".join([_.strip() for _ in lines]) 136 | lines = lines.replace("}{", "}[split]{") 137 | lines = lines.split("[split]") 138 | for line in lines: 139 | api_dicts.append(json.loads(line)) 140 | 141 | if len(api_dicts) == 0: 142 | api_dicts = [{"api_key": openai.api_key, "organization": openai.organization}] 143 | 144 | targets = [] 145 | targets_logprobs = [] 146 | 147 | logging.info("Using %d API keys" % len(api_dicts)) 148 | with ThreadPoolExecutor(max_workers=len(api_dicts)) as executor: 149 | futures = [] 150 | for batch_idx, api_dict in enumerate(api_dicts): 151 | single_process_batch_size = ((len(prompts) - 1) // len(api_dicts)) + 1 152 | start_idx = single_process_batch_size * batch_idx 153 | end_idx = single_process_batch_size * (batch_idx + 1) 154 | 155 | if batch_idx == len(api_dicts) - 1: 156 | single_process_prompts = prompts[start_idx:] 157 | else: 158 | single_process_prompts = prompts[start_idx:end_idx] 159 | 160 | futures.append( 161 | executor.submit( 162 | openai_api_call, 163 | single_process_prompts, 164 | api_dict["api_key"], 165 | api_dict["organization"], 166 | )) 167 | 168 | for future in futures: 169 | responses = future.result() 170 | if responses is not None: 171 | targets.extend(responses[0]) 172 | targets_logprobs.extend(responses[1]) 173 | 174 | if len(targets_logprobs) > 0: 175 | return targets, targets_logprobs 176 | else: 177 | return targets 178 | 179 | 180 | 181 | def call_openai_internal(create_fn, prompts, max_length, best_of, num_sampling, stop_token, 182 | temperature, presence_penalty, frequency_penalty, top_p, api_key, organization, 183 | accumulated_sleep_time, sleep_time): 184 | """Call OpenAI API with retry.""" 185 | responses = None 186 | while responses is None: 187 | try: 188 | responses = create_fn( 189 | model="code-davinci-002", 190 | prompt=prompts, 191 | max_tokens=max_length, 192 | best_of=best_of, 193 | stop=stop_token, 194 | temperature=temperature, 195 | n=num_sampling, 196 | api_key=api_key, 197 | organization=organization, 198 | presence_penalty=presence_penalty, 199 | frequency_penalty=frequency_penalty, 200 | top_p=top_p, 201 | ) 202 | except openai.error.RateLimitError as e: 203 | print(e) 204 | print(f"Batch size: {len(prompts)}, best_of: {best_of}, max_tokens: {max_length}") 205 | time.sleep(accumulated_sleep_time) 206 | accumulated_sleep_time += sleep_time 207 | except openai.error.APIError as e: 208 | print(e) 209 | print(f"Batch size: {len(prompts)}, best_of: {best_of}, max_tokens: {max_length}") 210 | print("API-Key:", api_key, "Organization:", organization) 211 | time.sleep(accumulated_sleep_time) 212 | accumulated_sleep_time += sleep_time 213 | except openai_error.Timeout as e: 214 | print(e) 215 | print("API-Key:", api_key, "Organization:", organization) 216 | time.sleep(accumulated_sleep_time) 217 | accumulated_sleep_time += sleep_time 218 | except openai_error.APIConnectionError as e: 219 | print(e) 220 | print("API-Key:", api_key, "Organization:", organization) 221 | time.sleep(accumulated_sleep_time) 222 | accumulated_sleep_time += sleep_time 223 | return responses, sleep_time 224 | 225 | 226 | def blip_complete( 227 | images, 228 | texts, 229 | blip_urls, 230 | max_length=10, 231 | temperature=1.0, 232 | num_beams=5, 233 | length_penalty=-1.0, 234 | internal_batch_size=None, 235 | ): 236 | """BLIP API call. 237 | Args: 238 | images: list of images, as numpy arrays 239 | texts: list of texts 240 | blip_urls: list of blip api urls 241 | max_length: max length of the output 242 | temperature: temperature of the output 243 | num_beams: number of beams 244 | length_penalty: length penalty 245 | internal_batch_size: internal batch size 246 | Returns: 247 | list of responses 248 | """ 249 | assert len(images) == len(texts) 250 | 251 | def blip_api_call(paired_image_text, url): 252 | response = None 253 | if len(paired_image_text) > 0: 254 | images = np.concatenate([img for img, _ in paired_image_text], axis=0) 255 | questions = [text for _, text in paired_image_text] 256 | 257 | port_number = url.split(":")[2].split("/")[0] 258 | NP_DATA_TYPE = np.float32 259 | MAX_BATCH_SIZE = 512 260 | NP_SHARED_NAME = f'npshared_{port_number}' 261 | shape_size = MAX_BATCH_SIZE * (224 * 224 * 3) 262 | d_size = np.dtype(NP_DATA_TYPE).itemsize * shape_size 263 | shm = shared_memory.SharedMemory(name=NP_SHARED_NAME, create=True, size=d_size) 264 | 265 | shared_images = np.ndarray((shape_size,), dtype=NP_DATA_TYPE, buffer=shm.buf) 266 | shared_images[:images.reshape(-1).shape[0]] = images.reshape(-1) 267 | shm.close() 268 | 269 | req = { 270 | "images_shape": images.shape, 271 | "texts": questions, 272 | "max_length": max_length, 273 | "temperature": temperature, 274 | "num_beams": num_beams, 275 | "length_penalty": length_penalty, 276 | } 277 | 278 | if internal_batch_size is not None: 279 | req["internal_batch_size"] = internal_batch_size 280 | 281 | res = requests.post(url, json=req) 282 | response = res.json()["targets"] 283 | shm.unlink() 284 | return response 285 | 286 | targets = [] 287 | 288 | with ThreadPoolExecutor(max_workers=len(blip_urls)) as executor: 289 | futures = [] 290 | for batch_idx, url in enumerate(blip_urls): 291 | single_process_batch_size = ((len(images) - 1) // len(blip_urls)) + 1 292 | start_idx = single_process_batch_size * batch_idx 293 | end_idx = single_process_batch_size * (batch_idx + 1) 294 | 295 | if batch_idx == len(blip_urls) - 1: 296 | single_process_paired_image_text = list(zip(images[start_idx:], texts[start_idx:])) 297 | else: 298 | single_process_paired_image_text = list(zip(images[start_idx:end_idx], texts[start_idx:end_idx])) 299 | 300 | futures.append( 301 | executor.submit( 302 | blip_api_call, 303 | single_process_paired_image_text, 304 | url, 305 | )) 306 | 307 | for future in futures: 308 | response = future.result() 309 | if response is not None: 310 | targets.extend(response) 311 | 312 | return targets 313 | 314 | 315 | def blip_completev2( 316 | images, 317 | texts, 318 | blip_urls, 319 | max_length=10, 320 | temperature=1.0, 321 | num_beams=5, 322 | length_penalty=-1.0, 323 | internal_batch_size=None, 324 | encoding_format="JPEG", 325 | ): 326 | """BLIP API call. 327 | Args: 328 | images: list of images, as numpy arrays 329 | texts: list of texts 330 | blip_urls: list of blip api urls 331 | max_length: max length of the output 332 | temperature: temperature of the output 333 | num_beams: number of beams 334 | length_penalty: length penalty 335 | internal_batch_size: internal batch size 336 | encoding_format: encoding format of the image 337 | Returns: 338 | list of responses 339 | """ 340 | assert len(images) == len(texts) 341 | 342 | def blip_api_call(paired_image_text, url): 343 | response = None 344 | if len(paired_image_text) > 0: 345 | headers = { 346 | "User-Agent": "BLIP-2 HuggingFace Space", 347 | } 348 | 349 | prompts = [text for _, text in paired_image_text] 350 | 351 | data = { 352 | "prompts": "[split]".join(prompts), 353 | "temperature": temperature, 354 | "length_penalty": length_penalty, 355 | "num_beams": num_beams, 356 | "max_length": max_length, 357 | } 358 | 359 | if internal_batch_size is not None: 360 | data["internal_batch_size"] = internal_batch_size 361 | 362 | files = {} 363 | for idx, (image, _) in enumerate(paired_image_text): 364 | image = encode_image(image, encoding_format=encoding_format) 365 | files[f"image{idx}"] = image 366 | 367 | response = requests.post(url, data=data, files=files, headers=headers).json() 368 | return response 369 | 370 | targets = [] 371 | 372 | with ThreadPoolExecutor(max_workers=len(blip_urls)) as executor: 373 | futures = [] 374 | for batch_idx, url in enumerate(blip_urls): 375 | single_process_batch_size = ((len(images) - 1) // len(blip_urls)) + 1 376 | start_idx = single_process_batch_size * batch_idx 377 | end_idx = single_process_batch_size * (batch_idx + 1) 378 | 379 | if batch_idx == len(blip_urls) - 1: 380 | single_process_paired_image_text = list(zip(images[start_idx:], texts[start_idx:])) 381 | else: 382 | single_process_paired_image_text = list(zip(images[start_idx:end_idx], texts[start_idx:end_idx])) 383 | 384 | futures.append( 385 | executor.submit( 386 | blip_api_call, 387 | single_process_paired_image_text, 388 | url, 389 | )) 390 | 391 | for future in futures: 392 | response = future.result() 393 | if response is not None: 394 | targets.extend(response) 395 | 396 | return targets 397 | 398 | 399 | def encode_image(image, encoding_format="JPEG"): 400 | buffered = BytesIO() 401 | image.save(buffered, format=encoding_format) 402 | buffered.seek(0) 403 | return buffered 404 | --------------------------------------------------------------------------------