├── .gitignore
├── README.md
├── assets
├── demo.gif
├── examples.png
├── logo.png
├── teaser-1.png
├── teaser.gif
└── teaser.png
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── conversation
└── conversation.py
├── demo.py
├── engine.py
├── eval.py
├── eval_mme.py
├── lavin
├── __init__.py
├── eval_model.py
├── generator.py
├── mm_adaptation.py
├── mm_adapter.py
├── model.py
└── tokenizer.py
├── requirements.txt
├── scripts
├── eval_mme_benchmark.sh
├── finetuning_sqa_13b.sh
├── finetuning_sqa_13b_lite.sh
├── finetuning_sqa_7b.sh
├── finetuning_sqa_7b_lite.sh
├── finetuning_sqa_vicuna_13b.sh
├── finetuning_sqa_vicuna_7b.sh
├── vl_instruction_tuning_13b.sh
└── vl_instruction_tuning_vicuna_13b.sh
├── setup.py
├── tools
└── data_processing.py
├── train.py
└── util
├── apply_delta.py
├── base_prompt.py
├── datasets.py
├── lars.py
├── lr_decay.py
├── lr_sched.py
├── misc.py
├── pos_embed.py
└── quantization.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # checkpoint
132 | *.pth
133 | outputs/
134 |
135 | .idea/
136 | debug.py
137 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | ---
4 |
5 | This repository contains the implementation of the NeurIPS 2023 paper:
6 | > **Cheap and Quick: Efficient Vision-Language Instruction Tuning for Large Language Models**
7 | > [[Project Page]](https://luogen1996.github.io/lavin/) [[Paper]](https://arxiv.org/pdf/2305.15023.pdf)
8 | > [Gen Luo](https://luogen1996.github.io)1, Yiyi Zhou12, [Tianhe Ren](https://rentainhe.github.io)1, Shengxin Chen1, [Xiaoshuai Sun](https://sites.google.com/view/xssun)12, [Rongrong Ji](https://mac.xmu.edu.cn/rrji/)12
9 | 1Media Analytics and Computing Lab, Department of Artificial Intelligence, School of Informatics, Xiamen University
10 | > 2Institute of Artificial Intelligence, Xiamen University
11 |
12 | In this work, we propose a novel and affordable solution for vision-language instruction tuning, namely Mixture-of-Modality Adaptation (MMA).
13 | Particularly, MMA is an end-to-end optimization regime, which connects the image encoder and LLM via lightweight adapters. Meanwhile, we also propose a novel routing algorithm in MMA, which can help the model automatically shifts the reasoning paths for single- and multi-modal instructions. Based on MMA, we develop a large vision-language instructed model called LaVIN, which demonstrates superior training efficiency and better reasoning ability than existing multimodal LLMs in various instruction-following tasks.
14 |
15 | ---
16 |
17 |
18 |

19 |
20 |
21 | ## News
22 | - **`2023/09/22`**: 🔥🔥🔥 Our paper is accepted by NeurIPS 2023!
23 | - **`2023/06/30`**: 🔥🔥🔥 With very limited training data and cost, LaVIN achieves 5-th place of Perception and Cognition on [MME benchmark](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation), outperforming seven existing multimodal LLMs. Evaluation codes are available.
24 | - **`2023/06/27`**: 🔥4-bit trainings are available now ! LaVIN-lite can be trained on one 3090 GPU, taking around 9G and 15G GPU memory for the scales of 7B and 13B , respectively. Technical details are available in [知乎](https://zhuanlan.zhihu.com/p/638784025).
25 | - **`2023/05/29`**: 🔥We released the demo and the pre-trained checkpoint (LLaMA-13B) for multimodal chatbot.
26 | - **`2023/05/25`**: 🔥We released the code of **LaVIN: Large Vision-Language Instructed model**, which achieves 89.4 (LaVIN-7B) and 90.8 (LaVIN-13B) accuracy on ScienceQA! 🔥With the proposed **mixture-of-modality adaptation**, the training time and trainable parameters can be reduced to 1.4 hours and 3.8M, respectively! Checkout the [paper](https://arxiv.org/pdf/2305.15023.pdf).
27 |
28 | ## TODO
29 | - [x] Release training codes.
30 | - [x] Release checkpoints and demo.
31 | - [x] 4-bit training.
32 | - [ ] Support more modalities, e.g., audio and video.
33 |
34 | ## Contents
35 | - [Setup](#setup)
36 | - [Fine-tuning](#fine-tuning)
37 | - [Demo](#demo)
38 | - [Model Zoo](#model-zoo)
39 |
40 | ## Setup
41 | ### Install Package
42 | ```bash
43 | conda create -n lavin python=3.8 -y
44 | conda activate lavin
45 |
46 | # install pytorch
47 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 -c pytorch
48 |
49 | # install dependency and lavin
50 | pip install -r requirements.txt
51 | pip install -e .
52 | ```
53 | ### Data Preparation
54 | - For ScienceQA, please prepare the dataset from the [official repo](https://github.com/lupantech/ScienceQA).
55 | - For Multimodal Chatbot, download the images in _train2014_ split from [MSCOCO](http://images.cocodataset.org/zips/train2014.zip), and obtain the prepared 52k text-only and 158k text-image instruction-following data from [here](https://drive.google.com/file/d/1gORDPruqwXbgy6NYmhpDXO7t089yzsg3/view?usp=share_link).
56 | - Obtain the weights of LLaMA from [this form](https://forms.gle/jk851eBVbX1m5TAv5) (official) or Download [LLaMA-7B](https://huggingface.co/nyanko7/LLaMA-7B/tree/main) and [LLaMA-13B](https://huggingface.co/TheBloke/llama-13b) from HuggingFace (unofficial).
57 | - If you want to use Vicuna weights to initialize the model, please download from [here](https://huggingface.co/lmsys).
58 | After that, the file structure should look like:
59 |
60 | ```bash
61 | LaVIN/
62 | |-- lavin
63 | |-- scripts
64 | |-- train.py
65 | |-- eval.py
66 | ......
67 | data/
68 | |-- problem.json
69 | |-- pid_splits.json
70 | |-- captions.json
71 | |-- all_data.json
72 | |-- images
73 | |-- train2014 # MSCOCO 2014
74 | |-- val2014 # MSCOCO 2014
75 | |-- train # ScienceQA train image
76 | |-- val # ScienceQA val image
77 | |-- test # ScienceQA test image
78 | |-- weights
79 | |-- tokenizer.model
80 | |--7B
81 | |-- params.json
82 | |-- consolidated.00.pth
83 | |--13B
84 | |-- params.json
85 | |-- consolidated.00.pth
86 | |-- consolidated.01.pth
87 | |--vicuna_7B
88 | |--vicuna_13B
89 | |-- config.json
90 | |-- generation_config.json
91 | |-- pytorch_model.bin.index.json
92 | |-- special_tokens_map.json
93 | |-- tokenizer_config.json
94 | |-- tokenizer.model
95 | |-- pytorch_model-00001-of-00003.bin
96 | |-- pytorch_model-00002-of-00003.bin
97 | |-- pytorch_model-00003-of-00003.bin
98 | ......
99 | ```
100 | ## Fine-tuning
101 | ### ScienceQA
102 | Reproduce the performance of LaVIN-7B on ScienceQA.
103 | For 7B model, we fine-tune it on 2x A100 (we find that the performance will be affected by the number of GPUs. We are working to address this problem).
104 |
105 |
106 | LLaMA weights:
107 | ```bash
108 | bash ./scripts/finetuning_sqa_7b.sh
109 | ```
110 |
111 | Vicuna weights:
112 | ```bash
113 | bash ./scripts/finetuning_sqa_vicuna_7b.sh
114 | ```
115 |
116 | LaVIN-lite with LLaMA weights (single GPU):
117 | ```bash
118 | bash ./scripts/finetuning_sqa_vicuna_7b_lite.sh
119 | ```
120 |
121 | Reproduce the performance of LaVIN-13B on ScienceQA (~2 hours on 8x A100 (80G)).
122 | For 13B model, we fine-tune it on 8x A100.
123 |
124 | LLaMA weights:
125 | ```bash
126 | bash ./scripts/finetuning_sqa_13b.sh
127 | ```
128 |
129 | Vicuna weights:
130 | ```bash
131 | bash ./scripts/finetuning_sqa_vicuna_13b.sh
132 | ```
133 | LaVIN-lite with LLaMA weights (single GPU):
134 | ```bash
135 | bash ./scripts/finetuning_sqa_vicuna_13b_lite.sh
136 | ```
137 | ### MultiModal ChatBot
138 | Fine-tune LaVIN-13B on 210k instruction-following data (~ 75 hours with 15 epochs and ~25 hours with 5 epochs on 8x A100 (80G) )
139 |
140 | LLaMA weights:
141 | ```bash
142 | bash ./scripts/vl_instruction_tuning_13b.sh
143 | ```
144 |
145 | Vicuna weights:
146 | ```bash
147 | bash ./scripts/vl_instruction_tuning_vicuna_13b.sh
148 | ```
149 | To train on fewer GPUs, you can reduce the number of gpus in the scripts and increase gradient accumulation via ```--accum_iter``` to guarantee the total batch size of 32. Setting ```--gradient_checkpointing``` and ```--bits 4bit``` in the scripts will greatly reduce the requirements of GPU memory.
150 |
151 |
152 | ## Demo
153 |
154 | LaVIN supports both single- and multi-modal instruction inputs. Try your custom instructions in our demo:
155 |
156 | - **Launch a gradio web server on your machine, then you can interact with LaVIN as you like.**
157 | ```
158 | torchrun --nproc_per_node 1 demo.py --server_name 127.0.0.1
159 | ```
160 |
161 |
162 |

163 |
164 |
165 |
166 | ## Model Zoo
167 | ### ScienceQA
168 | | Model | Weights | Time | Memory | #Params | Acc | Weights |
169 | |-----------|----------:|----------:|-------:|--------:|-----:|-----------------:|
170 | | LaVIN-7B-lite | LLaMA | 29 hours (single GPU) | 9G | 3.8M | 88.35 | [google drive](https://drive.google.com/file/d/1oVtoTgt-d9EqmrVic27oZUreN9dLClMo/view?usp=sharing) |
171 | | LaVIN-13B-lite | LLaMA | 42 hours (single GPU) | 14G | 5.4M | 89.44 | [google drive](https://drive.google.com/file/d/1PyVsap3FnmgXOGXFXjYsAtR75cFypaHw/view?usp=sharing) |
172 | | LaVIN-7B | LLaMA | 1.4 hours | 33.9G | 3.8M | 89.37 | [google drive](https://drive.google.com/file/d/10X2qCBYrLH1grZOHwHRMXLUoz-S6MSgV/view?usp=share_link) |
173 | | LaVIN-7B | Vicuna | 1.4 hours | 33.9G | 3.8M | 89.41 | [google drive](https://drive.google.com/file/d/1nuMxeiWlnJKxDybCshg8pVGSvLc5dZy8/view?usp=share_link) |
174 | | LaVIN-13B | LLaMA | 2 hours | 55.9G | 5.4M | 90.54 | [google drive](https://drive.google.com/file/d/1LkKUY54spZkkeXrR7BDmU-xmK9YadcKM/view?usp=share_link) |
175 | | LaVIN-13B | LLaMA | 4 hours | 55.9G | 5.4M | 90.8 | - |
176 |
177 | ### Multimodal ChatBot
178 | | Model |Weights | Time | Memory | #Params | Acc | Weights |
179 | |-----------|----------:|---------:|-------:|--------:|----:|-----------------:|
180 | | LaVIN-13B | LLaMA | 25 hours | 55.9G | 5.4M | - | - |
181 | | LaVIN-13B | LLaMA | 75 hours | 55.9G | 5.4M | - | [google drive](https://drive.google.com/file/d/1rHQNSaiGzFHYGgsamtySPYnd5AW4OE9j/view?usp=share_link)|
182 |
183 | ## Examples
184 |
185 |

186 |
187 |
188 | ## Star History
189 | [](https://star-history.com/#luogen1996/LaVIN&Date)
190 |
191 | ## Citation
192 | If you think our code and paper helpful, please kindly cite LaVIN and [RepAdapter](https://github.com/luogen1996/RepAdapter/):
193 | ```BibTeX
194 | @article{luo2023towards,
195 | title={Towards Efficient Visual Adaption via Structural Re-parameterization},
196 | author={Luo, Gen and Huang, Minglang and Zhou, Yiyi and Sun, Xiaoshuai and Jiang, Guangnan and Wang, Zhiyu and Ji, Rongrong},
197 | journal={arXiv preprint arXiv:2302.08106},
198 | year={2023}
199 | }
200 |
201 | @article{luo2023cheap,
202 | title={Cheap and Quick: Efficient Vision-Language Instruction Tuning for Large Language Models},
203 | author={Luo, Gen and Zhou, Yiyi and Ren, Tianhe and Chen, Shengxin and Sun, Xiaoshuai and Ji, Rongrong},
204 | journal={Advances in neural information processing systems (NeurIPS)},
205 | year={2023}
206 | }
207 | ```
208 |
209 |
210 | ## Acknowledgement
211 | This repo borrows some data and codes from [LLaMA](https://github.com/facebookresearch/llama), [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca), [LLaVA](https://github.com/haotian-liu/LLaVA), [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) and [LLaMA-Adapter](https://github.com/ZrrSkywalker/LLaMA-Adapter/). Thanks for their great works.
212 |
--------------------------------------------------------------------------------
/assets/demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/demo.gif
--------------------------------------------------------------------------------
/assets/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/examples.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/logo.png
--------------------------------------------------------------------------------
/assets/teaser-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser-1.png
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser.gif
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/assets/teaser.png
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luogen1996/LaVIN/dd0a1bfcfd6dd002f6cdc7113366b53957c927a6/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def patch_device(module):
149 | try:
150 | graphs = [module.graph] if hasattr(module, "graph") else []
151 | except RuntimeError:
152 | graphs = []
153 |
154 | if hasattr(module, "forward1"):
155 | graphs.append(module.forward1.graph)
156 |
157 | for graph in graphs:
158 | for node in graph.findAllNodes("prim::Constant"):
159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160 | node.copyAttributes(device_node)
161 |
162 | model.apply(patch_device)
163 | patch_device(model.encode_image)
164 | patch_device(model.encode_text)
165 |
166 | # patch dtype to float32 on CPU
167 | if str(device) == "cpu":
168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170 | float_node = float_input.node()
171 |
172 | def patch_float(module):
173 | try:
174 | graphs = [module.graph] if hasattr(module, "graph") else []
175 | except RuntimeError:
176 | graphs = []
177 |
178 | if hasattr(module, "forward1"):
179 | graphs.append(module.forward1.graph)
180 |
181 | for graph in graphs:
182 | for node in graph.findAllNodes("aten::to"):
183 | inputs = list(node.inputs())
184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185 | if inputs[i].node()["value"] == 5:
186 | inputs[i].node().copyAttributes(float_node)
187 |
188 | model.apply(patch_float)
189 | patch_float(model.encode_image)
190 | patch_float(model.encode_text)
191 |
192 | model.float()
193 |
194 | return model, _transform(model.input_resolution.item())
195 |
196 |
197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198 | """
199 | Returns the tokenized representation of given input string(s)
200 |
201 | Parameters
202 | ----------
203 | texts : Union[str, List[str]]
204 | An input string or a list of input strings to tokenize
205 |
206 | context_length : int
207 | The context length to use; all CLIP models use 77 as the context length
208 |
209 | truncate: bool
210 | Whether to truncate the text in case its encoding is longer than the context length
211 |
212 | Returns
213 | -------
214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216 | """
217 | if isinstance(texts, str):
218 | texts = [texts]
219 |
220 | sot_token = _tokenizer.encoder["<|startoftext|>"]
221 | eot_token = _tokenizer.encoder["<|endoftext|>"]
222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 | else:
226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227 |
228 | for i, tokens in enumerate(all_tokens):
229 | if len(tokens) > context_length:
230 | if truncate:
231 | tokens = tokens[:context_length]
232 | tokens[-1] = eot_token
233 | else:
234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235 | result[i, :len(tokens)] = torch.tensor(tokens)
236 |
237 | return result
238 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/conversation/conversation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | from PIL import Image
4 |
5 | import torch
6 | from transformers import StoppingCriteria, StoppingCriteriaList
7 |
8 | import dataclasses
9 | from enum import auto, Enum
10 | from typing import List, Tuple, Any
11 | import re
12 |
13 | ERROR_CODE= [260, 1794, 11440]
14 | ERROR_MESSAGE=[1, 7423, 29892, 474, 508, 29915, 29873, 1234, 445, 1139, 29889, 2]
15 |
16 |
17 | class SeparatorStyle(Enum):
18 | """Different separator style."""
19 | SINGLE = auto()
20 | TWO = auto()
21 |
22 |
23 | @dataclasses.dataclass
24 | class Conversation:
25 | """A class that keeps all conversation history."""
26 | system: str
27 | roles: List[str]
28 | messages: List[List[str]]
29 | offset: int
30 | # system_img: List[Image.Image] = []
31 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32 | sep: str = ""
33 | sep2: str = None
34 |
35 | skip_next: bool = False
36 | conv_id: Any = None
37 |
38 | def get_prompt(self):
39 | if self.sep_style == SeparatorStyle.SINGLE:
40 | ret = self.system
41 | for role, message in self.messages:
42 | if message:
43 | ret += role + ": " + message + self.sep
44 | else:
45 | ret += role + ":"
46 | return ret
47 | elif self.sep_style == SeparatorStyle.TWO:
48 | seps = [self.sep, self.sep2]
49 | ret = self.system + seps[0]
50 | for i, (role, message) in enumerate(self.messages):
51 | if message:
52 | ret += role + ": " + message + seps[i % 2]
53 | else:
54 | ret += role + ":"
55 | return ret
56 | else:
57 | raise ValueError(f"Invalid style: {self.sep_style}")
58 |
59 | def append_message(self, role, message):
60 | self.messages.append([role, message])
61 |
62 | def to_gradio_chatbot(self):
63 | ret = []
64 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
65 | if i % 2 == 0:
66 | ret.append([msg, None])
67 | else:
68 | ret[-1][-1] = msg
69 | return ret
70 |
71 | def copy(self):
72 | return Conversation(
73 | system=self.system,
74 | # system_img=self.system_img,
75 | roles=self.roles,
76 | messages=[[x, y] for x, y in self.messages],
77 | offset=self.offset,
78 | sep_style=self.sep_style,
79 | sep=self.sep,
80 | sep2=self.sep2,
81 | conv_id=self.conv_id)
82 |
83 | def dict(self):
84 | return {
85 | "system": self.system,
86 | # "system_img": self.system_img,
87 | "roles": self.roles,
88 | "messages": self.messages,
89 | "offset": self.offset,
90 | "sep": self.sep,
91 | "sep2": self.sep2,
92 | "conv_id": self.conv_id,
93 | }
94 |
95 |
96 | class StoppingCriteriaSub(StoppingCriteria):
97 |
98 | def __init__(self, stops=[], encounters=1):
99 | super().__init__()
100 | self.stops = stops
101 |
102 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
103 | for stop in self.stops:
104 | if torch.all((stop == input_ids[0][-len(stop):])).item():
105 | return True
106 |
107 | return False
108 |
109 |
110 | CONV_VISION = Conversation(
111 | system="",
112 | roles=("Instruction", "Responese"),
113 | messages=[],
114 | offset=2,
115 | sep_style=SeparatorStyle.SINGLE,
116 | sep="\n",
117 | )
118 |
119 |
120 |
121 | class Chat:
122 | def __init__(self, model, vis_processor, device='cuda:0'):
123 | self.device = device
124 | self.lavin = model
125 | self.vis_processor = vis_processor
126 |
127 | def ask(self, text, conv):
128 | if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
129 | and conv.messages[-1][1][-6:] == '': # last message is image.
130 | conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
131 | else:
132 | conv.append_message(conv.roles[0], text)
133 |
134 | def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
135 | repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,n_feats=6):
136 | conv.append_message(conv.roles[1], None)
137 | prompt, indicator, img = self.get_context_emb(conv, img_list)
138 |
139 | current_max_len = len(prompt) + max_new_tokens+n_feats
140 | if current_max_len - max_length > 0:
141 | print('Warning: The number of tokens in current conversation exceeds the max length. '
142 | 'The model will not see the contexts outside the range.')
143 | begin_idx = max(0, current_max_len - max_length)
144 |
145 | prompt = prompt[begin_idx:]
146 | CODE=self.lavin.tokenizer.encode(prompt, bos=False, eos=False)
147 | if ERROR_CODE in [CODE[i:i+len(ERROR_CODE)] for i in range(len(CODE)-len(ERROR_CODE)+1)]:
148 | output_text=self.lavin.tokenizer.decode(ERROR_MESSAGE).split('Responese:')[-1].strip()
149 | else:
150 | outputs = self.lavin.generate(
151 | prompts= [prompt],
152 | images= img,
153 | indicators=[indicator],
154 | max_gen_len=max_length,
155 | n_feats=n_feats,
156 | temperature = 0.1,
157 | top_p = 0.75,
158 | )
159 |
160 | output_text = outputs[0].split('Responese:')[-1].strip()
161 |
162 | conv.messages[-1][1] = output_text
163 | return output_text
164 |
165 | def upload_img(self, image, conv, img_list):
166 | if isinstance(image, str): # is a image path
167 | raw_image = Image.open(image).convert('RGB')
168 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
169 | elif isinstance(image, Image.Image):
170 | raw_image = image
171 | image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
172 | elif isinstance(image, torch.Tensor):
173 | if len(image.shape) == 3:
174 | image = image.unsqueeze(0)
175 | image = image.to(self.device)
176 |
177 | # image_emb, _ = self.lavin.backbone.encode_img(image)
178 | img_list.append(image)
179 | conv.append_message(conv.roles[0], "
")
180 | msg = "Received."
181 | # self.conv.append_message(self.conv.roles[1], msg)
182 | return msg
183 |
184 | def get_context_emb(self, conv, img_list):
185 | prompt = conv.get_prompt()
186 |
187 | if '
' in prompt:
188 | indicator=1
189 | prompt=prompt.replace('
','')
190 | else:
191 | indicator=0
192 | assert img_list is None or len(img_list) <= 1
193 |
194 | return prompt, indicator, img_list[0] if indicator==1 else torch.Tensor(torch.zeros(1,3, 224, 224).float())
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | import torch
7 | import torch.backends.cudnn as cudnn
8 | import gradio as gr
9 | from PIL import Image
10 | # from minigpt4.common.config import Config
11 | from util.misc import get_rank
12 | # from minigpt4.common.registry import registry
13 | from conversation.conversation import Chat, CONV_VISION
14 | from torchvision.transforms import transforms
15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16 | from eval import load
17 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
18 | from typing import Tuple
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description="Demo")
21 | parser.add_argument("--server_name", type=str, default="127.0.0.1", help="server name")
22 | parser.add_argument("--ckpt_dir", type=str, default="../data/weights/", help="dir of pre-trained weights.")
23 | parser.add_argument("--llm_model", type=str, default="13B", help="the type of llm.")
24 | parser.add_argument("--max_seq_len", type=int, default=512, help="decoder length")
25 | parser.add_argument('--adapter_type', type=str, default='attn', metavar='LENGTH',choices=['block','attn'],
26 | help='the insert position of adapter layer')
27 | parser.add_argument('--adapter_path', type=str, default='./15-eph-pretrain.pth', help='path of pre-trained adapter')
28 | parser.add_argument('--temperature', type=float, default=5., metavar='LENGTH',
29 | help='the temperature of router')
30 | parser.add_argument('--use_vicuna', action='store_true', help='use vicuna weights')
31 | parser.add_argument(
32 | "--options",
33 | nargs="+",
34 | help="override some settings in the used config, the key-value pair "
35 | "in xxx=yyy format will be merged into config file (deprecate), "
36 | "change to --cfg-options instead.",
37 | )
38 | args = parser.parse_args()
39 | return args
40 |
41 | def setup_model_parallel() -> Tuple[int, int]:
42 | local_rank = int(os.environ.get("LOCAL_RANK", -1))
43 | world_size = int(os.environ.get("WORLD_SIZE", -1))
44 |
45 | torch.distributed.init_process_group("nccl")
46 | initialize_model_parallel(world_size)
47 | torch.cuda.set_device(local_rank)
48 |
49 | # seed must be the same in all processes
50 | torch.manual_seed(1)
51 | return local_rank, world_size
52 |
53 | def setup_seeds(config):
54 | seed = config.run_cfg.seed + get_rank()
55 |
56 | random.seed(seed)
57 | np.random.seed(seed)
58 | torch.manual_seed(seed)
59 |
60 | cudnn.benchmark = False
61 | cudnn.deterministic = True
62 |
63 |
64 | # ========================================
65 | # Model Initialization
66 | # ========================================
67 |
68 | print('Initializing Chat')
69 | args = parse_args()
70 |
71 |
72 | local_rank, world_size = setup_model_parallel()
73 | lavin=load(
74 | ckpt_dir=args.ckpt_dir,
75 | llm_model=args.llm_model,
76 | adapter_path=args.adapter_path,
77 | max_seq_len=512,
78 | max_batch_size=4,
79 | adapter_type='attn',
80 | adapter_dim=8,
81 | adapter_scale=1,
82 | hidden_proj=128,
83 | visual_adapter_type='router',
84 | temperature=args.temperature,
85 | tokenizer_path='',
86 | local_rank=local_rank,
87 | world_size=world_size,
88 | use_vicuna=args.use_vicuna
89 | )
90 |
91 | vis_processor = transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
92 | chat = Chat(lavin, vis_processor, device=torch.device('cuda'))
93 | print('Initialization Finished')
94 |
95 |
96 | # ========================================
97 | # Gradio Setting
98 | # ========================================
99 |
100 | def gradio_reset(chat_state, img_list):
101 | if chat_state is not None:
102 | chat_state.messages = []
103 | if img_list is not None:
104 | img_list = []
105 | return None, gr.update(value=None, interactive=True), gr.update(placeholder='Type and press Enter',
106 | interactive=True), gr.update(
107 | value="Upload & Start Chat", interactive=True), chat_state, img_list
108 |
109 |
110 | def upload_img(gr_img, text_input, chat_state):
111 | if gr_img is None:
112 | return None, None, gr.update(interactive=True), chat_state, None
113 | chat_state = CONV_VISION.copy()
114 | img_list = []
115 | llm_message = chat.upload_img(gr_img, chat_state, img_list)
116 | return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
117 | value="Start Chatting", interactive=False), chat_state, img_list
118 |
119 |
120 | def gradio_ask(user_message, chatbot, chat_state):
121 | if len(user_message) == 0:
122 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
123 | if chat_state is None:
124 | chat_state=CONV_VISION.copy()
125 | chat.ask(user_message, chat_state)
126 | chatbot = chatbot + [[user_message, None]]
127 | return '', chatbot, chat_state
128 |
129 |
130 | def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
131 | llm_message = chat.answer(conv=chat_state,
132 | img_list=img_list,
133 | num_beams=num_beams,
134 | temperature=temperature,
135 | max_new_tokens=300,
136 | max_length=2000)
137 | chatbot[-1][1] = llm_message
138 | return chatbot, chat_state, img_list
139 |
140 |
141 | title = """Demo of LaVIN
"""
142 | description = """This is the demo of LaVIN. Upload your images and start chatting!
"""
143 |
144 |
145 |
146 | with gr.Blocks() as demo:
147 | gr.Markdown(title)
148 | gr.Markdown(description)
149 |
150 | with gr.Row():
151 | with gr.Column(scale=0.5):
152 | image = gr.Image(type="pil")
153 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
154 | clear = gr.Button("Restart")
155 |
156 | num_beams = gr.Slider(
157 | minimum=1,
158 | maximum=10,
159 | value=1,
160 | step=1,
161 | interactive=True,
162 | label="beam search numbers)",
163 | )
164 |
165 | temperature = gr.Slider(
166 | minimum=0.1,
167 | maximum=2.0,
168 | value=1.0,
169 | step=0.1,
170 | interactive=True,
171 | label="Temperature",
172 | )
173 |
174 | with gr.Column():
175 | chat_state = gr.State()
176 | img_list = gr.State()
177 | chatbot = gr.Chatbot(label='LaVIN-13B')
178 | text_input = gr.Textbox(label='User', placeholder='Type and press Enter', interactive=True)
179 |
180 | upload_button.click(upload_img, [image, text_input, chat_state],
181 | [image, text_input, upload_button, chat_state, img_list])
182 |
183 | text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
184 | gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
185 | )
186 | clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
187 | queue=False)
188 |
189 | demo.launch(share=True, enable_queue=True,server_name=args.server_name)
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import sys
3 | from typing import Iterable
4 |
5 | import torch
6 |
7 | import util.misc as misc
8 | import util.lr_sched as lr_sched
9 |
10 |
11 |
12 | def train_one_epoch(model: torch.nn.Module,
13 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
14 | device: torch.device, epoch: int, loss_scaler,
15 | log_writer=None,
16 | args=None):
17 |
18 | model.train(True)
19 | metric_logger = misc.MetricLogger(delimiter=" ")
20 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
21 | header = 'Epoch: [{}]'.format(epoch)
22 | print_freq = 10
23 |
24 | accum_iter = args.accum_iter
25 |
26 | optimizer.zero_grad()
27 |
28 | if log_writer is not None:
29 | print('log_dir: {}'.format(log_writer.log_dir))
30 |
31 |
32 |
33 | prefix_img = torch.tensor(data_loader.dataset.tokenizer.encode("Image: ", bos=False, eos=False), dtype=torch.int64)
34 | prefix_nonimg = torch.tensor(data_loader.dataset.tokenizer.encode("Image: N/A", bos=False, eos=False), dtype=torch.int64)
35 |
36 | for data_iter_step, (examples, labels, example_mask,images,indicators) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
37 | # we use a per iteration (instead of per epoch) lr scheduler
38 | if data_iter_step % accum_iter == 0:
39 | lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
40 |
41 | prefix_img=prefix_img.to(examples.device)
42 | prefix_nonimg=prefix_nonimg.to(examples.device)
43 | c_loss = model(examples, labels,images=images, prefix_img=prefix_img, prefix_nonimg=prefix_nonimg,img_indicators=indicators)
44 | loss = c_loss
45 | loss_value = loss.item()
46 | c_loss_value = c_loss.item()
47 |
48 |
49 | if torch.isnan(loss):
50 | print("NaN loss encountered. Skipping this batch.")
51 | continue
52 |
53 | loss = loss/accum_iter
54 |
55 | loss_scaler(loss, optimizer, parameters=model.parameters(),
56 | update_grad=(data_iter_step + 1) % accum_iter == 0,clip_grad=args.clip_grad)
57 | if (data_iter_step + 1) % accum_iter == 0:
58 | optimizer.zero_grad()
59 |
60 | torch.cuda.synchronize()
61 |
62 | metric_logger.update(closs=c_loss_value)
63 |
64 | lr = optimizer.param_groups[0]["lr"]
65 | metric_logger.update(lr=lr)
66 |
67 | loss_value_reduce = misc.all_reduce_mean(loss_value)
68 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
69 |
70 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
71 | """ We use epoch_1000x as the x-axis in tensorboard.
72 | This calibrates different curves when batch size changes.
73 | """
74 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
75 | log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
76 | log_writer.add_scalar('lr', lr, epoch_1000x)
77 |
78 | # gather the stats from all processes
79 | metric_logger.synchronize_between_processes()
80 | print("Averaged stats:", metric_logger)
81 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
82 |
83 |
84 | def val_one_epoch(model: torch.nn.Module,
85 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
86 | device: torch.device, epoch: int, loss_scaler,
87 | log_writer=None,
88 | args=None):
89 | model.eval()
90 | metric_logger = misc.MetricLogger(delimiter=" ")
91 | metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
92 | header = 'Epoch: [{}]'.format(epoch)
93 | print_freq = 10
94 |
95 | accum_iter = args.accum_iter
96 |
97 | if log_writer is not None:
98 | print('log_dir: {}'.format(log_writer.log_dir))
99 | for data_iter_step, (examples, labels, example_mask) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
100 |
101 | with torch.no_grad():
102 | c_loss = model(examples, labels)
103 | loss = c_loss
104 | loss_value = loss.item()
105 |
106 | c_loss_value = c_loss.item()
107 |
108 | if not math.isfinite(loss_value):
109 | print("Loss is {}, stopping training".format(loss_value))
110 | sys.exit(1)
111 |
112 | metric_logger.update(closs=c_loss_value)
113 |
114 | lr = optimizer.param_groups[0]["lr"]
115 | metric_logger.update(lr=lr)
116 |
117 | loss_value_reduce = misc.all_reduce_mean(loss_value)
118 | c_loss_value_reduce = misc.all_reduce_mean(c_loss_value)
119 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
120 | """ We use epoch_1000x as the x-axis in tensorboard.
121 | This calibrates different curves when batch size changes.
122 | """
123 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
124 | log_writer.add_scalar('c_train_loss', c_loss_value_reduce, epoch_1000x)
125 | log_writer.add_scalar('lr', lr, epoch_1000x)
126 |
127 | # gather the stats from all processes
128 | metric_logger.synchronize_between_processes()
129 | print("Averaged stats:", metric_logger)
130 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
131 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Tuple
5 | import os
6 | import sys
7 | import torch
8 | import fire
9 | import time
10 | import json
11 |
12 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
13 |
14 | from lavin.eval_model import ModelArgs, Transformer
15 | from lavin.tokenizer import Tokenizer
16 | from lavin.generator import LaVIN_Generator
17 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter
18 | from util.base_prompt import build_prompt
19 | from dataclasses import dataclass
20 | import re
21 | import random
22 |
23 | import warnings
24 | import pandas as pd
25 | from PIL import Image
26 |
27 | from torchvision.transforms import transforms
28 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
29 |
30 | from pathlib import Path
31 | import fairscale.nn.model_parallel.initialize as fs_init
32 | import torch.distributed as dist
33 | from util.apply_delta import apply_model_delta_online
34 |
35 | warnings.filterwarnings('ignore')
36 |
37 | @dataclass
38 | class PromptArgs:
39 | prompt_format='QCM-ALE'
40 | use_caption=True
41 | options=["A", "B", "C", "D", "E"]
42 |
43 | def setup_model_parallel() -> Tuple[int, int]:
44 | local_rank = int(os.environ.get("LOCAL_RANK", -1))
45 | world_size = int(os.environ.get("WORLD_SIZE", -1))
46 |
47 | torch.distributed.init_process_group("nccl")
48 | initialize_model_parallel(world_size)
49 | torch.cuda.set_device(local_rank)
50 |
51 | # seed must be the same in all processes
52 | torch.manual_seed(1)
53 | return local_rank, world_size
54 |
55 | def _load_and_redistribute_checkpoint(llama_model_path, model_name):
56 |
57 | with open(Path(llama_model_path) / model_name / 'params.json') as f:
58 | params = json.load(f)
59 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model'))
60 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name))
61 | if model_name=='7B':
62 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu")
63 | return checkpoint, tokenizer, params
64 |
65 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth')
66 | checkpoints = sorted(checkpoints)
67 |
68 | mp_world_size = fs_init.get_model_parallel_world_size()
69 | mp_rank = fs_init.get_model_parallel_rank()
70 | if mp_world_size == len(checkpoints):
71 | print('same number of shards of checkpoints and training, loading directly...')
72 | dist.barrier()
73 | print('[rank=%d, mp_rank=%d] loading from %s' % (dist.get_rank(), mp_rank, checkpoints[mp_rank]))
74 | checkpoint = torch.load(checkpoints[mp_rank], map_location='cpu')
75 | else:
76 | print('different number of shards of checkpoints and training, redistributing...')
77 | if dist.get_rank() == 0:
78 | loaded = []
79 | for x in checkpoints:
80 | print('loading from', x)
81 | loaded.append(torch.load(x, map_location='cpu'))
82 |
83 | full_state_dict = {}
84 | split_dims = {}
85 |
86 | def add_weight_with_split_dim(name, dim):
87 | if dim < 0: # bcast without split
88 | full_state_dict[name] = loaded[0][name].clone()
89 | else:
90 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim)
91 | for x in loaded:
92 | del x[name]
93 | split_dims[name] = dim
94 |
95 | add_weight_with_split_dim('tok_embeddings.weight', 1)
96 | add_weight_with_split_dim('norm.weight', -1)
97 | add_weight_with_split_dim('output.weight', 0)
98 | for i in range(params['n_layers']):
99 | print('gathering layer %d of %d' % (i, params['n_layers']))
100 | layer_prefix = f'layers.{i}.'
101 | bcast_names = [
102 | 'attention_norm.weight',
103 | 'ffn_norm.weight',
104 | ]
105 | column_parallel_names = [
106 | 'attention.wq.weight',
107 | 'attention.wk.weight',
108 | 'attention.wv.weight',
109 | 'feed_forward.w1.weight',
110 | 'feed_forward.w3.weight',
111 | ]
112 | row_parallel_names = [
113 | 'attention.wo.weight',
114 | 'feed_forward.w2.weight',
115 | ]
116 | for key in bcast_names:
117 | add_weight_with_split_dim(layer_prefix + key, -1)
118 | for key in column_parallel_names:
119 | add_weight_with_split_dim(layer_prefix + key, 0)
120 | for key in row_parallel_names:
121 | add_weight_with_split_dim(layer_prefix + key, 1)
122 |
123 | full_state_dict_meta = dict((k, v.shape) for k, v in full_state_dict.items())
124 | dist.broadcast_object_list([full_state_dict_meta, split_dims], src=0)
125 |
126 | else: # dist.get_rank() != 0
127 | recv_objs = [None, None]
128 | dist.broadcast_object_list(recv_objs, src=0)
129 | full_state_dict_meta, split_dims = recv_objs
130 |
131 | local_state_dict = {}
132 | for k in sorted(full_state_dict_meta.keys()):
133 | print('redistributing weights: %s' % k)
134 | if dist.get_rank() == 0:
135 | value = full_state_dict[k].cuda().half()
136 | del full_state_dict[k]
137 | else:
138 | value = torch.empty(full_state_dict_meta[k], device='cuda', dtype=torch.half)
139 | dist.broadcast(value, src=0)
140 | value = value.cpu()
141 | if split_dims[k] < 0:
142 | local_state_dict[k] = value
143 | else:
144 | dim = split_dims[k]
145 | assert dim >= 0 and dim < value.ndim and value.size(dim) % mp_world_size == 0
146 | shard_size = value.size(dim) // mp_world_size
147 | shard_st, shard_ed = shard_size * mp_rank, shard_size * (mp_rank + 1)
148 | # TODO: make more general
149 | if dim == 0:
150 | value = value[shard_st: shard_ed]
151 | elif dim == 1:
152 | value = value[:, shard_st: shard_ed]
153 | else:
154 | raise NotImplementedError()
155 | local_state_dict[k] = value.clone()
156 |
157 | checkpoint = local_state_dict
158 |
159 | return checkpoint, tokenizer, params
160 |
161 |
162 |
163 | def get_acc_with_contion(res_pd, key, values):
164 | if isinstance(values, list):
165 | total_pd = res_pd[res_pd[key].isin(values)]
166 | else:
167 | total_pd = res_pd[res_pd[key] == values]
168 | correct_pd = total_pd[total_pd['true_false'] == True]
169 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
170 | return acc
171 |
172 |
173 | def get_scores(result_file, data_file):
174 | # read result file
175 | results = json.load(open(result_file))
176 | num = len(results)
177 | assert num == 4241
178 |
179 | sqa_data = json.load(open(data_file))
180 |
181 | # construct pandas data
182 | sqa_pd = pd.DataFrame(sqa_data).T
183 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
184 |
185 | # update data
186 | for index, row in res_pd.iterrows():
187 |
188 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
189 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False
190 | res_pd.loc[index, 'has_image'] = True if row['image'] else False
191 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
192 |
193 | label = row['answer']
194 | pred = int(results[index])
195 | res_pd.loc[index, 'pred'] = pred
196 | res_pd.loc[index, 'true_false'] = (label == pred)
197 |
198 | # accuracy scores
199 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
200 |
201 | scores = {
202 | 'acc_natural':
203 | get_acc_with_contion(res_pd, 'subject', 'natural science'),
204 | 'acc_social':
205 | get_acc_with_contion(res_pd, 'subject', 'social science'),
206 | 'acc_language':
207 | get_acc_with_contion(res_pd, 'subject', 'language science'),
208 | 'acc_has_text':
209 | get_acc_with_contion(res_pd, 'has_text', True),
210 | 'acc_has_image':
211 | get_acc_with_contion(res_pd, 'has_image', True),
212 | 'acc_no_context':
213 | get_acc_with_contion(res_pd, 'no_context', True),
214 | 'acc_grade_1_6':
215 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
216 | 'acc_grade_7_12':
217 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
218 | 'acc_average':
219 | "{:.2f}".format(acc_average),
220 | }
221 |
222 | return scores
223 |
224 |
225 | def print_scores(scores):
226 | latex_output = ""
227 | for key, score in scores.items():
228 | print(f"{key[4:]}: \t{score}")
229 | latex_output += f"& {score} "
230 | latex_output += "\\\\"
231 | print(latex_output)
232 |
233 | def load(
234 | ckpt_dir: str,
235 | llm_model:str,
236 | tokenizer_path: str,
237 | adapter_path: str,
238 | local_rank: int,
239 | world_size: int,
240 | max_seq_len: int,
241 | max_batch_size: int,
242 | adapter_type: str,
243 | adapter_dim:int,
244 | adapter_scale:float,
245 | hidden_proj:int,
246 | visual_adapter_type: str,
247 | temperature: float,
248 | use_vicuna: bool,
249 | bits: str='16bits',
250 | cpu_load:bool=False,
251 | ) -> LaVIN_Generator:
252 | start_time = time.time()
253 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(ckpt_dir, llm_model)
254 |
255 | print("Loading")
256 | adapter_checkpoint = torch.load(adapter_path, map_location="cpu")
257 |
258 |
259 | model_args: ModelArgs = ModelArgs(
260 | max_seq_len=max_seq_len, max_batch_size=max_batch_size,hidden_proj=hidden_proj, **params
261 | )
262 | model_args.vocab_size = tokenizer.n_words
263 |
264 | if cpu_load:
265 | #cpu load is slow, but is freindly for GPU with limited memory.
266 | torch.set_default_tensor_type(torch.HalfTensor)
267 | else:
268 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
269 |
270 | model = Transformer(model_args)
271 | #delete language encoder
272 | del model.backbone.transformer
273 |
274 | torch.set_default_tensor_type(torch.FloatTensor)
275 |
276 | if bits in ['4bit','8bit']:
277 | from util.quantization import quant_model_bnb
278 | model.layers = quant_model_bnb(model.layers, quant_bit='4bit')
279 |
280 | set_MMAdapter(model, adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)
281 | set_Clip_Adapter(model.backbone.visual, visual_adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)
282 |
283 | model.load_state_dict(checkpoint, strict=False)
284 |
285 | if use_vicuna:
286 | apply_model_delta_online(model,'../data/weights/vicuna_'+llm_model)
287 |
288 |
289 | state_dict={}
290 | for key in adapter_checkpoint['model']:
291 | state_dict[key.replace('module.','')]=adapter_checkpoint['model'][key]
292 |
293 | model.load_state_dict(state_dict, strict=False)
294 | model.to(torch.device('cuda'))
295 |
296 | for name, param in model.named_parameters():
297 | print(name,param.dtype)
298 | generator = LaVIN_Generator(model, tokenizer)
299 | print(f"Loaded in {time.time() - start_time:.2f} seconds")
300 | return generator
301 |
302 | def get_pred_idx(prediction, choices, options):
303 | """
304 | Get the index (e.g. 2) from the prediction (e.g. 'C')
305 | """
306 | if prediction in options[:len(choices)]:
307 | return options.index(prediction)
308 | else:
309 | return random.choice(range(len(choices)))
310 |
311 | def main(
312 | ckpt_dir: str,
313 | tokenizer_path: str,
314 | adapter_path: str,
315 | data_root:str,
316 | caption_file:str,
317 | max_seq_len: int,
318 | max_batch_size: int,
319 | llm_model:str='7B',
320 | generation_temperature: float = 0.1,
321 | top_p: float = 0.75,
322 | split='val',
323 | prompt_format='QCM-ALE',
324 | use_caption=False,
325 | options=["A", "B", "C", "D", "E"],
326 | adapter_type='repattn',
327 | adapter_dim=8,
328 | adapter_scale=1,
329 | n_prompt=10,
330 | hidden_proj=128,
331 | visual_adapter_type='normal',
332 | temperature=10.,
333 | use_vicuna=False,
334 | bits: str='16bits',
335 | cpu_load:bool=False,
336 | ):
337 | print(max_batch_size,max_seq_len)
338 | print('use caption: ',use_caption)
339 | local_rank, world_size = setup_model_parallel()
340 | if local_rank > 0:
341 | sys.stdout = open(os.devnull, "w")
342 |
343 | generator = load(
344 | ckpt_dir,llm_model, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size,
345 | adapter_type,adapter_dim,adapter_scale,hidden_proj,visual_adapter_type,
346 | temperature,use_vicuna,bits=bits,cpu_load=cpu_load)
347 |
348 | print('split: ', split)
349 | problems = json.load(open(os.path.join(data_root, 'problems.json')))
350 | pid_splits = json.load(open(os.path.join(data_root, 'pid_splits.json')))
351 | captions = json.load(open(caption_file))["captions"]
352 | image_path = os.path.join(data_root, 'images', split)
353 | qids = pid_splits['%s' % (split)]
354 | total_items=len(qids)
355 | for qid in problems:
356 | problems[qid]['caption'] = captions[qid] if qid in captions else ""
357 | print('total_items: ',total_items)
358 |
359 |
360 | image_transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
361 |
362 | prompt_args=PromptArgs()
363 | prompt_args.prompt_format = prompt_format
364 | prompt_args.use_caption = use_caption
365 | prompt_args.options = options
366 |
367 | pattern = re.compile(r'The answer is ([A-Z]).')
368 |
369 | answers = []
370 | preds=[]
371 | for i in range(total_items//max_batch_size+1):
372 | print('progresses: ',i,' / ', total_items//max_batch_size+1)
373 | batch_qids=qids[i*max_batch_size:(i+1)*max_batch_size]
374 | if len(batch_qids)==0:
375 | break
376 | indicators = []
377 | prompts=[]
378 | images = []
379 | for qid in batch_qids:
380 | prompt,_ = build_prompt(problems, qid, prompt_args)
381 |
382 | answer=problems[qid]["answer"]
383 | if problems[qid]['image'] is not None:
384 | image = Image.open(os.path.join(image_path, qid, 'image.png')).convert('RGB')
385 | image = image_transforms(image)
386 | indicator = 1
387 | else:
388 | image = torch.Tensor(torch.zeros(3, 224, 224).float())
389 | indicator = 0
390 | prompts.append(prompt)
391 | answers.append(answer)
392 | images.append(image.unsqueeze(0))
393 | indicators.append(indicator)
394 | images=torch.cat(images,0)
395 |
396 |
397 | results = generator.generate(
398 | prompts,images=images,indicators=indicators, max_gen_len=64, temperature=generation_temperature, top_p=top_p,n_feats=n_prompt
399 | )
400 |
401 | for result in results:
402 | pred = pattern.findall(result)
403 |
404 | if len(pred) >= 1:
405 | pred = pred[0] # 'A', 'B', ...
406 | else:
407 | print(result)
408 | pred = "FAILED"
409 | preds.append(pred)
410 |
411 | #evaluations
412 | results={}
413 | correct=0
414 | for i, prediction in enumerate(preds):
415 | pred_idx = get_pred_idx(prediction, problems[qids[i]]["choices"],
416 | prompt_args.options) # 0, 1, ..., 4
417 | if pred_idx == answers[i]:
418 | correct += 1
419 | results[qids[i]] = pred_idx
420 | acc = correct / len(results) * 100
421 | print('overall accuracy: ', acc)
422 |
423 | with open('./preds.json', 'w') as f:
424 | json.dump(results,f)
425 |
426 | scores=get_scores('./preds.json',os.path.join(data_root, 'problems.json'))
427 | print(scores)
428 | import time
429 | with open(str(time.time())+'.txt','w') as f:
430 | f.write(str(scores))
431 |
432 |
433 | if __name__ == "__main__":
434 | fire.Fire(main)
435 |
--------------------------------------------------------------------------------
/eval_mme.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Tuple
5 | import os
6 | import sys
7 | import torch
8 | import fire
9 | import time
10 | import json
11 |
12 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
13 |
14 | from lavin.eval_model import ModelArgs, Transformer
15 | from lavin.tokenizer import Tokenizer
16 | from lavin.generator import LaVIN_Generator
17 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter
18 | from util.base_prompt import build_prompt
19 | from dataclasses import dataclass
20 | import re
21 | import random
22 |
23 | import warnings
24 | import pandas as pd
25 | from PIL import Image
26 |
27 | from torchvision.transforms import transforms
28 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
29 |
30 | from pathlib import Path
31 | import fairscale.nn.model_parallel.initialize as fs_init
32 | import torch.distributed as dist
33 | from util.apply_delta import apply_model_delta_online
34 |
35 | warnings.filterwarnings('ignore')
36 |
37 | @dataclass
38 | class PromptArgs:
39 | prompt_format='QCM-ALE'
40 | use_caption=True
41 | options=["A", "B", "C", "D", "E"]
42 |
43 | def setup_model_parallel() -> Tuple[int, int]:
44 | local_rank = int(os.environ.get("LOCAL_RANK", -1))
45 | world_size = int(os.environ.get("WORLD_SIZE", -1))
46 |
47 | torch.distributed.init_process_group("nccl")
48 | initialize_model_parallel(world_size)
49 | torch.cuda.set_device(local_rank)
50 |
51 | # seed must be the same in all processes
52 | torch.manual_seed(1)
53 | return local_rank, world_size
54 |
55 | def _load_and_redistribute_checkpoint(llama_model_path, model_name):
56 |
57 | with open(Path(llama_model_path) / model_name / 'params.json') as f:
58 | params = json.load(f)
59 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model'))
60 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name))
61 | if model_name=='7B':
62 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu")
63 | return checkpoint, tokenizer, params
64 |
65 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth')
66 | checkpoints = sorted(checkpoints)
67 |
68 | mp_world_size = fs_init.get_model_parallel_world_size()
69 | mp_rank = fs_init.get_model_parallel_rank()
70 | if mp_world_size == len(checkpoints):
71 | print('same number of shards of checkpoints and training, loading directly...')
72 | dist.barrier()
73 | print('[rank=%d, mp_rank=%d] loading from %s' % (dist.get_rank(), mp_rank, checkpoints[mp_rank]))
74 | checkpoint = torch.load(checkpoints[mp_rank], map_location='cpu')
75 | else:
76 | print('different number of shards of checkpoints and training, redistributing...')
77 | if dist.get_rank() == 0:
78 | loaded = []
79 | for x in checkpoints:
80 | print('loading from', x)
81 | loaded.append(torch.load(x, map_location='cpu'))
82 |
83 | full_state_dict = {}
84 | split_dims = {}
85 |
86 | def add_weight_with_split_dim(name, dim):
87 | if dim < 0: # bcast without split
88 | full_state_dict[name] = loaded[0][name].clone()
89 | else:
90 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim)
91 | for x in loaded:
92 | del x[name]
93 | split_dims[name] = dim
94 |
95 | add_weight_with_split_dim('tok_embeddings.weight', 1)
96 | add_weight_with_split_dim('norm.weight', -1)
97 | add_weight_with_split_dim('output.weight', 0)
98 | for i in range(params['n_layers']):
99 | print('gathering layer %d of %d' % (i, params['n_layers']))
100 | layer_prefix = f'layers.{i}.'
101 | bcast_names = [
102 | 'attention_norm.weight',
103 | 'ffn_norm.weight',
104 | ]
105 | column_parallel_names = [
106 | 'attention.wq.weight',
107 | 'attention.wk.weight',
108 | 'attention.wv.weight',
109 | 'feed_forward.w1.weight',
110 | 'feed_forward.w3.weight',
111 | ]
112 | row_parallel_names = [
113 | 'attention.wo.weight',
114 | 'feed_forward.w2.weight',
115 | ]
116 | for key in bcast_names:
117 | add_weight_with_split_dim(layer_prefix + key, -1)
118 | for key in column_parallel_names:
119 | add_weight_with_split_dim(layer_prefix + key, 0)
120 | for key in row_parallel_names:
121 | add_weight_with_split_dim(layer_prefix + key, 1)
122 |
123 | full_state_dict_meta = dict((k, v.shape) for k, v in full_state_dict.items())
124 | dist.broadcast_object_list([full_state_dict_meta, split_dims], src=0)
125 |
126 | else: # dist.get_rank() != 0
127 | recv_objs = [None, None]
128 | dist.broadcast_object_list(recv_objs, src=0)
129 | full_state_dict_meta, split_dims = recv_objs
130 |
131 | local_state_dict = {}
132 | for k in sorted(full_state_dict_meta.keys()):
133 | print('redistributing weights: %s' % k)
134 | if dist.get_rank() == 0:
135 | value = full_state_dict[k].cuda().half()
136 | del full_state_dict[k]
137 | else:
138 | value = torch.empty(full_state_dict_meta[k], device='cuda', dtype=torch.half)
139 | dist.broadcast(value, src=0)
140 | value = value.cpu()
141 | if split_dims[k] < 0:
142 | local_state_dict[k] = value
143 | else:
144 | dim = split_dims[k]
145 | assert dim >= 0 and dim < value.ndim and value.size(dim) % mp_world_size == 0
146 | shard_size = value.size(dim) // mp_world_size
147 | shard_st, shard_ed = shard_size * mp_rank, shard_size * (mp_rank + 1)
148 | # TODO: make more general
149 | if dim == 0:
150 | value = value[shard_st: shard_ed]
151 | elif dim == 1:
152 | value = value[:, shard_st: shard_ed]
153 | else:
154 | raise NotImplementedError()
155 | local_state_dict[k] = value.clone()
156 |
157 | checkpoint = local_state_dict
158 |
159 | return checkpoint, tokenizer, params
160 |
161 |
162 |
163 | def get_acc_with_contion(res_pd, key, values):
164 | if isinstance(values, list):
165 | total_pd = res_pd[res_pd[key].isin(values)]
166 | else:
167 | total_pd = res_pd[res_pd[key] == values]
168 | correct_pd = total_pd[total_pd['true_false'] == True]
169 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
170 | return acc
171 |
172 |
173 | def get_scores(result_file, data_file):
174 | # read result file
175 | results = json.load(open(result_file))
176 | num = len(results)
177 | assert num == 4241
178 |
179 | sqa_data = json.load(open(data_file))
180 |
181 | # construct pandas data
182 | sqa_pd = pd.DataFrame(sqa_data).T
183 | res_pd = sqa_pd[sqa_pd['split'] == 'test'] # test set
184 |
185 | # update data
186 | for index, row in res_pd.iterrows():
187 |
188 | res_pd.loc[index, 'no_context'] = True if (not row['hint'] and not row['image']) else False
189 | res_pd.loc[index, 'has_text'] = True if row['hint'] else False
190 | res_pd.loc[index, 'has_image'] = True if row['image'] else False
191 | res_pd.loc[index, 'has_text_image'] = True if (row['hint'] and row['image']) else False
192 |
193 | label = row['answer']
194 | pred = int(results[index])
195 | res_pd.loc[index, 'pred'] = pred
196 | res_pd.loc[index, 'true_false'] = (label == pred)
197 |
198 | # accuracy scores
199 | acc_average = len(res_pd[res_pd['true_false'] == True]) / num * 100
200 |
201 | scores = {
202 | 'acc_natural':
203 | get_acc_with_contion(res_pd, 'subject', 'natural science'),
204 | 'acc_social':
205 | get_acc_with_contion(res_pd, 'subject', 'social science'),
206 | 'acc_language':
207 | get_acc_with_contion(res_pd, 'subject', 'language science'),
208 | 'acc_has_text':
209 | get_acc_with_contion(res_pd, 'has_text', True),
210 | 'acc_has_image':
211 | get_acc_with_contion(res_pd, 'has_image', True),
212 | 'acc_no_context':
213 | get_acc_with_contion(res_pd, 'no_context', True),
214 | 'acc_grade_1_6':
215 | get_acc_with_contion(res_pd, 'grade', ['grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6']),
216 | 'acc_grade_7_12':
217 | get_acc_with_contion(res_pd, 'grade', ['grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12']),
218 | 'acc_average':
219 | "{:.2f}".format(acc_average),
220 | }
221 |
222 | return scores
223 |
224 |
225 | def print_scores(scores):
226 | latex_output = ""
227 | for key, score in scores.items():
228 | print(f"{key[4:]}: \t{score}")
229 | latex_output += f"& {score} "
230 | latex_output += "\\\\"
231 | print(latex_output)
232 |
233 | def load(
234 | ckpt_dir: str,
235 | llm_model:str,
236 | tokenizer_path: str,
237 | adapter_path: str,
238 | local_rank: int,
239 | world_size: int,
240 | max_seq_len: int,
241 | max_batch_size: int,
242 | adapter_type: str,
243 | adapter_dim:int,
244 | adapter_scale:float,
245 | hidden_proj:int,
246 | visual_adapter_type: str,
247 | temperature: float,
248 | use_vicuna: bool
249 | ) -> LaVIN_Generator:
250 | start_time = time.time()
251 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(ckpt_dir, llm_model)
252 |
253 | print("Loading")
254 | adapter_checkpoint = torch.load(adapter_path, map_location="cpu")
255 |
256 |
257 | model_args: ModelArgs = ModelArgs(
258 | max_seq_len=max_seq_len, max_batch_size=max_batch_size,hidden_proj=hidden_proj, **params
259 | )
260 | model_args.vocab_size = tokenizer.n_words
261 |
262 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
263 | model = Transformer(model_args)
264 | set_MMAdapter(model, adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)
265 | set_Clip_Adapter(model.backbone.visual, visual_adapter_type, dim=adapter_dim, s=adapter_scale,t=temperature)
266 |
267 | torch.set_default_tensor_type(torch.FloatTensor)
268 | model.load_state_dict(checkpoint, strict=False)
269 |
270 | if use_vicuna:
271 | apply_model_delta_online(model,'../data/weights/vicuna_'+llm_model)
272 |
273 | state_dict={}
274 | for key in adapter_checkpoint['model']:
275 | state_dict[key.replace('module.','')]=adapter_checkpoint['model'][key]
276 |
277 | model.load_state_dict(state_dict, strict=False)
278 |
279 | generator = LaVIN_Generator(model, tokenizer)
280 | print(f"Loaded in {time.time() - start_time:.2f} seconds")
281 | return generator
282 |
283 | def get_pred_idx(prediction, choices, options):
284 | """
285 | Get the index (e.g. 2) from the prediction (e.g. 'C')
286 | """
287 | if prediction in options[:len(choices)]:
288 | return options.index(prediction)
289 | else:
290 | return random.choice(range(len(choices)))
291 |
292 | def prepare_data(dir):
293 | if os.path.exists(os.path.join(dir,'images')):
294 | image_dir=os.path.join(dir,'images')
295 | else:
296 | image_dir=dir
297 | if os.path.exists(os.path.join(dir,'questions_answers_YN')):
298 | ann_dir=os.path.join(dir,'questions_answers_YN')
299 | else:
300 | ann_dir=dir
301 | image_list=[]
302 | image_path=[]
303 | for root, dirs, files in os.walk(image_dir):
304 | for file in files:
305 | # 检查文件后缀名是否为.jpg或.png
306 | if file.endswith(".jpg") or file.endswith(".png"):
307 | # 拼接文件的完整路径
308 | image_list.append(file)
309 | image_path.append(os.path.join(image_dir,file))
310 | ann_list=[]
311 | for img_id in image_list:
312 | ann_file=img_id.replace('.jpg','.txt').replace('.png','.txt')
313 | ann={}
314 | with open(os.path.join(ann_dir,ann_file)) as f:
315 | pos=f.readline().split('\t')[0]
316 | neg=f.readline().split('\t')[0]
317 | ann['pos']=pos
318 | ann['neg']=neg
319 | ann_list.append(ann)
320 |
321 | return image_path,ann_list
322 |
323 |
324 | def main(
325 | ckpt_dir: str,
326 | tokenizer_path: str,
327 | adapter_path: str,
328 | data_root:str,
329 | caption_file:str,
330 | max_seq_len: int,
331 | max_batch_size: int,
332 | llm_model:str='7B',
333 | generation_temperature: float = 0.1,
334 | top_p: float = 0.75,
335 | split='val',
336 | prompt_format='QCM-ALE',
337 | use_caption=False,
338 | options=["A", "B", "C", "D", "E"],
339 | adapter_type='repattn',
340 | adapter_dim=8,
341 | adapter_scale=1,
342 | n_prompt=10,
343 | hidden_proj=128,
344 | visual_adapter_type='normal',
345 | temperature=10.,
346 | use_vicuna=False,
347 | root_dir_='../data/mme'
348 | ):
349 | print(max_batch_size,max_seq_len)
350 | print('use caption: ',use_caption)
351 | local_rank, world_size = setup_model_parallel()
352 | if local_rank > 0:
353 | sys.stdout = open(os.devnull, "w")
354 |
355 | generator = load(
356 | ckpt_dir,llm_model, tokenizer_path, adapter_path, local_rank, world_size, max_seq_len, max_batch_size,
357 | adapter_type,adapter_dim,adapter_scale,hidden_proj,visual_adapter_type,
358 | temperature,use_vicuna)
359 |
360 | subsets=os.listdir(root_dir_)
361 | total_score=0
362 | cognition_score=0
363 | perception_score=0
364 | for subset in subsets:
365 | root_dir=os.path.join(root_dir_,subset)
366 | print('split: ', subset)
367 | img_list,ann_list=prepare_data(root_dir)
368 | qids=range(len(img_list))
369 | total_items=len(img_list)
370 | print('total_items: ',total_items)
371 |
372 |
373 | image_transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
374 |
375 | prompt_args=PromptArgs()
376 | prompt_args.prompt_format = prompt_format
377 | prompt_args.use_caption = use_caption
378 | prompt_args.options = options
379 |
380 | pattern = re.compile(r'The answer is ([A-Z]).')
381 |
382 | answers = []
383 | preds=[]
384 | max_batch_size=8
385 | for i in range(total_items//max_batch_size+1):
386 | batch_qids=qids[i*max_batch_size:(i+1)*max_batch_size]
387 | if len(batch_qids)==0:
388 | break
389 | indicators = []
390 | prompts=[]
391 | images = []
392 | for qid in batch_qids:
393 | #pos
394 | prompt= 'Instruction: '+ ann_list[qid]['pos']+'\n'+\
395 | 'Response: '
396 | prompt = prompt.replace(" ", " ").strip()
397 | answer='yes'
398 | image = Image.open(img_list[qid]).convert('RGB')
399 | image = image_transforms(image)
400 | indicator = 1
401 | prompts.append(prompt)
402 | answers.append(answer)
403 | images.append(image.unsqueeze(0))
404 | indicators.append(indicator)
405 |
406 | #neg
407 | prompt= 'Instruction: '+ ann_list[qid]['neg']+'\n'+\
408 | 'Response: '
409 | prompt = prompt.replace(" ", " ").strip()
410 | answer='no'
411 | indicator = 1
412 | prompts.append(prompt)
413 | answers.append(answer)
414 | images.append(image.unsqueeze(0))
415 | indicators.append(indicator)
416 | images=torch.cat(images,0)
417 |
418 |
419 | results = generator.generate(
420 | prompts,images=images,indicators=indicators, max_gen_len=20, temperature=generation_temperature, top_p=top_p,n_feats=n_prompt
421 | )
422 |
423 | for result in results:
424 | result=result.lower().strip().split('response:')[1]
425 | if 'yes' in result[:4]:
426 | pred='yes'
427 | elif 'no' in result[:4]:
428 | pred='no'
429 | else:
430 | pred='other'
431 | preds.append(pred)
432 |
433 | #evaluations
434 | correct=0
435 | corrects=[]
436 | assert len(preds)==len(answers)
437 | for i, prediction in enumerate(preds):
438 | if prediction == answers[i]:
439 | correct += 1
440 | corrects.append(1)
441 | else:
442 | corrects.append(0)
443 | import numpy as np
444 | corrects=np.array(corrects)
445 | acc = correct / len(preds) * 100
446 | acc_plus= (corrects.reshape(-1,2).sum(1)==2).sum()/ (len(preds)//2)* 100
447 | total_score+=acc
448 | total_score+=acc_plus
449 | if subset in ['commonsense_reasoning','numerical_calculation','text_translation','code_reasoning']:
450 | cognition_score+=acc_plus
451 | cognition_score+=acc
452 | else:
453 | perception_score+=acc_plus
454 | perception_score+=acc
455 | print('subset: ', subset)
456 | print('overall accuracy: ', acc)
457 | print('overall accuracy+: ', acc_plus)
458 | with open('mme_eval.txt','a') as f:
459 | f.write('subset: '+ subset+'\n')
460 | f.write('accuracy: '+ str(acc)+'\n')
461 | f.write('accuracy+: '+ str(acc_plus)+'\n')
462 |
463 | print('total_score: ',total_score)
464 | print('perception_score: ',perception_score)
465 | print('cognition_score: ',cognition_score)
466 | with open('mme_eval.txt', 'a') as f:
467 | f.write('total_score: ' + str(total_score) + '\n')
468 | f.write('perception_score: ' + str(perception_score) + '\n')
469 | f.write('cognition_score: ' + str(cognition_score) + '\n')
470 |
471 | if __name__ == "__main__":
472 | fire.Fire(main)
473 |
--------------------------------------------------------------------------------
/lavin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from .generator import LaVIN_Generator
5 | from .model import ModelArgs, Transformer
6 | from .tokenizer import Tokenizer
7 |
--------------------------------------------------------------------------------
/lavin/eval_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional, Tuple
5 | from dataclasses import dataclass
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 | import clip
12 | import fairscale.nn.model_parallel.initialize as fs_init
13 | from fairscale.nn.model_parallel.layers import (
14 | ParallelEmbedding,
15 | RowParallelLinear,
16 | ColumnParallelLinear,
17 | )
18 | from lavin.model import AdapterMLP
19 |
20 | @dataclass
21 | class ModelArgs:
22 | dim: int = 512
23 | n_layers: int = 8
24 | n_heads: int = 8
25 | vocab_size: int = -1 # defined later by tokenizer
26 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
27 | norm_eps: float = 1e-5
28 | hidden_proj: int=128
29 |
30 | max_batch_size: int = 32
31 | max_seq_len: int = 2048
32 |
33 |
34 |
35 | class RMSNorm(torch.nn.Module):
36 | def __init__(self, dim: int, eps: float = 1e-6):
37 | super().__init__()
38 | self.eps = eps
39 | self.weight = nn.Parameter(torch.ones(dim))
40 |
41 | def _norm(self, x):
42 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
43 |
44 | def forward(self, x):
45 | output = self._norm(x.float()).type_as(x)
46 | return output * self.weight
47 |
48 |
49 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
50 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
51 | t = torch.arange(end, device=freqs.device) # type: ignore
52 | freqs = torch.outer(t, freqs).float() # type: ignore
53 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
54 | return freqs_cis
55 |
56 |
57 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
58 | ndim = x.ndim
59 | assert 0 <= 1 < ndim
60 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
61 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
62 | return freqs_cis.view(*shape)
63 |
64 |
65 | def apply_rotary_emb(
66 | xq: torch.Tensor,
67 | xk: torch.Tensor,
68 | freqs_cis: torch.Tensor,
69 | ) -> Tuple[torch.Tensor, torch.Tensor]:
70 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
71 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
72 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
73 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
74 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
75 | return xq_out.type_as(xq), xk_out.type_as(xk)
76 |
77 |
78 | class Attention(nn.Module):
79 | def __init__(self, args: ModelArgs):
80 | super().__init__()
81 |
82 | self.n_local_heads = args.n_heads // fs_init.get_model_parallel_world_size()
83 | self.head_dim = args.dim // args.n_heads
84 |
85 | self.wq = ColumnParallelLinear(
86 | args.dim,
87 | args.n_heads * self.head_dim,
88 | bias=False,
89 | gather_output=False,
90 | init_method=lambda x: x,
91 | )
92 | self.wk = ColumnParallelLinear(
93 | args.dim,
94 | args.n_heads * self.head_dim,
95 | bias=False,
96 | gather_output=False,
97 | init_method=lambda x: x,
98 | )
99 | self.wv = ColumnParallelLinear(
100 | args.dim,
101 | args.n_heads * self.head_dim,
102 | bias=False,
103 | gather_output=False,
104 | init_method=lambda x: x,
105 | )
106 | self.wo = RowParallelLinear(
107 | args.n_heads * self.head_dim,
108 | args.dim,
109 | bias=False,
110 | input_is_parallel=True,
111 | init_method=lambda x: x,
112 | )
113 |
114 | self.cache_k = torch.zeros(
115 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
116 | ).cuda()
117 | self.cache_v = torch.zeros(
118 | (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
119 | ).cuda()
120 | self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
121 |
122 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
123 | bsz, seqlen, _ = x.shape
124 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
125 |
126 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
127 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
128 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
129 |
130 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
131 |
132 | self.cache_k = self.cache_k.to(xq)
133 | self.cache_v = self.cache_v.to(xq)
134 |
135 | #add modilaty embedding
136 | if start_pos==0:
137 | self.cache_k[:bsz, start_pos : start_pos + seqlen-1] = xk[:,1:]
138 | self.cache_v[:bsz, start_pos : start_pos + seqlen-1] = xv[:,1:]
139 |
140 | keys = xk
141 | values = xv
142 | else:
143 | self.cache_k[:bsz, start_pos: start_pos + seqlen] = xk
144 | self.cache_v[:bsz, start_pos: start_pos + seqlen] = xv
145 |
146 | keys = self.cache_k[:bsz, : start_pos + seqlen]
147 | values = self.cache_v[:bsz, : start_pos + seqlen]
148 |
149 |
150 | xq = xq.transpose(1, 2)
151 | keys = keys.transpose(1, 2)
152 | values = values.transpose(1, 2)
153 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
154 | if mask is not None:
155 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
156 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
157 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
158 |
159 | output = output.transpose(
160 | 1, 2
161 | ).contiguous().view(bsz, seqlen, -1)
162 |
163 | return self.wo(output)
164 |
165 |
166 | class FeedForward(nn.Module):
167 | def __init__(
168 | self,
169 | dim: int,
170 | hidden_dim: int,
171 | multiple_of: int,
172 | ):
173 | super().__init__()
174 | hidden_dim = int(2 * hidden_dim / 3)
175 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
176 |
177 | self.w1 = ColumnParallelLinear(
178 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
179 | )
180 | self.w2 = RowParallelLinear(
181 | hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
182 | )
183 | self.w3 = ColumnParallelLinear(
184 | dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
185 | )
186 |
187 | def forward(self, x):
188 | return self.w2(F.silu(self.w1(x)) * self.w3(x))
189 |
190 |
191 | class TransformerBlock(nn.Module):
192 | def __init__(self, layer_id: int, args: ModelArgs):
193 | super().__init__()
194 | self.n_heads = args.n_heads
195 | self.dim = args.dim
196 | self.head_dim = args.dim // args.n_heads
197 | self.attention = Attention(args)
198 | self.feed_forward = FeedForward(
199 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
200 | )
201 | self.layer_id = layer_id
202 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
203 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
204 |
205 | self.drop_path = nn.Identity()
206 | self.cache_weights = torch.zeros(
207 | (args.max_batch_size, 2)
208 | ).cuda()
209 | self.cache_weights_ffn = torch.zeros(
210 | (args.max_batch_size, 2)
211 | ).cuda()
212 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
213 | h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter)
214 | out = h + self.feed_forward.forward(self.ffn_norm(h))
215 | return out
216 |
217 | from torch.cuda.amp import autocast
218 | class Transformer(nn.Module):
219 | def __init__(self, params: ModelArgs):
220 | super().__init__()
221 | self.params = params
222 | self.vocab_size = params.vocab_size
223 | self.n_layers = params.n_layers
224 |
225 | self.tok_embeddings = ParallelEmbedding(
226 | params.vocab_size, params.dim, init_method=lambda x: x
227 | )
228 |
229 | self.layers = torch.nn.ModuleList()
230 | for layer_id in range(params.n_layers):
231 | self.layers.append(TransformerBlock(layer_id, params))
232 |
233 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
234 | self.output = ColumnParallelLinear(
235 | params.dim, params.vocab_size, bias=False, init_method=lambda x: x
236 | )
237 |
238 | self.freqs_cis = precompute_freqs_cis(
239 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
240 | )
241 |
242 | self.backbone = clip.load('ViT-L/14')[0]
243 |
244 | self.adapter_proj = AdapterMLP(1024, params.hidden_proj, params.dim).float()
245 | self.adapter_modality_embedding=nn.Embedding(2,params.dim).float()
246 |
247 | @torch.inference_mode()
248 | def forward(self, tokens: torch.Tensor, start_pos: int):
249 | with autocast():
250 | _bsz, seqlen,_ = tokens.shape
251 | # h = self.tok_embeddings(tokens)
252 | h=tokens
253 | self.freqs_cis = self.freqs_cis.to(h.device)
254 | freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
255 |
256 | mask = None
257 | if seqlen > 1:
258 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=tokens.device)
259 | mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
260 | # mask decision token
261 | mask[:, :, 1:, 0] = float("-inf")
262 |
263 | for layer in self.layers:
264 | h = layer(h, start_pos, freqs_cis,mask)
265 |
266 | h = self.norm(h)
267 | output = self.output(h[:, -1, :]) # only compute last logits
268 | return output.float()
269 |
--------------------------------------------------------------------------------
/lavin/generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import List
5 |
6 | import torch
7 |
8 | from lavin.tokenizer import Tokenizer
9 | from lavin.eval_model import Transformer
10 | from torch.cuda.amp import autocast
11 |
12 | class LaVIN_Generator:
13 | def __init__(self, model: Transformer, tokenizer: Tokenizer):
14 | self.model = model
15 | self.tokenizer = tokenizer
16 | # self.backbone = clip.load('ViT-B/16', device='cpu')[0]
17 |
18 | def insert_image_embeds(self,examples,image_embeds,prefix_img,prefix_nonimg,img_indicators):
19 | _bsz, seqlen,_ = examples.shape
20 | new_examples=[]
21 | for i, example in enumerate(examples):
22 | if img_indicators[i]>0.:
23 | new_example=torch.cat([example[:1],prefix_img,image_embeds[i],example[1:]],0)
24 | new_example = new_example[:seqlen]
25 | else:
26 | new_example=torch.cat([example[:1],prefix_nonimg,example[1:]],0)
27 | new_example = new_example[:seqlen]
28 | new_examples.append(new_example.unsqueeze(0))
29 | new_examples = torch.cat(new_examples, 0)
30 | return new_examples
31 |
32 | @torch.inference_mode()
33 | def generate(
34 | self,
35 | prompts: List[str],
36 | images: torch.Tensor,
37 | indicators: List[int],
38 | max_gen_len: int,
39 | n_feats: int=3,
40 | temperature: float = 0.8,
41 | top_p: float = 0.95,
42 | ) -> List[str]:
43 | bsz = len(prompts)
44 | params = self.model.params
45 | assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
46 | self.model.eval()
47 |
48 | prefix_img_token = self.tokenizer.encode("Image: ", bos=True, eos=False)
49 | non_prefix_img_token= self.tokenizer.encode("Image: N/A", bos=True, eos=False)
50 |
51 | images=images.cuda()
52 | self.model.backbone.cuda()
53 |
54 | image_embeds= self.model.backbone.encode_image(images)
55 | image_embeds=self.model.adapter_proj(image_embeds)
56 |
57 |
58 | prompt_tokens=[]
59 | for i,x in enumerate(prompts):
60 | if indicators[i]==1:
61 | token_idx=prefix_img_token+[0]*n_feats+self.tokenizer.encode(x, bos=False, eos=False)
62 | else:
63 | token_idx = non_prefix_img_token + self.tokenizer.encode(x, bos=False, eos=False)
64 | prompt_tokens.append(token_idx)
65 |
66 |
67 | min_prompt_size = min([len(t) for t in prompt_tokens])
68 | max_prompt_size = max([len(t) for t in prompt_tokens])
69 |
70 | total_len = min(params.max_seq_len, max_gen_len + max_prompt_size)
71 |
72 | tokens = torch.full((bsz, total_len), 0).cuda().long()
73 | input_text_mask=torch.zeros_like(tokens).bool()
74 |
75 | for k, t in enumerate(prompt_tokens):
76 | t=t[:total_len]
77 | tokens[k, : len(t)] = torch.tensor(t).long()
78 | input_text_mask[k,:len(t)]=True
79 |
80 | token_embeds=self.model.tok_embeddings(tokens)
81 | indicators=torch.Tensor(indicators).cuda().long()
82 | modality_embedding=self.model.adapter_modality_embedding(indicators).unsqueeze(1)
83 |
84 | for i in range(len(token_embeds)):
85 | if indicators[i]==1:
86 | pos=len(prefix_img_token)
87 | #insert image emebedding into the sequence
88 | image_token_embed=torch.cat([token_embeds[i,:pos],image_embeds[i],token_embeds[i,pos+n_feats:]],0)
89 | token_embeds[i]=image_token_embed
90 |
91 |
92 |
93 | start_pos = min_prompt_size
94 | prev_pos = 0
95 | for cur_pos in range(start_pos, total_len):
96 |
97 | if prev_pos==0:
98 | h=torch.cat([modality_embedding,token_embeds[:,prev_pos:cur_pos]],1)
99 | else:
100 | h=token_embeds[:,prev_pos:cur_pos]
101 | logits = self.model.forward(h, prev_pos)
102 | if temperature > 0:
103 | probs = torch.softmax(logits / temperature, dim=-1)
104 | next_token = sample_top_p(probs, top_p)
105 | else:
106 | next_token = torch.argmax(logits, dim=-1)
107 | next_token = next_token.reshape(-1)
108 | # only replace token if prompt has already been generated
109 |
110 | next_token_embeds = torch.where(
111 | input_text_mask[:, cur_pos,None], token_embeds[:, cur_pos], self.model.tok_embeddings(next_token)
112 | )
113 | token_embeds[:,cur_pos]=next_token_embeds
114 |
115 | next_token = torch.where(
116 | input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
117 | )
118 | tokens[:, cur_pos] = next_token
119 |
120 | prev_pos = cur_pos
121 |
122 | decoded = []
123 | for i, t in enumerate(tokens.tolist()):
124 | # cut to max gen len
125 | t = t[: len(prompt_tokens[i]) + max_gen_len]
126 | # cut to eos tok if any
127 | try:
128 | t = t[: t.index(self.tokenizer.eos_id)]
129 | except ValueError:
130 | pass
131 | decoded.append(self.tokenizer.decode(t))
132 |
133 |
134 | return decoded
135 |
136 |
137 | def sample_top_p(probs, p):
138 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
139 | probs_sum = torch.cumsum(probs_sort, dim=-1)
140 | mask = probs_sum - probs_sort > p
141 | probs_sort[mask] = 0.0
142 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
143 | next_token = torch.multinomial(probs_sort, num_samples=1)
144 | next_token = torch.gather(probs_idx, -1, next_token)
145 | return next_token
146 |
--------------------------------------------------------------------------------
/lavin/mm_adaptation.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 | import json
5 | from lavin import ModelArgs, Tokenizer, Transformer
6 | from lavin.mm_adapter import set_MMAdapter,set_Clip_Adapter
7 |
8 | from pathlib import Path
9 | from util.apply_delta import apply_model_delta_online
10 |
11 |
12 |
13 | def _load_and_redistribute_checkpoint(llama_model_path, model_name):
14 |
15 | with open(Path(llama_model_path) / model_name / 'params.json') as f:
16 | params = json.load(f)
17 | tokenizer = Tokenizer(model_path=str(Path(llama_model_path) / 'tokenizer.model'))
18 | print('Using model path: %s, model_name: %s' % (llama_model_path, model_name))
19 | if model_name=='7B':
20 | checkpoint = torch.load(llama_model_path + model_name + '/consolidated.00.pth', map_location="cpu")
21 | return checkpoint, tokenizer, params
22 |
23 |
24 | checkpoints = (Path(llama_model_path) / model_name).glob('*.pth')
25 | checkpoints = sorted(checkpoints)
26 |
27 |
28 | loaded = []
29 | for x in checkpoints:
30 | print('loading from', x)
31 | loaded.append(torch.load(x, map_location='cpu'))
32 |
33 | full_state_dict = {}
34 | split_dims = {}
35 |
36 | def add_weight_with_split_dim(name, dim):
37 | if dim < 0: # bcast without split
38 | full_state_dict[name] = loaded[0][name].clone()
39 | else:
40 | full_state_dict[name] = torch.cat([x[name] for x in loaded], dim=dim)
41 | for x in loaded:
42 | del x[name]
43 | split_dims[name] = dim
44 |
45 | add_weight_with_split_dim('tok_embeddings.weight', 1)
46 | add_weight_with_split_dim('norm.weight', -1)
47 | add_weight_with_split_dim('output.weight', 0)
48 | for i in range(params['n_layers']):
49 | print('gathering layer %d of %d' % (i, params['n_layers']))
50 | layer_prefix = f'layers.{i}.'
51 | bcast_names = [
52 | 'attention_norm.weight',
53 | 'ffn_norm.weight',
54 | ]
55 | column_parallel_names = [
56 | 'attention.wq.weight',
57 | 'attention.wk.weight',
58 | 'attention.wv.weight',
59 | 'feed_forward.w1.weight',
60 | 'feed_forward.w3.weight',
61 | ]
62 | row_parallel_names = [
63 | 'attention.wo.weight',
64 | 'feed_forward.w2.weight',
65 | ]
66 | for key in bcast_names:
67 | add_weight_with_split_dim(layer_prefix + key, -1)
68 | for key in column_parallel_names:
69 | add_weight_with_split_dim(layer_prefix + key, 0)
70 | for key in row_parallel_names:
71 | add_weight_with_split_dim(layer_prefix + key, 1)
72 |
73 | checkpoint=full_state_dict
74 |
75 |
76 | return checkpoint, tokenizer, params
77 |
78 | def LaVIN(args):
79 |
80 | llama_model_path =args.llama_model_path
81 | model_name = args.llm_model
82 |
83 | checkpoint, tokenizer, params = _load_and_redistribute_checkpoint(llama_model_path, model_name)
84 |
85 |
86 | model_args: ModelArgs = ModelArgs(
87 | max_seq_len=args.max_seq_len, max_batch_size=32,hidden_proj=args.hidden_proj,drop_path=args.drop_path, **params
88 | )
89 |
90 | model_args.vocab_size = tokenizer.n_words
91 |
92 | if args.cpu_load:
93 | #cpu load is slow, but is freindly for GPU with limited memory.
94 | torch.set_default_tensor_type(torch.HalfTensor)
95 | else:
96 | torch.set_default_tensor_type(torch.cuda.HalfTensor)
97 |
98 | llama = Transformer(model_args)
99 |
100 | #delete language encoder
101 | del llama.backbone.transformer
102 |
103 | torch.set_default_tensor_type(torch.FloatTensor)
104 |
105 | if args.bits in ['4bit','8bit']:
106 | from util.quantization import quant_model_bnb
107 | llama.layers=quant_model_bnb(llama.layers,quant_bit=args.bits)
108 |
109 | llama.load_state_dict(checkpoint, strict=False)
110 | if args.use_vicuna:
111 | apply_model_delta_online(llama,'../data/weights/vicuna_'+args.llm_model)
112 |
113 |
114 | if args.adapter_type=='block' or args.adapter_type=='attn':
115 | set_MMAdapter(llama,args.adapter_type,dim=args.adapter_dim,s=args.adapter_scale,t=args.temperature,gradient_checkpointing=args.gradient_checkpointing)
116 | set_Clip_Adapter(llama.backbone.visual,args.visual_adapter_type,dim=args.adapter_dim,s=args.adapter_scale,t=args.temperature)
117 |
118 |
119 |
120 |
121 | learnable_keys=['adapter']
122 | total=0.
123 | trainable_names=[]
124 | for name, param in llama.named_parameters():
125 | for key in learnable_keys:
126 |
127 | if key in name:
128 | param.requires_grad = True
129 | param.data = param.data.float()
130 | total += param.nelement()
131 | trainable_names.append(name)
132 | else:
133 | param.requires_grad = False
134 | print(trainable_names)
135 | print(' + Number of trainable params: %.2fM' % (total / 1e6))
136 | return llama
137 |
138 |
--------------------------------------------------------------------------------
/lavin/mm_adapter.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 | import lavin
5 | from typing import Optional, Tuple
6 | from torch.cuda.amp import autocast
7 | import lavin.eval_model
8 |
9 |
10 | class RepAdapter_Router(nn.Module):
11 | """ Pytorch Implemention of RepAdapter for 1d tensor"""
12 |
13 | def __init__(
14 | self,
15 | in_features=768,
16 | hidden_dim=8,
17 | groups=2,
18 | scale=1,
19 | t=10.
20 | ):
21 | super().__init__()
22 | self.conv_A=nn.Conv1d(in_features,hidden_dim,1,groups=1,bias=True)
23 | self.conv_B = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True)
24 |
25 | self.conv_D = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True)
26 |
27 | self.expert_weights=nn.Linear(in_features,2)
28 |
29 | self.dropout=nn.Dropout(0.1)
30 | self.groups=groups
31 | self.scale=scale
32 | self.t=t
33 |
34 | nn.init.xavier_uniform_( self.conv_A.weight)
35 | nn.init.zeros_(self.conv_A.bias)
36 | nn.init.zeros_(self.conv_B.weight)
37 | nn.init.zeros_(self.conv_B.bias)
38 |
39 |
40 | nn.init.zeros_(self.conv_D.weight)
41 | nn.init.zeros_(self.conv_D.bias)
42 |
43 | def forward(self, x,weights=None):
44 | with autocast():
45 | if weights is None:
46 | weights=torch.softmax(self.expert_weights(x[:,0])/self.t,-1).half()
47 | x=x.transpose(1,2)
48 | x_=self.dropout(self.conv_A(x))
49 | x=self.conv_B(x_)*self.scale*weights[:,0,None,None]+self.conv_D(x_)*self.scale*weights[:,1,None,None]+x
50 | x=x.transpose(1,2).contiguous()
51 | return x
52 |
53 |
54 |
55 | class RepAdapter(nn.Module):
56 | """
57 | Pytorch Implemention of RepAdapter for 1d tensor
58 | copy from https://github.com/luogen1996/RepAdapter/blob/main/repadapter.py
59 | """
60 |
61 | def __init__(
62 | self,
63 | in_features=768,
64 | hidden_dim=8,
65 | groups=2,
66 | scale=1
67 | ):
68 | super().__init__()
69 | self.conv_A=nn.Conv1d(in_features,hidden_dim,1,groups=1,bias=True)
70 | self.conv_B = nn.Conv1d(hidden_dim, in_features, 1, groups=groups, bias=True)
71 |
72 | self.dropout=nn.Dropout(0.1)
73 | self.groups=groups
74 | self.scale=scale
75 |
76 | nn.init.xavier_uniform_( self.conv_A.weight)
77 | nn.init.zeros_(self.conv_A.bias)
78 | nn.init.zeros_(self.conv_B.weight)
79 | nn.init.zeros_(self.conv_B.bias)
80 |
81 | def forward(self, x,weights=None):
82 | with autocast():
83 | x=x.transpose(1,2)
84 | x=self.conv_B(self.dropout(self.conv_A(x)))
85 | x=x.transpose(1,2).contiguous()
86 | return x.float()
87 |
88 |
89 | def forward_llama_block(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
90 | if self.training and self.gradient_checkpointing:
91 | h = x + self.drop_path(torch.utils.checkpoint.checkpoint(self.attention, self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask))
92 | out = h + self.drop_path(torch.utils.checkpoint.checkpoint(self.feed_forward, self.adapter_mlp(self.ffn_norm(h))))
93 | else:
94 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask, adapter))
95 | out = h + self.drop_path(self.feed_forward.forward(self.adapter_mlp(self.ffn_norm(h))))
96 | return out
97 |
98 | def forward_llama_attn(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
99 | if self.training and self.gradient_checkpointing:
100 | h = x + self.drop_path(torch.utils.checkpoint.checkpoint(self.attention, self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask))
101 | out = h + self.drop_path(torch.utils.checkpoint.checkpoint(self.feed_forward, self.ffn_norm(h)))
102 | else:
103 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x)), start_pos, freqs_cis, mask, adapter))
104 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h)))
105 | return out
106 | def forward_llama_attn_cache(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
107 | bs_=x.shape[0]
108 | if start_pos==0:
109 | self.cache_weights[:bs_]=torch.softmax(self.adapter_attn.expert_weights(self.attention_norm(x)[:,0].float())/self.t,-1).half()
110 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x),weights=self.cache_weights[:bs_]), start_pos, freqs_cis, mask, adapter))
111 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h)))
112 | return out
113 |
114 | def forward_llama_block_cache(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
115 | bs_=x.shape[0]
116 | if start_pos==0:
117 | self.cache_weights[:bs_]=torch.softmax(self.adapter_attn.expert_weights(self.attention_norm(x)[:,0].float())/self.t,-1).half()
118 | self.cache_weights_ffn[:bs_]=torch.softmax(self.adapter_mlp.expert_weights(self.ffn_norm(x)[:,0].float())/self.t,-1).half()
119 | h = x + self.drop_path(self.attention.forward(self.adapter_attn(self.attention_norm(x),weights=self.cache_weights[:bs_]), start_pos, freqs_cis, mask, adapter))
120 | out = h + self.drop_path(self.feed_forward.forward(self.adapter_mlp(self.ffn_norm(h),self.cache_weights_ffn[:bs_])))
121 | return out
122 |
123 | def forward_clip(self, x: torch.Tensor):
124 | x = x + self.attention(self.adapter_attn(self.ln_1(x)))
125 | x = x + self.mlp(self.ln_2(x))
126 | return x
127 |
128 | def forward_clip_full(self, x: torch.Tensor):
129 | x = x + self.attention(self.adapter_attn(self.ln_1(x)))
130 | x = x + self.mlp(self.adapter_mlp(self.ln_2(x)))
131 | return x
132 |
133 |
134 | def set_MMAdapter(model, method, dim=8, s=1, set_forward=True,t=10,gradient_checkpointing=False):
135 | if method == 'block':
136 | # not support right now
137 | assert NotImplementedError
138 | for _ in model.children():
139 | if type(_) == lavin.model.TransformerBlock or type(_) == lavin.eval_model.TransformerBlock:
140 | _.adapter_attn = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t)
141 | _.adapter_mlp = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t)
142 | _.s = s
143 | _.t = t
144 | _.gradient_checkpointing=gradient_checkpointing
145 | if type(_) == lavin.eval_model.TransformerBlock:
146 | bound_method = forward_llama_block_cache.__get__(_, _.__class__)
147 | else:
148 | bound_method = forward_llama_block.__get__(_, _.__class__)
149 | if set_forward:
150 | setattr(_, 'forward', bound_method)
151 | elif len(list(_.children())) != 0:
152 | set_MMAdapter(_, method, dim, s,set_forward=set_forward,t=t,gradient_checkpointing=gradient_checkpointing)
153 |
154 | else:
155 | for _ in model.children():
156 | if type(_) == lavin.model.TransformerBlock or type(_) == lavin.eval_model.TransformerBlock:
157 | _.adapter_attn = RepAdapter_Router(_.dim,hidden_dim=dim,scale=s,t=t)
158 | _.s = s
159 | _.t=t
160 | _.gradient_checkpointing = gradient_checkpointing
161 | if type(_) == lavin.eval_model.TransformerBlock:
162 | bound_method = forward_llama_attn_cache.__get__(_, _.__class__)
163 | else:
164 | bound_method = forward_llama_attn.__get__(_, _.__class__)
165 | if set_forward:
166 | setattr(_, 'forward', bound_method)
167 | elif len(list(_.children())) != 0:
168 | set_MMAdapter(_, method, dim, s, set_forward=set_forward,t=t,gradient_checkpointing=gradient_checkpointing)
169 |
170 |
171 | from clip.model import ResidualAttentionBlock
172 | def set_Clip_Adapter(model, method, dim=8, s=1, set_forward=True, t=10.):
173 | for _ in model.children():
174 | if type(_) == ResidualAttentionBlock:
175 | if method=='router':
176 | _.adapter_attn = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t)
177 | elif method=='router_block':
178 | _.adapter_attn = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t)
179 | _.adapter_mlp = RepAdapter_Router(1024, hidden_dim=dim, scale=s, t=t)
180 | else:
181 | _.adapter_attn = RepAdapter(1024, hidden_dim=dim, scale=s)
182 | _.s = s
183 | if method=='router_block':
184 | bound_method = forward_clip_full.__get__(_, _.__class__)
185 | else:
186 | bound_method = forward_clip.__get__(_, _.__class__)
187 | if set_forward:
188 | setattr(_, 'forward', bound_method)
189 | elif len(list(_.children())) != 0:
190 | set_Clip_Adapter(_, method, dim, s, set_forward=set_forward, t=t)
191 |
--------------------------------------------------------------------------------
/lavin/model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from typing import Optional, Tuple
5 | from dataclasses import dataclass
6 | import math
7 |
8 | import torch
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | import fairscale.nn.model_parallel.initialize as fs_init
13 | from fairscale.nn.model_parallel.layers import (
14 | ParallelEmbedding,
15 | RowParallelLinear,
16 | ColumnParallelLinear,
17 | )
18 |
19 | from torch.nn import Embedding, Linear
20 | import torch
21 | import pdb
22 | from timm.models.layers import DropPath
23 | import clip
24 | from torch.cuda.amp import autocast
25 | @dataclass
26 | class ModelArgs:
27 | dim: int = 512
28 | n_layers: int = 8
29 | n_heads: int = 8
30 | vocab_size: int = -1 # defined later by tokenizer
31 | multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
32 | norm_eps: float = 1e-5
33 | hidden_proj: int=128
34 |
35 | max_batch_size: int = 32
36 | max_seq_len: int = 2048
37 | drop_path: float=0.
38 |
39 |
40 | class RMSNorm(torch.nn.Module):
41 | def __init__(self, dim: int, eps: float = 1e-6):
42 | super().__init__()
43 | self.eps = eps
44 | self.weight = nn.Parameter(torch.ones(dim))
45 |
46 | def _norm(self, x):
47 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
48 |
49 | def forward(self, x):
50 | output = self._norm(x.float()).type_as(x)
51 | return output * self.weight
52 |
53 |
54 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
55 | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
56 | t = torch.arange(end, device=freqs.device) # type: ignore
57 | freqs = torch.outer(t, freqs).float() # type: ignore
58 | freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
59 | return freqs_cis
60 |
61 |
62 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
63 | ndim = x.ndim
64 | assert 0 <= 1 < ndim
65 | assert freqs_cis.shape == (x.shape[1], x.shape[-1])
66 | shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
67 | return freqs_cis.view(*shape)
68 |
69 |
70 | def apply_rotary_emb(
71 | xq: torch.Tensor,
72 | xk: torch.Tensor,
73 | freqs_cis: torch.Tensor,
74 | ) -> Tuple[torch.Tensor, torch.Tensor]:
75 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
76 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
77 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
78 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
79 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
80 | return xq_out.type_as(xq), xk_out.type_as(xk)
81 |
82 |
83 | class Attention(nn.Module):
84 | def __init__(self, args: ModelArgs):
85 | super().__init__()
86 |
87 | self.n_local_heads = args.n_heads
88 | self.head_dim = args.dim // args.n_heads
89 |
90 | #modified bias for reparameterizing
91 | self.wq = Linear(
92 | args.dim,
93 | args.n_heads * self.head_dim,
94 | bias=False
95 | )
96 | self.wk = Linear(
97 | args.dim,
98 | args.n_heads * self.head_dim,
99 | bias=False
100 | )
101 | self.wv = Linear(
102 | args.dim,
103 | args.n_heads * self.head_dim,
104 | bias=False
105 | )
106 | self.wo = Linear(
107 | args.n_heads * self.head_dim,
108 | args.dim,
109 | bias=False
110 | )
111 |
112 | # self.cache_k = torch.zeros(
113 | # (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
114 | # ).cuda()
115 | # self.cache_v = torch.zeros(
116 | # (args.max_batch_size, args.max_seq_len, self.n_local_heads, self.head_dim)
117 | # ).cuda()
118 | # self.gate = torch.nn.Parameter(torch.zeros(1, self.n_local_heads, 1, 1))
119 |
120 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
121 |
122 | bsz, seqlen, _ = x.shape
123 | xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
124 |
125 | xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
126 | xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
127 | xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
128 |
129 | xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
130 |
131 | keys = xk
132 | values = xv
133 |
134 |
135 | xq = xq.transpose(1, 2)
136 | keys = keys.transpose(1, 2)
137 | values = values.transpose(1, 2)
138 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
139 | if mask is not None:
140 | scores = scores + mask # (bs, n_local_heads, slen, cache_len + slen)
141 | scores = F.softmax(scores.float(), dim=-1).type_as(xq)
142 | output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
143 | output = output.transpose(
144 | 1, 2
145 | ).contiguous().view(bsz, seqlen, -1)
146 |
147 | return self.wo(output)
148 |
149 |
150 | class FeedForward(nn.Module):
151 | def __init__(
152 | self,
153 | dim: int,
154 | hidden_dim: int,
155 | multiple_of: int,
156 | ):
157 | super().__init__()
158 | hidden_dim = int(2 * hidden_dim / 3)
159 | hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
160 |
161 | self.w1 = Linear(
162 | dim, hidden_dim, bias=False
163 | )
164 | self.w2 = Linear(
165 | hidden_dim, dim, bias=False
166 | )
167 | self.w3 = Linear(
168 | dim, hidden_dim, bias=False
169 | )
170 |
171 | def forward(self, x):
172 | return self.w2(F.silu(self.w1(x),inplace=False) * self.w3(x))
173 |
174 |
175 | class TransformerBlock(nn.Module):
176 | def __init__(self, layer_id: int, args: ModelArgs):
177 | super().__init__()
178 | self.n_heads = args.n_heads
179 | self.dim = args.dim
180 | self.head_dim = args.dim // args.n_heads
181 | self.attention = Attention(args)
182 | self.feed_forward = FeedForward(
183 | dim=args.dim, hidden_dim=4 * args.dim, multiple_of=args.multiple_of
184 | )
185 | self.layer_id = layer_id
186 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
187 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
188 | self.drop_path = DropPath(args.drop_path) if args.drop_path > 0. else nn.Identity()
189 |
190 | def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], adapter=None):
191 |
192 | h = x + self.drop_path(self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask, adapter))
193 | out = h + self.drop_path(self.feed_forward.forward(self.ffn_norm(h)))
194 | return out
195 |
196 |
197 |
198 | class AdapterMLP(nn.Module):
199 | """ Pytorch Implemention of RepAdapter for 1d tensor"""
200 |
201 | def __init__(
202 | self,
203 | in_features=768,
204 | hidden_dim=128,
205 | out_features=4096
206 | ):
207 | super().__init__()
208 | self.conv_A=nn.Linear(in_features,hidden_dim)
209 | self.conv_B = nn.Linear(hidden_dim, out_features)
210 |
211 |
212 | nn.init.xavier_uniform_( self.conv_A.weight)
213 | nn.init.zeros_(self.conv_A.bias)
214 | nn.init.xavier_uniform_(self.conv_B.weight)
215 | nn.init.zeros_(self.conv_B.bias)
216 |
217 | def forward(self, x):
218 | with autocast():
219 | x=self.conv_B(F.silu(self.conv_A(x)))
220 | return x
221 |
222 |
223 | class Transformer(nn.Module):
224 | def __init__(self, params: ModelArgs):
225 | super().__init__()
226 | self.params = params
227 | self.vocab_size = params.vocab_size
228 | self.n_layers = params.n_layers
229 | self.tok_embeddings = Embedding(
230 | params.vocab_size, params.dim
231 | )
232 |
233 |
234 | self.criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
235 |
236 | # with init_empty_weights():
237 | self.layers = torch.nn.ModuleList()
238 | for layer_id in range(params.n_layers):
239 | self.layers.append(TransformerBlock(layer_id, params))
240 |
241 | self.norm = RMSNorm(params.dim, eps=params.norm_eps)
242 | self.output = Linear(
243 | params.dim, params.vocab_size, bias=False
244 | )
245 |
246 | self.freqs_cis = precompute_freqs_cis(
247 | self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
248 | )
249 |
250 | self.backbone = clip.load('ViT-L/14')[0]
251 |
252 |
253 | #handcraft define self.backbone.visual.transformer.width
254 | self.adapter_proj = AdapterMLP(1024, params.hidden_proj, params.dim).float()
255 | self.adapter_modality_embedding=nn.Embedding(2,params.dim).float()
256 |
257 |
258 |
259 | def insert_image_embeds(self,examples,labels,image_embeds,prefix_img,prefix_nonimg,img_indicators):
260 | _bsz, seqlen,_ = examples.shape
261 | new_examples=[]
262 | new_labels=[]
263 | for i, (example,label) in enumerate(zip(examples,labels)):
264 | if img_indicators[i]>0.:
265 | new_example=torch.cat([example[:1],prefix_img,image_embeds[i],example[1:]],0)
266 | new_label=torch.cat([label[:1],
267 | torch.zeros(prefix_img.shape[0]+image_embeds.shape[1]).to(examples.device).type_as(labels),
268 | label[1:]])
269 | new_example = new_example[:seqlen]
270 | new_label = new_label[:seqlen]
271 | else:
272 | new_example=torch.cat([example[:1],prefix_nonimg,example[1:]],0)
273 | new_label=torch.cat([label[:1],
274 | torch.zeros(prefix_nonimg.shape[0]).to(examples.device).type_as(labels),
275 | label[1:]])
276 | new_example = new_example[:seqlen]
277 | new_label = new_label[:seqlen]
278 | new_examples.append(new_example.unsqueeze(0))
279 | new_labels.append(new_label.unsqueeze(0))
280 | new_examples = torch.cat(new_examples, 0)
281 | new_labels = torch.cat(new_labels, 0)
282 | return new_examples,new_labels
283 |
284 | def forward(self, examples, labels,images=None, prefix_img=None, prefix_nonimg=None,img_indicators=None):
285 |
286 | # print(images.dtype)
287 | image_embeds = self.backbone.encode_image(images).half()
288 |
289 | # print(img_indicators)
290 | if isinstance(img_indicators,list):
291 | img_indicators = torch.Tensor(img_indicators).to(image_embeds.device).long()
292 | modality_embed=self.adapter_modality_embedding(img_indicators.unsqueeze(1))
293 |
294 | # with autocast():
295 | image_embeds=self.adapter_proj(image_embeds)
296 |
297 | # print(image_embeds.shape)
298 |
299 | _bsz, seqlen = examples.shape
300 |
301 | examples = self.tok_embeddings(examples)
302 | prefix_img=self.tok_embeddings(prefix_img.unsqueeze(0)).squeeze(0)
303 | prefix_nonimg=self.tok_embeddings(prefix_nonimg.unsqueeze(0)).squeeze(0)
304 |
305 |
306 | h,labels=self.insert_image_embeds(examples,labels,image_embeds,prefix_img,prefix_nonimg,img_indicators)
307 |
308 | h=torch.cat([modality_embed.half(),h],1)[:,:seqlen]
309 | modality_labels=torch.zeros(_bsz,1).to(labels.device).type_as(labels)
310 | labels=torch.cat([modality_labels,labels],1)[:,:seqlen]
311 |
312 |
313 | freqs_cis = self.freqs_cis.to(h.device)
314 | freqs_cis = freqs_cis[:seqlen]
315 | mask = None
316 | mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device)
317 | mask = torch.triu(mask, diagonal=0 + 1).type_as(h)
318 |
319 | #mask decision token
320 | mask[:,:,1:,0]=float("-inf")
321 |
322 | start_pos = 0
323 | for layer in self.layers:
324 | h = layer(h, start_pos, freqs_cis, mask)
325 |
326 | h = self.norm(h)
327 | output = self.output(h)
328 | output = output[:, :-1, :].reshape(-1, self.vocab_size)
329 | labels = labels[:, 1:].flatten()
330 |
331 |
332 | c_loss = self.criterion(output, labels)
333 | return c_loss
334 |
--------------------------------------------------------------------------------
/lavin/tokenizer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from sentencepiece import SentencePieceProcessor
5 | from logging import getLogger
6 | from typing import List
7 | import os
8 |
9 |
10 | logger = getLogger()
11 |
12 |
13 | class Tokenizer:
14 | def __init__(self, model_path: str):
15 | # reload tokenizer
16 | assert os.path.isfile(model_path), model_path
17 | self.sp_model = SentencePieceProcessor(model_file=model_path)
18 | logger.info(f"Reloaded SentencePiece model from {model_path}")
19 |
20 | # BOS / EOS token IDs
21 | self.n_words: int = self.sp_model.vocab_size()
22 | self.bos_id: int = self.sp_model.bos_id()
23 | self.eos_id: int = self.sp_model.eos_id()
24 | self.pad_id: int = self.sp_model.pad_id()
25 | logger.info(
26 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
27 | )
28 | assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
29 |
30 | def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
31 | assert type(s) is str
32 | t = self.sp_model.encode(s)
33 | if bos:
34 | t = [self.bos_id] + t
35 | if eos:
36 | t = t + [self.eos_id]
37 | return t
38 |
39 | def decode(self, t: List[int]) -> str:
40 | return self.sp_model.decode(t)
41 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.0
2 | fairscale
3 | fire
4 | sentencepiece
5 | transformers==4.30.0
6 | timm
7 | tensorboard
8 | ftfy
9 | gradio
10 | bitsandbytes==0.39.0
--------------------------------------------------------------------------------
/scripts/eval_mme_benchmark.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 1 --master_port 11345 eval_mme.py \
2 | --ckpt_dir ../data/weights/ \
3 | --llm_model 7B\
4 | --tokenizer_path ../data/weights/tokenizer.model \
5 | --data_root ../data \
6 | --caption_file ../data/captions.json \
7 | --adapter_path ./15-eph-pretrain.pth \
8 | --adapter_type attn \
9 | --adapter_dim 8 \
10 | --adapter_scale 1 \
11 | --prompt_format QCM-ALE \
12 | --max_batch_size 16\
13 | --max_seq_len 512 \
14 | --split test \
15 | --n_prompt 6 \
16 | --temperature 5.\
17 | --visual_adapter_type router
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_13b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 --master_port 12345 train.py \
2 | --llm_model 13B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 1 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-13B/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 5.\
19 | --visual_adapter_type router
20 |
21 | torchrun --nproc_per_node 1 eval.py \
22 | --ckpt_dir ../data/weights/ \
23 | --llm_model 13B\
24 | --tokenizer_path ../data/weights/tokenizer.model \
25 | --data_root ../data \
26 | --caption_file ../data/captions.json \
27 | --adapter_path ./LaVIN-13B/checkpoint-19.pth \
28 | --adapter_type attn \
29 | --adapter_dim 8 \
30 | --adapter_scale 1 \
31 | --prompt_format QCM-ALE \
32 | --max_batch_size 64\
33 | --max_seq_len 512 \
34 | --split test \
35 | --n_prompt 6 \
36 | --temperature 5.\
37 | --visual_adapter_type router
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_13b_lite.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 1 train.py \
2 | --llm_model 13B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 1 \
7 | --accum_iter 32 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-13B-lite/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 5.\
19 | --visual_adapter_type router \
20 | --gradient_checkpointing \
21 | --bits 4bit \
22 | --cpu_load
23 |
24 | torchrun --nproc_per_node 1 eval.py \
25 | --ckpt_dir ../data/weights/ \
26 | --llm_model 13B\
27 | --tokenizer_path ../data/weights/tokenizer.model \
28 | --data_root ../data \
29 | --caption_file ../data/captions.json \
30 | --adapter_path ./LaVIN-13B-lite/checkpoint-19.pth \
31 | --adapter_type attn \
32 | --adapter_dim 8 \
33 | --adapter_scale 1 \
34 | --prompt_format QCM-ALE \
35 | --max_batch_size 64\
36 | --max_seq_len 512 \
37 | --split test \
38 | --n_prompt 6 \
39 | --temperature 5.\
40 | --visual_adapter_type router \
41 | --bits 4bit
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_7b.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 --master_port 11111 train.py \
2 | --llm_model 7B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 4 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-7B/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 10.\
19 | --visual_adapter_type router
20 |
21 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 --master_port 11111 eval.py \
22 | --ckpt_dir ../data/weights/ \
23 | --llm_model 7B\
24 | --tokenizer_path ../data/weights/tokenizer.model \
25 | --data_root ../data \
26 | --caption_file ../data/captions.json \
27 | --adapter_path ./LaVIN-7B/checkpoint-19.pth \
28 | --adapter_type attn \
29 | --adapter_dim 8 \
30 | --adapter_scale 1 \
31 | --prompt_format QCM-ALE \
32 | --max_batch_size 64\
33 | --max_seq_len 512 \
34 | --split test \
35 | --n_prompt 6 \
36 | --temperature 10.\
37 | --visual_adapter_type router
38 |
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_7b_lite.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 --master_port 11111 train.py \
2 | --llm_model 7B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 1 \
7 | --accum_iter 32 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-7B-lite/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 10.\
19 | --visual_adapter_type router \
20 | --gradient_checkpointing \
21 | --bits 4bit \
22 | --cpu_load
23 |
24 | CUDA_VISIBLE_DEVICES=2 torchrun --nproc_per_node 1 eval.py \
25 | --ckpt_dir ../data/weights/ \
26 | --llm_model 7B\
27 | --tokenizer_path ../data/weights/tokenizer.model \
28 | --data_root ../data \
29 | --caption_file ../data/captions.json \
30 | --adapter_path ./LaVIN-7B-lite/checkpoint-19.pth \
31 | --adapter_type attn \
32 | --adapter_dim 8 \
33 | --adapter_scale 1 \
34 | --prompt_format QCM-ALE \
35 | --max_batch_size 64\
36 | --max_seq_len 512 \
37 | --split test \
38 | --n_prompt 6 \
39 | --temperature 10.\
40 | --visual_adapter_type router\
41 | --bits 4bit \
42 | --cpu_load
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_vicuna_13b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \
2 | --llm_model 13B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 1 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-Vicuna-13B/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 5.\
19 | --visual_adapter_type router \
20 | --use_vicuna
21 |
22 | torchrun --nproc_per_node 1 eval.py \
23 | --ckpt_dir ../data/weights/ \
24 | --llm_model 13B\
25 | --tokenizer_path ../data/weights/tokenizer.model \
26 | --data_root ../data \
27 | --caption_file ../data/captions.json \
28 | --adapter_path ./LaVIN-Vicuna-13B/checkpoint-19.pth \
29 | --adapter_type attn \
30 | --adapter_dim 8 \
31 | --adapter_scale 1 \
32 | --prompt_format QCM-ALE \
33 | --max_batch_size 64\
34 | --max_seq_len 512 \
35 | --split test \
36 | --n_prompt 6 \
37 | --temperature 5.\
38 | --visual_adapter_type router \
39 | --use_vicuna=True
--------------------------------------------------------------------------------
/scripts/finetuning_sqa_vicuna_7b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 2 --master_port 12345 --nproc_per_node 2 train.py \
2 | --llm_model 7B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 4 \
8 | --epochs 20 \
9 | --warmup_epochs 2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-Vicuna-7B/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 10.\
19 | --visual_adapter_type router\
20 | --use_vicuna
21 |
22 | torchrun --nproc_per_node 1 eval.py \
23 | --ckpt_dir ../data/weights/ \
24 | --llm_model 7B\
25 | --tokenizer_path ../data/weights/tokenizer.model \
26 | --data_root ../data \
27 | --caption_file ../data/captions.json \
28 | --adapter_path ./LaVIN-Vicuna-7B/checkpoint-19.pth \
29 | --adapter_type attn \
30 | --adapter_dim 8 \
31 | --adapter_scale 1 \
32 | --prompt_format QCM-ALE \
33 | --max_batch_size 64\
34 | --max_seq_len 512 \
35 | --split test \
36 | --n_prompt 6 \
37 | --temperature 10.\
38 | --visual_adapter_type router\
39 | --use_vicuna=True
--------------------------------------------------------------------------------
/scripts/vl_instruction_tuning_13b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \
2 | --llm_model 13B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 1 \
8 | --epochs 15 \
9 | --warmup_epochs 0.2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-13B-VLIT/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 5.\
19 | --visual_adapter_type router \
20 | --do_pretrain
--------------------------------------------------------------------------------
/scripts/vl_instruction_tuning_vicuna_13b.sh:
--------------------------------------------------------------------------------
1 | torchrun --nproc_per_node 8 --master_port 12345 --nproc_per_node 8 train.py \
2 | --llm_model 13B\
3 | --llama_model_path ../data/weights/ \
4 | --data_path ../data/alpaca_data.json \
5 | --max_seq_len 512 \
6 | --batch_size 4 \
7 | --accum_iter 1 \
8 | --epochs 15 \
9 | --warmup_epochs 0.2 \
10 | --blr 9e-3 \
11 | --weight_decay 0.02 \
12 | --output_dir ./LaVIN-Vicuna-13B-VLIT/\
13 | --adapter_type attn\
14 | --adapter_dim 8\
15 | --adapter_scale 1\
16 | --n_prompt 6 \
17 | --prompt_format QCM-ALE \
18 | --temperature 5.\
19 | --visual_adapter_type router \
20 | --do_pretrain \
21 | --use_vicuna
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # This software may be used and distributed according to the terms of the GNU General Public License version 3.
3 |
4 | from setuptools import find_packages, setup
5 |
6 | setup(name="lavin", version="0.1", packages=find_packages())
--------------------------------------------------------------------------------
/tools/data_processing.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | #instruction:
4 | #answer:
5 | #input:
6 | #options:
7 | #qid:
8 | #image
9 | all_data=[]
10 | with open('../data/alpaca_data.json') as f:
11 | alpaca_data=json.load(f)
12 | for i,item in enumerate(alpaca_data):
13 | data={}
14 | input=item['input']
15 | if len(input)==0:
16 | input=''
17 | data['instruction'] = 'Instruction: '+ item['instruction']+' '+input+'\n'+\
18 | 'Response: '
19 | data['instruction'] = data['instruction'].replace(" ", " ").strip()
20 | data['answer'] = item['output']
21 | data['image'] = None
22 | data['options'] = None
23 | data['image_source'] = None
24 | data['qid']='alpaca_'+str(i)
25 | all_data.append(data)
26 |
27 |
28 | with open('../data/complex_reasoning_77k.json') as f:
29 | gpt4_data_0=json.load(f)
30 | with open('../data/detail_23k.json') as f:
31 | gpt4_data_1=json.load(f)
32 | with open('../data/conversation_58k.json') as f:
33 | gpt4_data_2=json.load(f)
34 | gpt4_data=gpt4_data_0+gpt4_data_1
35 | for i,item in enumerate(gpt4_data):
36 | data={}
37 | data['instruction'] = 'Instruction: '+item['conversations'][0]['value'].replace('\n','').replace('\n','')+'\n'+ \
38 | 'Response: '
39 | data['instruction'] = data['instruction'].replace(" ", " ").strip()
40 | data['answer'] = item['conversations'][1]['value']
41 | data['image'] = item['image']
42 | data['image_source']='mscoco'
43 | data['options'] = None
44 | data['qid']='gpt4_'+str(i)
45 | all_data.append(data)
46 |
47 |
48 | for i,item in enumerate(gpt4_data_2):
49 | for j in range(0,len(item['conversations']),2):
50 | data={}
51 | data['instruction'] = 'Instruction: '+item['conversations'][j]['value'].replace('\n','').replace('\n','')+'\n'+ \
52 | 'Response: '
53 | data['instruction'] = data['instruction'].replace(" ", " ").strip()
54 | data['answer'] = item['conversations'][j+1]['value']
55 | data['image'] = item['image']
56 | data['image_source']='mscoco'
57 | data['options'] = None
58 | data['qid']='gpt4_2_'+str(i)+'_'+str(j)
59 | all_data.append(data)
60 |
61 |
62 |
63 | full_data={}
64 | full_data['all']=all_data
65 | with open('../data/all_data.json','w') as f:
66 | json.dump(full_data,f)
67 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import datetime
4 | import json
5 | import time
6 | import numpy as np
7 | from pathlib import Path
8 |
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.tensorboard import SummaryWriter
12 | import timm.optim.optim_factory as optim_factory
13 |
14 | import util.misc as misc
15 | from util.misc import NativeScalerWithGradNormCount as NativeScaler
16 | from engine import train_one_epoch
17 |
18 | from util.datasets import ScienceQADataSet,InstrcutDataSet
19 | from lavin.mm_adaptation import LaVIN
20 | import random
21 | import bitsandbytes as bnb
22 |
23 | def get_args_parser():
24 | parser = argparse.ArgumentParser('MAE pre-training', add_help=False)
25 | parser.add_argument('--batch_size', default=64, type=int,
26 | help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
27 | parser.add_argument('--epochs', default=400, type=int)
28 | parser.add_argument('--bits', default='16bit', type=str,choices=['4bit','8bit','16bit'],
29 | help='Quantization bits for training, fp16 by default')
30 | parser.add_argument('--accum_iter', default=1, type=int,
31 | help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
32 |
33 | # Model parameters
34 | parser.add_argument('--llama_model_path', default='./llama', type=str,
35 | help='path of llama model')
36 |
37 | parser.add_argument('--llm_model', default='7B', type=str, metavar='MODEL',
38 | help='Name of llm model to train')
39 |
40 | parser.add_argument('--use_vicuna', action='store_true', help='use vicuna weights')
41 |
42 | parser.add_argument('--cpu_load', action='store_true', help='load the model on cpu and avoid OOM on gpu')
43 |
44 | #block is not supported now.
45 | parser.add_argument('--adapter_type', type=str, default='attn', metavar='LENGTH',choices=['block','attn'],
46 | help='the insert position of adapter layer')
47 |
48 |
49 | parser.add_argument('--visual_adapter_type', type=str, default='normal', metavar='LENGTH',choices=['normal','router','router_block'],
50 | help='the type of adapter layer')
51 |
52 | parser.add_argument('--adapter_dim', type=int, default=8, metavar='LENGTH', help='the dims of adapter layer')
53 |
54 | parser.add_argument('--hidden_proj', type=int, default=128, metavar='LENGTH',
55 | help='the visual adapter dim')
56 |
57 | parser.add_argument('--temperature', type=float, default=10., metavar='LENGTH',
58 | help='the temperature of router')
59 |
60 | parser.add_argument('--n_prompt', type=int, default=10, metavar='LENGTH',
61 | help='the length of visual features')
62 | parser.add_argument('--adapter_scale', type=float, default=1., metavar='LENGTH', help='the scales of adapter layer')
63 | parser.add_argument('--drop_path', type=float, default=0., metavar='LENGTH', help='drop path')
64 |
65 | parser.add_argument('--max_seq_len', type=int, default=512, metavar='LENGTH',
66 | help='the maximum sequence length')
67 |
68 |
69 | # Optimizer parameters
70 | parser.add_argument('--weight_decay', type=float, default=0.05,
71 | help='weight decay (default: 0.05)')
72 |
73 | parser.add_argument('--lr', type=float, default=None, metavar='LR',
74 | help='learning rate (absolute lr)')
75 | parser.add_argument('--clip_grad', type=float, default=None, metavar='clip gradient',
76 | help='clips gradient norm of an iterable of parameters')
77 | parser.add_argument('--blr', type=float, default=1e-3, metavar='LR',
78 | help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
79 | parser.add_argument('--min_lr', type=float, default=0., metavar='LR',
80 | help='lower lr bound for cyclic schedulers that hit 0')
81 |
82 | parser.add_argument('--gradient_checkpointing', action='store_true',
83 | help='saving memory costs via gradient_checkpointing')
84 | parser.add_argument('--warmup_epochs', type=float, default=40, metavar='N',
85 | help='epochs to warmup LR')
86 |
87 | # Dataset parameters
88 | parser.add_argument('--data_path', default='/instruction_dataset/', type=str,
89 | help='dataset path')
90 |
91 | parser.add_argument('--output_dir', default='./output_dir',
92 | help='path where to save, empty for no saving')
93 | parser.add_argument('--log_dir', default='./output_dir',
94 | help='path where to tensorboard log')
95 | parser.add_argument('--device', default='cuda',
96 | help='device to use for training / testing')
97 | parser.add_argument('--seed', default=0, type=int)
98 | parser.add_argument('--resume', default='',
99 | help='resume from checkpoint')
100 |
101 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
102 | help='start epoch')
103 | parser.add_argument('--num_workers', default=10, type=int)
104 | parser.add_argument('--pin_mem', action='store_true',
105 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
106 | parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
107 | parser.set_defaults(pin_mem=True)
108 |
109 | # distributed training parameters
110 | parser.add_argument('--world_size', default=1, type=int,
111 | help='number of distributed processes')
112 | parser.add_argument('--local_rank', default=-1, type=int)
113 | parser.add_argument('--dist_on_itp', action='store_true')
114 | parser.add_argument('--dist_url', default='env://',
115 | help='url used to set up distributed training')
116 |
117 | #datasets
118 | parser.add_argument('--prompt_format',
119 | type=str,
120 | default='CQM-A',
121 | choices=[
122 | 'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A',
123 | 'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A',
124 | 'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE'
125 | ],
126 | help='prompt format template')
127 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
128 | parser.add_argument('--caption_file', type=str, default='../data/captions.json')
129 | parser.add_argument('--data_root', type=str, default='../data')
130 | parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
131 | parser.add_argument('--do_pretrain', action='store_true', help='pre-train on large scale vl instruction')
132 |
133 | return parser
134 |
135 |
136 | def main(args):
137 |
138 | misc.init_distributed_mode(args)
139 |
140 | print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
141 | print("{}".format(args).replace(', ', ',\n'))
142 |
143 | device = torch.device(args.device)
144 |
145 | # fix the seed for reproducibility
146 | seed = args.seed + misc.get_rank()
147 | torch.manual_seed(seed)
148 | np.random.seed(seed)
149 | g = torch.Generator()
150 | g.manual_seed(seed)
151 | random.seed(seed)
152 |
153 | cudnn.benchmark = False
154 | cudnn.deterministic = True
155 |
156 |
157 | if args.do_pretrain:
158 | dataset_train = InstrcutDataSet(args, 'all', args.llama_model_path, args.max_seq_len)
159 | else:
160 | dataset_train = ScienceQADataSet(args, 'train', args.llama_model_path, args.max_seq_len)
161 |
162 | print(dataset_train)
163 |
164 |
165 | num_tasks = misc.get_world_size()
166 | global_rank = misc.get_rank()
167 | sampler_train = torch.utils.data.DistributedSampler(
168 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
169 | )
170 |
171 | print("Sampler_train = %s" % str(sampler_train))
172 |
173 | if global_rank == 0 and args.log_dir is not None:
174 | os.makedirs(args.log_dir, exist_ok=True)
175 | log_writer = SummaryWriter(log_dir=args.log_dir)
176 | else:
177 | log_writer = None
178 |
179 | data_loader_train = torch.utils.data.DataLoader(
180 | dataset_train, sampler=sampler_train,
181 | batch_size=args.batch_size,
182 | num_workers=args.num_workers,
183 | pin_memory=args.pin_mem,
184 | drop_last=True,
185 | generator=g,
186 |
187 | )
188 |
189 |
190 |
191 |
192 | # define the model
193 | model = LaVIN(args)
194 |
195 |
196 | model.to(device)
197 |
198 | #for debug. print the data type.
199 | for name, param in model.named_parameters():
200 | print(name,param.dtype)
201 |
202 | model_without_ddp = model
203 |
204 | #for debug. print the model.
205 | # print("Model = %s" % str(model_without_ddp))
206 |
207 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
208 |
209 | if args.lr is None: # only base_lr is specified
210 | args.lr = args.blr * eff_batch_size / 256
211 |
212 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
213 | print("actual lr: %.2e" % args.lr)
214 |
215 | print("accumulate grad iterations: %d" % args.accum_iter)
216 | print("effective batch size: %d" % eff_batch_size)
217 |
218 | if args.distributed:
219 | print(args.gpu)
220 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True)
221 | model_without_ddp = model.module
222 |
223 | # following timm: set wd as 0 for bias and norm layers
224 | param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
225 |
226 | #following qlora: apply paged optimizer
227 | optimizer = bnb.optim.AdamW32bit(param_groups, lr=args.lr, betas=(0.9, 0.95),is_paged=True) #torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
228 | print(optimizer)
229 |
230 | #mixed precision scaler
231 | loss_scaler = NativeScaler()
232 |
233 | misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
234 |
235 | print(f"Start training for {args.epochs} epochs")
236 | start_time = time.time()
237 | for epoch in range(args.start_epoch, args.epochs):
238 |
239 | if args.distributed:
240 | data_loader_train.sampler.set_epoch(epoch)
241 |
242 | train_stats = train_one_epoch(
243 | model, data_loader_train,
244 | optimizer, device, epoch, loss_scaler,
245 | log_writer=log_writer,
246 | args=args
247 | )
248 |
249 | if args.output_dir:
250 | misc.save_model(
251 | args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
252 | loss_scaler=loss_scaler, epoch=epoch)
253 |
254 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
255 | 'epoch': epoch,}
256 |
257 |
258 | if args.output_dir and misc.is_main_process():
259 | if log_writer is not None:
260 | log_writer.flush()
261 | with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
262 | f.write(json.dumps(log_stats) + "\n")
263 |
264 | total_time = time.time() - start_time
265 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
266 | print('Training time {}'.format(total_time_str))
267 |
268 |
269 | if __name__ == '__main__':
270 |
271 | args = get_args_parser()
272 | args = args.parse_args()
273 | if args.output_dir:
274 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
275 | main(args)
276 |
--------------------------------------------------------------------------------
/util/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Apply the delta weights on top of a base model.
3 |
4 | Usage:
5 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta-v1.1
6 | """
7 | import argparse
8 | import gc
9 | import glob
10 | import json
11 | import os
12 | import shutil
13 | import tempfile
14 |
15 | from huggingface_hub import snapshot_download
16 | import torch
17 | from torch import nn
18 | from tqdm import tqdm
19 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
20 |
21 |
22 | GB = 1 << 30
23 |
24 |
25 | def split_files(model_path, tmp_path, split_size):
26 | if not os.path.exists(model_path):
27 | model_path = snapshot_download(repo_id=model_path)
28 | if not os.path.exists(tmp_path):
29 | os.makedirs(tmp_path)
30 |
31 | file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
32 | files = glob.glob(file_pattern)
33 |
34 | part = 0
35 | try:
36 | for file_path in tqdm(files):
37 | state_dict = torch.load(file_path)
38 | new_state_dict = {}
39 |
40 | current_size = 0
41 | for name, param in state_dict.items():
42 | param_size = param.numel() * param.element_size()
43 |
44 | if current_size + param_size > split_size:
45 | new_file_name = f"pytorch_model-{part}.bin"
46 | new_file_path = os.path.join(tmp_path, new_file_name)
47 | torch.save(new_state_dict, new_file_path)
48 | current_size = 0
49 | new_state_dict = None
50 | gc.collect()
51 | new_state_dict = {}
52 | part += 1
53 |
54 | new_state_dict[name] = param
55 | current_size += param_size
56 |
57 | new_file_name = f"pytorch_model-{part}.bin"
58 | new_file_path = os.path.join(tmp_path, new_file_name)
59 | torch.save(new_state_dict, new_file_path)
60 | new_state_dict = None
61 | gc.collect()
62 | new_state_dict = {}
63 | part += 1
64 | except Exception as e:
65 | print(f"An error occurred during split_files: {e}")
66 | shutil.rmtree(tmp_path)
67 | raise
68 |
69 |
70 | def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
71 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
72 | delta_config = AutoConfig.from_pretrained(delta_path)
73 |
74 | if os.path.exists(target_model_path):
75 | shutil.rmtree(target_model_path)
76 | os.makedirs(target_model_path)
77 |
78 | split_size = 4 * GB
79 |
80 | with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
81 | print(f"Split files for the base model to {tmp_base_path}")
82 | split_files(base_model_path, tmp_base_path, split_size)
83 | print(f"Split files for the delta weights to {tmp_delta_path}")
84 | split_files(delta_path, tmp_delta_path, split_size)
85 |
86 | base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
87 | base_files = glob.glob(base_pattern)
88 | delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
89 | delta_files = glob.glob(delta_pattern)
90 | delta_state_dict = torch.load(delta_files[0])
91 |
92 | print("Applying the delta")
93 | weight_map = {}
94 | total_size = 0
95 |
96 | for i, base_file in tqdm(enumerate(base_files)):
97 | state_dict = torch.load(base_file)
98 | file_name = f"pytorch_model-{i}.bin"
99 | for name, param in state_dict.items():
100 | if name not in delta_state_dict:
101 | for delta_file in delta_files:
102 | delta_state_dict = torch.load(delta_file)
103 | gc.collect()
104 | if name in delta_state_dict:
105 | break
106 |
107 | state_dict[name] += delta_state_dict[name]
108 | weight_map[name] = file_name
109 | total_size += param.numel() * param.element_size()
110 | gc.collect()
111 | torch.save(state_dict, os.path.join(target_model_path, file_name))
112 |
113 | with open(
114 | os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
115 | ) as f:
116 | json.dump(
117 | {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
118 | )
119 |
120 | print(f"Saving the target model to {target_model_path}")
121 | delta_tokenizer.save_pretrained(target_model_path)
122 | delta_config.save_pretrained(target_model_path)
123 |
124 |
125 | def apply_delta(base_model_path, target_model_path, delta_path):
126 | print(f"Loading the delta weights from {delta_path}")
127 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
128 | delta = AutoModelForCausalLM.from_pretrained(
129 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
130 | )
131 |
132 | print(f"Loading the base model from {base_model_path}")
133 | base = AutoModelForCausalLM.from_pretrained(
134 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
135 | )
136 |
137 | print("Applying the delta")
138 | for name, param in tqdm(base.state_dict().items(), desc="Applying delta"):
139 | assert name in delta.state_dict()
140 | param.data += delta.state_dict()[name]
141 |
142 | print(f"Saving the target model to {target_model_path}")
143 | base.save_pretrained(target_model_path)
144 | delta_tokenizer.save_pretrained(target_model_path)
145 |
146 |
147 | def huggingface2llama(key):
148 | key=key.replace('model.layers','layers')
149 | if 'embed_tokens' in key:
150 | key=key.replace('embed_tokens','tok_embeddings')
151 | if '.self_attn.' in key:
152 | key = key.replace( '.self_attn.q_proj','.attention.wq')
153 | key = key.replace( '.self_attn.k_proj','.attention.wk')
154 | key = key.replace( '.self_attn.v_proj','.attention.wv')
155 | key = key.replace( '.self_attn.o_proj','.attention.wo')
156 | if '.input_layernorm.' in key:
157 | key=key.replace('.input_layernorm.','.attention_norm.')
158 | if '.post_attention_layernorm.' in key:
159 | key = key.replace('.post_attention_layernorm.', '.ffn_norm.')
160 | if '.mlp.' in key:
161 | key=key.replace('.mlp.','.feed_forward.')
162 | if '.down_proj.' in key:
163 | key = key.replace('.down_proj.', '.w2.')
164 | if '.up_proj.' in key:
165 | key = key.replace('.up_proj.', '.w3.')
166 | if '.gate_proj.' in key:
167 | key = key.replace('.gate_proj.', '.w1.')
168 | if key=='model.norm.weight':
169 | key=key.replace('model.norm.','norm.')
170 | if key=='lm_head.weight':
171 | key=key.replace('lm_head.','output.')
172 | return key
173 |
174 | def llama2huggingface(key):
175 | key = key.replace('layers', 'model.layers')
176 | if 'tok_embeddings' in key:
177 | key = key.replace('tok_embeddings', 'model.embed_tokens')
178 | if '.attention.wq' in key:
179 | key = key.replace('.attention.wq', '.self_attn.q_proj')
180 | if '.attention.wk' in key:
181 | key = key.replace('.attention.wk', '.self_attn.k_proj')
182 | if '.attention.wv' in key:
183 | key = key.replace('.attention.wv', '.self_attn.v_proj')
184 | if '.attention.wo' in key:
185 | key = key.replace('.attention.wo', '.self_attn.o_proj')
186 | if '.attention_norm.' in key:
187 | key = key.replace('.attention_norm.', '.input_layernorm.')
188 | if '.ffn_norm.' in key:
189 | key = key.replace('.ffn_norm.', '.post_attention_layernorm.')
190 | if '.w3.' in key:
191 | key = key.replace('.w3.', '.up_proj.')
192 | if '.w2.' in key:
193 | key = key.replace('.w2.', '.down_proj.')
194 | if '.w1.' in key:
195 | key = key.replace('.w1.', '.gate_proj.')
196 | if '.feed_forward.' in key:
197 | key = key.replace('.feed_forward.', '.mlp.')
198 | if key=='norm.weight':
199 | key=key.replace('norm.','model.norm.')
200 | if key=='output.weight':
201 | key=key.replace('output.','lm_head.')
202 | return key
203 |
204 | def apply_model_delta_online(base_model, delta_path):
205 | print(f"Loading the delta weights from {delta_path}")
206 | # delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
207 | delta = AutoModelForCausalLM.from_pretrained(
208 | delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
209 | )
210 |
211 | candidate_weight=set()
212 | exclude_weight=[]
213 | for name, param in base_model.state_dict().items():
214 | if llama2huggingface(name) in delta.state_dict():
215 | candidate_weight.add(name)
216 | else:
217 | exclude_weight.append(name)
218 | print("excluding these weights in llama: ",exclude_weight)
219 |
220 | print("Applying the delta")
221 | for name, param in base_model.named_parameters():
222 | if name in candidate_weight:
223 | assert llama2huggingface(name) in delta.state_dict()
224 | param.data += delta.state_dict()[llama2huggingface(name)].to(param.data.device)
225 |
226 |
227 |
228 | if __name__ == "__main__":
229 | parser = argparse.ArgumentParser()
230 | parser.add_argument("--base-model-path", type=str, required=True)
231 | parser.add_argument("--target-model-path", type=str, required=True)
232 | parser.add_argument("--delta-path", type=str, required=True)
233 | parser.add_argument(
234 | "--low-cpu-mem",
235 | action="store_true",
236 | help="Lower the cpu memory usage. This will split large files and use "
237 | "disk as swap to reduce the memory usage below 10GB.",
238 | )
239 | args = parser.parse_args()
240 |
241 | if args.low_cpu_mem:
242 | apply_delta_low_cpu_mem(
243 | args.base_model_path, args.target_model_path, args.delta_path
244 | )
245 | else:
246 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
--------------------------------------------------------------------------------
/util/base_prompt.py:
--------------------------------------------------------------------------------
1 | def get_question_text(problem):
2 | question = problem['question']
3 | return question
4 |
5 |
6 | def get_context_text(problem, use_caption):
7 | txt_context = problem['hint']
8 | img_context = problem['caption'] if use_caption else ""
9 | context = " ".join([txt_context, img_context]).strip()
10 | if context == "":
11 | context = "N/A"
12 | return context
13 |
14 |
15 | def get_choice_text(probelm, options):
16 | choices = probelm['choices']
17 | choice_list = []
18 | for i, c in enumerate(choices):
19 | choice_list.append("({}) {}".format(options[i], c))
20 | choice_txt = " ".join(choice_list)
21 | #print(choice_txt)
22 | return choice_txt
23 |
24 |
25 | def get_answer(problem, options):
26 | return options[problem['answer']]
27 |
28 |
29 | def get_lecture_text(problem):
30 | # \\n: GPT-3 can generate the lecture with more tokens.
31 | lecture = problem['lecture'].replace("\n", "\\n")
32 | return lecture
33 |
34 |
35 | def get_solution_text(problem):
36 | # \\n: GPT-3 can generate the solution with more tokens
37 | solution = problem['solution'].replace("\n", "\\n")
38 | return solution
39 |
40 |
41 | def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
42 |
43 | input_format, output_format = format.split("-")
44 |
45 | ## Inputs
46 | if input_format == "CQM":
47 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
48 | elif input_format == "QCM":
49 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
50 | # upper bound experiment
51 | elif input_format == "QCML":
52 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
53 | elif input_format == "QCME":
54 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
55 | elif input_format == "QCMLE":
56 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
57 |
58 | elif input_format == "QCLM":
59 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
60 | elif input_format == "QCEM":
61 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
62 | elif input_format == "QCLEM":
63 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
64 |
65 | # Outputs
66 | if test_example:
67 | output = "Answer:"
68 | elif output_format == 'A':
69 | output = f"Answer: The answer is {answer}."
70 |
71 | elif output_format == 'AL':
72 | output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
73 | elif output_format == 'AE':
74 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
75 | elif output_format == 'ALE':
76 | output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
77 | elif output_format == 'AEL':
78 | output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
79 |
80 | elif output_format == 'LA':
81 | output = f"Answer: {lecture} The answer is {answer}."
82 | elif output_format == 'EA':
83 | output = f"Answer: {solution} The answer is {answer}."
84 | elif output_format == 'LEA':
85 | output = f"Answer: {lecture} {solution} The answer is {answer}."
86 | elif output_format == 'ELA':
87 | output = f"Answer: {solution} {lecture} The answer is {answer}."
88 |
89 | text = input + output
90 | text = text.replace(" ", " ").strip()
91 | if text.endswith("BECAUSE:"):
92 | text = text.replace("BECAUSE:", "").strip()
93 | return text
94 |
95 |
96 | def create_training_example(format, question, context, choice, answer, lecture, solution):
97 |
98 | input_format, output_format = format.split("-")
99 |
100 | ## Inputs
101 | if input_format == "CQM":
102 | input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
103 | elif input_format == "QCM":
104 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
105 | # upper bound experiment
106 | elif input_format == "QCML":
107 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
108 | elif input_format == "QCME":
109 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
110 | elif input_format == "QCMLE":
111 | input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
112 |
113 | elif input_format == "QCLM":
114 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
115 | elif input_format == "QCEM":
116 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
117 | elif input_format == "QCLEM":
118 | input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
119 |
120 | input+="Response:"
121 | input='\n'+input
122 |
123 |
124 | # Outputs
125 | if output_format == 'A':
126 | output = f"The answer is {answer}."
127 |
128 | elif output_format == 'AL':
129 | output = f"The answer is {answer}. BECAUSE: {solution}"
130 | elif output_format == 'AE':
131 | output = f"The answer is {answer}. BECAUSE: {lecture}"
132 | elif output_format == 'ALE':
133 | output = f"The answer is {answer}. BECAUSE: {lecture} {solution}"
134 | elif output_format == 'AEL':
135 | output = f"The answer is {answer}. BECAUSE: {solution} {lecture}"
136 |
137 | elif output_format == 'LA':
138 | output = f"{lecture} The answer is {answer}."
139 | elif output_format == 'EA':
140 | output = f"{solution} The answer is {answer}."
141 | elif output_format == 'LEA':
142 | output = f"{lecture} {solution} The answer is {answer}."
143 | elif output_format == 'ELA':
144 | output = f"{solution} {lecture} The answer is {answer}."
145 |
146 | input = input.replace(" ", " ").strip()
147 | output = output.replace(" ", " ").strip()
148 | if output.endswith("BECAUSE:"):
149 | text = output.replace("BECAUSE:", "").strip()
150 |
151 | # print(input)
152 | return input, output
153 |
154 | def build_few_shot_prompt(problems, shot_qids, test_qid, args):
155 |
156 | examples = []
157 |
158 | # n-shot training examples
159 | for qid in shot_qids:
160 | question = get_question_text(problems[qid])
161 | context = get_context_text(problems[qid], args.use_caption)
162 | choice = get_choice_text(problems[qid], args.options)
163 | answer = get_answer(problems[qid], args.options)
164 | lecture = get_lecture_text(problems[qid])
165 | solution = get_solution_text(problems[qid])
166 |
167 | train_example = create_one_example(args.prompt_format,
168 | question,
169 | context,
170 | choice,
171 | answer,
172 | lecture,
173 | solution,
174 | test_example=False)
175 | examples.append(train_example)
176 |
177 | # test example
178 | question = get_question_text(problems[test_qid])
179 | context = get_context_text(problems[test_qid], args.use_caption)
180 | choice = get_choice_text(problems[test_qid], args.options)
181 | answer = get_answer(problems[test_qid], args.options)
182 | lecture = get_lecture_text(problems[test_qid])
183 | solution = get_solution_text(problems[test_qid])
184 |
185 | test_example = create_one_example(args.prompt_format,
186 | question,
187 | context,
188 | choice,
189 | answer,
190 | lecture,
191 | solution,
192 | test_example=True)
193 | examples.append(test_example)
194 |
195 | # create the prompt input
196 | prompt_input = '\n\n'.join(examples)
197 |
198 | return prompt_input
199 |
200 | def build_prompt(problems, test_qid, args):
201 |
202 | # test example
203 | question = get_question_text(problems[test_qid])
204 | context = get_context_text(problems[test_qid], args.use_caption)
205 | choice = get_choice_text(problems[test_qid], args.options)
206 | answer = get_answer(problems[test_qid], args.options)
207 | lecture = get_lecture_text(problems[test_qid])
208 | solution = get_solution_text(problems[test_qid])
209 |
210 | test_example = create_training_example(args.prompt_format,
211 | question,
212 | context,
213 | choice,
214 | answer,
215 | lecture,
216 | solution)
217 | return test_example
218 |
--------------------------------------------------------------------------------
/util/datasets.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 Gen Luo. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import json, re,random
16 | import torch.utils.data as Data
17 | from torchvision.transforms import transforms
18 | import os
19 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20 | from PIL import Image
21 | from util.base_prompt import *
22 | import torch
23 | from lavin import Tokenizer
24 | import copy
25 |
26 | class ScienceQADataSet(Data.Dataset):
27 | def __init__(self, args,split,model_path,max_words=512,max_image_feats=1):
28 | super(ScienceQADataSet, self).__init__()
29 | self.args = args
30 | # --------------------------
31 | # ---- Raw data loading ---
32 | # --------------------------
33 | self.problems = json.load(open(os.path.join(args.data_root, 'problems.json')))
34 | pid_splits = json.load(open(os.path.join(args.data_root, 'pid_splits.json')))
35 | captions = json.load(open(args.caption_file))["captions"]
36 | self.image_path=os.path.join(args.data_root,'images',split)
37 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
38 | self.max_words = max_words
39 | self.max_image_feats=max_image_feats
40 | self.split=split
41 | for qid in self.problems:
42 | self.problems[qid]['caption'] = captions[qid] if qid in captions else ""
43 |
44 | self.qids = pid_splits['%s' % (split)]
45 |
46 | print(f"number of problems in split {split}: {len(self.qids)}\n")
47 |
48 | self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
49 |
50 | def tokenize(self,prompt,answer):
51 | example=prompt+answer
52 | # print(prompt)
53 | prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64)
54 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64)
55 | padding = self.max_words - example.shape[0]
56 | if padding > 0:
57 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
58 | elif padding < 0:
59 | example = example[:self.max_words]
60 | labels = copy.deepcopy(example)
61 | labels[:len(prompt)] = -1
62 | example_mask = example.ge(0)
63 | label_mask = labels.ge(0)
64 | example[~example_mask] = 0
65 | labels[~label_mask] = 0
66 | example_mask = example_mask.float()
67 | label_mask = label_mask.float()
68 | return example, labels, example_mask,label_mask
69 |
70 |
71 | def __getitem__(self, idx):
72 |
73 | prompt_question,prompt_answer= build_prompt(self.problems,self.qids[idx],self.args)
74 | answer,choices,qid=self.problems[self.qids[idx]]["answer"], self.problems[self.qids[idx]]["choices"],self.qids[idx]
75 |
76 | if self.problems[self.qids[idx]]['image'] is not None:
77 | image = Image.open(os.path.join(self.image_path, self.qids[idx], 'image.png')).convert('RGB')
78 | image = self.transforms(image)
79 | image_mask=torch.cat([torch.Tensor([float('-inf')]*self.max_image_feats),torch.zeros(self.max_words)])
80 | indicator=1
81 | else:
82 | image=torch.Tensor(torch.zeros(3,224,224).float())
83 | image_mask=torch.zeros(self.max_words+self.max_image_feats)
84 | indicator=0
85 |
86 | example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer)
87 |
88 | return example, labels, example_mask, image,indicator
89 |
90 | def __len__(self):
91 | return len(self.qids)
92 |
93 | def shuffle_list(self, list):
94 | random.shuffle(list)
95 |
96 |
97 |
98 | class InstrcutDataSet(Data.Dataset):
99 | def __init__(self, args,split,model_path,max_words=512,max_image_feats=1):
100 | super(InstrcutDataSet, self).__init__()
101 | self.args = args
102 | # --------------------------
103 | # ---- Raw data loading ---
104 | # --------------------------
105 | self.data = json.load(open(os.path.join(args.data_root, 'all_data.json')))[split]
106 |
107 | self.tokenizer = Tokenizer(model_path=model_path + '/tokenizer.model')
108 | self.max_words = max_words
109 | self.max_image_feats=max_image_feats
110 | self.split=split
111 | self.qids = [item['qid'] for item in self.data]
112 |
113 | print(f"number of problems in split {split}: {len(self.qids)}\n")
114 |
115 | self.transforms=transforms.Compose([transforms.Resize((224, 224), interpolation=Image.BICUBIC),transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)])
116 |
117 | def tokenize(self,prompt,answer,max_words=512):
118 | example=prompt+answer
119 | # print(prompt)
120 | prompt=torch.tensor(self.tokenizer.encode(prompt, bos=True, eos=False), dtype=torch.int64)
121 | example = torch.tensor(self.tokenizer.encode(example, bos=True, eos=True), dtype=torch.int64)
122 | padding = max_words - example.shape[0]
123 | if padding > 0:
124 | example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
125 | elif padding < 0:
126 | example = example[:self.max_words]
127 | labels = copy.deepcopy(example)
128 | labels[:len(prompt)] = -1
129 | example_mask = example.ge(0)
130 | label_mask = labels.ge(0)
131 | example[~example_mask] = 0
132 | labels[~label_mask] = 0
133 | example_mask = example_mask.float()
134 | label_mask = label_mask.float()
135 | return example, labels, example_mask,label_mask
136 |
137 |
138 | def __getitem__(self, idx):
139 |
140 | prompt_question=self.data[idx]['instruction']
141 | prompt_answer=self.data[idx]['answer']
142 |
143 | if self.data[idx]['image'] is not None:
144 | # image_path='../data/images/train' if self.data[idx]['image_source']=='sqa' else '../data/images/train2014'
145 | if self.data[idx]['image_source'] == 'sqa':
146 | image = Image.open(os.path.join('../data/images/train', self.qids[idx], 'image.png')).convert('RGB')
147 | else:
148 | image = Image.open(os.path.join('../data/images/train2014', 'COCO_train2014_'+self.data[idx]['image'])).convert('RGB')
149 | image = self.transforms(image)
150 | indicator=1
151 | else:
152 | image=torch.Tensor(torch.zeros(3,224,224).float())
153 | indicator=0
154 |
155 | # print(prompt_question,prompt_answer)
156 | example, labels, example_mask, label_mask=self.tokenize(prompt_question,prompt_answer)
157 |
158 | return example, labels, example_mask, image,indicator
159 |
160 | def __len__(self):
161 | return len(self.qids)
162 |
163 | def shuffle_list(self, list):
164 | random.shuffle(list)
165 |
166 | if __name__ == '__main__':
167 | from torch.utils.data import DataLoader
168 | class Cfg():
169 | def __init__(self):
170 | super(Cfg, self).__init__()
171 | self.options = ["A", "B", "C", "D", "E"]
172 | self.use_caption = True
173 | self.prompt_format = 'CQM-A'
174 | self.data_root = './data'
175 | self.output_root = './output'
176 | self.caption_file = './data/captions.json'
177 | cfg=Cfg()
178 | dataset=ScienceQADataSet(cfg,'val','./data/weights')
179 | data_loader = DataLoader(dataset,
180 | batch_size=1,
181 | shuffle=False,
182 | pin_memory=True)
183 | max_question_len=0
184 | max_answer_len=0
185 | #406 max question
186 | for prompt_questions,question_mask,images,image_masks,prompt_answers,answers,qids in data_loader:
187 | print(prompt_questions)
188 | print(answers)
189 | # if len(prompt_questions[0].split())>max_question_len:
190 | # max_question_len=len(prompt_questions[0].split())
191 | # if len(prompt_answers[0].split())>max_answer_len:
192 | # max_answer_len=len(prompt_answers[0].split())
193 | # print(max_question_len,max_answer_len)
194 |
195 |
196 |
197 |
198 |
199 |
200 |
--------------------------------------------------------------------------------
/util/lars.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # LARS optimizer, implementation from MoCo v3:
8 | # https://github.com/facebookresearch/moco-v3
9 | # --------------------------------------------------------
10 |
11 | import torch
12 |
13 |
14 | class LARS(torch.optim.Optimizer):
15 | """
16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
17 | """
18 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
19 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
20 | super().__init__(params, defaults)
21 |
22 | @torch.no_grad()
23 | def step(self):
24 | for g in self.param_groups:
25 | for p in g['params']:
26 | dp = p.grad
27 |
28 | if dp is None:
29 | continue
30 |
31 | if p.ndim > 1: # if not normalization gamma/beta or bias
32 | dp = dp.add(p, alpha=g['weight_decay'])
33 | param_norm = torch.norm(p)
34 | update_norm = torch.norm(dp)
35 | one = torch.ones_like(param_norm)
36 | q = torch.where(param_norm > 0.,
37 | torch.where(update_norm > 0,
38 | (g['trust_coefficient'] * param_norm / update_norm), one),
39 | one)
40 | dp = dp.mul(q)
41 |
42 | param_state = self.state[p]
43 | if 'mu' not in param_state:
44 | param_state['mu'] = torch.zeros_like(p)
45 | mu = param_state['mu']
46 | mu.mul_(g['momentum']).add_(dp)
47 | p.add_(mu, alpha=-g['lr'])
--------------------------------------------------------------------------------
/util/lr_decay.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # ELECTRA https://github.com/google-research/electra
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import json
13 |
14 |
15 | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
16 | """
17 | Parameter groups for layer-wise lr decay
18 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
19 | """
20 | param_group_names = {}
21 | param_groups = {}
22 |
23 | num_layers = len(model.blocks) + 1
24 |
25 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
26 |
27 | for n, p in model.named_parameters():
28 | if not p.requires_grad:
29 | continue
30 |
31 | # no decay: all 1D parameters and model specific ones
32 | if p.ndim == 1 or n in no_weight_decay_list:
33 | g_decay = "no_decay"
34 | this_decay = 0.
35 | else:
36 | g_decay = "decay"
37 | this_decay = weight_decay
38 |
39 | layer_id = get_layer_id_for_vit(n, num_layers)
40 | group_name = "layer_%d_%s" % (layer_id, g_decay)
41 |
42 | if group_name not in param_group_names:
43 | this_scale = layer_scales[layer_id]
44 |
45 | param_group_names[group_name] = {
46 | "lr_scale": this_scale,
47 | "weight_decay": this_decay,
48 | "params": [],
49 | }
50 | param_groups[group_name] = {
51 | "lr_scale": this_scale,
52 | "weight_decay": this_decay,
53 | "params": [],
54 | }
55 |
56 | param_group_names[group_name]["params"].append(n)
57 | param_groups[group_name]["params"].append(p)
58 |
59 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
60 |
61 | return list(param_groups.values())
62 |
63 |
64 | def get_layer_id_for_vit(name, num_layers):
65 | """
66 | Assign a parameter with its layer id
67 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
68 | """
69 | if name in ['cls_token', 'pos_embed']:
70 | return 0
71 | elif name.startswith('patch_embed'):
72 | return 0
73 | elif name.startswith('blocks'):
74 | return int(name.split('.')[1]) + 1
75 | else:
76 | return num_layers
--------------------------------------------------------------------------------
/util/lr_sched.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 |
9 | def adjust_learning_rate(optimizer, epoch, args):
10 | """Decay the learning rate with half-cycle cosine after warmup"""
11 | if epoch < args.warmup_epochs:
12 | lr = args.lr * epoch / args.warmup_epochs
13 | else:
14 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
15 | (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
16 | for param_group in optimizer.param_groups:
17 | if "lr_scale" in param_group:
18 | param_group["lr"] = lr * param_group["lr_scale"]
19 | else:
20 | param_group["lr"] = lr
21 | return lr
22 | #
23 | # def adjust_learning_rate_v2(optimizer, steps, args,step_per_epoch):
24 | # """Decay the learning rate with half-cycle cosine after warmup"""
25 | # if steps < args.warmup_epochs:
26 | # lr = args.lr * steps / (args.warmup_epochs *step_per_epoch)
27 | # else:
28 | # lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
29 | # (1. + math.cos(math.pi * (steps - args.warmup_epochs*step_per_epoch) / (args.epochs*step_per_epoch - step_per_epoch*args.warmup_epochs)))
30 | # for param_group in optimizer.param_groups:
31 | # if "lr_scale" in param_group:
32 | # param_group["lr"] = lr * param_group["lr_scale"]
33 | # else:
34 | # param_group["lr"] = lr
35 | # return lr
36 |
37 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # References:
8 | # DeiT: https://github.com/facebookresearch/deit
9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit
10 | # --------------------------------------------------------
11 |
12 | import builtins
13 | import datetime
14 | import os
15 | import time
16 | from collections import defaultdict, deque
17 | from pathlib import Path
18 |
19 | import torch
20 | import torch.distributed as dist
21 | from torch import inf
22 |
23 |
24 | class SmoothedValue(object):
25 | """Track a series of values and provide access to smoothed values over a
26 | window or the global series average.
27 | """
28 |
29 | def __init__(self, window_size=20, fmt=None):
30 | if fmt is None:
31 | fmt = "{median:.4f} ({global_avg:.4f})"
32 | self.deque = deque(maxlen=window_size)
33 | self.total = 0.0
34 | self.count = 0
35 | self.fmt = fmt
36 |
37 | def update(self, value, n=1):
38 | self.deque.append(value)
39 | self.count += n
40 | self.total += value * n
41 |
42 | def synchronize_between_processes(self):
43 | """
44 | Warning: does not synchronize the deque!
45 | """
46 | if not is_dist_avail_and_initialized():
47 | return
48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
49 | dist.barrier()
50 | dist.all_reduce(t)
51 | t = t.tolist()
52 | self.count = int(t[0])
53 | self.total = t[1]
54 |
55 | @property
56 | def median(self):
57 | d = torch.tensor(list(self.deque))
58 | return d.median().item()
59 |
60 | @property
61 | def avg(self):
62 | d = torch.tensor(list(self.deque), dtype=torch.float32)
63 | return d.mean().item()
64 |
65 | @property
66 | def global_avg(self):
67 | return self.total / self.count
68 |
69 | @property
70 | def max(self):
71 | return max(self.deque)
72 |
73 | @property
74 | def value(self):
75 | return self.deque[-1]
76 |
77 | def __str__(self):
78 | return self.fmt.format(
79 | median=self.median,
80 | avg=self.avg,
81 | global_avg=self.global_avg,
82 | max=self.max,
83 | value=self.value)
84 |
85 |
86 | class MetricLogger(object):
87 | def __init__(self, delimiter="\t"):
88 | self.meters = defaultdict(SmoothedValue)
89 | self.delimiter = delimiter
90 |
91 | def update(self, **kwargs):
92 | for k, v in kwargs.items():
93 | if v is None:
94 | continue
95 | if isinstance(v, torch.Tensor):
96 | v = v.item()
97 | assert isinstance(v, (float, int))
98 | self.meters[k].update(v)
99 |
100 | def __getattr__(self, attr):
101 | if attr in self.meters:
102 | return self.meters[attr]
103 | if attr in self.__dict__:
104 | return self.__dict__[attr]
105 | raise AttributeError("'{}' object has no attribute '{}'".format(
106 | type(self).__name__, attr))
107 |
108 | def __str__(self):
109 | loss_str = []
110 | for name, meter in self.meters.items():
111 | loss_str.append(
112 | "{}: {}".format(name, str(meter))
113 | )
114 | return self.delimiter.join(loss_str)
115 |
116 | def synchronize_between_processes(self):
117 | for meter in self.meters.values():
118 | meter.synchronize_between_processes()
119 |
120 | def add_meter(self, name, meter):
121 | self.meters[name] = meter
122 |
123 | def log_every(self, iterable, print_freq, header=None):
124 | i = 0
125 | if not header:
126 | header = ''
127 | start_time = time.time()
128 | end = time.time()
129 | iter_time = SmoothedValue(fmt='{avg:.4f}')
130 | data_time = SmoothedValue(fmt='{avg:.4f}')
131 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
132 | log_msg = [
133 | header,
134 | '[{0' + space_fmt + '}/{1}]',
135 | 'eta: {eta}',
136 | '{meters}',
137 | 'time: {time}',
138 | 'data: {data}'
139 | ]
140 | if torch.cuda.is_available():
141 | log_msg.append('max mem: {memory:.0f}')
142 | log_msg = self.delimiter.join(log_msg)
143 | MB = 1024.0 * 1024.0
144 | for obj in iterable:
145 | data_time.update(time.time() - end)
146 | yield obj
147 | iter_time.update(time.time() - end)
148 | if i % print_freq == 0 or i == len(iterable) - 1:
149 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
151 | if torch.cuda.is_available():
152 | print(log_msg.format(
153 | i, len(iterable), eta=eta_string,
154 | meters=str(self),
155 | time=str(iter_time), data=str(data_time),
156 | memory=torch.cuda.max_memory_allocated() / MB))
157 | else:
158 | print(log_msg.format(
159 | i, len(iterable), eta=eta_string,
160 | meters=str(self),
161 | time=str(iter_time), data=str(data_time)))
162 | i += 1
163 | end = time.time()
164 | total_time = time.time() - start_time
165 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
166 | print('{} Total time: {} ({:.4f} s / it)'.format(
167 | header, total_time_str, total_time / len(iterable)))
168 |
169 |
170 | def setup_for_distributed(is_master):
171 | """
172 | This function disables printing when not in master process
173 | """
174 | builtin_print = builtins.print
175 |
176 | def print(*args, **kwargs):
177 | force = kwargs.pop('force', False)
178 | force = force or (get_world_size() > 8)
179 | if is_master or force:
180 | now = datetime.datetime.now().time()
181 | builtin_print('[{}] '.format(now), end='') # print with time stamp
182 | builtin_print(*args, **kwargs)
183 |
184 | builtins.print = print
185 |
186 |
187 | def is_dist_avail_and_initialized():
188 | if not dist.is_available():
189 | return False
190 | if not dist.is_initialized():
191 | return False
192 | return True
193 |
194 |
195 | def get_world_size():
196 | if not is_dist_avail_and_initialized():
197 | return 1
198 | return dist.get_world_size()
199 |
200 |
201 | def get_rank():
202 | if not is_dist_avail_and_initialized():
203 | return 0
204 | return dist.get_rank()
205 |
206 |
207 | def is_main_process():
208 | return get_rank() == 0
209 |
210 |
211 | def save_on_master(*args, **kwargs):
212 | if is_main_process():
213 | torch.save(*args, **kwargs)
214 |
215 | from fairscale.nn.model_parallel.initialize import initialize_model_parallel
216 | def init_distributed_mode(args):
217 | if args.dist_on_itp:
218 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
219 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
220 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
221 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
222 | os.environ['LOCAL_RANK'] = str(args.gpu)
223 | os.environ['RANK'] = str(args.rank)
224 | os.environ['WORLD_SIZE'] = str(args.world_size)
225 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
226 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
227 | args.rank = int(os.environ["RANK"])
228 | args.world_size = int(os.environ['WORLD_SIZE'])
229 | args.gpu = int(os.environ['LOCAL_RANK'])
230 | elif 'SLURM_PROCID' in os.environ:
231 | args.rank = int(os.environ['SLURM_PROCID'])
232 | args.gpu = args.rank % torch.cuda.device_count()
233 | else:
234 | print('Not using distributed mode')
235 | setup_for_distributed(is_master=True) # hack
236 | args.distributed = False
237 | return
238 |
239 | args.distributed = True
240 |
241 | torch.cuda.set_device(args.gpu)
242 | args.dist_backend = 'nccl'
243 | print('| distributed init (rank {}): {}, gpu {}'.format(
244 | args.rank, args.dist_url, args.gpu), flush=True)
245 |
246 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
247 | world_size=args.world_size, rank=args.rank)
248 | torch.distributed.barrier()
249 | setup_for_distributed(args.rank == 0)
250 |
251 |
252 | class NativeScalerWithGradNormCount:
253 | state_dict_key = "amp_scaler"
254 |
255 | def __init__(self):
256 | self._scaler = torch.cuda.amp.GradScaler()
257 |
258 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
259 | self._scaler.scale(loss).backward(create_graph=create_graph)
260 | if update_grad:
261 | if clip_grad is not None:
262 | assert parameters is not None
263 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
264 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
265 | else:
266 | self._scaler.unscale_(optimizer)
267 | norm = get_grad_norm_(parameters)
268 | self._scaler.step(optimizer)
269 | self._scaler.update()
270 | else:
271 | norm = None
272 | return norm
273 |
274 | def state_dict(self):
275 | return self._scaler.state_dict()
276 |
277 | def load_state_dict(self, state_dict):
278 | self._scaler.load_state_dict(state_dict)
279 |
280 |
281 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
282 | if isinstance(parameters, torch.Tensor):
283 | parameters = [parameters]
284 | parameters = [p for p in parameters if p.grad is not None]
285 | norm_type = float(norm_type)
286 | if len(parameters) == 0:
287 | return torch.tensor(0.)
288 | device = parameters[0].grad.device
289 | if norm_type == inf:
290 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
291 | else:
292 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
293 | return total_norm
294 |
295 |
296 | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
297 | output_dir = Path(args.output_dir)
298 | epoch_name = str(epoch)
299 | model_without_ddp.eval()
300 | trainable = {}
301 | for n, p in model.named_parameters():
302 | if 'adapter' in n:
303 | trainable[n] = p.data
304 | # if loss_scaler is not None:
305 | checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
306 | for checkpoint_path in checkpoint_paths:
307 | to_save = {
308 | 'model': trainable,
309 | 'optimizer': optimizer.state_dict(),
310 | 'epoch': epoch,
311 | 'scaler': loss_scaler.state_dict() if loss_scaler is not None else None,
312 | 'args': args,
313 | }
314 | save_on_master(to_save, checkpoint_path)
315 |
316 |
317 | def load_model(args, model_without_ddp, optimizer, loss_scaler):
318 | if args.resume:
319 | if args.resume.startswith('https'):
320 | checkpoint = torch.hub.load_state_dict_from_url(
321 | args.resume, map_location='cpu', check_hash=True)
322 | else:
323 | checkpoint = torch.load(args.resume, map_location='cpu')
324 | model_without_ddp.load_state_dict(checkpoint['model'],strict=False)
325 | print("Resume checkpoint %s" % args.resume)
326 | if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
327 | optimizer.load_state_dict(checkpoint['optimizer'])
328 | args.start_epoch = checkpoint['epoch'] + 1
329 | if 'scaler' in checkpoint:
330 | loss_scaler.load_state_dict(checkpoint['scaler'])
331 | print("With optim & sched!")
332 |
333 |
334 | def all_reduce_mean(x):
335 | world_size = get_world_size()
336 | if world_size > 1:
337 | x_reduce = torch.tensor(x).cuda()
338 | dist.all_reduce(x_reduce)
339 | x_reduce /= world_size
340 | return x_reduce.item()
341 | else:
342 | return x
343 |
--------------------------------------------------------------------------------
/util/pos_embed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | # --------------------------------------------------------
7 | # Position embedding utils
8 | # --------------------------------------------------------
9 |
10 | import numpy as np
11 |
12 | import torch
13 |
14 | # --------------------------------------------------------
15 | # 2D sine-cosine position embedding
16 | # References:
17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18 | # MoCo v3: https://github.com/facebookresearch/moco-v3
19 | # --------------------------------------------------------
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 | return emb
47 |
48 |
49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
50 | """
51 | embed_dim: output dimension for each position
52 | pos: a list of positions to be encoded: size (M,)
53 | out: (M, D)
54 | """
55 | assert embed_dim % 2 == 0
56 | omega = np.arange(embed_dim // 2, dtype=np.float)
57 | omega /= embed_dim / 2.
58 | omega = 1. / 10000**omega # (D/2,)
59 |
60 | pos = pos.reshape(-1) # (M,)
61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
62 |
63 | emb_sin = np.sin(out) # (M, D/2)
64 | emb_cos = np.cos(out) # (M, D/2)
65 |
66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
67 | return emb
68 |
69 |
70 | # --------------------------------------------------------
71 | # Interpolate position embeddings for high-resolution
72 | # References:
73 | # DeiT: https://github.com/facebookresearch/deit
74 | # --------------------------------------------------------
75 | def interpolate_pos_embed(model, checkpoint_model):
76 | if 'pos_embed' in checkpoint_model:
77 | pos_embed_checkpoint = checkpoint_model['pos_embed']
78 | embedding_size = pos_embed_checkpoint.shape[-1]
79 | num_patches = model.patch_embed.num_patches
80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches
81 | # height (== width) for the checkpoint position embedding
82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
83 | # height (== width) for the new position embedding
84 | new_size = int(num_patches ** 0.5)
85 | # class_token and dist_token are kept unchanged
86 | if orig_size != new_size:
87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
89 | # only the position tokens are interpolated
90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
92 | pos_tokens = torch.nn.functional.interpolate(
93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
96 | checkpoint_model['pos_embed'] = new_pos_embed
97 |
--------------------------------------------------------------------------------
/util/quantization.py:
--------------------------------------------------------------------------------
1 | from transformers.utils.bitsandbytes import *
2 | from transformers import BitsAndBytesConfig
3 | import torch
4 | from torch import nn
5 | import bitsandbytes as bnb
6 |
7 | from fairscale.nn.model_parallel.layers import (
8 | ParallelEmbedding,
9 | RowParallelLinear,
10 | ColumnParallelLinear,
11 | )
12 | def _replace_with_bnb_linear(
13 | model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, has_been_replaced=False
14 | ):
15 | """
16 | Private method that wraps the recursion for module replacement.
17 |
18 | Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
19 | """
20 | for name, module in model.named_children():
21 | if current_key_name is None:
22 | current_key_name = []
23 | current_key_name.append(name)
24 |
25 | if (isinstance(module, nn.Linear) or isinstance(module, ColumnParallelLinear) or isinstance(module, RowParallelLinear) ) and name not in modules_to_not_convert:
26 | # Check if the current key is not in the `modules_to_not_convert`
27 | if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
28 | # with init_empty_weights():
29 | if quantization_config.quantization_method() == "llm_int8":
30 | model._modules[name] = bnb.nn.Linear8bitLt(
31 | module.in_features,
32 | module.out_features,
33 | module.bias is not None,
34 | has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
35 | threshold=quantization_config.llm_int8_threshold,
36 | )
37 | has_been_replaced = True
38 | else:
39 | if (
40 | quantization_config.llm_int8_skip_modules is not None
41 | and name in quantization_config.llm_int8_skip_modules
42 | ):
43 | pass
44 | else:
45 | model._modules[name] = bnb.nn.Linear4bit(
46 | module.in_features,
47 | module.out_features,
48 | module.bias is not None,
49 | quantization_config.bnb_4bit_compute_dtype,
50 | compress_statistics=quantization_config.bnb_4bit_use_double_quant,
51 | quant_type=quantization_config.bnb_4bit_quant_type,
52 | )
53 | has_been_replaced = True
54 | # Force requires grad to False to avoid unexpected errors
55 | model._modules[name].requires_grad_(False)
56 | if len(list(module.children())) > 0:
57 | _, has_been_replaced = _replace_with_bnb_linear(
58 | module,
59 | modules_to_not_convert,
60 | current_key_name,
61 | quantization_config,
62 | has_been_replaced=has_been_replaced,
63 | )
64 | # Remove the last key for recursion
65 | current_key_name.pop(-1)
66 | return model, has_been_replaced
67 |
68 |
69 | def quant_model_bnb(model, quant_bit='4bit', keep_in_fp32_modules=[],
70 | quantization_config=None):
71 | if quantization_config is None:
72 | # set default quantization config
73 | # compute_dtype = (torch.float16 if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
74 | quantization_config = BitsAndBytesConfig(
75 | load_in_4bit=quant_bit == '4bit',
76 | load_in_8bit=quant_bit == '8bit',
77 | llm_int8_threshold=6.0,
78 | llm_int8_has_fp16_weight=False,
79 | bnb_4bit_compute_dtype=torch.float16,
80 | bnb_4bit_use_double_quant=True,
81 | bnb_4bit_quant_type='nf4',
82 | )
83 |
84 | model,_ = _replace_with_bnb_linear(
85 | model, modules_to_not_convert=keep_in_fp32_modules, quantization_config=quantization_config
86 | )
87 |
88 | return model
89 |
90 |
--------------------------------------------------------------------------------