├── .gitignore ├── README.md ├── assets ├── demo.gif ├── examples.png ├── logo.png ├── teaser-1.png ├── teaser.gif └── teaser.png ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── conversation └── conversation.py ├── demo.py ├── engine.py ├── eval.py ├── eval_mme.py ├── lavin ├── __init__.py ├── eval_model.py ├── generator.py ├── mm_adaptation.py ├── mm_adapter.py ├── model.py └── tokenizer.py ├── requirements.txt ├── scripts ├── eval_mme_benchmark.sh ├── finetuning_sqa_13b.sh ├── finetuning_sqa_13b_lite.sh ├── finetuning_sqa_7b.sh ├── finetuning_sqa_7b_lite.sh ├── finetuning_sqa_vicuna_13b.sh ├── finetuning_sqa_vicuna_7b.sh ├── vl_instruction_tuning_13b.sh └── vl_instruction_tuning_vicuna_13b.sh ├── setup.py ├── tools └── data_processing.py ├── train.py └── util ├── apply_delta.py ├── base_prompt.py ├── datasets.py ├── lars.py ├── lr_decay.py ├── lr_sched.py ├── misc.py ├── pos_embed.py └── quantization.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # checkpoint 132 | *.pth 133 | outputs/ 134 | 135 | .idea/ 136 | debug.py 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](./assets/logo.png) 2 | 3 | --- 4 | 5 | This repository contains the implementation of the NeurIPS 2023 paper: 6 | > **Cheap and Quick: Efficient Vision-Language Instruction Tuning for Large Language Models** 7 | > [[Project Page]](https://luogen1996.github.io/lavin/) [[Paper]](https://arxiv.org/pdf/2305.15023.pdf)
8 | > [Gen Luo](https://luogen1996.github.io)1, Yiyi Zhou12, [Tianhe Ren](https://rentainhe.github.io)1, Shengxin Chen1, [Xiaoshuai Sun](https://sites.google.com/view/xssun)12, [Rongrong Ji](https://mac.xmu.edu.cn/rrji/)12
9 | 1Media Analytics and Computing Lab, Department of Artificial Intelligence, School of Informatics, Xiamen University 10 | > 2Institute of Artificial Intelligence, Xiamen University 11 | 12 | In this work, we propose a novel and affordable solution for vision-language instruction tuning, namely Mixture-of-Modality Adaptation (MMA). 13 | Particularly, MMA is an end-to-end optimization regime, which connects the image encoder and LLM via lightweight adapters. Meanwhile, we also propose a novel routing algorithm in MMA, which can help the model automatically shifts the reasoning paths for single- and multi-modal instructions. Based on MMA, we develop a large vision-language instructed model called LaVIN, which demonstrates superior training efficiency and better reasoning ability than existing multimodal LLMs in various instruction-following tasks. 14 | 15 | --- 16 | 17 |
18 | 19 |
20 | 21 | ## News 22 | - **`2023/09/22`**: 🔥🔥🔥 Our paper is accepted by NeurIPS 2023! 23 | - **`2023/06/30`**: 🔥🔥🔥 With very limited training data and cost, LaVIN achieves 5-th place of Perception and Cognition on [MME benchmark](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation), outperforming seven existing multimodal LLMs. Evaluation codes are available. 24 | - **`2023/06/27`**: 🔥4-bit trainings are available now ! LaVIN-lite can be trained on one 3090 GPU, taking around 9G and 15G GPU memory for the scales of 7B and 13B , respectively. Technical details are available in [知乎](https://zhuanlan.zhihu.com/p/638784025). 25 | - **`2023/05/29`**: 🔥We released the demo and the pre-trained checkpoint (LLaMA-13B) for multimodal chatbot. 26 | - **`2023/05/25`**: 🔥We released the code of **LaVIN: Large Vision-Language Instructed model**, which achieves 89.4 (LaVIN-7B) and 90.8 (LaVIN-13B) accuracy on ScienceQA! 🔥With the proposed **mixture-of-modality adaptation**, the training time and trainable parameters can be reduced to 1.4 hours and 3.8M, respectively! Checkout the [paper](https://arxiv.org/pdf/2305.15023.pdf). 27 | 28 | ## TODO 29 | - [x] Release training codes. 30 | - [x] Release checkpoints and demo. 31 | - [x] 4-bit training. 32 | - [ ] Support more modalities, e.g., audio and video. 33 | 34 | ## Contents 35 | - [Setup](#setup) 36 | - [Fine-tuning](#fine-tuning) 37 | - [Demo](#demo) 38 | - [Model Zoo](#model-zoo) 39 | 40 | ## Setup 41 | ### Install Package 42 | ```bash 43 | conda create -n lavin python=3.8 -y 44 | conda activate lavin 45 | 46 | # install pytorch 47 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 -c pytorch 48 | 49 | # install dependency and lavin 50 | pip install -r requirements.txt 51 | pip install -e . 52 | ``` 53 | ### Data Preparation 54 | - For ScienceQA, please prepare the dataset from the [official repo](https://github.com/lupantech/ScienceQA). 55 | - For Multimodal Chatbot, download the images in _train2014_ split from [MSCOCO](http://images.cocodataset.org/zips/train2014.zip), and obtain the prepared 52k text-only and 158k text-image instruction-following data from [here](https://drive.google.com/file/d/1gORDPruqwXbgy6NYmhpDXO7t089yzsg3/view?usp=share_link). 56 | - Obtain the weights of LLaMA from [this form](https://forms.gle/jk851eBVbX1m5TAv5) (official) or Download [LLaMA-7B](https://huggingface.co/nyanko7/LLaMA-7B/tree/main) and [LLaMA-13B](https://huggingface.co/TheBloke/llama-13b) from HuggingFace (unofficial). 57 | - If you want to use Vicuna weights to initialize the model, please download from [here](https://huggingface.co/lmsys). 58 | After that, the file structure should look like: 59 | 60 | ```bash 61 | LaVIN/ 62 | |-- lavin 63 | |-- scripts 64 | |-- train.py 65 | |-- eval.py 66 | ...... 67 | data/ 68 | |-- problem.json 69 | |-- pid_splits.json 70 | |-- captions.json 71 | |-- all_data.json 72 | |-- images 73 | |-- train2014 # MSCOCO 2014 74 | |-- val2014 # MSCOCO 2014 75 | |-- train # ScienceQA train image 76 | |-- val # ScienceQA val image 77 | |-- test # ScienceQA test image 78 | |-- weights 79 | |-- tokenizer.model 80 | |--7B 81 | |-- params.json 82 | |-- consolidated.00.pth 83 | |--13B 84 | |-- params.json 85 | |-- consolidated.00.pth 86 | |-- consolidated.01.pth 87 | |--vicuna_7B 88 | |--vicuna_13B 89 | |-- config.json 90 | |-- generation_config.json 91 | |-- pytorch_model.bin.index.json 92 | |-- special_tokens_map.json 93 | |-- tokenizer_config.json 94 | |-- tokenizer.model 95 | |-- pytorch_model-00001-of-00003.bin 96 | |-- pytorch_model-00002-of-00003.bin 97 | |-- pytorch_model-00003-of-00003.bin 98 | ...... 99 | ``` 100 | ## Fine-tuning 101 | ### ScienceQA 102 | Reproduce the performance of LaVIN-7B on ScienceQA. 103 | For 7B model, we fine-tune it on 2x A100 (we find that the performance will be affected by the number of GPUs. We are working to address this problem). 104 | 105 | 106 | LLaMA weights: 107 | ```bash 108 | bash ./scripts/finetuning_sqa_7b.sh 109 | ``` 110 | 111 | Vicuna weights: 112 | ```bash 113 | bash ./scripts/finetuning_sqa_vicuna_7b.sh 114 | ``` 115 | 116 | LaVIN-lite with LLaMA weights (single GPU): 117 | ```bash 118 | bash ./scripts/finetuning_sqa_vicuna_7b_lite.sh 119 | ``` 120 | 121 | Reproduce the performance of LaVIN-13B on ScienceQA (~2 hours on 8x A100 (80G)). 122 | For 13B model, we fine-tune it on 8x A100. 123 | 124 | LLaMA weights: 125 | ```bash 126 | bash ./scripts/finetuning_sqa_13b.sh 127 | ``` 128 | 129 | Vicuna weights: 130 | ```bash 131 | bash ./scripts/finetuning_sqa_vicuna_13b.sh 132 | ``` 133 | LaVIN-lite with LLaMA weights (single GPU): 134 | ```bash 135 | bash ./scripts/finetuning_sqa_vicuna_13b_lite.sh 136 | ``` 137 | ### MultiModal ChatBot 138 | Fine-tune LaVIN-13B on 210k instruction-following data (~ 75 hours with 15 epochs and ~25 hours with 5 epochs on 8x A100 (80G) ) 139 | 140 | LLaMA weights: 141 | ```bash 142 | bash ./scripts/vl_instruction_tuning_13b.sh 143 | ``` 144 | 145 | Vicuna weights: 146 | ```bash 147 | bash ./scripts/vl_instruction_tuning_vicuna_13b.sh 148 | ``` 149 | To train on fewer GPUs, you can reduce the number of gpus in the scripts and increase gradient accumulation via ```--accum_iter``` to guarantee the total batch size of 32. Setting ```--gradient_checkpointing``` and ```--bits 4bit``` in the scripts will greatly reduce the requirements of GPU memory. 150 | 151 | 152 | ## Demo 153 | 154 | LaVIN supports both single- and multi-modal instruction inputs. Try your custom instructions in our demo: 155 | 156 | - **Launch a gradio web server on your machine, then you can interact with LaVIN as you like.** 157 | ``` 158 | torchrun --nproc_per_node 1 demo.py --server_name 127.0.0.1 159 | ``` 160 | 161 |
162 | 163 |
164 | 165 | 166 | ## Model Zoo 167 | ### ScienceQA 168 | | Model | Weights | Time | Memory | #Params | Acc | Weights | 169 | |-----------|----------:|----------:|-------:|--------:|-----:|-----------------:| 170 | | LaVIN-7B-lite | LLaMA | 29 hours (single GPU) | 9G | 3.8M | 88.35 | [google drive](https://drive.google.com/file/d/1oVtoTgt-d9EqmrVic27oZUreN9dLClMo/view?usp=sharing) | 171 | | LaVIN-13B-lite | LLaMA | 42 hours (single GPU) | 14G | 5.4M | 89.44 | [google drive](https://drive.google.com/file/d/1PyVsap3FnmgXOGXFXjYsAtR75cFypaHw/view?usp=sharing) | 172 | | LaVIN-7B | LLaMA | 1.4 hours | 33.9G | 3.8M | 89.37 | [google drive](https://drive.google.com/file/d/10X2qCBYrLH1grZOHwHRMXLUoz-S6MSgV/view?usp=share_link) | 173 | | LaVIN-7B | Vicuna | 1.4 hours | 33.9G | 3.8M | 89.41 | [google drive](https://drive.google.com/file/d/1nuMxeiWlnJKxDybCshg8pVGSvLc5dZy8/view?usp=share_link) | 174 | | LaVIN-13B | LLaMA | 2 hours | 55.9G | 5.4M | 90.54 | [google drive](https://drive.google.com/file/d/1LkKUY54spZkkeXrR7BDmU-xmK9YadcKM/view?usp=share_link) | 175 | | LaVIN-13B | LLaMA | 4 hours | 55.9G | 5.4M | 90.8 | - | 176 | 177 | ### Multimodal ChatBot 178 | | Model |Weights | Time | Memory | #Params | Acc | Weights | 179 | |-----------|----------:|---------:|-------:|--------:|----:|-----------------:| 180 | | LaVIN-13B | LLaMA | 25 hours | 55.9G | 5.4M | - | - | 181 | | LaVIN-13B | LLaMA | 75 hours | 55.9G | 5.4M | - | [google drive](https://drive.google.com/file/d/1rHQNSaiGzFHYGgsamtySPYnd5AW4OE9j/view?usp=share_link)| 182 | 183 | ## Examples 184 |
185 | 186 |
187 | 188 | ## Star History 189 | [![Star History Chart](https://api.star-history.com/svg?repos=luogen1996/LaVIN&type=Date)](https://star-history.com/#luogen1996/LaVIN&Date) 190 | 191 | ## Citation 192 | If you think our code and paper helpful, please kindly cite LaVIN and [RepAdapter](https://github.com/luogen1996/RepAdapter/): 193 | ```BibTeX 194 | @article{luo2023towards, 195 | title={Towards Efficient Visual Adaption via Structural Re-parameterization}, 196 | author={Luo, Gen and Huang, Minglang and Zhou, Yiyi and Sun, Xiaoshuai and Jiang, Guangnan and Wang, Zhiyu and Ji, Rongrong}, 197 | journal={arXiv preprint arXiv:2302.08106}, 198 | year={2023} 199 | } 200 | 201 | @article{luo2023cheap, 202 | title={Cheap and Quick: Efficient Vision-Language Instruction Tuning for Large Language Models}, 203 | author={Luo, Gen and Zhou, Yiyi and Ren, Tianhe and Chen, Shengxin and Sun, Xiaoshuai and Ji, Rongrong}, 204 | journal={Advances in neural information processing systems (NeurIPS)}, 205 | year={2023} 206 | } 207 | ``` 208 | 209 | 210 | ## Acknowledgement 211 | This repo borrows some data and codes from [LLaMA](https://github.com/facebookresearch/llama), [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [LLaVA](https://github.com/haotian-liu/LLaVA), [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) and [LLaMA-Adapter](https://github.com/ZrrSkywalker/LLaMA-Adapter/). Thanks for their great works. 212 | -------------------------------------------------------------------------------- /assets/demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/demo.gif -------------------------------------------------------------------------------- /assets/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/examples.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/logo.png -------------------------------------------------------------------------------- /assets/teaser-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser-1.png -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser.png -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /conversation/conversation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | from PIL import Image 4 | 5 | import torch 6 | from transformers import StoppingCriteria, StoppingCriteriaList 7 | 8 | import dataclasses 9 | from enum import auto, Enum 10 | from typing import List, Tuple, Any 11 | import re 12 | 13 | ERROR_CODE= [260, 1794, 11440] 14 | ERROR_MESSAGE=[1, 7423, 29892, 474, 508, 29915, 29873, 1234, 445, 1139, 29889, 2] 15 | 16 | 17 | class SeparatorStyle(Enum): 18 | """Different separator style.""" 19 | SINGLE = auto() 20 | TWO = auto() 21 | 22 | 23 | @dataclasses.dataclass 24 | class Conversation: 25 | """A class that keeps all conversation history.""" 26 | system: str 27 | roles: List[str] 28 | messages: List[List[str]] 29 | offset: int 30 | # system_img: List[Image.Image] = [] 31 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 32 | sep: str = "" 33 | sep2: str = None 34 | 35 | skip_next: bool = False 36 | conv_id: Any = None 37 | 38 | def get_prompt(self): 39 | if self.sep_style == SeparatorStyle.SINGLE: 40 | ret = self.system 41 | for role, message in self.messages: 42 | if message: 43 | ret += role + ": " + message + self.sep 44 | else: 45 | ret += role + ":" 46 | return ret 47 | elif self.sep_style == SeparatorStyle.TWO: 48 | seps = [self.sep, self.sep2] 49 | ret = self.system + seps[0] 50 | for i, (role, message) in enumerate(self.messages): 51 | if message: 52 | ret += role + ": " + message + seps[i % 2] 53 | else: 54 | ret += role + ":" 55 | return ret 56 | else: 57 | raise ValueError(f"Invalid style: {self.sep_style}") 58 | 59 | def append_message(self, role, message): 60 | self.messages.append([role, message]) 61 | 62 | def to_gradio_chatbot(self): 63 | ret = [] 64 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 65 | if i % 2 == 0: 66 | ret.append([msg, None]) 67 | else: 68 | ret[-1][-1] = msg 69 | return ret 70 | 71 | def copy(self): 72 | return Conversation( 73 | system=self.system, 74 | # system_img=self.system_img, 75 | roles=self.roles, 76 | messages=[[x, y] for x, y in self.messages], 77 | offset=self.offset, 78 | sep_style=self.sep_style, 79 | sep=self.sep, 80 | sep2=self.sep2, 81 | conv_id=self.conv_id) 82 | 83 | def dict(self): 84 | return { 85 | "system": self.system, 86 | # "system_img": self.system_img, 87 | "roles": self.roles, 88 | "messages": self.messages, 89 | "offset": self.offset, 90 | "sep": self.sep, 91 | "sep2": self.sep2, 92 | "conv_id": self.conv_id, 93 | } 94 | 95 | 96 | class StoppingCriteriaSub(StoppingCriteria): 97 | 98 | def __init__(self, stops=[], encounters=1): 99 | super().__init__() 100 | self.stops = stops 101 | 102 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): 103 | for stop in self.stops: 104 | if torch.all((stop == input_ids[0][-len(stop):])).item(): 105 | return True 106 | 107 | return False 108 | 109 | 110 | CONV_VISION = Conversation( 111 | system="", 112 | roles=("Instruction", "Responese"), 113 | messages=[], 114 | offset=2, 115 | sep_style=SeparatorStyle.SINGLE, 116 | sep="\n", 117 | ) 118 | 119 | 120 | 121 | class Chat: 122 | def __init__(self, model, vis_processor, device='cuda:0'): 123 | self.device = device 124 | self.lavin = model 125 | self.vis_processor = vis_processor 126 | 127 | def ask(self, text, conv): 128 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ 129 | and conv.messages[-1][1][-6:] == '': # last message is image. 130 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) 131 | else: 132 | conv.append_message(conv.roles[0], text) 133 | 134 | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, 135 | repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,n_feats=6): 136 | conv.append_message(conv.roles[1], None) 137 | prompt, indicator, img = self.get_context_emb(conv, img_list) 138 | 139 | current_max_len = len(prompt) + max_new_tokens+n_feats 140 | if current_max_len - max_length > 0: 141 | print('Warning: The number of tokens in current conversation exceeds the max length. ' 142 | 'The model will not see the contexts outside the range.') 143 | begin_idx = max(0, current_max_len - max_length) 144 | 145 | prompt = prompt[begin_idx:] 146 | CODE=self.lavin.tokenizer.encode(prompt, bos=False, eos=False) 147 | if ERROR_CODE in [CODE[i:i+len(ERROR_CODE)] for i in range(len(CODE)-len(ERROR_CODE)+1)]: 148 | output_text=self.lavin.tokenizer.decode(ERROR_MESSAGE).split('Responese:')[-1].strip() 149 | else: 150 | outputs = self.lavin.generate( 151 | prompts= [prompt], 152 | images= img, 153 | indicators=[indicator], 154 | max_gen_len=max_length, 155 | n_feats=n_feats, 156 | temperature = 0.1, 157 | top_p = 0.75, 158 | ) 159 | 160 | output_text = outputs[0].split('Responese:')[-1].strip() 161 | 162 | conv.messages[-1][1] = output_text 163 | return output_text 164 | 165 | def upload_img(self, image, conv, img_list): 166 | if isinstance(image, str): # is a image path 167 | raw_image = Image.open(image).convert('RGB') 168 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 169 | elif isinstance(image, Image.Image): 170 | raw_image = image 171 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) 172 | elif isinstance(image, torch.Tensor): 173 | if len(image.shape) == 3: 174 | image = image.unsqueeze(0) 175 | image = image.to(self.device) 176 | 177 | # image_emb, _ = self.lavin.backbone.encode_img(image) 178 | img_list.append(image) 179 | conv.append_message(conv.roles[0], "") 180 | msg = "Received." 181 | # self.conv.append_message(self.conv.roles[1], msg) 182 | return msg 183 | 184 | def get_context_emb(self, conv, img_list): 185 | prompt = conv.get_prompt() 186 | 187 | if '' in prompt: 188 | indicator=1 189 | prompt=prompt.replace('','') 190 | else: 191 | indicator=0 192 | assert img_list is None or len(img_list) <= 1 193 | 194 | return prompt, indicator, img_list[0] if indicator==1 else torch.Tensor(torch.zeros(1,3, 224, 224).float()) -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import gradio as gr 9 | from PIL import Image 10 | # from minigpt4.common.config import Config 11 | from util.misc import get_rank 12 | # from minigpt4.common.registry import registry 13 | from conversation.conversation import Chat, CONV_VISION 14 | from torchvision.transforms import transforms 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from eval import load 17 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 18 | from typing import Tuple 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description="Demo") 21 | parser.add_argument("--server_name", type=str, default="127.0.0.1", help="server name") 22 | parser.add_argument("--ckpt_dir", type=str, default="../data/weights/", help="dir of pre-trained weights.") 23 | parser.add_argument("--llm_model", type=str, default="13B", help="the type of llm.") 24 | parser.add_argument("--max_seq_len", type=int, default=512, help="decoder length") 25 | parser.add_argument('--adapter_type', type=str, default='attn', metavar='LENGTH',choices=['block','attn'], 26 | help='the insert position of adapter layer') 27 | parser.add_argument('--adapter_path', type=str, default='./15-eph-pretrain.pth', help='path of pre-trained adapter') 28 | parser.add_argument('--temperature', type=float, default=5., metavar='LENGTH', 29 | help='the temperature of router') 30 | parser.add_argument('--use_vicuna', action='store_true', help='use vicuna weights') 31 | parser.add_argument( 32 | "--options", 33 | nargs="+", 34 | help="override some settings in the used config, the key-value pair " 35 | "in xxx=yyy format will be merged into config file (deprecate), " 36 | "change to --cfg-options instead.", 37 | ) 38 | args = parser.parse_args() 39 | return args 40 | 41 | def setup_model_parallel() -> Tuple[int, int]: 42 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) 43 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 44 | 45 | torch.distributed.init_process_group("nccl") 46 | initialize_model_parallel(world_size) 47 | torch.cuda.set_device(local_rank) 48 | 49 | # seed must be the same in all processes 50 | torch.manual_seed(1) 51 | return local_rank, world_size 52 | 53 | def setup_seeds(config): 54 | seed = config.run_cfg.seed + get_rank() 55 | 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | torch.manual_seed(seed) 59 | 60 | cudnn.benchmark = False 61 | cudnn.deterministic = True 62 | 63 | 64 | # ======================================== 65 | # Model Initialization 66 | # ======================================== 67 | 68 | print('Initializing Chat') 69 | args = parse_args() 70 | 71 | 72 | local_rank, world_size = setup_model_parallel() 73 | lavin=load( 74 | ckpt_dir=args.ckpt_dir, 75 | llm_model=args.llm_model, 76 | adapter_path=args.adapter_path, 77 | max_seq_len=512, 78 | max_batch_size=4, 79 | adapter_type='attn', 80 | adapter_dim=8, 81 | adapter_scale=1, 82 | hidden_proj=128, 83 | visual_adapter_type='router', 84 | temperature=args.temperature, 85 | tokenizer_path='', 86 | local_rank=local_rank, 87 | world_size=world_size, 88 | use_vicuna=args.use_vicuna 89 | ) 90 | 91 | vis_processor = transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 92 | chat = Chat(lavin, vis_processor, device=torch.device('cuda')) 93 | print('Initialization Finished') 94 | 95 | 96 | # ======================================== 97 | # Gradio Setting 98 | # ======================================== 99 | 100 | def gradio_reset(chat_state, img_list): 101 | if chat_state is not None: 102 | chat_state.messages = [] 103 | if img_list is not None: 104 | img_list = [] 105 | return None, gr.update(value=None, interactive=True), gr.update(placeholder='Type and press Enter', 106 | interactive=True), gr.update( 107 | value="Upload & Start Chat", interactive=True), chat_state, img_list 108 | 109 | 110 | def upload_img(gr_img, text_input, chat_state): 111 | if gr_img is None: 112 | return None, None, gr.update(interactive=True), chat_state, None 113 | chat_state = CONV_VISION.copy() 114 | img_list = [] 115 | llm_message = chat.upload_img(gr_img, chat_state, img_list) 116 | return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update( 117 | value="Start Chatting", interactive=False), chat_state, img_list 118 | 119 | 120 | def gradio_ask(user_message, chatbot, chat_state): 121 | if len(user_message) == 0: 122 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state 123 | if chat_state is None: 124 | chat_state=CONV_VISION.copy() 125 | chat.ask(user_message, chat_state) 126 | chatbot = chatbot + [[user_message, None]] 127 | return '', chatbot, chat_state 128 | 129 | 130 | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): 131 | llm_message = chat.answer(conv=chat_state, 132 | img_list=img_list, 133 | num_beams=num_beams, 134 | temperature=temperature, 135 | max_new_tokens=300, 136 | max_length=2000) 137 | chatbot[-1][1] = llm_message 138 | return chatbot, chat_state, img_list 139 | 140 | 141 | title = """

Demo of LaVIN

""" 142 | description = """

This is the demo of LaVIN. Upload your images and start chatting!

""" 143 | 144 | 145 | 146 | with gr.Blocks() as demo: 147 | gr.Markdown(title) 148 | gr.Markdown(description) 149 | 150 | with gr.Row(): 151 | with gr.Column(scale=0.5): 152 | image = gr.Image(type="pil") 153 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") 154 | clear = gr.Button("Restart") 155 | 156 | num_beams = gr.Slider( 157 | minimum=1, 158 | maximum=10, 159 | value=1, 160 | step=1, 161 | interactive=True, 162 | label="beam search numbers)", 163 | ) 164 | 165 | temperature = gr.Slider( 166 | minimum=0.1, 167 | maximum=2.0, 168 | value=1.0, 169 | step=0.1, 170 | interactive=True, 171 | label="Temperature", 172 | ) 173 | 174 | with gr.Column(): 175 | chat_state = gr.State() 176 | img_list = gr.State() 177 | chatbot = gr.Chatbot(label='LaVIN-13B') 178 | text_input = gr.Textbox(label='User', placeholder='Type and press Enter', interactive=True) 179 | 180 | upload_button.click(upload_img, [image, text_input, chat_state], 181 | [image, text_input, upload_button, chat_state, img_list]) 182 | 183 | text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( 184 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] 185 | ) 186 | clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], 187 | queue=False) 188 | 189 | demo.launch(share=True, enable_queue=True,server_name=args.server_name) -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Iterable 4 | 5 | import torch 6 | 7 | import util.misc as misc 8 | import util.lr_sched as lr_sched 9 | 10 | 11 | 12 | def train_one_epoch(model: torch.nn.Module, 13 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 14 | device: torch.device, epoch: int, loss_scaler, 15 | log_writer=None, 16 | args=None): 17 | 18 | model.train(True) 19 | metric_logger = misc.MetricLogger(delimiter=" ") 20 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 21 | header = 'Epoch: [{}]'.format(epoch) 22 | print_freq = 10 23 | 24 | accum_iter = args.accum_iter 25 | 26 | optimizer.zero_grad() 27 | 28 | if log_writer is not None: 29 | print('log_dir: {}'.format(log_writer.log_dir)) 30 | 31 | 32 | 33 | prefix_img = torch.tensor(data_loader.dataset.tokenizer.encode("Image: ", bos=False, eos=False), dtype=torch.int64) 34 | prefix_nonimg = torch.tensor(data_loader.dataset.tokenizer.encode("Image: N/A", bos=False, eos=False), dtype=torch.int64) 35 | 36 | for data_iter_step, (examples, labels, example_mask,images,indicators) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 37 | # we use a per iteration (instead of per epoch) lr scheduler 38 | if data_iter_step % accum_iter == 0: 39 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) 40 | 41 | prefix_img=prefix_img.to(examples.device) 42 | prefix_nonimg=prefix_nonimg.to(examples.device) 43 | c_loss = model(examples, labels,images=images, prefix_img=prefix_img, prefix_nonimg=prefix_nonimg,img_indicators=indicators) 44 | loss = c_loss 45 | loss_value = loss.item() 46 | c_loss_value = c_loss.item() 47 | 48 | 49 | if torch.isnan(loss): 50 | print("NaN loss encountered. Skipping this batch.") 51 | continue 52 | 53 | loss = loss/accum_iter 54 | 55 | loss_scaler(loss, optimizer, parameters=model.parameters(), 56 | update_grad=(data_iter_step + 1) % accum_iter == 0,clip_grad=args.clip_grad) 57 | if (data_iter_step + 1) % accum_iter == 0: 58 | optimizer.zero_grad() 59 | 60 | torch.cuda.synchronize() 61 | 62 | metric_logger.update(closs=c_loss_value) 63 | 64 | lr = optimizer.param_groups[0]["lr"] 65 | metric_logger.update(lr=lr) 66 | 67 | loss_value_reduce = misc.all_reduce_mean(loss_value) 68 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value) 69 | 70 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 71 | """ We use epoch_1000x as the x-axis in tensorboard. 72 | This calibrates different curves when batch size changes. 73 | """ 74 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 75 | log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x) 76 | log_writer.add_scalar('lr', lr, epoch_1000x) 77 | 78 | # gather the stats from all processes 79 | metric_logger.synchronize_between_processes() 80 | print("Averaged stats:", metric_logger) 81 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 82 | 83 | 84 | def val_one_epoch(model: torch.nn.Module, 85 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 86 | device: torch.device, epoch: int, loss_scaler, 87 | log_writer=None, 88 | args=None): 89 | model.eval() 90 | metric_logger = misc.MetricLogger(delimiter=" ") 91 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) 92 | header = 'Epoch: [{}]'.format(epoch) 93 | print_freq = 10 94 | 95 | accum_iter = args.accum_iter 96 | 97 | if log_writer is not None: 98 | print('log_dir: {}'.format(log_writer.log_dir)) 99 | for data_iter_step, (examples, labels, example_mask) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 100 | 101 | with torch.no_grad(): 102 | c_loss = model(examples, labels) 103 | loss = c_loss 104 | loss_value = loss.item() 105 | 106 | c_loss_value = c_loss.item() 107 | 108 | if not math.isfinite(loss_value): 109 | print("Loss is {}, stopping training".format(loss_value)) 110 | sys.exit(1) 111 | 112 | metric_logger.update(closs=c_loss_value) 113 | 114 | lr = optimizer.param_groups[0]["lr"] 115 | metric_logger.update(lr=lr) 116 | 117 | loss_value_reduce = misc.all_reduce_mean(loss_value) 118 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value) 119 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 120 | """ We use epoch_1000x as the x-axis in tensorboard. 121 | This calibrates different curves when batch size changes. 122 | """ 123 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 124 | log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x) 125 | log_writer.add_scalar('lr', lr, epoch_1000x) 126 | 127 | # gather the stats from all processes 128 | metric_logger.synchronize_between_processes() 129 | print("Averaged stats:", metric_logger) 130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 131 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Tuple 5 | import os 6 | import sys 7 | import torch 8 | import fire 9 | import time 10 | import json 11 | 12 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 13 | 14 | from lavin.eval_model import ModelArgs, Transformer 15 | from lavin.tokenizer import Tokenizer 16 | from lavin.generator import LaVIN_Generator 17 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter 18 | from util.base_prompt import build_prompt 19 | from dataclasses import dataclass 20 | import re 21 | import random 22 | 23 | import warnings 24 | import pandas as pd 25 | from PIL import Image 26 | 27 | from torchvision.transforms import transforms 28 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 29 | 30 | from pathlib import Path 31 | import fairscale.nn.model_parallel.initialize as fs_init 32 | import torch.distributed as dist 33 | from util.apply_delta import apply_model_delta_online 34 | 35 | warnings.filterwarnings('ignore') 36 | 37 | @dataclass 38 | class PromptArgs: 39 | prompt_format='QCM-ALE' 40 | use_caption=True 41 | options=["A", "B", "C", "D", "E"] 42 | 43 | def setup_model_parallel() -> Tuple[int, int]: 44 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) 45 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 46 | 47 | torch.distributed.init_process_group("nccl") 48 | initialize_model_parallel(world_size) 49 | torch.cuda.set_device(local_rank) 50 | 51 | # seed must be the same in all processes 52 | torch.manual_seed(1) 53 | return local_rank, world_size 54 | 55 | def _load_and_redistribute_checkpoint(llama_model_path, model_name): 56 | 57 | with open(Path(llama_model_path) / model_name / 'params.json') as f: 58 | params = json.load(f) 59 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model')) 60 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name)) 61 | if model_name=='7B': 62 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu") 63 | return checkpoint, tokenizer, params 64 | 65 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth') 66 | checkpoints = sorted(checkpoints) 67 | 68 | mp_world_size = fs_init.get_model_parallel_world_size() 69 | mp_rank = fs_init.get_model_parallel_rank() 70 | if mp_world_size == len(checkpoints): 71 | print('same number of shards of checkpoints and training, loading directly...') 72 | dist.barrier() 73 | print('[rank=%d, mp_rank=%d] loading from %s' % (dist.get_rank(), mp_rank, checkpoints[mp_rank])) 74 | checkpoint = torch.load(checkpoints[mp_rank], map_location='cpu') 75 | else: 76 | print('different number of shards of checkpoints and training, redistributing...') 77 | if dist.get_rank() == 0: 78 | loaded = [] 79 | for x in checkpoints: 80 | print('loading from', x) 81 | loaded.append(torch.load(x, map_location='cpu')) 82 | 83 | full_state_dict = {} 84 | split_dims = {} 85 | 86 | def add_weight_with_split_dim(name, dim): 87 | if dim < 0: # bcast without split 88 | full_state_dict[name] = loaded[0][name].clone() 89 | else: 90 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim) 91 | for x in loaded: 92 | del x[name] 93 | split_dims[name] = dim 94 | 95 | add_weight_with_split_dim('tok_embeddings.weight', 1) 96 | add_weight_with_split_dim('norm.weight', -1) 97 | add_weight_with_split_dim('output.weight', 0) 98 | for i in range(params['n_layers']): 99 | print('gathering layer %d of %d' % (i, params['n_layers'])) 100 | layer_prefix = f'layers.{i}.' 101 | bcast_names = [ 102 | 'attention_norm.weight', 103 | 'ffn_norm.weight', 104 | ] 105 | column_parallel_names = [ 106 | 'attention.wq.weight', 107 | 'attention.wk.weight', 108 | 'attention.wv.weight', 109 | 'feed_forward.w1.weight', 110 | 'feed_forward.w3.weight', 111 | ] 112 | row_parallel_names = [ 113 | 'attention.wo.weight', 114 | 'feed_forward.w2.weight', 115 | ] 116 | for key in bcast_names: 117 | add_weight_with_split_dim(layer_prefix + key, -1) 118 | for key in column_parallel_names: 119 | add_weight_with_split_dim(layer_prefix + key, 0) 120 | for key in row_parallel_names: 121 | add_weight_with_split_dim(layer_prefix + key, 1) 122 | 123 | full_state_dict_meta = dict((k, v.shape) for k, v in full_state_dict.items()) 124 | dist.broadcast_object_list([full_state_dict_meta, split_dims], src=0) 125 | 126 | else: # dist.get_rank() != 0 127 | recv_objs = [None, None] 128 | dist.broadcast_object_list(recv_objs, src=0) 129 | full_state_dict_meta, split_dims = recv_objs 130 | 131 | local_state_dict = {} 132 | for k in sorted(full_state_dict_meta.keys()): 133 | print('redistributing weights: %s' % k) 134 | if dist.get_rank() == 0: 135 | value = full_state_dict[k].cuda().half() 136 | del full_state_dict[k] 137 | else: 138 | value = torch.empty(full_state_dict_meta[k], device='cuda', dtype=torch.half) 139 | dist.broadcast(value, src=0) 140 | value = value.cpu() 141 | if split_dims[k] < 0: 142 | local_state_dict[k] = value 143 | else: 144 | dim = split_dims[k] 145 | assert dim >= 0 and dim < value.ndim and value.size(dim) % mp_world_size == 0 146 | shard_size = value.size(dim) // mp_world_size 147 | shard_st, shard_ed = shard_size * mp_rank, shard_size * (mp_rank + 1) 148 | # TODO: make more general 149 | if dim == 0: 150 | value = value[shard_st: shard_ed] 151 | elif dim == 1: 152 | value = value[:, shard_st: shard_ed] 153 | else: 154 | raise NotImplementedError() 155 | local_state_dict[k] = value.clone() 156 | 157 | checkpoint = local_state_dict 158 | 159 | return checkpoint, tokenizer, params 160 | 161 | 162 | 163 | def get_acc_with_contion(res_pd, key, values): 164 | if isinstance(values, list): 165 | total_pd = res_pd[res_pd[key].isin(values)] 166 | else: 167 | total_pd = res_pd[res_pd[key] == values] 168 | correct_pd = total_pd[total_pd['true_false'] == True] 169 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) 170 | return acc 171 | 172 | 173 | def get_scores(result_file, data_file): 174 | # read result file 175 | results = json.load(open(result_file)) 176 | num = len(results) 177 | assert num == 4241 178 | 179 | sqa_data = json.load(open(data_file)) 180 | 181 | # construct pandas data 182 | sqa_pd = pd.DataFrame(sqa_data).T 183 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set 184 | 185 | # update data 186 | for index, row in res_pd.iterrows(): 187 | 188 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False 189 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False 190 | res_pd.loc[index, 'has_image'] = True if row['image'] else False 191 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False 192 | 193 | label = row['answer'] 194 | pred = int(results[index]) 195 | res_pd.loc[index, 'pred'] = pred 196 | res_pd.loc[index, 'true_false'] = (label == pred) 197 | 198 | # accuracy scores 199 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100 200 | 201 | scores = { 202 | 'acc_natural': 203 | get_acc_with_contion(res_pd, 'subject', 'natural science'), 204 | 'acc_social': 205 | get_acc_with_contion(res_pd, 'subject', 'social science'), 206 | 'acc_language': 207 | get_acc_with_contion(res_pd, 'subject', 'language science'), 208 | 'acc_has_text': 209 | get_acc_with_contion(res_pd, 'has_text', True), 210 | 'acc_has_image': 211 | get_acc_with_contion(res_pd, 'has_image', True), 212 | 'acc_no_context': 213 | get_acc_with_contion(res_pd, 'no_context', True), 214 | 'acc_grade_1_6': 215 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']), 216 | 'acc_grade_7_12': 217 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']), 218 | 'acc_average': 219 | "{:.2f}".format(acc_average), 220 | } 221 | 222 | return scores 223 | 224 | 225 | def print_scores(scores): 226 | latex_output = "" 227 | for key, score in scores.items(): 228 | print(f"{key[4:]}: \t{score}") 229 | latex_output += f"& {score} " 230 | latex_output += "\\\\" 231 | print(latex_output) 232 | 233 | def load( 234 | ckpt_dir: str, 235 | llm_model:str, 236 | tokenizer_path: str, 237 | adapter_path: str, 238 | local_rank: int, 239 | world_size: int, 240 | max_seq_len: int, 241 | max_batch_size: int, 242 | adapter_type: str, 243 | adapter_dim:int, 244 | adapter_scale:float, 245 | hidden_proj:int, 246 | visual_adapter_type: str, 247 | temperature: float, 248 | use_vicuna: bool, 249 | bits: str='16bits', 250 | cpu_load:bool=False, 251 | ) -> LaVIN_Generator: 252 | start_time = time.time() 253 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(ckpt_dir, llm_model) 254 | 255 | print("Loading") 256 | adapter_checkpoint = torch.load(adapter_path, map_location="cpu") 257 | 258 | 259 | model_args: ModelArgs = ModelArgs( 260 | max_seq_len=max_seq_len, max_batch_size=max_batch_size,hidden_proj=hidden_proj, **params 261 | ) 262 | model_args.vocab_size = tokenizer.n_words 263 | 264 | if cpu_load: 265 | #cpu load is slow, but is freindly for GPU with limited memory. 266 | torch.set_default_tensor_type(torch.HalfTensor) 267 | else: 268 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 269 | 270 | model = Transformer(model_args) 271 | #delete language encoder 272 | del model.backbone.transformer 273 | 274 | torch.set_default_tensor_type(torch.FloatTensor) 275 | 276 | if bits in ['4bit','8bit']: 277 | from util.quantization import quant_model_bnb 278 | model.layers = quant_model_bnb(model.layers, quant_bit='4bit') 279 | 280 | set_MMAdapter(model, adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature) 281 | set_Clip_Adapter(model.backbone.visual, visual_adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature) 282 | 283 | model.load_state_dict(checkpoint, strict=False) 284 | 285 | if use_vicuna: 286 | apply_model_delta_online(model,'../data/weights/vicuna_'+llm_model) 287 | 288 | 289 | state_dict={} 290 | for key in adapter_checkpoint['model']: 291 | state_dict[key.replace('module.','')]=adapter_checkpoint['model'][key] 292 | 293 | model.load_state_dict(state_dict, strict=False) 294 | model.to(torch.device('cuda')) 295 | 296 | for name, param in model.named_parameters(): 297 | print(name,param.dtype) 298 | generator = LaVIN_Generator(model, tokenizer) 299 | print(f"Loaded in {time.time() - start_time:.2f} seconds") 300 | return generator 301 | 302 | def get_pred_idx(prediction, choices, options): 303 | """ 304 | Get the index (e.g. 2) from the prediction (e.g. 'C') 305 | """ 306 | if prediction in options[:len(choices)]: 307 | return options.index(prediction) 308 | else: 309 | return random.choice(range(len(choices))) 310 | 311 | def main( 312 | ckpt_dir: str, 313 | tokenizer_path: str, 314 | adapter_path: str, 315 | data_root:str, 316 | caption_file:str, 317 | max_seq_len: int, 318 | max_batch_size: int, 319 | llm_model:str='7B', 320 | generation_temperature: float = 0.1, 321 | top_p: float = 0.75, 322 | split='val', 323 | prompt_format='QCM-ALE', 324 | use_caption=False, 325 | options=["A", "B", "C", "D", "E"], 326 | adapter_type='repattn', 327 | adapter_dim=8, 328 | adapter_scale=1, 329 | n_prompt=10, 330 | hidden_proj=128, 331 | visual_adapter_type='normal', 332 | temperature=10., 333 | use_vicuna=False, 334 | bits: str='16bits', 335 | cpu_load:bool=False, 336 | ): 337 | print(max_batch_size,max_seq_len) 338 | print('use caption: ',use_caption) 339 | local_rank, world_size = setup_model_parallel() 340 | if local_rank > 0: 341 | sys.stdout = open(os.devnull, "w") 342 | 343 | generator = load( 344 | ckpt_dir,llm_model, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size, 345 | adapter_type,adapter_dim,adapter_scale,hidden_proj,visual_adapter_type, 346 | temperature,use_vicuna,bits=bits,cpu_load=cpu_load) 347 | 348 | print('split: ', split) 349 | problems = json.load(open(os.path.join(data_root, 'problems.json'))) 350 | pid_splits = json.load(open(os.path.join(data_root, 'pid_splits.json'))) 351 | captions = json.load(open(caption_file))["captions"] 352 | image_path = os.path.join(data_root, 'images', split) 353 | qids = pid_splits['%s' % (split)] 354 | total_items=len(qids) 355 | for qid in problems: 356 | problems[qid]['caption'] = captions[qid] if qid in captions else "" 357 | print('total_items: ',total_items) 358 | 359 | 360 | image_transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 361 | 362 | prompt_args=PromptArgs() 363 | prompt_args.prompt_format = prompt_format 364 | prompt_args.use_caption = use_caption 365 | prompt_args.options = options 366 | 367 | pattern = re.compile(r'The answer is ([A-Z]).') 368 | 369 | answers = [] 370 | preds=[] 371 | for i in range(total_items//max_batch_size+1): 372 | print('progresses: ',i,' / ', total_items//max_batch_size+1) 373 | batch_qids=qids[i*max_batch_size:(i+1)*max_batch_size] 374 | if len(batch_qids)==0: 375 | break 376 | indicators = [] 377 | prompts=[] 378 | images = [] 379 | for qid in batch_qids: 380 | prompt,_ = build_prompt(problems, qid, prompt_args) 381 | 382 | answer=problems[qid]["answer"] 383 | if problems[qid]['image'] is not None: 384 | image = Image.open(os.path.join(image_path, qid, 'image.png')).convert('RGB') 385 | image = image_transforms(image) 386 | indicator = 1 387 | else: 388 | image = torch.Tensor(torch.zeros(3, 224, 224).float()) 389 | indicator = 0 390 | prompts.append(prompt) 391 | answers.append(answer) 392 | images.append(image.unsqueeze(0)) 393 | indicators.append(indicator) 394 | images=torch.cat(images,0) 395 | 396 | 397 | results = generator.generate( 398 | prompts,images=images,indicators=indicators, max_gen_len=64, temperature=generation_temperature, top_p=top_p,n_feats=n_prompt 399 | ) 400 | 401 | for result in results: 402 | pred = pattern.findall(result) 403 | 404 | if len(pred) >= 1: 405 | pred = pred[0] # 'A', 'B', ... 406 | else: 407 | print(result) 408 | pred = "FAILED" 409 | preds.append(pred) 410 | 411 | #evaluations 412 | results={} 413 | correct=0 414 | for i, prediction in enumerate(preds): 415 | pred_idx = get_pred_idx(prediction, problems[qids[i]]["choices"], 416 | prompt_args.options) # 0, 1, ..., 4 417 | if pred_idx == answers[i]: 418 | correct += 1 419 | results[qids[i]] = pred_idx 420 | acc = correct / len(results) * 100 421 | print('overall accuracy: ', acc) 422 | 423 | with open('./preds.json', 'w') as f: 424 | json.dump(results,f) 425 | 426 | scores=get_scores('./preds.json',os.path.join(data_root, 'problems.json')) 427 | print(scores) 428 | import time 429 | with open(str(time.time())+'.txt','w') as f: 430 | f.write(str(scores)) 431 | 432 | 433 | if __name__ == "__main__": 434 | fire.Fire(main) 435 | -------------------------------------------------------------------------------- /eval_mme.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Tuple 5 | import os 6 | import sys 7 | import torch 8 | import fire 9 | import time 10 | import json 11 | 12 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 13 | 14 | from lavin.eval_model import ModelArgs, Transformer 15 | from lavin.tokenizer import Tokenizer 16 | from lavin.generator import LaVIN_Generator 17 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter 18 | from util.base_prompt import build_prompt 19 | from dataclasses import dataclass 20 | import re 21 | import random 22 | 23 | import warnings 24 | import pandas as pd 25 | from PIL import Image 26 | 27 | from torchvision.transforms import transforms 28 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 29 | 30 | from pathlib import Path 31 | import fairscale.nn.model_parallel.initialize as fs_init 32 | import torch.distributed as dist 33 | from util.apply_delta import apply_model_delta_online 34 | 35 | warnings.filterwarnings('ignore') 36 | 37 | @dataclass 38 | class PromptArgs: 39 | prompt_format='QCM-ALE' 40 | use_caption=True 41 | options=["A", "B", "C", "D", "E"] 42 | 43 | def setup_model_parallel() -> Tuple[int, int]: 44 | local_rank = int(os.environ.get("LOCAL_RANK", -1)) 45 | world_size = int(os.environ.get("WORLD_SIZE", -1)) 46 | 47 | torch.distributed.init_process_group("nccl") 48 | initialize_model_parallel(world_size) 49 | torch.cuda.set_device(local_rank) 50 | 51 | # seed must be the same in all processes 52 | torch.manual_seed(1) 53 | return local_rank, world_size 54 | 55 | def _load_and_redistribute_checkpoint(llama_model_path, model_name): 56 | 57 | with open(Path(llama_model_path) / model_name / 'params.json') as f: 58 | params = json.load(f) 59 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model')) 60 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name)) 61 | if model_name=='7B': 62 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu") 63 | return checkpoint, tokenizer, params 64 | 65 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth') 66 | checkpoints = sorted(checkpoints) 67 | 68 | mp_world_size = fs_init.get_model_parallel_world_size() 69 | mp_rank = fs_init.get_model_parallel_rank() 70 | if mp_world_size == len(checkpoints): 71 | print('same number of shards of checkpoints and training, loading directly...') 72 | dist.barrier() 73 | print('[rank=%d, mp_rank=%d] loading from %s' % (dist.get_rank(), mp_rank, checkpoints[mp_rank])) 74 | checkpoint = torch.load(checkpoints[mp_rank], map_location='cpu') 75 | else: 76 | print('different number of shards of checkpoints and training, redistributing...') 77 | if dist.get_rank() == 0: 78 | loaded = [] 79 | for x in checkpoints: 80 | print('loading from', x) 81 | loaded.append(torch.load(x, map_location='cpu')) 82 | 83 | full_state_dict = {} 84 | split_dims = {} 85 | 86 | def add_weight_with_split_dim(name, dim): 87 | if dim < 0: # bcast without split 88 | full_state_dict[name] = loaded[0][name].clone() 89 | else: 90 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim) 91 | for x in loaded: 92 | del x[name] 93 | split_dims[name] = dim 94 | 95 | add_weight_with_split_dim('tok_embeddings.weight', 1) 96 | add_weight_with_split_dim('norm.weight', -1) 97 | add_weight_with_split_dim('output.weight', 0) 98 | for i in range(params['n_layers']): 99 | print('gathering layer %d of %d' % (i, params['n_layers'])) 100 | layer_prefix = f'layers.{i}.' 101 | bcast_names = [ 102 | 'attention_norm.weight', 103 | 'ffn_norm.weight', 104 | ] 105 | column_parallel_names = [ 106 | 'attention.wq.weight', 107 | 'attention.wk.weight', 108 | 'attention.wv.weight', 109 | 'feed_forward.w1.weight', 110 | 'feed_forward.w3.weight', 111 | ] 112 | row_parallel_names = [ 113 | 'attention.wo.weight', 114 | 'feed_forward.w2.weight', 115 | ] 116 | for key in bcast_names: 117 | add_weight_with_split_dim(layer_prefix + key, -1) 118 | for key in column_parallel_names: 119 | add_weight_with_split_dim(layer_prefix + key, 0) 120 | for key in row_parallel_names: 121 | add_weight_with_split_dim(layer_prefix + key, 1) 122 | 123 | full_state_dict_meta = dict((k, v.shape) for k, v in full_state_dict.items()) 124 | dist.broadcast_object_list([full_state_dict_meta, split_dims], src=0) 125 | 126 | else: # dist.get_rank() != 0 127 | recv_objs = [None, None] 128 | dist.broadcast_object_list(recv_objs, src=0) 129 | full_state_dict_meta, split_dims = recv_objs 130 | 131 | local_state_dict = {} 132 | for k in sorted(full_state_dict_meta.keys()): 133 | print('redistributing weights: %s' % k) 134 | if dist.get_rank() == 0: 135 | value = full_state_dict[k].cuda().half() 136 | del full_state_dict[k] 137 | else: 138 | value = torch.empty(full_state_dict_meta[k], device='cuda', dtype=torch.half) 139 | dist.broadcast(value, src=0) 140 | value = value.cpu() 141 | if split_dims[k] < 0: 142 | local_state_dict[k] = value 143 | else: 144 | dim = split_dims[k] 145 | assert dim >= 0 and dim < value.ndim and value.size(dim) % mp_world_size == 0 146 | shard_size = value.size(dim) // mp_world_size 147 | shard_st, shard_ed = shard_size * mp_rank, shard_size * (mp_rank + 1) 148 | # TODO: make more general 149 | if dim == 0: 150 | value = value[shard_st: shard_ed] 151 | elif dim == 1: 152 | value = value[:, shard_st: shard_ed] 153 | else: 154 | raise NotImplementedError() 155 | local_state_dict[k] = value.clone() 156 | 157 | checkpoint = local_state_dict 158 | 159 | return checkpoint, tokenizer, params 160 | 161 | 162 | 163 | def get_acc_with_contion(res_pd, key, values): 164 | if isinstance(values, list): 165 | total_pd = res_pd[res_pd[key].isin(values)] 166 | else: 167 | total_pd = res_pd[res_pd[key] == values] 168 | correct_pd = total_pd[total_pd['true_false'] == True] 169 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) 170 | return acc 171 | 172 | 173 | def get_scores(result_file, data_file): 174 | # read result file 175 | results = json.load(open(result_file)) 176 | num = len(results) 177 | assert num == 4241 178 | 179 | sqa_data = json.load(open(data_file)) 180 | 181 | # construct pandas data 182 | sqa_pd = pd.DataFrame(sqa_data).T 183 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set 184 | 185 | # update data 186 | for index, row in res_pd.iterrows(): 187 | 188 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False 189 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False 190 | res_pd.loc[index, 'has_image'] = True if row['image'] else False 191 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False 192 | 193 | label = row['answer'] 194 | pred = int(results[index]) 195 | res_pd.loc[index, 'pred'] = pred 196 | res_pd.loc[index, 'true_false'] = (label == pred) 197 | 198 | # accuracy scores 199 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100 200 | 201 | scores = { 202 | 'acc_natural': 203 | get_acc_with_contion(res_pd, 'subject', 'natural science'), 204 | 'acc_social': 205 | get_acc_with_contion(res_pd, 'subject', 'social science'), 206 | 'acc_language': 207 | get_acc_with_contion(res_pd, 'subject', 'language science'), 208 | 'acc_has_text': 209 | get_acc_with_contion(res_pd, 'has_text', True), 210 | 'acc_has_image': 211 | get_acc_with_contion(res_pd, 'has_image', True), 212 | 'acc_no_context': 213 | get_acc_with_contion(res_pd, 'no_context', True), 214 | 'acc_grade_1_6': 215 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']), 216 | 'acc_grade_7_12': 217 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']), 218 | 'acc_average': 219 | "{:.2f}".format(acc_average), 220 | } 221 | 222 | return scores 223 | 224 | 225 | def print_scores(scores): 226 | latex_output = "" 227 | for key, score in scores.items(): 228 | print(f"{key[4:]}: \t{score}") 229 | latex_output += f"& {score} " 230 | latex_output += "\\\\" 231 | print(latex_output) 232 | 233 | def load( 234 | ckpt_dir: str, 235 | llm_model:str, 236 | tokenizer_path: str, 237 | adapter_path: str, 238 | local_rank: int, 239 | world_size: int, 240 | max_seq_len: int, 241 | max_batch_size: int, 242 | adapter_type: str, 243 | adapter_dim:int, 244 | adapter_scale:float, 245 | hidden_proj:int, 246 | visual_adapter_type: str, 247 | temperature: float, 248 | use_vicuna: bool 249 | ) -> LaVIN_Generator: 250 | start_time = time.time() 251 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(ckpt_dir, llm_model) 252 | 253 | print("Loading") 254 | adapter_checkpoint = torch.load(adapter_path, map_location="cpu") 255 | 256 | 257 | model_args: ModelArgs = ModelArgs( 258 | max_seq_len=max_seq_len, max_batch_size=max_batch_size,hidden_proj=hidden_proj, **params 259 | ) 260 | model_args.vocab_size = tokenizer.n_words 261 | 262 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 263 | model = Transformer(model_args) 264 | set_MMAdapter(model, adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature) 265 | set_Clip_Adapter(model.backbone.visual, visual_adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature) 266 | 267 | torch.set_default_tensor_type(torch.FloatTensor) 268 | model.load_state_dict(checkpoint, strict=False) 269 | 270 | if use_vicuna: 271 | apply_model_delta_online(model,'../data/weights/vicuna_'+llm_model) 272 | 273 | state_dict={} 274 | for key in adapter_checkpoint['model']: 275 | state_dict[key.replace('module.','')]=adapter_checkpoint['model'][key] 276 | 277 | model.load_state_dict(state_dict, strict=False) 278 | 279 | generator = LaVIN_Generator(model, tokenizer) 280 | print(f"Loaded in {time.time() - start_time:.2f} seconds") 281 | return generator 282 | 283 | def get_pred_idx(prediction, choices, options): 284 | """ 285 | Get the index (e.g. 2) from the prediction (e.g. 'C') 286 | """ 287 | if prediction in options[:len(choices)]: 288 | return options.index(prediction) 289 | else: 290 | return random.choice(range(len(choices))) 291 | 292 | def prepare_data(dir): 293 | if os.path.exists(os.path.join(dir,'images')): 294 | image_dir=os.path.join(dir,'images') 295 | else: 296 | image_dir=dir 297 | if os.path.exists(os.path.join(dir,'questions_answers_YN')): 298 | ann_dir=os.path.join(dir,'questions_answers_YN') 299 | else: 300 | ann_dir=dir 301 | image_list=[] 302 | image_path=[] 303 | for root, dirs, files in os.walk(image_dir): 304 | for file in files: 305 | # 检查文件后缀名是否为.jpg或.png 306 | if file.endswith(".jpg") or file.endswith(".png"): 307 | # 拼接文件的完整路径 308 | image_list.append(file) 309 | image_path.append(os.path.join(image_dir,file)) 310 | ann_list=[] 311 | for img_id in image_list: 312 | ann_file=img_id.replace('.jpg','.txt').replace('.png','.txt') 313 | ann={} 314 | with open(os.path.join(ann_dir,ann_file)) as f: 315 | pos=f.readline().split('\t')[0] 316 | neg=f.readline().split('\t')[0] 317 | ann['pos']=pos 318 | ann['neg']=neg 319 | ann_list.append(ann) 320 | 321 | return image_path,ann_list 322 | 323 | 324 | def main( 325 | ckpt_dir: str, 326 | tokenizer_path: str, 327 | adapter_path: str, 328 | data_root:str, 329 | caption_file:str, 330 | max_seq_len: int, 331 | max_batch_size: int, 332 | llm_model:str='7B', 333 | generation_temperature: float = 0.1, 334 | top_p: float = 0.75, 335 | split='val', 336 | prompt_format='QCM-ALE', 337 | use_caption=False, 338 | options=["A", "B", "C", "D", "E"], 339 | adapter_type='repattn', 340 | adapter_dim=8, 341 | adapter_scale=1, 342 | n_prompt=10, 343 | hidden_proj=128, 344 | visual_adapter_type='normal', 345 | temperature=10., 346 | use_vicuna=False, 347 | root_dir_='../data/mme' 348 | ): 349 | print(max_batch_size,max_seq_len) 350 | print('use caption: ',use_caption) 351 | local_rank, world_size = setup_model_parallel() 352 | if local_rank > 0: 353 | sys.stdout = open(os.devnull, "w") 354 | 355 | generator = load( 356 | ckpt_dir,llm_model, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size, 357 | adapter_type,adapter_dim,adapter_scale,hidden_proj,visual_adapter_type, 358 | temperature,use_vicuna) 359 | 360 | subsets=os.listdir(root_dir_) 361 | total_score=0 362 | cognition_score=0 363 | perception_score=0 364 | for subset in subsets: 365 | root_dir=os.path.join(root_dir_,subset) 366 | print('split: ', subset) 367 | img_list,ann_list=prepare_data(root_dir) 368 | qids=range(len(img_list)) 369 | total_items=len(img_list) 370 | print('total_items: ',total_items) 371 | 372 | 373 | image_transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 374 | 375 | prompt_args=PromptArgs() 376 | prompt_args.prompt_format = prompt_format 377 | prompt_args.use_caption = use_caption 378 | prompt_args.options = options 379 | 380 | pattern = re.compile(r'The answer is ([A-Z]).') 381 | 382 | answers = [] 383 | preds=[] 384 | max_batch_size=8 385 | for i in range(total_items//max_batch_size+1): 386 | batch_qids=qids[i*max_batch_size:(i+1)*max_batch_size] 387 | if len(batch_qids)==0: 388 | break 389 | indicators = [] 390 | prompts=[] 391 | images = [] 392 | for qid in batch_qids: 393 | #pos 394 | prompt= 'Instruction: '+ ann_list[qid]['pos']+'\n'+\ 395 | 'Response: ' 396 | prompt = prompt.replace(" ", " ").strip() 397 | answer='yes' 398 | image = Image.open(img_list[qid]).convert('RGB') 399 | image = image_transforms(image) 400 | indicator = 1 401 | prompts.append(prompt) 402 | answers.append(answer) 403 | images.append(image.unsqueeze(0)) 404 | indicators.append(indicator) 405 | 406 | #neg 407 | prompt= 'Instruction: '+ ann_list[qid]['neg']+'\n'+\ 408 | 'Response: ' 409 | prompt = prompt.replace(" ", " ").strip() 410 | answer='no' 411 | indicator = 1 412 | prompts.append(prompt) 413 | answers.append(answer) 414 | images.append(image.unsqueeze(0)) 415 | indicators.append(indicator) 416 | images=torch.cat(images,0) 417 | 418 | 419 | results = generator.generate( 420 | prompts,images=images,indicators=indicators, max_gen_len=20, temperature=generation_temperature, top_p=top_p,n_feats=n_prompt 421 | ) 422 | 423 | for result in results: 424 | result=result.lower().strip().split('response:')[1] 425 | if 'yes' in result[:4]: 426 | pred='yes' 427 | elif 'no' in result[:4]: 428 | pred='no' 429 | else: 430 | pred='other' 431 | preds.append(pred) 432 | 433 | #evaluations 434 | correct=0 435 | corrects=[] 436 | assert len(preds)==len(answers) 437 | for i, prediction in enumerate(preds): 438 | if prediction == answers[i]: 439 | correct += 1 440 | corrects.append(1) 441 | else: 442 | corrects.append(0) 443 | import numpy as np 444 | corrects=np.array(corrects) 445 | acc = correct / len(preds) * 100 446 | acc_plus= (corrects.reshape(-1,2).sum(1)==2).sum()/ (len(preds)//2)* 100 447 | total_score+=acc 448 | total_score+=acc_plus 449 | if subset in ['commonsense_reasoning','numerical_calculation','text_translation','code_reasoning']: 450 | cognition_score+=acc_plus 451 | cognition_score+=acc 452 | else: 453 | perception_score+=acc_plus 454 | perception_score+=acc 455 | print('subset: ', subset) 456 | print('overall accuracy: ', acc) 457 | print('overall accuracy+: ', acc_plus) 458 | with open('mme_eval.txt','a') as f: 459 | f.write('subset: '+ subset+'\n') 460 | f.write('accuracy: '+ str(acc)+'\n') 461 | f.write('accuracy+: '+ str(acc_plus)+'\n') 462 | 463 | print('total_score: ',total_score) 464 | print('perception_score: ',perception_score) 465 | print('cognition_score: ',cognition_score) 466 | with open('mme_eval.txt', 'a') as f: 467 | f.write('total_score: ' + str(total_score) + '\n') 468 | f.write('perception_score: ' + str(perception_score) + '\n') 469 | f.write('cognition_score: ' + str(cognition_score) + '\n') 470 | 471 | if __name__ == "__main__": 472 | fire.Fire(main) 473 | -------------------------------------------------------------------------------- /lavin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from .generator import LaVIN_Generator 5 | from .model import ModelArgs, Transformer 6 | from .tokenizer import Tokenizer 7 | -------------------------------------------------------------------------------- /lavin/eval_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional, Tuple 5 | from dataclasses import dataclass 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import clip 12 | import fairscale.nn.model_parallel.initialize as fs_init 13 | from fairscale.nn.model_parallel.layers import ( 14 | ParallelEmbedding, 15 | RowParallelLinear, 16 | ColumnParallelLinear, 17 | ) 18 | from lavin.model import AdapterMLP 19 | 20 | @dataclass 21 | class ModelArgs: 22 | dim: int = 512 23 | n_layers: int = 8 24 | n_heads: int = 8 25 | vocab_size: int = -1 # defined later by tokenizer 26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 27 | norm_eps: float = 1e-5 28 | hidden_proj: int=128 29 | 30 | max_batch_size: int = 32 31 | max_seq_len: int = 2048 32 | 33 | 34 | 35 | class RMSNorm(torch.nn.Module): 36 | def __init__(self, dim: int, eps: float = 1e-6): 37 | super().__init__() 38 | self.eps = eps 39 | self.weight = nn.Parameter(torch.ones(dim)) 40 | 41 | def _norm(self, x): 42 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 43 | 44 | def forward(self, x): 45 | output = self._norm(x.float()).type_as(x) 46 | return output * self.weight 47 | 48 | 49 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 50 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 51 | t = torch.arange(end, device=freqs.device) # type: ignore 52 | freqs = torch.outer(t, freqs).float() # type: ignore 53 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 54 | return freqs_cis 55 | 56 | 57 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 58 | ndim = x.ndim 59 | assert 0 <= 1 < ndim 60 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 61 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 62 | return freqs_cis.view(*shape) 63 | 64 | 65 | def apply_rotary_emb( 66 | xq: torch.Tensor, 67 | xk: torch.Tensor, 68 | freqs_cis: torch.Tensor, 69 | ) -> Tuple[torch.Tensor, torch.Tensor]: 70 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 71 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 72 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 73 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 74 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 75 | return xq_out.type_as(xq), xk_out.type_as(xk) 76 | 77 | 78 | class Attention(nn.Module): 79 | def __init__(self, args: ModelArgs): 80 | super().__init__() 81 | 82 | self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size() 83 | self.head_dim = args.dim // args.n_heads 84 | 85 | self.wq = ColumnParallelLinear( 86 | args.dim, 87 | args.n_heads * self.head_dim, 88 | bias=False, 89 | gather_output=False, 90 | init_method=lambda x: x, 91 | ) 92 | self.wk = ColumnParallelLinear( 93 | args.dim, 94 | args.n_heads * self.head_dim, 95 | bias=False, 96 | gather_output=False, 97 | init_method=lambda x: x, 98 | ) 99 | self.wv = ColumnParallelLinear( 100 | args.dim, 101 | args.n_heads * self.head_dim, 102 | bias=False, 103 | gather_output=False, 104 | init_method=lambda x: x, 105 | ) 106 | self.wo = RowParallelLinear( 107 | args.n_heads * self.head_dim, 108 | args.dim, 109 | bias=False, 110 | input_is_parallel=True, 111 | init_method=lambda x: x, 112 | ) 113 | 114 | self.cache_k = torch.zeros( 115 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 116 | ).cuda() 117 | self.cache_v = torch.zeros( 118 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 119 | ).cuda() 120 | self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1)) 121 | 122 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 123 | bsz, seqlen, _ = x.shape 124 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 125 | 126 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 127 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) 128 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) 129 | 130 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 131 | 132 | self.cache_k = self.cache_k.to(xq) 133 | self.cache_v = self.cache_v.to(xq) 134 | 135 | #add modilaty embedding 136 | if start_pos==0: 137 | self.cache_k[:bsz, start_pos : start_pos + seqlen-1] = xk[:,1:] 138 | self.cache_v[:bsz, start_pos : start_pos + seqlen-1] = xv[:,1:] 139 | 140 | keys = xk 141 | values = xv 142 | else: 143 | self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk 144 | self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv 145 | 146 | keys = self.cache_k[:bsz, : start_pos + seqlen] 147 | values = self.cache_v[:bsz, : start_pos + seqlen] 148 | 149 | 150 | xq = xq.transpose(1, 2) 151 | keys = keys.transpose(1, 2) 152 | values = values.transpose(1, 2) 153 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 154 | if mask is not None: 155 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 156 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 157 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) 158 | 159 | output = output.transpose( 160 | 1, 2 161 | ).contiguous().view(bsz, seqlen, -1) 162 | 163 | return self.wo(output) 164 | 165 | 166 | class FeedForward(nn.Module): 167 | def __init__( 168 | self, 169 | dim: int, 170 | hidden_dim: int, 171 | multiple_of: int, 172 | ): 173 | super().__init__() 174 | hidden_dim = int(2 * hidden_dim / 3) 175 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 176 | 177 | self.w1 = ColumnParallelLinear( 178 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 179 | ) 180 | self.w2 = RowParallelLinear( 181 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x 182 | ) 183 | self.w3 = ColumnParallelLinear( 184 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x 185 | ) 186 | 187 | def forward(self, x): 188 | return self.w2(F.silu(self.w1(x)) * self.w3(x)) 189 | 190 | 191 | class TransformerBlock(nn.Module): 192 | def __init__(self, layer_id: int, args: ModelArgs): 193 | super().__init__() 194 | self.n_heads = args.n_heads 195 | self.dim = args.dim 196 | self.head_dim = args.dim // args.n_heads 197 | self.attention = Attention(args) 198 | self.feed_forward = FeedForward( 199 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of 200 | ) 201 | self.layer_id = layer_id 202 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 203 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 204 | 205 | self.drop_path = nn.Identity() 206 | self.cache_weights = torch.zeros( 207 | (args.max_batch_size, 2) 208 | ).cuda() 209 | self.cache_weights_ffn = torch.zeros( 210 | (args.max_batch_size, 2) 211 | ).cuda() 212 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 213 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter) 214 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 215 | return out 216 | 217 | from torch.cuda.amp import autocast 218 | class Transformer(nn.Module): 219 | def __init__(self, params: ModelArgs): 220 | super().__init__() 221 | self.params = params 222 | self.vocab_size = params.vocab_size 223 | self.n_layers = params.n_layers 224 | 225 | self.tok_embeddings = ParallelEmbedding( 226 | params.vocab_size, params.dim, init_method=lambda x: x 227 | ) 228 | 229 | self.layers = torch.nn.ModuleList() 230 | for layer_id in range(params.n_layers): 231 | self.layers.append(TransformerBlock(layer_id, params)) 232 | 233 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 234 | self.output = ColumnParallelLinear( 235 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x 236 | ) 237 | 238 | self.freqs_cis = precompute_freqs_cis( 239 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 240 | ) 241 | 242 | self.backbone = clip.load('ViT-L/14')[0] 243 | 244 | self.adapter_proj = AdapterMLP(1024, params.hidden_proj, params.dim).float() 245 | self.adapter_modality_embedding=nn.Embedding(2,params.dim).float() 246 | 247 | @torch.inference_mode() 248 | def forward(self, tokens: torch.Tensor, start_pos: int): 249 | with autocast(): 250 | _bsz, seqlen,_ = tokens.shape 251 | # h = self.tok_embeddings(tokens) 252 | h=tokens 253 | self.freqs_cis = self.freqs_cis.to(h.device) 254 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] 255 | 256 | mask = None 257 | if seqlen > 1: 258 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device) 259 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h) 260 | # mask decision token 261 | mask[:, :, 1:, 0] = float("-inf") 262 | 263 | for layer in self.layers: 264 | h = layer(h, start_pos, freqs_cis,mask) 265 | 266 | h = self.norm(h) 267 | output = self.output(h[:, -1, :]) # only compute last logits 268 | return output.float() 269 | -------------------------------------------------------------------------------- /lavin/generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import List 5 | 6 | import torch 7 | 8 | from lavin.tokenizer import Tokenizer 9 | from lavin.eval_model import Transformer 10 | from torch.cuda.amp import autocast 11 | 12 | class LaVIN_Generator: 13 | def __init__(self, model: Transformer, tokenizer: Tokenizer): 14 | self.model = model 15 | self.tokenizer = tokenizer 16 | # self.backbone = clip.load('ViT-B/16', device='cpu')[0] 17 | 18 | def insert_image_embeds(self,examples,image_embeds,prefix_img,prefix_nonimg,img_indicators): 19 | _bsz, seqlen,_ = examples.shape 20 | new_examples=[] 21 | for i, example in enumerate(examples): 22 | if img_indicators[i]>0.: 23 | new_example=torch.cat([example[:1],prefix_img,image_embeds[i],example[1:]],0) 24 | new_example = new_example[:seqlen] 25 | else: 26 | new_example=torch.cat([example[:1],prefix_nonimg,example[1:]],0) 27 | new_example = new_example[:seqlen] 28 | new_examples.append(new_example.unsqueeze(0)) 29 | new_examples = torch.cat(new_examples, 0) 30 | return new_examples 31 | 32 | @torch.inference_mode() 33 | def generate( 34 | self, 35 | prompts: List[str], 36 | images: torch.Tensor, 37 | indicators: List[int], 38 | max_gen_len: int, 39 | n_feats: int=3, 40 | temperature: float = 0.8, 41 | top_p: float = 0.95, 42 | ) -> List[str]: 43 | bsz = len(prompts) 44 | params = self.model.params 45 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size) 46 | self.model.eval() 47 | 48 | prefix_img_token = self.tokenizer.encode("Image: ", bos=True, eos=False) 49 | non_prefix_img_token= self.tokenizer.encode("Image: N/A", bos=True, eos=False) 50 | 51 | images=images.cuda() 52 | self.model.backbone.cuda() 53 | 54 | image_embeds= self.model.backbone.encode_image(images) 55 | image_embeds=self.model.adapter_proj(image_embeds) 56 | 57 | 58 | prompt_tokens=[] 59 | for i,x in enumerate(prompts): 60 | if indicators[i]==1: 61 | token_idx=prefix_img_token+[0]*n_feats+self.tokenizer.encode(x, bos=False, eos=False) 62 | else: 63 | token_idx = non_prefix_img_token + self.tokenizer.encode(x, bos=False, eos=False) 64 | prompt_tokens.append(token_idx) 65 | 66 | 67 | min_prompt_size = min([len(t) for t in prompt_tokens]) 68 | max_prompt_size = max([len(t) for t in prompt_tokens]) 69 | 70 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size) 71 | 72 | tokens = torch.full((bsz, total_len), 0).cuda().long() 73 | input_text_mask=torch.zeros_like(tokens).bool() 74 | 75 | for k, t in enumerate(prompt_tokens): 76 | t=t[:total_len] 77 | tokens[k, : len(t)] = torch.tensor(t).long() 78 | input_text_mask[k,:len(t)]=True 79 | 80 | token_embeds=self.model.tok_embeddings(tokens) 81 | indicators=torch.Tensor(indicators).cuda().long() 82 | modality_embedding=self.model.adapter_modality_embedding(indicators).unsqueeze(1) 83 | 84 | for i in range(len(token_embeds)): 85 | if indicators[i]==1: 86 | pos=len(prefix_img_token) 87 | #insert image emebedding into the sequence 88 | image_token_embed=torch.cat([token_embeds[i,:pos],image_embeds[i],token_embeds[i,pos+n_feats:]],0) 89 | token_embeds[i]=image_token_embed 90 | 91 | 92 | 93 | start_pos = min_prompt_size 94 | prev_pos = 0 95 | for cur_pos in range(start_pos, total_len): 96 | 97 | if prev_pos==0: 98 | h=torch.cat([modality_embedding,token_embeds[:,prev_pos:cur_pos]],1) 99 | else: 100 | h=token_embeds[:,prev_pos:cur_pos] 101 | logits = self.model.forward(h, prev_pos) 102 | if temperature > 0: 103 | probs = torch.softmax(logits / temperature, dim=-1) 104 | next_token = sample_top_p(probs, top_p) 105 | else: 106 | next_token = torch.argmax(logits, dim=-1) 107 | next_token = next_token.reshape(-1) 108 | # only replace token if prompt has already been generated 109 | 110 | next_token_embeds = torch.where( 111 | input_text_mask[:, cur_pos,None], token_embeds[:, cur_pos], self.model.tok_embeddings(next_token) 112 | ) 113 | token_embeds[:,cur_pos]=next_token_embeds 114 | 115 | next_token = torch.where( 116 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token 117 | ) 118 | tokens[:, cur_pos] = next_token 119 | 120 | prev_pos = cur_pos 121 | 122 | decoded = [] 123 | for i, t in enumerate(tokens.tolist()): 124 | # cut to max gen len 125 | t = t[: len(prompt_tokens[i]) + max_gen_len] 126 | # cut to eos tok if any 127 | try: 128 | t = t[: t.index(self.tokenizer.eos_id)] 129 | except ValueError: 130 | pass 131 | decoded.append(self.tokenizer.decode(t)) 132 | 133 | 134 | return decoded 135 | 136 | 137 | def sample_top_p(probs, p): 138 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 139 | probs_sum = torch.cumsum(probs_sort, dim=-1) 140 | mask = probs_sum - probs_sort > p 141 | probs_sort[mask] = 0.0 142 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 143 | next_token = torch.multinomial(probs_sort, num_samples=1) 144 | next_token = torch.gather(probs_idx, -1, next_token) 145 | return next_token 146 | -------------------------------------------------------------------------------- /lavin/mm_adaptation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | import json 5 | from lavin import ModelArgs, Tokenizer, Transformer 6 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter 7 | 8 | from pathlib import Path 9 | from util.apply_delta import apply_model_delta_online 10 | 11 | 12 | 13 | def _load_and_redistribute_checkpoint(llama_model_path, model_name): 14 | 15 | with open(Path(llama_model_path) / model_name / 'params.json') as f: 16 | params = json.load(f) 17 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model')) 18 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name)) 19 | if model_name=='7B': 20 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu") 21 | return checkpoint, tokenizer, params 22 | 23 | 24 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth') 25 | checkpoints = sorted(checkpoints) 26 | 27 | 28 | loaded = [] 29 | for x in checkpoints: 30 | print('loading from', x) 31 | loaded.append(torch.load(x, map_location='cpu')) 32 | 33 | full_state_dict = {} 34 | split_dims = {} 35 | 36 | def add_weight_with_split_dim(name, dim): 37 | if dim < 0: # bcast without split 38 | full_state_dict[name] = loaded[0][name].clone() 39 | else: 40 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim) 41 | for x in loaded: 42 | del x[name] 43 | split_dims[name] = dim 44 | 45 | add_weight_with_split_dim('tok_embeddings.weight', 1) 46 | add_weight_with_split_dim('norm.weight', -1) 47 | add_weight_with_split_dim('output.weight', 0) 48 | for i in range(params['n_layers']): 49 | print('gathering layer %d of %d' % (i, params['n_layers'])) 50 | layer_prefix = f'layers.{i}.' 51 | bcast_names = [ 52 | 'attention_norm.weight', 53 | 'ffn_norm.weight', 54 | ] 55 | column_parallel_names = [ 56 | 'attention.wq.weight', 57 | 'attention.wk.weight', 58 | 'attention.wv.weight', 59 | 'feed_forward.w1.weight', 60 | 'feed_forward.w3.weight', 61 | ] 62 | row_parallel_names = [ 63 | 'attention.wo.weight', 64 | 'feed_forward.w2.weight', 65 | ] 66 | for key in bcast_names: 67 | add_weight_with_split_dim(layer_prefix + key, -1) 68 | for key in column_parallel_names: 69 | add_weight_with_split_dim(layer_prefix + key, 0) 70 | for key in row_parallel_names: 71 | add_weight_with_split_dim(layer_prefix + key, 1) 72 | 73 | checkpoint=full_state_dict 74 | 75 | 76 | return checkpoint, tokenizer, params 77 | 78 | def LaVIN(args): 79 | 80 | llama_model_path =args.llama_model_path 81 | model_name = args.llm_model 82 | 83 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(llama_model_path, model_name) 84 | 85 | 86 | model_args: ModelArgs = ModelArgs( 87 | max_seq_len=args.max_seq_len, max_batch_size=32,hidden_proj=args.hidden_proj,drop_path=args.drop_path, **params 88 | ) 89 | 90 | model_args.vocab_size = tokenizer.n_words 91 | 92 | if args.cpu_load: 93 | #cpu load is slow, but is freindly for GPU with limited memory. 94 | torch.set_default_tensor_type(torch.HalfTensor) 95 | else: 96 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 97 | 98 | llama = Transformer(model_args) 99 | 100 | #delete language encoder 101 | del llama.backbone.transformer 102 | 103 | torch.set_default_tensor_type(torch.FloatTensor) 104 | 105 | if args.bits in ['4bit','8bit']: 106 | from util.quantization import quant_model_bnb 107 | llama.layers=quant_model_bnb(llama.layers,quant_bit=args.bits) 108 | 109 | llama.load_state_dict(checkpoint, strict=False) 110 | if args.use_vicuna: 111 | apply_model_delta_online(llama,'../data/weights/vicuna_'+args.llm_model) 112 | 113 | 114 | if args.adapter_type=='block' or args.adapter_type=='attn': 115 | set_MMAdapter(llama,args.adapter_type,dim=args.adapter_dim,s=args.adapter_scale,t=args.temperature,gradient_checkpointing=args.gradient_checkpointing) 116 | set_Clip_Adapter(llama.backbone.visual,args.visual_adapter_type,dim=args.adapter_dim,s=args.adapter_scale,t=args.temperature) 117 | 118 | 119 | 120 | 121 | learnable_keys=['adapter'] 122 | total=0. 123 | trainable_names=[] 124 | for name, param in llama.named_parameters(): 125 | for key in learnable_keys: 126 | 127 | if key in name: 128 | param.requires_grad = True 129 | param.data = param.data.float() 130 | total += param.nelement() 131 | trainable_names.append(name) 132 | else: 133 | param.requires_grad = False 134 | print(trainable_names) 135 | print(' + Number of trainable params: %.2fM' % (total / 1e6)) 136 | return llama 137 | 138 | -------------------------------------------------------------------------------- /lavin/mm_adapter.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import lavin 5 | from typing import Optional, Tuple 6 | from torch.cuda.amp import autocast 7 | import lavin.eval_model 8 | 9 | 10 | class RepAdapter_Router(nn.Module): 11 | """ Pytorch Implemention of RepAdapter for 1d tensor""" 12 | 13 | def __init__( 14 | self, 15 | in_features=768, 16 | hidden_dim=8, 17 | groups=2, 18 | scale=1, 19 | t=10. 20 | ): 21 | super().__init__() 22 | self.conv_A=nn.Conv1d(in_features,hidden_dim,1,groups=1,bias=True) 23 | self.conv_B = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True) 24 | 25 | self.conv_D = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True) 26 | 27 | self.expert_weights=nn.Linear(in_features,2) 28 | 29 | self.dropout=nn.Dropout(0.1) 30 | self.groups=groups 31 | self.scale=scale 32 | self.t=t 33 | 34 | nn.init.xavier_uniform_( self.conv_A.weight) 35 | nn.init.zeros_(self.conv_A.bias) 36 | nn.init.zeros_(self.conv_B.weight) 37 | nn.init.zeros_(self.conv_B.bias) 38 | 39 | 40 | nn.init.zeros_(self.conv_D.weight) 41 | nn.init.zeros_(self.conv_D.bias) 42 | 43 | def forward(self, x,weights=None): 44 | with autocast(): 45 | if weights is None: 46 | weights=torch.softmax(self.expert_weights(x[:,0])/self.t,-1).half() 47 | x=x.transpose(1,2) 48 | x_=self.dropout(self.conv_A(x)) 49 | x=self.conv_B(x_)*self.scale*weights[:,0,None,None]+self.conv_D(x_)*self.scale*weights[:,1,None,None]+x 50 | x=x.transpose(1,2).contiguous() 51 | return x 52 | 53 | 54 | 55 | class RepAdapter(nn.Module): 56 | """ 57 | Pytorch Implemention of RepAdapter for 1d tensor 58 | copy from https://github.com/luogen1996/RepAdapter/blob/main/repadapter.py 59 | """ 60 | 61 | def __init__( 62 | self, 63 | in_features=768, 64 | hidden_dim=8, 65 | groups=2, 66 | scale=1 67 | ): 68 | super().__init__() 69 | self.conv_A=nn.Conv1d(in_features,hidden_dim,1,groups=1,bias=True) 70 | self.conv_B = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True) 71 | 72 | self.dropout=nn.Dropout(0.1) 73 | self.groups=groups 74 | self.scale=scale 75 | 76 | nn.init.xavier_uniform_( self.conv_A.weight) 77 | nn.init.zeros_(self.conv_A.bias) 78 | nn.init.zeros_(self.conv_B.weight) 79 | nn.init.zeros_(self.conv_B.bias) 80 | 81 | def forward(self, x,weights=None): 82 | with autocast(): 83 | x=x.transpose(1,2) 84 | x=self.conv_B(self.dropout(self.conv_A(x))) 85 | x=x.transpose(1,2).contiguous() 86 | return x.float() 87 | 88 | 89 | def forward_llama_block(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 90 | if self.training and self.gradient_checkpointing: 91 | h = x + self.drop_path(torch.utils.checkpoint.checkpoint(self.attention, self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask)) 92 | out = h + self.drop_path(torch.utils.checkpoint.checkpoint(self.feed_forward, self.adapter_mlp(self.ffn_norm(h)))) 93 | else: 94 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask, adapter)) 95 | out = h + self.drop_path(self.feed_forward.forward(self.adapter_mlp(self.ffn_norm(h)))) 96 | return out 97 | 98 | def forward_llama_attn(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 99 | if self.training and self.gradient_checkpointing: 100 | h = x + self.drop_path(torch.utils.checkpoint.checkpoint(self.attention, self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask)) 101 | out = h + self.drop_path(torch.utils.checkpoint.checkpoint(self.feed_forward, self.ffn_norm(h))) 102 | else: 103 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask, adapter)) 104 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h))) 105 | return out 106 | def forward_llama_attn_cache(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 107 | bs_=x.shape[0] 108 | if start_pos==0: 109 | self.cache_weights[:bs_]=torch.softmax(self.adapter_attn.expert_weights(self.attention_norm(x)[:,0].float())/self.t,-1).half() 110 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x),weights=self.cache_weights[:bs_]), start_pos, freqs_cis, mask, adapter)) 111 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h))) 112 | return out 113 | 114 | def forward_llama_block_cache(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 115 | bs_=x.shape[0] 116 | if start_pos==0: 117 | self.cache_weights[:bs_]=torch.softmax(self.adapter_attn.expert_weights(self.attention_norm(x)[:,0].float())/self.t,-1).half() 118 | self.cache_weights_ffn[:bs_]=torch.softmax(self.adapter_mlp.expert_weights(self.ffn_norm(x)[:,0].float())/self.t,-1).half() 119 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x),weights=self.cache_weights[:bs_]), start_pos, freqs_cis, mask, adapter)) 120 | out = h + self.drop_path(self.feed_forward.forward(self.adapter_mlp(self.ffn_norm(h),self.cache_weights_ffn[:bs_]))) 121 | return out 122 | 123 | def forward_clip(self, x: torch.Tensor): 124 | x = x + self.attention(self.adapter_attn(self.ln_1(x))) 125 | x = x + self.mlp(self.ln_2(x)) 126 | return x 127 | 128 | def forward_clip_full(self, x: torch.Tensor): 129 | x = x + self.attention(self.adapter_attn(self.ln_1(x))) 130 | x = x + self.mlp(self.adapter_mlp(self.ln_2(x))) 131 | return x 132 | 133 | 134 | def set_MMAdapter(model, method, dim=8, s=1, set_forward=True,t=10,gradient_checkpointing=False): 135 | if method == 'block': 136 | # not support right now 137 | assert NotImplementedError 138 | for _ in model.children(): 139 | if type(_) == lavin.model.TransformerBlock or type(_) == lavin.eval_model.TransformerBlock: 140 | _.adapter_attn = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t) 141 | _.adapter_mlp = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t) 142 | _.s = s 143 | _.t = t 144 | _.gradient_checkpointing=gradient_checkpointing 145 | if type(_) == lavin.eval_model.TransformerBlock: 146 | bound_method = forward_llama_block_cache.__get__(_, _.__class__) 147 | else: 148 | bound_method = forward_llama_block.__get__(_, _.__class__) 149 | if set_forward: 150 | setattr(_, 'forward', bound_method) 151 | elif len(list(_.children())) != 0: 152 | set_MMAdapter(_, method, dim, s,set_forward=set_forward,t=t,gradient_checkpointing=gradient_checkpointing) 153 | 154 | else: 155 | for _ in model.children(): 156 | if type(_) == lavin.model.TransformerBlock or type(_) == lavin.eval_model.TransformerBlock: 157 | _.adapter_attn = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t) 158 | _.s = s 159 | _.t=t 160 | _.gradient_checkpointing = gradient_checkpointing 161 | if type(_) == lavin.eval_model.TransformerBlock: 162 | bound_method = forward_llama_attn_cache.__get__(_, _.__class__) 163 | else: 164 | bound_method = forward_llama_attn.__get__(_, _.__class__) 165 | if set_forward: 166 | setattr(_, 'forward', bound_method) 167 | elif len(list(_.children())) != 0: 168 | set_MMAdapter(_, method, dim, s, set_forward=set_forward,t=t,gradient_checkpointing=gradient_checkpointing) 169 | 170 | 171 | from clip.model import ResidualAttentionBlock 172 | def set_Clip_Adapter(model, method, dim=8, s=1, set_forward=True, t=10.): 173 | for _ in model.children(): 174 | if type(_) == ResidualAttentionBlock: 175 | if method=='router': 176 | _.adapter_attn = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t) 177 | elif method=='router_block': 178 | _.adapter_attn = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t) 179 | _.adapter_mlp = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t) 180 | else: 181 | _.adapter_attn = RepAdapter(1024, hidden_dim=dim, scale=s) 182 | _.s = s 183 | if method=='router_block': 184 | bound_method = forward_clip_full.__get__(_, _.__class__) 185 | else: 186 | bound_method = forward_clip.__get__(_, _.__class__) 187 | if set_forward: 188 | setattr(_, 'forward', bound_method) 189 | elif len(list(_.children())) != 0: 190 | set_Clip_Adapter(_, method, dim, s, set_forward=set_forward, t=t) 191 | -------------------------------------------------------------------------------- /lavin/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from typing import Optional, Tuple 5 | from dataclasses import dataclass 6 | import math 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | import fairscale.nn.model_parallel.initialize as fs_init 13 | from fairscale.nn.model_parallel.layers import ( 14 | ParallelEmbedding, 15 | RowParallelLinear, 16 | ColumnParallelLinear, 17 | ) 18 | 19 | from torch.nn import Embedding, Linear 20 | import torch 21 | import pdb 22 | from timm.models.layers import DropPath 23 | import clip 24 | from torch.cuda.amp import autocast 25 | @dataclass 26 | class ModelArgs: 27 | dim: int = 512 28 | n_layers: int = 8 29 | n_heads: int = 8 30 | vocab_size: int = -1 # defined later by tokenizer 31 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 32 | norm_eps: float = 1e-5 33 | hidden_proj: int=128 34 | 35 | max_batch_size: int = 32 36 | max_seq_len: int = 2048 37 | drop_path: float=0. 38 | 39 | 40 | class RMSNorm(torch.nn.Module): 41 | def __init__(self, dim: int, eps: float = 1e-6): 42 | super().__init__() 43 | self.eps = eps 44 | self.weight = nn.Parameter(torch.ones(dim)) 45 | 46 | def _norm(self, x): 47 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 48 | 49 | def forward(self, x): 50 | output = self._norm(x.float()).type_as(x) 51 | return output * self.weight 52 | 53 | 54 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): 55 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 56 | t = torch.arange(end, device=freqs.device) # type: ignore 57 | freqs = torch.outer(t, freqs).float() # type: ignore 58 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 59 | return freqs_cis 60 | 61 | 62 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 63 | ndim = x.ndim 64 | assert 0 <= 1 < ndim 65 | assert freqs_cis.shape == (x.shape[1], x.shape[-1]) 66 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] 67 | return freqs_cis.view(*shape) 68 | 69 | 70 | def apply_rotary_emb( 71 | xq: torch.Tensor, 72 | xk: torch.Tensor, 73 | freqs_cis: torch.Tensor, 74 | ) -> Tuple[torch.Tensor, torch.Tensor]: 75 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 76 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 77 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 78 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 79 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 80 | return xq_out.type_as(xq), xk_out.type_as(xk) 81 | 82 | 83 | class Attention(nn.Module): 84 | def __init__(self, args: ModelArgs): 85 | super().__init__() 86 | 87 | self.n_local_heads = args.n_heads 88 | self.head_dim = args.dim // args.n_heads 89 | 90 | #modified bias for reparameterizing 91 | self.wq = Linear( 92 | args.dim, 93 | args.n_heads * self.head_dim, 94 | bias=False 95 | ) 96 | self.wk = Linear( 97 | args.dim, 98 | args.n_heads * self.head_dim, 99 | bias=False 100 | ) 101 | self.wv = Linear( 102 | args.dim, 103 | args.n_heads * self.head_dim, 104 | bias=False 105 | ) 106 | self.wo = Linear( 107 | args.n_heads * self.head_dim, 108 | args.dim, 109 | bias=False 110 | ) 111 | 112 | # self.cache_k = torch.zeros( 113 | # (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 114 | # ).cuda() 115 | # self.cache_v = torch.zeros( 116 | # (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim) 117 | # ).cuda() 118 | # self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1)) 119 | 120 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 121 | 122 | bsz, seqlen, _ = x.shape 123 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 124 | 125 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) 126 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) 127 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) 128 | 129 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) 130 | 131 | keys = xk 132 | values = xv 133 | 134 | 135 | xq = xq.transpose(1, 2) 136 | keys = keys.transpose(1, 2) 137 | values = values.transpose(1, 2) 138 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 139 | if mask is not None: 140 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen) 141 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 142 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) 143 | output = output.transpose( 144 | 1, 2 145 | ).contiguous().view(bsz, seqlen, -1) 146 | 147 | return self.wo(output) 148 | 149 | 150 | class FeedForward(nn.Module): 151 | def __init__( 152 | self, 153 | dim: int, 154 | hidden_dim: int, 155 | multiple_of: int, 156 | ): 157 | super().__init__() 158 | hidden_dim = int(2 * hidden_dim / 3) 159 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) 160 | 161 | self.w1 = Linear( 162 | dim, hidden_dim, bias=False 163 | ) 164 | self.w2 = Linear( 165 | hidden_dim, dim, bias=False 166 | ) 167 | self.w3 = Linear( 168 | dim, hidden_dim, bias=False 169 | ) 170 | 171 | def forward(self, x): 172 | return self.w2(F.silu(self.w1(x),inplace=False) * self.w3(x)) 173 | 174 | 175 | class TransformerBlock(nn.Module): 176 | def __init__(self, layer_id: int, args: ModelArgs): 177 | super().__init__() 178 | self.n_heads = args.n_heads 179 | self.dim = args.dim 180 | self.head_dim = args.dim // args.n_heads 181 | self.attention = Attention(args) 182 | self.feed_forward = FeedForward( 183 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of 184 | ) 185 | self.layer_id = layer_id 186 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 187 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 188 | self.drop_path = DropPath(args.drop_path) if args.drop_path > 0. else nn.Identity() 189 | 190 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None): 191 | 192 | h = x + self.drop_path(self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter)) 193 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h))) 194 | return out 195 | 196 | 197 | 198 | class AdapterMLP(nn.Module): 199 | """ Pytorch Implemention of RepAdapter for 1d tensor""" 200 | 201 | def __init__( 202 | self, 203 | in_features=768, 204 | hidden_dim=128, 205 | out_features=4096 206 | ): 207 | super().__init__() 208 | self.conv_A=nn.Linear(in_features,hidden_dim) 209 | self.conv_B = nn.Linear(hidden_dim, out_features) 210 | 211 | 212 | nn.init.xavier_uniform_( self.conv_A.weight) 213 | nn.init.zeros_(self.conv_A.bias) 214 | nn.init.xavier_uniform_(self.conv_B.weight) 215 | nn.init.zeros_(self.conv_B.bias) 216 | 217 | def forward(self, x): 218 | with autocast(): 219 | x=self.conv_B(F.silu(self.conv_A(x))) 220 | return x 221 | 222 | 223 | class Transformer(nn.Module): 224 | def __init__(self, params: ModelArgs): 225 | super().__init__() 226 | self.params = params 227 | self.vocab_size = params.vocab_size 228 | self.n_layers = params.n_layers 229 | self.tok_embeddings = Embedding( 230 | params.vocab_size, params.dim 231 | ) 232 | 233 | 234 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0) 235 | 236 | # with init_empty_weights(): 237 | self.layers = torch.nn.ModuleList() 238 | for layer_id in range(params.n_layers): 239 | self.layers.append(TransformerBlock(layer_id, params)) 240 | 241 | self.norm = RMSNorm(params.dim, eps=params.norm_eps) 242 | self.output = Linear( 243 | params.dim, params.vocab_size, bias=False 244 | ) 245 | 246 | self.freqs_cis = precompute_freqs_cis( 247 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2 248 | ) 249 | 250 | self.backbone = clip.load('ViT-L/14')[0] 251 | 252 | 253 | #handcraft define self.backbone.visual.transformer.width 254 | self.adapter_proj = AdapterMLP(1024, params.hidden_proj, params.dim).float() 255 | self.adapter_modality_embedding=nn.Embedding(2,params.dim).float() 256 | 257 | 258 | 259 | def insert_image_embeds(self,examples,labels,image_embeds,prefix_img,prefix_nonimg,img_indicators): 260 | _bsz, seqlen,_ = examples.shape 261 | new_examples=[] 262 | new_labels=[] 263 | for i, (example,label) in enumerate(zip(examples,labels)): 264 | if img_indicators[i]>0.: 265 | new_example=torch.cat([example[:1],prefix_img,image_embeds[i],example[1:]],0) 266 | new_label=torch.cat([label[:1], 267 | torch.zeros(prefix_img.shape[0]+image_embeds.shape[1]).to(examples.device).type_as(labels), 268 | label[1:]]) 269 | new_example = new_example[:seqlen] 270 | new_label = new_label[:seqlen] 271 | else: 272 | new_example=torch.cat([example[:1],prefix_nonimg,example[1:]],0) 273 | new_label=torch.cat([label[:1], 274 | torch.zeros(prefix_nonimg.shape[0]).to(examples.device).type_as(labels), 275 | label[1:]]) 276 | new_example = new_example[:seqlen] 277 | new_label = new_label[:seqlen] 278 | new_examples.append(new_example.unsqueeze(0)) 279 | new_labels.append(new_label.unsqueeze(0)) 280 | new_examples = torch.cat(new_examples, 0) 281 | new_labels = torch.cat(new_labels, 0) 282 | return new_examples,new_labels 283 | 284 | def forward(self, examples, labels,images=None, prefix_img=None, prefix_nonimg=None,img_indicators=None): 285 | 286 | # print(images.dtype) 287 | image_embeds = self.backbone.encode_image(images).half() 288 | 289 | # print(img_indicators) 290 | if isinstance(img_indicators,list): 291 | img_indicators = torch.Tensor(img_indicators).to(image_embeds.device).long() 292 | modality_embed=self.adapter_modality_embedding(img_indicators.unsqueeze(1)) 293 | 294 | # with autocast(): 295 | image_embeds=self.adapter_proj(image_embeds) 296 | 297 | # print(image_embeds.shape) 298 | 299 | _bsz, seqlen = examples.shape 300 | 301 | examples = self.tok_embeddings(examples) 302 | prefix_img=self.tok_embeddings(prefix_img.unsqueeze(0)).squeeze(0) 303 | prefix_nonimg=self.tok_embeddings(prefix_nonimg.unsqueeze(0)).squeeze(0) 304 | 305 | 306 | h,labels=self.insert_image_embeds(examples,labels,image_embeds,prefix_img,prefix_nonimg,img_indicators) 307 | 308 | h=torch.cat([modality_embed.half(),h],1)[:,:seqlen] 309 | modality_labels=torch.zeros(_bsz,1).to(labels.device).type_as(labels) 310 | labels=torch.cat([modality_labels,labels],1)[:,:seqlen] 311 | 312 | 313 | freqs_cis = self.freqs_cis.to(h.device) 314 | freqs_cis = freqs_cis[:seqlen] 315 | mask = None 316 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device) 317 | mask = torch.triu(mask, diagonal=0 + 1).type_as(h) 318 | 319 | #mask decision token 320 | mask[:,:,1:,0]=float("-inf") 321 | 322 | start_pos = 0 323 | for layer in self.layers: 324 | h = layer(h, start_pos, freqs_cis, mask) 325 | 326 | h = self.norm(h) 327 | output = self.output(h) 328 | output = output[:, :-1, :].reshape(-1, self.vocab_size) 329 | labels = labels[:, 1:].flatten() 330 | 331 | 332 | c_loss = self.criterion(output, labels) 333 | return c_loss 334 | -------------------------------------------------------------------------------- /lavin/tokenizer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from sentencepiece import SentencePieceProcessor 5 | from logging import getLogger 6 | from typing import List 7 | import os 8 | 9 | 10 | logger = getLogger() 11 | 12 | 13 | class Tokenizer: 14 | def __init__(self, model_path: str): 15 | # reload tokenizer 16 | assert os.path.isfile(model_path), model_path 17 | self.sp_model = SentencePieceProcessor(model_file=model_path) 18 | logger.info(f"Reloaded SentencePiece model from {model_path}") 19 | 20 | # BOS / EOS token IDs 21 | self.n_words: int = self.sp_model.vocab_size() 22 | self.bos_id: int = self.sp_model.bos_id() 23 | self.eos_id: int = self.sp_model.eos_id() 24 | self.pad_id: int = self.sp_model.pad_id() 25 | logger.info( 26 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 27 | ) 28 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() 29 | 30 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]: 31 | assert type(s) is str 32 | t = self.sp_model.encode(s) 33 | if bos: 34 | t = [self.bos_id] + t 35 | if eos: 36 | t = t + [self.eos_id] 37 | return t 38 | 39 | def decode(self, t: List[int]) -> str: 40 | return self.sp_model.decode(t) 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.0 2 | fairscale 3 | fire 4 | sentencepiece 5 | transformers==4.30.0 6 | timm 7 | tensorboard 8 | ftfy 9 | gradio 10 | bitsandbytes==0.39.0 -------------------------------------------------------------------------------- /scripts/eval_mme_benchmark.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 1 --master_port 11345 eval_mme.py \ 2 | --ckpt_dir ../data/weights/ \ 3 | --llm_model 7B\ 4 | --tokenizer_path ../data/weights/tokenizer.model \ 5 | --data_root ../data \ 6 | --caption_file ../data/captions.json \ 7 | --adapter_path ./15-eph-pretrain.pth \ 8 | --adapter_type attn \ 9 | --adapter_dim 8 \ 10 | --adapter_scale 1 \ 11 | --prompt_format QCM-ALE \ 12 | --max_batch_size 16\ 13 | --max_seq_len 512 \ 14 | --split test \ 15 | --n_prompt 6 \ 16 | --temperature 5.\ 17 | --visual_adapter_type router -------------------------------------------------------------------------------- /scripts/finetuning_sqa_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 --master_port 12345 train.py \ 2 | --llm_model 13B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 1 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-13B/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 5.\ 19 | --visual_adapter_type router 20 | 21 | torchrun --nproc_per_node 1 eval.py \ 22 | --ckpt_dir ../data/weights/ \ 23 | --llm_model 13B\ 24 | --tokenizer_path ../data/weights/tokenizer.model \ 25 | --data_root ../data \ 26 | --caption_file ../data/captions.json \ 27 | --adapter_path ./LaVIN-13B/checkpoint-19.pth \ 28 | --adapter_type attn \ 29 | --adapter_dim 8 \ 30 | --adapter_scale 1 \ 31 | --prompt_format QCM-ALE \ 32 | --max_batch_size 64\ 33 | --max_seq_len 512 \ 34 | --split test \ 35 | --n_prompt 6 \ 36 | --temperature 5.\ 37 | --visual_adapter_type router -------------------------------------------------------------------------------- /scripts/finetuning_sqa_13b_lite.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 1 train.py \ 2 | --llm_model 13B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 1 \ 7 | --accum_iter 32 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-13B-lite/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 5.\ 19 | --visual_adapter_type router \ 20 | --gradient_checkpointing \ 21 | --bits 4bit \ 22 | --cpu_load 23 | 24 | torchrun --nproc_per_node 1 eval.py \ 25 | --ckpt_dir ../data/weights/ \ 26 | --llm_model 13B\ 27 | --tokenizer_path ../data/weights/tokenizer.model \ 28 | --data_root ../data \ 29 | --caption_file ../data/captions.json \ 30 | --adapter_path ./LaVIN-13B-lite/checkpoint-19.pth \ 31 | --adapter_type attn \ 32 | --adapter_dim 8 \ 33 | --adapter_scale 1 \ 34 | --prompt_format QCM-ALE \ 35 | --max_batch_size 64\ 36 | --max_seq_len 512 \ 37 | --split test \ 38 | --n_prompt 6 \ 39 | --temperature 5.\ 40 | --visual_adapter_type router \ 41 | --bits 4bit -------------------------------------------------------------------------------- /scripts/finetuning_sqa_7b.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port 11111 train.py \ 2 | --llm_model 7B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 4 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-7B/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 10.\ 19 | --visual_adapter_type router 20 | 21 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 --master_port 11111 eval.py \ 22 | --ckpt_dir ../data/weights/ \ 23 | --llm_model 7B\ 24 | --tokenizer_path ../data/weights/tokenizer.model \ 25 | --data_root ../data \ 26 | --caption_file ../data/captions.json \ 27 | --adapter_path ./LaVIN-7B/checkpoint-19.pth \ 28 | --adapter_type attn \ 29 | --adapter_dim 8 \ 30 | --adapter_scale 1 \ 31 | --prompt_format QCM-ALE \ 32 | --max_batch_size 64\ 33 | --max_seq_len 512 \ 34 | --split test \ 35 | --n_prompt 6 \ 36 | --temperature 10.\ 37 | --visual_adapter_type router 38 | -------------------------------------------------------------------------------- /scripts/finetuning_sqa_7b_lite.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 --master_port 11111 train.py \ 2 | --llm_model 7B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 1 \ 7 | --accum_iter 32 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-7B-lite/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 10.\ 19 | --visual_adapter_type router \ 20 | --gradient_checkpointing \ 21 | --bits 4bit \ 22 | --cpu_load 23 | 24 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 eval.py \ 25 | --ckpt_dir ../data/weights/ \ 26 | --llm_model 7B\ 27 | --tokenizer_path ../data/weights/tokenizer.model \ 28 | --data_root ../data \ 29 | --caption_file ../data/captions.json \ 30 | --adapter_path ./LaVIN-7B-lite/checkpoint-19.pth \ 31 | --adapter_type attn \ 32 | --adapter_dim 8 \ 33 | --adapter_scale 1 \ 34 | --prompt_format QCM-ALE \ 35 | --max_batch_size 64\ 36 | --max_seq_len 512 \ 37 | --split test \ 38 | --n_prompt 6 \ 39 | --temperature 10.\ 40 | --visual_adapter_type router\ 41 | --bits 4bit \ 42 | --cpu_load -------------------------------------------------------------------------------- /scripts/finetuning_sqa_vicuna_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \ 2 | --llm_model 13B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 1 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-Vicuna-13B/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 5.\ 19 | --visual_adapter_type router \ 20 | --use_vicuna 21 | 22 | torchrun --nproc_per_node 1 eval.py \ 23 | --ckpt_dir ../data/weights/ \ 24 | --llm_model 13B\ 25 | --tokenizer_path ../data/weights/tokenizer.model \ 26 | --data_root ../data \ 27 | --caption_file ../data/captions.json \ 28 | --adapter_path ./LaVIN-Vicuna-13B/checkpoint-19.pth \ 29 | --adapter_type attn \ 30 | --adapter_dim 8 \ 31 | --adapter_scale 1 \ 32 | --prompt_format QCM-ALE \ 33 | --max_batch_size 64\ 34 | --max_seq_len 512 \ 35 | --split test \ 36 | --n_prompt 6 \ 37 | --temperature 5.\ 38 | --visual_adapter_type router \ 39 | --use_vicuna=True -------------------------------------------------------------------------------- /scripts/finetuning_sqa_vicuna_7b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 2 --master_port 12345 --nproc_per_node 2 train.py \ 2 | --llm_model 7B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 4 \ 8 | --epochs 20 \ 9 | --warmup_epochs 2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-Vicuna-7B/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 10.\ 19 | --visual_adapter_type router\ 20 | --use_vicuna 21 | 22 | torchrun --nproc_per_node 1 eval.py \ 23 | --ckpt_dir ../data/weights/ \ 24 | --llm_model 7B\ 25 | --tokenizer_path ../data/weights/tokenizer.model \ 26 | --data_root ../data \ 27 | --caption_file ../data/captions.json \ 28 | --adapter_path ./LaVIN-Vicuna-7B/checkpoint-19.pth \ 29 | --adapter_type attn \ 30 | --adapter_dim 8 \ 31 | --adapter_scale 1 \ 32 | --prompt_format QCM-ALE \ 33 | --max_batch_size 64\ 34 | --max_seq_len 512 \ 35 | --split test \ 36 | --n_prompt 6 \ 37 | --temperature 10.\ 38 | --visual_adapter_type router\ 39 | --use_vicuna=True -------------------------------------------------------------------------------- /scripts/vl_instruction_tuning_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \ 2 | --llm_model 13B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 1 \ 8 | --epochs 15 \ 9 | --warmup_epochs 0.2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-13B-VLIT/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 5.\ 19 | --visual_adapter_type router \ 20 | --do_pretrain -------------------------------------------------------------------------------- /scripts/vl_instruction_tuning_vicuna_13b.sh: -------------------------------------------------------------------------------- 1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \ 2 | --llm_model 13B\ 3 | --llama_model_path ../data/weights/ \ 4 | --data_path ../data/alpaca_data.json \ 5 | --max_seq_len 512 \ 6 | --batch_size 4 \ 7 | --accum_iter 1 \ 8 | --epochs 15 \ 9 | --warmup_epochs 0.2 \ 10 | --blr 9e-3 \ 11 | --weight_decay 0.02 \ 12 | --output_dir ./LaVIN-Vicuna-13B-VLIT/\ 13 | --adapter_type attn\ 14 | --adapter_dim 8\ 15 | --adapter_scale 1\ 16 | --n_prompt 6 \ 17 | --prompt_format QCM-ALE \ 18 | --temperature 5.\ 19 | --visual_adapter_type router \ 20 | --do_pretrain \ 21 | --use_vicuna -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | setup(name="lavin", version="0.1", packages=find_packages()) -------------------------------------------------------------------------------- /tools/data_processing.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | #instruction: 4 | #answer: 5 | #input: 6 | #options: 7 | #qid: 8 | #image 9 | all_data=[] 10 | with open('../data/alpaca_data.json') as f: 11 | alpaca_data=json.load(f) 12 | for i,item in enumerate(alpaca_data): 13 | data={} 14 | input=item['input'] 15 | if len(input)==0: 16 | input='' 17 | data['instruction'] = 'Instruction: '+ item['instruction']+' '+input+'\n'+\ 18 | 'Response: ' 19 | data['instruction'] = data['instruction'].replace(" ", " ").strip() 20 | data['answer'] = item['output'] 21 | data['image'] = None 22 | data['options'] = None 23 | data['image_source'] = None 24 | data['qid']='alpaca_'+str(i) 25 | all_data.append(data) 26 | 27 | 28 | with open('../data/complex_reasoning_77k.json') as f: 29 | gpt4_data_0=json.load(f) 30 | with open('../data/detail_23k.json') as f: 31 | gpt4_data_1=json.load(f) 32 | with open('../data/conversation_58k.json') as f: 33 | gpt4_data_2=json.load(f) 34 | gpt4_data=gpt4_data_0+gpt4_data_1 35 | for i,item in enumerate(gpt4_data): 36 | data={} 37 | data['instruction'] = 'Instruction: '+item['conversations'][0]['value'].replace('\n','').replace('\n','')+'\n'+ \ 38 | 'Response: ' 39 | data['instruction'] = data['instruction'].replace(" ", " ").strip() 40 | data['answer'] = item['conversations'][1]['value'] 41 | data['image'] = item['image'] 42 | data['image_source']='mscoco' 43 | data['options'] = None 44 | data['qid']='gpt4_'+str(i) 45 | all_data.append(data) 46 | 47 | 48 | for i,item in enumerate(gpt4_data_2): 49 | for j in range(0,len(item['conversations']),2): 50 | data={} 51 | data['instruction'] = 'Instruction: '+item['conversations'][j]['value'].replace('\n','').replace('\n','')+'\n'+ \ 52 | 'Response: ' 53 | data['instruction'] = data['instruction'].replace(" ", " ").strip() 54 | data['answer'] = item['conversations'][j+1]['value'] 55 | data['image'] = item['image'] 56 | data['image_source']='mscoco' 57 | data['options'] = None 58 | data['qid']='gpt4_2_'+str(i)+'_'+str(j) 59 | all_data.append(data) 60 | 61 | 62 | 63 | full_data={} 64 | full_data['all']=all_data 65 | with open('../data/all_data.json','w') as f: 66 | json.dump(full_data,f) 67 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import datetime 4 | import json 5 | import time 6 | import numpy as np 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.backends.cudnn as cudnn 11 | from torch.utils.tensorboard import SummaryWriter 12 | import timm.optim.optim_factory as optim_factory 13 | 14 | import util.misc as misc 15 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 16 | from engine import train_one_epoch 17 | 18 | from util.datasets import ScienceQADataSet,InstrcutDataSet 19 | from lavin.mm_adaptation import LaVIN 20 | import random 21 | import bitsandbytes as bnb 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False) 25 | parser.add_argument('--batch_size', default=64, type=int, 26 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') 27 | parser.add_argument('--epochs', default=400, type=int) 28 | parser.add_argument('--bits', default='16bit', type=str,choices=['4bit','8bit','16bit'], 29 | help='Quantization bits for training, fp16 by default') 30 | parser.add_argument('--accum_iter', default=1, type=int, 31 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') 32 | 33 | # Model parameters 34 | parser.add_argument('--llama_model_path', default='./llama', type=str, 35 | help='path of llama model') 36 | 37 | parser.add_argument('--llm_model', default='7B', type=str, metavar='MODEL', 38 | help='Name of llm model to train') 39 | 40 | parser.add_argument('--use_vicuna', action='store_true', help='use vicuna weights') 41 | 42 | parser.add_argument('--cpu_load', action='store_true', help='load the model on cpu and avoid OOM on gpu') 43 | 44 | #block is not supported now. 45 | parser.add_argument('--adapter_type', type=str, default='attn', metavar='LENGTH',choices=['block','attn'], 46 | help='the insert position of adapter layer') 47 | 48 | 49 | parser.add_argument('--visual_adapter_type', type=str, default='normal', metavar='LENGTH',choices=['normal','router','router_block'], 50 | help='the type of adapter layer') 51 | 52 | parser.add_argument('--adapter_dim', type=int, default=8, metavar='LENGTH', help='the dims of adapter layer') 53 | 54 | parser.add_argument('--hidden_proj', type=int, default=128, metavar='LENGTH', 55 | help='the visual adapter dim') 56 | 57 | parser.add_argument('--temperature', type=float, default=10., metavar='LENGTH', 58 | help='the temperature of router') 59 | 60 | parser.add_argument('--n_prompt', type=int, default=10, metavar='LENGTH', 61 | help='the length of visual features') 62 | parser.add_argument('--adapter_scale', type=float, default=1., metavar='LENGTH', help='the scales of adapter layer') 63 | parser.add_argument('--drop_path', type=float, default=0., metavar='LENGTH', help='drop path') 64 | 65 | parser.add_argument('--max_seq_len', type=int, default=512, metavar='LENGTH', 66 | help='the maximum sequence length') 67 | 68 | 69 | # Optimizer parameters 70 | parser.add_argument('--weight_decay', type=float, default=0.05, 71 | help='weight decay (default: 0.05)') 72 | 73 | parser.add_argument('--lr', type=float, default=None, metavar='LR', 74 | help='learning rate (absolute lr)') 75 | parser.add_argument('--clip_grad', type=float, default=None, metavar='clip gradient', 76 | help='clips gradient norm of an iterable of parameters') 77 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR', 78 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') 79 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR', 80 | help='lower lr bound for cyclic schedulers that hit 0') 81 | 82 | parser.add_argument('--gradient_checkpointing', action='store_true', 83 | help='saving memory costs via gradient_checkpointing') 84 | parser.add_argument('--warmup_epochs', type=float, default=40, metavar='N', 85 | help='epochs to warmup LR') 86 | 87 | # Dataset parameters 88 | parser.add_argument('--data_path', default='/instruction_dataset/', type=str, 89 | help='dataset path') 90 | 91 | parser.add_argument('--output_dir', default='./output_dir', 92 | help='path where to save, empty for no saving') 93 | parser.add_argument('--log_dir', default='./output_dir', 94 | help='path where to tensorboard log') 95 | parser.add_argument('--device', default='cuda', 96 | help='device to use for training / testing') 97 | parser.add_argument('--seed', default=0, type=int) 98 | parser.add_argument('--resume', default='', 99 | help='resume from checkpoint') 100 | 101 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 102 | help='start epoch') 103 | parser.add_argument('--num_workers', default=10, type=int) 104 | parser.add_argument('--pin_mem', action='store_true', 105 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 106 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') 107 | parser.set_defaults(pin_mem=True) 108 | 109 | # distributed training parameters 110 | parser.add_argument('--world_size', default=1, type=int, 111 | help='number of distributed processes') 112 | parser.add_argument('--local_rank', default=-1, type=int) 113 | parser.add_argument('--dist_on_itp', action='store_true') 114 | parser.add_argument('--dist_url', default='env://', 115 | help='url used to set up distributed training') 116 | 117 | #datasets 118 | parser.add_argument('--prompt_format', 119 | type=str, 120 | default='CQM-A', 121 | choices=[ 122 | 'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A', 123 | 'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A', 124 | 'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE' 125 | ], 126 | help='prompt format template') 127 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 128 | parser.add_argument('--caption_file', type=str, default='../data/captions.json') 129 | parser.add_argument('--data_root', type=str, default='../data') 130 | parser.add_argument('--use_caption', action='store_true', help='use image captions or not') 131 | parser.add_argument('--do_pretrain', action='store_true', help='pre-train on large scale vl instruction') 132 | 133 | return parser 134 | 135 | 136 | def main(args): 137 | 138 | misc.init_distributed_mode(args) 139 | 140 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) 141 | print("{}".format(args).replace(', ', ',\n')) 142 | 143 | device = torch.device(args.device) 144 | 145 | # fix the seed for reproducibility 146 | seed = args.seed + misc.get_rank() 147 | torch.manual_seed(seed) 148 | np.random.seed(seed) 149 | g = torch.Generator() 150 | g.manual_seed(seed) 151 | random.seed(seed) 152 | 153 | cudnn.benchmark = False 154 | cudnn.deterministic = True 155 | 156 | 157 | if args.do_pretrain: 158 | dataset_train = InstrcutDataSet(args, 'all', args.llama_model_path, args.max_seq_len) 159 | else: 160 | dataset_train = ScienceQADataSet(args, 'train', args.llama_model_path, args.max_seq_len) 161 | 162 | print(dataset_train) 163 | 164 | 165 | num_tasks = misc.get_world_size() 166 | global_rank = misc.get_rank() 167 | sampler_train = torch.utils.data.DistributedSampler( 168 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 169 | ) 170 | 171 | print("Sampler_train = %s" % str(sampler_train)) 172 | 173 | if global_rank == 0 and args.log_dir is not None: 174 | os.makedirs(args.log_dir, exist_ok=True) 175 | log_writer = SummaryWriter(log_dir=args.log_dir) 176 | else: 177 | log_writer = None 178 | 179 | data_loader_train = torch.utils.data.DataLoader( 180 | dataset_train, sampler=sampler_train, 181 | batch_size=args.batch_size, 182 | num_workers=args.num_workers, 183 | pin_memory=args.pin_mem, 184 | drop_last=True, 185 | generator=g, 186 | 187 | ) 188 | 189 | 190 | 191 | 192 | # define the model 193 | model = LaVIN(args) 194 | 195 | 196 | model.to(device) 197 | 198 | #for debug. print the data type. 199 | for name, param in model.named_parameters(): 200 | print(name,param.dtype) 201 | 202 | model_without_ddp = model 203 | 204 | #for debug. print the model. 205 | # print("Model = %s" % str(model_without_ddp)) 206 | 207 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 208 | 209 | if args.lr is None: # only base_lr is specified 210 | args.lr = args.blr * eff_batch_size / 256 211 | 212 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 213 | print("actual lr: %.2e" % args.lr) 214 | 215 | print("accumulate grad iterations: %d" % args.accum_iter) 216 | print("effective batch size: %d" % eff_batch_size) 217 | 218 | if args.distributed: 219 | print(args.gpu) 220 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True) 221 | model_without_ddp = model.module 222 | 223 | # following timm: set wd as 0 for bias and norm layers 224 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay) 225 | 226 | #following qlora: apply paged optimizer 227 | optimizer = bnb.optim.AdamW32bit(param_groups, lr=args.lr, betas=(0.9, 0.95),is_paged=True) #torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) 228 | print(optimizer) 229 | 230 | #mixed precision scaler 231 | loss_scaler = NativeScaler() 232 | 233 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) 234 | 235 | print(f"Start training for {args.epochs} epochs") 236 | start_time = time.time() 237 | for epoch in range(args.start_epoch, args.epochs): 238 | 239 | if args.distributed: 240 | data_loader_train.sampler.set_epoch(epoch) 241 | 242 | train_stats = train_one_epoch( 243 | model, data_loader_train, 244 | optimizer, device, epoch, loss_scaler, 245 | log_writer=log_writer, 246 | args=args 247 | ) 248 | 249 | if args.output_dir: 250 | misc.save_model( 251 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, 252 | loss_scaler=loss_scaler, epoch=epoch) 253 | 254 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 255 | 'epoch': epoch,} 256 | 257 | 258 | if args.output_dir and misc.is_main_process(): 259 | if log_writer is not None: 260 | log_writer.flush() 261 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: 262 | f.write(json.dumps(log_stats) + "\n") 263 | 264 | total_time = time.time() - start_time 265 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 266 | print('Training time {}'.format(total_time_str)) 267 | 268 | 269 | if __name__ == '__main__': 270 | 271 | args = get_args_parser() 272 | args = args.parse_args() 273 | if args.output_dir: 274 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 275 | main(args) 276 | -------------------------------------------------------------------------------- /util/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Apply the delta weights on top of a base model. 3 | 4 | Usage: 5 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1 6 | """ 7 | import argparse 8 | import gc 9 | import glob 10 | import json 11 | import os 12 | import shutil 13 | import tempfile 14 | 15 | from huggingface_hub import snapshot_download 16 | import torch 17 | from torch import nn 18 | from tqdm import tqdm 19 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig 20 | 21 | 22 | GB = 1 << 30 23 | 24 | 25 | def split_files(model_path, tmp_path, split_size): 26 | if not os.path.exists(model_path): 27 | model_path = snapshot_download(repo_id=model_path) 28 | if not os.path.exists(tmp_path): 29 | os.makedirs(tmp_path) 30 | 31 | file_pattern = os.path.join(model_path, "pytorch_model-*.bin") 32 | files = glob.glob(file_pattern) 33 | 34 | part = 0 35 | try: 36 | for file_path in tqdm(files): 37 | state_dict = torch.load(file_path) 38 | new_state_dict = {} 39 | 40 | current_size = 0 41 | for name, param in state_dict.items(): 42 | param_size = param.numel() * param.element_size() 43 | 44 | if current_size + param_size > split_size: 45 | new_file_name = f"pytorch_model-{part}.bin" 46 | new_file_path = os.path.join(tmp_path, new_file_name) 47 | torch.save(new_state_dict, new_file_path) 48 | current_size = 0 49 | new_state_dict = None 50 | gc.collect() 51 | new_state_dict = {} 52 | part += 1 53 | 54 | new_state_dict[name] = param 55 | current_size += param_size 56 | 57 | new_file_name = f"pytorch_model-{part}.bin" 58 | new_file_path = os.path.join(tmp_path, new_file_name) 59 | torch.save(new_state_dict, new_file_path) 60 | new_state_dict = None 61 | gc.collect() 62 | new_state_dict = {} 63 | part += 1 64 | except Exception as e: 65 | print(f"An error occurred during split_files: {e}") 66 | shutil.rmtree(tmp_path) 67 | raise 68 | 69 | 70 | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): 71 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 72 | delta_config = AutoConfig.from_pretrained(delta_path) 73 | 74 | if os.path.exists(target_model_path): 75 | shutil.rmtree(target_model_path) 76 | os.makedirs(target_model_path) 77 | 78 | split_size = 4 * GB 79 | 80 | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: 81 | print(f"Split files for the base model to {tmp_base_path}") 82 | split_files(base_model_path, tmp_base_path, split_size) 83 | print(f"Split files for the delta weights to {tmp_delta_path}") 84 | split_files(delta_path, tmp_delta_path, split_size) 85 | 86 | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") 87 | base_files = glob.glob(base_pattern) 88 | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") 89 | delta_files = glob.glob(delta_pattern) 90 | delta_state_dict = torch.load(delta_files[0]) 91 | 92 | print("Applying the delta") 93 | weight_map = {} 94 | total_size = 0 95 | 96 | for i, base_file in tqdm(enumerate(base_files)): 97 | state_dict = torch.load(base_file) 98 | file_name = f"pytorch_model-{i}.bin" 99 | for name, param in state_dict.items(): 100 | if name not in delta_state_dict: 101 | for delta_file in delta_files: 102 | delta_state_dict = torch.load(delta_file) 103 | gc.collect() 104 | if name in delta_state_dict: 105 | break 106 | 107 | state_dict[name] += delta_state_dict[name] 108 | weight_map[name] = file_name 109 | total_size += param.numel() * param.element_size() 110 | gc.collect() 111 | torch.save(state_dict, os.path.join(target_model_path, file_name)) 112 | 113 | with open( 114 | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" 115 | ) as f: 116 | json.dump( 117 | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f 118 | ) 119 | 120 | print(f"Saving the target model to {target_model_path}") 121 | delta_tokenizer.save_pretrained(target_model_path) 122 | delta_config.save_pretrained(target_model_path) 123 | 124 | 125 | def apply_delta(base_model_path, target_model_path, delta_path): 126 | print(f"Loading the delta weights from {delta_path}") 127 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 128 | delta = AutoModelForCausalLM.from_pretrained( 129 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 130 | ) 131 | 132 | print(f"Loading the base model from {base_model_path}") 133 | base = AutoModelForCausalLM.from_pretrained( 134 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 135 | ) 136 | 137 | print("Applying the delta") 138 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"): 139 | assert name in delta.state_dict() 140 | param.data += delta.state_dict()[name] 141 | 142 | print(f"Saving the target model to {target_model_path}") 143 | base.save_pretrained(target_model_path) 144 | delta_tokenizer.save_pretrained(target_model_path) 145 | 146 | 147 | def huggingface2llama(key): 148 | key=key.replace('model.layers','layers') 149 | if 'embed_tokens' in key: 150 | key=key.replace('embed_tokens','tok_embeddings') 151 | if '.self_attn.' in key: 152 | key = key.replace( '.self_attn.q_proj','.attention.wq') 153 | key = key.replace( '.self_attn.k_proj','.attention.wk') 154 | key = key.replace( '.self_attn.v_proj','.attention.wv') 155 | key = key.replace( '.self_attn.o_proj','.attention.wo') 156 | if '.input_layernorm.' in key: 157 | key=key.replace('.input_layernorm.','.attention_norm.') 158 | if '.post_attention_layernorm.' in key: 159 | key = key.replace('.post_attention_layernorm.', '.ffn_norm.') 160 | if '.mlp.' in key: 161 | key=key.replace('.mlp.','.feed_forward.') 162 | if '.down_proj.' in key: 163 | key = key.replace('.down_proj.', '.w2.') 164 | if '.up_proj.' in key: 165 | key = key.replace('.up_proj.', '.w3.') 166 | if '.gate_proj.' in key: 167 | key = key.replace('.gate_proj.', '.w1.') 168 | if key=='model.norm.weight': 169 | key=key.replace('model.norm.','norm.') 170 | if key=='lm_head.weight': 171 | key=key.replace('lm_head.','output.') 172 | return key 173 | 174 | def llama2huggingface(key): 175 | key = key.replace('layers', 'model.layers') 176 | if 'tok_embeddings' in key: 177 | key = key.replace('tok_embeddings', 'model.embed_tokens') 178 | if '.attention.wq' in key: 179 | key = key.replace('.attention.wq', '.self_attn.q_proj') 180 | if '.attention.wk' in key: 181 | key = key.replace('.attention.wk', '.self_attn.k_proj') 182 | if '.attention.wv' in key: 183 | key = key.replace('.attention.wv', '.self_attn.v_proj') 184 | if '.attention.wo' in key: 185 | key = key.replace('.attention.wo', '.self_attn.o_proj') 186 | if '.attention_norm.' in key: 187 | key = key.replace('.attention_norm.', '.input_layernorm.') 188 | if '.ffn_norm.' in key: 189 | key = key.replace('.ffn_norm.', '.post_attention_layernorm.') 190 | if '.w3.' in key: 191 | key = key.replace('.w3.', '.up_proj.') 192 | if '.w2.' in key: 193 | key = key.replace('.w2.', '.down_proj.') 194 | if '.w1.' in key: 195 | key = key.replace('.w1.', '.gate_proj.') 196 | if '.feed_forward.' in key: 197 | key = key.replace('.feed_forward.', '.mlp.') 198 | if key=='norm.weight': 199 | key=key.replace('norm.','model.norm.') 200 | if key=='output.weight': 201 | key=key.replace('output.','lm_head.') 202 | return key 203 | 204 | def apply_model_delta_online(base_model, delta_path): 205 | print(f"Loading the delta weights from {delta_path}") 206 | # delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) 207 | delta = AutoModelForCausalLM.from_pretrained( 208 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True 209 | ) 210 | 211 | candidate_weight=set() 212 | exclude_weight=[] 213 | for name, param in base_model.state_dict().items(): 214 | if llama2huggingface(name) in delta.state_dict(): 215 | candidate_weight.add(name) 216 | else: 217 | exclude_weight.append(name) 218 | print("excluding these weights in llama: ",exclude_weight) 219 | 220 | print("Applying the delta") 221 | for name, param in base_model.named_parameters(): 222 | if name in candidate_weight: 223 | assert llama2huggingface(name) in delta.state_dict() 224 | param.data += delta.state_dict()[llama2huggingface(name)].to(param.data.device) 225 | 226 | 227 | 228 | if __name__ == "__main__": 229 | parser = argparse.ArgumentParser() 230 | parser.add_argument("--base-model-path", type=str, required=True) 231 | parser.add_argument("--target-model-path", type=str, required=True) 232 | parser.add_argument("--delta-path", type=str, required=True) 233 | parser.add_argument( 234 | "--low-cpu-mem", 235 | action="store_true", 236 | help="Lower the cpu memory usage. This will split large files and use " 237 | "disk as swap to reduce the memory usage below 10GB.", 238 | ) 239 | args = parser.parse_args() 240 | 241 | if args.low_cpu_mem: 242 | apply_delta_low_cpu_mem( 243 | args.base_model_path, args.target_model_path, args.delta_path 244 | ) 245 | else: 246 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) -------------------------------------------------------------------------------- /util/base_prompt.py: -------------------------------------------------------------------------------- 1 | def get_question_text(problem): 2 | question = problem['question'] 3 | return question 4 | 5 | 6 | def get_context_text(problem, use_caption): 7 | txt_context = problem['hint'] 8 | img_context = problem['caption'] if use_caption else "" 9 | context = " ".join([txt_context, img_context]).strip() 10 | if context == "": 11 | context = "N/A" 12 | return context 13 | 14 | 15 | def get_choice_text(probelm, options): 16 | choices = probelm['choices'] 17 | choice_list = [] 18 | for i, c in enumerate(choices): 19 | choice_list.append("({}) {}".format(options[i], c)) 20 | choice_txt = " ".join(choice_list) 21 | #print(choice_txt) 22 | return choice_txt 23 | 24 | 25 | def get_answer(problem, options): 26 | return options[problem['answer']] 27 | 28 | 29 | def get_lecture_text(problem): 30 | # \\n: GPT-3 can generate the lecture with more tokens. 31 | lecture = problem['lecture'].replace("\n", "\\n") 32 | return lecture 33 | 34 | 35 | def get_solution_text(problem): 36 | # \\n: GPT-3 can generate the solution with more tokens 37 | solution = problem['solution'].replace("\n", "\\n") 38 | return solution 39 | 40 | 41 | def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True): 42 | 43 | input_format, output_format = format.split("-") 44 | 45 | ## Inputs 46 | if input_format == "CQM": 47 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" 48 | elif input_format == "QCM": 49 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" 50 | # upper bound experiment 51 | elif input_format == "QCML": 52 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" 53 | elif input_format == "QCME": 54 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" 55 | elif input_format == "QCMLE": 56 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" 57 | 58 | elif input_format == "QCLM": 59 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" 60 | elif input_format == "QCEM": 61 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" 62 | elif input_format == "QCLEM": 63 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" 64 | 65 | # Outputs 66 | if test_example: 67 | output = "Answer:" 68 | elif output_format == 'A': 69 | output = f"Answer: The answer is {answer}." 70 | 71 | elif output_format == 'AL': 72 | output = f"Answer: The answer is {answer}. BECAUSE: {solution}" 73 | elif output_format == 'AE': 74 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture}" 75 | elif output_format == 'ALE': 76 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}" 77 | elif output_format == 'AEL': 78 | output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}" 79 | 80 | elif output_format == 'LA': 81 | output = f"Answer: {lecture} The answer is {answer}." 82 | elif output_format == 'EA': 83 | output = f"Answer: {solution} The answer is {answer}." 84 | elif output_format == 'LEA': 85 | output = f"Answer: {lecture} {solution} The answer is {answer}." 86 | elif output_format == 'ELA': 87 | output = f"Answer: {solution} {lecture} The answer is {answer}." 88 | 89 | text = input + output 90 | text = text.replace(" ", " ").strip() 91 | if text.endswith("BECAUSE:"): 92 | text = text.replace("BECAUSE:", "").strip() 93 | return text 94 | 95 | 96 | def create_training_example(format, question, context, choice, answer, lecture, solution): 97 | 98 | input_format, output_format = format.split("-") 99 | 100 | ## Inputs 101 | if input_format == "CQM": 102 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n" 103 | elif input_format == "QCM": 104 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n" 105 | # upper bound experiment 106 | elif input_format == "QCML": 107 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n" 108 | elif input_format == "QCME": 109 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n" 110 | elif input_format == "QCMLE": 111 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n" 112 | 113 | elif input_format == "QCLM": 114 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n" 115 | elif input_format == "QCEM": 116 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n" 117 | elif input_format == "QCLEM": 118 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n" 119 | 120 | input+="Response:" 121 | input='\n'+input 122 | 123 | 124 | # Outputs 125 | if output_format == 'A': 126 | output = f"The answer is {answer}." 127 | 128 | elif output_format == 'AL': 129 | output = f"The answer is {answer}. BECAUSE: {solution}" 130 | elif output_format == 'AE': 131 | output = f"The answer is {answer}. BECAUSE: {lecture}" 132 | elif output_format == 'ALE': 133 | output = f"The answer is {answer}. BECAUSE: {lecture} {solution}" 134 | elif output_format == 'AEL': 135 | output = f"The answer is {answer}. BECAUSE: {solution} {lecture}" 136 | 137 | elif output_format == 'LA': 138 | output = f"{lecture} The answer is {answer}." 139 | elif output_format == 'EA': 140 | output = f"{solution} The answer is {answer}." 141 | elif output_format == 'LEA': 142 | output = f"{lecture} {solution} The answer is {answer}." 143 | elif output_format == 'ELA': 144 | output = f"{solution} {lecture} The answer is {answer}." 145 | 146 | input = input.replace(" ", " ").strip() 147 | output = output.replace(" ", " ").strip() 148 | if output.endswith("BECAUSE:"): 149 | text = output.replace("BECAUSE:", "").strip() 150 | 151 | # print(input) 152 | return input, output 153 | 154 | def build_few_shot_prompt(problems, shot_qids, test_qid, args): 155 | 156 | examples = [] 157 | 158 | # n-shot training examples 159 | for qid in shot_qids: 160 | question = get_question_text(problems[qid]) 161 | context = get_context_text(problems[qid], args.use_caption) 162 | choice = get_choice_text(problems[qid], args.options) 163 | answer = get_answer(problems[qid], args.options) 164 | lecture = get_lecture_text(problems[qid]) 165 | solution = get_solution_text(problems[qid]) 166 | 167 | train_example = create_one_example(args.prompt_format, 168 | question, 169 | context, 170 | choice, 171 | answer, 172 | lecture, 173 | solution, 174 | test_example=False) 175 | examples.append(train_example) 176 | 177 | # test example 178 | question = get_question_text(problems[test_qid]) 179 | context = get_context_text(problems[test_qid], args.use_caption) 180 | choice = get_choice_text(problems[test_qid], args.options) 181 | answer = get_answer(problems[test_qid], args.options) 182 | lecture = get_lecture_text(problems[test_qid]) 183 | solution = get_solution_text(problems[test_qid]) 184 | 185 | test_example = create_one_example(args.prompt_format, 186 | question, 187 | context, 188 | choice, 189 | answer, 190 | lecture, 191 | solution, 192 | test_example=True) 193 | examples.append(test_example) 194 | 195 | # create the prompt input 196 | prompt_input = '\n\n'.join(examples) 197 | 198 | return prompt_input 199 | 200 | def build_prompt(problems, test_qid, args): 201 | 202 | # test example 203 | question = get_question_text(problems[test_qid]) 204 | context = get_context_text(problems[test_qid], args.use_caption) 205 | choice = get_choice_text(problems[test_qid], args.options) 206 | answer = get_answer(problems[test_qid], args.options) 207 | lecture = get_lecture_text(problems[test_qid]) 208 | solution = get_solution_text(problems[test_qid]) 209 | 210 | test_example = create_training_example(args.prompt_format, 211 | question, 212 | context, 213 | choice, 214 | answer, 215 | lecture, 216 | solution) 217 | return test_example 218 | -------------------------------------------------------------------------------- /util/datasets.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 Gen Luo. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import json, re,random 16 | import torch.utils.data as Data 17 | from torchvision.transforms import transforms 18 | import os 19 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 20 | from PIL import Image 21 | from util.base_prompt import * 22 | import torch 23 | from lavin import Tokenizer 24 | import copy 25 | 26 | class ScienceQADataSet(Data.Dataset): 27 | def __init__(self, args,split,model_path,max_words=512,max_image_feats=1): 28 | super(ScienceQADataSet, self).__init__() 29 | self.args = args 30 | # -------------------------- 31 | # ---- Raw data loading --- 32 | # -------------------------- 33 | self.problems = json.load(open(os.path.join(args.data_root, 'problems.json'))) 34 | pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json'))) 35 | captions = json.load(open(args.caption_file))["captions"] 36 | self.image_path=os.path.join(args.data_root,'images',split) 37 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model') 38 | self.max_words = max_words 39 | self.max_image_feats=max_image_feats 40 | self.split=split 41 | for qid in self.problems: 42 | self.problems[qid]['caption'] = captions[qid] if qid in captions else "" 43 | 44 | self.qids = pid_splits['%s' % (split)] 45 | 46 | print(f"number of problems in split {split}: {len(self.qids)}\n") 47 | 48 | self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 49 | 50 | def tokenize(self,prompt,answer): 51 | example=prompt+answer 52 | # print(prompt) 53 | prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64) 54 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64) 55 | padding = self.max_words - example.shape[0] 56 | if padding > 0: 57 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 58 | elif padding < 0: 59 | example = example[:self.max_words] 60 | labels = copy.deepcopy(example) 61 | labels[:len(prompt)] = -1 62 | example_mask = example.ge(0) 63 | label_mask = labels.ge(0) 64 | example[~example_mask] = 0 65 | labels[~label_mask] = 0 66 | example_mask = example_mask.float() 67 | label_mask = label_mask.float() 68 | return example, labels, example_mask,label_mask 69 | 70 | 71 | def __getitem__(self, idx): 72 | 73 | prompt_question,prompt_answer= build_prompt(self.problems,self.qids[idx],self.args) 74 | answer,choices,qid=self.problems[self.qids[idx]]["answer"], self.problems[self.qids[idx]]["choices"],self.qids[idx] 75 | 76 | if self.problems[self.qids[idx]]['image'] is not None: 77 | image = Image.open(os.path.join(self.image_path, self.qids[idx], 'image.png')).convert('RGB') 78 | image = self.transforms(image) 79 | image_mask=torch.cat([torch.Tensor([float('-inf')]*self.max_image_feats),torch.zeros(self.max_words)]) 80 | indicator=1 81 | else: 82 | image=torch.Tensor(torch.zeros(3,224,224).float()) 83 | image_mask=torch.zeros(self.max_words+self.max_image_feats) 84 | indicator=0 85 | 86 | example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer) 87 | 88 | return example, labels, example_mask, image,indicator 89 | 90 | def __len__(self): 91 | return len(self.qids) 92 | 93 | def shuffle_list(self, list): 94 | random.shuffle(list) 95 | 96 | 97 | 98 | class InstrcutDataSet(Data.Dataset): 99 | def __init__(self, args,split,model_path,max_words=512,max_image_feats=1): 100 | super(InstrcutDataSet, self).__init__() 101 | self.args = args 102 | # -------------------------- 103 | # ---- Raw data loading --- 104 | # -------------------------- 105 | self.data = json.load(open(os.path.join(args.data_root, 'all_data.json')))[split] 106 | 107 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model') 108 | self.max_words = max_words 109 | self.max_image_feats=max_image_feats 110 | self.split=split 111 | self.qids = [item['qid'] for item in self.data] 112 | 113 | print(f"number of problems in split {split}: {len(self.qids)}\n") 114 | 115 | self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) 116 | 117 | def tokenize(self,prompt,answer,max_words=512): 118 | example=prompt+answer 119 | # print(prompt) 120 | prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64) 121 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64) 122 | padding = max_words - example.shape[0] 123 | if padding > 0: 124 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1)) 125 | elif padding < 0: 126 | example = example[:self.max_words] 127 | labels = copy.deepcopy(example) 128 | labels[:len(prompt)] = -1 129 | example_mask = example.ge(0) 130 | label_mask = labels.ge(0) 131 | example[~example_mask] = 0 132 | labels[~label_mask] = 0 133 | example_mask = example_mask.float() 134 | label_mask = label_mask.float() 135 | return example, labels, example_mask,label_mask 136 | 137 | 138 | def __getitem__(self, idx): 139 | 140 | prompt_question=self.data[idx]['instruction'] 141 | prompt_answer=self.data[idx]['answer'] 142 | 143 | if self.data[idx]['image'] is not None: 144 | # image_path='../data/images/train' if self.data[idx]['image_source']=='sqa' else '../data/images/train2014' 145 | if self.data[idx]['image_source'] == 'sqa': 146 | image = Image.open(os.path.join('../data/images/train', self.qids[idx], 'image.png')).convert('RGB') 147 | else: 148 | image = Image.open(os.path.join('../data/images/train2014', 'COCO_train2014_'+self.data[idx]['image'])).convert('RGB') 149 | image = self.transforms(image) 150 | indicator=1 151 | else: 152 | image=torch.Tensor(torch.zeros(3,224,224).float()) 153 | indicator=0 154 | 155 | # print(prompt_question,prompt_answer) 156 | example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer) 157 | 158 | return example, labels, example_mask, image,indicator 159 | 160 | def __len__(self): 161 | return len(self.qids) 162 | 163 | def shuffle_list(self, list): 164 | random.shuffle(list) 165 | 166 | if __name__ == '__main__': 167 | from torch.utils.data import DataLoader 168 | class Cfg(): 169 | def __init__(self): 170 | super(Cfg, self).__init__() 171 | self.options = ["A", "B", "C", "D", "E"] 172 | self.use_caption = True 173 | self.prompt_format = 'CQM-A' 174 | self.data_root = './data' 175 | self.output_root = './output' 176 | self.caption_file = './data/captions.json' 177 | cfg=Cfg() 178 | dataset=ScienceQADataSet(cfg,'val','./data/weights') 179 | data_loader = DataLoader(dataset, 180 | batch_size=1, 181 | shuffle=False, 182 | pin_memory=True) 183 | max_question_len=0 184 | max_answer_len=0 185 | #406 max question 186 | for prompt_questions,question_mask,images,image_masks,prompt_answers,answers,qids in data_loader: 187 | print(prompt_questions) 188 | print(answers) 189 | # if len(prompt_questions[0].split())>max_question_len: 190 | # max_question_len=len(prompt_questions[0].split()) 191 | # if len(prompt_answers[0].split())>max_answer_len: 192 | # max_answer_len=len(prompt_answers[0].split()) 193 | # print(max_question_len,max_answer_len) 194 | 195 | 196 | 197 | 198 | 199 | 200 | -------------------------------------------------------------------------------- /util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 20 | super().__init__(params, defaults) 21 | 22 | @torch.no_grad() 23 | def step(self): 24 | for g in self.param_groups: 25 | for p in g['params']: 26 | dp = p.grad 27 | 28 | if dp is None: 29 | continue 30 | 31 | if p.ndim > 1: # if not normalization gamma/beta or bias 32 | dp = dp.add(p, alpha=g['weight_decay']) 33 | param_norm = torch.norm(p) 34 | update_norm = torch.norm(dp) 35 | one = torch.ones_like(param_norm) 36 | q = torch.where(param_norm > 0., 37 | torch.where(update_norm > 0, 38 | (g['trust_coefficient'] * param_norm / update_norm), one), 39 | one) 40 | dp = dp.mul(q) 41 | 42 | param_state = self.state[p] 43 | if 'mu' not in param_state: 44 | param_state['mu'] = torch.zeros_like(p) 45 | mu = param_state['mu'] 46 | mu.mul_(g['momentum']).add_(dp) 47 | p.add_(mu, alpha=-g['lr']) -------------------------------------------------------------------------------- /util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): 16 | """ 17 | Parameter groups for layer-wise lr decay 18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 19 | """ 20 | param_group_names = {} 21 | param_groups = {} 22 | 23 | num_layers = len(model.blocks) + 1 24 | 25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 26 | 27 | for n, p in model.named_parameters(): 28 | if not p.requires_grad: 29 | continue 30 | 31 | # no decay: all 1D parameters and model specific ones 32 | if p.ndim == 1 or n in no_weight_decay_list: 33 | g_decay = "no_decay" 34 | this_decay = 0. 35 | else: 36 | g_decay = "decay" 37 | this_decay = weight_decay 38 | 39 | layer_id = get_layer_id_for_vit(n, num_layers) 40 | group_name = "layer_%d_%s" % (layer_id, g_decay) 41 | 42 | if group_name not in param_group_names: 43 | this_scale = layer_scales[layer_id] 44 | 45 | param_group_names[group_name] = { 46 | "lr_scale": this_scale, 47 | "weight_decay": this_decay, 48 | "params": [], 49 | } 50 | param_groups[group_name] = { 51 | "lr_scale": this_scale, 52 | "weight_decay": this_decay, 53 | "params": [], 54 | } 55 | 56 | param_group_names[group_name]["params"].append(n) 57 | param_groups[group_name]["params"].append(p) 58 | 59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 60 | 61 | return list(param_groups.values()) 62 | 63 | 64 | def get_layer_id_for_vit(name, num_layers): 65 | """ 66 | Assign a parameter with its layer id 67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 68 | """ 69 | if name in ['cls_token', 'pos_embed']: 70 | return 0 71 | elif name.startswith('patch_embed'): 72 | return 0 73 | elif name.startswith('blocks'): 74 | return int(name.split('.')[1]) + 1 75 | else: 76 | return num_layers -------------------------------------------------------------------------------- /util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | def adjust_learning_rate(optimizer, epoch, args): 10 | """Decay the learning rate with half-cycle cosine after warmup""" 11 | if epoch < args.warmup_epochs: 12 | lr = args.lr * epoch / args.warmup_epochs 13 | else: 14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) 16 | for param_group in optimizer.param_groups: 17 | if "lr_scale" in param_group: 18 | param_group["lr"] = lr * param_group["lr_scale"] 19 | else: 20 | param_group["lr"] = lr 21 | return lr 22 | # 23 | # def adjust_learning_rate_v2(optimizer, steps, args,step_per_epoch): 24 | # """Decay the learning rate with half-cycle cosine after warmup""" 25 | # if steps < args.warmup_epochs: 26 | # lr = args.lr * steps / (args.warmup_epochs *step_per_epoch) 27 | # else: 28 | # lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ 29 | # (1. + math.cos(math.pi * (steps - args.warmup_epochs*step_per_epoch) / (args.epochs*step_per_epoch - step_per_epoch*args.warmup_epochs))) 30 | # for param_group in optimizer.param_groups: 31 | # if "lr_scale" in param_group: 32 | # param_group["lr"] = lr * param_group["lr_scale"] 33 | # else: 34 | # param_group["lr"] = lr 35 | # return lr 36 | 37 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import inf 22 | 23 | 24 | class SmoothedValue(object): 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value) 84 | 85 | 86 | class MetricLogger(object): 87 | def __init__(self, delimiter="\t"): 88 | self.meters = defaultdict(SmoothedValue) 89 | self.delimiter = delimiter 90 | 91 | def update(self, **kwargs): 92 | for k, v in kwargs.items(): 93 | if v is None: 94 | continue 95 | if isinstance(v, torch.Tensor): 96 | v = v.item() 97 | assert isinstance(v, (float, int)) 98 | self.meters[k].update(v) 99 | 100 | def __getattr__(self, attr): 101 | if attr in self.meters: 102 | return self.meters[attr] 103 | if attr in self.__dict__: 104 | return self.__dict__[attr] 105 | raise AttributeError("'{}' object has no attribute '{}'".format( 106 | type(self).__name__, attr)) 107 | 108 | def __str__(self): 109 | loss_str = [] 110 | for name, meter in self.meters.items(): 111 | loss_str.append( 112 | "{}: {}".format(name, str(meter)) 113 | ) 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = '' 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt='{avg:.4f}') 130 | data_time = SmoothedValue(fmt='{avg:.4f}') 131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 132 | log_msg = [ 133 | header, 134 | '[{0' + space_fmt + '}/{1}]', 135 | 'eta: {eta}', 136 | '{meters}', 137 | 'time: {time}', 138 | 'data: {data}' 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append('max mem: {memory:.0f}') 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print(log_msg.format( 153 | i, len(iterable), eta=eta_string, 154 | meters=str(self), 155 | time=str(iter_time), data=str(data_time), 156 | memory=torch.cuda.max_memory_allocated() / MB)) 157 | else: 158 | print(log_msg.format( 159 | i, len(iterable), eta=eta_string, 160 | meters=str(self), 161 | time=str(iter_time), data=str(data_time))) 162 | i += 1 163 | end = time.time() 164 | total_time = time.time() - start_time 165 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 166 | print('{} Total time: {} ({:.4f} s / it)'.format( 167 | header, total_time_str, total_time / len(iterable))) 168 | 169 | 170 | def setup_for_distributed(is_master): 171 | """ 172 | This function disables printing when not in master process 173 | """ 174 | builtin_print = builtins.print 175 | 176 | def print(*args, **kwargs): 177 | force = kwargs.pop('force', False) 178 | force = force or (get_world_size() > 8) 179 | if is_master or force: 180 | now = datetime.datetime.now().time() 181 | builtin_print('[{}] '.format(now), end='') # print with time stamp 182 | builtin_print(*args, **kwargs) 183 | 184 | builtins.print = print 185 | 186 | 187 | def is_dist_avail_and_initialized(): 188 | if not dist.is_available(): 189 | return False 190 | if not dist.is_initialized(): 191 | return False 192 | return True 193 | 194 | 195 | def get_world_size(): 196 | if not is_dist_avail_and_initialized(): 197 | return 1 198 | return dist.get_world_size() 199 | 200 | 201 | def get_rank(): 202 | if not is_dist_avail_and_initialized(): 203 | return 0 204 | return dist.get_rank() 205 | 206 | 207 | def is_main_process(): 208 | return get_rank() == 0 209 | 210 | 211 | def save_on_master(*args, **kwargs): 212 | if is_main_process(): 213 | torch.save(*args, **kwargs) 214 | 215 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel 216 | def init_distributed_mode(args): 217 | if args.dist_on_itp: 218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 222 | os.environ['LOCAL_RANK'] = str(args.gpu) 223 | os.environ['RANK'] = str(args.rank) 224 | os.environ['WORLD_SIZE'] = str(args.world_size) 225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 227 | args.rank = int(os.environ["RANK"]) 228 | args.world_size = int(os.environ['WORLD_SIZE']) 229 | args.gpu = int(os.environ['LOCAL_RANK']) 230 | elif 'SLURM_PROCID' in os.environ: 231 | args.rank = int(os.environ['SLURM_PROCID']) 232 | args.gpu = args.rank % torch.cuda.device_count() 233 | else: 234 | print('Not using distributed mode') 235 | setup_for_distributed(is_master=True) # hack 236 | args.distributed = False 237 | return 238 | 239 | args.distributed = True 240 | 241 | torch.cuda.set_device(args.gpu) 242 | args.dist_backend = 'nccl' 243 | print('| distributed init (rank {}): {}, gpu {}'.format( 244 | args.rank, args.dist_url, args.gpu), flush=True) 245 | 246 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 247 | world_size=args.world_size, rank=args.rank) 248 | torch.distributed.barrier() 249 | setup_for_distributed(args.rank == 0) 250 | 251 | 252 | class NativeScalerWithGradNormCount: 253 | state_dict_key = "amp_scaler" 254 | 255 | def __init__(self): 256 | self._scaler = torch.cuda.amp.GradScaler() 257 | 258 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 259 | self._scaler.scale(loss).backward(create_graph=create_graph) 260 | if update_grad: 261 | if clip_grad is not None: 262 | assert parameters is not None 263 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 264 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 265 | else: 266 | self._scaler.unscale_(optimizer) 267 | norm = get_grad_norm_(parameters) 268 | self._scaler.step(optimizer) 269 | self._scaler.update() 270 | else: 271 | norm = None 272 | return norm 273 | 274 | def state_dict(self): 275 | return self._scaler.state_dict() 276 | 277 | def load_state_dict(self, state_dict): 278 | self._scaler.load_state_dict(state_dict) 279 | 280 | 281 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 282 | if isinstance(parameters, torch.Tensor): 283 | parameters = [parameters] 284 | parameters = [p for p in parameters if p.grad is not None] 285 | norm_type = float(norm_type) 286 | if len(parameters) == 0: 287 | return torch.tensor(0.) 288 | device = parameters[0].grad.device 289 | if norm_type == inf: 290 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 291 | else: 292 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 293 | return total_norm 294 | 295 | 296 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): 297 | output_dir = Path(args.output_dir) 298 | epoch_name = str(epoch) 299 | model_without_ddp.eval() 300 | trainable = {} 301 | for n, p in model.named_parameters(): 302 | if 'adapter' in n: 303 | trainable[n] = p.data 304 | # if loss_scaler is not None: 305 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] 306 | for checkpoint_path in checkpoint_paths: 307 | to_save = { 308 | 'model': trainable, 309 | 'optimizer': optimizer.state_dict(), 310 | 'epoch': epoch, 311 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None, 312 | 'args': args, 313 | } 314 | save_on_master(to_save, checkpoint_path) 315 | 316 | 317 | def load_model(args, model_without_ddp, optimizer, loss_scaler): 318 | if args.resume: 319 | if args.resume.startswith('https'): 320 | checkpoint = torch.hub.load_state_dict_from_url( 321 | args.resume, map_location='cpu', check_hash=True) 322 | else: 323 | checkpoint = torch.load(args.resume, map_location='cpu') 324 | model_without_ddp.load_state_dict(checkpoint['model'],strict=False) 325 | print("Resume checkpoint %s" % args.resume) 326 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): 327 | optimizer.load_state_dict(checkpoint['optimizer']) 328 | args.start_epoch = checkpoint['epoch'] + 1 329 | if 'scaler' in checkpoint: 330 | loss_scaler.load_state_dict(checkpoint['scaler']) 331 | print("With optim & sched!") 332 | 333 | 334 | def all_reduce_mean(x): 335 | world_size = get_world_size() 336 | if world_size > 1: 337 | x_reduce = torch.tensor(x).cuda() 338 | dist.all_reduce(x_reduce) 339 | x_reduce /= world_size 340 | return x_reduce.item() 341 | else: 342 | return x 343 | -------------------------------------------------------------------------------- /util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed 97 | -------------------------------------------------------------------------------- /util/quantization.py: -------------------------------------------------------------------------------- 1 | from transformers.utils.bitsandbytes import * 2 | from transformers import BitsAndBytesConfig 3 | import torch 4 | from torch import nn 5 | import bitsandbytes as bnb 6 | 7 | from fairscale.nn.model_parallel.layers import ( 8 | ParallelEmbedding, 9 | RowParallelLinear, 10 | ColumnParallelLinear, 11 | ) 12 | def _replace_with_bnb_linear( 13 | model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False 14 | ): 15 | """ 16 | Private method that wraps the recursion for module replacement. 17 | 18 | Returns the converted model and a boolean that indicates if the conversion has been successfull or not. 19 | """ 20 | for name, module in model.named_children(): 21 | if current_key_name is None: 22 | current_key_name = [] 23 | current_key_name.append(name) 24 | 25 | if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear) or isinstance(module, RowParallelLinear) ) and name not in modules_to_not_convert: 26 | # Check if the current key is not in the `modules_to_not_convert` 27 | if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): 28 | # with init_empty_weights(): 29 | if quantization_config.quantization_method() == "llm_int8": 30 | model._modules[name] = bnb.nn.Linear8bitLt( 31 | module.in_features, 32 | module.out_features, 33 | module.bias is not None, 34 | has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, 35 | threshold=quantization_config.llm_int8_threshold, 36 | ) 37 | has_been_replaced = True 38 | else: 39 | if ( 40 | quantization_config.llm_int8_skip_modules is not None 41 | and name in quantization_config.llm_int8_skip_modules 42 | ): 43 | pass 44 | else: 45 | model._modules[name] = bnb.nn.Linear4bit( 46 | module.in_features, 47 | module.out_features, 48 | module.bias is not None, 49 | quantization_config.bnb_4bit_compute_dtype, 50 | compress_statistics=quantization_config.bnb_4bit_use_double_quant, 51 | quant_type=quantization_config.bnb_4bit_quant_type, 52 | ) 53 | has_been_replaced = True 54 | # Force requires grad to False to avoid unexpected errors 55 | model._modules[name].requires_grad_(False) 56 | if len(list(module.children())) > 0: 57 | _, has_been_replaced = _replace_with_bnb_linear( 58 | module, 59 | modules_to_not_convert, 60 | current_key_name, 61 | quantization_config, 62 | has_been_replaced=has_been_replaced, 63 | ) 64 | # Remove the last key for recursion 65 | current_key_name.pop(-1) 66 | return model, has_been_replaced 67 | 68 | 69 | def quant_model_bnb(model, quant_bit='4bit', keep_in_fp32_modules=[], 70 | quantization_config=None): 71 | if quantization_config is None: 72 | # set default quantization config 73 | # compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32)) 74 | quantization_config = BitsAndBytesConfig( 75 | load_in_4bit=quant_bit == '4bit', 76 | load_in_8bit=quant_bit == '8bit', 77 | llm_int8_threshold=6.0, 78 | llm_int8_has_fp16_weight=False, 79 | bnb_4bit_compute_dtype=torch.float16, 80 | bnb_4bit_use_double_quant=True, 81 | bnb_4bit_quant_type='nf4', 82 | ) 83 | 84 | model,_ = _replace_with_bnb_linear( 85 | model, modules_to_not_convert=keep_in_fp32_modules, quantization_config=quantization_config 86 | ) 87 | 88 | return model 89 | 90 | --------------------------------------------------------------------------------