├── .gitignore ├── LICENSE ├── README.md ├── README_zh-CN.md ├── app.py ├── checkpoints └── .gitkeep ├── configs ├── dataset_config.py └── lora_config.py ├── docs └── images │ └── demo_image.jpg ├── environment.yml ├── mmgpt ├── __init__.py ├── datasets │ ├── __init__.py │ ├── alpaca_gpt4_dataset.py │ ├── aokvqa_dataset.py │ ├── baize_dataset.py │ ├── builder.py │ ├── cc_sbu_align_dataset.py │ ├── clevr_dataset.py │ ├── coco_caption_dataset.py │ ├── dial_dataset.py │ ├── dolly_dataset.py │ ├── gqa_dataset.py │ ├── llava_dataset.py │ ├── nlvr_dataset.py │ ├── ocr_vqa_dataset.py │ ├── samplers │ │ ├── __init__.py │ │ └── infinite_sampler.py │ ├── snli_ve_datasets.py │ ├── text_ocr_dataset.py │ └── vqa_dataset.py ├── models │ ├── __init__.py │ ├── blip2 │ │ └── __init__.py │ ├── builder.py │ └── open_flamingo │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── flamingo.py │ │ ├── flamingo_lm.py │ │ ├── helpers.py │ │ └── utils.py └── train │ ├── __init__.py │ ├── distributed.py │ ├── instruction_finetune.py │ └── train_utils.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | 3 | wandb/ 4 | 5 | checkpoints/ 6 | tests/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Pycharm project settings 121 | .idea 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | *.out 139 | src/wandb 140 | wandb 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # Training 146 | batchscript* 147 | work_dirs 148 | data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018-2023 OpenMMLab. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2018-2023 OpenMMLab. 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 Multi-modal GPT 2 | 3 | Train a multi-modal chatbot with visual and language instructions! 4 | 5 | Based on the open-source multi-modal model [OpenFlamingo](https://github.com/mlfoundations/open_flamingo), we create various **visual instruction** data with open datasets, including VQA, Image Captioning, Visual Reasoning, Text OCR, and Visual Dialogue. Additionally, we also train the language model component of OpenFlamingo using only **language-only instruction** data. 6 | 7 | The **joint training** of visual and language instructions effectively improves the performance of the model! For more details please refer to our [technical report](https://arxiv.org/abs/2305.04790). 8 | 9 | Welcome to join us! 10 | 11 | 12 | 13 |
14 | 15 | English | [简体中文](README_zh-CN.md) 16 | 17 |
18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 | 39 | ## Features 40 | 41 | - Support various vision and language instruction data 42 | - Parameter efficient fine-tuning with LoRA 43 | - Tuning vision and language at the same time, complement each other 44 | 45 | 46 | ## Installation 47 | 48 | To install the package in an existing environment, run 49 | 50 | ```bash 51 | git clone https://github.com/open-mmlab/Multimodal-GPT.git 52 | cd Multimodal-GPT 53 | pip install -r requirements.txt 54 | pip install -v -e . 55 | ``` 56 | 57 | or create a new conda environment 58 | 59 | ```bash 60 | conda env create -f environment.yml 61 | ``` 62 | 63 | 64 | ## Launch Demo Locally 65 | 66 | 1. Download the pre-trained weights. 67 | 68 | Use [this script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) for converting LLaMA weights to Hugging Face format. 69 | 70 | Download the OpenFlamingo pre-trained model from [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B). 71 | 72 | Download our LoRA Weight from [here](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt). 73 | 74 | Then place these models in `checkpoints` folders like this: 75 | 76 | ``` 77 | checkpoints 78 | ├── llama-7b_hf 79 | │ ├── config.json 80 | │ ├── pytorch_model-00001-of-00002.bin 81 | │ ├── ...... 82 | │ └── tokenizer.model 83 | ├── OpenFlamingo-9B 84 | │ └──checkpoint.pt 85 | ├──mmgpt-lora-v0-release.pt 86 | 87 | 2. launch the gradio demo 88 | 89 | ```bash 90 | python app.py 91 | ``` 92 | 93 | ## Examples 94 | 95 | ### Recipe: 96 | ![image4](https://user-images.githubusercontent.com/12907710/234554562-8f3be88f-d563-47ba-97d9-ade8d47c46b0.png) 97 | 98 | ### Travel plan: 99 | ![image3](https://user-images.githubusercontent.com/12907710/234523464-80c4e3f0-f99f-4498-96ef-dc43ef89c64b.png) 100 | 101 | ### Movie: 102 | ![image2](https://user-images.githubusercontent.com/12907710/234523468-e11905a6-491f-4b87-934f-90da7d14d1c3.png) 103 | 104 | ### Famous person: 105 | ![image](https://user-images.githubusercontent.com/12907710/234523475-fd91f979-a344-4228-813f-6b55a1bc250f.png) 106 | 107 | 108 | ## Fine-tuning 109 | 110 | ### Prepare datasets 111 | 112 | 1. [A-OKVQA](https://allenai.org/project/a-okvqa/home) 113 | 114 | Download annotation from [this link](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz) and unzip to `data/aokvqa/annotations`. 115 | 116 | It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home). 117 | 118 | 2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/) 119 | 120 | Download from [this link](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip) and unzip to `data/coco`. 121 | 122 | It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home). 123 | 124 | 3. [OCR VQA](https://ocr-vqa.github.io/) 125 | 126 | Download from [this link](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) and place in `data/OCR_VQA/`. 127 | 128 | 4. [LlaVA](https://llava-vl.github.io/) 129 | 130 | Download from [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and place in `data/llava/`. 131 | 132 | It also requires images from coco dataset which can be downloaded from [here](https://cocodataset.org/#home). 133 | 134 | 5. [Mini-GPT4](https://minigpt-4.github.io/) 135 | 136 | Download from [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) and place in `data/cc_sbu_align/`. 137 | 138 | 6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html) 139 | 140 | Download from [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) and place it in `data/dolly/databricks-dolly-15k.jsonl`. 141 | 142 | 7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) 143 | 144 | Download it from [this link](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) and place it in `data/alpaca_gpt4/alpaca_gpt4_data.json`. 145 | 146 | You can also customize the data path in the [configs/dataset_config.py](configs/dataset_config.py). 147 | 148 | 8. [Baize](https://github.com/project-baize/baize-chatbot) 149 | 150 | Download it from [this link](https://github.com/project-baize/baize-chatbot/blob/main/data/quora_chat_data.json) and place it in `data/baize/quora_chat_data.json`. 151 | 152 | 153 | ## Start training 154 | 155 | ```bash 156 | torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \ 157 | --lm_path checkpoints/llama-7b_hf \ 158 | --tokenizer_path checkpoints/llama-7b_hf \ 159 | --pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \ 160 | --run_name train-my-gpt4 \ 161 | --learning_rate 1e-5 \ 162 | --lr_scheduler cosine \ 163 | --batch_size 1 \ 164 | --tuning_config configs/lora_config.py \ 165 | --dataset_config configs/dataset_config.py \ 166 | --report_to_wandb 167 | ``` 168 | 169 | 170 | ## Acknowledgements 171 | 172 | - [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) 173 | - [LAVIS](https://github.com/salesforce/LAVIS) 174 | - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 175 | - [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) 176 | - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main) 177 | - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) 178 | 179 | If you find our project useful for your research and applications, please cite using this BibTeX: 180 | 181 | ```bibtex 182 | @misc{gong2023multimodalgpt, 183 | title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans}, 184 | author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen}, 185 | year={2023}, 186 | eprint={2305.04790}, 187 | archivePrefix={arXiv}, 188 | primaryClass={cs.CV} 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # 🤖 Multi-modal GPT 2 | 3 | 使用视觉和语言指令训练一个多模态聊天机器人! 4 | 5 | 基于开源多模态模型 [OpenFlamingo](https://github.com/mlfoundations/open_flamingo),我们使用公开数据集创建了各种**视觉指令**数据,包括视觉问答、图像字幕、视觉推理、文本 OCR 和视觉对话。此外,我们还使用仅包含**语言指令**数据的语言模型组件进行了训练。 6 | 7 | 视觉和语言指令的**联合训练**有效提高了模型的性能!更多细节请参阅我们的[技术报告](https://arxiv.org/abs/2305.04790)。 8 | 9 | 欢迎加入我们! 10 | 11 | 12 | 13 |
14 | 15 | [English](README.md) | 简体中文 16 | 17 |
18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 |
38 | 39 | ## 特性 40 | 41 | - 支持各种视觉和语言指令数据 42 | - 使用 LoRA 进行参数高效微调 43 | - 同时调整视觉和语言,相互补充 44 | 45 | ## 安装 46 | 47 | 在一个已有环境中安装依赖包,运行以下指令 48 | 49 | ```bash 50 | git clone https://github.com/open-mmlab/Multimodal-GPT.git 51 | cd Multimodal-GPT 52 | pip install -r requirements.txt 53 | pip install -v -e . 54 | ``` 55 | 56 | 或者创建一个新的 conda 环境 57 | 58 | ```bash 59 | conda env create -f environment.yml 60 | ``` 61 | 62 | ## Demo 63 | 64 | 1. 下载预训练权重 65 | 66 | 使用[这个脚本](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)把 LLaMA 权重转换成 HuggingFace 格式。 67 | 68 | 从 [openflamingo/OpenFlamingo-9B](https://huggingface.co/openflamingo/OpenFlamingo-9B) 下载 OpenFlamingo 预训练模型。 69 | 70 | 从[这个链接](https://download.openmmlab.com/mmgpt/v0/mmgpt-lora-v0-release.pt) 下载我们的 LoRA 权重。 71 | 72 | 然后把所有模型权重放到 `checkpoints` 文件夹下,目录结构如下: 73 | 74 | ``` 75 | checkpoints 76 | ├── llama-7b_hf 77 | │ ├── config.json 78 | │ ├── pytorch_model-00001-of-00002.bin 79 | │ ├── ...... 80 | │ └── tokenizer.model 81 | ├── OpenFlamingo-9B 82 | │ └──checkpoint.pt 83 | ├──mmgpt-lora-v0-release.pt 84 | 85 | 2. 启动 gradio demo 86 | 87 | ```bash 88 | python app.py 89 | ``` 90 | 91 | ## 示例 92 | 93 | ### 菜单: 94 | ![image4](https://user-images.githubusercontent.com/12907710/234554562-8f3be88f-d563-47ba-97d9-ade8d47c46b0.png) 95 | 96 | ### 旅行计划: 97 | ![image3](https://user-images.githubusercontent.com/12907710/234523464-80c4e3f0-f99f-4498-96ef-dc43ef89c64b.png) 98 | 99 | ### 电影: 100 | ![image2](https://user-images.githubusercontent.com/12907710/234523468-e11905a6-491f-4b87-934f-90da7d14d1c3.png) 101 | 102 | ### 名人: 103 | ![image](https://user-images.githubusercontent.com/12907710/234523475-fd91f979-a344-4228-813f-6b55a1bc250f.png) 104 | 105 | 106 | ## 微调 Fine-tuning 107 | 108 | ### 准备数据集 109 | 110 | 1. [A-OKVQA](https://allenai.org/project/a-okvqa/home) 111 | 112 | 从[这个链接](https://prior-datasets.s3.us-east-2.amazonaws.com/aokvqa/aokvqa_v1p0.tar.gz)下载标注,解压到 `data/aokvqa/annotations` 路径下。 113 | 114 | 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。 115 | 116 | 2. [COCO Caption](https://cs.stanford.edu/people/karpathy/deepimagesent/) 117 | 118 | 从[这个链接](https://cs.stanford.edu/people/karpathy/deepimagesent/coco.zip),解压到 `data/coco` 路径下。 119 | 120 | 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。 121 | 122 | 3. [OCR VQA](https://ocr-vqa.github.io/) 123 | 124 | 从 [这个链接](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) 下载数据集,放到 `data/OCR_VQA/` 路径下。 125 | 126 | 4. [LlaVA](https://llava-vl.github.io/) 127 | 128 | 从 [liuhaotian/LLaVA-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 下载数据集,放到 `data/llava/` 路径下。 129 | 130 | 同时还需要 coco 数据集的图像,可以从[这里](https://cocodataset.org/#home)下载。 131 | 132 | 5. [Mini-GPT4](https://minigpt-4.github.io/) 133 | 134 | 从 [Vision-CAIR/cc_sbu_align](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) 下载数据集,放到 `data/cc_sbu_align/` 路径下。 135 | 136 | 6. [Dolly 15k](https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html) 137 | 138 | 从 [databricks/databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) 下载数据集,放到 `data/dolly/databricks-dolly-15k.jsonl` 路径下。 139 | 140 | 7. [Alpaca GPT4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) 141 | 142 | 从[这个链接](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/raw/main/data/alpaca_gpt4_data.json) 下载数据集,放到 `data/alpaca_gpt4/alpaca_gpt4_data.json` 路径下。 143 | 144 | 你也可以在 [configs/dataset_config.py](configs/dataset_config.py) 文件中自定义数据集路径。 145 | 146 | 147 | ## 开启训练 148 | 149 | ```bash 150 | torchrun --nproc_per_node=8 mmgpt/train/instruction_finetune.py \ 151 | --lm_path checkpoints/llama-7b_hf \ 152 | --tokenizer_path checkpoints/llama-7b_hf \ 153 | --pretrained_path checkpoints/OpenFlamingo-9B/checkpoint.pt \ 154 | --run_name train-my-gpt4 \ 155 | --learning_rate 1e-5 \ 156 | --lr_scheduler cosine \ 157 | --batch_size 1 \ 158 | --tuning_config configs/lora_config.py \ 159 | --dataset_config configs/dataset_config.py \ 160 | --report_to_wandb 161 | ``` 162 | 163 | 164 | ## 致谢 165 | 166 | - [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) 167 | - [LAVIS](https://github.com/salesforce/LAVIS) 168 | - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) 169 | - [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) 170 | - [LLaVA](https://github.com/haotian-liu/LLaVA/tree/main) 171 | - [Instruction Tuning with GPT-4](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) 172 | 173 | 如果你觉得我们的项目对你的研究和应用有帮助,请用以下 BibTeX 进行引用 174 | 175 | ```bibtex 176 | @misc{gong2023multimodalgpt, 177 | title={MultiModal-GPT: A Vision and Language Model for Dialogue with Humans}, 178 | author={Tao Gong and Chengqi Lyu and Shilong Zhang and Yudong Wang and Miao Zheng and Qian Zhao and Kuikun Liu and Wenwei Zhang and Ping Luo and Kai Chen}, 179 | year={2023}, 180 | eprint={2305.04790}, 181 | archivePrefix={arXiv}, 182 | primaryClass={cs.CV} 183 | } 184 | ``` 185 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import gradio as gr 4 | import torch 5 | from PIL import Image 6 | 7 | from mmgpt.models.builder import create_model_and_transforms 8 | 9 | TEMPLATE = "Below is an instruction that describes a task. Write a response that appropriately completes the request." 10 | response_split = "### Response:" 11 | 12 | 13 | class Inferencer: 14 | 15 | def __init__(self, finetune_path, llama_path, open_flamingo_path): 16 | ckpt = torch.load(finetune_path, map_location="cpu") 17 | if "model_state_dict" in ckpt: 18 | state_dict = ckpt["model_state_dict"] 19 | # remove the "module." prefix 20 | state_dict = { 21 | k[7:]: v 22 | for k, v in state_dict.items() if k.startswith("module.") 23 | } 24 | else: 25 | state_dict = ckpt 26 | tuning_config = ckpt.get("tuning_config") 27 | if tuning_config is None: 28 | print("tuning_config not found in checkpoint") 29 | else: 30 | print("tuning_config found in checkpoint: ", tuning_config) 31 | model, image_processor, tokenizer = create_model_and_transforms( 32 | model_name="open_flamingo", 33 | clip_vision_encoder_path="ViT-L-14", 34 | clip_vision_encoder_pretrained="openai", 35 | lang_encoder_path=llama_path, 36 | tokenizer_path=llama_path, 37 | pretrained_model_path=open_flamingo_path, 38 | tuning_config=tuning_config, 39 | ) 40 | model.load_state_dict(state_dict, strict=False) 41 | model.half() 42 | model = model.to("cuda") 43 | model.eval() 44 | tokenizer.padding_side = "left" 45 | tokenizer.add_eos_token = False 46 | self.model = model 47 | self.image_processor = image_processor 48 | self.tokenizer = tokenizer 49 | 50 | def __call__(self, prompt, imgpaths, max_new_token, num_beams, temperature, 51 | top_k, top_p, do_sample): 52 | if len(imgpaths) > 1: 53 | raise gr.Error( 54 | "Current only support one image, please clear gallery and upload one image" 55 | ) 56 | lang_x = self.tokenizer([prompt], return_tensors="pt") 57 | if len(imgpaths) == 0 or imgpaths is None: 58 | for layer in self.model.lang_encoder._get_decoder_layers(): 59 | layer.condition_only_lang_x(True) 60 | output_ids = self.model.lang_encoder.generate( 61 | input_ids=lang_x["input_ids"].cuda(), 62 | attention_mask=lang_x["attention_mask"].cuda(), 63 | max_new_tokens=max_new_token, 64 | num_beams=num_beams, 65 | temperature=temperature, 66 | top_k=top_k, 67 | top_p=top_p, 68 | do_sample=do_sample, 69 | )[0] 70 | for layer in self.model.lang_encoder._get_decoder_layers(): 71 | layer.condition_only_lang_x(False) 72 | else: 73 | images = (Image.open(fp) for fp in imgpaths) 74 | vision_x = [self.image_processor(im).unsqueeze(0) for im in images] 75 | vision_x = torch.cat(vision_x, dim=0) 76 | vision_x = vision_x.unsqueeze(1).unsqueeze(0).half() 77 | 78 | output_ids = self.model.generate( 79 | vision_x=vision_x.cuda(), 80 | lang_x=lang_x["input_ids"].cuda(), 81 | attention_mask=lang_x["attention_mask"].cuda(), 82 | max_new_tokens=max_new_token, 83 | num_beams=num_beams, 84 | temperature=temperature, 85 | top_k=top_k, 86 | top_p=top_p, 87 | do_sample=do_sample, 88 | )[0] 89 | generated_text = self.tokenizer.decode( 90 | output_ids, skip_special_tokens=True) 91 | # print(generated_text) 92 | result = generated_text.split(response_split)[-1].strip() 93 | return result 94 | 95 | 96 | class PromptGenerator: 97 | 98 | def __init__( 99 | self, 100 | prompt_template=TEMPLATE, 101 | ai_prefix="Response", 102 | user_prefix="Instruction", 103 | sep: str = "\n\n### ", 104 | buffer_size=0, 105 | ): 106 | self.all_history = list() 107 | self.ai_prefix = ai_prefix 108 | self.user_prefix = user_prefix 109 | self.buffer_size = buffer_size 110 | self.prompt_template = prompt_template 111 | self.sep = sep 112 | 113 | def add_message(self, role, message): 114 | self.all_history.append([role, message]) 115 | 116 | def get_images(self): 117 | img_list = list() 118 | if self.buffer_size > 0: 119 | all_history = self.all_history[-2 * (self.buffer_size + 1):] 120 | elif self.buffer_size == 0: 121 | all_history = self.all_history[-2:] 122 | else: 123 | all_history = self.all_history[:] 124 | for his in all_history: 125 | if type(his[-1]) == tuple: 126 | img_list.append(his[-1][-1]) 127 | return img_list 128 | 129 | def get_prompt(self): 130 | format_dict = dict() 131 | if "{user_prefix}" in self.prompt_template: 132 | format_dict["user_prefix"] = self.user_prefix 133 | if "{ai_prefix}" in self.prompt_template: 134 | format_dict["ai_prefix"] = self.ai_prefix 135 | prompt_template = self.prompt_template.format(**format_dict) 136 | ret = prompt_template 137 | if self.buffer_size > 0: 138 | all_history = self.all_history[-2 * (self.buffer_size + 1):] 139 | elif self.buffer_size == 0: 140 | all_history = self.all_history[-2:] 141 | else: 142 | all_history = self.all_history[:] 143 | context = [] 144 | have_image = False 145 | for role, message in all_history[::-1]: 146 | if message: 147 | if type(message) is tuple and message[ 148 | 1] is not None and not have_image: 149 | message, _ = message 150 | context.append(self.sep + "Image:\n" + self.sep + 151 | role + ":\n" + message) 152 | else: 153 | context.append(self.sep + role + ":\n" + message) 154 | else: 155 | context.append(self.sep + role + ":\n") 156 | 157 | ret += "".join(context[::-1]) 158 | return ret 159 | 160 | 161 | def to_gradio_chatbot(prompt_generator): 162 | ret = [] 163 | for i, (role, msg) in enumerate(prompt_generator.all_history): 164 | if i % 2 == 0: 165 | if type(msg) is tuple: 166 | import base64 167 | from io import BytesIO 168 | 169 | msg, image = msg 170 | if type(image) is str: 171 | from PIL import Image 172 | 173 | image = Image.open(image) 174 | max_hw, min_hw = max(image.size), min(image.size) 175 | aspect_ratio = max_hw / min_hw 176 | max_len, min_len = 800, 400 177 | shortest_edge = int( 178 | min(max_len / aspect_ratio, min_len, min_hw)) 179 | longest_edge = int(shortest_edge * aspect_ratio) 180 | H, W = image.size 181 | if H > W: 182 | H, W = longest_edge, shortest_edge 183 | else: 184 | H, W = shortest_edge, longest_edge 185 | image = image.resize((H, W)) 186 | # image = image.resize((224, 224)) 187 | buffered = BytesIO() 188 | image.save(buffered, format="JPEG") 189 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 190 | img_str = f'user upload image' 191 | msg = msg + img_str 192 | ret.append([msg, None]) 193 | else: 194 | ret[-1][-1] = msg 195 | return ret 196 | 197 | 198 | def bot( 199 | text, 200 | image, 201 | state, 202 | prompt, 203 | ai_prefix, 204 | user_prefix, 205 | seperator, 206 | history_buffer, 207 | max_new_token, 208 | num_beams, 209 | temperature, 210 | top_k, 211 | top_p, 212 | do_sample, 213 | ): 214 | state.prompt_template = prompt 215 | state.ai_prefix = ai_prefix 216 | state.user_prefix = user_prefix 217 | state.sep = seperator 218 | state.buffer_size = history_buffer 219 | if image: 220 | state.add_message(user_prefix, (text, image)) 221 | else: 222 | state.add_message(user_prefix, text) 223 | state.add_message(ai_prefix, None) 224 | inputs = state.get_prompt() 225 | image_paths = state.get_images()[-1:] 226 | 227 | inference_results = inferencer(inputs, image_paths, max_new_token, 228 | num_beams, temperature, top_k, top_p, 229 | do_sample) 230 | state.all_history[-1][-1] = inference_results 231 | memory_allocated = str(round(torch.cuda.memory_allocated() / 1024**3, 232 | 2)) + 'GB' 233 | return state, to_gradio_chatbot(state), "", None, inputs, memory_allocated 234 | 235 | 236 | def clear(state): 237 | state.all_history = [] 238 | return state, to_gradio_chatbot(state), "", None, "" 239 | 240 | 241 | title_markdown = (""" 242 | # 🤖 Multi-modal GPT 243 | [[Project]](https://github.com/open-mmlab/Multimodal-GPT.git)""") 244 | 245 | 246 | def build_conversation_demo(): 247 | with gr.Blocks(title="Multi-modal GPT") as demo: 248 | gr.Markdown(title_markdown) 249 | 250 | state = gr.State(PromptGenerator()) 251 | with gr.Row(): 252 | with gr.Column(scale=3): 253 | memory_allocated = gr.Textbox( 254 | value=init_memory, label="Memory") 255 | imagebox = gr.Image(type="filepath") 256 | # TODO config parameters 257 | with gr.Accordion( 258 | "Parameters", 259 | open=True, 260 | ): 261 | max_new_token_bar = gr.Slider( 262 | 0, 1024, 512, label="max_new_token", step=1) 263 | num_beams_bar = gr.Slider( 264 | 0.0, 10, 3, label="num_beams", step=1) 265 | temperature_bar = gr.Slider( 266 | 0.0, 1.0, 1.0, label="temperature", step=0.01) 267 | topk_bar = gr.Slider(0, 100, 20, label="top_k", step=1) 268 | topp_bar = gr.Slider(0, 1.0, 1.0, label="top_p", step=0.01) 269 | do_sample = gr.Checkbox(True, label="do_sample") 270 | with gr.Accordion( 271 | "Prompt", 272 | open=False, 273 | ): 274 | with gr.Row(): 275 | ai_prefix = gr.Text("Response", label="AI Prefix") 276 | user_prefix = gr.Text( 277 | "Instruction", label="User Prefix") 278 | seperator = gr.Text("\n\n### ", label="Seperator") 279 | history_buffer = gr.Slider( 280 | -1, 10, -1, label="History buffer", step=1) 281 | prompt = gr.Text(TEMPLATE, label="Prompt") 282 | model_inputs = gr.Textbox(label="Actual inputs for Model") 283 | 284 | with gr.Column(scale=6): 285 | with gr.Row(): 286 | with gr.Column(): 287 | chatbot = gr.Chatbot(elem_id="chatbot").style( 288 | height=750) 289 | with gr.Row(): 290 | with gr.Column(scale=8): 291 | textbox = gr.Textbox( 292 | show_label=False, 293 | placeholder="Enter text and press ENTER", 294 | ).style(container=False) 295 | submit_btn = gr.Button(value="Submit") 296 | clear_btn = gr.Button(value="🗑️ Clear history") 297 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 298 | gr.Examples( 299 | examples=[ 300 | [ 301 | f"{cur_dir}/docs/images/demo_image.jpg", 302 | "What is in this image?" 303 | ], 304 | ], 305 | inputs=[imagebox, textbox], 306 | ) 307 | textbox.submit( 308 | bot, 309 | [ 310 | textbox, 311 | imagebox, 312 | state, 313 | prompt, 314 | ai_prefix, 315 | user_prefix, 316 | seperator, 317 | history_buffer, 318 | max_new_token_bar, 319 | num_beams_bar, 320 | temperature_bar, 321 | topk_bar, 322 | topp_bar, 323 | do_sample, 324 | ], 325 | [ 326 | state, chatbot, textbox, imagebox, model_inputs, 327 | memory_allocated 328 | ], 329 | ) 330 | submit_btn.click( 331 | bot, 332 | [ 333 | textbox, 334 | imagebox, 335 | state, 336 | prompt, 337 | ai_prefix, 338 | user_prefix, 339 | seperator, 340 | history_buffer, 341 | max_new_token_bar, 342 | num_beams_bar, 343 | temperature_bar, 344 | topk_bar, 345 | topp_bar, 346 | do_sample, 347 | ], 348 | [ 349 | state, chatbot, textbox, imagebox, model_inputs, 350 | memory_allocated 351 | ], 352 | ) 353 | clear_btn.click(clear, [state], 354 | [state, chatbot, textbox, imagebox, model_inputs]) 355 | return demo 356 | 357 | 358 | if __name__ == "__main__": 359 | llama_path = "checkpoints/llama-7b_hf" 360 | open_flamingo_path = "checkpoints/OpenFlamingo-9B/checkpoint.pt" 361 | finetune_path = "checkpoints/mmgpt-lora-v0-release.pt" 362 | 363 | inferencer = Inferencer( 364 | llama_path=llama_path, 365 | open_flamingo_path=open_flamingo_path, 366 | finetune_path=finetune_path) 367 | init_memory = str(round(torch.cuda.memory_allocated() / 1024**3, 2)) + 'GB' 368 | demo = build_conversation_demo() 369 | demo.queue(concurrency_count=3) 370 | IP = "0.0.0.0" 371 | PORT = 8997 372 | demo.launch(server_name=IP, server_port=PORT, share=True) 373 | -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Multimodal-GPT/9c73e47ad6c339e828a44f164d1a2c5bff904747/checkpoints/.gitkeep -------------------------------------------------------------------------------- /configs/dataset_config.py: -------------------------------------------------------------------------------- 1 | visual_datasets = [ 2 | dict( 3 | type="llava", 4 | vis_root="data/coco/train2017", 5 | ann_paths=[ 6 | "data/llava/detail_23k.json", 7 | "data/llava/complex_reasoning_77k.json", 8 | ], 9 | ), 10 | dict( 11 | type="llava_dial", 12 | vis_root="data/coco/train2017", 13 | ann_paths=[ 14 | "data/llava/conversation_58k.json", 15 | ], 16 | ), 17 | dict( 18 | type="aokvqa", 19 | vis_root="data/coco/images", 20 | ann_paths=[ 21 | "data/aokvqa/annotations/aokvqa_v1p0_train.json", 22 | ], 23 | sample=5000, 24 | ), 25 | dict( 26 | type="minigpt4", 27 | vis_root="data/cc_sbu_align/image", 28 | ann_paths=[ 29 | "data/cc_sbu_align/filter_cap.json", 30 | ], 31 | ), 32 | dict( 33 | type="coco_caption", 34 | vis_root="data/coco", 35 | ann_paths=[ 36 | "data/coco/annotations/coco_karpathy_train_converted.json", 37 | "data/coco/annotations/coco_karpathy_val.json", 38 | ], 39 | sample=512, 40 | ), 41 | dict( 42 | type="ocr_vqa", 43 | vis_root="data/OCR_VQA/image", 44 | ann_paths=[ 45 | "data/OCR_VQA/downloaded_dataset.json", 46 | ], 47 | sample=512, 48 | ), 49 | ] 50 | 51 | language_datasets = [ 52 | dict( 53 | type="dolly", 54 | ann_path="data/dolly/databricks-dolly-15k.jsonl", 55 | ), 56 | dict( 57 | type="alpaca_gpt4", 58 | ann_path="data/alpaca_gpt4/alpaca_gpt4_data.json", 59 | ), 60 | dict( 61 | type="baize", 62 | ann_path="data/baize/quora_chat_data.json", 63 | ), 64 | ] 65 | -------------------------------------------------------------------------------- /configs/lora_config.py: -------------------------------------------------------------------------------- 1 | tuning_config = dict( 2 | lora=True, 3 | lora_target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "to_q", "to_kv", "to_out", "ff.1", "ff.3"], 4 | lora_r=16, 5 | lora_alpha=16, 6 | lora_dropout=0.0, 7 | vis=True, 8 | unfrozen=[], 9 | ) 10 | -------------------------------------------------------------------------------- /docs/images/demo_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Multimodal-GPT/9c73e47ad6c339e828a44f164d1a2c5bff904747/docs/images/demo_image.jpg -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mmgpt 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - conda-forge::openjdk 7 | - pip 8 | - pip: 9 | - -r requirements.txt 10 | - -e . 11 | -------------------------------------------------------------------------------- /mmgpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .models.builder import create_model_and_transforms 2 | from .models.open_flamingo import Flamingo 3 | -------------------------------------------------------------------------------- /mmgpt/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_dataset # noqa: F401 2 | from .dial_dataset import DialDataset # noqa: F401 3 | from .samplers import InfiniteSampler # noqa: F401 4 | from .vqa_dataset import VQADataset # noqa: F401 5 | -------------------------------------------------------------------------------- /mmgpt/datasets/alpaca_gpt4_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from mmgpt.datasets.dolly_dataset import DollyDataset 4 | 5 | 6 | class AlpacaGPT4Dataset(DollyDataset): 7 | """ 8 | ```json 9 | [ 10 | { 11 | "instruction": "Identify the odd one out.", 12 | "input": "Twitter, Instagram, Telegram", 13 | "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service." 14 | }, 15 | ] 16 | """ 17 | 18 | def load_annotation(self, ann_path): 19 | self.annotation = json.load(open(ann_path, "r")) 20 | 21 | def process_text(self, ann): 22 | instruction = ann["instruction"] 23 | input = ann["input"] 24 | output = ann["output"] 25 | instruction = self.prompter(instruction=instruction, input=input) 26 | return dict(instruction=instruction, answer=output) 27 | -------------------------------------------------------------------------------- /mmgpt/datasets/aokvqa_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from .vqa_dataset import VQADataset 4 | 5 | REASON_QUESTIONS = [ 6 | "Why?", 7 | "Why is this?", 8 | "And why?", 9 | "What is the reason?", 10 | "And can you tell me why?", 11 | "Can you tell me why?", 12 | "Can you tell me the reason?", 13 | ] 14 | 15 | 16 | class AOKVQADataset(VQADataset): 17 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 18 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs) 19 | 20 | def process_text(self, ann): 21 | question = ann["question"] 22 | question = question + " " + random.choice(REASON_QUESTIONS) 23 | 24 | choices = ann["choices"] 25 | true_answer = choices[ann["correct_choice_idx"]] 26 | answer = "The answer is " + true_answer + ". Because " + " ".join(ann["rationales"]) 27 | 28 | is_option = random.random() < self.option_prob and len(choices) > 1 29 | if is_option: 30 | instruction = self.prompter(question, choices) 31 | else: 32 | instruction = self.prompter(question) 33 | 34 | instruction = self.prompter(question) 35 | return dict(instruction=instruction, answer=answer) 36 | 37 | 38 | def build_aokvqa_dataset( 39 | tokenizer, 40 | vis_processor, 41 | vis_root="data/coco/images", 42 | ann_paths=["data/aokvqa/annotations/aokvqa_v1p0_train.json"], 43 | sample_image=False, 44 | ): 45 | return AOKVQADataset( 46 | tokenizer=tokenizer, 47 | vis_processor=vis_processor, 48 | vis_root=vis_root, 49 | ann_paths=ann_paths, 50 | sample_image=sample_image, 51 | ) 52 | -------------------------------------------------------------------------------- /mmgpt/datasets/baize_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from mmgpt.datasets.dolly_dataset import DollyDataset 4 | 5 | 6 | TEMPLATE = { 7 | "description": "Template used by Alpaca-LoRA.", 8 | "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n", 9 | "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{question}\n\n### Response:\n", 10 | "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n", 11 | "response_split": "### Response:", 12 | } 13 | 14 | class LangDialPrompter: 15 | def __call__(self, question, options=None): 16 | if options: 17 | options = ", ".join(options) 18 | res = TEMPLATE["prompt_choice"].format(image="", question=question, options=options) 19 | else: 20 | res = TEMPLATE["prompt_dial"].format(question=question) 21 | return res 22 | 23 | def get_response(self, output: str) -> str: 24 | return output.split(TEMPLATE["response_split"])[-1].strip() 25 | 26 | class BaiZeDataset(DollyDataset): 27 | """ 28 | ```json 29 | [ 30 | { 31 | "instruction": "Identify the odd one out.", 32 | "input": "Twitter, Instagram, Telegram", 33 | "output": "The odd one out is Telegram. Twitter and Instagram are social media platforms mainly for sharing information, images and videos while Telegram is a cloud-based instant messaging and voice-over-IP service." 34 | }, 35 | ] 36 | """ 37 | def __init__(self, *args, **kwargs): 38 | super(BaiZeDataset, self).__init__(*args, **kwargs) 39 | self.prompter = LangDialPrompter() 40 | 41 | def load_annotation(self, ann_path): 42 | self.annotation = json.load(open(ann_path, "r")) 43 | 44 | def process_text(self, anns): 45 | # TODO remove this 46 | begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request." 47 | convs = anns['input'].split("[|Human|] ") 48 | conv_list = [] 49 | for conv_id, one_conv in enumerate(convs[1:-1]): 50 | question, answer = one_conv.split("[|AI|] ") 51 | question = question.replace("\n", "") 52 | answer = answer.replace("\n", "") 53 | instruction = self.prompter(question) 54 | if conv_id == 0: 55 | single_conv = dict(instruction=begin_string + instruction, answer=answer) 56 | else: 57 | single_conv = dict(instruction=instruction, answer=answer) 58 | conv_list.append(single_conv) 59 | return conv_list 60 | 61 | def __getitem__(self, index): 62 | ann = self.annotation[index] 63 | text_list = self.process_text(ann) 64 | res_list = [] 65 | for text in text_list: 66 | single_res = self.tokenize(text) 67 | single_res["instruction"] = text["instruction"] 68 | single_res["answer"] = text["answer"] 69 | res_list.append(single_res) 70 | 71 | input_ids = [] 72 | attention_mask = [] 73 | labels = [] 74 | instruction = [] 75 | answer = [] 76 | for res in res_list: 77 | input_ids.extend(res["input_ids"]) 78 | attention_mask.extend(res["attention_mask"]) 79 | labels.extend(res["labels"]) 80 | instruction.append(res["instruction"]) 81 | answer.append(res["answer"]) 82 | 83 | res = dict( 84 | input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer 85 | ) 86 | return res 87 | -------------------------------------------------------------------------------- /mmgpt/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .alpaca_gpt4_dataset import AlpacaGPT4Dataset # noqa: F401 5 | from .aokvqa_dataset import AOKVQADataset # noqa: F401 6 | from .cc_sbu_align_dataset import CcSbuAlignDataset # noqa: F401 7 | from .clevr_dataset import CLEVRDataset # noqa: F401 8 | from .coco_caption_dataset import COCOCaptionDataset # noqa: F401 9 | from .dial_dataset import DialDataset # noqa: F401 10 | from .dolly_dataset import DollyDataset # noqa: F401 11 | from .gqa_dataset import GQADataset # noqa: F401 12 | from .llava_dataset import LlavaDataset # noqa: F401 13 | from .nlvr_dataset import NLVRv1Dataset, NLVRv2Dataset # noqa: F401 14 | from .ocr_vqa_dataset import OCRVQADataset # noqa: F401 15 | from .snli_ve_datasets import SNLIVEDataset # noqa: F401 16 | from .text_ocr_dataset import TextOCRDataset # noqa: F401 17 | from .vqa_dataset import ConcatDataset, VQADataset # noqa: F401 18 | from .baize_dataset import BaiZeDataset # noqa: F401 19 | 20 | 21 | def build_dataset(dataset_config, **kwargs): 22 | if isinstance(dataset_config, list): 23 | datasets = [build_dataset(cfg, **kwargs) for cfg in dataset_config] 24 | return ConcatDataset(datasets) 25 | dataset_type = dataset_config.pop("type") 26 | sample = dataset_config.pop("sample", -1) 27 | if dataset_type == "llava": 28 | dataset = LlavaDataset( 29 | **dataset_config, 30 | **kwargs, 31 | ) 32 | elif dataset_type == "vqa": 33 | dataset = VQADataset( 34 | **dataset_config, 35 | **kwargs, 36 | ) 37 | elif dataset_type == "minigpt4": 38 | dataset = CcSbuAlignDataset( 39 | **dataset_config, 40 | **kwargs, 41 | ) 42 | elif dataset_type == "llava_dial": 43 | dataset = DialDataset( 44 | **dataset_config, 45 | **kwargs, 46 | ) 47 | elif dataset_type == "coco_dial": 48 | dataset = DialDataset( 49 | **dataset_config, 50 | **kwargs, 51 | ) 52 | elif dataset_type == "aokvqa": 53 | dataset = AOKVQADataset( 54 | **dataset_config, 55 | **kwargs, 56 | ) 57 | elif dataset_type == "okvqa": 58 | dataset = VQADataset( 59 | **dataset_config, 60 | **kwargs, 61 | ) 62 | elif dataset_type == "text_ocr": 63 | dataset = TextOCRDataset( 64 | **dataset_config, 65 | **kwargs, 66 | ) 67 | elif dataset_type == "ocr_vqa": 68 | dataset = OCRVQADataset( 69 | **dataset_config, 70 | **kwargs, 71 | ) 72 | elif dataset_type == "coco_caption": 73 | dataset = COCOCaptionDataset( 74 | **dataset_config, 75 | **kwargs, 76 | ) 77 | elif dataset_type == "gqa": 78 | dataset = GQADataset( 79 | **dataset_config, 80 | **kwargs, 81 | ) 82 | elif dataset_type == "clevr": 83 | dataset = CLEVRDataset( 84 | **dataset_config, 85 | **kwargs, 86 | ) 87 | elif dataset_type == "nlvrv1": 88 | dataset = NLVRv1Dataset( 89 | **dataset_config, 90 | **kwargs, 91 | ) 92 | elif dataset_type == "nlvrv2": 93 | dataset = NLVRv2Dataset( 94 | **dataset_config, 95 | **kwargs, 96 | ) 97 | elif dataset_type == "snlive": 98 | dataset = SNLIVEDataset( 99 | **dataset_config, 100 | **kwargs, 101 | ) 102 | elif dataset_type == "dolly": 103 | dataset = DollyDataset( 104 | **dataset_config, 105 | **kwargs, 106 | ) 107 | elif dataset_type == "alpaca_gpt4": 108 | dataset = AlpacaGPT4Dataset( 109 | **dataset_config, 110 | **kwargs, 111 | ) 112 | elif dataset_type == "baize": 113 | dataset = BaiZeDataset( 114 | **dataset_config, 115 | **kwargs, 116 | ) 117 | else: 118 | raise NotImplementedError 119 | 120 | if sample > 0: 121 | random_indices = np.random.choice(len(dataset), min(sample, len(dataset)), replace=False) 122 | subsample_dataset = torch.utils.data.Subset(dataset, random_indices) 123 | subsample_dataset.collater = dataset.collater 124 | return subsample_dataset 125 | else: 126 | return dataset 127 | -------------------------------------------------------------------------------- /mmgpt/datasets/cc_sbu_align_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | from PIL import Image 6 | 7 | from .vqa_dataset import VQADataset, VQAPrompter 8 | 9 | QUESTIONS = [ 10 | "please describe the image", 11 | "can you describe the image", 12 | "Could you provide a description of the image?", 13 | "What do you see in this image?", 14 | "Share your thoughts on the content of the image.", 15 | "Please narrate what's happening in the picture.", 16 | "Can you give a brief explanation of the image?", 17 | "Describe the main elements and details present in the image.", 18 | "In your own words, what is depicted in the image?", 19 | "Can you outline the key aspects of the image?", 20 | "What are the most striking features in this image?", 21 | "Please provide a summary of the image's content.", 22 | "Describe the overall theme or concept captured in the image.", 23 | "How would you explain the image's composition and focus?", 24 | "What is the focal point or main subject of the image?", 25 | "How do the different components of the image interact with each other?", 26 | "What would be a fitting caption for this image?", 27 | "Can you create a concise description that captures the essence of the image?", 28 | "How would you briefly summarize the content of this image in a phrase or sentence?", 29 | "Please provide a catchy and relevant caption for this picture.", 30 | "If you were to give this image a title, what would it be?", 31 | "Describe the image in one creative sentence.", 32 | "Please suggest a memorable phrase that encapsulates the image's content.", 33 | "What engaging phrase would best represent this image?", 34 | "Can you create an expressive caption that highlights the main theme of the image?", 35 | "How would you sum up the image's story for a caption?", 36 | "Provide an eye-catching caption that conveys the image's core message.", 37 | "If you were to give this image a headline, what would it say?", 38 | "Can you craft a captivating caption that communicates the essence of the image?", 39 | "How would you describe the image's content in a powerful caption?", 40 | "Please provide an inventive title to summarize the scene depicted in the image.", 41 | "Compose a concise and striking phrase that reflects the image's key elements.", 42 | "If you were to create a caption for this image, what would it be?", 43 | "Offer a compelling caption that highlights the central focus of the image.", 44 | "Can you produce a unique caption that encapsulates the image's overall mood?", 45 | "Please generate an attention-grabbing caption that would best illustrate the events captured in this image", 46 | "How would you express the image's main idea in an impactful sentence?", 47 | "Please create a vivid and concise title that conveys the essence of the picture.", 48 | "Compose an imaginative caption that reflects the image's most striking features.", 49 | "What memorable statement would best represent the scene illustrated in this image?", 50 | "Draft an evocative caption that brings the image to life for the reader.", 51 | "Can you suggest an insightful caption that highlights the underlying message of the image?", 52 | "What engaging phrase would effectively convey the action or subject matter depicted in this picture?", 53 | "How would you encapsulate the image's core theme in a concise and expressive manner?", 54 | "Please provide a creative and impactful title that captures the spirit of the image.", 55 | "Craft a captivating caption that showcases the image's most prominent attributes.", 56 | "What intriguing statement would best sum up the scene presented in this image?", 57 | "Develop a descriptive caption that paints a vivid picture for the viewer.", 58 | "Can you give a detailed account of the image's contents?", 59 | "What are the key elements and features visible in this image?", 60 | "How would you narrate the events or actions depicted in the picture?", 61 | "Please share your observations about the various components present in the image.", 62 | "What is the overall theme or concept captured in this image? Can you describe it?", 63 | ] 64 | 65 | 66 | class CcSbuAlignDataset(VQADataset): 67 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, add_eos=True, ignore_instruction=True): 68 | self.tokenizer = tokenizer 69 | self.vis_root = vis_root 70 | 71 | self.annotation = [] 72 | for ann_path in ann_paths: 73 | self.annotation.extend(json.load(open(ann_path, "r"))["annotations"]) 74 | 75 | self.vis_processor = vis_processor 76 | self.prompter = VQAPrompter() 77 | self.add_eos = add_eos 78 | self.ignore_instruction = ignore_instruction 79 | 80 | def process_text(self, ann): 81 | # random select a question 82 | question = random.choice(QUESTIONS) 83 | answer = ann["caption"] 84 | instruction = self.prompter(question) 85 | return dict(instruction=instruction, answer=answer) 86 | 87 | def process_image(self, ann): 88 | image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg") 89 | image = Image.open(image_path).convert("RGB") 90 | 91 | image = self.vis_processor(image) 92 | return image 93 | 94 | 95 | def build_ccsbualign_dataset( 96 | tokenizer, 97 | vis_processor, 98 | vis_root="data/cc_sbu_align/image/", 99 | ann_paths=["data/cc_sbu_align/filter_cap.json"], 100 | **kwargs, 101 | ): 102 | return CcSbuAlignDataset( 103 | tokenizer=tokenizer, 104 | vis_processor=vis_processor, 105 | vis_root=vis_root, 106 | ann_paths=ann_paths, 107 | ) 108 | -------------------------------------------------------------------------------- /mmgpt/datasets/clevr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | 6 | from PIL import Image 7 | 8 | from .vqa_dataset import VQADataset 9 | 10 | 11 | class CLEVRDataset(VQADataset): 12 | """Visual Reasoning Dataset. It also contains Dialog. 13 | 14 | Note: The image is a little bit simple. with several objects and simple background. 15 | """ 16 | 17 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 18 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) 19 | 20 | self.annotation = self.load_annotations(ann_paths) 21 | if self.sample_image: 22 | print("randomly sample one annotation for each image") 23 | self.annotation = self.parse_annotation(self.annotation) 24 | self._add_instance_ids() 25 | 26 | @staticmethod 27 | def load_annotations(ann_paths): 28 | annotation = [] 29 | for ann_path in ann_paths: 30 | ann = json.load(open(ann_path, "r")) 31 | annotation.extend(ann["questions"]) 32 | return annotation 33 | 34 | def parse_annotation(self, annotation): 35 | image_list = defaultdict(list) 36 | for ann in annotation: 37 | image_list[ann["image_filename"]].append(ann) 38 | annotation = [] 39 | for ann_list in image_list.values(): 40 | annotation.append(random.choice(ann_list)) 41 | return annotation 42 | 43 | def process_text(self, ann): 44 | question = ann["question"] 45 | answer = ann["answer"] 46 | instruction = self.prompter(question) 47 | return dict(instruction=instruction, answer=answer) 48 | 49 | def process_image(self, ann): 50 | split = ann["split"] 51 | image_path = os.path.join(self.vis_root, split, ann["image_filename"]) 52 | image = Image.open(image_path).convert("RGB") 53 | 54 | image = self.vis_processor(image) 55 | return image 56 | 57 | 58 | def build_clevr_dataset( 59 | tokenizer, 60 | vis_processor, 61 | vis_root="data/clevr/CLEVR_v1.0/images", 62 | ann_paths=[ 63 | "data/clevr/CLEVR_v1.0/questions/CLEVR_train_questions.json", 64 | "data/clevr/CLEVR_v1.0/questions/CLEVR_val_questions.json", 65 | ], 66 | sample_image=False, 67 | ): 68 | return CLEVRDataset( 69 | tokenizer=tokenizer, 70 | vis_processor=vis_processor, 71 | vis_root=vis_root, 72 | ann_paths=ann_paths, 73 | sample_image=sample_image, 74 | ) 75 | -------------------------------------------------------------------------------- /mmgpt/datasets/coco_caption_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import json 9 | import os 10 | import random 11 | 12 | import numpy as np 13 | from PIL import Image 14 | from transformers import LlamaTokenizer 15 | 16 | from .vqa_dataset import VQADataset 17 | 18 | QUESTIONS = [ 19 | "please describe the image", 20 | "can you describe the image", 21 | "Could you provide a description of the image?", 22 | "What do you see in this image?", 23 | "Share your thoughts on the content of the image.", 24 | "Please narrate what's happening in the picture.", 25 | "Can you give a brief explanation of the image?", 26 | "Describe the main elements and details present in the image.", 27 | "In your own words, what is depicted in the image?", 28 | "Can you outline the key aspects of the image?", 29 | "What are the most striking features in this image?", 30 | "Please provide a summary of the image's content.", 31 | "Describe the overall theme or concept captured in the image.", 32 | "How would you explain the image's composition and focus?", 33 | "What is the focal point or main subject of the image?", 34 | "How do the different components of the image interact with each other?", 35 | "What would be a fitting caption for this image?", 36 | "Can you create a concise description that captures the essence of the image?", 37 | "How would you briefly summarize the content of this image in a phrase or sentence?", 38 | "Please provide a catchy and relevant caption for this picture.", 39 | "If you were to give this image a title, what would it be?", 40 | "Describe the image in one creative sentence.", 41 | "Please suggest a memorable phrase that encapsulates the image's content.", 42 | "What engaging phrase would best represent this image?", 43 | "Can you create an expressive caption that highlights the main theme of the image?", 44 | "How would you sum up the image's story for a caption?", 45 | "Provide an eye-catching caption that conveys the image's core message.", 46 | "If you were to give this image a headline, what would it say?", 47 | "Can you craft a captivating caption that communicates the essence of the image?", 48 | "How would you describe the image's content in a powerful caption?", 49 | "Please provide an inventive title to summarize the scene depicted in the image.", 50 | "Compose a concise and striking phrase that reflects the image's key elements.", 51 | "If you were to create a caption for this image, what would it be?", 52 | "Offer a compelling caption that highlights the central focus of the image.", 53 | "Can you produce a unique caption that encapsulates the image's overall mood?", 54 | "Please generate an attention-grabbing caption that would best illustrate the events captured in this image", 55 | "How would you express the image's main idea in an impactful sentence?", 56 | "Please create a vivid and concise title that conveys the essence of the picture.", 57 | "Compose an imaginative caption that reflects the image's most striking features.", 58 | "What memorable statement would best represent the scene illustrated in this image?", 59 | "Draft an evocative caption that brings the image to life for the reader.", 60 | "Can you suggest an insightful caption that highlights the underlying message of the image?", 61 | "What engaging phrase would effectively convey the action or subject matter depicted in this picture?", 62 | "How would you encapsulate the image's core theme in a concise and expressive manner?", 63 | "Please provide a creative and impactful title that captures the spirit of the image.", 64 | "Craft a captivating caption that showcases the image's most prominent attributes.", 65 | "What intriguing statement would best sum up the scene presented in this image?", 66 | "Develop a descriptive caption that paints a vivid picture for the viewer.", 67 | "Can you give a detailed account of the image's contents?", 68 | "What are the key elements and features visible in this image?", 69 | "How would you narrate the events or actions depicted in the picture?", 70 | "Please share your observations about the various components present in the image.", 71 | "What is the overall theme or concept captured in this image? Can you describe it?", 72 | ] 73 | 74 | 75 | class COCOCaptionDataset(VQADataset): 76 | def __init__( 77 | self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True 78 | ): 79 | """ 80 | vis_root (string): Root directory of images (e.g. coco/images/) 81 | ann_root (string): directory to store the annotation file 82 | """ 83 | self.tokenizer: LlamaTokenizer = tokenizer 84 | self.vis_root = vis_root 85 | 86 | self.annotation = [] 87 | for ann_path in ann_paths: 88 | self.annotation.extend(json.load(open(ann_path, "r"))) 89 | 90 | self.vis_processor = vis_processor 91 | 92 | instructions = [] 93 | for question in QUESTIONS: 94 | # instruction = f"Below is a question about an image. Write a response to answer the question.\n\n### Image:\n\n\n### Question:\n{question}\n\n### Answer:\n".format( 95 | # question 96 | # ) 97 | instruction = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n".format( 98 | image="", question=question 99 | ) 100 | instructions.append(instruction) 101 | self.instructions = instructions 102 | self.add_eos = add_eos 103 | self.ignore_instruction = ignore_instruction 104 | 105 | def process_image(self, ann): 106 | image_path = os.path.join(self.vis_root, ann["image"]) 107 | image = Image.open(image_path).convert("RGB") 108 | 109 | image = self.vis_processor(image) 110 | return image 111 | 112 | def process_text(self, ann): 113 | all_captions = ann["caption"] 114 | if not isinstance(all_captions, list): 115 | all_captions = [all_captions] 116 | caption = random.choice(all_captions) 117 | instruction = random.choice(self.instructions) 118 | 119 | return dict(instruction=instruction, answer=caption) 120 | -------------------------------------------------------------------------------- /mmgpt/datasets/dial_dataset.py: -------------------------------------------------------------------------------- 1 | from .vqa_dataset import VQADataset 2 | 3 | TEMPLATE = { 4 | "description": "Template used by Alpaca-LoRA.", 5 | # "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Options:\n{options}\n\n### Answer:\n", 6 | # "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n", 7 | "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n", 8 | "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n", 9 | "prompt_dial": "\n\n### Instruction:\n{question}\n\n### Response:\n", 10 | "response_split": "### Response:", 11 | } 12 | 13 | 14 | class DialPrompter: 15 | def __call__(self, question, options=None): 16 | if options: 17 | options = ", ".join(options) 18 | res = TEMPLATE["prompt_choice"].format(image="", question=question, options=options) 19 | else: 20 | res = TEMPLATE["prompt_dial"].format(question=question) 21 | return res 22 | 23 | def get_response(self, output: str) -> str: 24 | return output.split(TEMPLATE["response_split"])[-1].strip() 25 | 26 | 27 | class DialDataset(VQADataset): 28 | def __init__(self, *args, **kwargs): 29 | super(DialDataset, self).__init__(*args, **kwargs) 30 | self.prompter = DialPrompter() 31 | 32 | def _add_instance_ids(self, key="id"): 33 | for idx, ann in enumerate(self.annotation): 34 | ann[key] = str(idx) 35 | 36 | def process_text(self, anns): 37 | # TODO remove this 38 | begin_string = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}".format( 39 | image="" 40 | ) 41 | num_convs = len(anns["conversations"]) // 2 42 | conv_list = [] 43 | for conv_id in range(num_convs): 44 | question = anns["conversations"][conv_id]["value"] 45 | # remove '' tag and '\n' 46 | question = question.replace("", "").replace("\n", "") 47 | answer = anns["conversations"][conv_id + 1]["value"] 48 | instruction = self.prompter(question) 49 | if conv_id == 0: 50 | single_conv = dict(instruction=begin_string + instruction, answer=answer) 51 | else: 52 | single_conv = dict(instruction=instruction, answer=answer) 53 | conv_list.append(single_conv) 54 | return conv_list 55 | 56 | def __getitem__(self, index): 57 | ann = self.annotation[index] 58 | image = self.process_image(ann) 59 | text_list = self.process_text(ann) 60 | res_list = [] 61 | for text in text_list: 62 | single_res = self.tokenize(text) 63 | single_res["instruction"] = text["instruction"] 64 | single_res["answer"] = text["answer"] 65 | res_list.append(single_res) 66 | 67 | input_ids = [] 68 | attention_mask = [] 69 | labels = [] 70 | instruction = [] 71 | answer = [] 72 | for res in res_list: 73 | input_ids.extend(res["input_ids"]) 74 | attention_mask.extend(res["attention_mask"]) 75 | labels.extend(res["labels"]) 76 | instruction.extend(res["instruction"]) 77 | answer.extend(res["answer"]) 78 | 79 | res = dict( 80 | input_ids=input_ids, attention_mask=attention_mask, labels=labels, instruction=instruction, answer=answer 81 | ) 82 | res.update(image=image) 83 | return res 84 | -------------------------------------------------------------------------------- /mmgpt/datasets/dolly_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from transformers import LlamaTokenizer 7 | 8 | TEMPLATE = { 9 | "description": "Template used by LLM.", 10 | "prompt_no_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n", 11 | "prompt_with_input_format": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n", 12 | "response_split": "### Response:", 13 | } 14 | 15 | 16 | class LMPrompter: 17 | def __call__(self, instruction, input=None): 18 | if input is None or len(input) == 0: 19 | return TEMPLATE["prompt_no_input_format"].format(instruction=instruction) 20 | else: 21 | return TEMPLATE["prompt_with_input_format"].format(instruction=instruction, input=input) 22 | 23 | def get_response(self, output: str) -> str: 24 | return output.split(TEMPLATE["response_split"])[-1].strip() 25 | 26 | 27 | class DollyDataset(Dataset): 28 | """Each line of the annotation file is a json object with the following fields: 29 | 30 | { 31 | "instruction": "What is a dispersive prism?", 32 | "context": "In optics, a dispersive prism is an optical prism that is used to disperse light, that is, to separate light into its spectral components (the colors of the rainbow). Different wavelengths (colors) of light will be deflected by the prism at different angles.[1] This is a result of the prism material's index of refraction varying with wavelength (dispersion). Generally, longer wavelengths (red) undergo a smaller deviation than shorter wavelengths (blue). The dispersion of white light into colors by a prism led Sir Isaac Newton to conclude that white light consisted of a mixture of different colors.", 33 | "response": "A dispersive prism is an optical prism that disperses the light's different wavelengths at different angles. When white light is shined through a dispersive prism it will separate into the different colors of the rainbow.", 34 | "category": "summarization" 35 | } 36 | 37 | """ 38 | 39 | def __init__(self, tokenizer, ann_path: str, add_eos=True, ignore_instruction=True, **kwargs): 40 | """ 41 | ann_path (string): directory to store the annotation file 42 | """ 43 | assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default" 44 | self.tokenizer: LlamaTokenizer = tokenizer 45 | 46 | self.annotation = [] 47 | self.prompter = LMPrompter() 48 | self.add_eos = add_eos 49 | self.ignore_instruction = ignore_instruction 50 | self.load_annotation(ann_path) 51 | 52 | def load_annotation(self, ann_path): 53 | self.annotation = [] 54 | for line in open(ann_path, "r").readlines(): 55 | self.annotation.append(json.loads(line)) 56 | 57 | def __len__(self): 58 | return len(self.annotation) 59 | 60 | def process_text(self, ann): 61 | instruction = ann["instruction"] 62 | context = ann["context"] 63 | response = ann["response"] 64 | instruction = self.prompter(instruction=instruction, input=context) 65 | return dict(instruction=instruction, answer=response) 66 | 67 | def tokenize(self, text): 68 | res = self.tokenizer( 69 | text["instruction"] + text["answer"], 70 | return_tensors=None, 71 | padding="do_not_pad", 72 | truncation=True, 73 | max_length=512, 74 | ) 75 | 76 | # manually add eos token 77 | if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos: 78 | res["input_ids"].append(self.tokenizer.eos_token_id) 79 | res["attention_mask"].append(1) 80 | labels = copy.deepcopy(res["input_ids"]) 81 | # ignore instruction_token 82 | if self.ignore_instruction: 83 | instruction_token = self.tokenizer( 84 | text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512 85 | ) 86 | labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :] 87 | 88 | res.update(labels=labels) 89 | return res 90 | 91 | def __getitem__(self, index): 92 | ann = self.annotation[index] 93 | text = self.process_text(ann) 94 | res = self.tokenize(text) 95 | res.update(text) 96 | return res 97 | 98 | def collater(self, samples): 99 | question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], [] 100 | 101 | for sample in samples: 102 | question_list.append(sample["instruction"]) 103 | answer_list.append(sample["answer"]) 104 | input_id_list.append(sample["input_ids"]) 105 | attention_mask_list.append(sample["attention_mask"]) 106 | labels_list.append(sample["labels"]) 107 | 108 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 109 | # same length to return tensors. 110 | max_label_length = max(len(l) for l in labels_list) 111 | padding_side = self.tokenizer.padding_side 112 | padded_labels = [] 113 | for l in labels_list: 114 | remainder = [-100] * (max_label_length - len(l)) 115 | if isinstance(l, list): 116 | l = l + remainder if padding_side == "right" else remainder + l 117 | elif padding_side == "right": 118 | l = np.concatenate([l, remainder]).astype(np.int64) 119 | else: 120 | l = np.concatenate([remainder, l]).astype(np.int64) 121 | padded_labels.append(l) 122 | 123 | padded_samples = self.tokenizer.pad( 124 | {"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels}, 125 | return_tensors="pt", 126 | padding="longest", 127 | ) 128 | 129 | labels = padded_samples["labels"] 130 | labels[labels == self.tokenizer.pad_token_id] = -100 131 | labels[:, 0] = -100 132 | return { 133 | "input_ids": padded_samples["input_ids"], 134 | "attention_mask": padded_samples["attention_mask"], 135 | "labels": labels, 136 | "instruction": question_list, 137 | "answer": answer_list, 138 | } 139 | 140 | 141 | def build_dolly_dataset( 142 | tokenizer, 143 | ann_path="data/dolly/databricks-dolly-15k.jsonl", 144 | **kwargs, 145 | ): 146 | return DollyDataset( 147 | tokenizer=tokenizer, 148 | ann_path=ann_path, 149 | **kwargs, 150 | ) 151 | -------------------------------------------------------------------------------- /mmgpt/datasets/gqa_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | 6 | from PIL import Image 7 | 8 | from .vqa_dataset import VQADataset 9 | 10 | 11 | class GQADataset(VQADataset): 12 | """Visual Reasoning Dataset.""" 13 | 14 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 15 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) 16 | 17 | self.annotation = self.load_annotations(ann_paths) 18 | if self.sample_image: 19 | print("randomly sample one annotation for each image") 20 | self.annotation = self.parse_annotation(self.annotation) 21 | self._add_instance_ids() 22 | self.answer_prob = 1.0 23 | 24 | @staticmethod 25 | def load_annotations(ann_paths): 26 | annotation = [] 27 | for ann_path in ann_paths: 28 | ann = json.load(open(ann_path, "r")) 29 | for k, v in ann.items(): 30 | v["question_id"] = k 31 | annotation.append(v) 32 | return annotation 33 | 34 | def parse_annotation(self, annotation): 35 | image_list = defaultdict(list) 36 | for ann in annotation: 37 | image_list[ann["imageId"]].append(ann) 38 | annotation = [] 39 | for ann_list in image_list.values(): 40 | annotation.append(random.choice(ann_list)) 41 | return annotation 42 | 43 | def process_text(self, ann): 44 | question = ann["question"] 45 | 46 | answer = ann["answer"] 47 | full_answer = ann["fullAnswer"] 48 | 49 | # TODO: check which one is better 50 | # Random select answer or full_answer 51 | if random.random() < self.answer_prob: 52 | select_answer = full_answer 53 | else: 54 | select_answer = answer 55 | 56 | instruction = self.prompter(question) 57 | return dict(instruction=instruction, answer=select_answer) 58 | 59 | def process_image(self, ann): 60 | image_path = os.path.join(self.vis_root, ann["imageId"] + ".jpg") 61 | image = Image.open(image_path).convert("RGB") 62 | 63 | image = self.vis_processor(image) 64 | return image 65 | 66 | 67 | def build_gqa_dataset( 68 | tokenizer, 69 | vis_processor, 70 | vis_root="data/gqa/images", 71 | ann_paths=[ 72 | "data/gqa/questions/train_all_questions/train_all_questions_0.json", 73 | "data/gqa/questions/val_all_questions.json", 74 | ], 75 | sample_image=False, 76 | ): 77 | return GQADataset( 78 | tokenizer=tokenizer, 79 | vis_processor=vis_processor, 80 | vis_root=vis_root, 81 | ann_paths=ann_paths, 82 | sample_image=sample_image, 83 | ) 84 | -------------------------------------------------------------------------------- /mmgpt/datasets/llava_dataset.py: -------------------------------------------------------------------------------- 1 | from .vqa_dataset import VQADataset 2 | 3 | 4 | class LlavaDataset(VQADataset): 5 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 6 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs) 7 | 8 | def _add_instance_ids(self, key="id"): 9 | for idx, ann in enumerate(self.annotation): 10 | ann[key] = str(idx) 11 | 12 | def process_text(self, ann): 13 | question = ann["conversations"][0]["value"] 14 | # remove '' tag and '\n' 15 | question = question.replace("", "").replace("\n", "") 16 | answer = ann["conversations"][1]["value"] 17 | instruction = self.prompter(question) 18 | return dict(instruction=instruction, answer=answer) 19 | -------------------------------------------------------------------------------- /mmgpt/datasets/nlvr_dataset.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import os 4 | import random 5 | from collections import defaultdict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | from PIL import Image 10 | 11 | from .vqa_dataset import VQADataset 12 | 13 | QUESTIONS = [ 14 | "Is this true?", 15 | "Is this right?", 16 | "Can you confirm this information?" "Do you agree with this statement?", 17 | "Does this align with your understanding?", 18 | "How do you interpret this information?", 19 | "Does this align with your understanding?", 20 | "Can you confirm this?", 21 | "Is this statement correct?", 22 | "Could you verify this information?", 23 | "Do you agree with this?", 24 | "Is this accurate?", 25 | "Can you validate this claim?", 26 | "Are these details valid?", 27 | "Is this factually correct?", 28 | "Is the following information correct?", 29 | "Could you please verify this fact?", 30 | "Do you agree with this assertion?", 31 | "Are these details accurate?", 32 | "Does this claim hold true?", 33 | ] 34 | 35 | 36 | class NLVRv1Dataset(VQADataset): 37 | """Visual Reasoning Dataset.""" 38 | 39 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 40 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) 41 | 42 | self.annotation = self.load_annotations(ann_paths) 43 | if self.sample_image: 44 | print("randomly sample one annotation for each image") 45 | self.annotation = self.parse_annotation(self.annotation) 46 | self._add_instance_ids() 47 | 48 | @staticmethod 49 | def load_annotations(ann_paths): 50 | annotation = [] 51 | for ann_path in ann_paths: 52 | if "train.json" in ann_path: 53 | split = "train" 54 | elif "dev.json" in ann_path: 55 | split = "dev" 56 | elif "test.json" in ann_path: 57 | split = "test" 58 | else: 59 | raise ValueError(f"Unknown split for {ann_path}") 60 | 61 | with open(ann_path, "r") as f: 62 | for line in f.readlines(): 63 | line = line.strip() 64 | if len(line) != 0: 65 | ann = json.loads(line) 66 | ann["split"] = split 67 | annotation.append(ann) 68 | 69 | return annotation 70 | 71 | def parse_annotation(self, annotation): 72 | image_list = defaultdict(list) 73 | for ann in annotation: 74 | img_key = f"{ann['split']}-{ann['identifier']}" 75 | image_list[img_key].append(ann) 76 | annotation = [] 77 | for ann_list in image_list.values(): 78 | annotation.append(random.choice(ann_list)) 79 | return annotation 80 | 81 | def process_text(self, ann): 82 | question = ann["sentence"] + " " + random.choice(QUESTIONS) 83 | true_answer = ann["label"] 84 | 85 | if random.random() < self.option_prob: 86 | instruction = self.prompter(question, ["true", "false"]) 87 | else: 88 | instruction = self.prompter(question) 89 | 90 | return dict(instruction=instruction, answer=true_answer) 91 | 92 | def process_image(self, ann): 93 | # each question have 6 images, we can random select one of them. 94 | # TODO: check whether using all 6 images? 95 | random_id = random.randint(0, 5) 96 | image_name = f"{ann['split']}-{ann['identifier']}-{random_id}.png" 97 | image_path = os.path.join(self.vis_root, ann["split"], "images", ann["directory"], image_name) 98 | image = Image.open(image_path).convert("RGB") 99 | 100 | image = self.vis_processor(image) 101 | return image 102 | 103 | 104 | class NLVRv2Dataset(VQADataset): 105 | """Visual Reasoning Dataset.""" 106 | 107 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 108 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths, **kwargs) 109 | self.flip_prob = 0.5 110 | 111 | def parse_annotation(self, annotation): 112 | image_list = defaultdict(list) 113 | for ann in annotation: 114 | image_list[ann["images"][0]].append(ann) 115 | # image_name_list = list(image_list.keys()) 116 | annotation = [] 117 | for ann_list in image_list.values(): 118 | annotation.append(random.choice(ann_list)) 119 | return annotation 120 | 121 | def process_text(self, ann): 122 | question = ann["sentence"] + " " + random.choice(QUESTIONS) 123 | true_answer = ann["label"] 124 | 125 | if random.random() < self.option_prob: 126 | instruction = self.prompter(question, ["true", "false"]) 127 | else: 128 | instruction = self.prompter(question) 129 | 130 | return dict(instruction=instruction, answer=true_answer) 131 | 132 | def process_image(self, ann): 133 | image_0_path = os.path.join(self.vis_root, ann["images"][0]) 134 | image_1_path = os.path.join(self.vis_root, ann["images"][1]) 135 | 136 | image_0 = Image.open(image_0_path).convert("RGB") 137 | image_1 = Image.open(image_1_path).convert("RGB") 138 | image_0 = self.vis_processor(image_0) 139 | image_1 = self.vis_processor(image_1) 140 | return image_0, image_1 141 | 142 | @staticmethod 143 | def _flip(samples): 144 | sentence = samples["sentence"] 145 | image0, image1 = samples["image0"], samples["image1"] 146 | 147 | if "left" not in sentence and "right" not in sentence: 148 | if random.random() < 0.5: 149 | image0, image1 = image1, image0 150 | else: 151 | if random.random() < 0.5: 152 | sentence = sentence.replace("left", "[TEMP_TOKEN]") 153 | sentence = sentence.replace("right", "left") 154 | sentence = sentence.replace("[TEMP_TOKEN]", "right") 155 | 156 | image0, image1 = image1, image0 157 | 158 | samples["sentence"] = sentence 159 | samples["image0"] = image0 160 | samples["image1"] = image1 161 | 162 | return samples 163 | 164 | def __getitem__(self, index): 165 | ann = copy.deepcopy(self.annotation[index]) 166 | image_0, image_1 = self.process_image(ann) 167 | if random.random() < self.flip_prob: 168 | samples = self._flip({"sentence": ann["sentence"], "image0": image_0, "image1": image_1}) 169 | image_0, image_1 = samples["image0"], samples["image1"] 170 | ann["sentence"] = samples["sentence"] 171 | # concat 172 | # TODO: https://github.com/salesforce/LAVIS/blob/main/lavis/models/blip_models/blip_nlvr.py 173 | # model logic need update if using nlvr2 174 | image = torch.cat([image_0, image_1], dim=2) 175 | image = F.interpolate(image[None, ...], size=(image_0.shape[1], image_0.shape[2]))[0] 176 | text = self.process_text(ann) 177 | res = self.tokenize(text) 178 | res.update(image=image) 179 | res.update(text) 180 | return res 181 | 182 | 183 | def build_nlvrv1_dataset( 184 | tokenizer, 185 | vis_processor, 186 | vis_root="data/nlvr", 187 | ann_paths=["data/nlvr//train/train.json"], 188 | sample_image=False, 189 | ): 190 | return NLVRv1Dataset( 191 | tokenizer=tokenizer, 192 | vis_processor=vis_processor, 193 | vis_root=vis_root, 194 | ann_paths=ann_paths, 195 | sample_image=sample_image, 196 | ) 197 | 198 | 199 | def build_nlvrv2_dataset( 200 | tokenizer, 201 | vis_processor, 202 | vis_root="data/nlvr2", 203 | ann_paths=["data/nlvr2/annotations/nlvr_train.json"], 204 | sample_image=False, 205 | ): 206 | return NLVRv2Dataset( 207 | tokenizer=tokenizer, 208 | vis_processor=vis_processor, 209 | vis_root=vis_root, 210 | ann_paths=ann_paths, 211 | sample_image=sample_image, 212 | ) 213 | -------------------------------------------------------------------------------- /mmgpt/datasets/ocr_vqa_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from PIL import Image 5 | 6 | from .vqa_dataset import VQADataset 7 | 8 | 9 | class OCRVQADataset(VQADataset): 10 | def process_image(self, ann): 11 | image_path = os.path.join(self.vis_root, ann["filename"]) 12 | image = Image.open(image_path).convert("RGB") 13 | 14 | image = self.vis_processor(image) 15 | return image 16 | 17 | def process_text(self, ann): 18 | index = random.choice(list(range(len(ann["questions"])))) 19 | question = ann["questions"][index] 20 | answer = ann["answers"][index] 21 | 22 | instruction = self.prompter(question) 23 | return dict(instruction=instruction, answer=answer) 24 | -------------------------------------------------------------------------------- /mmgpt/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .infinite_sampler import InfiniteSampler 2 | -------------------------------------------------------------------------------- /mmgpt/datasets/samplers/infinite_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch.utils.data.sampler import Sampler 5 | 6 | from mmgpt.train.distributed import world_info_from_env 7 | 8 | 9 | class InfiniteSampler(Sampler): 10 | def __init__(self, dataset: int, shuffle: bool = True, seed: int = 0): 11 | self._size = len(dataset) 12 | self._shuffle = shuffle 13 | self._seed = int(seed) 14 | _, rank, world_size = world_info_from_env() 15 | 16 | self._rank = rank 17 | self._world_size = world_size 18 | 19 | def __iter__(self): 20 | start = self._rank 21 | yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) 22 | 23 | def _infinite_indices(self): 24 | g = torch.Generator() 25 | g.manual_seed(self._seed) 26 | while True: 27 | if self._shuffle: 28 | yield from torch.randperm(self._size, generator=g).tolist() 29 | else: 30 | yield from torch.arange(self._size).tolist() 31 | -------------------------------------------------------------------------------- /mmgpt/datasets/snli_ve_datasets.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from collections import defaultdict 5 | 6 | from PIL import Image 7 | 8 | from .vqa_dataset import VQADataset 9 | 10 | QUESTIONS = [ 11 | "What do you think of the above sentence?", 12 | "Can you confirm this statement?", 13 | "How do you interpret the given information?", 14 | "What is your opinion on this matter?", 15 | "Could you provide your perspective on this statement?", 16 | "How would you respond to the provided claim?", 17 | "What are your thoughts regarding the mentioned subject?", 18 | "Can you elaborate on this idea in English?", 19 | "Do you have any insights or feedback on this topic?", 20 | "What's your take on the given statement?", 21 | "What is your perspective on the given statement?", 22 | "How would you interpret this remark?", 23 | "Could you please provide your opinion on this?", 24 | "Can you share your understanding of the above point?", 25 | "Would you mind elaborating on this topic?", 26 | "What are your views about the given statement?", 27 | "How do you feel about the presented information?", 28 | "Could you provide your perspective on this?", 29 | "What is your opinion regarding this statement?", 30 | "Can you share your thoughts about the mentioned claim?", 31 | "How would you interpret the above comment?", 32 | "Would you mind sharing your insights on this issue?", 33 | ] 34 | 35 | 36 | class SNLIVEDataset(VQADataset): 37 | """Visual Reasoning Dataset.""" 38 | 39 | def __init__(self, tokenizer, vis_processor, vis_root, ann_paths, **kwargs): 40 | super().__init__(tokenizer, vis_processor, vis_root, ann_paths=[], **kwargs) 41 | 42 | self.annotation = self.load_annotations(ann_paths) 43 | if self.sample_image: 44 | print("randomly sample one annotation for each image") 45 | self.annotation = self.parse_annotation(self.annotation) 46 | self._add_instance_ids() 47 | 48 | @staticmethod 49 | def load_annotations(ann_paths): 50 | annotation = [] 51 | for ann_path in ann_paths: 52 | with open(ann_path, "r") as f: 53 | for line in f.readlines(): 54 | line = line.strip() 55 | if len(line) != 0: 56 | ann = json.loads(line) 57 | annotation.append(ann) 58 | return annotation 59 | 60 | def parse_annotation(self, annotation): 61 | image_list = defaultdict(list) 62 | for ann in annotation: 63 | image_list[ann["Flickr30K_ID"]].append(ann) 64 | annotation = [] 65 | for ann_list in image_list.values(): 66 | annotation.append(random.choice(ann_list)) 67 | return annotation 68 | 69 | def process_text(self, ann): 70 | question = ann["sentence2"] + " " + random.choice(QUESTIONS) 71 | answer = ann["gold_label"] 72 | if random.random() < self.option_prob: 73 | instruction = self.prompter(question, ["entailment", "neutral", "contradiction"]) 74 | else: 75 | instruction = self.prompter(question) 76 | return dict(instruction=instruction, answer=answer) 77 | 78 | def process_image(self, ann): 79 | image_path = os.path.join(self.vis_root, ann["Flickr30K_ID"] + ".jpg") 80 | image = Image.open(image_path).convert("RGB") 81 | image = self.vis_processor(image) 82 | return image 83 | -------------------------------------------------------------------------------- /mmgpt/datasets/text_ocr_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | from PIL import Image 7 | from transformers import LlamaTokenizer 8 | 9 | from .vqa_dataset import VQADataset, VQAPrompter 10 | 11 | 12 | class TextOCRDataset(VQADataset): 13 | def __init__( 14 | self, tokenizer, vis_processor=None, vis_root=None, ann_paths=[], add_eos=True, ignore_instruction=True 15 | ): 16 | """ 17 | vis_root (string): Root directory of images (e.g. coco/images/) 18 | ann_root (string): directory to store the annotation file 19 | """ 20 | assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default" 21 | self.tokenizer: LlamaTokenizer = tokenizer 22 | self.vis_root = vis_root 23 | 24 | self.annotation = [] 25 | for ann_path in ann_paths: 26 | self.annotation.extend(json.load(open(ann_path, "r"))["data"]) 27 | 28 | self.vis_processor = vis_processor 29 | 30 | self._add_instance_ids() 31 | self.option_prob = 0.5 32 | self.prompter = VQAPrompter() 33 | self.add_eos = add_eos 34 | self.ignore_instruction = ignore_instruction 35 | 36 | def process_image(self, ann): 37 | image_path = os.path.join(self.vis_root, ann["image_id"] + ".jpg") 38 | image = Image.open(image_path).convert("RGB") 39 | 40 | image = self.vis_processor(image) 41 | return image 42 | 43 | def process_text(self, ann): 44 | question = ann["question"] 45 | 46 | answer_weight = {} 47 | for answer in ann["answers"]: 48 | if answer in answer_weight.keys(): 49 | answer_weight[answer] += 1 / len(ann["answers"]) 50 | else: 51 | answer_weight[answer] = 1 / len(ann["answers"]) 52 | 53 | answers = list(answer_weight.keys()) 54 | weights = list(answer_weight.values()) 55 | 56 | # create instruction 57 | true_answer = answers[np.argmax(weights)] 58 | is_option = random.random() < self.option_prob and len(answers) > 1 59 | if is_option: 60 | instruction = self.prompter(question, answers) 61 | else: 62 | instruction = self.prompter(question) 63 | 64 | return dict(instruction=instruction, answer=true_answer) 65 | -------------------------------------------------------------------------------- /mmgpt/datasets/vqa_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import copy 9 | import json 10 | import os 11 | import random 12 | from collections import defaultdict 13 | from typing import Iterable 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image 18 | from torch.utils.data import ConcatDataset, Dataset 19 | from torch.utils.data.dataloader import default_collate 20 | from transformers import LlamaTokenizer 21 | 22 | TEMPLATE = { 23 | "description": "Template used by Alpaca-LoRA.", 24 | # "prompt_choice": "Below is a multiple choice question about an image, along with answer options. Please choose the correct answer from these options.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Input:\n{options}\n\n### Answer:\n", 25 | # "prompt_qa": "Below is a question about an image. Write a response to answer the question.\n\n### Image:\n{image}\n\n### Question:\n{question}\n\n### Answer:\n", 26 | "prompt_choice": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Input:\n{options}\n\n### Response:\n", 27 | "prompt_qa": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Image:\n{image}\n\n### Instruction:\n{question}\n\n### Response:\n", 28 | "response_split": "### Response:", 29 | } 30 | 31 | 32 | class VQAPrompter: 33 | def __call__(self, question, options=None): 34 | if options: 35 | options = ", ".join(options) 36 | res = TEMPLATE["prompt_choice"].format(image="", question=question, options=options) 37 | else: 38 | res = TEMPLATE["prompt_qa"].format(image="", question=question) 39 | return res 40 | 41 | def get_response(self, output: str) -> str: 42 | return output.split(TEMPLATE["response_split"])[-1].strip() 43 | 44 | 45 | class VQADataset(Dataset): 46 | def __init__( 47 | self, 48 | tokenizer, 49 | vis_processor=None, 50 | vis_root=None, 51 | ann_paths=[], 52 | add_eos=True, 53 | ignore_instruction=True, 54 | sample_image=False, 55 | ): 56 | """ 57 | vis_root (string): Root directory of images (e.g. coco/images/) 58 | ann_root (string): directory to store the annotation file 59 | """ 60 | assert tokenizer.add_eos_token is False, "tokenizer should not add eos token by default" 61 | self.tokenizer: LlamaTokenizer = tokenizer 62 | self.vis_root = vis_root 63 | 64 | self.annotation = [] 65 | for ann_path in ann_paths: 66 | self.annotation.extend(json.load(open(ann_path, "r"))) 67 | 68 | self.sample_image = sample_image 69 | if self.sample_image: 70 | print("randomly sample one annotation for each image") 71 | self.annotation = self.parse_annotation(self.annotation) 72 | 73 | self.vis_processor = vis_processor 74 | 75 | self._add_instance_ids() 76 | self.option_prob = 0.5 77 | self.prompter = VQAPrompter() 78 | self.add_eos = add_eos 79 | self.ignore_instruction = ignore_instruction 80 | 81 | def parse_annotation(self, annotation): 82 | image_list = defaultdict(list) 83 | for ann in annotation: 84 | image_list[ann["image"]].append(ann) 85 | # image_name_list = list(image_list.keys()) 86 | annotation = [] 87 | for ann_list in image_list.values(): 88 | annotation.append(random.choice(ann_list)) 89 | return annotation 90 | 91 | def __len__(self): 92 | return len(self.annotation) 93 | 94 | def _add_instance_ids(self, key="instance_id"): 95 | for idx, ann in enumerate(self.annotation): 96 | ann[key] = str(idx) 97 | 98 | def process_image(self, ann): 99 | image_path = os.path.join(self.vis_root, ann["image"]) 100 | image = Image.open(image_path).convert("RGB") 101 | 102 | image = self.vis_processor(image) 103 | return image 104 | 105 | def process_text(self, ann): 106 | question = ann["question"] 107 | 108 | answer_weight = {} 109 | for answer in ann["answer"]: 110 | if answer in answer_weight.keys(): 111 | answer_weight[answer] += 1 / len(ann["answer"]) 112 | else: 113 | answer_weight[answer] = 1 / len(ann["answer"]) 114 | 115 | answers = list(answer_weight.keys()) 116 | weights = list(answer_weight.values()) 117 | 118 | # create instruction 119 | true_answer = answers[np.argmax(weights)] 120 | is_option = random.random() < self.option_prob and len(answers) > 1 121 | if is_option: 122 | instruction = self.prompter(question, answers) 123 | else: 124 | instruction = self.prompter(question) 125 | 126 | return dict(instruction=instruction, answer=true_answer) 127 | 128 | def tokenize(self, text): 129 | res = self.tokenizer( 130 | text["instruction"] + text["answer"], 131 | return_tensors=None, 132 | padding="do_not_pad", 133 | truncation=True, 134 | max_length=512, 135 | ) 136 | 137 | # manually add eos token 138 | if res["input_ids"][-1] != self.tokenizer.eos_token_id and len(res["input_ids"]) < 512 and self.add_eos: 139 | res["input_ids"].append(self.tokenizer.eos_token_id) 140 | res["attention_mask"].append(1) 141 | labels = copy.deepcopy(res["input_ids"]) 142 | # ignore instruction_token 143 | if self.ignore_instruction: 144 | instruction_token = self.tokenizer( 145 | text["instruction"], return_tensors=None, padding="do_not_pad", truncation=True, max_length=512 146 | ) 147 | labels = [-100] * len(instruction_token["input_ids"]) + labels[len(instruction_token["input_ids"]) :] 148 | 149 | res.update(labels=labels) 150 | return res 151 | 152 | def __getitem__(self, index): 153 | ann = self.annotation[index] 154 | image = self.process_image(ann) 155 | text = self.process_text(ann) 156 | res = self.tokenize(text) 157 | res.update(image=image) 158 | res.update(text) 159 | return res 160 | 161 | def collater(self, samples): 162 | image_list, question_list, answer_list, input_id_list, attention_mask_list, labels_list = [], [], [], [], [], [] 163 | 164 | for sample in samples: 165 | image_list.append(sample["image"]) 166 | question_list.append(sample["instruction"]) 167 | answer_list.append(sample["answer"]) 168 | input_id_list.append(sample["input_ids"]) 169 | attention_mask_list.append(sample["attention_mask"]) 170 | labels_list.append(sample["labels"]) 171 | 172 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 173 | # same length to return tensors. 174 | max_label_length = max(len(l) for l in labels_list) 175 | padding_side = self.tokenizer.padding_side 176 | padded_labels = [] 177 | for l in labels_list: 178 | remainder = [-100] * (max_label_length - len(l)) 179 | if isinstance(l, list): 180 | l = l + remainder if padding_side == "right" else remainder + l 181 | elif padding_side == "right": 182 | l = np.concatenate([l, remainder]).astype(np.int64) 183 | else: 184 | l = np.concatenate([remainder, l]).astype(np.int64) 185 | padded_labels.append(l) 186 | 187 | padded_samples = self.tokenizer.pad( 188 | {"input_ids": input_id_list, "attention_mask": attention_mask_list, "labels": padded_labels}, 189 | return_tensors="pt", 190 | padding="longest", 191 | ) 192 | 193 | labels = padded_samples["labels"] 194 | media_token_id = self.tokenizer("", add_special_tokens=False)["input_ids"][-1] 195 | labels[labels == self.tokenizer.pad_token_id] = -100 196 | labels[:, 0] = -100 197 | labels[labels == media_token_id] = -100 198 | return { 199 | "image": torch.stack(image_list, dim=0), 200 | "input_ids": padded_samples["input_ids"], 201 | "attention_mask": padded_samples["attention_mask"], 202 | "labels": labels, 203 | "instruction": question_list, 204 | "answer": answer_list, 205 | } 206 | 207 | 208 | class ConcatDataset(ConcatDataset): 209 | def __init__(self, datasets: Iterable[Dataset]) -> None: 210 | super().__init__(datasets) 211 | 212 | def collater(self, samples): 213 | # TODO For now only supports datasets with same underlying collater implementations 214 | 215 | all_keys = set() 216 | for s in samples: 217 | all_keys.update(s) 218 | 219 | shared_keys = all_keys 220 | for s in samples: 221 | shared_keys = shared_keys & set(s.keys()) 222 | 223 | samples_shared_keys = [] 224 | for s in samples: 225 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys}) 226 | 227 | return self.datasets[0].collater(samples_shared_keys) 228 | -------------------------------------------------------------------------------- /mmgpt/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Multimodal-GPT/9c73e47ad6c339e828a44f164d1a2c5bff904747/mmgpt/models/__init__.py -------------------------------------------------------------------------------- /mmgpt/models/blip2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/open-mmlab/Multimodal-GPT/9c73e47ad6c339e828a44f164d1a2c5bff904747/mmgpt/models/blip2/__init__.py -------------------------------------------------------------------------------- /mmgpt/models/builder.py: -------------------------------------------------------------------------------- 1 | from .open_flamingo import create_model_and_transforms as create_open_flamingo_model_and_transforms 2 | import torch.nn as nn 3 | from transformers import LlamaTokenizer, LlamaForCausalLM 4 | 5 | def create_model_and_transforms( 6 | model_name: str, 7 | clip_vision_encoder_path: str, 8 | clip_vision_encoder_pretrained: str, 9 | lang_encoder_path: str, 10 | tokenizer_path: str, 11 | tuning_config, 12 | pretrained_model_path, 13 | **kwargs, 14 | ): 15 | if model_name == "open_flamingo": 16 | return create_open_flamingo_model_and_transforms( 17 | clip_vision_encoder_path=clip_vision_encoder_path, 18 | clip_vision_encoder_pretrained=clip_vision_encoder_pretrained, 19 | lang_encoder_path=lang_encoder_path, 20 | tokenizer_path=tokenizer_path, 21 | tuning_config=tuning_config, 22 | pretrained_model_path=pretrained_model_path, 23 | **kwargs, 24 | ) 25 | # TODO: support BLIP2 26 | else: 27 | raise ValueError(f"Unknown model name: {model_name}") 28 | 29 | # only for debugging 30 | def create_toy_model_and_transforms( 31 | model_name: str, 32 | clip_vision_encoder_path: str, 33 | clip_vision_encoder_pretrained: str, 34 | lang_encoder_path: str, 35 | tokenizer_path: str, 36 | tuning_config, 37 | pretrained_model_path, 38 | **kwargs, 39 | ): 40 | print("init toy vision encoder") 41 | import torchvision 42 | 43 | image_processor = torchvision.transforms.Compose( 44 | [ 45 | torchvision.transforms.Resize((224, 224)), 46 | torchvision.transforms.ToTensor(), 47 | ] 48 | ) 49 | print("init tokenizer") 50 | text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path) 51 | # add Flamingo special tokens to the tokenizer 52 | text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", ""]}) 53 | if text_tokenizer.pad_token is None: 54 | # Issue: GPT models don't have a pad token, which we use to 55 | # modify labels for the loss. 56 | text_tokenizer.add_special_tokens({"pad_token": ""}) 57 | 58 | class ToyModel(nn.Module): 59 | def __init__(self, *args, **kwargs): 60 | super().__init__() 61 | self.input_embeddings = nn.Embedding(38000, 512) 62 | self.layer = nn.Linear(512, 512) 63 | self.config = {"hidden_size": 512} 64 | 65 | def forward(self, lang_x, **kwargs): 66 | x = self.input_embeddings(lang_x) 67 | x = self.layer(x) 68 | loss = x.sum() 69 | 70 | return (loss,) 71 | 72 | model = ToyModel() 73 | 74 | return model, image_processor, text_tokenizer 75 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import create_model_and_transforms 2 | from .flamingo import Flamingo 3 | from .flamingo_lm import FlamingoLMMixin 4 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/builder.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | import open_clip 3 | import torch 4 | import torch.nn as nn 5 | from bigmodelvis import Visualization 6 | from peft import LoraConfig, get_peft_model 7 | from transformers import LlamaForCausalLM, LlamaTokenizer 8 | 9 | from .flamingo import Flamingo 10 | from .flamingo_lm import FlamingoLMMixin 11 | from .utils import extend_instance 12 | 13 | 14 | def create_model_and_transforms( 15 | clip_vision_encoder_path: str, 16 | clip_vision_encoder_pretrained: str, 17 | lang_encoder_path: str, 18 | tokenizer_path: str, 19 | decoder_layers_attr_name: str = None, 20 | pretrained_model_path: str = None, 21 | tuning_config=None, 22 | **flamingo_kwargs, 23 | ): 24 | """ 25 | Initialize a Flamingo model from a pretrained vision encoder and language encoder. 26 | Appends special tokens to the tokenizer and freezes backbones. 27 | 28 | Args: 29 | clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32") 30 | clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k") 31 | lang_encoder_path (str): path to pretrained language encoder 32 | tokenizer_path (str): path to pretrained tokenizer 33 | decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None. 34 | Returns: 35 | Flamingo: Flamingo model from pretrained vision and language encoders 36 | Image processor: Pipeline to preprocess input images 37 | Tokenizer: A tokenizer for the language model 38 | """ 39 | print("init clip vision encoder") 40 | vision_encoder, _, image_processor = open_clip.create_model_and_transforms( 41 | clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained 42 | ) 43 | # set the vision encoder to output the visual features 44 | vision_encoder.visual.output_tokens = True 45 | print("init tokenizer") 46 | text_tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path) 47 | # add Flamingo special tokens to the tokenizer 48 | text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", ""]}) 49 | if text_tokenizer.pad_token is None: 50 | # Issue: GPT models don't have a pad token, which we use to 51 | # modify labels for the loss. 52 | text_tokenizer.add_special_tokens({"pad_token": ""}) 53 | text_tokenizer.bos_token_id = 1 54 | text_tokenizer.eos_token_id = 2 55 | 56 | print("init llama") 57 | lang_encoder = LlamaForCausalLM.from_pretrained(lang_encoder_path) 58 | extend_instance(lang_encoder, FlamingoLMMixin) 59 | 60 | if decoder_layers_attr_name is None: 61 | decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder) 62 | lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name) 63 | lang_encoder.resize_token_embeddings(len(text_tokenizer)) 64 | 65 | model = Flamingo( 66 | vision_encoder, 67 | lang_encoder, 68 | text_tokenizer.encode("<|endofchunk|>")[-1], 69 | text_tokenizer.encode("")[-1], 70 | vis_dim=open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"]["width"], 71 | cross_attn_every_n_layers=4, 72 | **flamingo_kwargs, 73 | ) 74 | 75 | if pretrained_model_path is not None: 76 | print(f"loading pretrained model from {pretrained_model_path}") 77 | model.load_state_dict(torch.load(pretrained_model_path), strict=False) 78 | 79 | # Freeze all parameters 80 | model.requires_grad_(False) 81 | assert sum(p.numel() for p in model.parameters() if p.requires_grad) == 0 82 | 83 | if tuning_config is not None: 84 | model = prepare_model_for_tuning(model, tuning_config) 85 | else: 86 | raise ValueError("tuning_config must be provided") 87 | 88 | print( 89 | f"Flamingo model initialized with {sum(p.numel() for p in model.parameters() if p.requires_grad)} trainable parameters" 90 | ) 91 | 92 | return model, image_processor, text_tokenizer 93 | 94 | 95 | def _infer_decoder_layers_attr_name(model): 96 | for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES: 97 | if k.lower() in model.__class__.__name__.lower(): 98 | return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k] 99 | 100 | raise ValueError( 101 | f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually." 102 | ) 103 | 104 | 105 | __KNOWN_DECODER_LAYERS_ATTR_NAMES = { 106 | "opt": "model.decoder.layers", 107 | "gptneo": "transformer.h", 108 | "gptj": "transformer.h", 109 | "gpt-j": "transformer.h", 110 | "pythia": "gpt_neox.layers", 111 | "llama": "model.layers", 112 | } 113 | 114 | 115 | def prepare_model_for_tuning(model: nn.Module, config): 116 | if config.lora: 117 | lora_config = LoraConfig( 118 | r=config.lora_r, 119 | lora_alpha=config.lora_alpha, 120 | target_modules=config.lora_target_modules, 121 | lora_dropout=config.lora_dropout, 122 | bias="none", # won't use bias currently 123 | modules_to_save=[], # TODO: might be helpful if save partial model 124 | task_type="VL", 125 | ) 126 | model.lang_encoder = get_peft_model(model.lang_encoder, peft_config=lora_config) 127 | 128 | # manually unfreeze modules, we use a `substring` fashion mathcing 129 | for name, param in model.named_parameters(): 130 | if any(substr in name for substr in config.unfrozen): 131 | param.requires_grad = True 132 | 133 | if config.vis and is_rank0(): 134 | Visualization(model).structure_graph() 135 | return model 136 | 137 | 138 | # temporary workaround, should use a common utils in the future 139 | def is_rank0(): 140 | if not torch.distributed.is_initialized(): 141 | return True 142 | return torch.distributed.get_rank() == 0 143 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/flamingo.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | import torch 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from .helpers import PerceiverResampler 7 | 8 | 9 | class Flamingo(nn.Module): 10 | def __init__( 11 | self, 12 | vision_encoder: nn.Module, 13 | lang_encoder: nn.Module, 14 | eoc_token_id: int, 15 | media_token_id: int, 16 | vis_dim: int, 17 | cross_attn_every_n_layers: int = 1, 18 | use_media_placement_augmentation: bool = False, 19 | ): 20 | """ 21 | Args: 22 | vision_encoder (nn.Module): HF CLIPModel 23 | lang_encoder (nn.Module): HF causal language model 24 | eoc_token_id (int): Token id for <|endofchunk|> 25 | media_token_id (int): Token id for 26 | vis_dim (int): Dimension of the visual features. 27 | Visual features are projected to match this shape along the last dimension. 28 | cross_attn_every_n_layers (int, optional): How often to apply cross attention after transformer layer. Defaults to 1. 29 | use_media_placement_augmentation (bool, optional): Whether to randomly assign images to the preceding or following text in training. Defaults to False. 30 | """ 31 | super().__init__() 32 | self.eoc_token_id = eoc_token_id 33 | self.media_token_id = media_token_id 34 | self.use_media_placement_augmentation = use_media_placement_augmentation 35 | self.vis_dim = vis_dim 36 | self.vision_encoder = vision_encoder 37 | self.perceiver = PerceiverResampler(dim=self.vis_dim) 38 | self.lang_encoder = lang_encoder 39 | self.lang_encoder.init_flamingo( 40 | media_token_id=media_token_id, 41 | vis_hidden_size=self.vis_dim, 42 | cross_attn_every_n_layers=cross_attn_every_n_layers, 43 | use_media_placement_augmentation=self.use_media_placement_augmentation, 44 | ) 45 | 46 | def forward( 47 | self, 48 | vision_x: torch.Tensor, 49 | lang_x: torch.Tensor, 50 | attention_mask: torch.Tensor = None, 51 | labels: torch.Tensor = None, 52 | use_cached_vision_x: bool = False, 53 | clear_conditioned_layers: bool = True, 54 | past_key_values=None, 55 | use_cache: bool = False, 56 | ): 57 | """ 58 | Forward pass of Flamingo. 59 | 60 | Args: 61 | vision_x (torch.Tensor): Vision input 62 | shape (B, T_img, F, C, H, W) with F=1 63 | lang_x (torch.Tensor): Language input ids 64 | shape (B, T_txt) 65 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 66 | labels (torch.Tensor, optional): Labels. Defaults to None. 67 | clear_conditioned_layers: if True, clear the conditioned layers 68 | once the foward pass is completed. Set this to false if the 69 | same set of images will be reused in another subsequent 70 | forward pass. 71 | past_key_values: pre-computed values to pass to language model. 72 | See past_key_values documentation in Hugging Face 73 | CausalLM models. 74 | use_cache: whether to use cached key values. See use_cache 75 | documentation in Hugging Face CausalLM models. 76 | """ 77 | if vision_x is None and use_cached_vision_x is False: 78 | for layer in self.lang_encoder._get_decoder_layers(): 79 | layer.condition_only_lang_x(True) 80 | output = self.lang_encoder( 81 | input_ids=lang_x, 82 | attention_mask=attention_mask, 83 | labels=labels, 84 | past_key_values=past_key_values, 85 | use_cache=use_cache, 86 | ) 87 | for layer in self.lang_encoder._get_decoder_layers(): 88 | layer.condition_only_lang_x(False) 89 | return output 90 | assert ( 91 | vision_x is not None 92 | ) or use_cached_vision_x, "Must provide either vision_x or use_cached_vision_x to True." 93 | 94 | if use_cached_vision_x: 95 | # Case: use cached; vision_x should be cached and other 96 | # vision-related inputs should not be provided. 97 | assert vision_x is None, "Expect vision_x to be None when use_cached_vision_x is True." 98 | assert self.lang_encoder.is_conditioned() 99 | 100 | else: 101 | # Case: do not use caching (i.e. this is a standard forward pass); 102 | self._encode_vision_x(vision_x=vision_x) 103 | 104 | output = self.lang_encoder( 105 | input_ids=lang_x, 106 | attention_mask=attention_mask, 107 | labels=labels, 108 | past_key_values=past_key_values, 109 | use_cache=use_cache, 110 | ) 111 | 112 | if clear_conditioned_layers: 113 | self.lang_encoder.clear_conditioned_layers() 114 | 115 | return output 116 | 117 | def generate( 118 | self, 119 | vision_x: torch.Tensor, 120 | lang_x: torch.Tensor, 121 | attention_mask: torch.Tensor = None, 122 | num_beams=1, 123 | max_new_tokens=None, 124 | temperature=1.0, 125 | top_k=0, 126 | top_p=1.0, 127 | no_repeat_ngram_size=0, 128 | prefix_allowed_tokens_fn=None, 129 | length_penalty=1.0, 130 | num_return_sequences=1, 131 | do_sample=False, 132 | early_stopping=False, 133 | ): 134 | """ 135 | Generate text conditioned on vision and language inputs. 136 | 137 | Args: 138 | vision_x (torch.Tensor): Vision input 139 | shape (B, T_img, F, C, H, W) 140 | images in the same chunk are collated along T_img, and frames are collated along F 141 | currently only F=1 is supported (single-frame videos) 142 | lang_x (torch.Tensor): Language input 143 | shape (B, T_txt) 144 | max_length (int, optional): Maximum length of the output. Defaults to None. 145 | attention_mask (torch.Tensor, optional): Attention mask. Defaults to None. 146 | num_beams (int, optional): Number of beams. Defaults to 1. 147 | max_new_tokens (int, optional): Maximum new tokens. Defaults to None. 148 | temperature (float, optional): Temperature. Defaults to 1.0. 149 | top_k (int, optional): Top k. Defaults to 0. 150 | top_p (float, optional): Top p. Defaults to 1.0. 151 | no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0. 152 | length_penalty (float, optional): Length penalty. Defaults to 1.0. 153 | num_return_sequences (int, optional): Number of return sequences. Defaults to 1. 154 | do_sample (bool, optional): Do sample. Defaults to False. 155 | early_stopping (bool, optional): Early stopping. Defaults to False. 156 | Returns: 157 | torch.Tensor: lang_x with generated tokens appended to it 158 | """ 159 | if num_beams > 1: 160 | vision_x = vision_x.repeat_interleave(num_beams, dim=0) 161 | 162 | self._encode_vision_x(vision_x=vision_x) 163 | 164 | output = self.lang_encoder.generate( 165 | lang_x, 166 | attention_mask=attention_mask, 167 | # eos_token_id=self.eoc_token_id, 168 | num_beams=num_beams, 169 | max_new_tokens=max_new_tokens, 170 | temperature=temperature, 171 | top_k=top_k, 172 | top_p=top_p, 173 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, 174 | no_repeat_ngram_size=no_repeat_ngram_size, 175 | length_penalty=length_penalty, 176 | num_return_sequences=num_return_sequences, 177 | do_sample=do_sample, 178 | early_stopping=early_stopping, 179 | ) 180 | 181 | self.lang_encoder.clear_conditioned_layers() 182 | return output 183 | 184 | def _encode_vision_x(self, vision_x: torch.Tensor): 185 | """ 186 | Compute media tokens from vision input by passing it through vision encoder and conditioning language model. 187 | Args: 188 | vision_x (torch.Tensor): Vision input 189 | shape (B, T_img, F, C, H, W) 190 | Images in the same chunk are collated along T_img, and frames are collated along F 191 | Currently only F=1 is supported (single-frame videos) 192 | 193 | rearrange code based on https://github.com/dhansmair/flamingo-mini 194 | """ 195 | 196 | assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)" 197 | b, T, F = vision_x.shape[:3] 198 | assert F == 1, "Only single frame supported" 199 | 200 | vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") 201 | with torch.no_grad(): 202 | vision_x = self.vision_encoder.visual(vision_x)[1] 203 | vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) 204 | 205 | vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d) 206 | 207 | for layer in self.lang_encoder._get_decoder_layers(): 208 | layer.condition_vis_x(vision_x) 209 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/flamingo_lm.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | import random 3 | 4 | import torch.nn as nn 5 | 6 | from .helpers import GatedCrossAttentionBlock 7 | from .utils import getattr_recursive, setattr_recursive 8 | 9 | 10 | class FlamingoLayer(nn.Module): 11 | def __init__(self, gated_cross_attn_layer, decoder_layer): 12 | super().__init__() 13 | self.gated_cross_attn_layer = gated_cross_attn_layer 14 | self.decoder_layer = decoder_layer 15 | self.vis_x = None 16 | self.media_locations = None 17 | self.only_lang_x = False 18 | 19 | def is_conditioned(self) -> bool: 20 | """Check whether the layer is conditioned.""" 21 | return self.vis_x is not None 22 | 23 | # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) 24 | def condition_vis_x(self, vis_x): 25 | self.vis_x = vis_x 26 | 27 | def condition_only_lang_x(self, only_lang_x=False): 28 | self.only_lang_x = only_lang_x 29 | 30 | def condition_media_locations(self, media_locations): 31 | self.media_locations = media_locations 32 | 33 | def condition_attend_previous(self, attend_previous): 34 | self.attend_previous = attend_previous 35 | 36 | def forward( 37 | self, 38 | lang_x, 39 | attention_mask=None, 40 | **decoder_layer_kwargs, 41 | ): 42 | if self.gated_cross_attn_layer is None or self.only_lang_x: 43 | return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) 44 | 45 | if self.vis_x is None: 46 | raise ValueError("vis_x must be conditioned before forward pass") 47 | 48 | if self.media_locations is None: 49 | raise ValueError("media_locations must be conditioned before forward pass") 50 | 51 | lang_x = self.gated_cross_attn_layer( 52 | lang_x, 53 | self.vis_x, 54 | media_locations=self.media_locations, 55 | attend_previous=self.attend_previous, 56 | ) 57 | lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) 58 | return lang_x 59 | 60 | 61 | class FlamingoLMMixin(nn.Module): 62 | """ 63 | Mixin to add cross-attention layers to a language model. 64 | """ 65 | 66 | def set_decoder_layers_attr_name(self, decoder_layers_attr_name): 67 | self.decoder_layers_attr_name = decoder_layers_attr_name 68 | 69 | def _get_decoder_layers(self): 70 | return getattr_recursive(self, self.decoder_layers_attr_name) 71 | 72 | def _set_decoder_layers(self, value): 73 | setattr_recursive(self, self.decoder_layers_attr_name, value) 74 | 75 | def init_flamingo( 76 | self, 77 | media_token_id, 78 | vis_hidden_size, 79 | cross_attn_every_n_layers, 80 | use_media_placement_augmentation, 81 | ): 82 | """ 83 | Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. 84 | """ 85 | 86 | self.gated_cross_attn_layers = nn.ModuleList( 87 | [ 88 | GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size) 89 | if (layer_idx + 1) % cross_attn_every_n_layers == 0 90 | else None 91 | for layer_idx, _ in enumerate(self._get_decoder_layers()) 92 | ] 93 | ) 94 | self._set_decoder_layers( 95 | nn.ModuleList( 96 | [ 97 | FlamingoLayer(gated_cross_attn_layer, decoder_layer) 98 | for gated_cross_attn_layer, decoder_layer in zip( 99 | self.gated_cross_attn_layers, self._get_decoder_layers() 100 | ) 101 | ] 102 | ) 103 | ) 104 | self.media_token_id = media_token_id 105 | self.use_media_placement_augmentation = use_media_placement_augmentation 106 | self.initialized_flamingo = True 107 | 108 | def forward(self, *input, **kwargs): 109 | """Condition the Flamingo layers on the media locations before forward()""" 110 | if not self.initialized_flamingo: 111 | raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.") 112 | 113 | input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0] 114 | media_locations = input_ids == self.media_token_id 115 | attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False 116 | 117 | for layer in self.get_decoder().layers: 118 | layer.condition_media_locations(media_locations) 119 | layer.condition_attend_previous(attend_previous) 120 | 121 | return super().forward(*input, **kwargs) # Call the other parent's forward method 122 | 123 | def is_conditioned(self) -> bool: 124 | """Check whether all decoder layers are already conditioned.""" 125 | return all(l.is_conditioned() for l in self._get_decoder_layers()) 126 | 127 | def clear_conditioned_layers(self): 128 | for layer in self._get_decoder_layers(): 129 | layer.condition_vis_x(None) 130 | layer.condition_media_locations(None) 131 | layer.condition_attend_previous(None) 132 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/lucidrains/flamingo-pytorch 3 | """ 4 | 5 | import torch 6 | from einops import rearrange, repeat 7 | from einops_exts import rearrange_many 8 | from torch import einsum, nn 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | 25 | class PerceiverAttention(nn.Module): 26 | def __init__(self, *, dim, dim_head=64, heads=8): 27 | super().__init__() 28 | self.scale = dim_head**-0.5 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm_media = nn.LayerNorm(dim) 33 | self.norm_latents = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, T, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, T, n2, D) 46 | """ 47 | x = self.norm_media(x) 48 | latents = self.norm_latents(latents) 49 | 50 | h = self.heads 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h) 56 | q = q * self.scale 57 | 58 | # attention 59 | sim = einsum("... i d, ... j d -> ... i j", q, k) 60 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 61 | attn = sim.softmax(dim=-1) 62 | 63 | out = einsum("... i j, ... j d -> ... i d", attn, v) 64 | out = rearrange(out, "b h t n d -> b t n (h d)", h=h) 65 | return self.to_out(out) 66 | 67 | 68 | class PerceiverResampler(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | dim, 73 | depth=6, 74 | dim_head=64, 75 | heads=8, 76 | num_latents=64, 77 | max_num_media=None, 78 | max_num_frames=None, 79 | ff_mult=4, 80 | ): 81 | super().__init__() 82 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 83 | self.frame_embs = nn.Parameter(torch.randn(max_num_frames, dim)) if exists(max_num_frames) else None 84 | self.media_time_embs = nn.Parameter(torch.randn(max_num_media, 1, dim)) if exists(max_num_media) else None 85 | 86 | self.layers = nn.ModuleList([]) 87 | for _ in range(depth): 88 | self.layers.append( 89 | nn.ModuleList( 90 | [ 91 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 92 | FeedForward(dim=dim, mult=ff_mult), 93 | ] 94 | ) 95 | ) 96 | 97 | self.norm = nn.LayerNorm(dim) 98 | 99 | def forward(self, x): 100 | """ 101 | Args: 102 | x (torch.Tensor): image features 103 | shape (b, T, F, v, D) 104 | Returns: 105 | shape (b, T, n, D) where n is self.num_latents 106 | """ 107 | b, T, F, v = x.shape[:4] 108 | 109 | # frame and media time embeddings 110 | if exists(self.frame_embs): 111 | frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v) 112 | x = x + frame_embs 113 | x = rearrange(x, "b T F v d -> b T (F v) d") # flatten the frame and spatial dimensions 114 | if exists(self.media_time_embs): 115 | x = x + self.media_time_embs[:T] 116 | 117 | # blocks 118 | latents = repeat(self.latents, "n d -> b T n d", b=b, T=T) 119 | for attn, ff in self.layers: 120 | latents = attn(x, latents) + latents 121 | latents = ff(latents) + latents 122 | return self.norm(latents) 123 | 124 | 125 | # gated cross attention 126 | 127 | 128 | class MaskedCrossAttention(nn.Module): 129 | def __init__( 130 | self, 131 | *, 132 | dim, 133 | dim_visual, 134 | dim_head=64, 135 | heads=8, 136 | only_attend_immediate_media=True, 137 | ): 138 | super().__init__() 139 | self.scale = dim_head**-0.5 140 | self.heads = heads 141 | inner_dim = dim_head * heads 142 | 143 | self.norm = nn.LayerNorm(dim) 144 | 145 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 146 | self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False) 147 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 148 | 149 | # whether for text to only attend to immediate preceding image, or all previous images 150 | self.only_attend_immediate_media = only_attend_immediate_media 151 | 152 | def forward(self, x, media, media_locations=None, attend_previous=True): 153 | """ 154 | Args: 155 | x (torch.Tensor): text features 156 | shape (B, T_txt, D_txt) 157 | media (torch.Tensor): image features 158 | shape (B, T_img, n, D_img) where n is the dim of the latents 159 | media_locations: boolean mask identifying the media tokens in x 160 | shape (B, T_txt) 161 | attend_previous: bool 162 | If false, ignores immediately preceding image and starts attending when following image 163 | """ 164 | _, T_img, n = media.shape[:3] 165 | h = self.heads 166 | 167 | x = self.norm(x) 168 | 169 | q = self.to_q(x) 170 | media = rearrange(media, "b t n d -> b (t n) d") 171 | 172 | k, v = self.to_kv(media).chunk(2, dim=-1) 173 | q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=h) 174 | 175 | q = q * self.scale 176 | 177 | sim = einsum("... i d, ... j d -> ... i j", q, k) 178 | 179 | if exists(media_locations): 180 | # at each boolean of True, increment the time counter (relative to media time) 181 | text_time = media_locations.cumsum(dim=-1) 182 | media_time = torch.arange(T_img, device=x.device) + 1 183 | 184 | if not attend_previous: 185 | text_time[~media_locations] += 1 186 | # make sure max is still the number of images in the sequence 187 | text_time[ 188 | text_time 189 | > repeat( 190 | torch.count_nonzero(media_locations, dim=1), 191 | "b -> b i", 192 | i=text_time.shape[1], 193 | ) 194 | ] = 0 195 | 196 | # text time must equal media time if only attending to most immediate image 197 | # otherwise, as long as text time is greater than media time (if attending to all previous images / media) 198 | mask_op = torch.eq if self.only_attend_immediate_media else torch.ge 199 | 200 | text_to_media_mask = mask_op( 201 | rearrange(text_time, "b i -> b 1 i 1"), 202 | repeat(media_time, "j -> 1 1 1 (j n)", n=n), 203 | ) 204 | sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max) 205 | 206 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 207 | attn = sim.softmax(dim=-1) 208 | 209 | if exists(media_locations) and self.only_attend_immediate_media: 210 | # any text without a preceding media needs to have attention zeroed out 211 | text_without_media_mask = text_time == 0 212 | text_without_media_mask = rearrange(text_without_media_mask, "b i -> b 1 i 1") 213 | attn = attn.masked_fill(text_without_media_mask, 0.0) 214 | 215 | out = einsum("... i j, ... j d -> ... i d", attn, v) 216 | out = rearrange(out, "b h n d -> b n (h d)") 217 | return self.to_out(out) 218 | 219 | 220 | class GatedCrossAttentionBlock(nn.Module): 221 | def __init__( 222 | self, 223 | *, 224 | dim, 225 | dim_visual, 226 | dim_head=64, 227 | heads=8, 228 | ff_mult=4, 229 | only_attend_immediate_media=True, 230 | ): 231 | super().__init__() 232 | self.attn = MaskedCrossAttention( 233 | dim=dim, 234 | dim_visual=dim_visual, 235 | dim_head=dim_head, 236 | heads=heads, 237 | only_attend_immediate_media=only_attend_immediate_media, 238 | ) 239 | self.attn_gate = nn.Parameter(torch.tensor([0.0])) 240 | 241 | self.ff = FeedForward(dim, mult=ff_mult) 242 | self.ff_gate = nn.Parameter(torch.tensor([0.0])) 243 | 244 | def forward( 245 | self, 246 | x, 247 | media, 248 | media_locations=None, 249 | attend_previous=True, 250 | ): 251 | x = ( 252 | self.attn( 253 | x, 254 | media, 255 | media_locations=media_locations, 256 | attend_previous=attend_previous, 257 | ) 258 | * self.attn_gate.tanh() 259 | + x 260 | ) 261 | x = self.ff(x) * self.ff_gate.tanh() + x 262 | 263 | return x 264 | -------------------------------------------------------------------------------- /mmgpt/models/open_flamingo/utils.py: -------------------------------------------------------------------------------- 1 | def extend_instance(obj, mixin): 2 | """Apply mixins to a class instance after creation""" 3 | base_cls = obj.__class__ 4 | base_cls_name = obj.__class__.__name__ 5 | obj.__class__ = type( 6 | base_cls_name, (mixin, base_cls), {} 7 | ) # mixin needs to go first for our forward() logic to work 8 | 9 | 10 | def getattr_recursive(obj, att): 11 | """ 12 | Return nested attribute of obj 13 | Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c 14 | """ 15 | if att == "": 16 | return obj 17 | i = att.find(".") 18 | if i < 0: 19 | return getattr(obj, att) 20 | else: 21 | return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :]) 22 | 23 | 24 | def setattr_recursive(obj, att, val): 25 | """ 26 | Set nested attribute of obj 27 | Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val 28 | """ 29 | if "." in att: 30 | obj = getattr_recursive(obj, ".".join(att.split(".")[:-1])) 31 | setattr(obj, att.split(".")[-1], val) 32 | -------------------------------------------------------------------------------- /mmgpt/train/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /mmgpt/train/distributed.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | import os 3 | 4 | import torch 5 | 6 | try: 7 | import horovod.torch as hvd 8 | except ImportError: 9 | hvd = None 10 | 11 | 12 | def is_global_master(args): 13 | return args.rank == 0 14 | 15 | 16 | def is_local_master(args): 17 | return args.local_rank == 0 18 | 19 | 20 | def is_master(args, local=False): 21 | return is_local_master(args) if local else is_global_master(args) 22 | 23 | 24 | def is_using_horovod(): 25 | # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set 26 | # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... 27 | ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] 28 | pmi_vars = ["PMI_RANK", "PMI_SIZE"] 29 | if all([var in os.environ for var in ompi_vars]) or all([var in os.environ for var in pmi_vars]): 30 | return True 31 | else: 32 | return False 33 | 34 | 35 | def is_using_distributed(): 36 | if "WORLD_SIZE" in os.environ: 37 | return int(os.environ["WORLD_SIZE"]) > 1 38 | if "SLURM_NTASKS" in os.environ: 39 | return int(os.environ["SLURM_NTASKS"]) > 1 40 | return False 41 | 42 | 43 | def world_info_from_env(): 44 | local_rank = 0 45 | for v in ( 46 | "LOCAL_RANK", 47 | "MPI_LOCALRANKID", 48 | "SLURM_LOCALID", 49 | "OMPI_COMM_WORLD_LOCAL_RANK", 50 | ): 51 | if v in os.environ: 52 | local_rank = int(os.environ[v]) 53 | break 54 | global_rank = 0 55 | for v in ("RANK", "PMI_RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK"): 56 | if v in os.environ: 57 | global_rank = int(os.environ[v]) 58 | break 59 | world_size = 1 60 | for v in ("WORLD_SIZE", "PMI_SIZE", "SLURM_NTASKS", "OMPI_COMM_WORLD_SIZE"): 61 | if v in os.environ: 62 | world_size = int(os.environ[v]) 63 | break 64 | 65 | return local_rank, global_rank, world_size 66 | 67 | 68 | def init_distributed_device(args): 69 | # Distributed training = training on more than one GPU. 70 | # Works in both single and multi-node scenarios. 71 | args.distributed = False 72 | args.world_size = 1 73 | args.rank = 0 # global rank 74 | args.local_rank = 0 75 | if args.horovod: 76 | assert hvd is not None, "Horovod is not installed" 77 | hvd.init() 78 | args.local_rank = int(hvd.local_rank()) 79 | args.rank = hvd.rank() 80 | args.world_size = hvd.size() 81 | args.distributed = True 82 | os.environ["LOCAL_RANK"] = str(args.local_rank) 83 | os.environ["RANK"] = str(args.rank) 84 | os.environ["WORLD_SIZE"] = str(args.world_size) 85 | elif is_using_distributed(): 86 | if "SLURM_PROCID" in os.environ: 87 | # DDP via SLURM 88 | args.local_rank, args.rank, args.world_size = world_info_from_env() 89 | # SLURM var -> torch.distributed vars in case needed 90 | os.environ["LOCAL_RANK"] = str(args.local_rank) 91 | os.environ["RANK"] = str(args.rank) 92 | os.environ["WORLD_SIZE"] = str(args.world_size) 93 | torch.distributed.init_process_group( 94 | backend=args.dist_backend, 95 | init_method=args.dist_url, 96 | world_size=args.world_size, 97 | rank=args.rank, 98 | ) 99 | else: 100 | # DDP via torchrun, torch.distributed.launch 101 | args.local_rank, _, _ = world_info_from_env() 102 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url) 103 | args.world_size = torch.distributed.get_world_size() 104 | args.rank = torch.distributed.get_rank() 105 | args.distributed = True 106 | else: 107 | # needed to run on single gpu 108 | torch.distributed.init_process_group( 109 | backend=args.dist_backend, 110 | init_method=args.dist_url, 111 | world_size=1, 112 | rank=0, 113 | ) 114 | 115 | if torch.cuda.is_available(): 116 | if args.distributed and not args.no_set_device_rank: 117 | device = "cuda:%d" % args.local_rank 118 | else: 119 | device = "cuda:0" 120 | torch.cuda.set_device(device) 121 | else: 122 | device = "cpu" 123 | args.device = device 124 | device = torch.device(device) 125 | return device 126 | 127 | 128 | def is_rank0(): 129 | if not torch.distributed.is_initialized(): 130 | return True 131 | return torch.distributed.get_rank() == 0 132 | -------------------------------------------------------------------------------- /mmgpt/train/instruction_finetune.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | 3 | import argparse 4 | import copy 5 | import glob 6 | import os 7 | import random 8 | import time 9 | 10 | import numpy as np 11 | import torch 12 | import wandb 13 | from mmengine import Config 14 | from torch.nn.parallel import DistributedDataParallel as DDP 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from tqdm import tqdm 17 | from transformers import ( 18 | get_constant_schedule_with_warmup, 19 | get_cosine_schedule_with_warmup, 20 | get_linear_schedule_with_warmup, 21 | ) 22 | 23 | from mmgpt import create_model_and_transforms 24 | from mmgpt.models.builder import create_toy_model_and_transforms 25 | from mmgpt.datasets import InfiniteSampler, build_dataset 26 | from mmgpt.train.distributed import init_distributed_device, world_info_from_env 27 | from mmgpt.train.train_utils import AverageMeter, get_autocast, get_cast_dtype, get_checkpoint 28 | 29 | 30 | def random_seed(seed=42, rank=0): 31 | torch.manual_seed(seed + rank) 32 | np.random.seed(seed + rank) 33 | random.seed(seed + rank) 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str) 39 | parser.add_argument("--vision_encoder_pretrained", default="openai", type=str) 40 | parser.add_argument("--lm_path", default="checkpoints/llama-7b_hf", type=str) 41 | parser.add_argument( 42 | "--tokenizer_path", 43 | default="checkpoints/llama-7b_hf", 44 | type=str, 45 | help="path to tokenizer", 46 | ) 47 | parser.add_argument( 48 | "--pretrained_path", 49 | default="checkpoints/OpenFlamingo-9B/checkpoint.pt", 50 | type=str, 51 | help="path to pretrained model", 52 | ) 53 | parser.add_argument( 54 | "--run_name", 55 | type=str, 56 | default="train-my-gpt4", 57 | help="used to name saving directory and wandb run", 58 | ) 59 | parser.add_argument("--use_media_placement_augmentation", action="store_true") 60 | parser.add_argument("--offline", action="store_true") 61 | parser.add_argument("--num_epochs", type=int, default=1) 62 | parser.add_argument("--logging_steps", type=int, default=100, help="log loss every n steps") 63 | # Sum of gradient optimization batch size 64 | parser.add_argument( 65 | "--resume_from_checkpoint", 66 | type=str, 67 | help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", 68 | default=None, 69 | ) 70 | parser.add_argument( 71 | "--delete_previous_checkpoint", 72 | action="store_true", 73 | help="delete previous checkpoint when saving new checkpoint", 74 | ) 75 | parser.add_argument("--seed", type=int, default=42) 76 | parser.add_argument("--learning_rate", default=1e-5, type=float) 77 | parser.add_argument( 78 | "--lr_scheduler", 79 | default="constant", 80 | type=str, 81 | help="constant, linear, or cosine", 82 | ) 83 | parser.add_argument("--warmup_steps", default=100, type=int) 84 | parser.add_argument("--weight_decay", default=0.1, type=float) 85 | parser.add_argument( 86 | "--precision", 87 | choices=["amp", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], 88 | default="amp", 89 | help="Floating point precision.", 90 | ) 91 | # data args 92 | parser.add_argument("--workers", type=int, default=0) 93 | parser.add_argument("--batch_size", type=int, default=1) 94 | parser.add_argument("--dataset_config", type=str, default=None, help="path to dataset config file") 95 | parser.add_argument("--gradient_accumulation_steps", type=int, default=16) 96 | # Finetune config 97 | parser.add_argument("--tuning_config", type=str, default=None, help="path to tuning config file") 98 | # distributed training args 99 | parser.add_argument( 100 | "--dist-url", 101 | default="env://", 102 | type=str, 103 | help="url used to set up distributed training", 104 | ) 105 | parser.add_argument("--dist-backend", default="nccl", type=str, help="distributed backend") 106 | parser.add_argument( 107 | "--horovod", 108 | default=False, 109 | action="store_true", 110 | help="Use horovod for distributed training.", 111 | ) 112 | parser.add_argument( 113 | "--no-set-device-rank", 114 | default=False, 115 | action="store_true", 116 | help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", 117 | ) 118 | # wandb args 119 | parser.add_argument("--report_to_wandb", default=False, action="store_true") 120 | parser.add_argument( 121 | "--wandb_project", 122 | type=str, 123 | ) 124 | parser.add_argument( 125 | "--wandb_entity", 126 | type=str, 127 | ) 128 | parser.add_argument( 129 | "--save_checkpoints_to_wandb", 130 | default=False, 131 | action="store_true", 132 | help="save checkpoints to wandb", 133 | ) 134 | 135 | args = parser.parse_args() 136 | 137 | if args.save_checkpoints_to_wandb and not args.report_to_wandb: 138 | raise ValueError("save_checkpoints_to_wandb requires report_to_wandb") 139 | 140 | if args.offline: 141 | os.environ["WANDB_MODE"] = "offline" 142 | os.environ["TRANSFORMERS_OFFLINE"] = "1" 143 | 144 | args.local_rank, args.rank, args.world_size = world_info_from_env() 145 | 146 | if args.rank == 0: 147 | if not os.path.exists(args.run_name): 148 | os.makedirs(args.run_name) 149 | 150 | device_id = init_distributed_device(args) 151 | 152 | random_seed(args.seed) 153 | 154 | if args.tuning_config is not None: 155 | tuning_config = Config.fromfile(args.tuning_config) 156 | else: 157 | raise ValueError("tuning_config must be specified") 158 | 159 | model, image_processor, tokenizer = create_model_and_transforms( 160 | model_name="open_flamingo", 161 | clip_vision_encoder_path=args.vision_encoder_path, 162 | clip_vision_encoder_pretrained=args.vision_encoder_pretrained, 163 | lang_encoder_path=args.lm_path, 164 | tokenizer_path=args.tokenizer_path if args.tokenizer_path else args.lm_path, 165 | use_media_placement_augmentation=args.use_media_placement_augmentation, 166 | pretrained_model_path=args.pretrained_path, 167 | tuning_config=tuning_config.tuning_config, 168 | ) 169 | 170 | if args.dataset_config is not None: 171 | dataset_config = Config.fromfile(args.dataset_config) 172 | else: 173 | raise ValueError("dataset_config must be specified") 174 | 175 | dataset = build_dataset( 176 | dataset_config=dataset_config.visual_datasets, 177 | vis_processor=image_processor, 178 | tokenizer=tokenizer, 179 | ) 180 | train_dataloader = DataLoader( 181 | dataset, 182 | batch_size=args.batch_size, 183 | num_workers=args.workers, 184 | sampler=DistributedSampler(dataset, shuffle=True, drop_last=True), 185 | collate_fn=dataset.collater, 186 | ) 187 | 188 | # build language dataset and dataloader for multi-modality training 189 | if dataset_config.get('language_datasets') is not None and len(dataset_config.language_datasets) > 0: 190 | lang_dataset = build_dataset( 191 | dataset_config=dataset_config.language_datasets, 192 | tokenizer=tokenizer, 193 | ) 194 | lang_dataloader = DataLoader( 195 | lang_dataset, 196 | batch_size=args.batch_size, 197 | num_workers=args.workers, 198 | sampler=InfiniteSampler(lang_dataset, shuffle=True), 199 | collate_fn=lang_dataset.collater, 200 | ) 201 | lang_dataloader = iter(lang_dataloader) 202 | else: 203 | lang_dataloader = None 204 | 205 | random_seed(args.seed, args.rank) 206 | 207 | print(f"Start running training on rank {args.rank}.") 208 | 209 | if args.rank == 0 and args.report_to_wandb: 210 | wandb.init( 211 | project=args.wandb_project, 212 | entity=args.wandb_entity, 213 | name=args.run_name, 214 | config=vars(args), 215 | ) 216 | 217 | device_id = args.rank % torch.cuda.device_count() 218 | model = model.to(device_id) 219 | 220 | ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters=True) 221 | 222 | def get_grouped_params(model): 223 | params_with_wd, params_without_wd = [], [] 224 | 225 | def apply_decay(x): 226 | return ( 227 | "gated_cross_attn_layer" in x 228 | and "ff_gate" not in x 229 | and "attn_gate" not in x 230 | and "norm" not in x 231 | and "bias" not in x 232 | ) 233 | 234 | for n, p in model.named_parameters(): 235 | # if p.requires_grad: 236 | if apply_decay(n): 237 | params_with_wd.append(p) 238 | else: 239 | params_without_wd.append(p) 240 | 241 | return [ 242 | {"params": params_with_wd, "weight_decay": args.weight_decay}, 243 | {"params": params_without_wd, "weight_decay": 0.0}, 244 | ] 245 | 246 | optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate) 247 | 248 | total_training_steps = len(train_dataloader) * args.num_epochs 249 | 250 | if args.rank == 0: 251 | print(f"Total training steps: {total_training_steps}") 252 | 253 | if args.lr_scheduler == "linear": 254 | lr_scheduler = get_linear_schedule_with_warmup( 255 | optimizer, 256 | num_warmup_steps=args.warmup_steps, 257 | num_training_steps=total_training_steps // args.gradient_accumulation_steps, 258 | ) 259 | elif args.lr_scheduler == "cosine": 260 | lr_scheduler = get_cosine_schedule_with_warmup( 261 | optimizer, 262 | num_warmup_steps=args.warmup_steps, 263 | num_training_steps=total_training_steps // args.gradient_accumulation_steps, 264 | ) 265 | else: 266 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) 267 | 268 | # check if a checkpoint exists for this run 269 | if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None: 270 | checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") 271 | if len(checkpoint_list) == 0: 272 | print(f"Found no checkpoints for run {args.run_name}.") 273 | else: 274 | args.resume_from_checkpoint = sorted(checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]))[-1] 275 | print(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") 276 | 277 | resume_from_epoch = 0 278 | if args.resume_from_checkpoint is not None: 279 | if args.rank == 0: 280 | print(f"Loading checkpoint from {args.resume_from_checkpoint}") 281 | checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") 282 | ddp_model.load_state_dict(checkpoint["model_state_dict"], False) 283 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 284 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) 285 | resume_from_epoch = checkpoint["epoch"] + 1 286 | 287 | ddp_model.train() 288 | 289 | for epoch in range(resume_from_epoch, args.num_epochs): 290 | train_dataloader.sampler.set_epoch(epoch) 291 | 292 | train_one_epoch( 293 | args=args, 294 | model=ddp_model, 295 | epoch=epoch, 296 | tokenizer=tokenizer, 297 | optimizer=optimizer, 298 | lr_scheduler=lr_scheduler, 299 | train_dataloader=train_dataloader, 300 | language_dataloader=lang_dataloader, 301 | device_id=device_id, 302 | wandb=wandb, 303 | ) 304 | 305 | if args.rank == 0: 306 | if not os.path.exists(args.run_name): 307 | os.makedirs(args.run_name) 308 | 309 | checkpoint_dict = { 310 | "epoch": epoch, 311 | "model_state_dict": get_checkpoint(ddp_model), 312 | "optimizer_state_dict": optimizer.state_dict(), 313 | "lr_scheduler_state_dict": lr_scheduler.state_dict(), 314 | "tuning_config": tuning_config, 315 | } 316 | 317 | print(f"Saving checkpoint to {args.run_name}/checkpoint_{epoch}.pt") 318 | torch.save(checkpoint_dict, f"{args.run_name}/checkpoint_{epoch}.pt") 319 | if args.report_to_wandb and args.save_checkpoints_to_wandb: 320 | wandb.save(f"{args.run_name}/checkpoint_{epoch}.pt") 321 | 322 | if args.delete_previous_checkpoint: 323 | if epoch > 0: 324 | os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt") 325 | if args.rank == 0: 326 | torch.save( 327 | {"model_state_dict": get_checkpoint(ddp_model.module), "tuning_config": tuning_config}, 328 | f"{args.run_name}/final_weights.pt", 329 | ) 330 | if args.report_to_wandb and args.save_checkpoints_to_wandb: 331 | wandb.save(f"{args.run_name}/final_weights.pt") 332 | 333 | 334 | def train_one_epoch( 335 | args, 336 | model, 337 | epoch, 338 | train_dataloader, 339 | language_dataloader, 340 | tokenizer, 341 | optimizer, 342 | lr_scheduler, 343 | device_id, 344 | wandb, 345 | ): 346 | num_batches_per_epoch = len(train_dataloader) 347 | 348 | total_training_steps = num_batches_per_epoch * args.num_epochs 349 | 350 | autocast = get_autocast(args.precision) 351 | cast_dtype = get_cast_dtype(args.precision) 352 | 353 | model.train() 354 | 355 | # setup logging 356 | step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) 357 | data_time_m = ( 358 | AverageMeter() 359 | ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) 360 | end = time.time() 361 | 362 | # loop through dataloader 363 | for num_steps, batch in tqdm( 364 | enumerate(train_dataloader), 365 | disable=args.rank != 0, 366 | total=total_training_steps, 367 | initial=(epoch * num_batches_per_epoch), 368 | ): 369 | data_time_m.update(time.time() - end) 370 | 371 | global_step = num_steps + epoch * num_batches_per_epoch 372 | 373 | #### VISION FORWARD PASS #### 374 | images = batch["image"].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) 375 | input_ids = batch["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) 376 | attention_mask = batch["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) 377 | labels = batch["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) 378 | 379 | with autocast(): 380 | loss_batch = model( 381 | vision_x=images, 382 | lang_x=input_ids, 383 | attention_mask=attention_mask, 384 | labels=labels, 385 | )[0] 386 | loss = loss_batch / args.gradient_accumulation_steps 387 | loss_vision = loss # for logging 388 | 389 | #### BACKWARD PASS #### 390 | loss.backward() 391 | 392 | #### LANGUAGE FORWARD PASS #### 393 | if language_dataloader is not None: 394 | batch_lang = next(language_dataloader) 395 | lang_input_ids = batch_lang["input_ids"].to(device_id, dtype=cast_dtype, non_blocking=True) 396 | lang_attention_mask = batch_lang["attention_mask"].to(device_id, dtype=cast_dtype, non_blocking=True) 397 | lang_labels = batch_lang["labels"].to(device_id, dtype=cast_dtype, non_blocking=True) 398 | 399 | with autocast(): 400 | lang_loss_batch = model( 401 | vision_x=None, 402 | lang_x=lang_input_ids, 403 | attention_mask=lang_attention_mask, 404 | labels=lang_labels, 405 | )[0] 406 | lang_loss = lang_loss_batch / args.gradient_accumulation_steps 407 | #### BACKWARD PASS #### 408 | lang_loss.backward() 409 | 410 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 411 | 412 | # step optimizer and log 413 | if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): 414 | optimizer.step() 415 | lr_scheduler.step() 416 | optimizer.zero_grad() 417 | 418 | # step time and reset end outside of rank 0 419 | step_time_m.update(time.time() - end) 420 | end = time.time() 421 | 422 | if args.rank == 0 and args.report_to_wandb: 423 | # compute within rank 0 424 | samples_per_second = ( 425 | args.gradient_accumulation_steps * args.batch_size * args.world_size / step_time_m.val 426 | ) 427 | samples_per_second_per_gpu = args.gradient_accumulation_steps * args.batch_size / step_time_m.val 428 | 429 | wandb.log( 430 | { 431 | "data_time": data_time_m.avg, 432 | "step_time": step_time_m.avg, 433 | "samples_per_second": samples_per_second, 434 | "samples_per_second_per_gpu": samples_per_second_per_gpu, 435 | "lr": optimizer.param_groups[0]["lr"], 436 | }, 437 | commit=False, 438 | ) 439 | step_time_m.reset() 440 | data_time_m.reset() 441 | 442 | loss_log = { 443 | "loss": loss.item(), 444 | "loss_vision": loss_vision.item(), 445 | "global_step": global_step, 446 | } 447 | if language_dataloader is not None: 448 | loss_log["loss_lang"] = lang_loss.item() 449 | 450 | wandb.log(loss_log, commit=True) 451 | 452 | # Log loss to console 453 | if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: 454 | print( 455 | f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss: {loss.item():.3f}" 456 | ) 457 | 458 | 459 | if __name__ == "__main__": 460 | main() 461 | -------------------------------------------------------------------------------- /mmgpt/train/train_utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/mlfoundations/open_flamingo""" 2 | import time 3 | from contextlib import suppress 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | 9 | def get_cast_dtype(precision: str): 10 | cast_dtype = None 11 | if precision == "bf16": 12 | cast_dtype = torch.bfloat16 13 | elif precision == "fp16": 14 | cast_dtype = torch.float16 15 | return cast_dtype 16 | 17 | 18 | def get_autocast(precision): 19 | if precision == "amp": 20 | return torch.cuda.amp.autocast 21 | elif precision == "amp_bfloat16" or precision == "amp_bf16": 22 | # amp_bfloat16 is more stable than amp float16 for clip training 23 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 24 | else: 25 | return suppress 26 | 27 | 28 | def train_one_epoch( 29 | args, 30 | model, 31 | epoch, 32 | laion_loader, 33 | mmc4_loader, 34 | tokenizer, 35 | optimizer, 36 | lr_scheduler, 37 | device_id, 38 | wandb, 39 | ): 40 | num_batches_per_epoch_laion = laion_loader.num_batches 41 | num_batches_per_epoch_mmc4 = mmc4_loader.num_batches 42 | 43 | assert ( 44 | num_batches_per_epoch_laion == num_batches_per_epoch_mmc4 45 | ), "Number of batches in laion and mmc4 datasets must be the same" 46 | num_batches_per_epoch = num_batches_per_epoch_mmc4 47 | total_training_steps = num_batches_per_epoch * args.num_epochs 48 | 49 | autocast = get_autocast(args.precision) 50 | cast_dtype = get_cast_dtype(args.precision) 51 | 52 | media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] 53 | endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] 54 | 55 | model.train() 56 | 57 | # setup logging 58 | step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) 59 | data_time_m = ( 60 | AverageMeter() 61 | ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) 62 | end = time.time() 63 | 64 | # loop through dataloader 65 | for num_steps, (batch_laion, batch_mmc4) in tqdm( 66 | enumerate(zip(laion_loader, mmc4_loader)), 67 | disable=args.rank != 0, 68 | total=total_training_steps, 69 | initial=(epoch * num_batches_per_epoch), 70 | ): 71 | data_time_m.update(time.time() - end) 72 | 73 | global_step = num_steps + epoch * num_batches_per_epoch 74 | 75 | #### LAION FORWARD PASS #### 76 | images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) 77 | 78 | input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) 79 | attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True) 80 | 81 | labels = input_ids.clone() 82 | labels[labels == tokenizer.pad_token_id] = -100 83 | labels[:, 0] = -100 84 | labels[labels == media_token_id] = -100 85 | labels.to(device_id) 86 | 87 | with autocast(): 88 | loss_laion = model( 89 | vision_x=images, 90 | lang_x=input_ids, 91 | attention_mask=attention_mask, 92 | labels=labels, 93 | )[0] 94 | divided_loss_laion = loss_laion / args.gradient_accumulation_steps 95 | 96 | #### C4 FORWARD PASS #### 97 | images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2) 98 | input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) 99 | attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) 100 | 101 | # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len) 102 | labels = input_ids.clone() 103 | labels[labels == tokenizer.pad_token_id] = -100 104 | labels[:, 0] = -100 105 | 106 | for i in range(labels.shape[0]): 107 | # remove loss for any token before the first token 108 | label_idx = 0 109 | while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id: 110 | labels[i][label_idx] = -100 111 | label_idx += 1 112 | 113 | # get index of all endofchunk tokens in the sequence 114 | endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] 115 | for endofchunk_idx in endofchunk_idxs: 116 | token_idx = endofchunk_idx + 1 117 | while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id: 118 | labels[i][token_idx] = -100 119 | token_idx += 1 120 | 121 | labels[labels == media_token_id] = -100 122 | labels.to(device_id) 123 | 124 | with autocast(): 125 | loss_mmc4 = model( 126 | vision_x=images, 127 | lang_x=input_ids, 128 | attention_mask=attention_mask, 129 | labels=labels, 130 | )[0] 131 | 132 | # if loss is nan, skip this batch 133 | if torch.isnan(loss_mmc4): 134 | print("loss is nan, skipping this batch") 135 | print("input_ids: ", tokenizer.batch_decode(input_ids)) 136 | print("labels: ", labels) 137 | print("images: ", images) 138 | optimizer.zero_grad() 139 | continue 140 | 141 | divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps 142 | 143 | #### BACKWARD PASS #### 144 | loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4 145 | loss.backward() 146 | 147 | #### MASK GRADIENTS FOR EMBEDDINGS #### 148 | # Note (anas): Do not apply weight decay to embeddings as it will break this function. 149 | def mask_embedding(m): 150 | if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: 151 | zero_mask = torch.zeros_like(m.weight.grad) 152 | zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) 153 | zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id]) 154 | m.weight.grad = m.weight.grad * zero_mask 155 | 156 | model.apply(mask_embedding) 157 | 158 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 159 | 160 | # step optimizer and log 161 | if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): 162 | optimizer.step() 163 | lr_scheduler.step() 164 | optimizer.zero_grad() 165 | 166 | # step time and reset end outside of rank 0 167 | step_time_m.update(time.time() - end) 168 | end = time.time() 169 | 170 | if args.rank == 0 and args.report_to_wandb: 171 | # compute within rank 0 172 | laion_samples_per_second = ( 173 | args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val 174 | ) 175 | laion_samples_per_second_per_gpu = ( 176 | args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val 177 | ) 178 | 179 | c4_samples_per_second = ( 180 | args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val 181 | ) 182 | c4_samples_per_second_per_gpu = ( 183 | args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val 184 | ) 185 | 186 | wandb.log( 187 | { 188 | "data_time": data_time_m.avg, 189 | "step_time": step_time_m.avg, 190 | "laion_samples_per_second": laion_samples_per_second, 191 | "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu, 192 | "c4_samples_per_second": c4_samples_per_second, 193 | "c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu, 194 | "lr": optimizer.param_groups[0]["lr"], 195 | }, 196 | commit=False, 197 | ) 198 | step_time_m.reset() 199 | data_time_m.reset() 200 | 201 | wandb.log( 202 | { 203 | "loss_laion": divided_loss_laion.item(), 204 | "global_step": global_step, 205 | }, 206 | commit=False, 207 | ) 208 | wandb.log( 209 | {"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step}, 210 | commit=True, 211 | ) 212 | 213 | # Log loss to console 214 | if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: 215 | print( 216 | f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}" 217 | ) 218 | 219 | 220 | def get_checkpoint(model: torch.nn.Module): 221 | state_dict = model.state_dict() 222 | parameters = {k: v for k, v in model.named_parameters()} 223 | # remove duplicate parameters 224 | duplicate_keys = set(state_dict.keys()) - set(parameters.keys()) 225 | for k in duplicate_keys: 226 | del state_dict[k] 227 | # remove non-grad parameters 228 | for name, p in parameters.items(): 229 | if not p.requires_grad: 230 | del state_dict[name] 231 | 232 | return state_dict 233 | 234 | 235 | class AverageMeter(object): 236 | """Computes and stores the average and current value""" 237 | 238 | def __init__(self): 239 | self.reset() 240 | 241 | def reset(self): 242 | self.val = 0 243 | self.avg = 0 244 | self.sum = 0 245 | self.count = 0 246 | 247 | def update(self, val, n=1): 248 | self.val = val 249 | self.sum += val * n 250 | self.count += n 251 | self.avg = self.sum / self.count 252 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | einops-exts 3 | transformers 4 | peft 5 | bigmodelvis 6 | torch 7 | torchvision 8 | pillow 9 | more-itertools 10 | datasets 11 | braceexpand 12 | webdataset 13 | wandb 14 | nltk 15 | scipy 16 | inflection 17 | sentencepiece 18 | open_clip_torch 19 | mmengine 20 | gradio 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | if __name__ == "__main__": 6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: 7 | long_description = file.read() 8 | 9 | # TODO: This is a hack to get around the fact that we can't read the requirements.txt file, we should fix this. 10 | # def _read_reqs(relpath): 11 | # fullpath = os.path.join(Path(__file__).parent, relpath) 12 | # with open(fullpath) as f: 13 | # return [ 14 | # s.strip() 15 | # for s in f.readlines() 16 | # if (s.strip() and not s.startswith("#")) 17 | # ] 18 | 19 | REQUIREMENTS = [ 20 | "einops", 21 | "einops-exts", 22 | "transformers", 23 | "torch", 24 | "torchvision", 25 | "pillow", 26 | "more-itertools", 27 | "datasets", 28 | "braceexpand", 29 | "webdataset", 30 | "wandb", 31 | "nltk", 32 | "scipy", 33 | "inflection", 34 | "sentencepiece", 35 | "open_clip_torch", 36 | ] 37 | 38 | setup( 39 | name="mmgpt", 40 | packages=find_packages(), 41 | include_package_data=True, 42 | version="0.0.1", 43 | license="Apache 2.0", 44 | description="An open-source framework for multi-modality instruction fine-tuning", 45 | long_description=long_description, 46 | long_description_content_type="text/markdown", 47 | data_files=[(".", ["README.md"])], 48 | keywords=["machine learning"], 49 | install_requires=REQUIREMENTS, 50 | ) 51 | --------------------------------------------------------------------------------