├── .gitattributes
├── .gitignore
├── FROMAGe_example_notebook.ipynb
├── LICENSE
├── README.md
├── datasets
├── cc3m_train.tsv
└── cc3m_val.tsv
├── demo
├── app.py
└── share_btn.py
├── evals
├── VIST_Contextual_Image_Retrieval.ipynb
├── VisDial_Inference_IT2T_Generation.ipynb
├── VisDial_Inference_T2I_Retrieval.ipynb
├── eval_visdial_generation.py
├── eval_visdial_retrieval.py
└── eval_vist_retrieval.py
├── fromage
├── __init__.py
├── data.py
├── evaluate.py
├── extract_img_embs.py
├── losses.py
├── models.py
├── prune_model_ckpt.py
└── utils.py
├── fromage_model
├── fromage_vis4
│ ├── model_args.json
│ └── pretrained_ckpt.pth.tar
├── model_args.json
└── pretrained_ckpt.pth.tar
├── main.py
├── requirements.txt
├── teaser.png
├── teaser_gif.gif
└── test_main.py
/.gitattributes:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/.gitattributes
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | *.pyc
3 | __pycache__
4 | .pytest_cache
5 | venv
6 | runs/
7 | data/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Grounding Language Models to Images for Multimodal Inputs and Outputs
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | This repository hosts the code and model weights for FROMAGe.
10 |
11 | [Paper](https://arxiv.org/abs/2301.13823) | [Project Webpage](https://jykoh.com/fromage) | [Demo](https://huggingface.co/spaces/jykoh/fromage)
12 |
13 |
14 | ## Setup instructions
15 |
16 | ### Environment
17 | Set up a new virtualenv, and install required libraries:
18 | ```
19 | python -m venv venv
20 | source venv/bin/activate
21 | pip install -r requirements.txt
22 | ```
23 |
24 | Add the `fromage` library to PYTHONPATH:
25 | ```
26 | export PYTHONPATH=$PYTHONPATH:/home/path/to/fromage/
27 | ```
28 |
29 | ### Pretrained Checkpoints
30 |
31 | The FROMAGe model weights (linear layers and [RET] embedding) are small (around 11MB), and are included in this Git repo. They will be in the `fromage_model/` folder after cloning. The checkpoint and model config in `fromage_model/` reproduce the results reported in our paper.
32 |
33 | We have also included a second model trained with a stronger visual linear layer (4 visual tokens instead of 1), located at `fromage_model/fromage_vis4`. This model generally does better on dialogue settings and does not require as much tuning of inference time hyperparameters, as it is able to better represent more complex images.
34 |
35 | ### Precomputed Embeddings For Image Retrieval
36 |
37 | The visual embeddings for Conceptual Captions images with valid URLs are precomputed and stored at this [URL](https://drive.google.com/file/d/1wMojZNqEwApNlsCZVvSgQVtZLgbeLoKi/view?usp=share_link). These are used to enable the model to retrieve images. The embeddings take up around 3GB, and are compatible with both model configs we provide. Download the files and place `cc3m_embeddings.pkl` into the `fromage_model/` directory.
38 |
39 | If you wish to precompute these embeddings for a different set of image URLs or for a different model, edit `fromage/extract_img_embs.py` with the list of image URLs and run it:
40 |
41 | ```python fromage/extract_img_embs.py```
42 |
43 |
44 | ## Inference
45 |
46 | Check out `FROMAGe_example_notebook.ipynb` for examples on calling the model for inference. Several of the figures presented in the paper are reproduced in this notebook using greedy decoding of the model. Note that there may be minor differences in image outputs due to CC3M images being lost over time.
47 |
48 |
49 | ## Training
50 |
51 | ### Preparing CC3M
52 |
53 | Our model is trained on the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions) dataset. After following the instructions on the website to download the captions and images, format it into a `.tsv` file as follows:
54 |
55 | ```
56 | caption image
57 | A picture of a cat cat.png
58 | Mountains mountain.png
59 | ```
60 | where each line contains the caption followed by the filename of the image files. Save these `.tsv` files into the `dataset/` folder (the default names expected are `cc3m_train.tsv` and `cc3m_val.tsv`). The repo contains two placeholder files, and you will have to replace them with the appropriate data.
61 |
62 | The corresponding image files should be saved in the `data/` directory. The directory can be changed with the `--image-dir` runtime flag.
63 |
64 |
65 | ### Training FROMAGe
66 |
67 | After preparing CC3M as detailed above, you can start a new training job with the following command line flag:
68 |
69 | ```
70 | randport=$(shuf -i8000-9999 -n1) # Generate a random port number
71 | python -u main.py \
72 | --dist-url "tcp://127.0.0.1:${randport}" --dist-backend 'nccl' \
73 | --multiprocessing-distributed --world-size 1 --rank 0 \
74 | --dataset=cc3m --val-dataset=cc3m \
75 | --opt-version='facebook/opt-6.7b' --visual-model='openai/clip-vit-large-patch14' \
76 | --exp_name='fromage_exp' --image-dir='data/' --log-base-dir='runs/' \
77 | --batch-size=180 --val-batch-size=100 --learning-rate=0.0003 --precision='bf16' --print-freq=100
78 | ```
79 |
80 | On a single A6000 GPU, the model converges within 24 hours (with a batch size of 180). For GPUs with smaller memory available, you might need to reduce the batch size, enable gradient accumulation, or adjust hyperparameters to get good performance. You may also have to disable NCCL P2P with `export NCCL_P2P_DISABLE=1` if you run into issues.
81 |
82 |
83 | ### Pruning Model Weights
84 |
85 | As FROMAGe only consists of a few pretrained linear layers and the `[RET]` embedding, we can discard most of the pretrained weights to save on disk space. If you have trained a new model, and wish to do so, you can use `fromage/prune_model_ckpt.py` to prune the model weights. We used the same script to create the weights in the `fromage_model` directory.
86 |
87 |
88 | ### Unit Tests
89 |
90 | You can also test that the code runs locally by running the unit test with `pytest -x`. This runs a short training and evaluation job, with smaller models, to ensure the code works. The test should complete within approximately 90s.
91 |
92 | Note that because of exception catching (to ensure data errors don't terminate training), the test will silently fail and not terminate if there is an I/O error when reading data. Hence, we recommend running the Python command above for debugging data preprocessing.
93 |
94 |
95 | ## Evaluation
96 |
97 | We provide an evaluation script to reproduce our results on contextual image retrieval on Visual Storytelling (results of Table 1 of our paper). The script can be run from `evals/eval_vist_retrieval.py`. There is also a iPython notebook version (`VIST_Contextual_Image_Retrieval.ipynb`) in the same directory.
98 |
99 | Similarly, we provide scripts to reproduce the text generation and image retrieval results on VisDial (presented in Table 2 of our paper). The script for VisDial text generation can be run from `evals/eval_visdial_generation.py` (or through the notebook version, `VisDial_Inference_IT2T_Generation.ipynb`). This reports the NDCG, MRR, and R@k scores for VisDial.
100 |
101 | The results for image retrieval can be reproduced by running the `evals/eval_visdial_retrieval.py` script (or through the notebook version `VisDial_Inference_T2I_Retrieval.ipynb`), which reports R@k scores.
102 |
103 |
104 | ## Gradio Demo
105 |
106 | You can launch your own version of the Gradio demo locally by running `python demo/app.py`, or duplicating the [HuggingFace space](https://huggingface.co/spaces/jykoh/fromage).
107 |
108 | Check out other unofficial HuggingFace spaces for FROMAGe:
109 | - [alvanlii FROMAGe demo](https://huggingface.co/spaces/alvanlii/FROMAGe)
110 |
111 |
112 | ## Citation
113 |
114 | If you find this work useful, please consider citing:
115 |
116 | ```
117 | @inproceedings{koh2023grounding,
118 | title={Grounding Language Models to Images for Multimodal Inputs and Outputs},
119 | author={Koh, Jing Yu and Salakhutdinov, Ruslan and Fried, Daniel},
120 | journal={ICML},
121 | year={2023}
122 | }
123 | ```
--------------------------------------------------------------------------------
/datasets/cc3m_train.tsv:
--------------------------------------------------------------------------------
1 | caption image
2 | author : a life in photography -- in pictures 0_1595581236
3 | the player staring intently at a computer screen . 3_2841214833
4 | the - bedroom stone cottage can sleep people 5_2227185927
5 | party in the park under cherry blossoms 8_1666482269
6 |
--------------------------------------------------------------------------------
/datasets/cc3m_val.tsv:
--------------------------------------------------------------------------------
1 | caption image
2 | author : a life in photography -- in pictures 0_1595581236
3 | the player staring intently at a computer screen . 3_2841214833
4 | the - bedroom stone cottage can sleep people 5_2227185927
5 | party in the park under cherry blossoms 8_1666482269
6 |
--------------------------------------------------------------------------------
/demo/app.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from share_btn import community_icon_html, loading_icon_html, share_js, save_js
3 | import huggingface_hub
4 | import gradio as gr
5 | from fromage import utils
6 | from fromage import models
7 | import matplotlib.pyplot as plt
8 | from PIL import Image
9 | import torch
10 | import numpy as np
11 | import os
12 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False"
13 |
14 |
15 | css = """
16 | #chatbot { min-height: 300px; }
17 | #save-btn {
18 | background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
19 | }
20 | #save-btn:hover {
21 | background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
22 | }
23 | #share-btn {
24 | background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0));
25 | }
26 | #share-btn:hover {
27 | background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0));
28 | }
29 | #gallery { z-index: 999999; }
30 | #gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;}
31 | #gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;}
32 | @media (hover: none) {
33 | #gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;}
34 | }
35 | """
36 |
37 | examples = [
38 | 'examples/sparrow.png',
39 | 'examples/beaver.png',
40 | 'examples/couch.png',
41 | 'examples/guac.png',
42 | 'examples/scraped_knee.png'
43 | ]
44 |
45 | # Download model from HF Hub.
46 | ckpt_path = huggingface_hub.hf_hub_download(
47 | repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
48 | args_path = huggingface_hub.hf_hub_download(
49 | repo_id='jykoh/fromage', filename='model_args.json')
50 | model = models.load_fromage('./', args_path, ckpt_path)
51 |
52 |
53 | def upload_image(state, image_input):
54 | conversation = state[0]
55 | chat_history = state[1]
56 | input_image = Image.open(image_input.name).resize(
57 | (224, 224)).convert('RGB')
58 | input_image.save(image_input.name) # Overwrite with smaller image.
59 | conversation += [(f' ', "")]
60 | return [conversation, chat_history + [input_image, ""]], conversation
61 |
62 |
63 | def reset():
64 | return [[], []], []
65 |
66 |
67 | def reset_last(state):
68 | conversation = state[0][:-1]
69 | chat_history = state[1][:-2]
70 | return [conversation, chat_history], conversation
71 |
72 |
73 | def save_image_to_local(image: Image.Image):
74 | # TODO(jykoh): Update so the url path is used, to prevent repeat saving.
75 | filename = next(tempfile._get_candidate_names()) + '.png'
76 | image.save(filename)
77 | return filename
78 |
79 |
80 | def generate_for_prompt(input_text, state, ret_scale_factor, max_num_rets, num_words, temperature):
81 | # Ignore empty inputs.
82 | if len(input_text) == 0:
83 | return state, state[0], gr.update(visible=True)
84 |
85 | input_prompt = 'Q: ' + input_text + '\nA:'
86 | conversation = state[0]
87 | chat_history = state[1]
88 | print('Generating for', chat_history, flush=True)
89 |
90 | # If an image was uploaded, prepend it to the model.
91 | model_inputs = chat_history
92 | model_inputs.append(input_prompt)
93 |
94 | top_p = 1.0
95 | if temperature != 0.0:
96 | top_p = 0.95
97 |
98 | print('Running model.generate_for_images_and_texts with',
99 | model_inputs, flush=True)
100 | model_outputs = model.generate_for_images_and_texts(model_inputs,
101 | num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p,
102 | temperature=temperature, max_num_rets=max_num_rets)
103 | print('model_outputs', model_outputs, flush=True)
104 |
105 | im_names = []
106 | response = ''
107 | text_outputs = []
108 | for output_i, output in enumerate(model_outputs):
109 | if type(output) == str:
110 | if output_i > 0:
111 | response += ' '
112 | text_outputs.append(output)
113 | response += output
114 | if len(model_outputs) > 1:
115 | response += ' '
116 | elif type(output) == list:
117 | for image in output:
118 | filename = save_image_to_local(image)
119 | response += f' '
120 | elif type(output) == Image.Image:
121 | filename = save_image_to_local(output)
122 | response += f' '
123 |
124 | chat_history = model_inputs + \
125 | [' '.join([s for s in model_outputs if type(s) == str]) + '\n']
126 | # Remove [RET] from outputs.
127 | conversation.append((input_text, response.replace('[RET]', '')))
128 |
129 | # Set input image to None.
130 | print('state', state, flush=True)
131 | print('updated state', [conversation, chat_history], flush=True)
132 | return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True)
133 |
134 |
135 | with gr.Blocks(css=css) as demo:
136 | gr.HTML("""
137 | 🧀 FROMAGe
138 | This is the official Gradio demo for the FROMAGe model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.
139 |
140 | Paper: Grounding Language Models to Images for Multimodal Generation
141 |
142 | Project Website: FROMAGe Website
143 |
144 | Code and Models: GitHub
145 |
146 |
147 |
148 | Tips:
149 |
150 | Start by inputting either image or text prompts (or both) and chat with FROMAGe to get image-and-text replies.
151 | Tweak the level of sensitivity to images and text using the parameters on the right.
152 | Check out cool conversations in the examples or community tab for inspiration and share your own!
153 | For faster inference without waiting in queue, you may duplicate the space and use your own GPU:
154 |
155 | """)
156 |
157 | gr_state = gr.State([[], []]) # conversation, chat_history
158 |
159 | with gr.Row():
160 | with gr.Column(scale=0.7, min_width=500):
161 | with gr.Row():
162 | chatbot = gr.Chatbot(elem_id="chatbot", label="🧀 FROMAGe Chatbot")
163 | with gr.Row():
164 | image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"])
165 |
166 | text_input = gr.Textbox(label="Message", placeholder="Type a message")
167 |
168 | with gr.Column():
169 | submit_btn = gr.Button(
170 | "Submit", interactive=True, variant="primary")
171 | clear_last_btn = gr.Button("Undo")
172 | clear_btn = gr.Button("Reset All")
173 | with gr.Row(visible=False) as save_group:
174 | save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn")
175 |
176 | with gr.Row(visible=False) as share_group:
177 | share_button = gr.Button("🤗 Share to Community (opens new window)", elem_id="share-btn")
178 |
179 | with gr.Column(scale=0.3, min_width=400):
180 | ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True,
181 | label="Frequency multiplier for returning images (higher means more frequent)")
182 | max_ret_images = gr.Number(
183 | minimum=0, maximum=3, value=2, precision=1, interactive=True, label="Max images to return")
184 | gr_max_len = gr.Slider(minimum=1, maximum=64, value=32,
185 | step=1, interactive=True, label="Max # of words")
186 | gr_temperature = gr.Slider(
187 | minimum=0.0, maximum=1.0, value=0.0, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)")
188 |
189 | gallery = gr.Gallery(
190 | value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery",
191 | ).style(grid=[2], height="auto")
192 |
193 | text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
194 | max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
195 | text_input.submit(lambda: "", None, text_input) # Reset chatbox.
196 | submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor,
197 | max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group])
198 | submit_btn.click(lambda: "", None, text_input) # Reset chatbox.
199 |
200 | image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot])
201 | clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot])
202 | clear_btn.click(reset, [], [gr_state, chatbot])
203 | share_button.click(None, [], [], _js=share_js)
204 | save_button.click(None, [], [], _js=save_js)
205 |
206 |
207 | demo.queue(concurrency_count=1, api_open=False, max_size=16)
208 | demo.launch(debug=True, server_name="0.0.0.0")
209 | # demo.launch(debug=True, server_name="127.0.0.1")
210 |
--------------------------------------------------------------------------------
/demo/share_btn.py:
--------------------------------------------------------------------------------
1 | # Modified from https://huggingface.co/spaces/haoheliu/audioldm-text-to-audio-generation/blob/79681cd8cb235160a27cdd100673346eb1784e53/share_btn.py
2 |
3 | community_icon_html = """
4 |
5 |
6 | """
7 |
8 | loading_icon_html = """ """
12 |
13 | share_js = """
14 | async () => {
15 | const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
16 | async function uploadFile(file) {
17 | console.log(file.type)
18 | const UPLOAD_URL = 'https://huggingface.co/uploads';
19 | const response = await fetch(UPLOAD_URL, {
20 | method: 'POST',
21 | headers: {
22 | 'Content-Type': file.type,
23 | 'X-Requested-With': 'XMLHttpRequest',
24 | },
25 | body: file, /// <- File inherits from Blob
26 | });
27 | const url = await response.text();
28 | return url;
29 | }
30 | async function getImageFile(div) {
31 | return new Promise((resolve, reject) =>
32 | html2canvas(div)
33 | .then((canvas) => {
34 | const imageBlob = canvas.toBlob((blob) => {
35 | const imageId = Date.now();
36 | const fileName = "FROMAGe-" + imageId + ".jpg";
37 | resolve(new File([blob], fileName, { type: 'image/jpeg' }));
38 | }, 'image/jpeg', 0.95);
39 | })
40 |
41 | )
42 | }
43 |
44 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
45 | const chatbotEl = gradioEl.querySelector('#chatbot')
46 | const imageFile = await getImageFile(chatbotEl);
47 | console.log(imageFile);
48 | const urlChatbotImage = await uploadFile(imageFile);
49 | console.log(urlChatbotImage);
50 | let titleTxt = `FROMAGe Example`;
51 |
52 | //const shareBtnEl = gradioEl.querySelector('#share-btn');
53 | //shareBtnEl.style.pointerEvents = 'none';
54 | const descriptionMd = `
55 |
56 |
57 | `;
58 | const params = new URLSearchParams({
59 | title: titleTxt,
60 | description: descriptionMd,
61 | });
62 | const paramsStr = params.toString();
63 | window.open(`https://huggingface.co/spaces/jykoh/fromage/discussions/new?${paramsStr}`, '_blank');
64 | //shareBtnEl.style.removeProperty('pointer-events');
65 | }
66 | """
67 |
68 | save_js = """
69 | async () => {
70 | const html2canvas = (await import('https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.1/html2canvas.esm.js')).default;
71 |
72 | function saveAs(uri, filename) {
73 | var link = document.createElement('a');
74 | if (typeof link.download === 'string') {
75 | link.href = uri;
76 | link.download = filename;
77 |
78 | //Firefox requires the link to be in the body
79 | document.body.appendChild(link);
80 |
81 | //simulate click
82 | link.click();
83 |
84 | //remove the link when done
85 | document.body.removeChild(link);
86 | } else {
87 | window.open(uri);
88 | }
89 | }
90 |
91 | async function getImageFile(div) {
92 | return new Promise((resolve, reject) =>
93 | html2canvas(div)
94 | .then((canvas) => {
95 | const imageId = Date.now();
96 | const fileName = "FROMAGe-" + imageId + ".png";
97 | saveAs(canvas.toDataURL(), fileName);
98 | })
99 |
100 | )
101 | }
102 | const gradioEl = document.querySelector("gradio-app").shadowRoot || document.querySelector('body > gradio-app');
103 | const chatbotEl = gradioEl.querySelector('#chatbot')
104 | const imageFile = await getImageFile(chatbotEl);
105 | console.log(imageFile);
106 | }
107 | """
--------------------------------------------------------------------------------
/evals/VIST_Contextual_Image_Retrieval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "a66bc991",
6 | "metadata": {},
7 | "source": [
8 | "# FROMAGe Contextual Image Retrieval\n",
9 | "\n",
10 | "This is a notebook showcasing the contextual image retrieval results from our paper, [Grounding Language Models to Images for Multimodal Generation](https://arxiv.org/abs/2301.13823). This result is reported in Table 1. The results of this notebook may be slightly different from the paper, as the Flickr images from Visual Storytelling may disappear over time.\n",
11 | "\n",
12 | "At least 18GB of GPU memory is required to run FROMAGe, and it has only been tested on A6000, V100, and 3090 GPUs."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "475add8f",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import numpy as np\n",
23 | "import collections\n",
24 | "import copy\n",
25 | "import json\n",
26 | "import os\n",
27 | "import torch\n",
28 | "from transformers import logging\n",
29 | "from tqdm import notebook\n",
30 | "logging.set_verbosity_error()\n",
31 | "\n",
32 | "from PIL import Image\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "\n",
35 | "from fromage import models\n",
36 | "from fromage import utils"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "7e884127",
42 | "metadata": {},
43 | "source": [
44 | "### Load Pretrained Model"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "4646a124",
51 | "metadata": {
52 | "scrolled": true
53 | },
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.\n",
60 | "Using facebook/opt-6.7b for the language model.\n",
61 | "Using openai/clip-vit-large-patch14 for the visual model with 1 visual tokens.\n"
62 | ]
63 | },
64 | {
65 | "data": {
66 | "application/vnd.jupyter.widget-view+json": {
67 | "model_id": "6bb4489d7d8b4235abe3f965318af27d",
68 | "version_major": 2,
69 | "version_minor": 0
70 | },
71 | "text/plain": [
72 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
73 | ]
74 | },
75 | "metadata": {},
76 | "output_type": "display_data"
77 | },
78 | {
79 | "name": "stdout",
80 | "output_type": "stream",
81 | "text": [
82 | "Freezing the LM.\n",
83 | "Initializing embedding for the retrieval token [RET] (id = 50266).\n",
84 | "Restoring pretrained weights for the visual model.\n",
85 | "Freezing the VM.\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "# Load model used in the paper.\n",
91 | "model_dir = './fromage_model/'\n",
92 | "model = models.load_fromage(model_dir)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "ad8d373b",
98 | "metadata": {},
99 | "source": [
100 | "### Contextual Image Retrieval for Visual Storytelling\n",
101 | "\n",
102 | "Download the Visual Storytelling SIS dataset from [their website](https://visionandlanguage.net/VIST/json_files/story-in-sequence/SIS-with-labels.tar.gz). Extract the files (there should be three sets: train, val, and test). We'll use the val set for reporting results.\n",
103 | "\n",
104 | "First, we'll do some data preprocessing to make things easier for us later on:"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 3,
110 | "id": "6bf39013",
111 | "metadata": {},
112 | "outputs": [
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "8034\n"
118 | ]
119 | }
120 | ],
121 | "source": [
122 | "vist_val_json_path = 'sis/val.story-in-sequence.json'\n",
123 | "with open(vist_val_json_path, 'r') as f:\n",
124 | " vist_data_raw = json.load(f)\n",
125 | " \n",
126 | "# Format into a dictionary of {story_id: data} items.\n",
127 | "vist_data = {\n",
128 | " 'annotations': collections.defaultdict(list)\n",
129 | "}\n",
130 | "used_image_ids = []\n",
131 | "\n",
132 | "\n",
133 | "for ann in vist_data_raw['annotations']:\n",
134 | " assert len(ann) == 1\n",
135 | " ann = ann[0]\n",
136 | " story_id = ann['story_id']\n",
137 | " vist_data['annotations'][story_id].append({\n",
138 | " 'caption': ann['text'],\n",
139 | " 'image_id': ann['photo_flickr_id'],\n",
140 | " 'sequence_index': ann['worker_arranged_photo_order'],\n",
141 | " })\n",
142 | " used_image_ids.append(ann['photo_flickr_id'])\n",
143 | "\n",
144 | "used_image_ids = set(used_image_ids)\n",
145 | "print(len(used_image_ids))"
146 | ]
147 | },
148 | {
149 | "cell_type": "markdown",
150 | "id": "c6c8b664",
151 | "metadata": {},
152 | "source": [
153 | "Then, we can precompute features for all images. This will be used for image retrieval."
154 | ]
155 | },
156 | {
157 | "cell_type": "code",
158 | "execution_count": 4,
159 | "id": "b17d5cca",
160 | "metadata": {},
161 | "outputs": [],
162 | "source": [
163 | "id2url = {}\n",
164 | "\n",
165 | "for image_data in vist_data_raw['images']:\n",
166 | " image_id = image_data['id']\n",
167 | " if image_id in used_image_ids:\n",
168 | " image_url = image_data.get('url_o', None)\n",
169 | " if image_url is not None:\n",
170 | " id2url[image_id] = image_url\n",
171 | "\n",
172 | "# Extract image features.\n",
173 | "embs_fn = 'sis_img_features.npy'\n",
174 | "\n",
175 | "# Compute visual embeddings.\n",
176 | "if not os.path.exists(embs_fn):\n",
177 | " print(f'{embs_fn} does not exist, computing it from scratch.')\n",
178 | " all_visual_embs = []\n",
179 | " all_image_ids = []\n",
180 | "\n",
181 | " for image_id, image_url in notebook.tqdm(id2url.items()):\n",
182 | " try:\n",
183 | " images = utils.get_image_from_url(image_url)\n",
184 | " pixel_values = utils.get_pixel_values_for_model(model.model.feature_extractor, images)\n",
185 | " pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype)\n",
186 | " pixel_values = pixel_values[None, ...]\n",
187 | " visual_embs = model.model.get_visual_embs(pixel_values, mode='retrieval')\n",
188 | " all_visual_embs.append(visual_embs.float().cpu().detach().numpy())\n",
189 | " all_image_ids.append(image_id)\n",
190 | " except Image.UnidentifiedImageError:\n",
191 | " pass\n",
192 | "\n",
193 | " all_image_ids = np.array(all_image_ids)\n",
194 | " all_visual_embs = np.concatenate(all_visual_embs, axis=0)\n",
195 | " assert all_image_ids.shape[0] == all_visual_embs.shape[0], (all_image_ids.shape, all_visual_embs.shape)\n",
196 | " print(all_image_ids.shape, all_visual_embs.shape)\n",
197 | "\n",
198 | " with open(embs_fn, 'wb') as wf:\n",
199 | " np.save(wf, {'image_ids': all_image_ids, 'embeddings': all_visual_embs})\n",
200 | "\n",
201 | "# Load embeddings.\n",
202 | "with open(embs_fn, 'rb') as wf:\n",
203 | " embs_data = np.load(wf, allow_pickle=True).item()\n",
204 | " all_image_ids = embs_data['image_ids']\n",
205 | " emb_matrix = embs_data['embeddings']"
206 | ]
207 | },
208 | {
209 | "cell_type": "code",
210 | "execution_count": 5,
211 | "id": "970a5d15",
212 | "metadata": {},
213 | "outputs": [
214 | {
215 | "name": "stdout",
216 | "output_type": "stream",
217 | "text": [
218 | "emb_matrix.shape torch.Size([7043, 1, 256])\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "len(all_image_ids), emb_matrix.shape\n",
224 | "\n",
225 | "# Normalize embedding matrix to be suitable for image retrieval.\n",
226 | "logit_scale = model.model.logit_scale.exp()\n",
227 | "emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)\n",
228 | "emb_matrix = emb_matrix / emb_matrix.norm(dim=-1, keepdim=True)\n",
229 | "emb_matrix = logit_scale * emb_matrix\n",
230 | "print('emb_matrix.shape', emb_matrix.shape)"
231 | ]
232 | },
233 | {
234 | "cell_type": "markdown",
235 | "id": "944691a6",
236 | "metadata": {},
237 | "source": [
238 | "Then, for each VIST example, we process it as `... [RET]`, providing this as input to FROMAGe, and retrieve the image corresponding to the `[RET]` embedding."
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 6,
244 | "id": "d20c3c02",
245 | "metadata": {
246 | "scrolled": true
247 | },
248 | "outputs": [
249 | {
250 | "data": {
251 | "application/vnd.jupyter.widget-view+json": {
252 | "model_id": "b9e320375086411eaf5e0ca0acc6f6fa",
253 | "version_major": 2,
254 | "version_minor": 0
255 | },
256 | "text/plain": [
257 | " 0%| | 0/4990 [00:00, ?it/s]"
258 | ]
259 | },
260 | "metadata": {},
261 | "output_type": "display_data"
262 | }
263 | ],
264 | "source": [
265 | "topk = (1, 5, 10)\n",
266 | "top_k_preds = {}\n",
267 | "\n",
268 | "with torch.no_grad():\n",
269 | " for story_idx, (story_id, story_data) in notebook.tqdm(enumerate(vist_data['annotations'].items()), total=len(vist_data['annotations'])):\n",
270 | " gt_image_id = story_data[-1]['image_id']\n",
271 | " skip = False # Skip examples that do not have images (due to URLs being taken down, or something)\n",
272 | " for s in story_data:\n",
273 | " if s['image_id'] not in all_image_ids or s['image_id'] not in id2url:\n",
274 | " skip = True\n",
275 | " break\n",
276 | "\n",
277 | " if not skip:\n",
278 | " # Use the first n-1 images and n captions as input.\n",
279 | " image_urls = [id2url[s['image_id']] for s in story_data[:-1]]\n",
280 | " captions = [s['caption'] for s in story_data]\n",
281 | " assert len(image_urls) == len(captions) - 1\n",
282 | "\n",
283 | " visual_embs = []\n",
284 | " # Compute embeddings for the input images.\n",
285 | " images = [utils.get_image_from_url(image_url) for image_url in image_urls]\n",
286 | " pixel_values = [utils.get_pixel_values_for_model(model.model.feature_extractor, image) for image in images]\n",
287 | " pixel_values = torch.stack(pixel_values, dim=0) # (n-1, 3, 224, 224)\n",
288 | " pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype)\n",
289 | " visual_embs = model.model.get_visual_embs(pixel_values, mode='captioning')\n",
290 | "\n",
291 | " # Compute embeddings for the input captions.\n",
292 | " all_input_ids = []\n",
293 | " for i, c in enumerate(captions):\n",
294 | " if i == len(captions) - 1:\n",
295 | " c += '[RET]' # Add the [RET] token to the final caption.\n",
296 | " input_ids = model.model.tokenizer(c, add_special_tokens=True, return_tensors=\"pt\").input_ids.to(emb_matrix.device)\n",
297 | " all_input_ids.append(input_ids)\n",
298 | " \n",
299 | " input_embs = [model.model.input_embeddings(s)[0, ...] for s in all_input_ids] # (N, T, D)\n",
300 | "\n",
301 | " # Interleave captions and images as [caption1, image1, caption2, ..., image4, caption5].\n",
302 | " final_input_embs = []\n",
303 | " assert len(visual_embs) == len(input_embs) - 1\n",
304 | " for i in range(len(images)):\n",
305 | " final_input_embs.append(input_embs[i])\n",
306 | " final_input_embs.append(visual_embs[i])\n",
307 | " final_input_embs.append(input_embs[len(images)])\n",
308 | " final_input_embs = torch.cat(final_input_embs, dim=0)[None, ...] # (1, T, 4096)\n",
309 | " \n",
310 | " # Get embedding of the [RET] token, and compute scores:\n",
311 | " output = model.model.lm(inputs_embeds=final_input_embs, labels=None, use_cache=False, output_hidden_states=True)\n",
312 | " last_hidden_state = model.model.text_hidden_fcs[0](output.hidden_states[-1])\n",
313 | " ret_emb = last_hidden_state[:, -1, :]\n",
314 | "\n",
315 | " ret_emb = ret_emb / ret_emb.norm(dim=1, keepdim=True)\n",
316 | " scores = ret_emb.squeeze() @ emb_matrix.squeeze().T\n",
317 | " \n",
318 | " # Don't retrieve previously seen images.\n",
319 | " prev_image_ids = [s['image_id'] for s in story_data[:-1]]\n",
320 | " for prev_id in prev_image_ids:\n",
321 | " scores[np.where(all_image_ids == prev_id)[0]] -= 10000\n",
322 | " \n",
323 | " # Store top-k preds.\n",
324 | " _, preds = scores.topk(max(topk))\n",
325 | " preds = preds.cpu().detach().numpy()\n",
326 | " preds = [all_image_ids[p] for p in preds]\n",
327 | " top_k_preds[story_id] = {'topk_preds': preds, 'gt': gt_image_id}"
328 | ]
329 | },
330 | {
331 | "cell_type": "markdown",
332 | "id": "aef2a81a",
333 | "metadata": {},
334 | "source": [
335 | "Finally, we can compute Recall@k:"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 7,
341 | "id": "7686317d",
342 | "metadata": {},
343 | "outputs": [
344 | {
345 | "name": "stdout",
346 | "output_type": "stream",
347 | "text": [
348 | "k=1, acc=0.18232\n",
349 | "k=5, acc=0.42682\n",
350 | "k=10, acc=0.51775\n"
351 | ]
352 | }
353 | ],
354 | "source": [
355 | "top_k_accuracy = collections.defaultdict(list)\n",
356 | "\n",
357 | "for story_id, results in top_k_preds.items():\n",
358 | " for k in topk:\n",
359 | " acc = results['gt'] in results['topk_preds'][:k]\n",
360 | " top_k_accuracy[k].append(acc)\n",
361 | "\n",
362 | "for k in topk:\n",
363 | " result_str = f'k={k}, acc={np.mean(top_k_accuracy[k]):.5f}'\n",
364 | " print(result_str)"
365 | ]
366 | },
367 | {
368 | "cell_type": "code",
369 | "execution_count": null,
370 | "id": "807fd749",
371 | "metadata": {},
372 | "outputs": [],
373 | "source": []
374 | }
375 | ],
376 | "metadata": {
377 | "kernelspec": {
378 | "display_name": "Python 3 (ipykernel)",
379 | "language": "python",
380 | "name": "python3"
381 | },
382 | "language_info": {
383 | "codemirror_mode": {
384 | "name": "ipython",
385 | "version": 3
386 | },
387 | "file_extension": ".py",
388 | "mimetype": "text/x-python",
389 | "name": "python",
390 | "nbconvert_exporter": "python",
391 | "pygments_lexer": "ipython3",
392 | "version": "3.10.4"
393 | }
394 | },
395 | "nbformat": 4,
396 | "nbformat_minor": 5
397 | }
398 |
--------------------------------------------------------------------------------
/evals/VisDial_Inference_IT2T_Generation.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "a66bc991",
6 | "metadata": {},
7 | "source": [
8 | "# FROMAGe Visual Dialog (Text Generation)\n",
9 | "\n",
10 | "This is a notebook showcasing the VisDial image-and-text-to-text (IT2T) results from our paper, [Grounding Language Models to Images for Multimodal Generation](https://arxiv.org/abs/2301.13823). This result is reported in Table 2 of the paper. This is the standard [VisDial](https://arxiv.org/abs/1611.08669) evaluation, which measures the ability of models to pick out the correct text answer out of 100 options.\n",
11 | "\n",
12 | "At least 18GB of GPU memory is required to run FROMAGe, and it has only been tested on A6000, V100, and 3090 GPUs."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "475add8f",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import numpy as np\n",
23 | "import collections\n",
24 | "import copy\n",
25 | "import json\n",
26 | "import os\n",
27 | "import torch\n",
28 | "from transformers import logging\n",
29 | "from tqdm import notebook\n",
30 | "logging.set_verbosity_error()\n",
31 | "\n",
32 | "from PIL import Image\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "\n",
35 | "from fromage import models\n",
36 | "from fromage import utils"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "7e884127",
42 | "metadata": {},
43 | "source": [
44 | "### Load Pretrained FROMAGe Model"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "4646a124",
51 | "metadata": {
52 | "scrolled": true
53 | },
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.\n",
60 | "Using facebook/opt-6.7b for the language model.\n",
61 | "Using openai/clip-vit-large-patch14 for the visual model with 1 visual tokens.\n"
62 | ]
63 | },
64 | {
65 | "data": {
66 | "application/vnd.jupyter.widget-view+json": {
67 | "model_id": "7cf51dce17a5479992af55c8dbf341bf",
68 | "version_major": 2,
69 | "version_minor": 0
70 | },
71 | "text/plain": [
72 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
73 | ]
74 | },
75 | "metadata": {},
76 | "output_type": "display_data"
77 | },
78 | {
79 | "name": "stdout",
80 | "output_type": "stream",
81 | "text": [
82 | "Freezing the LM.\n",
83 | "Initializing embedding for the retrieval token [RET] (id = 50266).\n",
84 | "Restoring pretrained weights for the visual model.\n",
85 | "Freezing the VM.\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "# Load model used in the paper.\n",
91 | "model_dir = './fromage_model/'\n",
92 | "model = models.load_fromage(model_dir)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "ad8d373b",
98 | "metadata": {},
99 | "source": [
100 | "### VisDial\n",
101 | "\n",
102 | "Download the VisDial validation [annotations](https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0), the [dense answer annotations](https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0) (for computing MRR) and the [images](https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0). Extract everything to the `VisualDialog` folder.\n",
103 | "\n",
104 | "First, we'll load the annotations, and define the paths to our images and annotations:"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 3,
110 | "id": "6bf39013",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "base_dir = 'VisualDialog/'\n",
115 | "split = 'val'\n",
116 | "img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018')\n",
117 | "\n",
118 | "with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f:\n",
119 | " visdial_data = json.load(f)\n",
120 | " \n",
121 | "with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f:\n",
122 | " dense_data = json.load(f)\n",
123 | "\n",
124 | "# Check that dense and sparse data are aligned.\n",
125 | "assert len(dense_data) == len(visdial_data['data']['dialogs'])\n",
126 | "for i in range(len(dense_data)):\n",
127 | " assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id']\n",
128 | " \n",
129 | "questions = visdial_data['data']['questions']\n",
130 | "answers = visdial_data['data']['answers']\n",
131 | "dialogs = visdial_data['data']['dialogs']"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 4,
137 | "id": "4a4ae3b3",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "def get_pixel_values_from_path(path: str, feature_extractor):\n",
142 | " \"\"\"Helper function for getting images pixels from a local path.\"\"\"\n",
143 | " img = Image.open(path)\n",
144 | " img = img.resize((224, 224))\n",
145 | " img = img.convert('RGB')\n",
146 | " pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)\n",
147 | " if torch.cuda.is_available():\n",
148 | " pixel_values = pixel_values.bfloat16()\n",
149 | " pixel_values = pixel_values.cuda()\n",
150 | " return pixel_values[None, ...]"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "944691a6",
156 | "metadata": {},
157 | "source": [
158 | "Then, for each VisDial example, we compute the loss conditioned on the image and the preceding dialogue. We return the option with the lowest loss as the answer:"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 5,
164 | "id": "d20c3c02",
165 | "metadata": {
166 | "scrolled": true
167 | },
168 | "outputs": [
169 | {
170 | "data": {
171 | "application/vnd.jupyter.widget-view+json": {
172 | "model_id": "8d4c377d84004a969c6603604dd62540",
173 | "version_major": 2,
174 | "version_minor": 0
175 | },
176 | "text/plain": [
177 | " 0%| | 0/2064 [00:00, ?it/s]"
178 | ]
179 | },
180 | "metadata": {},
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "topk = (1, 5, 10)\n",
186 | "# Number of options in a batch to compute loss for.\n",
187 | "# If using a GPU with lower VRAM, this may have to be lowered.\n",
188 | "batch_size = 20\n",
189 | "ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda()\n",
190 | "\n",
191 | "# Save intermediate results to a numpy file, to allow resuming in case of interruptions.\n",
192 | "save_path = 'visdial_results_full.npy'\n",
193 | "if os.path.exists(save_path):\n",
194 | " with open(save_path, 'rb') as rf:\n",
195 | " all_data = np.load(rf, allow_pickle=True).item()\n",
196 | " all_preds = all_data['all_preds']\n",
197 | " all_gt_results = all_data['all_gt_results']\n",
198 | " all_losses = all_data['all_losses']\n",
199 | " assert len(all_preds) == len(all_gt_results) == len(all_losses)\n",
200 | "else:\n",
201 | " # No in progress data, initialize from scratch.\n",
202 | " all_preds = []\n",
203 | " all_gt_results = []\n",
204 | " all_losses = []\n",
205 | "\n",
206 | "for example_idx in notebook.tqdm(range(len(all_preds) // 10, len(dialogs))):\n",
207 | " dialog = dialogs[example_idx]\n",
208 | " image_id = str(dialog['image_id']).rjust(12, '0')\n",
209 | " contexts = []\n",
210 | "\n",
211 | " with torch.no_grad():\n",
212 | " images = get_pixel_values_from_path(\n",
213 | " os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'),\n",
214 | " model.model.feature_extractor)\n",
215 | " visual_embs = model.model.get_visual_embs(images, mode='captioning')\n",
216 | "\n",
217 | " for i in range(len(dialog['dialog'])):\n",
218 | " prev_d = dialog['dialog'][i-1]\n",
219 | " current_d = dialog['dialog'][i]\n",
220 | " if i > 0:\n",
221 | " contexts.append('A: ' + answers[prev_d['answer']])\n",
222 | " contexts.append('Q: ' + questions[current_d['question']] + '?')\n",
223 | " answer_options = [answers[i] for i in current_d['answer_options']]\n",
224 | " answer = answers[current_d['answer']]\n",
225 | " gt_index = current_d['gt_index']\n",
226 | " caption = '\\n'.join(contexts) + '\\nA: '\n",
227 | "\n",
228 | " # Run through every possible option, and pick the option with the lowest loss (= lowest perplexity)\n",
229 | " example_losses = []\n",
230 | " # Tokenize the dialogue sequence (as this is the same for all answer choices).\n",
231 | " caption_ids = model.model.tokenizer(\n",
232 | " caption, add_special_tokens=True, return_tensors=\"pt\").input_ids\n",
233 | " caption_ids = caption_ids.to(images.device)\n",
234 | " caption_embs = model.model.input_embeddings(caption_ids) # (N, T, D)\n",
235 | " condition_length = visual_embs.shape[1] + caption_embs.shape[1]\n",
236 | "\n",
237 | " all_example_embs = []\n",
238 | " all_example_labels = []\n",
239 | "\n",
240 | " for _, ans in enumerate(answer_options):\n",
241 | " ans_ids = model.model.tokenizer(ans, add_special_tokens=True, return_tensors=\"pt\").input_ids\n",
242 | " ans_ids = ans_ids.to(images.device)\n",
243 | " ans_embs = model.model.input_embeddings(ans_ids)\n",
244 | " input_embs = torch.cat([\n",
245 | " visual_embs,\n",
246 | " caption_embs,\n",
247 | " ans_embs], dim=1)\n",
248 | " labels = torch.cat([\n",
249 | " torch.zeros(visual_embs.shape[:-1], device=caption_ids.device, dtype=caption_ids.dtype) - 100,\n",
250 | " caption_ids,\n",
251 | " ans_ids], dim=1)\n",
252 | " assert labels.shape[1] == input_embs.shape[1]\n",
253 | "\n",
254 | " all_example_embs.append(input_embs)\n",
255 | " all_example_labels.append(labels)\n",
256 | "\n",
257 | " max_len = max([x.shape[1] for x in all_example_labels])\n",
258 | " padded_example_embs = [torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[1])) for x in all_example_embs]\n",
259 | " padded_example_embs = torch.cat(padded_example_embs, axis=0)\n",
260 | "\n",
261 | " padded_example_labels = [torch.nn.functional.pad(x, (0, max_len - x.shape[1]), value=-100) for x in all_example_labels]\n",
262 | " padded_example_labels = torch.cat(padded_example_labels, axis=0)\n",
263 | "\n",
264 | " all_logits = []\n",
265 | " batches = int(np.ceil(padded_example_embs.shape[0] / batch_size))\n",
266 | " for i in range(batches):\n",
267 | " start_idx = i * batch_size\n",
268 | " end_idx = start_idx + batch_size\n",
269 | " out = model.model.lm(\n",
270 | " inputs_embeds=padded_example_embs[start_idx:end_idx, ...],\n",
271 | " labels=None,\n",
272 | " use_cache=False,\n",
273 | " output_hidden_states=True)\n",
274 | " all_logits.append(out.logits)\n",
275 | "\n",
276 | " logits = torch.cat(all_logits, dim=0)\n",
277 | " example_losses = ce_loss(logits.reshape((-1, logits.shape[-1])), padded_example_labels.reshape((-1,)))\n",
278 | " example_losses = example_losses.reshape((100, max_len))[:, condition_length:]\n",
279 | " example_losses = example_losses.sum(axis=1)\n",
280 | "\n",
281 | " all_losses.append(example_losses.cpu().float().numpy())\n",
282 | " scores = -example_losses\n",
283 | " _, preds = scores.topk(max(topk))\n",
284 | " all_preds.append(preds)\n",
285 | " all_gt_results.append(gt_index)\n",
286 | "\n",
287 | " with open(save_path, 'wb') as wf:\n",
288 | " np.save(wf, {'all_preds': all_preds, 'all_gt_results': all_gt_results, 'all_losses': all_losses})"
289 | ]
290 | },
291 | {
292 | "cell_type": "markdown",
293 | "id": "aef2a81a",
294 | "metadata": {},
295 | "source": [
296 | "### Computing Results\n",
297 | "\n",
298 | "Finally, we can compute NDCG, MRR, and Recall@k:"
299 | ]
300 | },
301 | {
302 | "cell_type": "code",
303 | "execution_count": 6,
304 | "id": "8c0b673c",
305 | "metadata": {},
306 | "outputs": [],
307 | "source": [
308 | "# Define some classes to help us compute NDCG and MRR.\n",
309 | "# Modified from https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/metrics.py\n",
310 | "\n",
311 | "class NDCG(object):\n",
312 | " def __init__(self):\n",
313 | " self._ndcg_numerator = 0.0\n",
314 | " self._ndcg_denominator = 0.0\n",
315 | "\n",
316 | " def observe(\n",
317 | " self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor\n",
318 | " ):\n",
319 | " \"\"\"\n",
320 | " Observe model output scores and target ground truth relevance and\n",
321 | " accumulate NDCG metric.\n",
322 | " Parameters\n",
323 | " ----------\n",
324 | " predicted_scores: torch.Tensor\n",
325 | " A tensor of shape (batch_size, num_options), because dense\n",
326 | " annotations are available for 1 randomly picked round out of 10.\n",
327 | " target_relevance: torch.Tensor\n",
328 | " A tensor of shape same as predicted scores, indicating ground truth\n",
329 | " relevance of each answer option for a particular round.\n",
330 | " \"\"\"\n",
331 | " predicted_scores = predicted_scores.detach()\n",
332 | "\n",
333 | " # shape: (batch_size, 1, num_options)\n",
334 | " predicted_scores = predicted_scores.unsqueeze(1)\n",
335 | " predicted_ranks = scores_to_ranks(predicted_scores)\n",
336 | "\n",
337 | " # shape: (batch_size, num_options)\n",
338 | " predicted_ranks = predicted_ranks.squeeze(1)\n",
339 | " batch_size, num_options = predicted_ranks.size()\n",
340 | "\n",
341 | " k = torch.sum(target_relevance != 0, dim=-1)\n",
342 | "\n",
343 | " # shape: (batch_size, num_options)\n",
344 | " _, rankings = torch.sort(predicted_ranks, dim=-1)\n",
345 | " # Sort relevance in descending order so highest relevance gets top rnk.\n",
346 | " _, best_rankings = torch.sort(\n",
347 | " target_relevance, dim=-1, descending=True\n",
348 | " )\n",
349 | "\n",
350 | " # shape: (batch_size, )\n",
351 | " batch_ndcg = []\n",
352 | " for batch_index in range(batch_size):\n",
353 | " num_relevant = k[batch_index]\n",
354 | " dcg = self._dcg(\n",
355 | " rankings[batch_index][:num_relevant],\n",
356 | " target_relevance[batch_index],\n",
357 | " )\n",
358 | " best_dcg = self._dcg(\n",
359 | " best_rankings[batch_index][:num_relevant],\n",
360 | " target_relevance[batch_index],\n",
361 | " )\n",
362 | " batch_ndcg.append(dcg / best_dcg)\n",
363 | "\n",
364 | " self._ndcg_denominator += batch_size\n",
365 | " self._ndcg_numerator += sum(batch_ndcg)\n",
366 | "\n",
367 | " def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):\n",
368 | " sorted_relevance = relevance[rankings].cpu().float()\n",
369 | " discounts = torch.log2(torch.arange(len(rankings)).float() + 2)\n",
370 | " return torch.sum(sorted_relevance / discounts, dim=-1)\n",
371 | "\n",
372 | " def retrieve(self, reset: bool = True, key=\"\"):\n",
373 | " if self._ndcg_denominator > 0:\n",
374 | " metrics = {\n",
375 | " key + \"ndcg\": float(self._ndcg_numerator / self._ndcg_denominator)\n",
376 | " }\n",
377 | " else:\n",
378 | " metrics = {}\n",
379 | "\n",
380 | " if reset:\n",
381 | " self.reset()\n",
382 | " return metrics\n",
383 | "\n",
384 | " def reset(self):\n",
385 | " self._ndcg_numerator = 0.0\n",
386 | " self._ndcg_denominator = 0.0\n",
387 | " \n",
388 | "\n",
389 | "def scores_to_ranks(scores: torch.Tensor):\n",
390 | " \"\"\"Convert model output scores into ranks.\"\"\"\n",
391 | " batch_size, num_rounds, num_options = scores.size()\n",
392 | " scores = scores.view(-1, num_options)\n",
393 | "\n",
394 | " # sort in descending order - largest score gets highest rank\n",
395 | " sorted_ranks, ranked_idx = scores.sort(1, descending=True)\n",
396 | "\n",
397 | " # i-th position in ranked_idx specifies which score shall take this\n",
398 | " # position but we want i-th position to have rank of score at that\n",
399 | " # position, do this conversion\n",
400 | " ranks = ranked_idx.clone().fill_(0)\n",
401 | " for i in range(ranked_idx.size(0)):\n",
402 | " for j in range(num_options):\n",
403 | " ranks[i][ranked_idx[i][j]] = j\n",
404 | " # convert from 0-99 ranks to 1-100 ranks\n",
405 | " ranks += 1\n",
406 | " ranks = ranks.view(batch_size, num_rounds, num_options)\n",
407 | " return ranks"
408 | ]
409 | },
410 | {
411 | "cell_type": "code",
412 | "execution_count": 7,
413 | "id": "15e4c2aa",
414 | "metadata": {},
415 | "outputs": [],
416 | "source": [
417 | "with open(save_path, 'rb') as rf:\n",
418 | " all_data = np.load(rf, allow_pickle=True).item()\n",
419 | " all_preds = all_data['all_preds']\n",
420 | " all_gt_results = all_data['all_gt_results']\n",
421 | " all_losses = all_data['all_losses']"
422 | ]
423 | },
424 | {
425 | "cell_type": "code",
426 | "execution_count": 8,
427 | "id": "7686317d",
428 | "metadata": {},
429 | "outputs": [
430 | {
431 | "name": "stdout",
432 | "output_type": "stream",
433 | "text": [
434 | "top-k, k=1, acc=0.17573\n",
435 | "top-k, k=5, acc=0.19971\n",
436 | "top-k, k=10, acc=0.24414\n",
437 | "top-k, k=20, acc=0.48309\n",
438 | "MRR: 0.21997\n",
439 | "NDCG: 0.16594\n"
440 | ]
441 | }
442 | ],
443 | "source": [
444 | "top_k_accuracy = collections.defaultdict(list)\n",
445 | "mrr_results = []\n",
446 | "all_ranks = []\n",
447 | "topk = (1, 5, 10, 20)\n",
448 | "ndcg = NDCG()\n",
449 | "\n",
450 | "assert len(all_preds) == len(all_gt_results)\n",
451 | "for gt, loss in zip(all_gt_results, all_losses):\n",
452 | " scores = -loss\n",
453 | " _, preds = torch.tensor(scores).topk(100)\n",
454 | " rank = np.where(preds == gt)[0][0] + 1\n",
455 | " all_ranks.append(rank)\n",
456 | " mrr_results.append(1 / rank)\n",
457 | "\n",
458 | " for k in topk:\n",
459 | " acc = gt in preds[:k]\n",
460 | " top_k_accuracy[k].append(acc)\n",
461 | " \n",
462 | "dense_mrr = []\n",
463 | "for i in range(len(dense_data)):\n",
464 | " idx = i * 10 + dense_data[i]['round_id']\n",
465 | " if idx >= len(all_losses):\n",
466 | " break\n",
467 | " scores = -torch.tensor(all_losses[idx])[None, :]\n",
468 | " relevance = torch.tensor(dense_data[i]['gt_relevance'])[None, :]\n",
469 | " ndcg.observe(scores, relevance)\n",
470 | " dense_mrr.append(mrr_results[idx])\n",
471 | "\n",
472 | "for k in topk:\n",
473 | " print(f'top-k, k={k}, acc={np.mean(top_k_accuracy[k]):.5f}')\n",
474 | "print(f'MRR: {np.mean(mrr_results):.5f}')\n",
475 | "print(f'NDCG: {ndcg.retrieve(reset=True)[\"ndcg\"]:.5f}')"
476 | ]
477 | },
478 | {
479 | "cell_type": "code",
480 | "execution_count": null,
481 | "id": "1982d6fe",
482 | "metadata": {},
483 | "outputs": [],
484 | "source": []
485 | }
486 | ],
487 | "metadata": {
488 | "kernelspec": {
489 | "display_name": "Python 3 (ipykernel)",
490 | "language": "python",
491 | "name": "python3"
492 | },
493 | "language_info": {
494 | "codemirror_mode": {
495 | "name": "ipython",
496 | "version": 3
497 | },
498 | "file_extension": ".py",
499 | "mimetype": "text/x-python",
500 | "name": "python",
501 | "nbconvert_exporter": "python",
502 | "pygments_lexer": "ipython3",
503 | "version": "3.10.4"
504 | }
505 | },
506 | "nbformat": 4,
507 | "nbformat_minor": 5
508 | }
509 |
--------------------------------------------------------------------------------
/evals/VisDial_Inference_T2I_Retrieval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "a66bc991",
6 | "metadata": {},
7 | "source": [
8 | "# FROMAGe Visual Dialog (Image Retrieval)\n",
9 | "\n",
10 | "This is a notebook reproducing the VisDial text-to-image (T2I) retrieval results from our paper, [Grounding Language Models to Images for Multimodal Inputs and Outputs](https://arxiv.org/abs/2301.13823). This result is reported in Table 2 of the paper. This measures the recall of the model in selecting the appropriate image conditioned on a dialogue sequence.\n",
11 | "\n",
12 | "At least 18GB of GPU memory is required to run FROMAGe, and it has only been tested on A6000, V100, and 3090 GPUs."
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 1,
18 | "id": "475add8f",
19 | "metadata": {},
20 | "outputs": [],
21 | "source": [
22 | "import numpy as np\n",
23 | "import collections\n",
24 | "import copy\n",
25 | "import json\n",
26 | "import os\n",
27 | "import torch\n",
28 | "from transformers import logging\n",
29 | "from tqdm import notebook\n",
30 | "logging.set_verbosity_error()\n",
31 | "\n",
32 | "from PIL import Image\n",
33 | "import matplotlib.pyplot as plt\n",
34 | "\n",
35 | "from fromage import models\n",
36 | "from fromage import utils"
37 | ]
38 | },
39 | {
40 | "cell_type": "markdown",
41 | "id": "7e884127",
42 | "metadata": {},
43 | "source": [
44 | "### Load Pretrained FROMAGe Model"
45 | ]
46 | },
47 | {
48 | "cell_type": "code",
49 | "execution_count": 2,
50 | "id": "4646a124",
51 | "metadata": {
52 | "scrolled": true
53 | },
54 | "outputs": [
55 | {
56 | "name": "stdout",
57 | "output_type": "stream",
58 | "text": [
59 | "Using HuggingFace AutoFeatureExtractor for openai/clip-vit-large-patch14.\n",
60 | "Using facebook/opt-6.7b for the language model.\n",
61 | "Using openai/clip-vit-large-patch14 for the visual model with 1 visual tokens.\n"
62 | ]
63 | },
64 | {
65 | "data": {
66 | "application/vnd.jupyter.widget-view+json": {
67 | "model_id": "b2aa1cfecea64ed6a4d2c3cbb946d463",
68 | "version_major": 2,
69 | "version_minor": 0
70 | },
71 | "text/plain": [
72 | "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
73 | ]
74 | },
75 | "metadata": {},
76 | "output_type": "display_data"
77 | },
78 | {
79 | "name": "stdout",
80 | "output_type": "stream",
81 | "text": [
82 | "Freezing the LM.\n",
83 | "Initializing embedding for the retrieval token [RET] (id = 50266).\n",
84 | "Restoring pretrained weights for the visual model.\n",
85 | "Freezing the VM.\n"
86 | ]
87 | }
88 | ],
89 | "source": [
90 | "# Load model used in the paper.\n",
91 | "model_dir = './fromage_model/'\n",
92 | "model = models.load_fromage(model_dir)"
93 | ]
94 | },
95 | {
96 | "cell_type": "markdown",
97 | "id": "ad8d373b",
98 | "metadata": {},
99 | "source": [
100 | "### VisDial\n",
101 | "\n",
102 | "Download the VisDial validation [annotations](https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0), the [dense answer annotations](https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0) (for computing MRR) and the [images](https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0). Extract everything to the `VisualDialog` folder.\n",
103 | "\n",
104 | "First, we'll do some data preprocessing to make things easier for us later on:"
105 | ]
106 | },
107 | {
108 | "cell_type": "code",
109 | "execution_count": 3,
110 | "id": "6bf39013",
111 | "metadata": {},
112 | "outputs": [],
113 | "source": [
114 | "base_dir = 'VisualDialog/'\n",
115 | "split = 'val'\n",
116 | "img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018')\n",
117 | "\n",
118 | "with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f:\n",
119 | " visdial_data = json.load(f)\n",
120 | " \n",
121 | "with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f:\n",
122 | " dense_data = json.load(f)\n",
123 | "\n",
124 | "# Check that dense and sparse data are aligned.\n",
125 | "assert len(dense_data) == len(visdial_data['data']['dialogs'])\n",
126 | "for i in range(len(dense_data)):\n",
127 | " assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id']\n",
128 | " \n",
129 | "questions = visdial_data['data']['questions']\n",
130 | "answers = visdial_data['data']['answers']\n",
131 | "dialogs = visdial_data['data']['dialogs']"
132 | ]
133 | },
134 | {
135 | "cell_type": "code",
136 | "execution_count": 4,
137 | "id": "418ee205",
138 | "metadata": {},
139 | "outputs": [],
140 | "source": [
141 | "def get_pixel_values_from_path(path: str, feature_extractor):\n",
142 | " \"\"\"Helper function for getting images pixels from a local path.\"\"\"\n",
143 | " img = Image.open(path)\n",
144 | " img = img.resize((224, 224))\n",
145 | " img = img.convert('RGB')\n",
146 | " pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)\n",
147 | " if torch.cuda.is_available():\n",
148 | " pixel_values = pixel_values.bfloat16()\n",
149 | " pixel_values = pixel_values.cuda()\n",
150 | " return pixel_values[None, ...]"
151 | ]
152 | },
153 | {
154 | "cell_type": "markdown",
155 | "id": "944691a6",
156 | "metadata": {},
157 | "source": [
158 | "Then, we compute the image features and text features for each VisDial example:"
159 | ]
160 | },
161 | {
162 | "cell_type": "code",
163 | "execution_count": 5,
164 | "id": "d20c3c02",
165 | "metadata": {
166 | "scrolled": true
167 | },
168 | "outputs": [
169 | {
170 | "data": {
171 | "application/vnd.jupyter.widget-view+json": {
172 | "model_id": "6fdb984f282f45babbe00eb49fbb6ac0",
173 | "version_major": 2,
174 | "version_minor": 0
175 | },
176 | "text/plain": [
177 | " 0%| | 0/2064 [00:00, ?it/s]"
178 | ]
179 | },
180 | "metadata": {},
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "topk = (1, 5, 10)\n",
186 | "ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda()\n",
187 | "\n",
188 | "all_visual_embs = []\n",
189 | "all_text_embs = []\n",
190 | "\n",
191 | "for example_idx in notebook.tqdm(range(len(dialogs))):\n",
192 | " dialog = dialogs[example_idx]\n",
193 | " image_id = str(dialog['image_id']).rjust(12, '0')\n",
194 | " contexts = []\n",
195 | "\n",
196 | " with torch.no_grad():\n",
197 | " images = get_pixel_values_from_path(\n",
198 | " os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'),\n",
199 | " model.model.feature_extractor)\n",
200 | " visual_embs = model.model.get_visual_embs(images, mode='retrieval')\n",
201 | "\n",
202 | " for i in range(len(dialog['dialog'])):\n",
203 | " contexts.append('Q: ' + questions[dialog['dialog'][i]['question']] + '?')\n",
204 | " contexts.append('A: ' + answers[dialog['dialog'][i]['answer']] + '.')\n",
205 | "\n",
206 | " full_context_sent = ' '.join(contexts) + '[RET]'\n",
207 | " input_ids = model.model.tokenizer(full_context_sent, add_special_tokens=True, return_tensors=\"pt\").input_ids\n",
208 | " input_ids = input_ids.cuda()\n",
209 | " input_embs = model.model.input_embeddings(input_ids) # (N, T, D)\n",
210 | " generated_ids, output_embs, _ = model(input_embs, None, None, generate=True, num_words=1, temperature=0.0)\n",
211 | " embeddings = output_embs[0]\n",
212 | "\n",
213 | " full_input_ids = torch.cat([input_ids, generated_ids], dim=1)\n",
214 | " ret_emb = embeddings[:, -1, :] \n",
215 | "\n",
216 | " all_visual_embs.append(visual_embs.cpu().detach().float().numpy())\n",
217 | " all_text_embs.append(ret_emb.cpu().detach().float().numpy())\n",
218 | "\n",
219 | "# Compute scores over the whole dataset:\n",
220 | "scores = np.concatenate(all_visual_embs, axis=0)[:, 0, :] @ np.concatenate(all_text_embs, axis=0).T\n",
221 | "scores = torch.tensor(scores).float()\n",
222 | "assert scores.shape == (2064, 2064), scores.shape"
223 | ]
224 | },
225 | {
226 | "cell_type": "markdown",
227 | "id": "6d0649fc",
228 | "metadata": {},
229 | "source": [
230 | "Finally, we can compute the Recall@k scores:"
231 | ]
232 | },
233 | {
234 | "cell_type": "code",
235 | "execution_count": 6,
236 | "id": "627e9f7c",
237 | "metadata": {},
238 | "outputs": [
239 | {
240 | "name": "stdout",
241 | "output_type": "stream",
242 | "text": [
243 | "top-k, k=1, acc=0.20785\n",
244 | "top-k, k=5, acc=0.44913\n",
245 | "top-k, k=10, acc=0.55959\n",
246 | "====================\n"
247 | ]
248 | }
249 | ],
250 | "source": [
251 | "_, preds = scores.topk(max(topk))\n",
252 | "for k in topk:\n",
253 | " labels = torch.arange(preds.shape[0])\n",
254 | " correct = torch.any(preds[:, :k] == labels[:, None], axis=1).sum()\n",
255 | " acc = correct / preds.shape[0]\n",
256 | " print(f'top-k, k={k}, acc={acc:.5f}')\n",
257 | "print('=' * 20)"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": null,
263 | "id": "d4a01a5b",
264 | "metadata": {},
265 | "outputs": [],
266 | "source": []
267 | }
268 | ],
269 | "metadata": {
270 | "kernelspec": {
271 | "display_name": "Python 3 (ipykernel)",
272 | "language": "python",
273 | "name": "python3"
274 | },
275 | "language_info": {
276 | "codemirror_mode": {
277 | "name": "ipython",
278 | "version": 3
279 | },
280 | "file_extension": ".py",
281 | "mimetype": "text/x-python",
282 | "name": "python",
283 | "nbconvert_exporter": "python",
284 | "pygments_lexer": "ipython3",
285 | "version": "3.10.4"
286 | }
287 | },
288 | "nbformat": 4,
289 | "nbformat_minor": 5
290 | }
291 |
--------------------------------------------------------------------------------
/evals/eval_visdial_generation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | """
5 | This is a script reproducing the VisDial image-and-text-to-image (IT2I) results from our paper,
6 | Grounding Language Models to Images for Multimodal Inputs and Outputs (https://arxiv.org/abs/2301.13823).
7 | This result is reported in Table 2 of the paper. This is the standard VisDial (https://arxiv.org/abs/1611.08669)
8 | evaluation, which measures the ability of models to pick out the correct text answer out of 100 options.
9 | This script reports NDCG, MRR, and R@k results.
10 |
11 | Example usage: `python eval_visdial_generation.py`
12 | """
13 |
14 | import numpy as np
15 | import collections
16 | import copy
17 | import json
18 | import os
19 | import torch
20 | from transformers import logging
21 | from tqdm import notebook
22 | logging.set_verbosity_error()
23 |
24 | from PIL import Image
25 | import matplotlib.pyplot as plt
26 |
27 | from fromage import models
28 | from fromage import utils
29 |
30 |
31 | # Parameters used for eval.
32 | topk = (1, 5, 10)
33 | # Number of options in a batch to compute loss for.
34 | # If using a GPU with lower VRAM, this may have to be lowered.
35 | batch_size = 20
36 |
37 | # Download the VisDial validation annotations (https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0),
38 | # the dense answer annotations (https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0)
39 | # (for computing MRR) and the images (https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0).
40 | # Extract everything to the `VisualDialog` folder.
41 | # First, we'll load the annotations, and define the paths to our images
42 | # and annotations:
43 | base_dir = 'VisualDialog/'
44 | split = 'val'
45 |
46 | # Path to save intermediate results to, to allow resuming in case of interruptions.
47 | save_path = 'visdial_results_full.npy'
48 |
49 |
50 |
51 | def get_pixel_values_from_path(path: str, feature_extractor):
52 | """Helper function for getting images pixels from a local path."""
53 | img = Image.open(path)
54 | img = img.resize((224, 224))
55 | img = img.convert('RGB')
56 | pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)
57 | if torch.cuda.is_available():
58 | pixel_values = pixel_values.bfloat16()
59 | pixel_values = pixel_values.cuda()
60 | return pixel_values[None, ...]
61 |
62 |
63 | # Define some classes to help us compute NDCG and MRR.
64 | # Modified from https://github.com/batra-mlp-lab/visdial-challenge-starter-pytorch/blob/master/visdialch/metrics.py
65 | class NDCG(object):
66 | def __init__(self):
67 | self._ndcg_numerator = 0.0
68 | self._ndcg_denominator = 0.0
69 |
70 | def observe(
71 | self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor
72 | ):
73 | """
74 | Observe model output scores and target ground truth relevance and
75 | accumulate NDCG metric.
76 | Parameters
77 | ----------
78 | predicted_scores: torch.Tensor
79 | A tensor of shape (batch_size, num_options), because dense
80 | annotations are available for 1 randomly picked round out of 10.
81 | target_relevance: torch.Tensor
82 | A tensor of shape same as predicted scores, indicating ground truth
83 | relevance of each answer option for a particular round.
84 | """
85 | predicted_scores = predicted_scores.detach()
86 |
87 | # shape: (batch_size, 1, num_options)
88 | predicted_scores = predicted_scores.unsqueeze(1)
89 | predicted_ranks = scores_to_ranks(predicted_scores)
90 |
91 | # shape: (batch_size, num_options)
92 | predicted_ranks = predicted_ranks.squeeze(1)
93 | batch_size, num_options = predicted_ranks.size()
94 |
95 | k = torch.sum(target_relevance != 0, dim=-1)
96 |
97 | # shape: (batch_size, num_options)
98 | _, rankings = torch.sort(predicted_ranks, dim=-1)
99 | # Sort relevance in descending order so highest relevance gets top rnk.
100 | _, best_rankings = torch.sort(
101 | target_relevance, dim=-1, descending=True
102 | )
103 |
104 | # shape: (batch_size, )
105 | batch_ndcg = []
106 | for batch_index in range(batch_size):
107 | num_relevant = k[batch_index]
108 | dcg = self._dcg(
109 | rankings[batch_index][:num_relevant],
110 | target_relevance[batch_index],
111 | )
112 | best_dcg = self._dcg(
113 | best_rankings[batch_index][:num_relevant],
114 | target_relevance[batch_index],
115 | )
116 | batch_ndcg.append(dcg / best_dcg)
117 |
118 | self._ndcg_denominator += batch_size
119 | self._ndcg_numerator += sum(batch_ndcg)
120 |
121 | def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor):
122 | sorted_relevance = relevance[rankings].cpu().float()
123 | discounts = torch.log2(torch.arange(len(rankings)).float() + 2)
124 | return torch.sum(sorted_relevance / discounts, dim=-1)
125 |
126 | def retrieve(self, reset: bool = True, key=""):
127 | if self._ndcg_denominator > 0:
128 | metrics = {
129 | key + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator)
130 | }
131 | else:
132 | metrics = {}
133 |
134 | if reset:
135 | self.reset()
136 | return metrics
137 |
138 | def reset(self):
139 | self._ndcg_numerator = 0.0
140 | self._ndcg_denominator = 0.0
141 |
142 |
143 | def scores_to_ranks(scores: torch.Tensor):
144 | """Convert model output scores into ranks."""
145 | batch_size, num_rounds, num_options = scores.size()
146 | scores = scores.view(-1, num_options)
147 |
148 | # sort in descending order - largest score gets highest rank
149 | sorted_ranks, ranked_idx = scores.sort(1, descending=True)
150 |
151 | # i-th position in ranked_idx specifies which score shall take this
152 | # position but we want i-th position to have rank of score at that
153 | # position, do this conversion
154 | ranks = ranked_idx.clone().fill_(0)
155 | for i in range(ranked_idx.size(0)):
156 | for j in range(num_options):
157 | ranks[i][ranked_idx[i][j]] = j
158 | # convert from 0-99 ranks to 1-100 ranks
159 | ranks += 1
160 | ranks = ranks.view(batch_size, num_rounds, num_options)
161 | return ranks
162 |
163 |
164 | if __name__ == "__main__":
165 | # Load model used in the paper.
166 | model_dir = './fromage_model/'
167 | model = models.load_fromage(model_dir)
168 |
169 | # Load VisDial data.
170 | img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018')
171 |
172 | with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f:
173 | visdial_data = json.load(f)
174 |
175 | with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f:
176 | dense_data = json.load(f)
177 |
178 | # Check that dense and sparse data are aligned.
179 | assert len(dense_data) == len(visdial_data['data']['dialogs'])
180 | for i in range(len(dense_data)):
181 | assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id']
182 |
183 | questions = visdial_data['data']['questions']
184 | answers = visdial_data['data']['answers']
185 | dialogs = visdial_data['data']['dialogs']
186 |
187 | # Then, for each VisDial example, we compute the loss
188 | # conditioned on the image and the preceding dialogue.
189 | # We return the option with the lowest loss as the answer:
190 | ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda()
191 |
192 | if os.path.exists(save_path):
193 | with open(save_path, 'rb') as rf:
194 | all_data = np.load(rf, allow_pickle=True).item()
195 | all_preds = all_data['all_preds']
196 | all_gt_results = all_data['all_gt_results']
197 | all_losses = all_data['all_losses']
198 | assert len(all_preds) == len(all_gt_results) == len(all_losses)
199 | else:
200 | # No in progress data, initialize from scratch.
201 | all_preds = []
202 | all_gt_results = []
203 | all_losses = []
204 |
205 | for example_idx in notebook.tqdm(range(len(all_preds) // 10, len(dialogs))):
206 | dialog = dialogs[example_idx]
207 | image_id = str(dialog['image_id']).rjust(12, '0')
208 | contexts = []
209 |
210 | with torch.no_grad():
211 | images = get_pixel_values_from_path(
212 | os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'),
213 | model.model.feature_extractor)
214 | visual_embs = model.model.get_visual_embs(images, mode='captioning')
215 |
216 | for i in range(len(dialog['dialog'])):
217 | prev_d = dialog['dialog'][i-1]
218 | current_d = dialog['dialog'][i]
219 | if i > 0:
220 | contexts.append('A: ' + answers[prev_d['answer']])
221 | contexts.append('Q: ' + questions[current_d['question']] + '?')
222 | answer_options = [answers[i] for i in current_d['answer_options']]
223 | answer = answers[current_d['answer']]
224 | gt_index = current_d['gt_index']
225 | caption = '\n'.join(contexts) + '\nA: '
226 |
227 | # Run through every possible option, and pick the option with
228 | # the lowest loss (= lowest perplexity)
229 | example_losses = []
230 | # Tokenize the dialogue sequence (as this is the same for all answer choices).
231 | caption_ids = model.model.tokenizer(
232 | caption, add_special_tokens=True, return_tensors="pt").input_ids
233 | caption_ids = caption_ids.to(images.device)
234 | caption_embs = model.model.input_embeddings(caption_ids) # (N, T, D)
235 | condition_length = visual_embs.shape[1] + caption_embs.shape[1]
236 |
237 | all_example_embs = []
238 | all_example_labels = []
239 |
240 | for _, ans in enumerate(answer_options):
241 | ans_ids = model.model.tokenizer(ans, add_special_tokens=True, return_tensors="pt").input_ids
242 | ans_ids = ans_ids.to(images.device)
243 | ans_embs = model.model.input_embeddings(ans_ids)
244 | input_embs = torch.cat([
245 | visual_embs,
246 | caption_embs,
247 | ans_embs], dim=1)
248 | labels = torch.cat([
249 | torch.zeros(visual_embs.shape[:-1], device=caption_ids.device, dtype=caption_ids.dtype) - 100,
250 | caption_ids,
251 | ans_ids], dim=1)
252 | assert labels.shape[1] == input_embs.shape[1]
253 |
254 | all_example_embs.append(input_embs)
255 | all_example_labels.append(labels)
256 |
257 | max_len = max([x.shape[1] for x in all_example_labels])
258 | padded_example_embs = [torch.nn.functional.pad(x, (0, 0, 0, max_len - x.shape[1])) for x in all_example_embs]
259 | padded_example_embs = torch.cat(padded_example_embs, axis=0)
260 |
261 | padded_example_labels = [torch.nn.functional.pad(x, (0, max_len - x.shape[1]), value=-100) for x in all_example_labels]
262 | padded_example_labels = torch.cat(padded_example_labels, axis=0)
263 |
264 | all_logits = []
265 | batches = int(np.ceil(padded_example_embs.shape[0] / batch_size))
266 | for i in range(batches):
267 | start_idx = i * batch_size
268 | end_idx = start_idx + batch_size
269 | out = model.model.lm(
270 | inputs_embeds=padded_example_embs[start_idx:end_idx, ...],
271 | labels=None,
272 | use_cache=False,
273 | output_hidden_states=True)
274 | all_logits.append(out.logits)
275 |
276 | logits = torch.cat(all_logits, dim=0)
277 | example_losses = ce_loss(logits.reshape((-1, logits.shape[-1])), padded_example_labels.reshape((-1,)))
278 | example_losses = example_losses.reshape((100, max_len))[:, condition_length:]
279 | example_losses = example_losses.sum(axis=1)
280 |
281 | all_losses.append(example_losses.cpu().float().numpy())
282 | scores = -example_losses
283 | _, preds = scores.topk(max(topk))
284 | all_preds.append(preds)
285 | all_gt_results.append(gt_index)
286 |
287 | with open(save_path, 'wb') as wf:
288 | np.save(wf, {'all_preds': all_preds, 'all_gt_results': all_gt_results, 'all_losses': all_losses})
289 |
290 | # Finally, we can compute NDCG, MRR, and Recall@k:
291 | with open(save_path, 'rb') as rf:
292 | all_data = np.load(rf, allow_pickle=True).item()
293 | all_preds = all_data['all_preds']
294 | all_gt_results = all_data['all_gt_results']
295 | all_losses = all_data['all_losses']
296 |
297 | top_k_accuracy = collections.defaultdict(list)
298 | mrr_results = []
299 | all_ranks = []
300 | topk = (1, 5, 10, 20)
301 | ndcg = NDCG()
302 |
303 | assert len(all_preds) == len(all_gt_results)
304 | for gt, loss in zip(all_gt_results, all_losses):
305 | scores = -loss
306 | _, preds = torch.tensor(scores).topk(100)
307 | rank = np.where(preds == gt)[0][0] + 1
308 | all_ranks.append(rank)
309 | mrr_results.append(1 / rank)
310 |
311 | for k in topk:
312 | acc = gt in preds[:k]
313 | top_k_accuracy[k].append(acc)
314 |
315 | dense_mrr = []
316 | for i in range(len(dense_data)):
317 | idx = i * 10 + dense_data[i]['round_id']
318 | if idx >= len(all_losses):
319 | break
320 | scores = -torch.tensor(all_losses[idx])[None, :]
321 | relevance = torch.tensor(dense_data[i]['gt_relevance'])[None, :]
322 | ndcg.observe(scores, relevance)
323 | dense_mrr.append(mrr_results[idx])
324 |
325 | for k in topk:
326 | print(f'top-k, k={k}, acc={np.mean(top_k_accuracy[k]):.5f}')
327 | print(f'MRR: {np.mean(mrr_results):.5f}')
328 | print(f'NDCG: {ndcg.retrieve(reset=True)["ndcg"]:.5f}')
329 |
--------------------------------------------------------------------------------
/evals/eval_visdial_retrieval.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 |
4 | """
5 | This is a script reproducing the VisDial text-to-image (T2I) retrieval results from our paper,
6 | Grounding Language Models to Images for Multimodal Inputs and Outputs (https://arxiv.org/abs/2301.13823).
7 | This result is reported in Table 2 of the paper. This measures the recall of the model in selecting
8 | the appropriate image conditioned on a dialogue sequence.
9 |
10 | Example usage: `python eval_visdial_retrieval.py`
11 | """
12 |
13 | import numpy as np
14 | import collections
15 | import copy
16 | import json
17 | import os
18 | import torch
19 | from transformers import logging
20 | from tqdm import notebook
21 | logging.set_verbosity_error()
22 |
23 | from PIL import Image
24 | import matplotlib.pyplot as plt
25 |
26 | from fromage import models
27 | from fromage import utils
28 |
29 |
30 | def get_pixel_values_from_path(path: str, feature_extractor):
31 | """Helper function for getting images pixels from a local path."""
32 | img = Image.open(path)
33 | img = img.resize((224, 224))
34 | img = img.convert('RGB')
35 | pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)
36 | if torch.cuda.is_available():
37 | pixel_values = pixel_values.bfloat16()
38 | pixel_values = pixel_values.cuda()
39 | return pixel_values[None, ...]
40 |
41 |
42 | if __name__ == "__main__":
43 | # Load model used in the paper.
44 | model_dir = './' #'./fromage_model/'
45 | model = models.load_fromage(model_dir)
46 |
47 |
48 | # Download the VisDial validation annotations (https://www.dropbox.com/s/ibs3a0zhw74zisc/visdial_1.0_val.zip?dl=0),
49 | # the dense answer annotations (https://www.dropbox.com/s/3knyk09ko4xekmc/visdial_1.0_val_dense_annotations.json?dl=0)
50 | # (for computing MRR) and the images (https://www.dropbox.com/s/twmtutniktom7tu/VisualDialog_val2018.zip?dl=0).
51 | # Extract everything to the `VisualDialog` folder.
52 |
53 | base_dir = '/projects/tir6/general/jingyuk/VisualDialog'
54 | split = 'val'
55 | img_dir = os.path.join(base_dir, f'VisualDialog_{split}2018')
56 |
57 | with open(os.path.join(base_dir, f'visdial_1.0_{split}.json'), 'r') as f:
58 | visdial_data = json.load(f)
59 |
60 | with open(os.path.join(base_dir, f'visdial_1.0_{split}_dense_annotations.json'), 'r') as f:
61 | dense_data = json.load(f)
62 |
63 | # Check that dense and sparse data are aligned.
64 | assert len(dense_data) == len(visdial_data['data']['dialogs'])
65 | for i in range(len(dense_data)):
66 | assert dense_data[i]['image_id'] == visdial_data['data']['dialogs'][i]['image_id']
67 |
68 | questions = visdial_data['data']['questions']
69 | answers = visdial_data['data']['answers']
70 | dialogs = visdial_data['data']['dialogs']
71 |
72 | # Then, we compute the image features and text features for each VisDial example:
73 | topk = (1, 5, 10)
74 | ce_loss = torch.nn.CrossEntropyLoss(ignore_index=-100, reduction='none').cuda()
75 |
76 | all_visual_embs = []
77 | all_text_embs = []
78 |
79 | for example_idx in notebook.tqdm(range(len(dialogs))):
80 | dialog = dialogs[example_idx]
81 | image_id = str(dialog['image_id']).rjust(12, '0')
82 | contexts = []
83 |
84 | with torch.no_grad():
85 | images = get_pixel_values_from_path(
86 | os.path.join(img_dir, f'VisualDialog_{split}2018_{image_id}.jpg'),
87 | model.model.feature_extractor)
88 | visual_embs = model.model.get_visual_embs(images, mode='retrieval')
89 |
90 | for i in range(len(dialog['dialog'])):
91 | contexts.append('Q: ' + questions[dialog['dialog'][i]['question']] + '?')
92 | contexts.append('A: ' + answers[dialog['dialog'][i]['answer']] + '.')
93 |
94 | full_context_sent = ' '.join(contexts) + '[RET]'
95 | input_ids = model.model.tokenizer(full_context_sent, add_special_tokens=True, return_tensors="pt").input_ids
96 | input_ids = input_ids.cuda()
97 | input_embs = model.model.input_embeddings(input_ids) # (N, T, D)
98 | generated_ids, output_embs, _ = model(input_embs, None, None, generate=True, num_words=1, temperature=0.0)
99 | embeddings = output_embs[0]
100 |
101 | full_input_ids = torch.cat([input_ids, generated_ids], dim=1)
102 | ret_emb = embeddings[:, -1, :]
103 |
104 | all_visual_embs.append(visual_embs.cpu().detach().float().numpy())
105 | all_text_embs.append(ret_emb.cpu().detach().float().numpy())
106 |
107 | # Compute scores over the whole dataset:
108 | scores = np.concatenate(all_visual_embs, axis=0)[:, 0, :] @ np.concatenate(all_text_embs, axis=0).T
109 | scores = torch.tensor(scores).float()
110 | assert scores.shape == (2064, 2064), scores.shape
111 |
112 |
113 | # Finally, we can compute the Recall@k scores:
114 | _, preds = scores.topk(max(topk))
115 | for k in topk:
116 | labels = torch.arange(preds.shape[0])
117 | correct = torch.any(preds[:, :k] == labels[:, None], axis=1).sum()
118 | acc = correct / preds.shape[0]
119 | print(f'top-k, k={k}, acc={acc:.5f}')
120 | print('=' * 20)
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/evals/eval_vist_retrieval.py:
--------------------------------------------------------------------------------
1 | # This is a script reproducing the contextual image retrieval results from our paper.
2 | # This result is reported in Table 1.
3 | # The results of this may be slightly different from the paper, as the Flickr images from Visual Storytelling may disappear over time.
4 |
5 | import numpy as np
6 | import collections
7 | import copy
8 | import json
9 | import os
10 | import torch
11 | from transformers import logging
12 | from tqdm import tqdm
13 | logging.set_verbosity_error()
14 |
15 | from PIL import Image
16 | import matplotlib.pyplot as plt
17 |
18 | from fromage import models
19 | from fromage import utils
20 |
21 |
22 | # Load model used in the paper.
23 | model_dir = './fromage_model/'
24 | model = models.load_fromage(model_dir)
25 |
26 | # Download the Visual Storytelling SIS dataset from https://visionandlanguage.net/VIST/json_files/story-in-sequence/SIS-with-labels.tar.gz
27 | # Extract the files (there should be three sets: train, val, and test).
28 | # We use the val set for reporting results.
29 | vist_val_json_path = 'sis/val.story-in-sequence.json'
30 | with open(vist_val_json_path, 'r') as f:
31 | vist_data_raw = json.load(f)
32 |
33 | # Format into a dictionary of {story_id: data} items.
34 | vist_data = {
35 | 'annotations': collections.defaultdict(list)
36 | }
37 | used_image_ids = []
38 |
39 |
40 | for ann in vist_data_raw['annotations']:
41 | assert len(ann) == 1
42 | ann = ann[0]
43 | story_id = ann['story_id']
44 | vist_data['annotations'][story_id].append({
45 | 'caption': ann['text'],
46 | 'image_id': ann['photo_flickr_id'],
47 | 'sequence_index': ann['worker_arranged_photo_order'],
48 | })
49 | used_image_ids.append(ann['photo_flickr_id'])
50 |
51 | used_image_ids = set(used_image_ids)
52 | print(len(used_image_ids))
53 |
54 |
55 | # Precompute image features for running retrieval.
56 | embs_fn = 'sis_img_features.npy'
57 | id2url = {}
58 |
59 | for image_data in vist_data_raw['images']:
60 | image_id = image_data['id']
61 | if image_id in used_image_ids:
62 | image_url = image_data.get('url_o', None)
63 | if image_url is not None:
64 | id2url[image_id] = image_url
65 |
66 | if not os.path.exists(embs_fn):
67 | print(f'{embs_fn} does not exist, computing it from scratch.')
68 | all_visual_embs = []
69 | all_image_ids = []
70 |
71 | for image_id, image_url in tqdm(id2url.items()):
72 | try:
73 | images = utils.get_image_from_url(image_url)
74 | pixel_values = utils.get_pixel_values_for_model(model.model.feature_extractor, images)
75 | pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype)
76 | pixel_values = pixel_values[None, ...]
77 | visual_embs = model.model.get_visual_embs(pixel_values, mode='retrieval')
78 | all_visual_embs.append(visual_embs.float().cpu().detach().numpy())
79 | all_image_ids.append(image_id)
80 | except Image.UnidentifiedImageError:
81 | pass
82 |
83 | all_image_ids = np.array(all_image_ids)
84 | all_visual_embs = np.concatenate(all_visual_embs, axis=0)
85 | assert all_image_ids.shape[0] == all_visual_embs.shape[0], (all_image_ids.shape, all_visual_embs.shape)
86 | print(all_image_ids.shape, all_visual_embs.shape)
87 |
88 | with open(embs_fn, 'wb') as wf:
89 | np.save(wf, {'image_ids': all_image_ids, 'embeddings': all_visual_embs})
90 |
91 | # Load embeddings.
92 | with open(embs_fn, 'rb') as wf:
93 | embs_data = np.load(wf, allow_pickle=True).item()
94 | all_image_ids = embs_data['image_ids']
95 | emb_matrix = embs_data['embeddings']
96 |
97 | # Normalize embedding matrix to be suitable for image retrieval.
98 | logit_scale = model.model.logit_scale.exp()
99 | emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
100 | emb_matrix = emb_matrix / emb_matrix.norm(dim=-1, keepdim=True)
101 | emb_matrix = logit_scale * emb_matrix
102 | print('emb_matrix.shape', emb_matrix.shape)
103 |
104 |
105 | # Then, for each VIST example, we process it as `... [RET]`,
106 | # providing this as input to FROMAGe, and retrieve the image corresponding to the `[RET]` embedding.
107 |
108 | topk = (1, 5, 10)
109 | top_k_preds = {}
110 |
111 | with torch.no_grad():
112 | for story_idx, (story_id, story_data) in tqdm(enumerate(vist_data['annotations'].items()), total=len(vist_data['annotations'])):
113 | gt_image_id = story_data[-1]['image_id']
114 | skip = False # Skip examples that do not have images (due to URLs being taken down, or something)
115 | for s in story_data:
116 | if s['image_id'] not in all_image_ids or s['image_id'] not in id2url:
117 | skip = True
118 | break
119 |
120 | if not skip:
121 | # Use the first n-1 images and n captions as input.
122 | image_urls = [id2url[s['image_id']] for s in story_data[:-1]]
123 | captions = [s['caption'] for s in story_data]
124 | assert len(image_urls) == len(captions) - 1
125 |
126 | visual_embs = []
127 | # Compute embeddings for the input images.
128 | images = [utils.get_image_from_url(image_url) for image_url in image_urls]
129 | pixel_values = [utils.get_pixel_values_for_model(model.model.feature_extractor, image) for image in images]
130 | pixel_values = torch.stack(pixel_values, dim=0) # (n-1, 3, 224, 224)
131 | pixel_values = pixel_values.to(device=model.model.logit_scale.device, dtype=model.model.logit_scale.dtype)
132 | visual_embs = model.model.get_visual_embs(pixel_values, mode='captioning')
133 |
134 | # Compute embeddings for the input captions.
135 | all_input_ids = []
136 | for i, c in enumerate(captions):
137 | if i == len(captions) - 1:
138 | c += '[RET]' # Add the [RET] token to the final caption.
139 | input_ids = model.model.tokenizer(c, add_special_tokens=True, return_tensors="pt").input_ids.to(emb_matrix.device)
140 | all_input_ids.append(input_ids)
141 |
142 | input_embs = [model.model.input_embeddings(s)[0, ...] for s in all_input_ids] # (N, T, D)
143 |
144 | # Interleave captions and images as [caption1, image1, caption2, ..., image4, caption5].
145 | final_input_embs = []
146 | assert len(visual_embs) == len(input_embs) - 1
147 | for i in range(len(images)):
148 | final_input_embs.append(input_embs[i])
149 | final_input_embs.append(visual_embs[i])
150 | final_input_embs.append(input_embs[len(images)])
151 | final_input_embs = torch.cat(final_input_embs, dim=0)[None, ...] # (1, T, 4096)
152 |
153 | # Get embedding of the [RET] token, and compute scores:
154 | output = model.model.lm(inputs_embeds=final_input_embs, labels=None, use_cache=False, output_hidden_states=True)
155 | last_hidden_state = model.model.text_hidden_fcs[0](output.hidden_states[-1])
156 | ret_emb = last_hidden_state[:, -1, :]
157 |
158 | ret_emb = ret_emb / ret_emb.norm(dim=1, keepdim=True)
159 | scores = ret_emb.squeeze() @ emb_matrix.squeeze().T
160 |
161 | # Don't retrieve previously seen images.
162 | prev_image_ids = [s['image_id'] for s in story_data[:-1]]
163 | for prev_id in prev_image_ids:
164 | scores[np.where(all_image_ids == prev_id)[0]] -= 10000
165 |
166 | # Store top-k preds.
167 | _, preds = scores.topk(max(topk))
168 | preds = preds.cpu().detach().numpy()
169 | preds = [all_image_ids[p] for p in preds]
170 | top_k_preds[story_id] = {'topk_preds': preds, 'gt': gt_image_id}
171 |
172 |
173 | # Finally, we can compute Recall@k:
174 | top_k_accuracy = collections.defaultdict(list)
175 |
176 | for story_id, results in top_k_preds.items():
177 | for k in topk:
178 | acc = results['gt'] in results['topk_preds'][:k]
179 | top_k_accuracy[k].append(acc)
180 |
181 | for k in topk:
182 | result_str = f'k={k}, acc={np.mean(top_k_accuracy[k]):.5f}'
183 | print(result_str)
184 |
--------------------------------------------------------------------------------
/fromage/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage/__init__.py
--------------------------------------------------------------------------------
/fromage/data.py:
--------------------------------------------------------------------------------
1 | """Modified from https://github.com/mlfoundations/open_clip"""
2 |
3 | from typing import Optional, Tuple
4 |
5 | import collections
6 | import logging
7 | import os
8 | import numpy as np
9 | import pandas as pd
10 | import torch
11 | import torchvision.datasets as datasets
12 | from torchvision import transforms as T
13 | from PIL import Image, ImageFont
14 | from torch.utils.data import Dataset
15 |
16 | from fromage import utils
17 |
18 |
19 | def collate_fn(batch):
20 | batch = list(filter(lambda x: x is not None, batch))
21 | return torch.utils.data.dataloader.default_collate(batch)
22 |
23 |
24 | def get_dataset(args, split: str, tokenizer, precision: str = 'fp32') -> Dataset:
25 | assert split in ['train', 'val'
26 | ], 'Expected split to be one of "train" or "val", got {split} instead.'
27 |
28 | dataset_paths = []
29 | image_data_dirs = []
30 | train = split == 'train'
31 |
32 | # Default configs for datasets.
33 | # Folder structure should look like:
34 | if split == 'train':
35 | if 'cc3m' in args.dataset:
36 | dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_train.tsv'))
37 | image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/training/'))
38 | else:
39 | raise NotImplementedError
40 |
41 | elif split == 'val':
42 | if 'cc3m' in args.val_dataset:
43 | dataset_paths.append(os.path.join(args.dataset_dir, 'cc3m_val.tsv'))
44 | image_data_dirs.append(os.path.join(args.image_dir, 'cc3m/validation'))
45 | else:
46 | raise NotImplementedError
47 |
48 | assert len(dataset_paths) == len(image_data_dirs) == 1, (dataset_paths, image_data_dirs)
49 | else:
50 | raise NotImplementedError
51 |
52 | if len(dataset_paths) > 1:
53 | print(f'{len(dataset_paths)} datasets requested: {dataset_paths}')
54 | dataset = torch.utils.data.ConcatDataset([
55 | CsvDataset(path, image_dir, tokenizer, 'image',
56 | 'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
57 | image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
58 | for (path, image_dir) in zip(dataset_paths, image_data_dirs)])
59 | elif len(dataset_paths) == 1:
60 | dataset = CsvDataset(dataset_paths[0], image_data_dirs[0], tokenizer, 'image',
61 | 'caption', args.visual_model, train=train, max_len=args.max_len, precision=args.precision,
62 | image_size=args.image_size, retrieval_token_idx=args.retrieval_token_idx)
63 | else:
64 | raise ValueError(f'There should be at least one valid dataset, got train={args.dataset}, val={args.val_dataset} instead.')
65 | return dataset
66 |
67 |
68 | class CsvDataset(Dataset):
69 | def __init__(self, input_filename, base_image_dir, tokenizer, img_key,
70 | caption_key, feature_extractor_model: str,
71 | train: bool = True, max_len: int = 32, sep="\t", precision: str = 'fp32',
72 | image_size: int = 224, retrieval_token_idx: int = -1):
73 | logging.debug(f'Loading tsv data from {input_filename}.')
74 | df = pd.read_csv(input_filename, sep=sep)
75 |
76 | self.base_image_dir = base_image_dir
77 | self.images = df[img_key].tolist()
78 | self.captions = df[caption_key].tolist()
79 | assert len(self.images) == len(self.captions)
80 |
81 | self.feature_extractor_model = feature_extractor_model
82 | self.feature_extractor = utils.get_feature_extractor_for_model(
83 | feature_extractor_model, image_size=image_size, train=False)
84 | self.image_size = image_size
85 |
86 | self.tokenizer = tokenizer
87 | self.max_len = max_len
88 | self.precision = precision
89 | self.retrieval_token_idx = retrieval_token_idx
90 |
91 | self.font = None
92 |
93 | logging.debug('Done loading data.')
94 |
95 | def __len__(self):
96 | return len(self.captions)
97 |
98 | def __getitem__(self, idx):
99 | while True:
100 | image_path = os.path.join(self.base_image_dir, str(self.images[idx]))
101 | caption = str(self.captions[idx])
102 |
103 | try:
104 | img = Image.open(image_path)
105 | images = utils.get_pixel_values_for_model(self.feature_extractor, img)
106 |
107 | caption += '[RET]'
108 | tokenized_data = self.tokenizer(
109 | caption,
110 | return_tensors="pt",
111 | padding='max_length',
112 | truncation=True,
113 | max_length=self.max_len)
114 | tokens = tokenized_data.input_ids[0]
115 |
116 | caption_len = tokenized_data.attention_mask[0].sum()
117 |
118 | decode_caption = self.tokenizer.decode(tokens, skip_special_tokens=False)
119 | self.font = self.font or ImageFont.load_default()
120 | cap_img = utils.create_image_of_text(decode_caption.encode('ascii', 'ignore'), width=self.image_size, nrows=2, font=self.font)
121 |
122 | if tokens[-1] not in [self.retrieval_token_idx, self.tokenizer.pad_token_id]:
123 | tokens[-1] = self.retrieval_token_idx
124 |
125 | return image_path, images, cap_img, tokens, caption_len
126 | except Exception as e:
127 | print(f'Error reading {image_path} with caption {caption}: {e}')
128 | # Pick a new example at random.
129 | idx = np.random.randint(0, len(self)-1)
130 |
--------------------------------------------------------------------------------
/fromage/evaluate.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import json
3 | import os
4 | from PIL import Image
5 | import numpy as np
6 | import time
7 | import tqdm
8 | import torch
9 | import torch.distributed as dist
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torchmetrics import BLEUScore
12 | import torchvision
13 |
14 | from fromage import losses as losses_utils
15 | from fromage import utils
16 |
17 |
18 | def validate(val_loader, model, tokenizer, criterion, epoch, args):
19 | ngpus_per_node = torch.cuda.device_count()
20 | writer = SummaryWriter(args.log_dir)
21 | bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3, 4]]
22 | actual_step = (epoch + 1) * args.steps_per_epoch
23 | model_modes = ['captioning', 'retrieval']
24 | num_words = 32 # Number of tokens to generate.
25 |
26 | feature_extractor = utils.get_feature_extractor_for_model(args.visual_model, image_size=args.image_size, train=False)
27 |
28 | def get_pixel_values_from_path(path: str):
29 | img = Image.open(path)
30 | img = img.resize((args.image_size, args.image_size))
31 | pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)[None, ...]
32 |
33 | if args.precision == 'fp16':
34 | pixel_values = pixel_values.half()
35 | elif args.precision == 'bf16':
36 | pixel_values = pixel_values.bfloat16()
37 | if torch.cuda.is_available():
38 | pixel_values = pixel_values.cuda()
39 | return pixel_values
40 |
41 | def run_validate(loader, base_progress=0):
42 | with torch.no_grad():
43 | end = time.time()
44 | all_generated_captions = []
45 | all_gt_captions = []
46 | all_generated_image_paths = []
47 | all_image_features = []
48 | all_text_features = []
49 |
50 | for i, (image_paths, images, caption_images, tgt_tokens, token_len) in tqdm.tqdm(enumerate(loader), position=0, total=len(loader)):
51 | i = base_progress + i
52 |
53 | if torch.cuda.is_available():
54 | tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True)
55 | token_len = token_len.cuda(args.gpu, non_blocking=True)
56 | images = images.cuda()
57 |
58 | if args.precision == 'fp16':
59 | images = images.half()
60 | elif args.precision == 'bf16':
61 | images = images.bfloat16()
62 |
63 | for model_mode in model_modes:
64 | (model_output, full_labels, last_embedding, _, visual_embs) = model(
65 | images, tgt_tokens, token_len, mode=model_mode, input_prefix=args.input_prompt, inference=True) # (N, T, C)
66 |
67 | if model_mode == 'captioning':
68 | loss = args.cap_loss_scale * model_output.loss
69 | elif model_mode == 'retrieval':
70 | loss = args.ret_loss_scale * model_output.loss
71 | else:
72 | raise NotImplementedError
73 |
74 | output = model_output.logits
75 | if model_mode == 'captioning':
76 | acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5))
77 | top1.update(acc1[0], images.size(0))
78 | top5.update(acc5[0], images.size(0))
79 | ce_losses.update(loss.item(), images.size(0))
80 |
81 | if model_mode == 'captioning':
82 | losses.update(loss.item(), images.size(0))
83 | elif model_mode == 'retrieval':
84 | if args.distributed:
85 | original_last_embedding = torch.clone(last_embedding)
86 | all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())]
87 | all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())]
88 | dist.all_gather(all_visual_embs, visual_embs)
89 | dist.all_gather(all_last_embedding, last_embedding)
90 |
91 | # Overwrite with embeddings produced on this replica, which track the gradients.
92 | all_visual_embs[dist.get_rank()] = visual_embs
93 | all_last_embedding[dist.get_rank()] = last_embedding
94 | visual_embs = torch.cat(all_visual_embs)
95 | last_embedding = torch.cat(all_last_embedding)
96 | start_idx = args.rank * images.shape[0]
97 | end_idx = start_idx + images.shape[0]
98 | assert torch.all(last_embedding[start_idx:end_idx] == original_last_embedding), args.rank
99 |
100 | all_text_features.append(last_embedding.cpu())
101 | all_image_features.append(visual_embs.cpu())
102 |
103 | # Run auto-regressive generation sample
104 | if model_mode == 'captioning':
105 | input_embs = model.module.model.get_visual_embs(images, mode='captioning') # (2, n_visual_tokens, D)
106 | if args.input_prompt is not None:
107 | print(f'Adding prefix "{args.input_prompt}" to captioning generate=True.')
108 | prompt_ids = tokenizer(args.input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
109 | prompt_ids = prompt_ids.to(visual_embs.device)
110 | prompt_embs = model.module.model.input_embeddings(prompt_ids)
111 | prompt_embs = prompt_embs.repeat(input_embs.shape[0], 1, 1)
112 | input_embs = torch.cat([input_embs, prompt_embs], dim=1)
113 |
114 | generated_ids, _, _ = model(input_embs, tgt_tokens, token_len,
115 | generate=True, num_words=num_words, temperature=0.0, top_p=1.0,
116 | min_word_tokens=num_words)
117 |
118 | if args.distributed and ngpus_per_node > 1:
119 | all_generated_ids = [torch.zeros_like(generated_ids) for _ in range(dist.get_world_size())]
120 | dist.all_gather(all_generated_ids, generated_ids)
121 | all_generated_ids[dist.get_rank()] = generated_ids
122 | generated_ids = torch.cat(all_generated_ids)
123 |
124 | all_tgt_tokens = [torch.zeros_like(tgt_tokens) for _ in range(dist.get_world_size())]
125 | dist.all_gather(all_tgt_tokens, tgt_tokens)
126 | all_tgt_tokens[dist.get_rank()] = tgt_tokens
127 | all_tgt_tokens = torch.cat(all_tgt_tokens)
128 |
129 | all_image_paths = [[None for _ in image_paths] for _ in range(dist.get_world_size())]
130 | dist.all_gather_object(all_image_paths, image_paths)
131 | all_image_paths[dist.get_rank()] = image_paths
132 | image_paths = []
133 | for p in all_image_paths:
134 | image_paths.extend(p)
135 | else:
136 | all_tgt_tokens = tgt_tokens
137 |
138 | all_tgt_tokens[all_tgt_tokens == -100] = tokenizer.pad_token_id
139 | generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
140 | gt_captions = tokenizer.batch_decode(all_tgt_tokens, skip_special_tokens=True)
141 |
142 | for cap_i in range(len(generated_captions)):
143 | image_path = image_paths[cap_i]
144 | all_generated_image_paths.append(image_path)
145 | stop_idx = generated_captions[cap_i].find('.')
146 | if stop_idx > 5:
147 | all_generated_captions.append(generated_captions[cap_i][:stop_idx])
148 | else:
149 | all_generated_captions.append(generated_captions[cap_i])
150 | all_gt_captions.append([gt_captions[cap_i]])
151 | elif model_mode == 'retrieval':
152 | if i == 0:
153 | # Generate without image input to visualize text-generation ability.
154 | input_ids = tgt_tokens[:, :3] # Use first 3 tokens as initial prompt for generation.
155 | input_embs = model.module.model.input_embeddings(input_ids) # (N, T, D)
156 | generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, generate=True, num_words=num_words, temperature=0.0, top_p=1.0)
157 | generated_ids = torch.cat([input_ids, generated_ids], dim=1)
158 | generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
159 | gt_captions = tokenizer.batch_decode(tgt_tokens, skip_special_tokens=False)
160 | else:
161 | raise NotImplementedError
162 |
163 | if i == 0:
164 | max_to_display = 5
165 | print('=' * 30)
166 | print('Generated samples:')
167 | for cap_i, cap in enumerate(generated_captions[:max_to_display]):
168 | print(f'{cap_i}) {cap}')
169 | print('=' * 30)
170 | print('Real samples:')
171 | for cap_i, cap in enumerate(gt_captions[:max_to_display]):
172 | print(f'{cap_i}) {cap}')
173 | print('=' * 30)
174 |
175 | # Write images and captions to Tensorboard.
176 | if not args.distributed or (args.rank % ngpus_per_node == 0):
177 | max_images_to_show = 16
178 | normalized_images = images - images.min()
179 | normalized_images /= normalized_images.max() # (N, 3, H, W)
180 | # Create generated caption text.
181 | generated_cap_images = torch.stack([
182 | utils.create_image_of_text(
183 | generated_captions[j].encode('ascii', 'ignore'),
184 | width=normalized_images.shape[3],
185 | color=(255, 255, 0))
186 | for j in range(normalized_images.shape[0])], axis=0)
187 | # Append gt/generated caption images.
188 | display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show]
189 | grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4)
190 | writer.add_image(f'val/images_{model_mode}', grid, actual_step)
191 |
192 | # measure elapsed time
193 | batch_time.update(time.time() - end)
194 | end = time.time()
195 |
196 | if i % args.print_freq == 0:
197 | progress.display(i + 1)
198 |
199 | if i == args.val_steps_per_epoch - 1:
200 | break
201 |
202 | # Measure captioning metrics.
203 | path2captions = collections.defaultdict(list)
204 | for image_path, caption in zip(all_generated_image_paths, all_gt_captions):
205 | assert len(caption) == 1, caption
206 | path2captions[image_path].append(caption[0].replace('[RET]', ''))
207 | full_gt_captions = [path2captions[path] for path in all_generated_image_paths]
208 |
209 | print(f'Computing BLEU with {len(all_generated_captions)} generated captions:'
210 | f'{all_generated_captions[:5]} and {len(full_gt_captions)} groundtruth captions:',
211 | f'{full_gt_captions[:5]}.')
212 | bleu1_score = bleu_scorers[0](all_generated_captions, full_gt_captions)
213 | bleu1.update(bleu1_score, 1)
214 | bleu2_score = bleu_scorers[1](all_generated_captions, full_gt_captions)
215 | bleu2.update(bleu2_score, 1)
216 | bleu3_score = bleu_scorers[2](all_generated_captions, full_gt_captions)
217 | bleu3.update(bleu3_score, 2)
218 | bleu4_score = bleu_scorers[3](all_generated_captions, full_gt_captions)
219 | bleu4.update(bleu4_score, 3)
220 |
221 | # Measure retrieval metrics over the entire validation set.
222 | all_image_features = torch.cat(all_image_features, axis=0) # (coco_val_len, 2048)
223 | all_text_features = torch.cat(all_text_features, axis=0) # (coco_val_len, 2048)
224 |
225 | print(f"Computing similarity between {all_image_features.shape} and {all_text_features.shape}.")
226 | logits_per_image = all_image_features @ all_text_features.t()
227 | logits_per_text = logits_per_image.t()
228 | all_image_acc1, all_image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5))
229 | all_caption_acc1, all_caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5))
230 | image_loss = losses_utils.contrastive_loss(logits_per_image)
231 | caption_loss = losses_utils.contrastive_loss(logits_per_text)
232 |
233 | loss = args.ret_loss_scale * (image_loss + caption_loss) / 2.0
234 | losses.update(loss.item(), logits_per_image.size(0))
235 | top1_caption.update(all_caption_acc1.item(), logits_per_image.size(0))
236 | top5_caption.update(all_caption_acc5.item(), logits_per_image.size(0))
237 | top1_image.update(all_image_acc1.item(), logits_per_image.size(0))
238 | top5_image.update(all_image_acc5.item(), logits_per_image.size(0))
239 |
240 |
241 | batch_time = utils.AverageMeter('Time', ':6.3f', utils.Summary.AVERAGE)
242 | losses = utils.AverageMeter('Loss', ':.4e', utils.Summary.AVERAGE)
243 | ce_losses = utils.AverageMeter('CeLoss', ':.4e', utils.Summary.AVERAGE)
244 | top1 = utils.AverageMeter('Acc@1', ':6.2f', utils.Summary.AVERAGE)
245 | top5 = utils.AverageMeter('Acc@5', ':6.2f', utils.Summary.AVERAGE)
246 | bleu1 = utils.AverageMeter('BLEU@1', ':6.2f', utils.Summary.AVERAGE)
247 | bleu2 = utils.AverageMeter('BLEU@2', ':6.2f', utils.Summary.AVERAGE)
248 | bleu3 = utils.AverageMeter('BLEU@3', ':6.2f', utils.Summary.AVERAGE)
249 | bleu4 = utils.AverageMeter('BLEU@4', ':6.2f', utils.Summary.AVERAGE)
250 | top1_caption = utils.AverageMeter('CaptionAcc@1', ':6.2f', utils.Summary.AVERAGE)
251 | top5_caption = utils.AverageMeter('CaptionAcc@5', ':6.2f', utils.Summary.AVERAGE)
252 | top1_image = utils.AverageMeter('ImageAcc@1', ':6.2f', utils.Summary.AVERAGE)
253 | top5_image = utils.AverageMeter('ImageAcc@5', ':6.2f', utils.Summary.AVERAGE)
254 |
255 | progress = utils.ProgressMeter(
256 | len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
257 | [batch_time, losses, top1, top5, bleu4],
258 | prefix='Test: ')
259 |
260 | # switch to evaluate mode
261 | model.eval()
262 |
263 | run_validate(val_loader)
264 | if args.distributed:
265 | batch_time.all_reduce()
266 | losses.all_reduce()
267 | bleu1.all_reduce()
268 | bleu2.all_reduce()
269 | bleu3.all_reduce()
270 | bleu4.all_reduce()
271 | top1.all_reduce()
272 | top5.all_reduce()
273 | top1_caption.all_reduce()
274 | top5_caption.all_reduce()
275 | top1_image.all_reduce()
276 | top5_image.all_reduce()
277 |
278 | if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
279 | aux_val_dataset = Subset(val_loader.dataset,
280 | range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
281 | aux_val_loader = torch.utils.data.DataLoader(
282 | aux_val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False,
283 | num_workers=args.workers, pin_memory=True, collate_fn=data.collate_fn)
284 | run_validate(aux_val_loader, len(val_loader))
285 |
286 | progress.display_summary()
287 |
288 | writer.add_scalar('val/total_secs_per_batch', batch_time.avg, actual_step)
289 | writer.add_scalar('val/seq_top1_acc', top1.avg, actual_step)
290 | writer.add_scalar('val/seq_top5_acc', top5.avg, actual_step)
291 | writer.add_scalar('val/ce_loss', losses.avg, actual_step)
292 | writer.add_scalar('val/bleu1', bleu1.avg, actual_step)
293 | writer.add_scalar('val/bleu2', bleu2.avg, actual_step)
294 | writer.add_scalar('val/bleu3', bleu3.avg, actual_step)
295 | writer.add_scalar('val/bleu4', bleu4.avg, actual_step)
296 | writer.add_scalar('val/contrastive_loss', losses.avg, actual_step)
297 | writer.add_scalar('val/t2i_top1_acc', top1_caption.avg, actual_step)
298 | writer.add_scalar('val/t2i_top5_acc', top5_caption.avg, actual_step)
299 | writer.add_scalar('val/i2t_top1_acc', top1_image.avg, actual_step)
300 | writer.add_scalar('val/i2t_top5_acc', top5_image.avg, actual_step)
301 | writer.add_scalar('val/top1_acc', (top1_caption.avg + top1_image.avg) / 2.0, actual_step)
302 | writer.add_scalar('val/top5_acc', (top5_caption.avg + top5_image.avg) / 2.0, actual_step)
303 |
304 | writer.close()
305 |
306 | # Use top1 accuracy as the metric for keeping the best checkpoint.
307 | return top1_caption.avg
308 |
--------------------------------------------------------------------------------
/fromage/extract_img_embs.py:
--------------------------------------------------------------------------------
1 | """Extract image embeddings for a list of image urls.
2 |
3 | Example usage:
4 | python extract_img_embs.py
5 | """
6 | import torch
7 | from fromage import models, utils
8 |
9 | from PIL import Image
10 | import os
11 | import requests
12 | from io import BytesIO
13 | import pickle as pkl
14 |
15 |
16 | def extract_embeddings_for_urls(image_urls: list[str], emb_output_path: str, device: str = "cuda"):
17 | # Load model checkpoint.
18 | model = models.load_fromage("./fromage_model/")
19 | model.eval()
20 |
21 | visual_encoder = "openai/clip-vit-large-patch14"
22 | feature_extractor = utils.get_feature_extractor_for_model(
23 | visual_encoder, train=False
24 | )
25 |
26 | output_data = {"paths": [], "embeddings": []}
27 | with torch.no_grad():
28 | for img_url in image_urls:
29 | img = Image.open(BytesIO(requests.get(img_url).content))
30 |
31 | img_tensor = utils.get_pixel_values_for_model(feature_extractor, img)
32 | img_tensor = img_tensor[None, ...].to(device).bfloat16()
33 | img_emb = model.model.get_visual_embs(img_tensor, mode="retrieval")
34 | img_emb = img_emb[0, 0, :].cpu()
35 | output_data["paths"].append(img_url)
36 | output_data["embeddings"].append(img_emb)
37 |
38 | with open(emb_output_path, "wb") as f:
39 | pkl.dump(output_data, f)
40 |
41 |
42 | if __name__ == "__main__":
43 | image_urls = [] # TODO: Replace with image urls
44 | if image_urls == []:
45 | raise ValueError("Please replace `image_urls` with a list of image urls.")
46 |
47 | extract_embeddings_for_urls(image_urls, "fromage_model/cc3m_embeddings.pkl")
48 |
--------------------------------------------------------------------------------
/fromage/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | from fromage import utils
4 |
5 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
6 | return torch.nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device))
7 |
8 |
9 | def contrastive_acc(logits: torch.Tensor, target: Optional[torch.Tensor] = None, topk=(1,)) -> torch.Tensor:
10 | """
11 | Args:
12 | logits: (N, N) predictions.
13 | target: (N, num_correct_answers) labels.
14 | """
15 | assert len(logits.shape) == 2, logits.shape
16 | batch_size = logits.shape[0]
17 |
18 | if target is None:
19 | target = torch.arange(len(logits), device=logits.device)
20 | return utils.accuracy(logits, target, -1, topk)
21 | else:
22 | assert len(target.shape) == 2, target.shape
23 | with torch.no_grad():
24 | maxk = max(topk)
25 | if logits.shape[-1] < maxk:
26 | print(f"[WARNING] Less than {maxk} predictions available. Using {logits.shape[-1]} for topk.")
27 | maxk = min(maxk, logits.shape[-1])
28 |
29 | # Take topk along the last dimension.
30 | _, pred = logits.topk(maxk, -1, True, True) # (N, topk)
31 | assert pred.shape == (batch_size, maxk)
32 |
33 | target_expand = target[:, :, None].repeat(1, 1, maxk) # (N, num_correct_answers, topk)
34 | pred_expand = pred[:, None, :].repeat(1, target.shape[1], 1) # (N, num_correct_answers, topk)
35 | correct = pred_expand.eq(target_expand) # (N, num_correct_answers, topk)
36 | correct = torch.any(correct, dim=1) # (N, topk)
37 |
38 | res = []
39 | for k in topk:
40 | any_k_correct = torch.clamp(correct[:, :k].sum(1), max=1) # (N,)
41 | correct_k = any_k_correct.float().sum(0, keepdim=True)
42 | res.append(correct_k.mul_(100.0 / batch_size))
43 | return res
44 |
45 |
--------------------------------------------------------------------------------
/fromage/models.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple, Union
2 | from collections import namedtuple
3 | import json
4 | import glob
5 | import math
6 | import numpy as np
7 | import os
8 | import torch
9 | from torch import Tensor
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from einops import rearrange
13 | from functools import partial
14 | import pickle as pkl
15 | from PIL import Image, UnidentifiedImageError
16 |
17 | from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
18 | from transformers import OPTForCausalLM, GPT2Tokenizer
19 | from transformers import CLIPVisionModel, CLIPVisionConfig
20 |
21 | from fromage import utils
22 |
23 |
24 | class FrozenArgs:
25 | freeze_lm: bool = True
26 | freeze_vm: bool = True
27 | opt_version: str = 'facebook/opt-6.7b'
28 | visual_encoder: str = 'openai/clip-vit-large-patch14'
29 | n_visual_tokens: int = 1
30 | image_embed_dropout_prob: float = 0.0
31 | task: str = 'captioning'
32 | shared_emb_dim: Optional[int] = 256
33 | text_emb_layers: List[int] = [-1]
34 | retrieval_token_idx: int = 0
35 |
36 |
37 | class FromageModel(nn.Module):
38 | def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()):
39 | super().__init__()
40 | self.tokenizer = tokenizer
41 | self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
42 | self.image_token = self.tokenizer.cls_token_id
43 | assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
44 | self.args = args
45 |
46 | opt_version = args.opt_version
47 | visual_encoder = args.visual_encoder
48 | n_visual_tokens = args.n_visual_tokens
49 | print(f"Using {opt_version} for the language model.")
50 | print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
51 |
52 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
53 |
54 | if 'facebook/opt' in opt_version:
55 | self.lm = OPTForCausalLM.from_pretrained(opt_version)
56 | else:
57 | raise NotImplementedError
58 |
59 | self.opt_version = opt_version
60 |
61 | if self.args.freeze_lm:
62 | self.lm.eval()
63 | print("Freezing the LM.")
64 | for param in self.lm.parameters():
65 | param.requires_grad = False
66 | else:
67 | self.lm.train()
68 |
69 | # NOTE: Resizing sets all token embeddings and all lm_head weights (since they are tied in OPT)
70 | # to be trainable (param.requires_grad = True).
71 | self.retrieval_token_idx = args.retrieval_token_idx
72 | print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).')
73 | self.lm.resize_token_embeddings(len(tokenizer))
74 |
75 | self.input_embeddings = self.lm.get_input_embeddings()
76 |
77 | print("Restoring pretrained weights for the visual model.")
78 | if 'clip' in visual_encoder:
79 | self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
80 | else:
81 | self.visual_model = AutoModel.from_pretrained(visual_encoder)
82 |
83 | if 'clip' in visual_encoder:
84 | hidden_size = self.visual_model.config.hidden_size
85 | else:
86 | raise NotImplementedError
87 |
88 | if self.args.freeze_vm:
89 | print("Freezing the VM.")
90 | self.visual_model.eval()
91 | for param in self.visual_model.parameters():
92 | param.requires_grad = False
93 | else:
94 | self.visual_model.train()
95 |
96 | self.visual_model_name = visual_encoder
97 |
98 | embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
99 | self.text_hidden_fcs = nn.ModuleList([])
100 | if self.args.shared_emb_dim is None:
101 | if len(self.args.text_emb_layers) == 1:
102 | if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version):
103 | out_dim = self.lm.config.word_embed_proj_dim
104 | else:
105 | out_dim = self.lm.config.hidden_size
106 | else:
107 | if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \
108 | and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size):
109 | raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).')
110 | else:
111 | out_dim = self.lm.config.hidden_size
112 | else:
113 | out_dim = self.args.shared_emb_dim
114 |
115 | for layer_idx in self.args.text_emb_layers:
116 | if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
117 | in_dim = self.lm.config.word_embed_proj_dim
118 |
119 | text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
120 | self.text_hidden_fcs.append(nn.Sequential(*text_fc))
121 |
122 | elif layer_idx < self.lm.config.num_hidden_layers:
123 | text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
124 | self.text_hidden_fcs.append(nn.Sequential(*text_fc))
125 | else:
126 | raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
127 |
128 | self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
129 | self.visual_fc = nn.Linear(hidden_size, out_dim)
130 |
131 | self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob)
132 |
133 |
134 | def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
135 | if mode not in ['captioning', 'retrieval']:
136 | raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.')
137 |
138 | # Extract visual embeddings from the vision encoder.
139 | if 'clip' in self.visual_model_name:
140 | outputs = self.visual_model(pixel_values)
141 | encoder_outputs = outputs.pooler_output
142 | else:
143 | raise NotImplementedError
144 |
145 | # Use the correct fc based on function argument.
146 | if mode == 'captioning':
147 | visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
148 | visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
149 | elif mode == 'retrieval':
150 | visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
151 | visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
152 | else:
153 | raise NotImplementedError
154 |
155 | visual_embs = self.image_dropout(visual_embs)
156 | return visual_embs
157 |
158 |
159 | def train(self, mode=True):
160 | super(FromageModel, self).train(mode=mode)
161 | # Overwrite train() to ensure Frozen models remain frozen.
162 | if self.args.freeze_lm:
163 | self.lm.eval()
164 | if self.args.freeze_vm:
165 | self.visual_model.eval()
166 |
167 |
168 | def forward(
169 | self,
170 | pixel_values: torch.FloatTensor,
171 | labels: torch.LongTensor,
172 | caption_len: torch.LongTensor,
173 | mode: str = 'captioning',
174 | concat_captions: bool = False,
175 | input_prefix: Optional[str] = None,
176 | inference: bool = False,
177 | ):
178 | visual_embs = self.get_visual_embs(pixel_values, mode)
179 |
180 | batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
181 | if labels is not None:
182 | assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
183 |
184 | input_embs = self.input_embeddings(labels) # (N, T, D)
185 |
186 | last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
187 |
188 | if input_prefix is not None:
189 | prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
190 | prompt_ids = prompt_ids.to(visual_embs.device)
191 | prompt_embs = self.input_embeddings(prompt_ids)
192 | prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
193 | assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
194 | assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
195 | assert len(prompt_embs.shape) == 3, prompt_embs.shape
196 |
197 | if mode == 'captioning':
198 | # Concat to text embeddings.
199 | condition_seq_len = 0
200 | if input_prefix is None:
201 | # Just add visual embeddings.
202 | input_embs = torch.cat([visual_embs, input_embs], axis=1)
203 | last_embedding_idx += vis_seq_len
204 | condition_seq_len += vis_seq_len
205 | full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
206 | else:
207 | # Add visual and prompt embeddings.
208 | prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
209 | input_embs = torch.cat([prefix_embs, input_embs], axis=1)
210 |
211 | last_embedding_idx += prefix_embs.shape[1]
212 | condition_seq_len += prefix_embs.shape[1]
213 | full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
214 |
215 | # Mask out embedding tokens in the labels.
216 | full_labels = torch.cat([full_labels, labels], axis=1)
217 |
218 | pad_idx = []
219 |
220 | for label in full_labels:
221 | for k, token in enumerate(label):
222 | # Mask out retrieval token if it exists.
223 | if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]:
224 | label[k:] = -100
225 | pad_idx.append(k)
226 | break
227 | if k == len(label) - 1: # No padding found.
228 | pad_idx.append(k + 1)
229 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
230 |
231 | bs, seq_len, embs_dim = input_embs.shape
232 | if concat_captions:
233 | assert len(input_embs.shape) == 3, input_embs
234 | assert len(full_labels.shape) == 2, full_labels
235 | assert batch_size % 2 == 0
236 | all_concat_input_embs = []
237 | all_concat_labels = []
238 |
239 | # Rearrange embeddings and labels (and their padding) to concatenate captions.
240 | for i in range(batch_size // 2):
241 | first_idx = i * 2
242 | second_idx = first_idx + 1
243 | first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
244 | first_labels = full_labels[first_idx, :pad_idx[first_idx]]
245 | first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
246 | first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
247 |
248 | second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
249 | second_labels = full_labels[second_idx, :pad_idx[second_idx]]
250 | second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
251 | second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
252 |
253 | assert torch.all(first_labels_padding == -100), first_labels_padding
254 | assert torch.all(second_labels_padding == -100), second_labels_padding
255 | concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
256 | concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
257 | all_concat_input_embs.append(concat_input_embs)
258 | all_concat_labels.append(concat_labels)
259 |
260 | # Pad to max length.
261 | input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
262 | full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
263 | assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape
264 | assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape
265 |
266 | output = self.lm(inputs_embeds=input_embs,
267 | labels=full_labels,
268 | output_hidden_states=True)
269 | elif mode == 'retrieval':
270 | full_labels = torch.clone(labels)
271 | if input_prefix is not None:
272 | print(f'Adding prefix "{input_prefix}" to retrieval.')
273 | # Add prompt embeddings.
274 | prefix_embs = prompt_embs
275 | input_embs = torch.cat([prefix_embs, input_embs], axis=1)
276 | last_embedding_idx += prefix_embs.shape[1]
277 | full_labels = torch.cat([
278 | torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
279 | full_labels
280 | ], axis=1)
281 |
282 | pad_idx = []
283 | for label in full_labels:
284 | for k, token in enumerate(label):
285 | if token == self.tokenizer.pad_token_id:
286 | label[k:] = -100
287 | pad_idx.append(k)
288 | break
289 | if k == len(label) - 1: # No padding found.
290 | pad_idx.append(k + 1)
291 | assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
292 |
293 | output = self.lm(inputs_embeds=input_embs,
294 | labels=full_labels,
295 | output_hidden_states=True)
296 | else:
297 | raise NotImplementedError
298 |
299 | last_embedding = None
300 | last_output_logit = None
301 | hidden_states = []
302 |
303 | if mode == 'retrieval':
304 | if self.args.shared_emb_dim is not None:
305 | for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
306 | hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
307 | else:
308 | for idx in self.args.text_emb_layers:
309 | hidden_states.append(output.hidden_states[idx])
310 |
311 | # Add hidden states together.
312 | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
313 |
314 | if not concat_captions:
315 | last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) # (N, D)
316 | last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
317 | else:
318 | raise NotImplementedError
319 |
320 | # Compute retrieval loss.
321 | assert visual_embs.shape[1] == 1, visual_embs.shape
322 | visual_embs = visual_embs[:, 0, :]
323 | visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
324 | last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
325 |
326 | # cosine similarity as logits
327 | logit_scale = self.logit_scale.exp()
328 | visual_embs = logit_scale * visual_embs
329 | elif mode == 'captioning':
330 | pass
331 | else:
332 | raise NotImplementedError
333 |
334 | return output, full_labels, last_embedding, last_output_logit, visual_embs
335 |
336 | def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
337 | temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
338 | ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')):
339 | """Runs greedy decoding and returns generated captions.
340 |
341 | Args:
342 | embeddings: Input condition that the model uses for autoregressive generation.
343 | max_len: Maximum number of tokens to generate.
344 | temperature: Used to modulate logit distribution.
345 | top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
346 | min_word_tokens: Minimum number of words to generate before allowing a [RET] output.
347 | ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
348 | filter_value: Value to assign to tokens that should never be generated.
349 | Outputs:
350 | out: (N, T) int32 sequence of output tokens.
351 | output_embeddings: (N, T, 256) sequence of text output embeddings.
352 | """
353 | self.lm.eval()
354 |
355 | with torch.no_grad(): # no tracking history
356 | batch_size, s, _ = embeddings.shape
357 | # init output with image tokens
358 | out = None
359 | past_key_values = None
360 | output_embeddings = []
361 | output_logits = []
362 |
363 | for i in range(max_len):
364 | if 'opt' in self.opt_version:
365 | output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
366 | else:
367 | if i == 0:
368 | output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True)
369 | else:
370 | output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
371 |
372 | # Collect and sum the hidden states.
373 | hidden_states = []
374 | if self.args.shared_emb_dim is not None:
375 | for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
376 | hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
377 | else:
378 | for idx in self.args.text_emb_layers:
379 | hidden_states.append(output.hidden_states[idx])
380 | # Add hidden states together.
381 | last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, T, 256)
382 | last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True)
383 | output_embeddings.append(last_embedding)
384 |
385 | logits = output.logits[:, -1, :] # (N, vocab_size)
386 | if top_p == 1.0:
387 | logits = logits.cpu()
388 | output_logits.append(logits)
389 |
390 | if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None:
391 | if i < min_word_tokens:
392 | # Eliminate probability of generating [RET] if this is earlier than min_word_tokens.
393 | logits[:, self.retrieval_token_idx] = filter_value
394 | else:
395 | # Multiply by scaling factor.
396 | logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor
397 |
398 | past_key_values = output.past_key_values
399 |
400 | if temperature == 0.0:
401 | if top_p != 1.0:
402 | raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
403 | next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
404 | else:
405 | logits = logits / temperature
406 |
407 | # Apply top-p filtering.
408 | if top_p < 1.0:
409 | assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
410 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
411 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
412 |
413 | # Remove tokens with cumulative probability above the threshold
414 | sorted_indices_to_remove = cumulative_probs > top_p
415 | # Shift the indices to the right to keep also the first token above the threshold
416 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
417 | sorted_indices_to_remove[..., 0] = 0
418 |
419 | for j in range(sorted_indices.shape[0]):
420 | indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
421 | logits[j, indices_to_remove] = filter_value
422 |
423 | token_weights = logits.exp() # (N, vocab_size)
424 | next_token = torch.multinomial(token_weights, 1) # (N, 1)
425 |
426 | next_token = next_token.long().to(embeddings.device)
427 | if out is not None:
428 | out = torch.cat([out, next_token], dim=-1)
429 | else:
430 | out = next_token
431 |
432 | if 'opt' in self.opt_version:
433 | next_embedding = self.input_embeddings(next_token)
434 | embeddings = torch.cat([embeddings, next_embedding], dim=1)
435 | elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()):
436 | # End of generation.
437 | break
438 |
439 | return out, output_embeddings, output_logits
440 |
441 |
442 | class Fromage(nn.Module):
443 | def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None,
444 | path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None):
445 | super().__init__()
446 | self.model = FromageModel(tokenizer, model_args)
447 | self.path_array = path_array
448 | self.emb_matrix = emb_matrix
449 |
450 | def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
451 | generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
452 | ret_scale_factor: float = 1.0, min_word_tokens: int = 0,
453 | mode: str = 'captioning', concat_captions: bool = False,
454 | input_prefix: Optional[str] = None, inference: bool = False) -> Tensor:
455 | if generate:
456 | return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
457 | min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor)
458 | else:
459 | output = self.model(
460 | pixel_values = images,
461 | labels = tgt_tokens,
462 | caption_len = caption_len,
463 | mode = mode,
464 | concat_captions = concat_captions,
465 | input_prefix = input_prefix,
466 | inference = inference)
467 | return output
468 |
469 | def generate_for_images_and_texts(
470 | self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0,
471 | max_num_rets: int = 1, max_img_per_ret: int = 1):
472 | """
473 | Encode prompts into embeddings.
474 |
475 | Args:
476 | prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
477 | num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
478 | ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
479 | top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
480 | temperature: Used to modulate logit distribution.
481 | max_num_rets: Maximum number of images to return in one generation pass.
482 | max_img_per_ret: Maximum number of images to return for each [RET] token.
483 | Returns:
484 | return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
485 | """
486 | input_embs = []
487 | input_ids = []
488 | add_bos = True
489 |
490 | for i, p in enumerate(prompts):
491 | if type(p) == Image.Image:
492 | # Encode as image.
493 | pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
494 | pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
495 | pixel_values = pixel_values[None, ...]
496 |
497 | visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
498 | input_embs.append(visual_embs)
499 | elif type(p) == str:
500 | text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
501 | if not add_bos:
502 | # Remove tag.
503 | text_ids = text_ids[:, 1:]
504 | else:
505 | # Only add once.
506 | add_bos = False
507 |
508 | text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
509 | input_embs.append(text_embs)
510 | input_ids.append(text_ids)
511 | else:
512 | raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
513 | input_embs = torch.cat(input_embs, dim=1)
514 | input_ids = torch.cat(input_ids, dim=1)
515 |
516 | if num_words == 0:
517 | generated_ids = input_ids
518 | outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True)
519 | # Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens.
520 | out = []
521 | for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs):
522 | out.append(fc(outputs.hidden_states[x]))
523 | embeddings = torch.stack(out, dim=-1).sum(dim=-1)
524 | embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T, 256)
525 | elif num_words > 0:
526 | generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
527 | temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
528 | embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
529 |
530 | # Truncate to newline.
531 | newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
532 | trunc_idx = 0
533 | for j in range(generated_ids.shape[1]):
534 | if generated_ids[0, j] == newline_token_id:
535 | trunc_idx = j
536 | break
537 | if trunc_idx > 0:
538 | generated_ids = generated_ids[:, :trunc_idx]
539 | embeddings = embeddings[:, :trunc_idx]
540 | else:
541 | raise ValueError
542 |
543 | # Save outputs as an interleaved list.
544 | return_outputs = []
545 | # Find up to max_num_rets [RET] tokens, and their corresponding scores.
546 | all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets]
547 | seen_image_idx = [] # Avoid showing the same image multiple times.
548 |
549 | last_ret_idx = 0
550 | if len(all_ret_idx) == 0:
551 | # No [RET] tokens.
552 | caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
553 | return_outputs.append(utils.truncate_caption(caption))
554 | else:
555 | for ret_idx in all_ret_idx:
556 | ret_emb = embeddings[:, ret_idx, :]
557 | scores = self.emb_matrix @ ret_emb.T
558 |
559 | # Downweight seen images.
560 | for seen_idx in seen_image_idx:
561 | scores[seen_idx, :] -= 1000
562 |
563 | # Get the top max_img_per_ret + 3 (in case some fail) images for each image.
564 | _, top_image_idx = scores.squeeze().topk(max_img_per_ret + 3)
565 | image_outputs = []
566 | for img_idx in top_image_idx:
567 | # Find the first image that does not error out.
568 | try:
569 | seen_image_idx.append(img_idx)
570 | img = utils.get_image_from_url(self.path_array[img_idx])
571 | image_outputs.append(img)
572 | if len(image_outputs) == max_img_per_ret:
573 | break
574 | except UnidentifiedImageError:
575 | pass
576 |
577 | caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
578 | last_ret_idx = ret_idx + 1
579 | return_outputs.append(utils.truncate_caption(caption) + ' [RET]')
580 | return_outputs.append(image_outputs)
581 |
582 | return return_outputs
583 |
584 | def get_log_lik_scores(
585 | self, prompts: List):
586 | """
587 | Output the log likelihoods of the given interleaved prompts.
588 |
589 | Args:
590 | prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
591 | Returns:
592 | log lik score of prompt sequence.
593 | """
594 | input_embs = []
595 | input_ids = []
596 | add_bos = True
597 |
598 | for i, p in enumerate(prompts):
599 | if type(p) == Image.Image:
600 | # Encode as image.
601 | pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
602 | pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
603 | pixel_values = pixel_values[None, ...]
604 |
605 | visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
606 | input_embs.append(visual_embs)
607 | id_ = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
608 | input_ids.append(id_)
609 | elif type(p) == str:
610 | text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
611 | if not add_bos:
612 | # Remove tag.
613 | text_ids = text_ids[:, 1:]
614 | else:
615 | # Only add once.
616 | add_bos = False
617 |
618 | text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
619 | input_embs.append(text_embs)
620 | input_ids.append(text_ids)
621 | else:
622 | raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
623 | input_embs = torch.cat(input_embs, dim=1)
624 | input_ids = torch.cat(input_ids, dim=1)
625 |
626 | outputs = self.model.lm(inputs_embeds=input_embs, labels=input_ids, use_cache=False, output_hidden_states=True)
627 | return -outputs.loss.item()
628 |
629 | def load_fromage(model_dir: str) -> Fromage:
630 | model_args_path = os.path.join(model_dir, 'model_args.json')
631 | model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
632 | embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
633 |
634 | if not os.path.exists(model_args_path):
635 | raise ValueError(f'model_args.json does not exist in {model_dir}.')
636 | if not os.path.exists(model_ckpt_path):
637 | raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
638 | if len(embs_paths) == 0:
639 | raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.')
640 |
641 | # Load embeddings.
642 | # Construct embedding matrix for nearest neighbor lookup.
643 | path_array = []
644 | emb_matrix = []
645 |
646 | # These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
647 | for p in embs_paths:
648 | with open(p, 'rb') as wf:
649 | train_embs_data = pkl.load(wf)
650 | path_array.extend(train_embs_data['paths'])
651 | emb_matrix.append(train_embs_data['embeddings'])
652 | emb_matrix = np.concatenate(emb_matrix, axis=0)
653 |
654 | # Number of paths should be equal to number of embeddings.
655 | assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0])
656 |
657 | with open(model_args_path, 'r') as f:
658 | model_kwargs = json.load(f)
659 |
660 | # Initialize tokenizer.
661 | tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
662 | tokenizer.pad_token = tokenizer.eos_token
663 | # Add special tokens to the model to enable [RET].
664 | tokenizer.add_special_tokens({"cls_token": "<|image|>"})
665 | tokenizer.add_tokens('[RET]')
666 | ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
667 | assert len(ret_token_idx) == 1, ret_token_idx
668 | model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
669 | args = namedtuple('args', model_kwargs)(**model_kwargs)
670 |
671 | # Initialize model for inference.
672 | model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
673 | model = model.eval()
674 | model = model.bfloat16()
675 | model = model.cuda()
676 |
677 | # Load pretrained linear mappings and [RET] embeddings.
678 | checkpoint = torch.load(model_ckpt_path)
679 | model.load_state_dict(checkpoint['state_dict'], strict=False)
680 | with torch.no_grad():
681 | model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].squeeze().cpu().detach())
682 |
683 | logit_scale = model.model.logit_scale.exp()
684 | emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
685 | emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
686 | emb_matrix = logit_scale * emb_matrix
687 | model.emb_matrix = emb_matrix
688 |
689 | return model
690 |
691 |
--------------------------------------------------------------------------------
/fromage/prune_model_ckpt.py:
--------------------------------------------------------------------------------
1 | """Prune pretrained model weights to reduce size.
2 |
3 | This keeps only the weights that we finetune, and discards the pretrained LLM / visual encoder weights.
4 | """
5 | import torch
6 | import json
7 | from collections import OrderedDict
8 |
9 | ckpt_path = 'ckpt.pth.tar'
10 | pruned_output_path = 'ckpt_pruned.pth.tar'
11 | model_args_path = 'model_args.json'
12 |
13 | if __name__ == '__main__':
14 | with open(model_args_path, 'r') as f:
15 | model_kwargs = json.load(f)
16 | ret_token_idx = model_kwargs['retrieval_token_idx']
17 |
18 | with open(ckpt_path, 'rb') as f:
19 | checkpoint = torch.load(f)
20 |
21 | stripped_state_dict = {
22 | k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items() if
23 | ('.lm' not in k and '.visual_model' not in k)
24 | }
25 | stripped_state_dict = OrderedDict(sorted(stripped_state_dict.items()))
26 |
27 | del checkpoint['epoch']
28 | print('Best score:', checkpoint['best_score'])
29 | del checkpoint['optimizer']
30 | del checkpoint['scheduler']
31 | for k, v in stripped_state_dict.items():
32 | stripped_state_dict[k] = v.detach().clone()
33 |
34 | # Prune the pretrained token embeddings and keep just [RET].
35 | ret_embedding = stripped_state_dict['model.input_embeddings.weight'][ret_token_idx:ret_token_idx+1, :].detach().clone()
36 | stripped_state_dict['ret_input_embeddings.weight'] = ret_embedding
37 |
38 | with open(pruned_output_path, 'wb') as f:
39 | torch.save({'state_dict': stripped_state_dict}, f)
40 |
--------------------------------------------------------------------------------
/fromage/utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import subprocess
3 | import sys
4 | import shutil
5 | import torch
6 | import torch.distributed as dist
7 | from torchvision.transforms import functional as F
8 | from torchvision import transforms as T
9 | from transformers import AutoFeatureExtractor
10 | from PIL import Image, ImageDraw, ImageFont, ImageOps
11 | import requests
12 | from io import BytesIO
13 |
14 | import random
15 |
16 |
17 | def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
18 | """Logs git status to stdout."""
19 | subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
20 | subprocess.call('echo', shell=True, stdout=out_file)
21 | exclude_string = ''
22 | subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
23 |
24 |
25 | def get_image_from_url(url: str):
26 | response = requests.get(url)
27 | img = Image.open(BytesIO(response.content))
28 | img = img.resize((224, 224))
29 | img = img.convert('RGB')
30 | return img
31 |
32 |
33 | def truncate_caption(caption: str) -> str:
34 | """Truncate captions at periods and newlines."""
35 | caption = caption.strip('\n')
36 | trunc_index = caption.find('\n') + 1
37 | if trunc_index <= 0:
38 | trunc_index = caption.find('.') + 1
39 | if trunc_index > 0:
40 | caption = caption[:trunc_index]
41 | return caption
42 |
43 |
44 | def pad_to_size(x, size=256):
45 | delta_w = size - x.size[0]
46 | delta_h = size - x.size[1]
47 | padding = (
48 | delta_w // 2,
49 | delta_h // 2,
50 | delta_w - (delta_w // 2),
51 | delta_h - (delta_h // 2),
52 | )
53 | new_im = ImageOps.expand(x, padding)
54 | return new_im
55 |
56 |
57 | class RandCropResize(object):
58 |
59 | """
60 | Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
61 | """
62 |
63 | def __init__(self, target_size):
64 | self.target_size = target_size
65 |
66 | def __call__(self, img):
67 | img = pad_to_size(img, self.target_size)
68 | d_min = min(img.size)
69 | img = T.RandomCrop(size=d_min)(img)
70 | t_min = min(d_min, round(9 / 8 * self.target_size))
71 | t_max = min(d_min, round(12 / 8 * self.target_size))
72 | t = random.randint(t_min, t_max + 1)
73 | img = T.Resize(t)(img)
74 | if min(img.size) < 256:
75 | img = T.Resize(256)(img)
76 | return T.RandomCrop(size=self.target_size)(img)
77 |
78 |
79 | class SquarePad(object):
80 | """Pads image to square.
81 | From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
82 | """
83 | def __call__(self, image):
84 | max_wh = max(image.size)
85 | p_left, p_top = [(max_wh - s) // 2 for s in image.size]
86 | p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
87 | padding = (p_left, p_top, p_right, p_bottom)
88 | return F.pad(image, padding, 0, 'constant')
89 |
90 |
91 | def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
92 | """Creates a (3, nrows * 14, width) image of text.
93 |
94 | Returns:
95 | cap_img: (3, 14 * nrows, width) image of wrapped text.
96 | """
97 | height = 12
98 | padding = 5
99 | effective_width = width - 2 * padding
100 | # Create a black image to draw text on.
101 | cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
102 | draw = ImageDraw.Draw(cap_img)
103 | draw.text((0, 0), text, color, font=font or ImageFont.load_default())
104 | cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
105 | cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
106 | cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
107 | # Add zero padding.
108 | cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
109 | return cap_img
110 |
111 |
112 | def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
113 | print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
114 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
115 | return feature_extractor
116 |
117 |
118 | def get_pixel_values_for_model(feature_extractor, img):
119 | pixel_values = feature_extractor(
120 | img.convert('RGB'),
121 | return_tensors="pt").pixel_values[0, ...] # (3, H, W)
122 | return pixel_values
123 |
124 |
125 | def save_checkpoint(state, is_best, filename='checkpoint'):
126 | torch.save(state, filename + '.pth.tar')
127 | if is_best:
128 | shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
129 |
130 |
131 | def accuracy(output, target, padding, topk=(1,)):
132 | """Computes the accuracy over the k top predictions for the specified values of k"""
133 | with torch.no_grad():
134 | maxk = max(topk)
135 | if output.shape[-1] < maxk:
136 | print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
137 |
138 | maxk = min(maxk, output.shape[-1])
139 | batch_size = target.size(0)
140 |
141 | # Take topk along the last dimension.
142 | _, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
143 |
144 | mask = (target != padding).type(target.dtype)
145 | target_expand = target[..., None].expand_as(pred)
146 | correct = pred.eq(target_expand)
147 | correct = correct * mask[..., None].expand_as(correct)
148 |
149 | res = []
150 | for k in topk:
151 | correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
152 | res.append(correct_k.mul_(100.0 / mask.sum()))
153 | return res
154 |
155 |
156 | def get_params_count(model, max_name_len: int = 60):
157 | params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
158 | total_trainable_params = sum([x[1] for x in params if x[-1]])
159 | total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
160 | return params, total_trainable_params, total_nontrainable_params
161 |
162 |
163 | def get_params_count_str(model, max_name_len: int = 60):
164 | padding = 70 # Hardcoded depending on desired amount of padding and separators.
165 | params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
166 | param_counts_text = ''
167 | param_counts_text += '=' * (max_name_len + padding) + '\n'
168 | param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
169 | param_counts_text += '-' * (max_name_len + padding) + '\n'
170 | for name, param_count, shape, trainable in params:
171 | param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
172 | param_counts_text += '-' * (max_name_len + padding) + '\n'
173 | param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
174 | param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
175 | param_counts_text += '=' * (max_name_len + padding) + '\n'
176 | return param_counts_text
177 |
178 |
179 | class Summary(Enum):
180 | NONE = 0
181 | AVERAGE = 1
182 | SUM = 2
183 | COUNT = 3
184 |
185 |
186 | class ProgressMeter(object):
187 | def __init__(self, num_batches, meters, prefix=""):
188 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
189 | self.meters = meters
190 | self.prefix = prefix
191 |
192 | def display(self, batch):
193 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
194 | entries += [str(meter) for meter in self.meters]
195 | print('\t'.join(entries))
196 |
197 | def display_summary(self):
198 | entries = [" *"]
199 | entries += [meter.summary() for meter in self.meters]
200 | print(' '.join(entries))
201 |
202 | def _get_batch_fmtstr(self, num_batches):
203 | num_digits = len(str(num_batches // 1))
204 | fmt = '{:' + str(num_digits) + 'd}'
205 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
206 |
207 |
208 | class AverageMeter(object):
209 | """Computes and stores the average and current value"""
210 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
211 | self.name = name
212 | self.fmt = fmt
213 | self.summary_type = summary_type
214 | self.reset()
215 |
216 | def reset(self):
217 | self.val = 0
218 | self.avg = 0
219 | self.sum = 0
220 | self.count = 0
221 |
222 | def update(self, val, n=1):
223 | self.val = val
224 | self.sum += val * n
225 | self.count += n
226 | self.avg = self.sum / self.count
227 |
228 | def all_reduce(self):
229 | device = "cuda" if torch.cuda.is_available() else "cpu"
230 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
231 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
232 | self.sum, self.count = total.tolist()
233 | self.avg = self.sum / self.count
234 |
235 | def __str__(self):
236 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
237 | return fmtstr.format(**self.__dict__)
238 |
239 | def summary(self):
240 | fmtstr = ''
241 | if self.summary_type is Summary.NONE:
242 | fmtstr = ''
243 | elif self.summary_type is Summary.AVERAGE:
244 | fmtstr = '{name} {avg:.3f}'
245 | elif self.summary_type is Summary.SUM:
246 | fmtstr = '{name} {sum:.3f}'
247 | elif self.summary_type is Summary.COUNT:
248 | fmtstr = '{name} {count:.3f}'
249 | else:
250 | raise ValueError('invalid summary type %r' % self.summary_type)
251 |
252 | return fmtstr.format(**self.__dict__)
253 |
--------------------------------------------------------------------------------
/fromage_model/fromage_vis4/model_args.json:
--------------------------------------------------------------------------------
1 | {
2 | "opt_version": "facebook/opt-6.7b",
3 | "freeze_lm": true,
4 | "visual_encoder": "openai/clip-vit-large-patch14",
5 | "freeze_vm": true,
6 | "n_visual_tokens": 4,
7 | "use_image_embed_norm": false,
8 | "image_embed_dropout_prob": 0.0,
9 | "use_text_embed_layernorm": false,
10 | "text_embed_dropout_prob": 0.0,
11 | "shared_emb_dim": 256,
12 | "text_emb_layers": [
13 | -1
14 | ],
15 | "retrieval_token_idx": 50266
16 | }
--------------------------------------------------------------------------------
/fromage_model/fromage_vis4/pretrained_ckpt.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage_model/fromage_vis4/pretrained_ckpt.pth.tar
--------------------------------------------------------------------------------
/fromage_model/model_args.json:
--------------------------------------------------------------------------------
1 | {
2 | "opt_version": "facebook/opt-6.7b",
3 | "task": "multitask",
4 | "freeze_lm": true,
5 | "visual_encoder": "openai/clip-vit-large-patch14",
6 | "freeze_vm": true,
7 | "pretrained_visual": true,
8 | "use_pooler": true,
9 | "n_visual_tokens": 1,
10 | "image_embed_dropout_prob": 0.0,
11 | "text_embed_dropout_prob": 0.0,
12 | "shared_emb_dim": 256,
13 | "text_emb_layers": [
14 | -1
15 | ],
16 | "append_retrieval_token": true,
17 | "num_appended_retrieval_tokens": 1,
18 | "input_prompt": "A picture of",
19 | "add_input_to_ret": true,
20 | "tunable_prompt_length": 0,
21 | "retrieval_token_idx": 50266
22 | }
23 |
--------------------------------------------------------------------------------
/fromage_model/pretrained_ckpt.pth.tar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/fromage_model/pretrained_ckpt.pth.tar
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """Training example.
2 |
3 | Modified from https://github.com/pytorch/examples/blob/main/imagenet/main.py.
4 | """
5 | import argparse
6 | import json
7 | import os
8 | import sys
9 | import time
10 | import warnings
11 |
12 | import numpy as np
13 | from PIL import Image
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.parallel
17 | import torch.backends.cudnn as cudnn
18 | import torch.distributed as dist
19 | import torch.optim
20 | from torch.optim.lr_scheduler import StepLR
21 | from warmup_scheduler import GradualWarmupScheduler
22 | import torch.multiprocessing as mp
23 | import torch.utils.data
24 | import torch.utils.data.distributed
25 | import torchvision.transforms as transforms
26 | import torchvision.datasets as datasets
27 | from torch.utils.tensorboard import SummaryWriter
28 | import torchvision
29 |
30 | from fromage import data
31 | from fromage import losses as losses_utils
32 | from fromage import models
33 | from fromage import utils
34 | from fromage import evaluate
35 | from transformers import AutoTokenizer
36 |
37 | # Disable HuggingFace tokenizer parallelism.
38 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
39 |
40 | # Available LLM models.
41 | llm_models = ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b',
42 | 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b',
43 | 'facebook/opt-66b']
44 | datasets = ['cc3m']
45 | best_score = 0 # Variable to keep track of best model so far.
46 |
47 |
48 | def parse_args(args):
49 | parser = argparse.ArgumentParser(description='FROMAGe training')
50 | parser.add_argument('--opt-version', default='facebook/opt-6.7b',
51 | choices=llm_models,
52 | help='OPT versions: ' +
53 | ' | '.join(llm_models) +
54 | ' (default: "facebook/opt-6.7b")')
55 | parser.add_argument('--visual-model', default='openai/clip-vit-large-patch14', type=str,
56 | help="Visual encoder to use.")
57 | parser.add_argument('-d', '--dataset', metavar='DATASET', help='Delimited list of datasets:' +
58 | ' | '.join(datasets), default='cc3m',
59 | type=lambda s: [x for x in s.split(',')])
60 |
61 | parser.add_argument('--val-dataset', metavar='DATASET', default='cc3m',
62 | type=lambda s: [x for x in s.split(',')],
63 | help='Validation dataset: ' +
64 | ' | '.join(datasets) +
65 | ' (default: cc3m)')
66 | parser.add_argument('--dataset_dir', default='datasets', type=str,
67 | help='Dataset directory containing .tsv files.')
68 | parser.add_argument('--image-dir', default='./data/', type=str,
69 | help='Dataset directory containing image folders.')
70 | parser.add_argument('--log-base-dir', default='./runs/', type=str,
71 | help='Base directory to write logs and ckpts to.')
72 | parser.add_argument('--exp_name', default='frozen', type=str,
73 | help='Name of experiment, used for saving checkpoints.')
74 |
75 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
76 | help='number of data loading workers (default: 4)')
77 | parser.add_argument('--epochs', default=10, type=int, metavar='N',
78 | help='number of total epochs to run')
79 | parser.add_argument('--steps-per-epoch', default=2000, type=int, metavar='N',
80 | help='number of training steps per epoch')
81 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
82 | help='manual epoch number (useful on restarts)')
83 | parser.add_argument('--val-steps-per-epoch', default=-1, type=int, metavar='N',
84 | help='number of validation steps per epoch.')
85 | parser.add_argument('-b', '--batch-size', default=180, type=int,
86 | metavar='N',
87 | help='mini-batch size (default: 180), this is the total '
88 | 'batch size of all GPUs on the current node when '
89 | 'using Data Parallel or Distributed Data Parallel')
90 | parser.add_argument('--val-batch-size', default=None, type=int)
91 | parser.add_argument('--lr', '--learning-rate', default=0.0003, type=float,
92 | metavar='LR', help='initial learning rate', dest='lr')
93 | parser.add_argument('--lr-warmup-steps', default=100, type=int,
94 | metavar='N', help='Number of steps to warm up lr.')
95 | parser.add_argument('--lr-schedule-step-size', default=10, type=int,
96 | metavar='N', help='Number of steps before decaying lr.')
97 | parser.add_argument('--lr-schedule-gamma', default=0.1, type=float,
98 | metavar='N', help='Decay parameter for learning rate scheduler.')
99 | parser.add_argument('--grad-accumulation-steps', default=1, type=int, metavar='N',
100 | help='number of gradient accumulation steps')
101 | parser.add_argument('--grad-clip', default=1.0, type=float, help='gradient clipping amount')
102 |
103 | parser.add_argument('--precision', default='fp32', type=str, choices=['fp32', 'fp16', 'bf16'], help="Precision to train in.")
104 | parser.add_argument('--cap-loss-scale', type=float, default=1.0, help="Scale on captioning loss.")
105 | parser.add_argument('--ret-loss-scale', type=float, default=1.0, help="Scale on retrieval loss.")
106 |
107 | parser.add_argument('--concat-captions-prob', type=float, default=0.5, help="Probability of concatenating two examples sequentially for captioning.")
108 | parser.add_argument('--concat-for-ret', action='store_true', default=False, help="Whether to concatenate examples for retrieval mode.")
109 | parser.add_argument('--input-prompt', default=None, type=str, help="Input prompt for the language model, if any.")
110 |
111 | parser.add_argument('--image-size', default=224, type=int, metavar='N', help='Size of images.')
112 | parser.add_argument('--use_image_embed_norm', action='store_true', default=False, help="Whether to use norm on the image embeddings to make them equal to language.")
113 | parser.add_argument('--image_embed_dropout_prob', type=float, default=0.0, help="Dropout probability on the image embeddings.")
114 | parser.add_argument('--use_text_embed_layernorm', action='store_true', default=False, help="Whether to use layer norm on the text embeddings for retrieval.")
115 | parser.add_argument('--text_embed_dropout_prob', type=float, default=0.0, help="Dropout probability on the text embeddings.")
116 | parser.add_argument('--shared-emb-dim', default=256, type=int, metavar='N', help='Embedding dimension for retrieval.')
117 | parser.add_argument('--text-emb-layers', help='Layer to use for text embeddings. OPT-2.7b has 33 layers.', default='-1',
118 | type=lambda s: [int(x) for x in s.split(',')])
119 |
120 | parser.add_argument('--max-len', default=24, type=int,
121 | metavar='N', help='Maximum length to truncate captions / generations to.')
122 | parser.add_argument('--n-visual-tokens', default=1, type=int,
123 | metavar='N', help='Number of visual tokens to use for the Frozen model.')
124 |
125 | parser.add_argument('--beta1', default=0.9, type=float, metavar='M', help='beta1 for Adam')
126 | parser.add_argument('--beta2', default=0.95, type=float, metavar='M', help='beta2 for Adam')
127 | parser.add_argument('--wd', '--weight-decay', default=0.0, type=float,
128 | metavar='W', help='weight decay (default: 0.0)', dest='weight_decay')
129 | parser.add_argument('-p', '--print-freq', default=10, type=int,
130 | metavar='N', help='print frequency (default: 10)')
131 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
132 | help='path to latest checkpoint (default: none)')
133 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
134 | help='evaluate model on validation set')
135 | parser.add_argument('--world-size', default=-1, type=int,
136 | help='number of nodes for distributed training')
137 | parser.add_argument('--rank', default=-1, type=int,
138 | help='node rank for distributed training')
139 | parser.add_argument('--dist-url', default='tcp://127.0.0.1:1337', type=str,
140 | help='url used to set up distributed training')
141 | parser.add_argument('--dist-backend', default='nccl', type=str,
142 | help='distributed backend')
143 | parser.add_argument('--seed', default=None, type=int,
144 | help='seed for initializing training. ')
145 | parser.add_argument('--gpu', default=None, type=int,
146 | help='GPU id to use.')
147 | parser.add_argument('--multiprocessing-distributed', action='store_true',
148 | help='Use multi-processing distributed training to launch '
149 | 'N processes per node, which has N GPUs. This is the '
150 | 'fastest way to use PyTorch for either single node or '
151 | 'multi node data parallel training')
152 | return parser.parse_args(args)
153 |
154 |
155 | def main(args):
156 | args = parse_args(args)
157 | i = 1
158 | args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
159 | while os.path.exists(args.log_dir):
160 | args.log_dir = os.path.join(args.log_base_dir, f'{args.exp_name}_{i}')
161 | i += 1
162 | os.makedirs(args.log_dir)
163 |
164 | with open(os.path.join(args.log_dir, f'args.json'), 'w') as wf:
165 | json.dump(vars(args), wf, indent=4)
166 |
167 | with open(os.path.join(args.log_dir, f'git_info.txt'), 'w') as wf:
168 | utils.dump_git_status(out_file=wf)
169 |
170 | print(f'Logging to {args.log_dir}.')
171 |
172 | if args.seed is not None:
173 | torch.manual_seed(args.seed)
174 | cudnn.deterministic = True
175 | warnings.warn('You have chosen to seed training. '
176 | 'This will turn on the CUDNN deterministic setting, '
177 | 'which can slow down your training considerably! '
178 | 'You may see unexpected behavior when restarting '
179 | 'from checkpoints.')
180 |
181 | if args.gpu is not None:
182 | warnings.warn('You have chosen a specific GPU. This will completely '
183 | 'disable data parallelism.')
184 |
185 | if args.dist_url == "env://" and args.world_size == -1:
186 | args.world_size = int(os.environ["WORLD_SIZE"])
187 |
188 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed
189 |
190 | ngpus_per_node = torch.cuda.device_count()
191 | if args.multiprocessing_distributed:
192 | # Since we have ngpus_per_node processes per node, the total world_size
193 | # needs to be adjusted accordingly
194 | args.world_size = ngpus_per_node * args.world_size
195 | # Use torch.multiprocessing.spawn to launch distributed processes: the
196 | # main_worker process function
197 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
198 | else:
199 | # Simply call main_worker function
200 | main_worker(args.gpu, ngpus_per_node, args)
201 |
202 |
203 | def main_worker(gpu, ngpus_per_node, args):
204 | """Setup code."""
205 | global best_score
206 | args.gpu = gpu
207 |
208 | if args.gpu is not None:
209 | print("Use GPU: {} for training".format(args.gpu))
210 |
211 | if args.distributed:
212 | if args.dist_url == "env://" and args.rank == -1:
213 | args.rank = int(os.environ["RANK"])
214 | if args.multiprocessing_distributed:
215 | # For multiprocessing distributed training, rank needs to be the
216 | # global rank among all the processes
217 | args.rank = args.rank * ngpus_per_node + gpu
218 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
219 | world_size=args.world_size, rank=args.rank)
220 |
221 | # Create model
222 | model_args = models.FrozenArgs()
223 | model_args.opt_version = args.opt_version
224 | model_args.freeze_lm = True
225 | model_args.visual_encoder = args.visual_model
226 | model_args.freeze_vm = True
227 | model_args.n_visual_tokens = args.n_visual_tokens
228 | model_args.use_image_embed_norm = args.use_image_embed_norm
229 | model_args.image_embed_dropout_prob = args.image_embed_dropout_prob
230 | model_args.use_text_embed_layernorm = args.use_text_embed_layernorm
231 | model_args.text_embed_dropout_prob = args.text_embed_dropout_prob
232 | model_args.shared_emb_dim = args.shared_emb_dim
233 | model_args.text_emb_layers = args.text_emb_layers
234 |
235 | tokenizer = AutoTokenizer.from_pretrained(args.opt_version, use_fast=False)
236 | # Add an image token for loss masking (and visualization) purposes.
237 | tokenizer.add_special_tokens({"cls_token": "<|image|>"}) # add special image token to tokenizer
238 | print('Adding [RET] token to vocabulary.')
239 | print('Before adding new token, tokenizer("[RET]") =', tokenizer('[RET]', add_special_tokens=False))
240 | num_added_tokens = tokenizer.add_tokens('[RET]')
241 | print(f'After adding {num_added_tokens} new tokens, tokenizer("[RET]") =', tokenizer('[RET]', add_special_tokens=False))
242 | ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
243 | assert len(ret_token_idx) == 1, ret_token_idx
244 | model_args.retrieval_token_idx = ret_token_idx[0]
245 | args.retrieval_token_idx = ret_token_idx[0]
246 |
247 | # Save model args to disk.
248 | with open(os.path.join(args.log_dir, 'model_args.json'), 'w') as f:
249 | json.dump(vars(model_args), f, indent=4)
250 |
251 | model = models.Fromage(tokenizer, model_args)
252 | if args.precision == 'fp16':
253 | model = model.half()
254 | elif args.precision == 'bf16':
255 | model = model.bfloat16()
256 |
257 | # Print parameters and count of model.
258 | param_counts_text = utils.get_params_count_str(model)
259 | with open(os.path.join(args.log_dir, 'param_count.txt'), 'w') as f:
260 | f.write(param_counts_text)
261 |
262 | # Log trainable parameters to Tensorboard.
263 | _, total_trainable_params, total_nontrainable_params = utils.get_params_count(model)
264 | writer = SummaryWriter(args.log_dir)
265 | writer.add_scalar('params/total', total_trainable_params + total_nontrainable_params, 0)
266 | writer.add_scalar('params/total_trainable', total_trainable_params, 0)
267 | writer.add_scalar('params/total_non_trainable', total_nontrainable_params, 0)
268 | writer.close()
269 |
270 | if not torch.cuda.is_available():
271 | print('WARNING: using CPU, this will be slow!')
272 | model = torch.nn.DataParallel(model)
273 | elif args.distributed:
274 | # For multiprocessing distributed, DistributedDataParallel constructor
275 | # should always set the single device scope, otherwise,
276 | # DistributedDataParallel will use all available devices.
277 | if args.gpu is not None:
278 | torch.cuda.set_device(args.gpu)
279 | model.cuda(args.gpu)
280 | # When using a single GPU per process and per
281 | # DistributedDataParallel, we need to divide the batch size
282 | # ourselves based on the total number of GPUs of the current node.
283 | args.batch_size = int(args.batch_size / ngpus_per_node)
284 | args.val_batch_size = int((args.val_batch_size or args.batch_size) / ngpus_per_node)
285 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
286 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
287 | else:
288 | model.cuda()
289 | # DistributedDataParallel will divide and allocate batch_size to all
290 | # available GPUs if device_ids are not set
291 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=False)
292 | elif args.gpu is not None:
293 | torch.cuda.set_device(args.gpu)
294 | model = model.cuda(args.gpu)
295 | else:
296 | model = torch.nn.DataParallel(model).cuda()
297 |
298 | # define loss function (criterion), optimizer, and learning rate scheduler
299 | criterion = nn.CrossEntropyLoss().cuda(args.gpu)
300 | optimizer_cls = torch.optim.AdamW
301 | print('Using torch.optim.AdamW as the optimizer.')
302 | optimizer = optimizer_cls(model.parameters(), args.lr,
303 | betas=(args.beta1, args.beta2),
304 | weight_decay=args.weight_decay,
305 | eps=1e-8)
306 |
307 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs"""
308 | scheduler_steplr = StepLR(optimizer, step_size=args.lr_schedule_step_size * args.steps_per_epoch, gamma=args.lr_schedule_gamma)
309 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.lr_warmup_steps, after_scheduler=scheduler_steplr)
310 |
311 | # optionally resume from a checkpoint
312 | if args.resume:
313 | if os.path.isfile(args.resume):
314 | print("=> loading checkpoint '{}'".format(args.resume))
315 | if args.gpu is None:
316 | checkpoint = torch.load(args.resume)
317 | else:
318 | # Map model to be loaded to specified single gpu.
319 | loc = 'cuda:{}'.format(args.gpu)
320 | checkpoint = torch.load(args.resume, map_location=loc)
321 | args.start_epoch = checkpoint['epoch']
322 | best_score = checkpoint.get('best_score', 0)
323 | model.load_state_dict(checkpoint['state_dict'])
324 | optimizer.load_state_dict(checkpoint['optimizer'])
325 | scheduler.load_state_dict(checkpoint['scheduler'])
326 | print("=> loaded checkpoint '{}' (epoch {})"
327 | .format(args.resume, checkpoint['epoch']))
328 | else:
329 | print("=> no checkpoint found at '{}'".format(args.resume))
330 |
331 | cudnn.benchmark = True
332 |
333 | # Data loading code
334 | train_dataset = data.get_dataset(args, 'train', tokenizer)
335 | val_dataset = data.get_dataset(args, 'val', tokenizer)
336 | print(f'Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples.')
337 |
338 | if args.distributed:
339 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, drop_last=True)
340 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
341 | else:
342 | train_sampler = None
343 | val_sampler = None
344 |
345 | train_loader = torch.utils.data.DataLoader(
346 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
347 | num_workers=args.workers, pin_memory=True, sampler=train_sampler)
348 | val_loader = torch.utils.data.DataLoader(
349 | val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False,
350 | num_workers=args.workers, pin_memory=True, sampler=val_sampler)
351 |
352 | if args.evaluate:
353 | evaluate.validate(val_loader, model, tokenizer, criterion, epoch, args)
354 | return
355 |
356 | for epoch in range(args.start_epoch, args.epochs):
357 | if epoch == 0:
358 | evaluate.validate(val_loader, model, tokenizer, criterion, epoch-1, args)
359 | if args.distributed:
360 | train_sampler.set_epoch(epoch)
361 |
362 | # train for one epoch
363 | train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler, args)
364 |
365 | # evaluate on validation set
366 | eval_score = evaluate.validate(val_loader, model, tokenizer, criterion, epoch, args)
367 |
368 | # remember best score and save checkpoint
369 | is_best = eval_score > best_score
370 | best_score = max(eval_score, best_score)
371 |
372 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
373 | and args.rank % ngpus_per_node == 0):
374 | utils.save_checkpoint({
375 | 'epoch': epoch + 1,
376 | 'state_dict': model.state_dict(),
377 | 'best_score': best_score,
378 | 'optimizer' : optimizer.state_dict(),
379 | 'scheduler' : scheduler.state_dict()
380 | }, is_best, os.path.join(args.log_dir, 'ckpt'))
381 |
382 |
383 | def train(train_loader, model, tokenizer, criterion, optimizer, epoch, scheduler, args):
384 | """Main training loop."""
385 | ngpus_per_node = torch.cuda.device_count()
386 | batch_time = utils.AverageMeter('Time', ':6.3f')
387 | cap_time = utils.AverageMeter('CaptioningTime', ':6.3f')
388 | ret_time = utils.AverageMeter('RetrievalTime', ':6.3f')
389 | data_time = utils.AverageMeter('Data', ':6.3f')
390 | losses = utils.AverageMeter('Loss', ':.4e')
391 | ce_losses = utils.AverageMeter('CeLoss', ':.4e')
392 | top1 = utils.AverageMeter('Acc@1', ':6.2f')
393 | top5 = utils.AverageMeter('Acc@5', ':6.2f')
394 | cont_losses = utils.AverageMeter('ContLoss', ':.4e')
395 | top1_caption = utils.AverageMeter('AccCaption@1', ':6.2f')
396 | top5_caption = utils.AverageMeter('AccCaption@5', ':6.2f')
397 | top1_image = utils.AverageMeter('AccImage@1', ':6.2f')
398 | top5_image = utils.AverageMeter('AccImage@5', ':6.2f')
399 |
400 | writer = SummaryWriter(args.log_dir)
401 |
402 | progress = utils.ProgressMeter(
403 | args.steps_per_epoch,
404 | [batch_time, losses, ce_losses, cont_losses, top1, top5],
405 | prefix="Epoch: [{}]".format(epoch))
406 |
407 | # switch to train mode
408 | model.train()
409 |
410 | end = time.time()
411 |
412 | for i, (image_paths, images, caption_images, tgt_tokens, token_len) in enumerate(train_loader):
413 | actual_step = epoch * args.steps_per_epoch + i + 1
414 | # measure data loading time
415 | data_time.update(time.time() - end)
416 |
417 | if torch.cuda.is_available():
418 | images = images.cuda(args.gpu, non_blocking=True)
419 | tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True)
420 | token_len = token_len.cuda(args.gpu, non_blocking=True)
421 |
422 | if args.precision == 'fp16':
423 | images = images.half()
424 | elif args.precision == 'bf16':
425 | images = images.bfloat16()
426 |
427 | model_modes = ['captioning', 'retrieval']
428 | loss = 0
429 |
430 | for model_mode in model_modes:
431 | mode_start = time.time()
432 | # compute output
433 | concat_captions = np.random.uniform(0, 1) < args.concat_captions_prob
434 | if not args.concat_for_ret:
435 | concat_captions = concat_captions and model_mode == 'captioning'
436 |
437 | (model_output, full_labels, last_embedding, _, visual_embs) = model(
438 | images, tgt_tokens, token_len, mode=model_mode, concat_captions=concat_captions, inference=False)
439 | output = model_output.logits
440 |
441 | # Measure captioning accuracy for multi-task models and next-token prediction for retrieval models.
442 | if model_mode == 'captioning':
443 | acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5))
444 | top1.update(acc1[0], images.size(0))
445 | top5.update(acc5[0], images.size(0))
446 |
447 | ce_loss = model_output.loss
448 | if model_mode == 'captioning':
449 | ce_loss = ce_loss * args.cap_loss_scale
450 | elif model_mode == 'retrieval':
451 | ce_loss = ce_loss * args.ret_loss_scale
452 | else:
453 | raise NotImplementedError
454 |
455 | loss += ce_loss
456 | ce_losses.update(ce_loss.item(), images.size(0))
457 |
458 | if model_mode == 'retrieval':
459 | # Cross replica concat for embeddings.
460 | if args.distributed:
461 | all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())]
462 | all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())]
463 | dist.all_gather(all_visual_embs, visual_embs)
464 | dist.all_gather(all_last_embedding, last_embedding)
465 | # Overwrite with embeddings produced on this replace, which have the gradient.
466 | all_visual_embs[dist.get_rank()] = visual_embs
467 | all_last_embedding[dist.get_rank()] = last_embedding
468 | visual_embs = torch.cat(all_visual_embs)
469 | last_embedding = torch.cat(all_last_embedding)
470 |
471 | start_idx = args.rank * images.shape[0]
472 | end_idx = start_idx + images.shape[0]
473 |
474 | logits_per_image = visual_embs @ last_embedding.t()
475 | logits_per_text = logits_per_image.t()
476 | if i == 0:
477 | print(f'Running contrastive loss over logits_per_text.shape = {logits_per_text.shape} and logits_per_image.shape = {logits_per_image.shape}')
478 |
479 | # Compute contrastive losses for retrieval.
480 | caption_loss = losses_utils.contrastive_loss(logits_per_text)
481 | image_loss = losses_utils.contrastive_loss(logits_per_image)
482 | caption_acc1, caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5))
483 | image_acc1, image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5))
484 | loss += args.ret_loss_scale * (caption_loss + image_loss) / 2.0
485 | cont_losses.update(loss.item(), images.size(0))
486 |
487 | # measure accuracy and record loss
488 | top1_caption.update(caption_acc1[0], images.size(0))
489 | top5_caption.update(caption_acc5[0], images.size(0))
490 | top1_image.update(image_acc1[0], images.size(0))
491 | top5_image.update(image_acc5[0], images.size(0))
492 |
493 | if model_mode == 'retrieval':
494 | ret_time.update(time.time() - mode_start)
495 | elif model_mode == 'captioning':
496 | cap_time.update(time.time() - mode_start)
497 |
498 | loss = loss / args.grad_accumulation_steps
499 | losses.update(loss.item(), images.size(0))
500 | loss.backward()
501 |
502 | # Update weights
503 | if ((i + 1) % args.grad_accumulation_steps == 0) or (i == args.steps_per_epoch - 1):
504 | # Zero out gradients of the embedding matrix outside of [RET].
505 | for param in model.module.model.input_embeddings.parameters():
506 | assert param.grad.shape[0] == len(tokenizer)
507 | # Keep other embeddings frozen.
508 | mask = torch.arange(param.grad.shape[0]) != args.retrieval_token_idx
509 | param.grad[mask, :] = 0
510 |
511 | # compute gradient and do SGD step
512 | if args.grad_clip > 0:
513 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
514 | optimizer.step()
515 | optimizer.zero_grad()
516 |
517 | with torch.no_grad():
518 | # Normalize trainable embeddings.
519 | frozen_norm = torch.norm(model.module.model.input_embeddings.weight[:-1, :], dim=1).mean(0)
520 | trainable_weight = model.module.model.input_embeddings.weight[-1, :]
521 | model.module.model.input_embeddings.weight[-1, :].div_(torch.norm(trainable_weight) / frozen_norm)
522 |
523 | # measure elapsed time
524 | batch_time.update(time.time() - end)
525 | end = time.time()
526 |
527 | if actual_step == 1 or (i + 1) % args.print_freq == 0:
528 | ex_per_sec = args.batch_size / batch_time.avg
529 | if args.distributed:
530 | batch_time.all_reduce()
531 | data_time.all_reduce()
532 | ex_per_sec = (args.batch_size / batch_time.avg) * ngpus_per_node
533 |
534 | losses.all_reduce()
535 | ce_losses.all_reduce()
536 | top1.all_reduce()
537 | top5.all_reduce()
538 | ret_time.all_reduce()
539 | cont_losses.all_reduce()
540 | top1_caption.all_reduce()
541 | top5_caption.all_reduce()
542 | top1_image.all_reduce()
543 | top5_image.all_reduce()
544 | cap_time.all_reduce()
545 |
546 | progress.display(i + 1)
547 |
548 | writer.add_scalar('train/loss', losses.avg, actual_step)
549 | writer.add_scalar('train/ce_loss', ce_losses.avg, actual_step)
550 | writer.add_scalar('train/seq_top1_acc', top1.avg, actual_step)
551 | writer.add_scalar('train/seq_top5_acc', top5.avg, actual_step)
552 | writer.add_scalar('train/contrastive_loss', cont_losses.avg, actual_step)
553 | writer.add_scalar('train/t2i_top1_acc', top1_caption.avg, actual_step)
554 | writer.add_scalar('train/t2i_top5_acc', top5_caption.avg, actual_step)
555 | writer.add_scalar('train/i2t_top1_acc', top1_image.avg, actual_step)
556 | writer.add_scalar('train/i2t_top5_acc', top5_image.avg, actual_step)
557 | writer.add_scalar('metrics/total_secs_per_batch', batch_time.avg, actual_step)
558 | writer.add_scalar('metrics/total_secs_captioning', cap_time.avg, actual_step)
559 | writer.add_scalar('metrics/total_secs_retrieval', ret_time.avg, actual_step)
560 | writer.add_scalar('metrics/data_secs_per_batch', data_time.avg, actual_step)
561 | writer.add_scalar('metrics/examples_per_sec', ex_per_sec, actual_step)
562 |
563 | if not args.multiprocessing_distributed or (args.multiprocessing_distributed
564 | and args.rank % ngpus_per_node == 0):
565 | image_bs = images.shape[0]
566 | normalized_images = images - images.min()
567 | normalized_images /= normalized_images.max() # (N, 3, H, W)
568 | max_images_to_show = 16
569 |
570 | # Append caption text.
571 | pred_tokens = output[:, args.n_visual_tokens-1:-1, :].argmax(dim=-1)
572 | generated_captions = tokenizer.batch_decode(pred_tokens, skip_special_tokens=False)
573 |
574 | # Log image (and generated caption) outputs to Tensorboard.
575 | if model_mode == 'captioning':
576 | # Create generated caption text.
577 | generated_cap_images = torch.stack([
578 | utils.create_image_of_text(
579 | generated_captions[i].encode('ascii', 'ignore'),
580 | width=normalized_images.shape[3],
581 | color=(255, 255, 0))
582 | for i in range(len(generated_captions))], axis=0)
583 |
584 | # Duplicate captions if we concatenated them.
585 | if (args.concat_captions_prob > 0 and model_mode == 'captioning' and generated_cap_images.shape[0] != caption_images.shape[0]):
586 | generated_cap_images = torch.cat([generated_cap_images, generated_cap_images], axis=0)
587 |
588 | display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show]
589 | grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4)
590 | writer.add_image('train/images_gen_cap', grid, actual_step)
591 |
592 | # Retrieved images (from text).
593 | retrieved_image_idx = logits_per_text[:image_bs, :image_bs].argmax(-1)
594 | t2i_images = torch.stack(
595 | [normalized_images[retrieved_image_idx[i], ...] for i in range(len(retrieved_image_idx))],
596 | axis=0)
597 | t2i_images = torch.cat([t2i_images.float().cpu(), caption_images], axis=2)[:max_images_to_show]
598 | t2i_grid = torchvision.utils.make_grid(t2i_images, nrow=int(max_images_to_show ** 0.5), padding=4)
599 | writer.add_image('train/t2i_ret', t2i_grid, actual_step)
600 |
601 | # Retrieved text (from image).
602 | retrieved_text_idx = logits_per_image[:image_bs, :image_bs].argmax(-1)
603 | retrieved_text = torch.stack(
604 | [caption_images[retrieved_text_idx[i], ...] for i in range(len(retrieved_text_idx))],
605 | axis=0)
606 | i2t_images = torch.cat([normalized_images.float().cpu(), retrieved_text], axis=2)[:max_images_to_show]
607 | i2t_grid = torchvision.utils.make_grid(i2t_images, nrow=int(max_images_to_show ** 0.5), padding=4)
608 | writer.add_image('train/i2t_ret', i2t_grid, actual_step)
609 |
610 | batch_time.reset()
611 | cap_time.reset()
612 | ret_time.reset()
613 | data_time.reset()
614 | losses.reset()
615 | ce_losses.reset()
616 | top1.reset()
617 | top5.reset()
618 | cont_losses.reset()
619 | top1_caption.reset()
620 | top5_caption.reset()
621 | top1_image.reset()
622 | top5_image.reset()
623 |
624 | if i == args.steps_per_epoch - 1:
625 | break
626 |
627 | scheduler.step()
628 | curr_lr = scheduler.get_last_lr()
629 | if (actual_step == 1) or (i + 1) % args.print_freq == 0:
630 | # Write current learning rate to Tensorboard.
631 | writer = SummaryWriter(args.log_dir)
632 | writer.add_scalar('train/lr', curr_lr[0], actual_step)
633 | writer.close()
634 |
635 | writer.close()
636 |
637 |
638 | if __name__ == '__main__':
639 | main(sys.argv[1:])
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | attrs==22.2.0
2 | certifi==2022.12.7
3 | charset-normalizer==3.0.1
4 | contourpy==1.0.7
5 | cycler==0.11.0
6 | einops==0.4.1
7 | exceptiongroup==1.1.0
8 | filelock==3.9.0
9 | fonttools==4.38.0
10 | huggingface-hub==0.12.0
11 | idna==3.4
12 | iniconfig==2.0.0
13 | kiwisolver==1.4.4
14 | matplotlib==3.6.3
15 | numpy==1.24.2
16 | packaging==23.0
17 | pandas==1.5.3
18 | Pillow==9.4.0
19 | pluggy==1.0.0
20 | pyparsing==3.0.9
21 | pytest==7.2.1
22 | python-dateutil==2.8.2
23 | PyYAML==6.0
24 | regex==2022.10.31
25 | requests==2.28.2
26 | six==1.16.0
27 | tensorboard==2.12.0
28 | tensorboard-data-server==0.7.0
29 | tensorboard-plugin-wit==1.8.1
30 | tokenizers==0.12.1
31 | tomli==2.0.1
32 | torch==1.11.0
33 | torchaudio==0.11.0
34 | torchmetrics==0.9.3
35 | torchvision==0.12.0
36 | tqdm==4.64.1
37 | transformers==4.21.3
38 | typing_extensions==4.4.0
39 | urllib3==1.26.14
40 | git+https://github.com/ildoonet/pytorch-gradual-warmup-lr.git
41 |
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/teaser.png
--------------------------------------------------------------------------------
/teaser_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kohjingyu/fromage/b36a1889e16cb9486e83e1853dce68ab653068c9/teaser_gif.gif
--------------------------------------------------------------------------------
/test_main.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | import tempfile
3 |
4 | import unittest
5 | import argparse
6 | import main
7 | import os
8 |
9 |
10 | def get_base_args():
11 | args = [
12 | '-b', '2', '--opt-version', 'facebook/opt-125m',
13 | '--val-steps-per-epoch', '2', '--epochs', '1', '--steps-per-epoch', '2',
14 | '--text-emb-layers', '-1', '--shared-emb-dim', '256',
15 | '--n-visual-tokens', '1', '--visual-model', 'openai/clip-vit-base-patch32', '--concat-captions-prob', '0.5']
16 | return args
17 |
18 | def check_workdir_outputs(workdir_path):
19 | workdir_content = os.listdir(workdir_path)
20 | print('workdir content: %s', workdir_content)
21 |
22 | assert 'ckpt.pth.tar' in workdir_content
23 | assert 'model_args.json' in workdir_content
24 | assert 'param_count.txt' in workdir_content
25 | assert 'git_info.txt' in workdir_content
26 | assert any(['events.out.tfevents' in fn for fn in workdir_content])
27 |
28 |
29 | class MultitaskTrainTest(unittest.TestCase):
30 | """Test captioning."""
31 | def test_train_and_evaluate(self):
32 | workdir = tempfile.mkdtemp()
33 | proj_root_dir = pathlib.Path(__file__).parents[0]
34 | exp_name = 'test_multitask'
35 |
36 | parser = argparse.ArgumentParser(description='Unit test parser')
37 | args = get_base_args() + ['--log-base-dir', workdir, '--exp_name', exp_name]
38 | main.main(args)
39 | check_workdir_outputs(os.path.join(workdir, exp_name))
40 |
41 |
42 | if __name__ == '__main__':
43 | unittest.main()
44 |
--------------------------------------------------------------------------------