├── .gitattributes ├── .gitignore ├── FROMAGe_example_notebook.ipynb ├── LICENSE ├── README.md ├── datasets ├── cc3m_train.tsv └── cc3m_val.tsv ├── demo ├── app.py └── share_btn.py ├── evals ├── VIST_Contextual_Image_Retrieval.ipynb ├── VisDial_Inference_IT2T_Generation.ipynb ├── VisDial_Inference_T2I_Retrieval.ipynb ├── eval_visdial_generation.py ├── eval_visdial_retrieval.py └── eval_vist_retrieval.py ├── fromage ├── __init__.py ├── data.py ├── evaluate.py ├── extract_img_embs.py ├── losses.py ├── models.py ├── prune_model_ckpt.py └── utils.py ├── fromage_model ├── fromage_vis4 │ ├── model_args.json │ └── pretrained_ckpt.pth.tar ├── model_args.json └── pretrained_ckpt.pth.tar ├── main.py ├── requirements.txt ├── teaser.png ├── teaser_gif.gif └── test_main.py /.gitattributes: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/.gitattributes -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | __pycache__ 4 | .pytest_cache 5 | venv 6 | runs/ 7 | data/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Grounding Language Models to Images for Multimodal Inputs and Outputs 2 | 3 |

4 | FROMAGe model architecture 5 |

6 | FROMAGe chat animation 7 |

8 | 9 | This repository hosts the code and model weights for FROMAGe. 10 | 11 | [Paper](https://arxiv.org/abs/2301.13823) | [Project Webpage](https://jykoh.com/fromage) | [Demo](https://huggingface.co/spaces/jykoh/fromage) 12 | 13 | 14 | ## Setup instructions 15 | 16 | ### Environment 17 | Set up a new virtualenv, and install required libraries: 18 | ``` 19 | python -m venv venv 20 | source venv/bin/activate 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | Add the `fromage` library to PYTHONPATH: 25 | ``` 26 | export PYTHONPATH=$PYTHONPATH:/home/path/to/fromage/ 27 | ``` 28 | 29 | ### Pretrained Checkpoints 30 | 31 | The FROMAGe model weights (linear layers and [RET] embedding) are small (around 11MB), and are included in this Git repo. They will be in the `fromage_model/` folder after cloning. The checkpoint and model config in `fromage_model/` reproduce the results reported in our paper. 32 | 33 | We have also included a second model trained with a stronger visual linear layer (4 visual tokens instead of 1), located at `fromage_model/fromage_vis4`. This model generally does better on dialogue settings and does not require as much tuning of inference time hyperparameters, as it is able to better represent more complex images. 34 | 35 | ### Precomputed Embeddings For Image Retrieval 36 | 37 | The visual embeddings for Conceptual Captions images with valid URLs are precomputed and stored at this [URL](https://drive.google.com/file/d/1wMojZNqEwApNlsCZVvSgQVtZLgbeLoKi/view?usp=share_link). These are used to enable the model to retrieve images. The embeddings take up around 3GB, and are compatible with both model configs we provide. Download the files and place `cc3m_embeddings.pkl` into the `fromage_model/` directory. 38 | 39 | If you wish to precompute these embeddings for a different set of image URLs or for a different model, edit `fromage/extract_img_embs.py` with the list of image URLs and run it: 40 | 41 | ```python fromage/extract_img_embs.py``` 42 | 43 | 44 | ## Inference 45 | 46 | Check out `FROMAGe_example_notebook.ipynb` for examples on calling the model for inference. Several of the figures presented in the paper are reproduced in this notebook using greedy decoding of the model. Note that there may be minor differences in image outputs due to CC3M images being lost over time. 47 | 48 | 49 | ## Training 50 | 51 | ### Preparing CC3M 52 | 53 | Our model is trained on the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions) dataset. After following the instructions on the website to download the captions and images, format it into a `.tsv` file as follows: 54 | 55 | ``` 56 | caption image 57 | A picture of a cat cat.png 58 | Mountains mountain.png 59 | ``` 60 | where each line contains the caption followed by the filename of the image files. Save these `.tsv` files into the `dataset/` folder (the default names expected are `cc3m_train.tsv` and `cc3m_val.tsv`). The repo contains two placeholder files, and you will have to replace them with the appropriate data. 61 | 62 | The corresponding image files should be saved in the `data/` directory. The directory can be changed with the `--image-dir` runtime flag. 63 | 64 | 65 | ### Training FROMAGe 66 | 67 | After preparing CC3M as detailed above, you can start a new training job with the following command line flag: 68 | 69 | ``` 70 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number 71 | python -u main.py \ 72 | --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \ 73 | --multiprocessing-distributed --world-size 1 --rank 0 \ 74 | --dataset=cc3m --val-dataset=cc3m \ 75 | --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \ 76 | --exp_name='fromage_exp' --image-dir='data/' --log-base-dir='runs/' \ 77 | --batch-size=180 --val-batch-size=100 --learning-rate=0.0003 --precision='bf16' --print-freq=100 78 | ``` 79 | 80 | On a single A6000 GPU, the model converges within 24 hours (with a batch size of 180). For GPUs with smaller memory available, you might need to reduce the batch size, enable gradient accumulation, or adjust hyperparameters to get good performance. You may also have to disable NCCL P2P with `export NCCL_P2P_DISABLE=1` if you run into issues. 81 | 82 | 83 | ### Pruning Model Weights 84 | 85 | As FROMAGe only consists of a few pretrained linear layers and the `[RET]` embedding, we can discard most of the pretrained weights to save on disk space. If you have trained a new model, and wish to do so, you can use `fromage/prune_model_ckpt.py` to prune the model weights. We used the same script to create the weights in the `fromage_model` directory. 86 | 87 | 88 | ### Unit Tests 89 | 90 | You can also test that the code runs locally by running the unit test with `pytest -x`. This runs a short training and evaluation job, with smaller models, to ensure the code works. The test should complete within approximately 90s. 91 | 92 | Note that because of exception catching (to ensure data errors don't terminate training), the test will silently fail and not terminate if there is an I/O error when reading data. Hence, we recommend running the Python command above for debugging data preprocessing. 93 | 94 | 95 | ## Evaluation 96 | 97 | We provide an evaluation script to reproduce our results on contextual image retrieval on Visual Storytelling (results of Table 1 of our paper). The script can be run from `evals/eval_vist_retrieval.py`. There is also a iPython notebook version (`VIST_Contextual_Image_Retrieval.ipynb`) in the same directory. 98 | 99 | Similarly, we provide scripts to reproduce the text generation and image retrieval results on VisDial (presented in Table 2 of our paper). The script for VisDial text generation can be run from `evals/eval_visdial_generation.py` (or through the notebook version, `VisDial_Inference_IT2T_Generation.ipynb`). This reports the NDCG, MRR, and R@k scores for VisDial. 100 | 101 | The results for image retrieval can be reproduced by running the `evals/eval_visdial_retrieval.py` script (or through the notebook version `VisDial_Inference_T2I_Retrieval.ipynb`), which reports R@k scores. 102 | 103 | 104 | ## Gradio Demo 105 | 106 | You can launch your own version of the Gradio demo locally by running `python demo/app.py`, or duplicating the [HuggingFace space](https://huggingface.co/spaces/jykoh/fromage). 107 | 108 | Check out other unofficial HuggingFace spaces for FROMAGe: 109 | - [alvanlii FROMAGe demo](https://huggingface.co/spaces/alvanlii/FROMAGe) 110 | 111 | 112 | ## Citation 113 | 114 | If you find this work useful, please consider citing: 115 | 116 | ``` 117 | @inproceedings{koh2023grounding, 118 | title={Grounding Language Models to Images for Multimodal Inputs and Outputs}, 119 | author={Koh, Jing Yu and Salakhutdinov, Ruslan and Fried, Daniel}, 120 | journal={ICML}, 121 | year={2023} 122 | } 123 | ``` -------------------------------------------------------------------------------- /datasets/cc3m_train.tsv: -------------------------------------------------------------------------------- 1 | caption image 2 | author : a life in photography -- in pictures 0_1595581236 3 | the player staring intently at a computer screen . 3_2841214833 4 | the - bedroom stone cottage can sleep people 5_2227185927 5 | party in the park under cherry blossoms 8_1666482269 6 | -------------------------------------------------------------------------------- /datasets/cc3m_val.tsv: -------------------------------------------------------------------------------- 1 | caption image 2 | author : a life in photography -- in pictures 0_1595581236 3 | the player staring intently at a computer screen . 3_2841214833 4 | the - bedroom stone cottage can sleep people 5_2227185927 5 | party in the park under cherry blossoms 8_1666482269 6 | -------------------------------------------------------------------------------- /demo/app.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from share_btn import community_icon_html, loading_icon_html, share_js, save_js 3 | import huggingface_hub 4 | import gradio as gr 5 | from fromage import utils 6 | from fromage import models 7 | import matplotlib.pyplot as plt 8 | from PIL import Image 9 | import torch 10 | import numpy as np 11 | import os 12 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" 13 | 14 | 15 | css = """ 16 | #chatbot { min-height: 300px; } 17 | #save-btn { 18 | background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); 19 | } 20 | #save-btn:hover { 21 | background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); 22 | } 23 | #share-btn { 24 | background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); 25 | } 26 | #share-btn:hover { 27 | background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); 28 | } 29 | #gallery { z-index: 999999; } 30 | #gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;} 31 | #gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;} 32 | @media (hover: none) { 33 | #gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;} 34 | } 35 | """ 36 | 37 | examples = [ 38 | 'examples/sparrow.png', 39 | 'examples/beaver.png', 40 | 'examples/couch.png', 41 | 'examples/guac.png', 42 | 'examples/scraped_knee.png' 43 | ] 44 | 45 | # Download model from HF Hub. 46 | ckpt_path = huggingface_hub.hf_hub_download( 47 | repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar') 48 | args_path = huggingface_hub.hf_hub_download( 49 | repo_id='jykoh/fromage', filename='model_args.json') 50 | model = models.load_fromage('./', args_path, ckpt_path) 51 | 52 | 53 | def upload_image(state, image_input): 54 | conversation = state[0] 55 | chat_history = state[1] 56 | input_image = Image.open(image_input.name).resize( 57 | (224, 224)).convert('RGB') 58 | input_image.save(image_input.name) # Overwrite with smaller image. 59 | conversation += [(f'', "")] 60 | return [conversation, chat_history + [input_image, ""]], conversation 61 | 62 | 63 | def reset(): 64 | return [[], []], [] 65 | 66 | 67 | def reset_last(state): 68 | conversation = state[0][:-1] 69 | chat_history = state[1][:-2] 70 | return [conversation, chat_history], conversation 71 | 72 | 73 | def save_image_to_local(image: Image.Image): 74 | # TODO(jykoh): Update so the url path is used, to prevent repeat saving. 75 | filename = next(tempfile._get_candidate_names()) + '.png' 76 | image.save(filename) 77 | return filename 78 | 79 | 80 | def generate_for_prompt(input_text, state, ret_scale_factor, max_num_rets, num_words, temperature): 81 | # Ignore empty inputs. 82 | if len(input_text) == 0: 83 | return state, state[0], gr.update(visible=True) 84 | 85 | input_prompt = 'Q: ' + input_text + '\nA:' 86 | conversation = state[0] 87 | chat_history = state[1] 88 | print('Generating for', chat_history, flush=True) 89 | 90 | # If an image was uploaded, prepend it to the model. 91 | model_inputs = chat_history 92 | model_inputs.append(input_prompt) 93 | 94 | top_p = 1.0 95 | if temperature != 0.0: 96 | top_p = 0.95 97 | 98 | print('Running model.generate_for_images_and_texts with', 99 | model_inputs, flush=True) 100 | model_outputs = model.generate_for_images_and_texts(model_inputs, 101 | num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p, 102 | temperature=temperature, max_num_rets=max_num_rets) 103 | print('model_outputs', model_outputs, flush=True) 104 | 105 | im_names = [] 106 | response = '' 107 | text_outputs = [] 108 | for output_i, output in enumerate(model_outputs): 109 | if type(output) == str: 110 | if output_i > 0: 111 | response += '
' 112 | text_outputs.append(output) 113 | response += output 114 | if len(model_outputs) > 1: 115 | response += '
' 116 | elif type(output) == list: 117 | for image in output: 118 | filename = save_image_to_local(image) 119 | response += f'' 120 | elif type(output) == Image.Image: 121 | filename = save_image_to_local(output) 122 | response += f'' 123 | 124 | chat_history = model_inputs + \ 125 | [' '.join([s for s in model_outputs if type(s) == str]) + '\n'] 126 | # Remove [RET] from outputs. 127 | conversation.append((input_text, response.replace('[RET]', ''))) 128 | 129 | # Set input image to None. 130 | print('state', state, flush=True) 131 | print('updated state', [conversation, chat_history], flush=True) 132 | return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True) 133 | 134 | 135 | with gr.Blocks(css=css) as demo: 136 | gr.HTML(""" 137 |

🧀 FROMAGe

138 |

This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.

139 | 140 | Paper: Grounding Language Models to Images for Multimodal Generation 141 |
142 | Project Website: FROMAGe Website 143 |
144 | Code and Models: GitHub 145 |
146 |
147 | 148 | Tips: 149 | 155 | """) 156 | 157 | gr_state = gr.State([[], []]) # conversation, chat_history 158 | 159 | with gr.Row(): 160 | with gr.Column(scale=0.7, min_width=500): 161 | with gr.Row(): 162 | chatbot = gr.Chatbot(elem_id="chatbot", label="🧀 FROMAGe Chatbot") 163 | with gr.Row(): 164 | image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"]) 165 | 166 | text_input = gr.Textbox(label="Message", placeholder="Type a message") 167 | 168 | with gr.Column(): 169 | submit_btn = gr.Button( 170 | "Submit", interactive=True, variant="primary") 171 | clear_last_btn = gr.Button("Undo") 172 | clear_btn = gr.Button("Reset All") 173 | with gr.Row(visible=False) as save_group: 174 | save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn") 175 | 176 | with gr.Row(visible=False) as share_group: 177 | share_button = gr.Button("🤗 Share to Community (opens new window)", elem_id="share-btn") 178 | 179 | with gr.Column(scale=0.3, min_width=400): 180 | ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, 181 | label="Frequency multiplier for returning images (higher means more frequent)") 182 | max_ret_images = gr.Number( 183 | minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return") 184 | gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, 185 | step=1, interactive=True, label="Max # of words") 186 | gr_temperature = gr.Slider( 187 | minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)") 188 | 189 | gallery = gr.Gallery( 190 | value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery", 191 | ).style(grid=[2], height="auto") 192 | 193 | text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, 194 | max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) 195 | text_input.submit(lambda: "", None, text_input) # Reset chatbox. 196 | submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, 197 | max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) 198 | submit_btn.click(lambda: "", None, text_input) # Reset chatbox. 199 | 200 | image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) 201 | clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) 202 | clear_btn.click(reset, [], [gr_state, chatbot]) 203 | share_button.click(None, [], [], _js=share_js) 204 | save_button.click(None, [], [], _js=save_js) 205 | 206 | 207 | demo.queue(concurrency_count=1, api_open=False, max_size=16) 208 | demo.launch(debug=True, server_name="0.0.0.0") 209 | # demo.launch(debug=True, server_name="127.0.0.1") 210 | -------------------------------------------------------------------------------- /demo/share_btn.py: -------------------------------------------------------------------------------- 1 | # Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py 2 | 3 | community_icon_html = """""" 7 | 8 | loading_icon_html = """""" 12 | 13 | share_js = """ 14 | async () => { 15 | const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default; 16 | async function uploadFile(file) { 17 | console.log(file.type) 18 | const UPLOAD_URL = 'https://huggingface.co/uploads'; 19 | const response = await fetch(UPLOAD_URL, { 20 | method: 'POST', 21 | headers: { 22 | 'Content-Type': file.type, 23 | 'X-Requested-With': 'XMLHttpRequest', 24 | }, 25 | body: file, /// <- File inherits from Blob 26 | }); 27 | const url = await response.text(); 28 | return url; 29 | } 30 | async function getImageFile(div) { 31 | return new Promise((resolve, reject) => 32 | html2canvas(div) 33 | .then((canvas) => { 34 | const imageBlob = canvas.toBlob((blob) => { 35 | const imageId = Date.now(); 36 | const fileName = "FROMAGe-" + imageId + ".jpg"; 37 | resolve(new File([blob], fileName, { type: 'image/jpeg' })); 38 | }, 'image/jpeg', 0.95); 39 | }) 40 | 41 | ) 42 | } 43 | 44 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app'); 45 | const chatbotEl = gradioEl.querySelector('#chatbot') 46 | const imageFile = await getImageFile(chatbotEl); 47 | console.log(imageFile); 48 | const urlChatbotImage = await uploadFile(imageFile); 49 | console.log(urlChatbotImage); 50 | let titleTxt = `FROMAGe Example`; 51 | 52 | //const shareBtnEl = gradioEl.querySelector('#share-btn'); 53 | //shareBtnEl.style.pointerEvents = 'none'; 54 | const descriptionMd = ` 55 | 56 | 57 | `; 58 | const params = new URLSearchParams({ 59 | title: titleTxt, 60 | description: descriptionMd, 61 | }); 62 | const paramsStr = params.toString(); 63 | window.open(`https://huggingface.co/spaces/jykoh/fromage/discussions/new?${paramsStr}`, '_blank'); 64 | //shareBtnEl.style.removeProperty('pointer-events'); 65 | } 66 | """ 67 | 68 | save_js = """ 69 | async () => { 70 | const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default; 71 | 72 | function saveAs(uri, filename) { 73 | var link = document.createElement('a'); 74 | if (typeof link.download === 'string') { 75 | link.href = uri; 76 | link.download = filename; 77 | 78 | //Firefox requires the link to be in the body 79 | document.body.appendChild(link); 80 | 81 | //simulate click 82 | link.click(); 83 | 84 | //remove the link when done 85 | document.body.removeChild(link); 86 | } else { 87 | window.open(uri); 88 | } 89 | } 90 | 91 | async function getImageFile(div) { 92 | return new Promise((resolve, reject) => 93 | html2canvas(div) 94 | .then((canvas) => { 95 | const imageId = Date.now(); 96 | const fileName = "FROMAGe-" + imageId + ".png"; 97 | saveAs(canvas.toDataURL(), fileName); 98 | }) 99 | 100 | ) 101 | } 102 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app'); 103 | const chatbotEl = gradioEl.querySelector('#chatbot') 104 | const imageFile = await getImageFile(chatbotEl); 105 | console.log(imageFile); 106 | } 107 | """ -------------------------------------------------------------------------------- /evals/VIST_Contextual_Image_Retrieval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a66bc991", 6 | "metadata": {}, 7 | "source": [ 8 | "# FROMAGe Contextual Image Retrieval\n", 9 | "\n", 10 | "This is a notebook showcasing the contextual image retrieval results from our paper, [Grounding Language Models to Images for Multimodal Generation](https://arxiv.org/abs/2301.13823). This result is reported in Table 1. The results of this notebook may be slightly different from the paper, as the Flickr images from Visual Storytelling may disappear over time.\n", 11 | "\n", 12 | "At least 18GB of GPU memory is required to run FROMAGe, and it has only been tested on A6000, V100, and 3090 GPUs." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "475add8f", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import collections\n", 24 | "import copy\n", 25 | "import json\n", 26 | "import os\n", 27 | "import torch\n", 28 | "from transformers import logging\n", 29 | "from tqdm import notebook\n", 30 | "logging.set_verbosity_error()\n", 31 | "\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "from fromage import models\n", 36 | "from fromage import utils" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "7e884127", 42 | "metadata": {}, 43 | "source": [ 44 | "### Load Pretrained Model" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "4646a124", 51 | "metadata": { 52 | "scrolled": true 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.\n", 60 | "Using facebook/opt-6.7b for the language model.\n", 61 | "Using openai/clip-vit-large-patch14 for the visual model with 1 visual tokens.\n" 62 | ] 63 | }, 64 | { 65 | "data": { 66 | "application/vnd.jupyter.widget-view+json": { 67 | "model_id": "6bb4489d7d8b4235abe3f965318af27d", 68 | "version_major": 2, 69 | "version_minor": 0 70 | }, 71 | "text/plain": [ 72 | "Loading checkpoint shards: 0%| | 0/2 [00:00... [RET]`, providing this as input to FROMAGe, and retrieve the image corresponding to the `[RET]` embedding." 239 | ] 240 | }, 241 | { 242 | "cell_type": "code", 243 | "execution_count": 6, 244 | "id": "d20c3c02", 245 | "metadata": { 246 | "scrolled": true 247 | }, 248 | "outputs": [ 249 | { 250 | "data": { 251 | "application/vnd.jupyter.widget-view+json": { 252 | "model_id": "b9e320375086411eaf5e0ca0acc6f6fa", 253 | "version_major": 2, 254 | "version_minor": 0 255 | }, 256 | "text/plain": [ 257 | " 0%| | 0/4990 [00:00 0:\n", 221 | " contexts.append('A: ' + answers[prev_d['answer']])\n", 222 | " contexts.append('Q: ' + questions[current_d['question']] + '?')\n", 223 | " answer_options = [answers[i] for i in current_d['answer_options']]\n", 224 | " answer = answers[current_d['answer']]\n", 225 | " gt_index = current_d['gt_index']\n", 226 | " caption = '\\n'.join(contexts) + '\\nA: '\n", 227 | "\n", 228 | " # Run through every possible option, and pick the option with the lowest loss (= lowest perplexity)\n", 229 | " example_losses = []\n", 230 | " # Tokenize the dialogue sequence (as this is the same for all answer choices).\n", 231 | " caption_ids = model.model.tokenizer(\n", 232 | " caption, add_special_tokens=True, return_tensors=\"pt\").input_ids\n", 233 | " caption_ids = caption_ids.to(images.device)\n", 234 | " caption_embs = model.model.input_embeddings(caption_ids) # (N, T, D)\n", 235 | " condition_length = visual_embs.shape[1] + caption_embs.shape[1]\n", 236 | "\n", 237 | " all_example_embs = []\n", 238 | " all_example_labels = []\n", 239 | "\n", 240 | " for _, ans in enumerate(answer_options):\n", 241 | " ans_ids = model.model.tokenizer(ans, add_special_tokens=True, return_tensors=\"pt\").input_ids\n", 242 | " ans_ids = ans_ids.to(images.device)\n", 243 | " ans_embs = model.model.input_embeddings(ans_ids)\n", 244 | " input_embs = torch.cat([\n", 245 | " visual_embs,\n", 246 | " caption_embs,\n", 247 | " ans_embs], dim=1)\n", 248 | " labels = torch.cat([\n", 249 | " torch.zeros(visual_embs.shape[:-1], device=caption_ids.device, dtype=caption_ids.dtype) - 100,\n", 250 | " caption_ids,\n", 251 | " ans_ids], dim=1)\n", 252 | " assert labels.shape[1] == input_embs.shape[1]\n", 253 | "\n", 254 | " all_example_embs.append(input_embs)\n", 255 | " all_example_labels.append(labels)\n", 256 | "\n", 257 | " max_len = max([x.shape[1] for x in all_example_labels])\n", 258 | " padded_example_embs = [torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[1])) for x in all_example_embs]\n", 259 | " padded_example_embs = torch.cat(padded_example_embs, axis=0)\n", 260 | "\n", 261 | " padded_example_labels = [torch.nn.functional.pad(x, (0, max_len - x.shape[1]), value=-100) for x in all_example_labels]\n", 262 | " padded_example_labels = torch.cat(padded_example_labels, axis=0)\n", 263 | "\n", 264 | " all_logits = []\n", 265 | " batches = int(np.ceil(padded_example_embs.shape[0] / batch_size))\n", 266 | " for i in range(batches):\n", 267 | " start_idx = i * batch_size\n", 268 | " end_idx = start_idx + batch_size\n", 269 | " out = model.model.lm(\n", 270 | " inputs_embeds=padded_example_embs[start_idx:end_idx, ...],\n", 271 | " labels=None,\n", 272 | " use_cache=False,\n", 273 | " output_hidden_states=True)\n", 274 | " all_logits.append(out.logits)\n", 275 | "\n", 276 | " logits = torch.cat(all_logits, dim=0)\n", 277 | " example_losses = ce_loss(logits.reshape((-1, logits.shape[-1])), padded_example_labels.reshape((-1,)))\n", 278 | " example_losses = example_losses.reshape((100, max_len))[:, condition_length:]\n", 279 | " example_losses = example_losses.sum(axis=1)\n", 280 | "\n", 281 | " all_losses.append(example_losses.cpu().float().numpy())\n", 282 | " scores = -example_losses\n", 283 | " _, preds = scores.topk(max(topk))\n", 284 | " all_preds.append(preds)\n", 285 | " all_gt_results.append(gt_index)\n", 286 | "\n", 287 | " with open(save_path, 'wb') as wf:\n", 288 | " np.save(wf, {'all_preds': all_preds, 'all_gt_results': all_gt_results, 'all_losses': all_losses})" 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "id": "aef2a81a", 294 | "metadata": {}, 295 | "source": [ 296 | "### Computing Results\n", 297 | "\n", 298 | "Finally, we can compute NDCG, MRR, and Recall@k:" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 6, 304 | "id": "8c0b673c", 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "# Define some classes to help us compute NDCG and MRR.\n", 309 | "# Modified from https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/metrics.py\n", 310 | "\n", 311 | "class NDCG(object):\n", 312 | " def __init__(self):\n", 313 | " self._ndcg_numerator = 0.0\n", 314 | " self._ndcg_denominator = 0.0\n", 315 | "\n", 316 | " def observe(\n", 317 | " self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor\n", 318 | " ):\n", 319 | " \"\"\"\n", 320 | " Observe model output scores and target ground truth relevance and\n", 321 | " accumulate NDCG metric.\n", 322 | " Parameters\n", 323 | " ----------\n", 324 | " predicted_scores: torch.Tensor\n", 325 | " A tensor of shape (batch_size, num_options), because dense\n", 326 | " annotations are available for 1 randomly picked round out of 10.\n", 327 | " target_relevance: torch.Tensor\n", 328 | " A tensor of shape same as predicted scores, indicating ground truth\n", 329 | " relevance of each answer option for a particular round.\n", 330 | " \"\"\"\n", 331 | " predicted_scores = predicted_scores.detach()\n", 332 | "\n", 333 | " # shape: (batch_size, 1, num_options)\n", 334 | " predicted_scores = predicted_scores.unsqueeze(1)\n", 335 | " predicted_ranks = scores_to_ranks(predicted_scores)\n", 336 | "\n", 337 | " # shape: (batch_size, num_options)\n", 338 | " predicted_ranks = predicted_ranks.squeeze(1)\n", 339 | " batch_size, num_options = predicted_ranks.size()\n", 340 | "\n", 341 | " k = torch.sum(target_relevance != 0, dim=-1)\n", 342 | "\n", 343 | " # shape: (batch_size, num_options)\n", 344 | " _, rankings = torch.sort(predicted_ranks, dim=-1)\n", 345 | " # Sort relevance in descending order so highest relevance gets top rnk.\n", 346 | " _, best_rankings = torch.sort(\n", 347 | " target_relevance, dim=-1, descending=True\n", 348 | " )\n", 349 | "\n", 350 | " # shape: (batch_size, )\n", 351 | " batch_ndcg = []\n", 352 | " for batch_index in range(batch_size):\n", 353 | " num_relevant = k[batch_index]\n", 354 | " dcg = self._dcg(\n", 355 | " rankings[batch_index][:num_relevant],\n", 356 | " target_relevance[batch_index],\n", 357 | " )\n", 358 | " best_dcg = self._dcg(\n", 359 | " best_rankings[batch_index][:num_relevant],\n", 360 | " target_relevance[batch_index],\n", 361 | " )\n", 362 | " batch_ndcg.append(dcg / best_dcg)\n", 363 | "\n", 364 | " self._ndcg_denominator += batch_size\n", 365 | " self._ndcg_numerator += sum(batch_ndcg)\n", 366 | "\n", 367 | " def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):\n", 368 | " sorted_relevance = relevance[rankings].cpu().float()\n", 369 | " discounts = torch.log2(torch.arange(len(rankings)).float() + 2)\n", 370 | " return torch.sum(sorted_relevance / discounts, dim=-1)\n", 371 | "\n", 372 | " def retrieve(self, reset: bool = True, key=\"\"):\n", 373 | " if self._ndcg_denominator > 0:\n", 374 | " metrics = {\n", 375 | " key + \"ndcg\": float(self._ndcg_numerator / self._ndcg_denominator)\n", 376 | " }\n", 377 | " else:\n", 378 | " metrics = {}\n", 379 | "\n", 380 | " if reset:\n", 381 | " self.reset()\n", 382 | " return metrics\n", 383 | "\n", 384 | " def reset(self):\n", 385 | " self._ndcg_numerator = 0.0\n", 386 | " self._ndcg_denominator = 0.0\n", 387 | " \n", 388 | "\n", 389 | "def scores_to_ranks(scores: torch.Tensor):\n", 390 | " \"\"\"Convert model output scores into ranks.\"\"\"\n", 391 | " batch_size, num_rounds, num_options = scores.size()\n", 392 | " scores = scores.view(-1, num_options)\n", 393 | "\n", 394 | " # sort in descending order - largest score gets highest rank\n", 395 | " sorted_ranks, ranked_idx = scores.sort(1, descending=True)\n", 396 | "\n", 397 | " # i-th position in ranked_idx specifies which score shall take this\n", 398 | " # position but we want i-th position to have rank of score at that\n", 399 | " # position, do this conversion\n", 400 | " ranks = ranked_idx.clone().fill_(0)\n", 401 | " for i in range(ranked_idx.size(0)):\n", 402 | " for j in range(num_options):\n", 403 | " ranks[i][ranked_idx[i][j]] = j\n", 404 | " # convert from 0-99 ranks to 1-100 ranks\n", 405 | " ranks += 1\n", 406 | " ranks = ranks.view(batch_size, num_rounds, num_options)\n", 407 | " return ranks" 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 7, 413 | "id": "15e4c2aa", 414 | "metadata": {}, 415 | "outputs": [], 416 | "source": [ 417 | "with open(save_path, 'rb') as rf:\n", 418 | " all_data = np.load(rf, allow_pickle=True).item()\n", 419 | " all_preds = all_data['all_preds']\n", 420 | " all_gt_results = all_data['all_gt_results']\n", 421 | " all_losses = all_data['all_losses']" 422 | ] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": 8, 427 | "id": "7686317d", 428 | "metadata": {}, 429 | "outputs": [ 430 | { 431 | "name": "stdout", 432 | "output_type": "stream", 433 | "text": [ 434 | "top-k, k=1, acc=0.17573\n", 435 | "top-k, k=5, acc=0.19971\n", 436 | "top-k, k=10, acc=0.24414\n", 437 | "top-k, k=20, acc=0.48309\n", 438 | "MRR: 0.21997\n", 439 | "NDCG: 0.16594\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "top_k_accuracy = collections.defaultdict(list)\n", 445 | "mrr_results = []\n", 446 | "all_ranks = []\n", 447 | "topk = (1, 5, 10, 20)\n", 448 | "ndcg = NDCG()\n", 449 | "\n", 450 | "assert len(all_preds) == len(all_gt_results)\n", 451 | "for gt, loss in zip(all_gt_results, all_losses):\n", 452 | " scores = -loss\n", 453 | " _, preds = torch.tensor(scores).topk(100)\n", 454 | " rank = np.where(preds == gt)[0][0] + 1\n", 455 | " all_ranks.append(rank)\n", 456 | " mrr_results.append(1 / rank)\n", 457 | "\n", 458 | " for k in topk:\n", 459 | " acc = gt in preds[:k]\n", 460 | " top_k_accuracy[k].append(acc)\n", 461 | " \n", 462 | "dense_mrr = []\n", 463 | "for i in range(len(dense_data)):\n", 464 | " idx = i * 10 + dense_data[i]['round_id']\n", 465 | " if idx >= len(all_losses):\n", 466 | " break\n", 467 | " scores = -torch.tensor(all_losses[idx])[None, :]\n", 468 | " relevance = torch.tensor(dense_data[i]['gt_relevance'])[None, :]\n", 469 | " ndcg.observe(scores, relevance)\n", 470 | " dense_mrr.append(mrr_results[idx])\n", 471 | "\n", 472 | "for k in topk:\n", 473 | " print(f'top-k, k={k}, acc={np.mean(top_k_accuracy[k]):.5f}')\n", 474 | "print(f'MRR: {np.mean(mrr_results):.5f}')\n", 475 | "print(f'NDCG: {ndcg.retrieve(reset=True)[\"ndcg\"]:.5f}')" 476 | ] 477 | }, 478 | { 479 | "cell_type": "code", 480 | "execution_count": null, 481 | "id": "1982d6fe", 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [] 485 | } 486 | ], 487 | "metadata": { 488 | "kernelspec": { 489 | "display_name": "Python 3 (ipykernel)", 490 | "language": "python", 491 | "name": "python3" 492 | }, 493 | "language_info": { 494 | "codemirror_mode": { 495 | "name": "ipython", 496 | "version": 3 497 | }, 498 | "file_extension": ".py", 499 | "mimetype": "text/x-python", 500 | "name": "python", 501 | "nbconvert_exporter": "python", 502 | "pygments_lexer": "ipython3", 503 | "version": "3.10.4" 504 | } 505 | }, 506 | "nbformat": 4, 507 | "nbformat_minor": 5 508 | } 509 | -------------------------------------------------------------------------------- /evals/VisDial_Inference_T2I_Retrieval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a66bc991", 6 | "metadata": {}, 7 | "source": [ 8 | "# FROMAGe Visual Dialog (Image Retrieval)\n", 9 | "\n", 10 | "This is a notebook reproducing the VisDial text-to-image (T2I) retrieval results from our paper, [Grounding Language Models to Images for Multimodal Inputs and Outputs](https://arxiv.org/abs/2301.13823). This result is reported in Table 2 of the paper. This measures the recall of the model in selecting the appropriate image conditioned on a dialogue sequence.\n", 11 | "\n", 12 | "At least 18GB of GPU memory is required to run FROMAGe, and it has only been tested on A6000, V100, and 3090 GPUs." 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "id": "475add8f", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "import numpy as np\n", 23 | "import collections\n", 24 | "import copy\n", 25 | "import json\n", 26 | "import os\n", 27 | "import torch\n", 28 | "from transformers import logging\n", 29 | "from tqdm import notebook\n", 30 | "logging.set_verbosity_error()\n", 31 | "\n", 32 | "from PIL import Image\n", 33 | "import matplotlib.pyplot as plt\n", 34 | "\n", 35 | "from fromage import models\n", 36 | "from fromage import utils" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "id": "7e884127", 42 | "metadata": {}, 43 | "source": [ 44 | "### Load Pretrained FROMAGe Model" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 2, 50 | "id": "4646a124", 51 | "metadata": { 52 | "scrolled": true 53 | }, 54 | "outputs": [ 55 | { 56 | "name": "stdout", 57 | "output_type": "stream", 58 | "text": [ 59 | "Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.\n", 60 | "Using facebook/opt-6.7b for the language model.\n", 61 | "Using openai/clip-vit-large-patch14 for the visual model with 1 visual tokens.\n" 62 | ] 63 | }, 64 | { 65 | "data": { 66 | "application/vnd.jupyter.widget-view+json": { 67 | "model_id": "b2aa1cfecea64ed6a4d2c3cbb946d463", 68 | "version_major": 2, 69 | "version_minor": 0 70 | }, 71 | "text/plain": [ 72 | "Loading checkpoint shards: 0%| | 0/2 [00:00 0: 128 | metrics = { 129 | key + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) 130 | } 131 | else: 132 | metrics = {} 133 | 134 | if reset: 135 | self.reset() 136 | return metrics 137 | 138 | def reset(self): 139 | self._ndcg_numerator = 0.0 140 | self._ndcg_denominator = 0.0 141 | 142 | 143 | def scores_to_ranks(scores: torch.Tensor): 144 | """Convert model output scores into ranks.""" 145 | batch_size, num_rounds, num_options = scores.size() 146 | scores = scores.view(-1, num_options) 147 | 148 | # sort in descending order - largest score gets highest rank 149 | sorted_ranks, ranked_idx = scores.sort(1, descending=True) 150 | 151 | # i-th position in ranked_idx specifies which score shall take this 152 | # position but we want i-th position to have rank of score at that 153 | # position, do this conversion 154 | ranks = ranked_idx.clone().fill_(0) 155 | for i in range(ranked_idx.size(0)): 156 | for j in range(num_options): 157 | ranks[i][ranked_idx[i][j]] = j 158 | # convert from 0-99 ranks to 1-100 ranks 159 | ranks += 1 160 | ranks = ranks.view(batch_size, num_rounds, num_options) 161 | return ranks 162 | 163 | 164 | if __name__ == "__main__": 165 | # Load model used in the paper. 166 | model_dir = './fromage_model/' 167 | model = models.load_fromage(model_dir) 168 | 169 | # Load VisDial data. 170 | img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018') 171 | 172 | with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f: 173 | visdial_data = json.load(f) 174 | 175 | with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f: 176 | dense_data = json.load(f) 177 | 178 | # Check that dense and sparse data are aligned. 179 | assert len(dense_data) == len(visdial_data['data']['dialogs']) 180 | for i in range(len(dense_data)): 181 | assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id'] 182 | 183 | questions = visdial_data['data']['questions'] 184 | answers = visdial_data['data']['answers'] 185 | dialogs = visdial_data['data']['dialogs'] 186 | 187 | # Then, for each VisDial example, we compute the loss 188 | # conditioned on the image and the preceding dialogue. 189 | # We return the option with the lowest loss as the answer: 190 | ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda() 191 | 192 | if os.path.exists(save_path): 193 | with open(save_path, 'rb') as rf: 194 | all_data = np.load(rf, allow_pickle=True).item() 195 | all_preds = all_data['all_preds'] 196 | all_gt_results = all_data['all_gt_results'] 197 | all_losses = all_data['all_losses'] 198 | assert len(all_preds) == len(all_gt_results) == len(all_losses) 199 | else: 200 | # No in progress data, initialize from scratch. 201 | all_preds = [] 202 | all_gt_results = [] 203 | all_losses = [] 204 | 205 | for example_idx in notebook.tqdm(range(len(all_preds) // 10, len(dialogs))): 206 | dialog = dialogs[example_idx] 207 | image_id = str(dialog['image_id']).rjust(12, '0') 208 | contexts = [] 209 | 210 | with torch.no_grad(): 211 | images = get_pixel_values_from_path( 212 | os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'), 213 | model.model.feature_extractor) 214 | visual_embs = model.model.get_visual_embs(images, mode='captioning') 215 | 216 | for i in range(len(dialog['dialog'])): 217 | prev_d = dialog['dialog'][i-1] 218 | current_d = dialog['dialog'][i] 219 | if i > 0: 220 | contexts.append('A: ' + answers[prev_d['answer']]) 221 | contexts.append('Q: ' + questions[current_d['question']] + '?') 222 | answer_options = [answers[i] for i in current_d['answer_options']] 223 | answer = answers[current_d['answer']] 224 | gt_index = current_d['gt_index'] 225 | caption = '\n'.join(contexts) + '\nA: ' 226 | 227 | # Run through every possible option, and pick the option with 228 | # the lowest loss (= lowest perplexity) 229 | example_losses = [] 230 | # Tokenize the dialogue sequence (as this is the same for all answer choices). 231 | caption_ids = model.model.tokenizer( 232 | caption, add_special_tokens=True, return_tensors="pt").input_ids 233 | caption_ids = caption_ids.to(images.device) 234 | caption_embs = model.model.input_embeddings(caption_ids) # (N, T, D) 235 | condition_length = visual_embs.shape[1] + caption_embs.shape[1] 236 | 237 | all_example_embs = [] 238 | all_example_labels = [] 239 | 240 | for _, ans in enumerate(answer_options): 241 | ans_ids = model.model.tokenizer(ans, add_special_tokens=True, return_tensors="pt").input_ids 242 | ans_ids = ans_ids.to(images.device) 243 | ans_embs = model.model.input_embeddings(ans_ids) 244 | input_embs = torch.cat([ 245 | visual_embs, 246 | caption_embs, 247 | ans_embs], dim=1) 248 | labels = torch.cat([ 249 | torch.zeros(visual_embs.shape[:-1], device=caption_ids.device, dtype=caption_ids.dtype) - 100, 250 | caption_ids, 251 | ans_ids], dim=1) 252 | assert labels.shape[1] == input_embs.shape[1] 253 | 254 | all_example_embs.append(input_embs) 255 | all_example_labels.append(labels) 256 | 257 | max_len = max([x.shape[1] for x in all_example_labels]) 258 | padded_example_embs = [torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[1])) for x in all_example_embs] 259 | padded_example_embs = torch.cat(padded_example_embs, axis=0) 260 | 261 | padded_example_labels = [torch.nn.functional.pad(x, (0, max_len - x.shape[1]), value=-100) for x in all_example_labels] 262 | padded_example_labels = torch.cat(padded_example_labels, axis=0) 263 | 264 | all_logits = [] 265 | batches = int(np.ceil(padded_example_embs.shape[0] / batch_size)) 266 | for i in range(batches): 267 | start_idx = i * batch_size 268 | end_idx = start_idx + batch_size 269 | out = model.model.lm( 270 | inputs_embeds=padded_example_embs[start_idx:end_idx, ...], 271 | labels=None, 272 | use_cache=False, 273 | output_hidden_states=True) 274 | all_logits.append(out.logits) 275 | 276 | logits = torch.cat(all_logits, dim=0) 277 | example_losses = ce_loss(logits.reshape((-1, logits.shape[-1])), padded_example_labels.reshape((-1,))) 278 | example_losses = example_losses.reshape((100, max_len))[:, condition_length:] 279 | example_losses = example_losses.sum(axis=1) 280 | 281 | all_losses.append(example_losses.cpu().float().numpy()) 282 | scores = -example_losses 283 | _, preds = scores.topk(max(topk)) 284 | all_preds.append(preds) 285 | all_gt_results.append(gt_index) 286 | 287 | with open(save_path, 'wb') as wf: 288 | np.save(wf, {'all_preds': all_preds, 'all_gt_results': all_gt_results, 'all_losses': all_losses}) 289 | 290 | # Finally, we can compute NDCG, MRR, and Recall@k: 291 | with open(save_path, 'rb') as rf: 292 | all_data = np.load(rf, allow_pickle=True).item() 293 | all_preds = all_data['all_preds'] 294 | all_gt_results = all_data['all_gt_results'] 295 | all_losses = all_data['all_losses'] 296 | 297 | top_k_accuracy = collections.defaultdict(list) 298 | mrr_results = [] 299 | all_ranks = [] 300 | topk = (1, 5, 10, 20) 301 | ndcg = NDCG() 302 | 303 | assert len(all_preds) == len(all_gt_results) 304 | for gt, loss in zip(all_gt_results, all_losses): 305 | scores = -loss 306 | _, preds = torch.tensor(scores).topk(100) 307 | rank = np.where(preds == gt)[0][0] + 1 308 | all_ranks.append(rank) 309 | mrr_results.append(1 / rank) 310 | 311 | for k in topk: 312 | acc = gt in preds[:k] 313 | top_k_accuracy[k].append(acc) 314 | 315 | dense_mrr = [] 316 | for i in range(len(dense_data)): 317 | idx = i * 10 + dense_data[i]['round_id'] 318 | if idx >= len(all_losses): 319 | break 320 | scores = -torch.tensor(all_losses[idx])[None, :] 321 | relevance = torch.tensor(dense_data[i]['gt_relevance'])[None, :] 322 | ndcg.observe(scores, relevance) 323 | dense_mrr.append(mrr_results[idx]) 324 | 325 | for k in topk: 326 | print(f'top-k, k={k}, acc={np.mean(top_k_accuracy[k]):.5f}') 327 | print(f'MRR: {np.mean(mrr_results):.5f}') 328 | print(f'NDCG: {ndcg.retrieve(reset=True)["ndcg"]:.5f}') 329 | -------------------------------------------------------------------------------- /evals/eval_visdial_retrieval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | """ 5 | This is a script reproducing the VisDial text-to-image (T2I) retrieval results from our paper, 6 | Grounding Language Models to Images for Multimodal Inputs and Outputs (https://arxiv.org/abs/2301.13823). 7 | This result is reported in Table 2 of the paper. This measures the recall of the model in selecting 8 | the appropriate image conditioned on a dialogue sequence. 9 | 10 | Example usage: `python eval_visdial_retrieval.py` 11 | """ 12 | 13 | import numpy as np 14 | import collections 15 | import copy 16 | import json 17 | import os 18 | import torch 19 | from transformers import logging 20 | from tqdm import notebook 21 | logging.set_verbosity_error() 22 | 23 | from PIL import Image 24 | import matplotlib.pyplot as plt 25 | 26 | from fromage import models 27 | from fromage import utils 28 | 29 | 30 | def get_pixel_values_from_path(path: str, feature_extractor): 31 | """Helper function for getting images pixels from a local path.""" 32 | img = Image.open(path) 33 | img = img.resize((224, 224)) 34 | img = img.convert('RGB') 35 | pixel_values = utils.get_pixel_values_for_model(feature_extractor, img) 36 | if torch.cuda.is_available(): 37 | pixel_values = pixel_values.bfloat16() 38 | pixel_values = pixel_values.cuda() 39 | return pixel_values[None, ...] 40 | 41 | 42 | if __name__ == "__main__": 43 | # Load model used in the paper. 44 | model_dir = './' #'./fromage_model/' 45 | model = models.load_fromage(model_dir) 46 | 47 | 48 | # Download the VisDial validation annotations (https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0), 49 | # the dense answer annotations (https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0) 50 | # (for computing MRR) and the images (https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0). 51 | # Extract everything to the `VisualDialog` folder. 52 | 53 | base_dir = '/projects/tir6/general/jingyuk/VisualDialog' 54 | split = 'val' 55 | img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018') 56 | 57 | with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f: 58 | visdial_data = json.load(f) 59 | 60 | with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f: 61 | dense_data = json.load(f) 62 | 63 | # Check that dense and sparse data are aligned. 64 | assert len(dense_data) == len(visdial_data['data']['dialogs']) 65 | for i in range(len(dense_data)): 66 | assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id'] 67 | 68 | questions = visdial_data['data']['questions'] 69 | answers = visdial_data['data']['answers'] 70 | dialogs = visdial_data['data']['dialogs'] 71 | 72 | # Then, we compute the image features and text features for each VisDial example: 73 | topk = (1, 5, 10) 74 | ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda() 75 | 76 | all_visual_embs = [] 77 | all_text_embs = [] 78 | 79 | for example_idx in notebook.tqdm(range(len(dialogs))): 80 | dialog = dialogs[example_idx] 81 | image_id = str(dialog['image_id']).rjust(12, '0') 82 | contexts = [] 83 | 84 | with torch.no_grad(): 85 | images = get_pixel_values_from_path( 86 | os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'), 87 | model.model.feature_extractor) 88 | visual_embs = model.model.get_visual_embs(images, mode='retrieval') 89 | 90 | for i in range(len(dialog['dialog'])): 91 | contexts.append('Q: ' + questions[dialog['dialog'][i]['question']] + '?') 92 | contexts.append('A: ' + answers[dialog['dialog'][i]['answer']] + '.') 93 | 94 | full_context_sent = ' '.join(contexts) + '[RET]' 95 | input_ids = model.model.tokenizer(full_context_sent, add_special_tokens=True, return_tensors="pt").input_ids 96 | input_ids = input_ids.cuda() 97 | input_embs = model.model.input_embeddings(input_ids) # (N, T, D) 98 | generated_ids, output_embs, _ = model(input_embs, None, None, generate=True, num_words=1, temperature=0.0) 99 | embeddings = output_embs[0] 100 | 101 | full_input_ids = torch.cat([input_ids, generated_ids], dim=1) 102 | ret_emb = embeddings[:, -1, :] 103 | 104 | all_visual_embs.append(visual_embs.cpu().detach().float().numpy()) 105 | all_text_embs.append(ret_emb.cpu().detach().float().numpy()) 106 | 107 | # Compute scores over the whole dataset: 108 | scores = np.concatenate(all_visual_embs, axis=0)[:, 0, :] @ np.concatenate(all_text_embs, axis=0).T 109 | scores = torch.tensor(scores).float() 110 | assert scores.shape == (2064, 2064), scores.shape 111 | 112 | 113 | # Finally, we can compute the Recall@k scores: 114 | _, preds = scores.topk(max(topk)) 115 | for k in topk: 116 | labels = torch.arange(preds.shape[0]) 117 | correct = torch.any(preds[:, :k] == labels[:, None], axis=1).sum() 118 | acc = correct / preds.shape[0] 119 | print(f'top-k, k={k}, acc={acc:.5f}') 120 | print('=' * 20) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /evals/eval_vist_retrieval.py: -------------------------------------------------------------------------------- 1 | # This is a script reproducing the contextual image retrieval results from our paper. 2 | # This result is reported in Table 1. 3 | # The results of this may be slightly different from the paper, as the Flickr images from Visual Storytelling may disappear over time. 4 | 5 | import numpy as np 6 | import collections 7 | import copy 8 | import json 9 | import os 10 | import torch 11 | from transformers import logging 12 | from tqdm import tqdm 13 | logging.set_verbosity_error() 14 | 15 | from PIL import Image 16 | import matplotlib.pyplot as plt 17 | 18 | from fromage import models 19 | from fromage import utils 20 | 21 | 22 | # Load model used in the paper. 23 | model_dir = './fromage_model/' 24 | model = models.load_fromage(model_dir) 25 | 26 | # Download the Visual Storytelling SIS dataset from https://visionandlanguage.net/VIST/json_files/story-in-sequence/SIS-with-labels.tar.gz 27 | # Extract the files (there should be three sets: train, val, and test). 28 | # We use the val set for reporting results. 29 | vist_val_json_path = 'sis/val.story-in-sequence.json' 30 | with open(vist_val_json_path, 'r') as f: 31 | vist_data_raw = json.load(f) 32 | 33 | # Format into a dictionary of {story_id: data} items. 34 | vist_data = { 35 | 'annotations': collections.defaultdict(list) 36 | } 37 | used_image_ids = [] 38 | 39 | 40 | for ann in vist_data_raw['annotations']: 41 | assert len(ann) == 1 42 | ann = ann[0] 43 | story_id = ann['story_id'] 44 | vist_data['annotations'][story_id].append({ 45 | 'caption': ann['text'], 46 | 'image_id': ann['photo_flickr_id'], 47 | 'sequence_index': ann['worker_arranged_photo_order'], 48 | }) 49 | used_image_ids.append(ann['photo_flickr_id']) 50 | 51 | used_image_ids = set(used_image_ids) 52 | print(len(used_image_ids)) 53 | 54 | 55 | # Precompute image features for running retrieval. 56 | embs_fn = 'sis_img_features.npy' 57 | id2url = {} 58 | 59 | for image_data in vist_data_raw['images']: 60 | image_id = image_data['id'] 61 | if image_id in used_image_ids: 62 | image_url = image_data.get('url_o', None) 63 | if image_url is not None: 64 | id2url[image_id] = image_url 65 | 66 | if not os.path.exists(embs_fn): 67 | print(f'{embs_fn} does not exist, computing it from scratch.') 68 | all_visual_embs = [] 69 | all_image_ids = [] 70 | 71 | for image_id, image_url in tqdm(id2url.items()): 72 | try: 73 | images = utils.get_image_from_url(image_url) 74 | pixel_values = utils.get_pixel_values_for_model(model.model.feature_extractor, images) 75 | pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype) 76 | pixel_values = pixel_values[None, ...] 77 | visual_embs = model.model.get_visual_embs(pixel_values, mode='retrieval') 78 | all_visual_embs.append(visual_embs.float().cpu().detach().numpy()) 79 | all_image_ids.append(image_id) 80 | except Image.UnidentifiedImageError: 81 | pass 82 | 83 | all_image_ids = np.array(all_image_ids) 84 | all_visual_embs = np.concatenate(all_visual_embs, axis=0) 85 | assert all_image_ids.shape[0] == all_visual_embs.shape[0], (all_image_ids.shape, all_visual_embs.shape) 86 | print(all_image_ids.shape, all_visual_embs.shape) 87 | 88 | with open(embs_fn, 'wb') as wf: 89 | np.save(wf, {'image_ids': all_image_ids, 'embeddings': all_visual_embs}) 90 | 91 | # Load embeddings. 92 | with open(embs_fn, 'rb') as wf: 93 | embs_data = np.load(wf, allow_pickle=True).item() 94 | all_image_ids = embs_data['image_ids'] 95 | emb_matrix = embs_data['embeddings'] 96 | 97 | # Normalize embedding matrix to be suitable for image retrieval. 98 | logit_scale = model.model.logit_scale.exp() 99 | emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device) 100 | emb_matrix = emb_matrix / emb_matrix.norm(dim=-1, keepdim=True) 101 | emb_matrix = logit_scale * emb_matrix 102 | print('emb_matrix.shape', emb_matrix.shape) 103 | 104 | 105 | # Then, for each VIST example, we process it as `... [RET]`, 106 | # providing this as input to FROMAGe, and retrieve the image corresponding to the `[RET]` embedding. 107 | 108 | topk = (1, 5, 10) 109 | top_k_preds = {} 110 | 111 | with torch.no_grad(): 112 | for story_idx, (story_id, story_data) in tqdm(enumerate(vist_data['annotations'].items()), total=len(vist_data['annotations'])): 113 | gt_image_id = story_data[-1]['image_id'] 114 | skip = False # Skip examples that do not have images (due to URLs being taken down, or something) 115 | for s in story_data: 116 | if s['image_id'] not in all_image_ids or s['image_id'] not in id2url: 117 | skip = True 118 | break 119 | 120 | if not skip: 121 | # Use the first n-1 images and n captions as input. 122 | image_urls = [id2url[s['image_id']] for s in story_data[:-1]] 123 | captions = [s['caption'] for s in story_data] 124 | assert len(image_urls) == len(captions) - 1 125 | 126 | visual_embs = [] 127 | # Compute embeddings for the input images. 128 | images = [utils.get_image_from_url(image_url) for image_url in image_urls] 129 | pixel_values = [utils.get_pixel_values_for_model(model.model.feature_extractor, image) for image in images] 130 | pixel_values = torch.stack(pixel_values, dim=0) # (n-1, 3, 224, 224) 131 | pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype) 132 | visual_embs = model.model.get_visual_embs(pixel_values, mode='captioning') 133 | 134 | # Compute embeddings for the input captions. 135 | all_input_ids = [] 136 | for i, c in enumerate(captions): 137 | if i == len(captions) - 1: 138 | c += '[RET]' # Add the [RET] token to the final caption. 139 | input_ids = model.model.tokenizer(c, add_special_tokens=True, return_tensors="pt").input_ids.to(emb_matrix.device) 140 | all_input_ids.append(input_ids) 141 | 142 | input_embs = [model.model.input_embeddings(s)[0, ...] for s in all_input_ids] # (N, T, D) 143 | 144 | # Interleave captions and images as [caption1, image1, caption2, ..., image4, caption5]. 145 | final_input_embs = [] 146 | assert len(visual_embs) == len(input_embs) - 1 147 | for i in range(len(images)): 148 | final_input_embs.append(input_embs[i]) 149 | final_input_embs.append(visual_embs[i]) 150 | final_input_embs.append(input_embs[len(images)]) 151 | final_input_embs = torch.cat(final_input_embs, dim=0)[None, ...] # (1, T, 4096) 152 | 153 | # Get embedding of the [RET] token, and compute scores: 154 | output = model.model.lm(inputs_embeds=final_input_embs, labels=None, use_cache=False, output_hidden_states=True) 155 | last_hidden_state = model.model.text_hidden_fcs[0](output.hidden_states[-1]) 156 | ret_emb = last_hidden_state[:, -1, :] 157 | 158 | ret_emb = ret_emb / ret_emb.norm(dim=1, keepdim=True) 159 | scores = ret_emb.squeeze() @ emb_matrix.squeeze().T 160 | 161 | # Don't retrieve previously seen images. 162 | prev_image_ids = [s['image_id'] for s in story_data[:-1]] 163 | for prev_id in prev_image_ids: 164 | scores[np.where(all_image_ids == prev_id)[0]] -= 10000 165 | 166 | # Store top-k preds. 167 | _, preds = scores.topk(max(topk)) 168 | preds = preds.cpu().detach().numpy() 169 | preds = [all_image_ids[p] for p in preds] 170 | top_k_preds[story_id] = {'topk_preds': preds, 'gt': gt_image_id} 171 | 172 | 173 | # Finally, we can compute Recall@k: 174 | top_k_accuracy = collections.defaultdict(list) 175 | 176 | for story_id, results in top_k_preds.items(): 177 | for k in topk: 178 | acc = results['gt'] in results['topk_preds'][:k] 179 | top_k_accuracy[k].append(acc) 180 | 181 | for k in topk: 182 | result_str = f'k={k}, acc={np.mean(top_k_accuracy[k]):.5f}' 183 | print(result_str) 184 | -------------------------------------------------------------------------------- /fromage/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage/__init__.py -------------------------------------------------------------------------------- /fromage/data.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_clip""" 2 | 3 | from typing import Optional, Tuple 4 | 5 | import collections 6 | import logging 7 | import os 8 | import numpy as np 9 | import pandas as pd 10 | import torch 11 | import torchvision.datasets as datasets 12 | from torchvision import transforms as T 13 | from PIL import Image, ImageFont 14 | from torch.utils.data import Dataset 15 | 16 | from fromage import utils 17 | 18 | 19 | def collate_fn(batch): 20 | batch = list(filter(lambda x: x is not None, batch)) 21 | return torch.utils.data.dataloader.default_collate(batch) 22 | 23 | 24 | def get_dataset(args, split: str, tokenizer, precision: str = 'fp32') -> Dataset: 25 | assert split in ['train', 'val' 26 | ], 'Expected split to be one of "train" or "val", got {split} instead.' 27 | 28 | dataset_paths = [] 29 | image_data_dirs = [] 30 | train = split == 'train' 31 | 32 | # Default configs for datasets. 33 | # Folder structure should look like: 34 | if split == 'train': 35 | if 'cc3m' in args.dataset: 36 | dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_train.tsv')) 37 | image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/training/')) 38 | else: 39 | raise NotImplementedError 40 | 41 | elif split == 'val': 42 | if 'cc3m' in args.val_dataset: 43 | dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_val.tsv')) 44 | image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/validation')) 45 | else: 46 | raise NotImplementedError 47 | 48 | assert len(dataset_paths) == len(image_data_dirs) == 1, (dataset_paths, image_data_dirs) 49 | else: 50 | raise NotImplementedError 51 | 52 | if len(dataset_paths) > 1: 53 | print(f'{len(dataset_paths)} datasets requested: {dataset_paths}') 54 | dataset = torch.utils.data.ConcatDataset([ 55 | CsvDataset(path, image_dir, tokenizer, 'image', 56 | 'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision, 57 | image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx) 58 | for (path, image_dir) in zip(dataset_paths, image_data_dirs)]) 59 | elif len(dataset_paths) == 1: 60 | dataset = CsvDataset(dataset_paths[0], image_data_dirs[0], tokenizer, 'image', 61 | 'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision, 62 | image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx) 63 | else: 64 | raise ValueError(f'There should be at least one valid dataset, got train={args.dataset}, val={args.val_dataset} instead.') 65 | return dataset 66 | 67 | 68 | class CsvDataset(Dataset): 69 | def __init__(self, input_filename, base_image_dir, tokenizer, img_key, 70 | caption_key, feature_extractor_model: str, 71 | train: bool = True, max_len: int = 32, sep="\t", precision: str = 'fp32', 72 | image_size: int = 224, retrieval_token_idx: int = -1): 73 | logging.debug(f'Loading tsv data from {input_filename}.') 74 | df = pd.read_csv(input_filename, sep=sep) 75 | 76 | self.base_image_dir = base_image_dir 77 | self.images = df[img_key].tolist() 78 | self.captions = df[caption_key].tolist() 79 | assert len(self.images) == len(self.captions) 80 | 81 | self.feature_extractor_model = feature_extractor_model 82 | self.feature_extractor = utils.get_feature_extractor_for_model( 83 | feature_extractor_model, image_size=image_size, train=False) 84 | self.image_size = image_size 85 | 86 | self.tokenizer = tokenizer 87 | self.max_len = max_len 88 | self.precision = precision 89 | self.retrieval_token_idx = retrieval_token_idx 90 | 91 | self.font = None 92 | 93 | logging.debug('Done loading data.') 94 | 95 | def __len__(self): 96 | return len(self.captions) 97 | 98 | def __getitem__(self, idx): 99 | while True: 100 | image_path = os.path.join(self.base_image_dir, str(self.images[idx])) 101 | caption = str(self.captions[idx]) 102 | 103 | try: 104 | img = Image.open(image_path) 105 | images = utils.get_pixel_values_for_model(self.feature_extractor, img) 106 | 107 | caption += '[RET]' 108 | tokenized_data = self.tokenizer( 109 | caption, 110 | return_tensors="pt", 111 | padding='max_length', 112 | truncation=True, 113 | max_length=self.max_len) 114 | tokens = tokenized_data.input_ids[0] 115 | 116 | caption_len = tokenized_data.attention_mask[0].sum() 117 | 118 | decode_caption = self.tokenizer.decode(tokens, skip_special_tokens=False) 119 | self.font = self.font or ImageFont.load_default() 120 | cap_img = utils.create_image_of_text(decode_caption.encode('ascii', 'ignore'), width=self.image_size, nrows=2, font=self.font) 121 | 122 | if tokens[-1] not in [self.retrieval_token_idx, self.tokenizer.pad_token_id]: 123 | tokens[-1] = self.retrieval_token_idx 124 | 125 | return image_path, images, cap_img, tokens, caption_len 126 | except Exception as e: 127 | print(f'Error reading {image_path} with caption {caption}: {e}') 128 | # Pick a new example at random. 129 | idx = np.random.randint(0, len(self)-1) 130 | -------------------------------------------------------------------------------- /fromage/evaluate.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import json 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | import time 7 | import tqdm 8 | import torch 9 | import torch.distributed as dist 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torchmetrics import BLEUScore 12 | import torchvision 13 | 14 | from fromage import losses as losses_utils 15 | from fromage import utils 16 | 17 | 18 | def validate(val_loader, model, tokenizer, criterion, epoch, args): 19 | ngpus_per_node = torch.cuda.device_count() 20 | writer = SummaryWriter(args.log_dir) 21 | bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3, 4]] 22 | actual_step = (epoch + 1) * args.steps_per_epoch 23 | model_modes = ['captioning', 'retrieval'] 24 | num_words = 32 # Number of tokens to generate. 25 | 26 | feature_extractor = utils.get_feature_extractor_for_model(args.visual_model, image_size=args.image_size, train=False) 27 | 28 | def get_pixel_values_from_path(path: str): 29 | img = Image.open(path) 30 | img = img.resize((args.image_size, args.image_size)) 31 | pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)[None, ...] 32 | 33 | if args.precision == 'fp16': 34 | pixel_values = pixel_values.half() 35 | elif args.precision == 'bf16': 36 | pixel_values = pixel_values.bfloat16() 37 | if torch.cuda.is_available(): 38 | pixel_values = pixel_values.cuda() 39 | return pixel_values 40 | 41 | def run_validate(loader, base_progress=0): 42 | with torch.no_grad(): 43 | end = time.time() 44 | all_generated_captions = [] 45 | all_gt_captions = [] 46 | all_generated_image_paths = [] 47 | all_image_features = [] 48 | all_text_features = [] 49 | 50 | for i, (image_paths, images, caption_images, tgt_tokens, token_len) in tqdm.tqdm(enumerate(loader), position=0, total=len(loader)): 51 | i = base_progress + i 52 | 53 | if torch.cuda.is_available(): 54 | tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True) 55 | token_len = token_len.cuda(args.gpu, non_blocking=True) 56 | images = images.cuda() 57 | 58 | if args.precision == 'fp16': 59 | images = images.half() 60 | elif args.precision == 'bf16': 61 | images = images.bfloat16() 62 | 63 | for model_mode in model_modes: 64 | (model_output, full_labels, last_embedding, _, visual_embs) = model( 65 | images, tgt_tokens, token_len, mode=model_mode, input_prefix=args.input_prompt, inference=True) # (N, T, C) 66 | 67 | if model_mode == 'captioning': 68 | loss = args.cap_loss_scale * model_output.loss 69 | elif model_mode == 'retrieval': 70 | loss = args.ret_loss_scale * model_output.loss 71 | else: 72 | raise NotImplementedError 73 | 74 | output = model_output.logits 75 | if model_mode == 'captioning': 76 | acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5)) 77 | top1.update(acc1[0], images.size(0)) 78 | top5.update(acc5[0], images.size(0)) 79 | ce_losses.update(loss.item(), images.size(0)) 80 | 81 | if model_mode == 'captioning': 82 | losses.update(loss.item(), images.size(0)) 83 | elif model_mode == 'retrieval': 84 | if args.distributed: 85 | original_last_embedding = torch.clone(last_embedding) 86 | all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())] 87 | all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())] 88 | dist.all_gather(all_visual_embs, visual_embs) 89 | dist.all_gather(all_last_embedding, last_embedding) 90 | 91 | # Overwrite with embeddings produced on this replica, which track the gradients. 92 | all_visual_embs[dist.get_rank()] = visual_embs 93 | all_last_embedding[dist.get_rank()] = last_embedding 94 | visual_embs = torch.cat(all_visual_embs) 95 | last_embedding = torch.cat(all_last_embedding) 96 | start_idx = args.rank * images.shape[0] 97 | end_idx = start_idx + images.shape[0] 98 | assert torch.all(last_embedding[start_idx:end_idx] == original_last_embedding), args.rank 99 | 100 | all_text_features.append(last_embedding.cpu()) 101 | all_image_features.append(visual_embs.cpu()) 102 | 103 | # Run auto-regressive generation sample 104 | if model_mode == 'captioning': 105 | input_embs = model.module.model.get_visual_embs(images, mode='captioning') # (2, n_visual_tokens, D) 106 | if args.input_prompt is not None: 107 | print(f'Adding prefix "{args.input_prompt}" to captioning generate=True.') 108 | prompt_ids = tokenizer(args.input_prompt, add_special_tokens=False, return_tensors="pt").input_ids 109 | prompt_ids = prompt_ids.to(visual_embs.device) 110 | prompt_embs = model.module.model.input_embeddings(prompt_ids) 111 | prompt_embs = prompt_embs.repeat(input_embs.shape[0], 1, 1) 112 | input_embs = torch.cat([input_embs, prompt_embs], dim=1) 113 | 114 | generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, 115 | generate=True, num_words=num_words, temperature=0.0, top_p=1.0, 116 | min_word_tokens=num_words) 117 | 118 | if args.distributed and ngpus_per_node > 1: 119 | all_generated_ids = [torch.zeros_like(generated_ids) for _ in range(dist.get_world_size())] 120 | dist.all_gather(all_generated_ids, generated_ids) 121 | all_generated_ids[dist.get_rank()] = generated_ids 122 | generated_ids = torch.cat(all_generated_ids) 123 | 124 | all_tgt_tokens = [torch.zeros_like(tgt_tokens) for _ in range(dist.get_world_size())] 125 | dist.all_gather(all_tgt_tokens, tgt_tokens) 126 | all_tgt_tokens[dist.get_rank()] = tgt_tokens 127 | all_tgt_tokens = torch.cat(all_tgt_tokens) 128 | 129 | all_image_paths = [[None for _ in image_paths] for _ in range(dist.get_world_size())] 130 | dist.all_gather_object(all_image_paths, image_paths) 131 | all_image_paths[dist.get_rank()] = image_paths 132 | image_paths = [] 133 | for p in all_image_paths: 134 | image_paths.extend(p) 135 | else: 136 | all_tgt_tokens = tgt_tokens 137 | 138 | all_tgt_tokens[all_tgt_tokens == -100] = tokenizer.pad_token_id 139 | generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 140 | gt_captions = tokenizer.batch_decode(all_tgt_tokens, skip_special_tokens=True) 141 | 142 | for cap_i in range(len(generated_captions)): 143 | image_path = image_paths[cap_i] 144 | all_generated_image_paths.append(image_path) 145 | stop_idx = generated_captions[cap_i].find('.') 146 | if stop_idx > 5: 147 | all_generated_captions.append(generated_captions[cap_i][:stop_idx]) 148 | else: 149 | all_generated_captions.append(generated_captions[cap_i]) 150 | all_gt_captions.append([gt_captions[cap_i]]) 151 | elif model_mode == 'retrieval': 152 | if i == 0: 153 | # Generate without image input to visualize text-generation ability. 154 | input_ids = tgt_tokens[:, :3] # Use first 3 tokens as initial prompt for generation. 155 | input_embs = model.module.model.input_embeddings(input_ids) # (N, T, D) 156 | generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, generate=True, num_words=num_words, temperature=0.0, top_p=1.0) 157 | generated_ids = torch.cat([input_ids, generated_ids], dim=1) 158 | generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) 159 | gt_captions = tokenizer.batch_decode(tgt_tokens, skip_special_tokens=False) 160 | else: 161 | raise NotImplementedError 162 | 163 | if i == 0: 164 | max_to_display = 5 165 | print('=' * 30) 166 | print('Generated samples:') 167 | for cap_i, cap in enumerate(generated_captions[:max_to_display]): 168 | print(f'{cap_i}) {cap}') 169 | print('=' * 30) 170 | print('Real samples:') 171 | for cap_i, cap in enumerate(gt_captions[:max_to_display]): 172 | print(f'{cap_i}) {cap}') 173 | print('=' * 30) 174 | 175 | # Write images and captions to Tensorboard. 176 | if not args.distributed or (args.rank % ngpus_per_node == 0): 177 | max_images_to_show = 16 178 | normalized_images = images - images.min() 179 | normalized_images /= normalized_images.max() # (N, 3, H, W) 180 | # Create generated caption text. 181 | generated_cap_images = torch.stack([ 182 | utils.create_image_of_text( 183 | generated_captions[j].encode('ascii', 'ignore'), 184 | width=normalized_images.shape[3], 185 | color=(255, 255, 0)) 186 | for j in range(normalized_images.shape[0])], axis=0) 187 | # Append gt/generated caption images. 188 | display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show] 189 | grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4) 190 | writer.add_image(f'val/images_{model_mode}', grid, actual_step) 191 | 192 | # measure elapsed time 193 | batch_time.update(time.time() - end) 194 | end = time.time() 195 | 196 | if i % args.print_freq == 0: 197 | progress.display(i + 1) 198 | 199 | if i == args.val_steps_per_epoch - 1: 200 | break 201 | 202 | # Measure captioning metrics. 203 | path2captions = collections.defaultdict(list) 204 | for image_path, caption in zip(all_generated_image_paths, all_gt_captions): 205 | assert len(caption) == 1, caption 206 | path2captions[image_path].append(caption[0].replace('[RET]', '')) 207 | full_gt_captions = [path2captions[path] for path in all_generated_image_paths] 208 | 209 | print(f'Computing BLEU with {len(all_generated_captions)} generated captions:' 210 | f'{all_generated_captions[:5]} and {len(full_gt_captions)} groundtruth captions:', 211 | f'{full_gt_captions[:5]}.') 212 | bleu1_score = bleu_scorers[0](all_generated_captions, full_gt_captions) 213 | bleu1.update(bleu1_score, 1) 214 | bleu2_score = bleu_scorers[1](all_generated_captions, full_gt_captions) 215 | bleu2.update(bleu2_score, 1) 216 | bleu3_score = bleu_scorers[2](all_generated_captions, full_gt_captions) 217 | bleu3.update(bleu3_score, 2) 218 | bleu4_score = bleu_scorers[3](all_generated_captions, full_gt_captions) 219 | bleu4.update(bleu4_score, 3) 220 | 221 | # Measure retrieval metrics over the entire validation set. 222 | all_image_features = torch.cat(all_image_features, axis=0) # (coco_val_len, 2048) 223 | all_text_features = torch.cat(all_text_features, axis=0) # (coco_val_len, 2048) 224 | 225 | print(f"Computing similarity between {all_image_features.shape} and {all_text_features.shape}.") 226 | logits_per_image = all_image_features @ all_text_features.t() 227 | logits_per_text = logits_per_image.t() 228 | all_image_acc1, all_image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5)) 229 | all_caption_acc1, all_caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5)) 230 | image_loss = losses_utils.contrastive_loss(logits_per_image) 231 | caption_loss = losses_utils.contrastive_loss(logits_per_text) 232 | 233 | loss = args.ret_loss_scale * (image_loss + caption_loss) / 2.0 234 | losses.update(loss.item(), logits_per_image.size(0)) 235 | top1_caption.update(all_caption_acc1.item(), logits_per_image.size(0)) 236 | top5_caption.update(all_caption_acc5.item(), logits_per_image.size(0)) 237 | top1_image.update(all_image_acc1.item(), logits_per_image.size(0)) 238 | top5_image.update(all_image_acc5.item(), logits_per_image.size(0)) 239 | 240 | 241 | batch_time = utils.AverageMeter('Time', ':6.3f', utils.Summary.AVERAGE) 242 | losses = utils.AverageMeter('Loss', ':.4e', utils.Summary.AVERAGE) 243 | ce_losses = utils.AverageMeter('CeLoss', ':.4e', utils.Summary.AVERAGE) 244 | top1 = utils.AverageMeter('Acc@1', ':6.2f', utils.Summary.AVERAGE) 245 | top5 = utils.AverageMeter('Acc@5', ':6.2f', utils.Summary.AVERAGE) 246 | bleu1 = utils.AverageMeter('BLEU@1', ':6.2f', utils.Summary.AVERAGE) 247 | bleu2 = utils.AverageMeter('BLEU@2', ':6.2f', utils.Summary.AVERAGE) 248 | bleu3 = utils.AverageMeter('BLEU@3', ':6.2f', utils.Summary.AVERAGE) 249 | bleu4 = utils.AverageMeter('BLEU@4', ':6.2f', utils.Summary.AVERAGE) 250 | top1_caption = utils.AverageMeter('CaptionAcc@1', ':6.2f', utils.Summary.AVERAGE) 251 | top5_caption = utils.AverageMeter('CaptionAcc@5', ':6.2f', utils.Summary.AVERAGE) 252 | top1_image = utils.AverageMeter('ImageAcc@1', ':6.2f', utils.Summary.AVERAGE) 253 | top5_image = utils.AverageMeter('ImageAcc@5', ':6.2f', utils.Summary.AVERAGE) 254 | 255 | progress = utils.ProgressMeter( 256 | len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), 257 | [batch_time, losses, top1, top5, bleu4], 258 | prefix='Test: ') 259 | 260 | # switch to evaluate mode 261 | model.eval() 262 | 263 | run_validate(val_loader) 264 | if args.distributed: 265 | batch_time.all_reduce() 266 | losses.all_reduce() 267 | bleu1.all_reduce() 268 | bleu2.all_reduce() 269 | bleu3.all_reduce() 270 | bleu4.all_reduce() 271 | top1.all_reduce() 272 | top5.all_reduce() 273 | top1_caption.all_reduce() 274 | top5_caption.all_reduce() 275 | top1_image.all_reduce() 276 | top5_image.all_reduce() 277 | 278 | if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): 279 | aux_val_dataset = Subset(val_loader.dataset, 280 | range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset))) 281 | aux_val_loader = torch.utils.data.DataLoader( 282 | aux_val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False, 283 | num_workers=args.workers, pin_memory=True, collate_fn=data.collate_fn) 284 | run_validate(aux_val_loader, len(val_loader)) 285 | 286 | progress.display_summary() 287 | 288 | writer.add_scalar('val/total_secs_per_batch', batch_time.avg, actual_step) 289 | writer.add_scalar('val/seq_top1_acc', top1.avg, actual_step) 290 | writer.add_scalar('val/seq_top5_acc', top5.avg, actual_step) 291 | writer.add_scalar('val/ce_loss', losses.avg, actual_step) 292 | writer.add_scalar('val/bleu1', bleu1.avg, actual_step) 293 | writer.add_scalar('val/bleu2', bleu2.avg, actual_step) 294 | writer.add_scalar('val/bleu3', bleu3.avg, actual_step) 295 | writer.add_scalar('val/bleu4', bleu4.avg, actual_step) 296 | writer.add_scalar('val/contrastive_loss', losses.avg, actual_step) 297 | writer.add_scalar('val/t2i_top1_acc', top1_caption.avg, actual_step) 298 | writer.add_scalar('val/t2i_top5_acc', top5_caption.avg, actual_step) 299 | writer.add_scalar('val/i2t_top1_acc', top1_image.avg, actual_step) 300 | writer.add_scalar('val/i2t_top5_acc', top5_image.avg, actual_step) 301 | writer.add_scalar('val/top1_acc', (top1_caption.avg + top1_image.avg) / 2.0, actual_step) 302 | writer.add_scalar('val/top5_acc', (top5_caption.avg + top5_image.avg) / 2.0, actual_step) 303 | 304 | writer.close() 305 | 306 | # Use top1 accuracy as the metric for keeping the best checkpoint. 307 | return top1_caption.avg 308 | -------------------------------------------------------------------------------- /fromage/extract_img_embs.py: -------------------------------------------------------------------------------- 1 | """Extract image embeddings for a list of image urls. 2 | 3 | Example usage: 4 | python extract_img_embs.py 5 | """ 6 | import torch 7 | from fromage import models, utils 8 | 9 | from PIL import Image 10 | import os 11 | import requests 12 | from io import BytesIO 13 | import pickle as pkl 14 | 15 | 16 | def extract_embeddings_for_urls(image_urls: list[str], emb_output_path: str, device: str = "cuda"): 17 | # Load model checkpoint. 18 | model = models.load_fromage("./fromage_model/") 19 | model.eval() 20 | 21 | visual_encoder = "openai/clip-vit-large-patch14" 22 | feature_extractor = utils.get_feature_extractor_for_model( 23 | visual_encoder, train=False 24 | ) 25 | 26 | output_data = {"paths": [], "embeddings": []} 27 | with torch.no_grad(): 28 | for img_url in image_urls: 29 | img = Image.open(BytesIO(requests.get(img_url).content)) 30 | 31 | img_tensor = utils.get_pixel_values_for_model(feature_extractor, img) 32 | img_tensor = img_tensor[None, ...].to(device).bfloat16() 33 | img_emb = model.model.get_visual_embs(img_tensor, mode="retrieval") 34 | img_emb = img_emb[0, 0, :].cpu() 35 | output_data["paths"].append(img_url) 36 | output_data["embeddings"].append(img_emb) 37 | 38 | with open(emb_output_path, "wb") as f: 39 | pkl.dump(output_data, f) 40 | 41 | 42 | if __name__ == "__main__": 43 | image_urls = [] # TODO: Replace with image urls 44 | if image_urls == []: 45 | raise ValueError("Please replace `image_urls` with a list of image urls.") 46 | 47 | extract_embeddings_for_urls(image_urls, "fromage_model/cc3m_embeddings.pkl") 48 | -------------------------------------------------------------------------------- /fromage/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | from fromage import utils 4 | 5 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 6 | return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) 7 | 8 | 9 | def contrastive_acc(logits: torch.Tensor, target: Optional[torch.Tensor] = None, topk=(1,)) -> torch.Tensor: 10 | """ 11 | Args: 12 | logits: (N, N) predictions. 13 | target: (N, num_correct_answers) labels. 14 | """ 15 | assert len(logits.shape) == 2, logits.shape 16 | batch_size = logits.shape[0] 17 | 18 | if target is None: 19 | target = torch.arange(len(logits), device=logits.device) 20 | return utils.accuracy(logits, target, -1, topk) 21 | else: 22 | assert len(target.shape) == 2, target.shape 23 | with torch.no_grad(): 24 | maxk = max(topk) 25 | if logits.shape[-1] < maxk: 26 | print(f"[WARNING] Less than {maxk} predictions available. Using {logits.shape[-1]} for topk.") 27 | maxk = min(maxk, logits.shape[-1]) 28 | 29 | # Take topk along the last dimension. 30 | _, pred = logits.topk(maxk, -1, True, True) # (N, topk) 31 | assert pred.shape == (batch_size, maxk) 32 | 33 | target_expand = target[:, :, None].repeat(1, 1, maxk) # (N, num_correct_answers, topk) 34 | pred_expand = pred[:, None, :].repeat(1, target.shape[1], 1) # (N, num_correct_answers, topk) 35 | correct = pred_expand.eq(target_expand) # (N, num_correct_answers, topk) 36 | correct = torch.any(correct, dim=1) # (N, topk) 37 | 38 | res = [] 39 | for k in topk: 40 | any_k_correct = torch.clamp(correct[:, :k].sum(1), max=1) # (N,) 41 | correct_k = any_k_correct.float().sum(0, keepdim=True) 42 | res.append(correct_k.mul_(100.0 / batch_size)) 43 | return res 44 | 45 | -------------------------------------------------------------------------------- /fromage/models.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple, Union 2 | from collections import namedtuple 3 | import json 4 | import glob 5 | import math 6 | import numpy as np 7 | import os 8 | import torch 9 | from torch import Tensor 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from einops import rearrange 13 | from functools import partial 14 | import pickle as pkl 15 | from PIL import Image, UnidentifiedImageError 16 | 17 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM 18 | from transformers import OPTForCausalLM, GPT2Tokenizer 19 | from transformers import CLIPVisionModel, CLIPVisionConfig 20 | 21 | from fromage import utils 22 | 23 | 24 | class FrozenArgs: 25 | freeze_lm: bool = True 26 | freeze_vm: bool = True 27 | opt_version: str = 'facebook/opt-6.7b' 28 | visual_encoder: str = 'openai/clip-vit-large-patch14' 29 | n_visual_tokens: int = 1 30 | image_embed_dropout_prob: float = 0.0 31 | task: str = 'captioning' 32 | shared_emb_dim: Optional[int] = 256 33 | text_emb_layers: List[int] = [-1] 34 | retrieval_token_idx: int = 0 35 | 36 | 37 | class FromageModel(nn.Module): 38 | def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()): 39 | super().__init__() 40 | self.tokenizer = tokenizer 41 | self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False) 42 | self.image_token = self.tokenizer.cls_token_id 43 | assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique' 44 | self.args = args 45 | 46 | opt_version = args.opt_version 47 | visual_encoder = args.visual_encoder 48 | n_visual_tokens = args.n_visual_tokens 49 | print(f"Using {opt_version} for the language model.") 50 | print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.") 51 | 52 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 53 | 54 | if 'facebook/opt' in opt_version: 55 | self.lm = OPTForCausalLM.from_pretrained(opt_version) 56 | else: 57 | raise NotImplementedError 58 | 59 | self.opt_version = opt_version 60 | 61 | if self.args.freeze_lm: 62 | self.lm.eval() 63 | print("Freezing the LM.") 64 | for param in self.lm.parameters(): 65 | param.requires_grad = False 66 | else: 67 | self.lm.train() 68 | 69 | # NOTE: Resizing sets all token embeddings and all lm_head weights (since they are tied in OPT) 70 | # to be trainable (param.requires_grad = True). 71 | self.retrieval_token_idx = args.retrieval_token_idx 72 | print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).') 73 | self.lm.resize_token_embeddings(len(tokenizer)) 74 | 75 | self.input_embeddings = self.lm.get_input_embeddings() 76 | 77 | print("Restoring pretrained weights for the visual model.") 78 | if 'clip' in visual_encoder: 79 | self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder) 80 | else: 81 | self.visual_model = AutoModel.from_pretrained(visual_encoder) 82 | 83 | if 'clip' in visual_encoder: 84 | hidden_size = self.visual_model.config.hidden_size 85 | else: 86 | raise NotImplementedError 87 | 88 | if self.args.freeze_vm: 89 | print("Freezing the VM.") 90 | self.visual_model.eval() 91 | for param in self.visual_model.parameters(): 92 | param.requires_grad = False 93 | else: 94 | self.visual_model.train() 95 | 96 | self.visual_model_name = visual_encoder 97 | 98 | embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens 99 | self.text_hidden_fcs = nn.ModuleList([]) 100 | if self.args.shared_emb_dim is None: 101 | if len(self.args.text_emb_layers) == 1: 102 | if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version): 103 | out_dim = self.lm.config.word_embed_proj_dim 104 | else: 105 | out_dim = self.lm.config.hidden_size 106 | else: 107 | if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \ 108 | and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size): 109 | raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).') 110 | else: 111 | out_dim = self.lm.config.hidden_size 112 | else: 113 | out_dim = self.args.shared_emb_dim 114 | 115 | for layer_idx in self.args.text_emb_layers: 116 | if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version): 117 | in_dim = self.lm.config.word_embed_proj_dim 118 | 119 | text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)] 120 | self.text_hidden_fcs.append(nn.Sequential(*text_fc)) 121 | 122 | elif layer_idx < self.lm.config.num_hidden_layers: 123 | text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)] 124 | self.text_hidden_fcs.append(nn.Sequential(*text_fc)) 125 | else: 126 | raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.') 127 | 128 | self.visual_embeddings = nn.Linear(hidden_size, embedding_dim) 129 | self.visual_fc = nn.Linear(hidden_size, out_dim) 130 | 131 | self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob) 132 | 133 | 134 | def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'): 135 | if mode not in ['captioning', 'retrieval']: 136 | raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.') 137 | 138 | # Extract visual embeddings from the vision encoder. 139 | if 'clip' in self.visual_model_name: 140 | outputs = self.visual_model(pixel_values) 141 | encoder_outputs = outputs.pooler_output 142 | else: 143 | raise NotImplementedError 144 | 145 | # Use the correct fc based on function argument. 146 | if mode == 'captioning': 147 | visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens) 148 | visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1)) 149 | elif mode == 'retrieval': 150 | visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens) 151 | visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1)) 152 | else: 153 | raise NotImplementedError 154 | 155 | visual_embs = self.image_dropout(visual_embs) 156 | return visual_embs 157 | 158 | 159 | def train(self, mode=True): 160 | super(FromageModel, self).train(mode=mode) 161 | # Overwrite train() to ensure Frozen models remain frozen. 162 | if self.args.freeze_lm: 163 | self.lm.eval() 164 | if self.args.freeze_vm: 165 | self.visual_model.eval() 166 | 167 | 168 | def forward( 169 | self, 170 | pixel_values: torch.FloatTensor, 171 | labels: torch.LongTensor, 172 | caption_len: torch.LongTensor, 173 | mode: str = 'captioning', 174 | concat_captions: bool = False, 175 | input_prefix: Optional[str] = None, 176 | inference: bool = False, 177 | ): 178 | visual_embs = self.get_visual_embs(pixel_values, mode) 179 | 180 | batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens 181 | if labels is not None: 182 | assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape) 183 | 184 | input_embs = self.input_embeddings(labels) # (N, T, D) 185 | 186 | last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token 187 | 188 | if input_prefix is not None: 189 | prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids 190 | prompt_ids = prompt_ids.to(visual_embs.device) 191 | prompt_embs = self.input_embeddings(prompt_ids) 192 | prompt_embs = prompt_embs.repeat(batch_size, 1, 1) 193 | assert prompt_embs.shape[0] == batch_size, prompt_embs.shape 194 | assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape 195 | assert len(prompt_embs.shape) == 3, prompt_embs.shape 196 | 197 | if mode == 'captioning': 198 | # Concat to text embeddings. 199 | condition_seq_len = 0 200 | if input_prefix is None: 201 | # Just add visual embeddings. 202 | input_embs = torch.cat([visual_embs, input_embs], axis=1) 203 | last_embedding_idx += vis_seq_len 204 | condition_seq_len += vis_seq_len 205 | full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 206 | else: 207 | # Add visual and prompt embeddings. 208 | prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1) 209 | input_embs = torch.cat([prefix_embs, input_embs], axis=1) 210 | 211 | last_embedding_idx += prefix_embs.shape[1] 212 | condition_seq_len += prefix_embs.shape[1] 213 | full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 214 | 215 | # Mask out embedding tokens in the labels. 216 | full_labels = torch.cat([full_labels, labels], axis=1) 217 | 218 | pad_idx = [] 219 | 220 | for label in full_labels: 221 | for k, token in enumerate(label): 222 | # Mask out retrieval token if it exists. 223 | if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]: 224 | label[k:] = -100 225 | pad_idx.append(k) 226 | break 227 | if k == len(label) - 1: # No padding found. 228 | pad_idx.append(k + 1) 229 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) 230 | 231 | bs, seq_len, embs_dim = input_embs.shape 232 | if concat_captions: 233 | assert len(input_embs.shape) == 3, input_embs 234 | assert len(full_labels.shape) == 2, full_labels 235 | assert batch_size % 2 == 0 236 | all_concat_input_embs = [] 237 | all_concat_labels = [] 238 | 239 | # Rearrange embeddings and labels (and their padding) to concatenate captions. 240 | for i in range(batch_size // 2): 241 | first_idx = i * 2 242 | second_idx = first_idx + 1 243 | first_emb = input_embs[first_idx, :pad_idx[first_idx], :] 244 | first_labels = full_labels[first_idx, :pad_idx[first_idx]] 245 | first_padding = input_embs[first_idx, pad_idx[first_idx]:, :] 246 | first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:] 247 | 248 | second_emb = input_embs[second_idx, :pad_idx[second_idx], :] 249 | second_labels = full_labels[second_idx, :pad_idx[second_idx]] 250 | second_padding = input_embs[second_idx, pad_idx[second_idx]:, :] 251 | second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:] 252 | 253 | assert torch.all(first_labels_padding == -100), first_labels_padding 254 | assert torch.all(second_labels_padding == -100), second_labels_padding 255 | concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768) 256 | concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768) 257 | all_concat_input_embs.append(concat_input_embs) 258 | all_concat_labels.append(concat_labels) 259 | 260 | # Pad to max length. 261 | input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768) 262 | full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768) 263 | assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape 264 | assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape 265 | 266 | output = self.lm(inputs_embeds=input_embs, 267 | labels=full_labels, 268 | output_hidden_states=True) 269 | elif mode == 'retrieval': 270 | full_labels = torch.clone(labels) 271 | if input_prefix is not None: 272 | print(f'Adding prefix "{input_prefix}" to retrieval.') 273 | # Add prompt embeddings. 274 | prefix_embs = prompt_embs 275 | input_embs = torch.cat([prefix_embs, input_embs], axis=1) 276 | last_embedding_idx += prefix_embs.shape[1] 277 | full_labels = torch.cat([ 278 | torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100, 279 | full_labels 280 | ], axis=1) 281 | 282 | pad_idx = [] 283 | for label in full_labels: 284 | for k, token in enumerate(label): 285 | if token == self.tokenizer.pad_token_id: 286 | label[k:] = -100 287 | pad_idx.append(k) 288 | break 289 | if k == len(label) - 1: # No padding found. 290 | pad_idx.append(k + 1) 291 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size) 292 | 293 | output = self.lm(inputs_embeds=input_embs, 294 | labels=full_labels, 295 | output_hidden_states=True) 296 | else: 297 | raise NotImplementedError 298 | 299 | last_embedding = None 300 | last_output_logit = None 301 | hidden_states = [] 302 | 303 | if mode == 'retrieval': 304 | if self.args.shared_emb_dim is not None: 305 | for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs): 306 | hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048) 307 | else: 308 | for idx in self.args.text_emb_layers: 309 | hidden_states.append(output.hidden_states[idx]) 310 | 311 | # Add hidden states together. 312 | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) 313 | 314 | if not concat_captions: 315 | last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) # (N, D) 316 | last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D) 317 | else: 318 | raise NotImplementedError 319 | 320 | # Compute retrieval loss. 321 | assert visual_embs.shape[1] == 1, visual_embs.shape 322 | visual_embs = visual_embs[:, 0, :] 323 | visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True) 324 | last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True) 325 | 326 | # cosine similarity as logits 327 | logit_scale = self.logit_scale.exp() 328 | visual_embs = logit_scale * visual_embs 329 | elif mode == 'captioning': 330 | pass 331 | else: 332 | raise NotImplementedError 333 | 334 | return output, full_labels, last_embedding, last_output_logit, visual_embs 335 | 336 | def generate(self, embeddings = torch.FloatTensor, max_len: int = 32, 337 | temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0, 338 | ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')): 339 | """Runs greedy decoding and returns generated captions. 340 | 341 | Args: 342 | embeddings: Input condition that the model uses for autoregressive generation. 343 | max_len: Maximum number of tokens to generate. 344 | temperature: Used to modulate logit distribution. 345 | top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation. 346 | min_word_tokens: Minimum number of words to generate before allowing a [RET] output. 347 | ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs. 348 | filter_value: Value to assign to tokens that should never be generated. 349 | Outputs: 350 | out: (N, T) int32 sequence of output tokens. 351 | output_embeddings: (N, T, 256) sequence of text output embeddings. 352 | """ 353 | self.lm.eval() 354 | 355 | with torch.no_grad(): # no tracking history 356 | batch_size, s, _ = embeddings.shape 357 | # init output with image tokens 358 | out = None 359 | past_key_values = None 360 | output_embeddings = [] 361 | output_logits = [] 362 | 363 | for i in range(max_len): 364 | if 'opt' in self.opt_version: 365 | output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True) 366 | else: 367 | if i == 0: 368 | output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True) 369 | else: 370 | output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True) 371 | 372 | # Collect and sum the hidden states. 373 | hidden_states = [] 374 | if self.args.shared_emb_dim is not None: 375 | for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs): 376 | hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048) 377 | else: 378 | for idx in self.args.text_emb_layers: 379 | hidden_states.append(output.hidden_states[idx]) 380 | # Add hidden states together. 381 | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, T, 256) 382 | last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True) 383 | output_embeddings.append(last_embedding) 384 | 385 | logits = output.logits[:, -1, :] # (N, vocab_size) 386 | if top_p == 1.0: 387 | logits = logits.cpu() 388 | output_logits.append(logits) 389 | 390 | if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None: 391 | if i < min_word_tokens: 392 | # Eliminate probability of generating [RET] if this is earlier than min_word_tokens. 393 | logits[:, self.retrieval_token_idx] = filter_value 394 | else: 395 | # Multiply by scaling factor. 396 | logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor 397 | 398 | past_key_values = output.past_key_values 399 | 400 | if temperature == 0.0: 401 | if top_p != 1.0: 402 | raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).') 403 | next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1) 404 | else: 405 | logits = logits / temperature 406 | 407 | # Apply top-p filtering. 408 | if top_p < 1.0: 409 | assert top_p > 0, f'top_p should be above 0, got {top_p} instead.' 410 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D) 411 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D) 412 | 413 | # Remove tokens with cumulative probability above the threshold 414 | sorted_indices_to_remove = cumulative_probs > top_p 415 | # Shift the indices to the right to keep also the first token above the threshold 416 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 417 | sorted_indices_to_remove[..., 0] = 0 418 | 419 | for j in range(sorted_indices.shape[0]): 420 | indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]] 421 | logits[j, indices_to_remove] = filter_value 422 | 423 | token_weights = logits.exp() # (N, vocab_size) 424 | next_token = torch.multinomial(token_weights, 1) # (N, 1) 425 | 426 | next_token = next_token.long().to(embeddings.device) 427 | if out is not None: 428 | out = torch.cat([out, next_token], dim=-1) 429 | else: 430 | out = next_token 431 | 432 | if 'opt' in self.opt_version: 433 | next_embedding = self.input_embeddings(next_token) 434 | embeddings = torch.cat([embeddings, next_embedding], dim=1) 435 | elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()): 436 | # End of generation. 437 | break 438 | 439 | return out, output_embeddings, output_logits 440 | 441 | 442 | class Fromage(nn.Module): 443 | def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None, 444 | path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None): 445 | super().__init__() 446 | self.model = FromageModel(tokenizer, model_args) 447 | self.path_array = path_array 448 | self.emb_matrix = emb_matrix 449 | 450 | def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None, 451 | generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0, 452 | ret_scale_factor: float = 1.0, min_word_tokens: int = 0, 453 | mode: str = 'captioning', concat_captions: bool = False, 454 | input_prefix: Optional[str] = None, inference: bool = False) -> Tensor: 455 | if generate: 456 | return self.model.generate(images, num_words, temperature=temperature, top_p=top_p, 457 | min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor) 458 | else: 459 | output = self.model( 460 | pixel_values = images, 461 | labels = tgt_tokens, 462 | caption_len = caption_len, 463 | mode = mode, 464 | concat_captions = concat_captions, 465 | input_prefix = input_prefix, 466 | inference = inference) 467 | return output 468 | 469 | def generate_for_images_and_texts( 470 | self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0, 471 | max_num_rets: int = 1, max_img_per_ret: int = 1): 472 | """ 473 | Encode prompts into embeddings. 474 | 475 | Args: 476 | prompts: List of interleaved PIL.Image.Image and strings representing input to the model. 477 | num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs. 478 | ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs. 479 | top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation. 480 | temperature: Used to modulate logit distribution. 481 | max_num_rets: Maximum number of images to return in one generation pass. 482 | max_img_per_ret: Maximum number of images to return for each [RET] token. 483 | Returns: 484 | return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs. 485 | """ 486 | input_embs = [] 487 | input_ids = [] 488 | add_bos = True 489 | 490 | for i, p in enumerate(prompts): 491 | if type(p) == Image.Image: 492 | # Encode as image. 493 | pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p) 494 | pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) 495 | pixel_values = pixel_values[None, ...] 496 | 497 | visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D) 498 | input_embs.append(visual_embs) 499 | elif type(p) == str: 500 | text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device) 501 | if not add_bos: 502 | # Remove tag. 503 | text_ids = text_ids[:, 1:] 504 | else: 505 | # Only add once. 506 | add_bos = False 507 | 508 | text_embs = self.model.input_embeddings(text_ids) # (1, T, D) 509 | input_embs.append(text_embs) 510 | input_ids.append(text_ids) 511 | else: 512 | raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.') 513 | input_embs = torch.cat(input_embs, dim=1) 514 | input_ids = torch.cat(input_ids, dim=1) 515 | 516 | if num_words == 0: 517 | generated_ids = input_ids 518 | outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True) 519 | # Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens. 520 | out = [] 521 | for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs): 522 | out.append(fc(outputs.hidden_states[x])) 523 | embeddings = torch.stack(out, dim=-1).sum(dim=-1) 524 | embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T, 256) 525 | elif num_words > 0: 526 | generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words, 527 | temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor) 528 | embeddings = generated_embeddings[-1][:, input_embs.shape[1]:] 529 | 530 | # Truncate to newline. 531 | newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0] 532 | trunc_idx = 0 533 | for j in range(generated_ids.shape[1]): 534 | if generated_ids[0, j] == newline_token_id: 535 | trunc_idx = j 536 | break 537 | if trunc_idx > 0: 538 | generated_ids = generated_ids[:, :trunc_idx] 539 | embeddings = embeddings[:, :trunc_idx] 540 | else: 541 | raise ValueError 542 | 543 | # Save outputs as an interleaved list. 544 | return_outputs = [] 545 | # Find up to max_num_rets [RET] tokens, and their corresponding scores. 546 | all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets] 547 | seen_image_idx = [] # Avoid showing the same image multiple times. 548 | 549 | last_ret_idx = 0 550 | if len(all_ret_idx) == 0: 551 | # No [RET] tokens. 552 | caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 553 | return_outputs.append(utils.truncate_caption(caption)) 554 | else: 555 | for ret_idx in all_ret_idx: 556 | ret_emb = embeddings[:, ret_idx, :] 557 | scores = self.emb_matrix @ ret_emb.T 558 | 559 | # Downweight seen images. 560 | for seen_idx in seen_image_idx: 561 | scores[seen_idx, :] -= 1000 562 | 563 | # Get the top max_img_per_ret + 3 (in case some fail) images for each image. 564 | _, top_image_idx = scores.squeeze().topk(max_img_per_ret + 3) 565 | image_outputs = [] 566 | for img_idx in top_image_idx: 567 | # Find the first image that does not error out. 568 | try: 569 | seen_image_idx.append(img_idx) 570 | img = utils.get_image_from_url(self.path_array[img_idx]) 571 | image_outputs.append(img) 572 | if len(image_outputs) == max_img_per_ret: 573 | break 574 | except UnidentifiedImageError: 575 | pass 576 | 577 | caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0] 578 | last_ret_idx = ret_idx + 1 579 | return_outputs.append(utils.truncate_caption(caption) + ' [RET]') 580 | return_outputs.append(image_outputs) 581 | 582 | return return_outputs 583 | 584 | def get_log_lik_scores( 585 | self, prompts: List): 586 | """ 587 | Output the log likelihoods of the given interleaved prompts. 588 | 589 | Args: 590 | prompts: List of interleaved PIL.Image.Image and strings representing input to the model. 591 | Returns: 592 | log lik score of prompt sequence. 593 | """ 594 | input_embs = [] 595 | input_ids = [] 596 | add_bos = True 597 | 598 | for i, p in enumerate(prompts): 599 | if type(p) == Image.Image: 600 | # Encode as image. 601 | pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p) 602 | pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype) 603 | pixel_values = pixel_values[None, ...] 604 | 605 | visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D) 606 | input_embs.append(visual_embs) 607 | id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100 608 | input_ids.append(id_) 609 | elif type(p) == str: 610 | text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device) 611 | if not add_bos: 612 | # Remove tag. 613 | text_ids = text_ids[:, 1:] 614 | else: 615 | # Only add once. 616 | add_bos = False 617 | 618 | text_embs = self.model.input_embeddings(text_ids) # (1, T, D) 619 | input_embs.append(text_embs) 620 | input_ids.append(text_ids) 621 | else: 622 | raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.') 623 | input_embs = torch.cat(input_embs, dim=1) 624 | input_ids = torch.cat(input_ids, dim=1) 625 | 626 | outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True) 627 | return -outputs.loss.item() 628 | 629 | def load_fromage(model_dir: str) -> Fromage: 630 | model_args_path = os.path.join(model_dir, 'model_args.json') 631 | model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar') 632 | embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))] 633 | 634 | if not os.path.exists(model_args_path): 635 | raise ValueError(f'model_args.json does not exist in {model_dir}.') 636 | if not os.path.exists(model_ckpt_path): 637 | raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.') 638 | if len(embs_paths) == 0: 639 | raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.') 640 | 641 | # Load embeddings. 642 | # Construct embedding matrix for nearest neighbor lookup. 643 | path_array = [] 644 | emb_matrix = [] 645 | 646 | # These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`. 647 | for p in embs_paths: 648 | with open(p, 'rb') as wf: 649 | train_embs_data = pkl.load(wf) 650 | path_array.extend(train_embs_data['paths']) 651 | emb_matrix.append(train_embs_data['embeddings']) 652 | emb_matrix = np.concatenate(emb_matrix, axis=0) 653 | 654 | # Number of paths should be equal to number of embeddings. 655 | assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0]) 656 | 657 | with open(model_args_path, 'r') as f: 658 | model_kwargs = json.load(f) 659 | 660 | # Initialize tokenizer. 661 | tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version']) 662 | tokenizer.pad_token = tokenizer.eos_token 663 | # Add special tokens to the model to enable [RET]. 664 | tokenizer.add_special_tokens({"cls_token": "<|image|>"}) 665 | tokenizer.add_tokens('[RET]') 666 | ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids 667 | assert len(ret_token_idx) == 1, ret_token_idx 668 | model_kwargs['retrieval_token_idx'] = ret_token_idx[0] 669 | args = namedtuple('args', model_kwargs)(**model_kwargs) 670 | 671 | # Initialize model for inference. 672 | model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix) 673 | model = model.eval() 674 | model = model.bfloat16() 675 | model = model.cuda() 676 | 677 | # Load pretrained linear mappings and [RET] embeddings. 678 | checkpoint = torch.load(model_ckpt_path) 679 | model.load_state_dict(checkpoint['state_dict'], strict=False) 680 | with torch.no_grad(): 681 | model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze().cpu().detach()) 682 | 683 | logit_scale = model.model.logit_scale.exp() 684 | emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device) 685 | emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True) 686 | emb_matrix = logit_scale * emb_matrix 687 | model.emb_matrix = emb_matrix 688 | 689 | return model 690 | 691 | -------------------------------------------------------------------------------- /fromage/prune_model_ckpt.py: -------------------------------------------------------------------------------- 1 | """Prune pretrained model weights to reduce size. 2 | 3 | This keeps only the weights that we finetune, and discards the pretrained LLM / visual encoder weights. 4 | """ 5 | import torch 6 | import json 7 | from collections import OrderedDict 8 | 9 | ckpt_path = 'ckpt.pth.tar' 10 | pruned_output_path = 'ckpt_pruned.pth.tar' 11 | model_args_path = 'model_args.json' 12 | 13 | if __name__ == '__main__': 14 | with open(model_args_path, 'r') as f: 15 | model_kwargs = json.load(f) 16 | ret_token_idx = model_kwargs['retrieval_token_idx'] 17 | 18 | with open(ckpt_path, 'rb') as f: 19 | checkpoint = torch.load(f) 20 | 21 | stripped_state_dict = { 22 | k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() if 23 | ('.lm' not in k and '.visual_model' not in k) 24 | } 25 | stripped_state_dict = OrderedDict(sorted(stripped_state_dict.items())) 26 | 27 | del checkpoint['epoch'] 28 | print('Best score:', checkpoint['best_score']) 29 | del checkpoint['optimizer'] 30 | del checkpoint['scheduler'] 31 | for k, v in stripped_state_dict.items(): 32 | stripped_state_dict[k] = v.detach().clone() 33 | 34 | # Prune the pretrained token embeddings and keep just [RET]. 35 | ret_embedding = stripped_state_dict['model.input_embeddings.weight'][ret_token_idx:ret_token_idx+1, :].detach().clone() 36 | stripped_state_dict['ret_input_embeddings.weight'] = ret_embedding 37 | 38 | with open(pruned_output_path, 'wb') as f: 39 | torch.save({'state_dict': stripped_state_dict}, f) 40 | -------------------------------------------------------------------------------- /fromage/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import subprocess 3 | import sys 4 | import shutil 5 | import torch 6 | import torch.distributed as dist 7 | from torchvision.transforms import functional as F 8 | from torchvision import transforms as T 9 | from transformers import AutoFeatureExtractor 10 | from PIL import Image, ImageDraw, ImageFont, ImageOps 11 | import requests 12 | from io import BytesIO 13 | 14 | import random 15 | 16 | 17 | def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']): 18 | """Logs git status to stdout.""" 19 | subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file) 20 | subprocess.call('echo', shell=True, stdout=out_file) 21 | exclude_string = '' 22 | subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file) 23 | 24 | 25 | def get_image_from_url(url: str): 26 | response = requests.get(url) 27 | img = Image.open(BytesIO(response.content)) 28 | img = img.resize((224, 224)) 29 | img = img.convert('RGB') 30 | return img 31 | 32 | 33 | def truncate_caption(caption: str) -> str: 34 | """Truncate captions at periods and newlines.""" 35 | caption = caption.strip('\n') 36 | trunc_index = caption.find('\n') + 1 37 | if trunc_index <= 0: 38 | trunc_index = caption.find('.') + 1 39 | if trunc_index > 0: 40 | caption = caption[:trunc_index] 41 | return caption 42 | 43 | 44 | def pad_to_size(x, size=256): 45 | delta_w = size - x.size[0] 46 | delta_h = size - x.size[1] 47 | padding = ( 48 | delta_w // 2, 49 | delta_h // 2, 50 | delta_w - (delta_w // 2), 51 | delta_h - (delta_h // 2), 52 | ) 53 | new_im = ImageOps.expand(x, padding) 54 | return new_im 55 | 56 | 57 | class RandCropResize(object): 58 | 59 | """ 60 | Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092 61 | """ 62 | 63 | def __init__(self, target_size): 64 | self.target_size = target_size 65 | 66 | def __call__(self, img): 67 | img = pad_to_size(img, self.target_size) 68 | d_min = min(img.size) 69 | img = T.RandomCrop(size=d_min)(img) 70 | t_min = min(d_min, round(9 / 8 * self.target_size)) 71 | t_max = min(d_min, round(12 / 8 * self.target_size)) 72 | t = random.randint(t_min, t_max + 1) 73 | img = T.Resize(t)(img) 74 | if min(img.size) < 256: 75 | img = T.Resize(256)(img) 76 | return T.RandomCrop(size=self.target_size)(img) 77 | 78 | 79 | class SquarePad(object): 80 | """Pads image to square. 81 | From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9 82 | """ 83 | def __call__(self, image): 84 | max_wh = max(image.size) 85 | p_left, p_top = [(max_wh - s) // 2 for s in image.size] 86 | p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])] 87 | padding = (p_left, p_top, p_right, p_bottom) 88 | return F.pad(image, padding, 0, 'constant') 89 | 90 | 91 | def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor: 92 | """Creates a (3, nrows * 14, width) image of text. 93 | 94 | Returns: 95 | cap_img: (3, 14 * nrows, width) image of wrapped text. 96 | """ 97 | height = 12 98 | padding = 5 99 | effective_width = width - 2 * padding 100 | # Create a black image to draw text on. 101 | cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0)) 102 | draw = ImageDraw.Draw(cap_img) 103 | draw.text((0, 0), text, color, font=font or ImageFont.load_default()) 104 | cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows) 105 | cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W) 106 | cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W) 107 | # Add zero padding. 108 | cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding]) 109 | return cap_img 110 | 111 | 112 | def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True): 113 | print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.') 114 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) 115 | return feature_extractor 116 | 117 | 118 | def get_pixel_values_for_model(feature_extractor, img): 119 | pixel_values = feature_extractor( 120 | img.convert('RGB'), 121 | return_tensors="pt").pixel_values[0, ...] # (3, H, W) 122 | return pixel_values 123 | 124 | 125 | def save_checkpoint(state, is_best, filename='checkpoint'): 126 | torch.save(state, filename + '.pth.tar') 127 | if is_best: 128 | shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar') 129 | 130 | 131 | def accuracy(output, target, padding, topk=(1,)): 132 | """Computes the accuracy over the k top predictions for the specified values of k""" 133 | with torch.no_grad(): 134 | maxk = max(topk) 135 | if output.shape[-1] < maxk: 136 | print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.") 137 | 138 | maxk = min(maxk, output.shape[-1]) 139 | batch_size = target.size(0) 140 | 141 | # Take topk along the last dimension. 142 | _, pred = output.topk(maxk, -1, True, True) # (N, T, topk) 143 | 144 | mask = (target != padding).type(target.dtype) 145 | target_expand = target[..., None].expand_as(pred) 146 | correct = pred.eq(target_expand) 147 | correct = correct * mask[..., None].expand_as(correct) 148 | 149 | res = [] 150 | for k in topk: 151 | correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True) 152 | res.append(correct_k.mul_(100.0 / mask.sum())) 153 | return res 154 | 155 | 156 | def get_params_count(model, max_name_len: int = 60): 157 | params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()] 158 | total_trainable_params = sum([x[1] for x in params if x[-1]]) 159 | total_nontrainable_params = sum([x[1] for x in params if not x[-1]]) 160 | return params, total_trainable_params, total_nontrainable_params 161 | 162 | 163 | def get_params_count_str(model, max_name_len: int = 60): 164 | padding = 70 # Hardcoded depending on desired amount of padding and separators. 165 | params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len) 166 | param_counts_text = '' 167 | param_counts_text += '=' * (max_name_len + padding) + '\n' 168 | param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n' 169 | param_counts_text += '-' * (max_name_len + padding) + '\n' 170 | for name, param_count, shape, trainable in params: 171 | param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n' 172 | param_counts_text += '-' * (max_name_len + padding) + '\n' 173 | param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n' 174 | param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n' 175 | param_counts_text += '=' * (max_name_len + padding) + '\n' 176 | return param_counts_text 177 | 178 | 179 | class Summary(Enum): 180 | NONE = 0 181 | AVERAGE = 1 182 | SUM = 2 183 | COUNT = 3 184 | 185 | 186 | class ProgressMeter(object): 187 | def __init__(self, num_batches, meters, prefix=""): 188 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 189 | self.meters = meters 190 | self.prefix = prefix 191 | 192 | def display(self, batch): 193 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 194 | entries += [str(meter) for meter in self.meters] 195 | print('\t'.join(entries)) 196 | 197 | def display_summary(self): 198 | entries = [" *"] 199 | entries += [meter.summary() for meter in self.meters] 200 | print(' '.join(entries)) 201 | 202 | def _get_batch_fmtstr(self, num_batches): 203 | num_digits = len(str(num_batches // 1)) 204 | fmt = '{:' + str(num_digits) + 'd}' 205 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 206 | 207 | 208 | class AverageMeter(object): 209 | """Computes and stores the average and current value""" 210 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 211 | self.name = name 212 | self.fmt = fmt 213 | self.summary_type = summary_type 214 | self.reset() 215 | 216 | def reset(self): 217 | self.val = 0 218 | self.avg = 0 219 | self.sum = 0 220 | self.count = 0 221 | 222 | def update(self, val, n=1): 223 | self.val = val 224 | self.sum += val * n 225 | self.count += n 226 | self.avg = self.sum / self.count 227 | 228 | def all_reduce(self): 229 | device = "cuda" if torch.cuda.is_available() else "cpu" 230 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 231 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 232 | self.sum, self.count = total.tolist() 233 | self.avg = self.sum / self.count 234 | 235 | def __str__(self): 236 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 237 | return fmtstr.format(**self.__dict__) 238 | 239 | def summary(self): 240 | fmtstr = '' 241 | if self.summary_type is Summary.NONE: 242 | fmtstr = '' 243 | elif self.summary_type is Summary.AVERAGE: 244 | fmtstr = '{name} {avg:.3f}' 245 | elif self.summary_type is Summary.SUM: 246 | fmtstr = '{name} {sum:.3f}' 247 | elif self.summary_type is Summary.COUNT: 248 | fmtstr = '{name} {count:.3f}' 249 | else: 250 | raise ValueError('invalid summary type %r' % self.summary_type) 251 | 252 | return fmtstr.format(**self.__dict__) 253 | -------------------------------------------------------------------------------- /fromage_model/fromage_vis4/model_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "opt_version": "facebook/opt-6.7b", 3 | "freeze_lm": true, 4 | "visual_encoder": "openai/clip-vit-large-patch14", 5 | "freeze_vm": true, 6 | "n_visual_tokens": 4, 7 | "use_image_embed_norm": false, 8 | "image_embed_dropout_prob": 0.0, 9 | "use_text_embed_layernorm": false, 10 | "text_embed_dropout_prob": 0.0, 11 | "shared_emb_dim": 256, 12 | "text_emb_layers": [ 13 | -1 14 | ], 15 | "retrieval_token_idx": 50266 16 | } -------------------------------------------------------------------------------- /fromage_model/fromage_vis4/pretrained_ckpt.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage_model/fromage_vis4/pretrained_ckpt.pth.tar -------------------------------------------------------------------------------- /fromage_model/model_args.json: -------------------------------------------------------------------------------- 1 | { 2 | "opt_version": "facebook/opt-6.7b", 3 | "task": "multitask", 4 | "freeze_lm": true, 5 | "visual_encoder": "openai/clip-vit-large-patch14", 6 | "freeze_vm": true, 7 | "pretrained_visual": true, 8 | "use_pooler": true, 9 | "n_visual_tokens": 1, 10 | "image_embed_dropout_prob": 0.0, 11 | "text_embed_dropout_prob": 0.0, 12 | "shared_emb_dim": 256, 13 | "text_emb_layers": [ 14 | -1 15 | ], 16 | "append_retrieval_token": true, 17 | "num_appended_retrieval_tokens": 1, 18 | "input_prompt": "A picture of", 19 | "add_input_to_ret": true, 20 | "tunable_prompt_length": 0, 21 | "retrieval_token_idx": 50266 22 | } 23 | -------------------------------------------------------------------------------- /fromage_model/pretrained_ckpt.pth.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage_model/pretrained_ckpt.pth.tar -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Training example. 2 | 3 | Modified from https://github.com/pytorch/examples/blob/main/imagenet/main.py. 4 | """ 5 | import argparse 6 | import json 7 | import os 8 | import sys 9 | import time 10 | import warnings 11 | 12 | import numpy as np 13 | from PIL import Image 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.optim 20 | from torch.optim.lr_scheduler import StepLR 21 | from warmup_scheduler import GradualWarmupScheduler 22 | import torch.multiprocessing as mp 23 | import torch.utils.data 24 | import torch.utils.data.distributed 25 | import torchvision.transforms as transforms 26 | import torchvision.datasets as datasets 27 | from torch.utils.tensorboard import SummaryWriter 28 | import torchvision 29 | 30 | from fromage import data 31 | from fromage import losses as losses_utils 32 | from fromage import models 33 | from fromage import utils 34 | from fromage import evaluate 35 | from transformers import AutoTokenizer 36 | 37 | # Disable HuggingFace tokenizer parallelism. 38 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 39 | 40 | # Available LLM models. 41 | llm_models = ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 42 | 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 43 | 'facebook/opt-66b'] 44 | datasets = ['cc3m'] 45 | best_score = 0 # Variable to keep track of best model so far. 46 | 47 | 48 | def parse_args(args): 49 | parser = argparse.ArgumentParser(description='FROMAGe training') 50 | parser.add_argument('--opt-version', default='facebook/opt-6.7b', 51 | choices=llm_models, 52 | help='OPT versions: ' + 53 | ' | '.join(llm_models) + 54 | ' (default: "facebook/opt-6.7b")') 55 | parser.add_argument('--visual-model', default='openai/clip-vit-large-patch14', type=str, 56 | help="Visual encoder to use.") 57 | parser.add_argument('-d', '--dataset', metavar='DATASET', help='Delimited list of datasets:' + 58 | ' | '.join(datasets), default='cc3m', 59 | type=lambda s: [x for x in s.split(',')]) 60 | 61 | parser.add_argument('--val-dataset', metavar='DATASET', default='cc3m', 62 | type=lambda s: [x for x in s.split(',')], 63 | help='Validation dataset: ' + 64 | ' | '.join(datasets) + 65 | ' (default: cc3m)') 66 | parser.add_argument('--dataset_dir', default='datasets', type=str, 67 | help='Dataset directory containing .tsv files.') 68 | parser.add_argument('--image-dir', default='./data/', type=str, 69 | help='Dataset directory containing image folders.') 70 | parser.add_argument('--log-base-dir', default='./runs/', type=str, 71 | help='Base directory to write logs and ckpts to.') 72 | parser.add_argument('--exp_name', default='frozen', type=str, 73 | help='Name of experiment, used for saving checkpoints.') 74 | 75 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 76 | help='number of data loading workers (default: 4)') 77 | parser.add_argument('--epochs', default=10, type=int, metavar='N', 78 | help='number of total epochs to run') 79 | parser.add_argument('--steps-per-epoch', default=2000, type=int, metavar='N', 80 | help='number of training steps per epoch') 81 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 82 | help='manual epoch number (useful on restarts)') 83 | parser.add_argument('--val-steps-per-epoch', default=-1, type=int, metavar='N', 84 | help='number of validation steps per epoch.') 85 | parser.add_argument('-b', '--batch-size', default=180, type=int, 86 | metavar='N', 87 | help='mini-batch size (default: 180), this is the total ' 88 | 'batch size of all GPUs on the current node when ' 89 | 'using Data Parallel or Distributed Data Parallel') 90 | parser.add_argument('--val-batch-size', default=None, type=int) 91 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float, 92 | metavar='LR', help='initial learning rate', dest='lr') 93 | parser.add_argument('--lr-warmup-steps', default=100, type=int, 94 | metavar='N', help='Number of steps to warm up lr.') 95 | parser.add_argument('--lr-schedule-step-size', default=10, type=int, 96 | metavar='N', help='Number of steps before decaying lr.') 97 | parser.add_argument('--lr-schedule-gamma', default=0.1, type=float, 98 | metavar='N', help='Decay parameter for learning rate scheduler.') 99 | parser.add_argument('--grad-accumulation-steps', default=1, type=int, metavar='N', 100 | help='number of gradient accumulation steps') 101 | parser.add_argument('--grad-clip', default=1.0, type=float, help='gradient clipping amount') 102 | 103 | parser.add_argument('--precision', default='fp32', type=str, choices=['fp32', 'fp16', 'bf16'], help="Precision to train in.") 104 | parser.add_argument('--cap-loss-scale', type=float, default=1.0, help="Scale on captioning loss.") 105 | parser.add_argument('--ret-loss-scale', type=float, default=1.0, help="Scale on retrieval loss.") 106 | 107 | parser.add_argument('--concat-captions-prob', type=float, default=0.5, help="Probability of concatenating two examples sequentially for captioning.") 108 | parser.add_argument('--concat-for-ret', action='store_true', default=False, help="Whether to concatenate examples for retrieval mode.") 109 | parser.add_argument('--input-prompt', default=None, type=str, help="Input prompt for the language model, if any.") 110 | 111 | parser.add_argument('--image-size', default=224, type=int, metavar='N', help='Size of images.') 112 | parser.add_argument('--use_image_embed_norm', action='store_true', default=False, help="Whether to use norm on the image embeddings to make them equal to language.") 113 | parser.add_argument('--image_embed_dropout_prob', type=float, default=0.0, help="Dropout probability on the image embeddings.") 114 | parser.add_argument('--use_text_embed_layernorm', action='store_true', default=False, help="Whether to use layer norm on the text embeddings for retrieval.") 115 | parser.add_argument('--text_embed_dropout_prob', type=float, default=0.0, help="Dropout probability on the text embeddings.") 116 | parser.add_argument('--shared-emb-dim', default=256, type=int, metavar='N', help='Embedding dimension for retrieval.') 117 | parser.add_argument('--text-emb-layers', help='Layer to use for text embeddings. OPT-2.7b has 33 layers.', default='-1', 118 | type=lambda s: [int(x) for x in s.split(',')]) 119 | 120 | parser.add_argument('--max-len', default=24, type=int, 121 | metavar='N', help='Maximum length to truncate captions / generations to.') 122 | parser.add_argument('--n-visual-tokens', default=1, type=int, 123 | metavar='N', help='Number of visual tokens to use for the Frozen model.') 124 | 125 | parser.add_argument('--beta1', default=0.9, type=float, metavar='M', help='beta1 for Adam') 126 | parser.add_argument('--beta2', default=0.95, type=float, metavar='M', help='beta2 for Adam') 127 | parser.add_argument('--wd', '--weight-decay', default=0.0, type=float, 128 | metavar='W', help='weight decay (default: 0.0)', dest='weight_decay') 129 | parser.add_argument('-p', '--print-freq', default=10, type=int, 130 | metavar='N', help='print frequency (default: 10)') 131 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 132 | help='path to latest checkpoint (default: none)') 133 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 134 | help='evaluate model on validation set') 135 | parser.add_argument('--world-size', default=-1, type=int, 136 | help='number of nodes for distributed training') 137 | parser.add_argument('--rank', default=-1, type=int, 138 | help='node rank for distributed training') 139 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:1337', type=str, 140 | help='url used to set up distributed training') 141 | parser.add_argument('--dist-backend', default='nccl', type=str, 142 | help='distributed backend') 143 | parser.add_argument('--seed', default=None, type=int, 144 | help='seed for initializing training. ') 145 | parser.add_argument('--gpu', default=None, type=int, 146 | help='GPU id to use.') 147 | parser.add_argument('--multiprocessing-distributed', action='store_true', 148 | help='Use multi-processing distributed training to launch ' 149 | 'N processes per node, which has N GPUs. This is the ' 150 | 'fastest way to use PyTorch for either single node or ' 151 | 'multi node data parallel training') 152 | return parser.parse_args(args) 153 | 154 | 155 | def main(args): 156 | args = parse_args(args) 157 | i = 1 158 | args.log_dir = os.path.join(args.log_base_dir, args.exp_name) 159 | while os.path.exists(args.log_dir): 160 | args.log_dir = os.path.join(args.log_base_dir, f'{args.exp_name}_{i}') 161 | i += 1 162 | os.makedirs(args.log_dir) 163 | 164 | with open(os.path.join(args.log_dir, f'args.json'), 'w') as wf: 165 | json.dump(vars(args), wf, indent=4) 166 | 167 | with open(os.path.join(args.log_dir, f'git_info.txt'), 'w') as wf: 168 | utils.dump_git_status(out_file=wf) 169 | 170 | print(f'Logging to {args.log_dir}.') 171 | 172 | if args.seed is not None: 173 | torch.manual_seed(args.seed) 174 | cudnn.deterministic = True 175 | warnings.warn('You have chosen to seed training. ' 176 | 'This will turn on the CUDNN deterministic setting, ' 177 | 'which can slow down your training considerably! ' 178 | 'You may see unexpected behavior when restarting ' 179 | 'from checkpoints.') 180 | 181 | if args.gpu is not None: 182 | warnings.warn('You have chosen a specific GPU. This will completely ' 183 | 'disable data parallelism.') 184 | 185 | if args.dist_url == "env://" and args.world_size == -1: 186 | args.world_size = int(os.environ["WORLD_SIZE"]) 187 | 188 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 189 | 190 | ngpus_per_node = torch.cuda.device_count() 191 | if args.multiprocessing_distributed: 192 | # Since we have ngpus_per_node processes per node, the total world_size 193 | # needs to be adjusted accordingly 194 | args.world_size = ngpus_per_node * args.world_size 195 | # Use torch.multiprocessing.spawn to launch distributed processes: the 196 | # main_worker process function 197 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 198 | else: 199 | # Simply call main_worker function 200 | main_worker(args.gpu, ngpus_per_node, args) 201 | 202 | 203 | def main_worker(gpu, ngpus_per_node, args): 204 | """Setup code.""" 205 | global best_score 206 | args.gpu = gpu 207 | 208 | if args.gpu is not None: 209 | print("Use GPU: {} for training".format(args.gpu)) 210 | 211 | if args.distributed: 212 | if args.dist_url == "env://" and args.rank == -1: 213 | args.rank = int(os.environ["RANK"]) 214 | if args.multiprocessing_distributed: 215 | # For multiprocessing distributed training, rank needs to be the 216 | # global rank among all the processes 217 | args.rank = args.rank * ngpus_per_node + gpu 218 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 219 | world_size=args.world_size, rank=args.rank) 220 | 221 | # Create model 222 | model_args = models.FrozenArgs() 223 | model_args.opt_version = args.opt_version 224 | model_args.freeze_lm = True 225 | model_args.visual_encoder = args.visual_model 226 | model_args.freeze_vm = True 227 | model_args.n_visual_tokens = args.n_visual_tokens 228 | model_args.use_image_embed_norm = args.use_image_embed_norm 229 | model_args.image_embed_dropout_prob = args.image_embed_dropout_prob 230 | model_args.use_text_embed_layernorm = args.use_text_embed_layernorm 231 | model_args.text_embed_dropout_prob = args.text_embed_dropout_prob 232 | model_args.shared_emb_dim = args.shared_emb_dim 233 | model_args.text_emb_layers = args.text_emb_layers 234 | 235 | tokenizer = AutoTokenizer.from_pretrained(args.opt_version, use_fast=False) 236 | # Add an image token for loss masking (and visualization) purposes. 237 | tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer 238 | print('Adding [RET] token to vocabulary.') 239 | print('Before adding new token, tokenizer("[RET]") =', tokenizer('[RET]', add_special_tokens=False)) 240 | num_added_tokens = tokenizer.add_tokens('[RET]') 241 | print(f'After adding {num_added_tokens} new tokens, tokenizer("[RET]") =', tokenizer('[RET]', add_special_tokens=False)) 242 | ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids 243 | assert len(ret_token_idx) == 1, ret_token_idx 244 | model_args.retrieval_token_idx = ret_token_idx[0] 245 | args.retrieval_token_idx = ret_token_idx[0] 246 | 247 | # Save model args to disk. 248 | with open(os.path.join(args.log_dir, 'model_args.json'), 'w') as f: 249 | json.dump(vars(model_args), f, indent=4) 250 | 251 | model = models.Fromage(tokenizer, model_args) 252 | if args.precision == 'fp16': 253 | model = model.half() 254 | elif args.precision == 'bf16': 255 | model = model.bfloat16() 256 | 257 | # Print parameters and count of model. 258 | param_counts_text = utils.get_params_count_str(model) 259 | with open(os.path.join(args.log_dir, 'param_count.txt'), 'w') as f: 260 | f.write(param_counts_text) 261 | 262 | # Log trainable parameters to Tensorboard. 263 | _, total_trainable_params, total_nontrainable_params = utils.get_params_count(model) 264 | writer = SummaryWriter(args.log_dir) 265 | writer.add_scalar('params/total', total_trainable_params + total_nontrainable_params, 0) 266 | writer.add_scalar('params/total_trainable', total_trainable_params, 0) 267 | writer.add_scalar('params/total_non_trainable', total_nontrainable_params, 0) 268 | writer.close() 269 | 270 | if not torch.cuda.is_available(): 271 | print('WARNING: using CPU, this will be slow!') 272 | model = torch.nn.DataParallel(model) 273 | elif args.distributed: 274 | # For multiprocessing distributed, DistributedDataParallel constructor 275 | # should always set the single device scope, otherwise, 276 | # DistributedDataParallel will use all available devices. 277 | if args.gpu is not None: 278 | torch.cuda.set_device(args.gpu) 279 | model.cuda(args.gpu) 280 | # When using a single GPU per process and per 281 | # DistributedDataParallel, we need to divide the batch size 282 | # ourselves based on the total number of GPUs of the current node. 283 | args.batch_size = int(args.batch_size / ngpus_per_node) 284 | args.val_batch_size = int((args.val_batch_size or args.batch_size) / ngpus_per_node) 285 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 286 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 287 | else: 288 | model.cuda() 289 | # DistributedDataParallel will divide and allocate batch_size to all 290 | # available GPUs if device_ids are not set 291 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=False) 292 | elif args.gpu is not None: 293 | torch.cuda.set_device(args.gpu) 294 | model = model.cuda(args.gpu) 295 | else: 296 | model = torch.nn.DataParallel(model).cuda() 297 | 298 | # define loss function (criterion), optimizer, and learning rate scheduler 299 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 300 | optimizer_cls = torch.optim.AdamW 301 | print('Using torch.optim.AdamW as the optimizer.') 302 | optimizer = optimizer_cls(model.parameters(), args.lr, 303 | betas=(args.beta1, args.beta2), 304 | weight_decay=args.weight_decay, 305 | eps=1e-8) 306 | 307 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs""" 308 | scheduler_steplr = StepLR(optimizer, step_size=args.lr_schedule_step_size * args.steps_per_epoch, gamma=args.lr_schedule_gamma) 309 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.lr_warmup_steps, after_scheduler=scheduler_steplr) 310 | 311 | # optionally resume from a checkpoint 312 | if args.resume: 313 | if os.path.isfile(args.resume): 314 | print("=> loading checkpoint '{}'".format(args.resume)) 315 | if args.gpu is None: 316 | checkpoint = torch.load(args.resume) 317 | else: 318 | # Map model to be loaded to specified single gpu. 319 | loc = 'cuda:{}'.format(args.gpu) 320 | checkpoint = torch.load(args.resume, map_location=loc) 321 | args.start_epoch = checkpoint['epoch'] 322 | best_score = checkpoint.get('best_score', 0) 323 | model.load_state_dict(checkpoint['state_dict']) 324 | optimizer.load_state_dict(checkpoint['optimizer']) 325 | scheduler.load_state_dict(checkpoint['scheduler']) 326 | print("=> loaded checkpoint '{}' (epoch {})" 327 | .format(args.resume, checkpoint['epoch'])) 328 | else: 329 | print("=> no checkpoint found at '{}'".format(args.resume)) 330 | 331 | cudnn.benchmark = True 332 | 333 | # Data loading code 334 | train_dataset = data.get_dataset(args, 'train', tokenizer) 335 | val_dataset = data.get_dataset(args, 'val', tokenizer) 336 | print(f'Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples.') 337 | 338 | if args.distributed: 339 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, drop_last=True) 340 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) 341 | else: 342 | train_sampler = None 343 | val_sampler = None 344 | 345 | train_loader = torch.utils.data.DataLoader( 346 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 347 | num_workers=args.workers, pin_memory=True, sampler=train_sampler) 348 | val_loader = torch.utils.data.DataLoader( 349 | val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False, 350 | num_workers=args.workers, pin_memory=True, sampler=val_sampler) 351 | 352 | if args.evaluate: 353 | evaluate.validate(val_loader, model, tokenizer, criterion, epoch, args) 354 | return 355 | 356 | for epoch in range(args.start_epoch, args.epochs): 357 | if epoch == 0: 358 | evaluate.validate(val_loader, model, tokenizer, criterion, epoch-1, args) 359 | if args.distributed: 360 | train_sampler.set_epoch(epoch) 361 | 362 | # train for one epoch 363 | train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler, args) 364 | 365 | # evaluate on validation set 366 | eval_score = evaluate.validate(val_loader, model, tokenizer, criterion, epoch, args) 367 | 368 | # remember best score and save checkpoint 369 | is_best = eval_score > best_score 370 | best_score = max(eval_score, best_score) 371 | 372 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 373 | and args.rank % ngpus_per_node == 0): 374 | utils.save_checkpoint({ 375 | 'epoch': epoch + 1, 376 | 'state_dict': model.state_dict(), 377 | 'best_score': best_score, 378 | 'optimizer' : optimizer.state_dict(), 379 | 'scheduler' : scheduler.state_dict() 380 | }, is_best, os.path.join(args.log_dir, 'ckpt')) 381 | 382 | 383 | def train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler, args): 384 | """Main training loop.""" 385 | ngpus_per_node = torch.cuda.device_count() 386 | batch_time = utils.AverageMeter('Time', ':6.3f') 387 | cap_time = utils.AverageMeter('CaptioningTime', ':6.3f') 388 | ret_time = utils.AverageMeter('RetrievalTime', ':6.3f') 389 | data_time = utils.AverageMeter('Data', ':6.3f') 390 | losses = utils.AverageMeter('Loss', ':.4e') 391 | ce_losses = utils.AverageMeter('CeLoss', ':.4e') 392 | top1 = utils.AverageMeter('Acc@1', ':6.2f') 393 | top5 = utils.AverageMeter('Acc@5', ':6.2f') 394 | cont_losses = utils.AverageMeter('ContLoss', ':.4e') 395 | top1_caption = utils.AverageMeter('AccCaption@1', ':6.2f') 396 | top5_caption = utils.AverageMeter('AccCaption@5', ':6.2f') 397 | top1_image = utils.AverageMeter('AccImage@1', ':6.2f') 398 | top5_image = utils.AverageMeter('AccImage@5', ':6.2f') 399 | 400 | writer = SummaryWriter(args.log_dir) 401 | 402 | progress = utils.ProgressMeter( 403 | args.steps_per_epoch, 404 | [batch_time, losses, ce_losses, cont_losses, top1, top5], 405 | prefix="Epoch: [{}]".format(epoch)) 406 | 407 | # switch to train mode 408 | model.train() 409 | 410 | end = time.time() 411 | 412 | for i, (image_paths, images, caption_images, tgt_tokens, token_len) in enumerate(train_loader): 413 | actual_step = epoch * args.steps_per_epoch + i + 1 414 | # measure data loading time 415 | data_time.update(time.time() - end) 416 | 417 | if torch.cuda.is_available(): 418 | images = images.cuda(args.gpu, non_blocking=True) 419 | tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True) 420 | token_len = token_len.cuda(args.gpu, non_blocking=True) 421 | 422 | if args.precision == 'fp16': 423 | images = images.half() 424 | elif args.precision == 'bf16': 425 | images = images.bfloat16() 426 | 427 | model_modes = ['captioning', 'retrieval'] 428 | loss = 0 429 | 430 | for model_mode in model_modes: 431 | mode_start = time.time() 432 | # compute output 433 | concat_captions = np.random.uniform(0, 1) < args.concat_captions_prob 434 | if not args.concat_for_ret: 435 | concat_captions = concat_captions and model_mode == 'captioning' 436 | 437 | (model_output, full_labels, last_embedding, _, visual_embs) = model( 438 | images, tgt_tokens, token_len, mode=model_mode, concat_captions=concat_captions, inference=False) 439 | output = model_output.logits 440 | 441 | # Measure captioning accuracy for multi-task models and next-token prediction for retrieval models. 442 | if model_mode == 'captioning': 443 | acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5)) 444 | top1.update(acc1[0], images.size(0)) 445 | top5.update(acc5[0], images.size(0)) 446 | 447 | ce_loss = model_output.loss 448 | if model_mode == 'captioning': 449 | ce_loss = ce_loss * args.cap_loss_scale 450 | elif model_mode == 'retrieval': 451 | ce_loss = ce_loss * args.ret_loss_scale 452 | else: 453 | raise NotImplementedError 454 | 455 | loss += ce_loss 456 | ce_losses.update(ce_loss.item(), images.size(0)) 457 | 458 | if model_mode == 'retrieval': 459 | # Cross replica concat for embeddings. 460 | if args.distributed: 461 | all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())] 462 | all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())] 463 | dist.all_gather(all_visual_embs, visual_embs) 464 | dist.all_gather(all_last_embedding, last_embedding) 465 | # Overwrite with embeddings produced on this replace, which have the gradient. 466 | all_visual_embs[dist.get_rank()] = visual_embs 467 | all_last_embedding[dist.get_rank()] = last_embedding 468 | visual_embs = torch.cat(all_visual_embs) 469 | last_embedding = torch.cat(all_last_embedding) 470 | 471 | start_idx = args.rank * images.shape[0] 472 | end_idx = start_idx + images.shape[0] 473 | 474 | logits_per_image = visual_embs @ last_embedding.t() 475 | logits_per_text = logits_per_image.t() 476 | if i == 0: 477 | print(f'Running contrastive loss over logits_per_text.shape = {logits_per_text.shape} and logits_per_image.shape = {logits_per_image.shape}') 478 | 479 | # Compute contrastive losses for retrieval. 480 | caption_loss = losses_utils.contrastive_loss(logits_per_text) 481 | image_loss = losses_utils.contrastive_loss(logits_per_image) 482 | caption_acc1, caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5)) 483 | image_acc1, image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5)) 484 | loss += args.ret_loss_scale * (caption_loss + image_loss) / 2.0 485 | cont_losses.update(loss.item(), images.size(0)) 486 | 487 | # measure accuracy and record loss 488 | top1_caption.update(caption_acc1[0], images.size(0)) 489 | top5_caption.update(caption_acc5[0], images.size(0)) 490 | top1_image.update(image_acc1[0], images.size(0)) 491 | top5_image.update(image_acc5[0], images.size(0)) 492 | 493 | if model_mode == 'retrieval': 494 | ret_time.update(time.time() - mode_start) 495 | elif model_mode == 'captioning': 496 | cap_time.update(time.time() - mode_start) 497 | 498 | loss = loss / args.grad_accumulation_steps 499 | losses.update(loss.item(), images.size(0)) 500 | loss.backward() 501 | 502 | # Update weights 503 | if ((i + 1) % args.grad_accumulation_steps == 0) or (i == args.steps_per_epoch - 1): 504 | # Zero out gradients of the embedding matrix outside of [RET]. 505 | for param in model.module.model.input_embeddings.parameters(): 506 | assert param.grad.shape[0] == len(tokenizer) 507 | # Keep other embeddings frozen. 508 | mask = torch.arange(param.grad.shape[0]) != args.retrieval_token_idx 509 | param.grad[mask, :] = 0 510 | 511 | # compute gradient and do SGD step 512 | if args.grad_clip > 0: 513 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 514 | optimizer.step() 515 | optimizer.zero_grad() 516 | 517 | with torch.no_grad(): 518 | # Normalize trainable embeddings. 519 | frozen_norm = torch.norm(model.module.model.input_embeddings.weight[:-1, :], dim=1).mean(0) 520 | trainable_weight = model.module.model.input_embeddings.weight[-1, :] 521 | model.module.model.input_embeddings.weight[-1, :].div_(torch.norm(trainable_weight) / frozen_norm) 522 | 523 | # measure elapsed time 524 | batch_time.update(time.time() - end) 525 | end = time.time() 526 | 527 | if actual_step == 1 or (i + 1) % args.print_freq == 0: 528 | ex_per_sec = args.batch_size / batch_time.avg 529 | if args.distributed: 530 | batch_time.all_reduce() 531 | data_time.all_reduce() 532 | ex_per_sec = (args.batch_size / batch_time.avg) * ngpus_per_node 533 | 534 | losses.all_reduce() 535 | ce_losses.all_reduce() 536 | top1.all_reduce() 537 | top5.all_reduce() 538 | ret_time.all_reduce() 539 | cont_losses.all_reduce() 540 | top1_caption.all_reduce() 541 | top5_caption.all_reduce() 542 | top1_image.all_reduce() 543 | top5_image.all_reduce() 544 | cap_time.all_reduce() 545 | 546 | progress.display(i + 1) 547 | 548 | writer.add_scalar('train/loss', losses.avg, actual_step) 549 | writer.add_scalar('train/ce_loss', ce_losses.avg, actual_step) 550 | writer.add_scalar('train/seq_top1_acc', top1.avg, actual_step) 551 | writer.add_scalar('train/seq_top5_acc', top5.avg, actual_step) 552 | writer.add_scalar('train/contrastive_loss', cont_losses.avg, actual_step) 553 | writer.add_scalar('train/t2i_top1_acc', top1_caption.avg, actual_step) 554 | writer.add_scalar('train/t2i_top5_acc', top5_caption.avg, actual_step) 555 | writer.add_scalar('train/i2t_top1_acc', top1_image.avg, actual_step) 556 | writer.add_scalar('train/i2t_top5_acc', top5_image.avg, actual_step) 557 | writer.add_scalar('metrics/total_secs_per_batch', batch_time.avg, actual_step) 558 | writer.add_scalar('metrics/total_secs_captioning', cap_time.avg, actual_step) 559 | writer.add_scalar('metrics/total_secs_retrieval', ret_time.avg, actual_step) 560 | writer.add_scalar('metrics/data_secs_per_batch', data_time.avg, actual_step) 561 | writer.add_scalar('metrics/examples_per_sec', ex_per_sec, actual_step) 562 | 563 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed 564 | and args.rank % ngpus_per_node == 0): 565 | image_bs = images.shape[0] 566 | normalized_images = images - images.min() 567 | normalized_images /= normalized_images.max() # (N, 3, H, W) 568 | max_images_to_show = 16 569 | 570 | # Append caption text. 571 | pred_tokens = output[:, args.n_visual_tokens-1:-1, :].argmax(dim=-1) 572 | generated_captions = tokenizer.batch_decode(pred_tokens, skip_special_tokens=False) 573 | 574 | # Log image (and generated caption) outputs to Tensorboard. 575 | if model_mode == 'captioning': 576 | # Create generated caption text. 577 | generated_cap_images = torch.stack([ 578 | utils.create_image_of_text( 579 | generated_captions[i].encode('ascii', 'ignore'), 580 | width=normalized_images.shape[3], 581 | color=(255, 255, 0)) 582 | for i in range(len(generated_captions))], axis=0) 583 | 584 | # Duplicate captions if we concatenated them. 585 | if (args.concat_captions_prob > 0 and model_mode == 'captioning' and generated_cap_images.shape[0] != caption_images.shape[0]): 586 | generated_cap_images = torch.cat([generated_cap_images, generated_cap_images], axis=0) 587 | 588 | display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show] 589 | grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4) 590 | writer.add_image('train/images_gen_cap', grid, actual_step) 591 | 592 | # Retrieved images (from text). 593 | retrieved_image_idx = logits_per_text[:image_bs, :image_bs].argmax(-1) 594 | t2i_images = torch.stack( 595 | [normalized_images[retrieved_image_idx[i], ...] for i in range(len(retrieved_image_idx))], 596 | axis=0) 597 | t2i_images = torch.cat([t2i_images.float().cpu(), caption_images], axis=2)[:max_images_to_show] 598 | t2i_grid = torchvision.utils.make_grid(t2i_images, nrow=int(max_images_to_show ** 0.5), padding=4) 599 | writer.add_image('train/t2i_ret', t2i_grid, actual_step) 600 | 601 | # Retrieved text (from image). 602 | retrieved_text_idx = logits_per_image[:image_bs, :image_bs].argmax(-1) 603 | retrieved_text = torch.stack( 604 | [caption_images[retrieved_text_idx[i], ...] for i in range(len(retrieved_text_idx))], 605 | axis=0) 606 | i2t_images = torch.cat([normalized_images.float().cpu(), retrieved_text], axis=2)[:max_images_to_show] 607 | i2t_grid = torchvision.utils.make_grid(i2t_images, nrow=int(max_images_to_show ** 0.5), padding=4) 608 | writer.add_image('train/i2t_ret', i2t_grid, actual_step) 609 | 610 | batch_time.reset() 611 | cap_time.reset() 612 | ret_time.reset() 613 | data_time.reset() 614 | losses.reset() 615 | ce_losses.reset() 616 | top1.reset() 617 | top5.reset() 618 | cont_losses.reset() 619 | top1_caption.reset() 620 | top5_caption.reset() 621 | top1_image.reset() 622 | top5_image.reset() 623 | 624 | if i == args.steps_per_epoch - 1: 625 | break 626 | 627 | scheduler.step() 628 | curr_lr = scheduler.get_last_lr() 629 | if (actual_step == 1) or (i + 1) % args.print_freq == 0: 630 | # Write current learning rate to Tensorboard. 631 | writer = SummaryWriter(args.log_dir) 632 | writer.add_scalar('train/lr', curr_lr[0], actual_step) 633 | writer.close() 634 | 635 | writer.close() 636 | 637 | 638 | if __name__ == '__main__': 639 | main(sys.argv[1:]) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs==22.2.0 2 | certifi==2022.12.7 3 | charset-normalizer==3.0.1 4 | contourpy==1.0.7 5 | cycler==0.11.0 6 | einops==0.4.1 7 | exceptiongroup==1.1.0 8 | filelock==3.9.0 9 | fonttools==4.38.0 10 | huggingface-hub==0.12.0 11 | idna==3.4 12 | iniconfig==2.0.0 13 | kiwisolver==1.4.4 14 | matplotlib==3.6.3 15 | numpy==1.24.2 16 | packaging==23.0 17 | pandas==1.5.3 18 | Pillow==9.4.0 19 | pluggy==1.0.0 20 | pyparsing==3.0.9 21 | pytest==7.2.1 22 | python-dateutil==2.8.2 23 | PyYAML==6.0 24 | regex==2022.10.31 25 | requests==2.28.2 26 | six==1.16.0 27 | tensorboard==2.12.0 28 | tensorboard-data-server==0.7.0 29 | tensorboard-plugin-wit==1.8.1 30 | tokenizers==0.12.1 31 | tomli==2.0.1 32 | torch==1.11.0 33 | torchaudio==0.11.0 34 | torchmetrics==0.9.3 35 | torchvision==0.12.0 36 | tqdm==4.64.1 37 | transformers==4.21.3 38 | typing_extensions==4.4.0 39 | urllib3==1.26.14 40 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git 41 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/teaser.png -------------------------------------------------------------------------------- /teaser_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/teaser_gif.gif -------------------------------------------------------------------------------- /test_main.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import tempfile 3 | 4 | import unittest 5 | import argparse 6 | import main 7 | import os 8 | 9 | 10 | def get_base_args(): 11 | args = [ 12 | '-b', '2', '--opt-version', 'facebook/opt-125m', 13 | '--val-steps-per-epoch', '2', '--epochs', '1', '--steps-per-epoch', '2', 14 | '--text-emb-layers', '-1', '--shared-emb-dim', '256', 15 | '--n-visual-tokens', '1', '--visual-model', 'openai/clip-vit-base-patch32', '--concat-captions-prob', '0.5'] 16 | return args 17 | 18 | def check_workdir_outputs(workdir_path): 19 | workdir_content = os.listdir(workdir_path) 20 | print('workdir content: %s', workdir_content) 21 | 22 | assert 'ckpt.pth.tar' in workdir_content 23 | assert 'model_args.json' in workdir_content 24 | assert 'param_count.txt' in workdir_content 25 | assert 'git_info.txt' in workdir_content 26 | assert any(['events.out.tfevents' in fn for fn in workdir_content]) 27 | 28 | 29 | class MultitaskTrainTest(unittest.TestCase): 30 | """Test captioning.""" 31 | def test_train_and_evaluate(self): 32 | workdir = tempfile.mkdtemp() 33 | proj_root_dir = pathlib.Path(__file__).parents[0] 34 | exp_name = 'test_multitask' 35 | 36 | parser = argparse.ArgumentParser(description='Unit test parser') 37 | args = get_base_args() + ['--log-base-dir', workdir, '--exp_name', exp_name] 38 | main.main(args) 39 | check_workdir_outputs(os.path.join(workdir, exp_name)) 40 | 41 | 42 | if __name__ == '__main__': 43 | unittest.main() 44 | --------------------------------------------------------------------------------