├── .gitignore
├── LICENSE
├── README.md
├── examples
└── image1.jpg
├── images
├── fig1.jpg
└── fig2.jpg
├── internvl
├── conversation.py
├── dist_utils.py
├── model
│ ├── internlm2
│ │ ├── configuration_internlm2.py
│ │ ├── modeling_internlm2.py
│ │ ├── modeling_internlm2_ve.py
│ │ ├── tokenization_internlm2.py
│ │ └── tokenization_internlm2_fast.py
│ └── internvl_chat
│ │ ├── __init__.py
│ │ ├── configuration_intern_vit.py
│ │ ├── configuration_internvl_chat.py
│ │ ├── flash_attention.py
│ │ ├── modeling_intern_vit.py
│ │ └── modeling_internvl_chat.py
├── patch
│ ├── __init__.py
│ ├── internlm2_packed_training_patch.py
│ ├── llama2_flash_attn_monkey_patch.py
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_rmsnorm_monkey_patch.py
│ ├── pad_data_collator.py
│ ├── qwen2_packed_training_patch.py
│ ├── train_dataloader_patch.py
│ └── train_sampler_patch.py
├── serve
│ ├── __init__.py
│ ├── constants.py
│ ├── mm_utils.py
│ ├── model_worker.py
│ └── utils.py
└── train
│ ├── __init__.py
│ ├── constants.py
│ ├── dataset.py
│ ├── dataset_packed.py
│ ├── internvl_chat_finetune.py
│ └── trainer_monkey_patch.py
├── requirements.txt
└── shell
├── data_llava_finetune.json
├── mono_internvl_finetune_llava_slurm.sh
├── mono_internvl_finetune_llava_torchrun.sh
└── zero_stage1_config.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | .idea/
163 |
164 | .DS_Store
165 |
166 |
167 | /playground/data/
168 | /playground/*.jsonl
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 OpenGVLab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mono-InternVL: Pushing the Boundaries of Monolithic Multimodal Large Language Models with Endogenous Visual Pre-training
2 |
3 | [[📜 Paper]](https://arxiv.org/abs/2410.08202) [[⭐️Project Page]](https://internvl.github.io/blog/2024-10-10-Mono-InternVL/) [[🤗 Model]](https://huggingface.co/collections/OpenGVLab/mono-internvl-6707cb402afb22f1e29f4d2b) [[📝 Chinese Post]](https://mp.weixin.qq.com/s/FmjG0Gp5ow7mm2Vzd9ppPg)
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## 📰 News
15 | - **2025.3**: We release the SFT code on LLaVA-v1.5-mix665k dataset. We also release the [258M synthetic data](https://huggingface.co/datasets/OpenGVLab/Mono-InternVL-2B-Synthetic-Data) used in S1.2 to boost future research.
16 | - **2025.2**: 🎉🎉 Mono-InternVL is accepted by **CVPR 2025**. Also check out our [**SynerGen-VL**](https://huggingface.co/papers/2412.09604) (CVPR 2025) that extends the monolithic structure to unified image generation and multimodal understanding, which will be open-sourced soon.
17 | - **2024.11**: Mono-InternVL is supported by [lmdeploy](https://github.com/InternLM/lmdeploy/pull/2727).
18 | - **2024.11**: Mono-InternVL is supported by [vllm](https://github.com/vllm-project/vllm/pull/9528).
19 |
20 |
21 | ## ⭐️ Introduction
22 |
23 | We release Mono-InternVL, a **monolithic** multimodal large language model (MLLM) that integrates visual encoding and textual decoding into a single LLM. In Mono-InternVL, a set of visual experts is embedded into the pre-trained LLM via a **mixture-of-experts (MoE) mechanism**. By freezing the LLM, Mono-InternVL ensures that visual capabilities are optimized without compromising the pre-trained language knowledge. Based on this structure, an innovative **Endogenous Visual Pretraining (EViP)** is introduced to realize coarse-to-fine visual learning.
24 |
25 |
26 | Mono-InternVL achieves superior performance compared to state-of-the-art MLLM Mini-InternVL-2B-1.5 and significantly outperforms other monolithic MLLMs, as shown in the [radar chart](#radar) above. Meanwhile, it achieves better deployment efficiency, with first token latency reduced by up to 67%.
27 |
28 |
29 | For more details, please refer to our [paper](https://arxiv.org/abs/2410.08202).
30 |
31 |
32 | ## 📊 Performance
33 | | Benchmark | Chameleon-7B | EVE-7B (HD) | Emu3 | Mini-InternVL-2B-1-5 | Mono-InternVL-2B |
34 | | :--------------------------: | :----------: | :---------: | :--------: | :------------------: | :--------------: |
35 | | Type | Monolithic | Monolithic | Monolithic | Modular | Monolithic |
36 | | #Activated Params | 7B | 7B | 8B | 2.2B | 1.8B |
37 | | | | | | | |
38 | | MMVet | 8.3 | 25.7 | 37.2 | 39.3 | 40.1 |
39 | | MMMUval | 25.4 | 32.6 | 31.6 | 34.6 | 33.7 |
40 | | MMEsum | 170 | 1628 | — | 1902 | 1875 |
41 | | MMBench-ENtest | 31.1 | 52.3 | 58.5 | 70.9 | 65.5 |
42 | | MathVistatestmini | 22.3 | 34.2 | — | 41.1 | 45.7 |
43 | | SEED-Image | 30.6 | 64.6 | 68.2 | 69.8 | 67.4 |
44 | | OCRBench | 7 | 398 | 687 | 654 | 767 |
45 | | Hallusion-Bench | 17.1 | 26.4 | — | 37.5 | 34.8 |
46 | | CCBenchdev | 3.5 | 16.3 | — | 63.5 | 66.3 |
47 | | Avgmultimodal | 16.1 | 38.9 | — | 54.4 | 55.2 |
48 | | | | | | | |
49 | | TextVQAval | 4.8 | 56.8 | 64.7 | 70.5 | 72.6 |
50 | | SQA-Itest | 47.2 | 64.9 | 89.2 | 84.9 | 93.6 |
51 | | GQAtest | — | 62.6 | 60.3 | 61.6 | 59.5 |
52 | | DocVQAtest | 1.5 | 53.0 | 76.3 | 85.0 | 80.0 |
53 | | AI2Dtest | 46.0 | 61.0 | 70.0 | 69.8 | 68.6 |
54 | | ChartQAtest | 2.9 | 59.1 | 68.6 | 74.8 | 73.7 |
55 | | InfoVQAtest | 5.0 | 25.0 | 43.8 | 55.4 | 43.0 |
56 | | AvgVQA | 17.9 | 54.6 | 67.6 | 71.7 | 70.1 |
57 |
58 | > * Sources of the results include the original papers, our evaluation with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), and [OpenCompass](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME).
59 | > * Average scores are computed by normalizing each metric to a range between 0 and 100.
60 | > * Please note that evaluating the same model using different testing toolkits can result in slight differences, which is normal. Updates to code versions and variations in environment and hardware can also cause minor discrepancies in results.
61 |
62 |
63 | ## 🚀 Inference
64 |
65 | We provide an example code to run Mono-InternVL-2B inference using `transformers`.
66 |
67 | > Please use transformers==4.37.2 to ensure the model works normally.
68 |
69 |
70 | Inference with Transformers (click to expand)
71 |
72 | ```python
73 | import numpy as np
74 | import torch
75 | import torchvision.transforms as T
76 | from decord import VideoReader, cpu
77 | from PIL import Image
78 | from torchvision.transforms.functional import InterpolationMode
79 | from transformers import AutoModel, AutoTokenizer
80 |
81 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
82 | IMAGENET_STD = (0.229, 0.224, 0.225)
83 |
84 | def build_transform(input_size):
85 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
86 | transform = T.Compose([
87 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
88 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
89 | T.ToTensor(),
90 | T.Normalize(mean=MEAN, std=STD)
91 | ])
92 | return transform
93 |
94 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
95 | best_ratio_diff = float('inf')
96 | best_ratio = (1, 1)
97 | area = width * height
98 | for ratio in target_ratios:
99 | target_aspect_ratio = ratio[0] / ratio[1]
100 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
101 | if ratio_diff < best_ratio_diff:
102 | best_ratio_diff = ratio_diff
103 | best_ratio = ratio
104 | elif ratio_diff == best_ratio_diff:
105 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
106 | best_ratio = ratio
107 | return best_ratio
108 |
109 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
110 | orig_width, orig_height = image.size
111 | aspect_ratio = orig_width / orig_height
112 |
113 | # calculate the existing image aspect ratio
114 | target_ratios = set(
115 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
116 | i * j <= max_num and i * j >= min_num)
117 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
118 |
119 | # find the closest aspect ratio to the target
120 | target_aspect_ratio = find_closest_aspect_ratio(
121 | aspect_ratio, target_ratios, orig_width, orig_height, image_size)
122 |
123 | # calculate the target width and height
124 | target_width = image_size * target_aspect_ratio[0]
125 | target_height = image_size * target_aspect_ratio[1]
126 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
127 |
128 | # resize the image
129 | resized_img = image.resize((target_width, target_height))
130 | processed_images = []
131 | for i in range(blocks):
132 | box = (
133 | (i % (target_width // image_size)) * image_size,
134 | (i // (target_width // image_size)) * image_size,
135 | ((i % (target_width // image_size)) + 1) * image_size,
136 | ((i // (target_width // image_size)) + 1) * image_size
137 | )
138 | # split the image
139 | split_img = resized_img.crop(box)
140 | processed_images.append(split_img)
141 | assert len(processed_images) == blocks
142 | if use_thumbnail and len(processed_images) != 1:
143 | thumbnail_img = image.resize((image_size, image_size))
144 | processed_images.append(thumbnail_img)
145 | return processed_images
146 |
147 | def load_image(image_file, input_size=448, max_num=12):
148 | image = Image.open(image_file).convert('RGB')
149 | transform = build_transform(input_size=input_size)
150 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
151 | pixel_values = [transform(image) for image in images]
152 | pixel_values = torch.stack(pixel_values)
153 | return pixel_values
154 |
155 |
156 | path = 'OpenGVLab/Mono-InternVL-2B'
157 | model = AutoModel.from_pretrained(
158 | path,
159 | torch_dtype=torch.bfloat16,
160 | low_cpu_mem_usage=True,
161 | trust_remote_code=True).eval().cuda()
162 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
163 |
164 | # set the max number of tiles in `max_num`
165 | pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda()
166 | generation_config = dict(max_new_tokens=1024, do_sample=True)
167 |
168 | # pure-text conversation
169 | question = 'Hello, who are you?'
170 | response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
171 | print(f'User: {question}\nAssistant: {response}')
172 |
173 | question = 'Can you tell me a story?'
174 | response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True)
175 | print(f'User: {question}\nAssistant: {response}')
176 |
177 | # single-image single-round conversation
178 | question = '\nPlease describe the image shortly.'
179 | response = model.chat(tokenizer, pixel_values, question, generation_config)
180 | print(f'User: {question}\nAssistant: {response}')
181 |
182 | # single-image multi-round conversation
183 | question = '\nPlease describe the image in detail.'
184 | response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
185 | print(f'User: {question}\nAssistant: {response}')
186 |
187 | question = 'Please write a poem according to the image.'
188 | response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
189 | print(f'User: {question}\nAssistant: {response}')
190 | ```
191 |
192 |
193 |
194 |
195 |
196 | Inference with LMDeploy
197 |
198 | Please install lmdeploy>=0.6.3 for Mono-InternVL support.
199 |
200 | ```python
201 | from lmdeploy import pipeline
202 | from lmdeploy.vl import load_image
203 |
204 | image = load_image('./examples/image1.jpg')
205 | pipe = pipeline('OpenGVLab/Mono-InternVL-2B')
206 | response = pipe(('Please describe the image shortly.', image))
207 | print(response.text)
208 | ```
209 |
210 |
211 | ## 🔥 Supervised Finetuning
212 |
213 | Currently we provide the supervised finetuning (S2 instruction tuning) code on the LLaVA-v1.5-mix665k dataset. For details on the dataset, please refer to [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA).
214 |
215 |
216 | Installation
217 |
218 | - Clone this repository:
219 |
220 | ```bash
221 | git clone https://github.com/OpenGVLab/Mono-InternVL.git
222 | ```
223 |
224 | - Create a conda virtual environment and activate it:
225 |
226 | ```bash
227 | conda create -n monointernvl python=3.9 -y
228 | conda activate monointernvl
229 | ```
230 |
231 | - Install dependencies using `requirements.txt`:
232 |
233 | ```bash
234 | pip install -r requirements.txt
235 | ```
236 |
237 | - Additional: Install `flash-attn==2.5.6`:
238 |
239 | ```bash
240 | pip install flash-attn==2.5.6 --no-build-isolation
241 | ```
242 |
243 | Alternatively you can compile from source:
244 |
245 | ```bash
246 | git clone https://github.com/Dao-AILab/flash-attention.git
247 | cd flash-attention
248 | git checkout v2.5.6
249 | python setup.py install
250 | ```
251 |
252 |
253 |
254 | Dataset Preparation
255 |
256 | #### LLaVA-v1.5-mix665k Dataset
257 |
258 | 1. Download the instruction tuning data:
259 | ```sh
260 | mkdir playground
261 | wget https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json -P playground/
262 | ```
263 |
264 | 2. Download image datasets:
265 |
266 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
267 | - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
268 | - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing)
269 | - TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
270 | - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
271 |
272 | 3. Organize data as follows:
273 |
274 | ```none
275 | playground/
276 | ├── data/
277 | │ ├── coco/train2017/
278 | │ ├── gqa/images/
279 | │ ├── ocr_vqa/images/
280 | │ ├── textvqa/train_images/
281 | │ └── vg/
282 | │ ├── VG_100K/
283 | │ └── VG_100K_2/
284 | └── llava_v1_5_mix665k.json
285 | ```
286 |
287 | #### Custom Dataset
288 |
289 | For custom dataset, format your data in to a JSONL file, where each entry is a dictionary organized in the following format (similar to `llava_v1_5_mix665k.json`):
290 |
291 | ```python
292 | {
293 | "id": "000000120375",
294 | "image": "coco/train2017/000000120375.jpg",
295 | "conversations": [
296 | {
297 | "from": "human",
298 | "value": "\nWhat type of vehicle is driving down the street in the image?"
299 | },
300 | {
301 | "from": "gpt",
302 | "value": "A red sports utility vehicle (SUV) is driving down the street in the image."
303 | },
304 | {
305 | "from": "human",
306 | "value": "Is the street crowded with people?"
307 | },
308 | {
309 | "from": "gpt",
310 | "value": "Yes, the street is filled with a considerable number of people, which indicates that the area is busy."
311 | }
312 | # (more turns ...)
313 | ]
314 | }
315 | ```
316 |
317 | Then modify the metadata file `shell/data_llava_finetune.json`:
318 |
319 | ```python
320 | {
321 | "name of your dataset": {
322 | "root": "playground/data/", # combination of "root" and "image" in the JSONL gives the complete image path
323 | "annotation": "path to your JSONL",
324 | "data_augment": false,
325 | "repeat_time": 1,
326 | "length": 12345 # change to the actual number of samples in your dataset
327 | }
328 | }
329 | ```
330 |
331 |
332 |
333 |
334 | Model Preparation
335 |
336 | We provide pretrained models of different stages (S1.1 concept learning, S1.2 semantic learning, S1.3 alignment learning).
337 | Choose from the following models and download the weights to `workdirs/` folder.
338 |
339 |
340 | | model name | download | size |
341 | | ----------------------- | ---------------------------------------------------------------------- |:------:|
342 | | Mono-InternVL-2B-S1-1 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-1) | 6.2 GB |
343 | | Mono-InternVL-2B-S1-2 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-2) | 6.2 GB |
344 | | Mono-InternVL-2B-S1-3 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-3) | 6.2 GB |
345 |
346 |
347 | ```sh
348 | mkdir workdirs
349 | cd workdirs/
350 | # pip install -U huggingface_hub
351 | huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/Mono-InternVL-2B-S1-1 --local-dir Mono-InternVL-2B-S1-1
352 | ```
353 |
354 | The directory structure is:
355 |
356 | ```sh
357 | workdirs/
358 | ├── Mono-InternVL-2B-S1-1/
359 | ├── Mono-InternVL-2B-S1-2/
360 | └── Mono-InternVL-2B-S1-3/
361 | ```
362 |
363 |
364 |
365 | Training
366 |
367 | Finetuning takes around 12 hours on 8x A100 (80G) GPUs.
368 |
369 | #### Single Node Multi-GPU
370 | ```sh
371 | MODEL="./workdirs/Mono-InternVL-2B-S1-3" OUTPUT_DIR="./workdirs/mono_internvl_llava_sft" sh shell/mono_internvl_finetune_llava_torchrun.sh
372 | ```
373 |
374 | #### Slurm Cluster
375 | ```sh
376 | PARTITION="your partition" MODEL="./workdirs/Mono-InternVL-2B-S1-3" OUTPUT_DIR="./workdirs/mono_internvl_llava_sft" sh shell/mono_internvl_finetune_llava_slurm.sh
377 | ```
378 |
379 |
380 |
381 |
382 | ## 🎫 License
383 |
384 | This project is released under the [MIT License](LICENSE).
385 |
386 | ## 🖊️ Citation
387 |
388 | If you find this work helpful in your research, please consider giving this repo a star ⭐ and citing our paper:
389 |
390 | ```bibtex
391 | @article{luo2024mono,
392 | title={Mono-InternVL: Pushing the Boundaries of Monolithic Multimodal Large Language Models with Endogenous Visual Pre-training},
393 | author={Luo, Gen and Yang, Xue and Dou, Wenhan and Wang, Zhaokai and Liu, Jiawen and Dai, Jifeng and Qiao, Yu and Zhu, Xizhou},
394 | journal={arXiv preprint arXiv:2410.08202},
395 | year={2024}
396 | }
397 | ```
398 |
--------------------------------------------------------------------------------
/examples/image1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/examples/image1.jpg
--------------------------------------------------------------------------------
/images/fig1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/images/fig1.jpg
--------------------------------------------------------------------------------
/images/fig2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/images/fig2.jpg
--------------------------------------------------------------------------------
/internvl/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 | import subprocess
4 | from datetime import timedelta
5 |
6 | import deepspeed
7 | import torch
8 | import torch.multiprocessing as mp
9 | from torch import distributed as dist
10 |
11 | timeout = timedelta(minutes=60)
12 |
13 |
14 | def _find_free_port():
15 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
17 | # Binding to port 0 will cause the OS to find an available port for us
18 | sock.bind(('', 0))
19 | port = sock.getsockname()[1]
20 | sock.close()
21 | # NOTE: there is still a chance the port could be taken by other processes.
22 | return port
23 |
24 |
25 | def _is_free_port(port):
26 | ips = socket.gethostbyname_ex(socket.gethostname())[-1]
27 | ips.append('localhost')
28 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
29 | return all(s.connect_ex((ip, port)) != 0 for ip in ips)
30 |
31 |
32 | def init_dist(launcher, backend='nccl', **kwargs):
33 | if mp.get_start_method(allow_none=True) is None:
34 | mp.set_start_method('spawn')
35 | if launcher == 'pytorch':
36 | _init_dist_pytorch(backend, **kwargs)
37 | elif launcher == 'mpi':
38 | _init_dist_mpi(backend, **kwargs)
39 | elif launcher == 'slurm':
40 | _init_dist_slurm(backend, **kwargs)
41 | else:
42 | raise ValueError(f'Invalid launcher type: {launcher}')
43 |
44 |
45 | def _init_dist_pytorch(backend, **kwargs):
46 |
47 | rank = int(os.environ['RANK'])
48 | num_gpus = torch.cuda.device_count()
49 | torch.cuda.set_device(rank % num_gpus)
50 | # dist.init_process_group(backend=backend, **kwargs)
51 | deepspeed.init_distributed(dist_backend=backend)
52 |
53 |
54 | def _init_dist_mpi(backend, **kwargs):
55 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
56 | torch.cuda.set_device(local_rank)
57 | if 'MASTER_PORT' not in os.environ:
58 | # 29500 is torch.distributed default port
59 | os.environ['MASTER_PORT'] = '29500'
60 | if 'MASTER_ADDR' not in os.environ:
61 | raise KeyError('The environment variable MASTER_ADDR is not set')
62 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
63 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
64 | dist.init_process_group(backend=backend, **kwargs)
65 |
66 |
67 | def _init_dist_slurm(backend, port=None):
68 | """Initialize slurm distributed training environment.
69 |
70 | If argument ``port`` is not specified, then the master port will be system
71 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
72 | environment variable, then a default port ``29500`` will be used.
73 |
74 | Args:
75 | backend (str): Backend of torch.distributed.
76 | port (int, optional): Master port. Defaults to None.
77 | """
78 | proc_id = int(os.environ['SLURM_PROCID'])
79 | ntasks = int(os.environ['SLURM_NTASKS'])
80 | node_list = os.environ['SLURM_NODELIST']
81 | num_gpus = torch.cuda.device_count()
82 | torch.cuda.set_device(proc_id % num_gpus)
83 | addr = subprocess.getoutput(
84 | f'scontrol show hostname {node_list} | head -n1')
85 | # specify master port
86 | if port is not None:
87 | os.environ['MASTER_PORT'] = str(port)
88 | elif 'MASTER_PORT' in os.environ:
89 | pass # use MASTER_PORT in the environment variable
90 | else:
91 | # if torch.distributed default port(29500) is available
92 | # then use it, else find a free port
93 | if _is_free_port(29500):
94 | os.environ['MASTER_PORT'] = '29500'
95 | else:
96 | os.environ['MASTER_PORT'] = str(_find_free_port())
97 | # use MASTER_ADDR in the environment variable if it already exists
98 | if 'MASTER_ADDR' not in os.environ:
99 | os.environ['MASTER_ADDR'] = addr
100 | os.environ['WORLD_SIZE'] = str(ntasks)
101 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
102 | os.environ['RANK'] = str(proc_id)
103 | # dist.init_process_group(backend=backend, timeout=timeout)
104 | deepspeed.init_distributed(dist_backend=backend)
105 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/configuration_internlm2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ InternLM2 model configuration"""
17 |
18 | from transformers.configuration_utils import PretrainedConfig
19 | from transformers.utils import logging
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24 |
25 |
26 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig
27 | class InternLM2Config(PretrainedConfig):
28 | r"""
29 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32 |
33 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34 | documentation from [`PretrainedConfig`] for more information.
35 |
36 |
37 | Args:
38 | vocab_size (`int`, *optional*, defaults to 32000):
39 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
40 | `inputs_ids` passed when calling [`InternLM2Model`]
41 | hidden_size (`int`, *optional*, defaults to 4096):
42 | Dimension of the hidden representations.
43 | intermediate_size (`int`, *optional*, defaults to 11008):
44 | Dimension of the MLP representations.
45 | num_hidden_layers (`int`, *optional*, defaults to 32):
46 | Number of hidden layers in the Transformer encoder.
47 | num_attention_heads (`int`, *optional*, defaults to 32):
48 | Number of attention heads for each attention layer in the Transformer encoder.
49 | num_key_value_heads (`int`, *optional*):
50 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54 | by meanpooling all the original heads within that group. For more details checkout [this
55 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56 | `num_attention_heads`.
57 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58 | The non-linear activation function (function or string) in the decoder.
59 | max_position_embeddings (`int`, *optional*, defaults to 2048):
60 | The maximum sequence length that this model might ever be used with. Typically set this to something large
61 | just in case (e.g., 512 or 1024 or 2048).
62 | initializer_range (`float`, *optional*, defaults to 0.02):
63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64 | rms_norm_eps (`float`, *optional*, defaults to 1e-12):
65 | The epsilon used by the rms normalization layers.
66 | use_cache (`bool`, *optional*, defaults to `True`):
67 | Whether or not the model should return the last key/values attentions (not used by all models). Only
68 | relevant if `config.is_decoder=True`.
69 | tie_word_embeddings(`bool`, *optional*, defaults to `False`):
70 | Whether to tie weight embeddings
71 | Example:
72 |
73 | """
74 | model_type = 'internlm2'
75 | _auto_class = 'AutoConfig'
76 |
77 | def __init__( # pylint: disable=W0102
78 | self,
79 | vocab_size=103168,
80 | hidden_size=4096,
81 | intermediate_size=11008,
82 | num_hidden_layers=32,
83 | num_attention_heads=32,
84 | num_key_value_heads=None,
85 | hidden_act='silu',
86 | max_position_embeddings=2048,
87 | initializer_range=0.02,
88 | rms_norm_eps=1e-6,
89 | use_cache=True,
90 | pad_token_id=0,
91 | bos_token_id=1,
92 | eos_token_id=2,
93 | tie_word_embeddings=False,
94 | bias=True,
95 | rope_theta=10000,
96 | rope_scaling=None,
97 | attn_implementation='eager',
98 | **kwargs,
99 | ):
100 | self.vocab_size = vocab_size
101 | self.max_position_embeddings = max_position_embeddings
102 | self.hidden_size = hidden_size
103 | self.intermediate_size = intermediate_size
104 | self.num_hidden_layers = num_hidden_layers
105 | self.num_attention_heads = num_attention_heads
106 | self.bias = bias
107 |
108 | if num_key_value_heads is None:
109 | num_key_value_heads = num_attention_heads
110 | self.num_key_value_heads = num_key_value_heads
111 |
112 | self.hidden_act = hidden_act
113 | self.initializer_range = initializer_range
114 | self.rms_norm_eps = rms_norm_eps
115 | self.use_cache = use_cache
116 | self.rope_theta = rope_theta
117 | self.rope_scaling = rope_scaling
118 | self._rope_scaling_validation()
119 |
120 |
121 | self.attn_implementation = attn_implementation
122 | if self.attn_implementation is None:
123 | self.attn_implementation = 'eager'
124 | super().__init__(
125 | pad_token_id=pad_token_id,
126 | bos_token_id=bos_token_id,
127 | eos_token_id=eos_token_id,
128 | tie_word_embeddings=tie_word_embeddings,
129 | **kwargs,
130 | )
131 |
132 | def _rope_scaling_validation(self):
133 | """
134 | Validate the `rope_scaling` configuration.
135 | """
136 | if self.rope_scaling is None:
137 | return
138 |
139 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140 | raise ValueError(
141 | '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
142 | f'got {self.rope_scaling}'
143 | )
144 | rope_scaling_type = self.rope_scaling.get('type', None)
145 | rope_scaling_factor = self.rope_scaling.get('factor', None)
146 | if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
147 | raise ValueError(
148 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149 | )
150 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
151 | raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
152 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/tokenization_internlm2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Tokenization classes for InternLM."""
18 | import os
19 | from shutil import copyfile
20 | from typing import Any, Dict, List, Optional, Tuple
21 |
22 | import sentencepiece as spm
23 | from transformers.tokenization_utils import PreTrainedTokenizer
24 | from transformers.utils import logging
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
29 |
30 | PRETRAINED_VOCAB_FILES_MAP = {}
31 |
32 |
33 | # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
34 | class InternLM2Tokenizer(PreTrainedTokenizer):
35 | """
36 | Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
37 |
38 | Args:
39 | vocab_file (`str`):
40 | Path to the vocabulary file.
41 | """
42 |
43 | vocab_files_names = VOCAB_FILES_NAMES
44 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45 | model_input_names = ['input_ids', 'attention_mask']
46 | _auto_class = 'AutoTokenizer'
47 |
48 | def __init__(
49 | self,
50 | vocab_file,
51 | unk_token='',
52 | bos_token='',
53 | eos_token='',
54 | pad_token='',
55 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
56 | add_bos_token=True,
57 | add_eos_token=False,
58 | decode_with_prefix_space=False,
59 | clean_up_tokenization_spaces=False,
60 | **kwargs,
61 | ):
62 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
63 | self.vocab_file = vocab_file
64 | self.add_bos_token = add_bos_token
65 | self.add_eos_token = add_eos_token
66 | self.decode_with_prefix_space = decode_with_prefix_space
67 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
68 | self.sp_model.Load(vocab_file)
69 | self._no_prefix_space_tokens = None
70 | super().__init__(
71 | bos_token=bos_token,
72 | eos_token=eos_token,
73 | unk_token=unk_token,
74 | pad_token=pad_token,
75 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
76 | **kwargs,
77 | )
78 |
79 | @property
80 | def no_prefix_space_tokens(self):
81 | if self._no_prefix_space_tokens is None:
82 | vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
83 | self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')}
84 | return self._no_prefix_space_tokens
85 |
86 | @property
87 | def vocab_size(self):
88 | """Returns vocab size"""
89 | return self.sp_model.get_piece_size()
90 |
91 | @property
92 | def bos_token_id(self) -> Optional[int]:
93 | return self.sp_model.bos_id()
94 |
95 | @property
96 | def eos_token_id(self) -> Optional[int]:
97 | return self.sp_model.eos_id()
98 |
99 | def get_vocab(self):
100 | """Returns vocab as a dict"""
101 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
102 | vocab.update(self.added_tokens_encoder)
103 | return vocab
104 |
105 | def _tokenize(self, text):
106 | """Returns a tokenized string."""
107 | return self.sp_model.encode(text, out_type=str)
108 |
109 | def _convert_token_to_id(self, token):
110 | """Converts a token (str) in an id using the vocab."""
111 | return self.sp_model.piece_to_id(token)
112 |
113 | def _convert_id_to_token(self, index):
114 | """Converts an index (integer) in a token (str) using the vocab."""
115 | token = self.sp_model.IdToPiece(index)
116 | return token
117 |
118 | def _maybe_add_prefix_space(self, tokens, decoded):
119 | if tokens and tokens[0] not in self.no_prefix_space_tokens:
120 | return ' ' + decoded
121 | else:
122 | return decoded
123 |
124 | def convert_tokens_to_string(self, tokens):
125 | """Converts a sequence of tokens (string) in a single string."""
126 | current_sub_tokens = []
127 | out_string = ''
128 | prev_is_special = False
129 | for token in tokens:
130 | # make sure that special tokens are not decoded using sentencepiece model
131 | if token in self.all_special_tokens:
132 | if not prev_is_special:
133 | out_string += ' '
134 | out_string += self.sp_model.decode(current_sub_tokens) + token
135 | prev_is_special = True
136 | current_sub_tokens = []
137 | else:
138 | current_sub_tokens.append(token)
139 | prev_is_special = False
140 | out_string += self.sp_model.decode(current_sub_tokens)
141 | out_string = self.clean_up_tokenization(out_string)
142 | out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
143 | return out_string[1:]
144 |
145 | def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
146 | """
147 | Save the vocabulary and special tokens file to a directory.
148 |
149 | Args:
150 | save_directory (`str`):
151 | The directory in which to save the vocabulary.
152 |
153 | Returns:
154 | `Tuple(str)`: Paths to the files saved.
155 | """
156 | if not os.path.isdir(save_directory):
157 | logger.error(f'Vocabulary path ({save_directory}) should be a directory')
158 | return
159 | out_vocab_file = os.path.join(
160 | save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
161 | )
162 |
163 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
164 | copyfile(self.vocab_file, out_vocab_file)
165 | elif not os.path.isfile(self.vocab_file):
166 | with open(out_vocab_file, 'wb') as fi:
167 | content_spiece_model = self.sp_model.serialized_model_proto()
168 | fi.write(content_spiece_model)
169 |
170 | return (out_vocab_file,)
171 |
172 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
173 | if self.add_bos_token:
174 | bos_token_ids = [self.bos_token_id]
175 | else:
176 | bos_token_ids = []
177 |
178 | output = bos_token_ids + token_ids_0
179 |
180 | if token_ids_1 is not None:
181 | output = output + token_ids_1
182 |
183 | if self.add_eos_token:
184 | output = output + [self.eos_token_id]
185 |
186 | return output
187 |
188 | def get_special_tokens_mask(
189 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
190 | ) -> List[int]:
191 | """
192 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
193 | special tokens using the tokenizer `prepare_for_model` method.
194 |
195 | Args:
196 | token_ids_0 (`List[int]`):
197 | List of IDs.
198 | token_ids_1 (`List[int]`, *optional*):
199 | Optional second list of IDs for sequence pairs.
200 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
201 | Whether or not the token list is already formatted with special tokens for the model.
202 |
203 | Returns:
204 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
205 | """
206 | if already_has_special_tokens:
207 | return super().get_special_tokens_mask(
208 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
209 | )
210 |
211 | if token_ids_1 is None:
212 | return [1] + ([0] * len(token_ids_0)) + [1]
213 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
214 |
215 | def create_token_type_ids_from_sequences(
216 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
217 | ) -> List[int]:
218 | """
219 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
220 | use of token type ids, therefore a list of zeros is returned.
221 |
222 | Args:
223 | token_ids_0 (`List[int]`):
224 | List of IDs.
225 | token_ids_1 (`List[int]`, *optional*):
226 | Optional second list of IDs for sequence pairs.
227 |
228 | Returns:
229 | `List[int]`: List of zeros.
230 | """
231 | eos = [self.eos_token_id]
232 |
233 | if token_ids_1 is None:
234 | return len(token_ids_0 + eos) * [0]
235 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
236 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/tokenization_internlm2_fast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Tokenization Fast class for InternLM."""
18 | import os
19 | from shutil import copyfile
20 | from typing import Any, Dict, Optional, Tuple
21 |
22 | from tokenizers import Tokenizer, decoders, normalizers, processors
23 | from tokenizers.models import BPE
24 | from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
25 | SentencePieceExtractor,
26 | SpmConverter)
27 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
28 | from transformers.utils import logging
29 |
30 | from .tokenization_internlm2 import InternLM2Tokenizer
31 |
32 | logger = logging.get_logger(__name__)
33 |
34 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
35 |
36 |
37 | # Modified from transformers.convert_slow_tokenizer.LlamaConverter
38 | class InternLM2Converter(SpmConverter):
39 | handle_byte_fallback = True
40 |
41 | def vocab(self, proto):
42 | vocab = [
43 | ('', 0.0),
44 | ('', 0.0),
45 | ('', 0.0),
46 | ]
47 | vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
48 | return vocab
49 |
50 | def unk_id(self, proto):
51 | unk_id = 0
52 | return unk_id
53 |
54 | def decoder(self, replacement, add_prefix_space):
55 | return decoders.Sequence(
56 | [
57 | decoders.Replace('▁', ' '),
58 | decoders.ByteFallback(),
59 | decoders.Fuse(),
60 | decoders.Strip(content=' ', left=1),
61 | ]
62 | )
63 |
64 | def tokenizer(self, proto):
65 | model_type = proto.trainer_spec.model_type
66 | vocab_scores = self.vocab(proto)
67 | # special tokens
68 | added_tokens = self.original_tokenizer.added_tokens_decoder
69 | for i in range(len(vocab_scores)):
70 | piece, score = vocab_scores[i]
71 | if i in added_tokens:
72 | vocab_scores[i] = (added_tokens[i].content, score)
73 | if model_type == 1:
74 | raise RuntimeError('InternLM2 is supposed to be a BPE model!')
75 |
76 | elif model_type == 2:
77 | _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores)
78 | bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
79 | tokenizer = Tokenizer(
80 | BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True)
81 | )
82 | tokenizer.add_special_tokens(
83 | [ added_token for index, added_token in added_tokens.items()]
84 | )
85 | else:
86 | raise Exception(
87 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
88 | )
89 |
90 | return tokenizer
91 |
92 | def normalizer(self, proto):
93 | normalizers_list = []
94 | if proto.normalizer_spec.add_dummy_prefix:
95 | normalizers_list.append(normalizers.Prepend(prepend='▁'))
96 | normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
97 | return normalizers.Sequence(normalizers_list)
98 |
99 | def pre_tokenizer(self, replacement, add_prefix_space):
100 | return None
101 |
102 |
103 | SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
104 |
105 |
106 | # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
107 | class InternLM2TokenizerFast(PreTrainedTokenizerFast):
108 | vocab_files_names = VOCAB_FILES_NAMES
109 | slow_tokenizer_class = InternLM2Tokenizer
110 | padding_side = 'left'
111 | model_input_names = ['input_ids', 'attention_mask']
112 | _auto_class = 'AutoTokenizer'
113 |
114 | def __init__(
115 | self,
116 | vocab_file,
117 | unk_token='',
118 | bos_token='',
119 | eos_token='',
120 | pad_token='',
121 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
122 | add_bos_token=True,
123 | add_eos_token=False,
124 | decode_with_prefix_space=False,
125 | clean_up_tokenization_spaces=False,
126 | **kwargs,
127 | ):
128 | super().__init__(
129 | vocab_file=vocab_file,
130 | unk_token=unk_token,
131 | bos_token=bos_token,
132 | eos_token=eos_token,
133 | pad_token=pad_token,
134 | sp_model_kwargs=sp_model_kwargs,
135 | add_bos_token=add_bos_token,
136 | add_eos_token=add_eos_token,
137 | decode_with_prefix_space=decode_with_prefix_space,
138 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
139 | **kwargs,
140 | )
141 | self._add_bos_token = add_bos_token
142 | self._add_eos_token = add_eos_token
143 | self.update_post_processor()
144 | self.vocab_file = vocab_file
145 |
146 | @property
147 | def can_save_slow_tokenizer(self) -> bool:
148 | return os.path.isfile(self.vocab_file) if self.vocab_file else False
149 |
150 | def update_post_processor(self):
151 | """
152 | Updates the underlying post processor with the current `bos_token` and `eos_token`.
153 | """
154 | bos = self.bos_token
155 | bos_token_id = self.bos_token_id
156 | if bos is None and self.add_bos_token:
157 | raise ValueError('add_bos_token = True but bos_token = None')
158 |
159 | eos = self.eos_token
160 | eos_token_id = self.eos_token_id
161 | if eos is None and self.add_eos_token:
162 | raise ValueError('add_eos_token = True but eos_token = None')
163 |
164 | single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
165 | pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
166 |
167 | special_tokens = []
168 | if self.add_bos_token:
169 | special_tokens.append((bos, bos_token_id))
170 | if self.add_eos_token:
171 | special_tokens.append((eos, eos_token_id))
172 | self._tokenizer.post_processor = processors.TemplateProcessing(
173 | single=single, pair=pair, special_tokens=special_tokens
174 | )
175 |
176 | @property
177 | def add_eos_token(self):
178 | return self._add_eos_token
179 |
180 | @property
181 | def add_bos_token(self):
182 | return self._add_bos_token
183 |
184 | @add_eos_token.setter
185 | def add_eos_token(self, value):
186 | self._add_eos_token = value
187 | self.update_post_processor()
188 |
189 | @add_bos_token.setter
190 | def add_bos_token(self, value):
191 | self._add_bos_token = value
192 | self.update_post_processor()
193 |
194 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
195 | if not self.can_save_slow_tokenizer:
196 | raise ValueError(
197 | 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
198 | 'tokenizer.'
199 | )
200 |
201 | if not os.path.isdir(save_directory):
202 | logger.error(f'Vocabulary path ({save_directory}) should be a directory')
203 | return
204 | out_vocab_file = os.path.join(
205 | save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
206 | )
207 |
208 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
209 | copyfile(self.vocab_file, out_vocab_file)
210 |
211 | return (out_vocab_file,)
212 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2025 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .configuration_intern_vit import InternVisionConfig,InternVisionPatchConfig
8 | from .configuration_internvl_chat import InternVLChatConfig
9 | from .modeling_intern_vit import InternVisionModel,InternVisionPatchModel
10 | from .modeling_internvl_chat import InternVLChatModel
11 |
12 | __all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVisionPatchModel',
13 | 'InternVLChatConfig', 'InternVisionPatchConfig','InternVLChatModel']
14 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/configuration_intern_vit.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2025 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | import os
7 | from typing import Union
8 |
9 | from transformers.configuration_utils import PretrainedConfig
10 | from transformers.utils import logging
11 |
12 | logger = logging.get_logger(__name__)
13 |
14 |
15 | class InternVisionConfig(PretrainedConfig):
16 | r"""
17 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18 | instantiate a vision encoder according to the specified arguments, defining the model architecture.
19 |
20 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21 | documentation from [`PretrainedConfig`] for more information.
22 |
23 | Args:
24 | num_channels (`int`, *optional*, defaults to 3):
25 | Number of color channels in the input images (e.g., 3 for RGB).
26 | patch_size (`int`, *optional*, defaults to 14):
27 | The size (resolution) of each patch.
28 | image_size (`int`, *optional*, defaults to 224):
29 | The size (resolution) of each image.
30 | qkv_bias (`bool`, *optional*, defaults to `False`):
31 | Whether to add a bias to the queries and values in the self-attention layers.
32 | hidden_size (`int`, *optional*, defaults to 3200):
33 | Dimensionality of the encoder layers and the pooler layer.
34 | num_attention_heads (`int`, *optional*, defaults to 25):
35 | Number of attention heads for each attention layer in the Transformer encoder.
36 | intermediate_size (`int`, *optional*, defaults to 12800):
37 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38 | qk_normalization (`bool`, *optional*, defaults to `True`):
39 | Whether to normalize the queries and keys in the self-attention layers.
40 | num_hidden_layers (`int`, *optional*, defaults to 48):
41 | Number of hidden layers in the Transformer encoder.
42 | use_flash_attn (`bool`, *optional*, defaults to `True`):
43 | Whether to use flash attention mechanism.
44 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47 | layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48 | The epsilon used by the layer normalization layers.
49 | dropout (`float`, *optional*, defaults to 0.0):
50 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51 | drop_path_rate (`float`, *optional*, defaults to 0.0):
52 | Dropout rate for stochastic depth.
53 | attention_dropout (`float`, *optional*, defaults to 0.0):
54 | The dropout ratio for the attention probabilities.
55 | initializer_range (`float`, *optional*, defaults to 0.02):
56 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57 | initializer_factor (`float`, *optional*, defaults to 0.1):
58 | A factor for layer scale.
59 | """
60 |
61 | model_type = 'intern_vit_6b'
62 |
63 | def __init__(
64 | self,
65 | num_channels=3,
66 | patch_size=14,
67 | image_size=224,
68 | qkv_bias=False,
69 | hidden_size=3200,
70 | num_attention_heads=25,
71 | intermediate_size=12800,
72 | qk_normalization=True,
73 | num_hidden_layers=48,
74 | use_flash_attn=True,
75 | hidden_act='gelu',
76 | norm_type='rms_norm',
77 | layer_norm_eps=1e-6,
78 | dropout=0.0,
79 | drop_path_rate=0.0,
80 | attention_dropout=0.0,
81 | initializer_range=0.02,
82 | initializer_factor=0.1,
83 | **kwargs,
84 | ):
85 | super().__init__(**kwargs)
86 |
87 | self.hidden_size = hidden_size
88 | self.intermediate_size = intermediate_size
89 | self.dropout = dropout
90 | self.drop_path_rate = drop_path_rate
91 | self.num_hidden_layers = num_hidden_layers
92 | self.num_attention_heads = num_attention_heads
93 | self.num_channels = num_channels
94 | self.patch_size = patch_size
95 | self.image_size = image_size
96 | self.initializer_range = initializer_range
97 | self.initializer_factor = initializer_factor
98 | self.attention_dropout = attention_dropout
99 | self.layer_norm_eps = layer_norm_eps
100 | self.hidden_act = hidden_act
101 | self.norm_type = norm_type
102 | self.qkv_bias = qkv_bias
103 | self.qk_normalization = qk_normalization
104 | self.use_flash_attn = use_flash_attn
105 |
106 | @classmethod
107 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
108 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
109 |
110 | if 'vision_config' in config_dict:
111 | config_dict = config_dict['vision_config']
112 |
113 | if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
114 | logger.warning(
115 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
116 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
117 | )
118 |
119 |
120 | return cls.from_dict(config_dict, **kwargs)
121 |
122 |
123 |
124 |
125 | class InternVisionPatchConfig(PretrainedConfig):
126 | r"""
127 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
128 | instantiate a vision encoder according to the specified arguments, defining the model architecture.
129 |
130 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
131 | documentation from [`PretrainedConfig`] for more information.
132 |
133 | Args:
134 | num_channels (`int`, *optional*, defaults to 3):
135 | Number of color channels in the input images (e.g., 3 for RGB).
136 | patch_size (`int`, *optional*, defaults to 14):
137 | The size (resolution) of each patch.
138 | image_size (`int`, *optional*, defaults to 224):
139 | The size (resolution) of each image.
140 | qkv_bias (`bool`, *optional*, defaults to `False`):
141 | Whether to add a bias to the queries and values in the self-attention layers.
142 | hidden_size (`int`, *optional*, defaults to 3200):
143 | Dimensionality of the encoder layers and the pooler layer.
144 | num_attention_heads (`int`, *optional*, defaults to 25):
145 | Number of attention heads for each attention layer in the Transformer encoder.
146 | intermediate_size (`int`, *optional*, defaults to 12800):
147 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
148 | qk_normalization (`bool`, *optional*, defaults to `True`):
149 | Whether to normalize the queries and keys in the self-attention layers.
150 | num_hidden_layers (`int`, *optional*, defaults to 48):
151 | Number of hidden layers in the Transformer encoder.
152 | use_flash_attn (`bool`, *optional*, defaults to `True`):
153 | Whether to use flash attention mechanism.
154 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
155 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
156 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
157 | layer_norm_eps (`float`, *optional*, defaults to 1e-6):
158 | The epsilon used by the layer normalization layers.
159 | dropout (`float`, *optional*, defaults to 0.0):
160 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
161 | drop_path_rate (`float`, *optional*, defaults to 0.0):
162 | Dropout rate for stochastic depth.
163 | attention_dropout (`float`, *optional*, defaults to 0.0):
164 | The dropout ratio for the attention probabilities.
165 | initializer_range (`float`, *optional*, defaults to 0.02):
166 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
167 | initializer_factor (`float`, *optional*, defaults to 0.1):
168 | A factor for layer scale.
169 | """
170 |
171 | model_type = 'intern_vit_patch'
172 |
173 | def __init__(
174 | self,
175 | num_channels=3,
176 | patch_size=14,
177 | image_size=224,
178 | qkv_bias=False,
179 | hidden_size=3200,
180 | num_attention_heads=25,
181 | intermediate_size=12800,
182 | qk_normalization=True,
183 | num_hidden_layers=48,
184 | use_flash_attn=True,
185 | hidden_act='gelu',
186 | norm_type='rms_norm',
187 | layer_norm_eps=1e-6,
188 | dropout=0.0,
189 | drop_path_rate=0.0,
190 | attention_dropout=0.0,
191 | initializer_range=0.02,
192 | initializer_factor=0.1,
193 | **kwargs,
194 | ):
195 | super().__init__(**kwargs)
196 |
197 | self.hidden_size = hidden_size
198 | self.intermediate_size = intermediate_size
199 | self.dropout = dropout
200 | self.drop_path_rate = drop_path_rate
201 | self.num_hidden_layers = num_hidden_layers
202 | self.num_attention_heads = num_attention_heads
203 | self.num_channels = num_channels
204 | self.patch_size = patch_size
205 | self.image_size = image_size
206 | self.initializer_range = initializer_range
207 | self.initializer_factor = initializer_factor
208 | self.attention_dropout = attention_dropout
209 | self.layer_norm_eps = layer_norm_eps
210 | self.hidden_act = hidden_act
211 | self.norm_type = norm_type
212 | self.qkv_bias = qkv_bias
213 | self.qk_normalization = qk_normalization
214 | self.use_flash_attn = use_flash_attn
215 |
216 | @classmethod
217 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
218 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
219 |
220 | if 'vision_config' in config_dict:
221 | config_dict = config_dict['vision_config']
222 |
223 | if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
224 | logger.warning(
225 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
226 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
227 | )
228 |
229 |
230 | return cls.from_dict(config_dict, **kwargs)
231 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/configuration_internvl_chat.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2025 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import copy
8 |
9 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config
10 | from transformers import AutoConfig, LlamaConfig, Qwen2Config
11 | from transformers.configuration_utils import PretrainedConfig
12 | from transformers.utils import logging
13 |
14 | from .configuration_intern_vit import InternVisionConfig, InternVisionPatchConfig
15 |
16 | logger = logging.get_logger(__name__)
17 |
18 |
19 | class InternVLChatConfig(PretrainedConfig):
20 | model_type = 'internvl_chat'
21 | is_composition = True
22 |
23 | def __init__(
24 | self,
25 | vision_config=None,
26 | llm_config=None,
27 | use_backbone_lora=0,
28 | use_llm_lora=0,
29 | pad2square=False,
30 | select_layer=-4,
31 | force_image_size=None,
32 | downsample_ratio=0.5,
33 | template=None,
34 | dynamic_image_size=False,
35 | use_thumbnail=False,
36 | ps_version='v1',
37 | min_dynamic_patch=1,
38 | max_dynamic_patch=6,
39 | **kwargs):
40 | super().__init__(**kwargs)
41 |
42 | if vision_config is None:
43 | vision_config = {}
44 | logger.info('vision_config is None. Initializing the InternVisionConfig with default values.')
45 |
46 | if llm_config is None:
47 | llm_config = {}
48 | logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).')
49 |
50 | if vision_config and vision_config['model_type']=='intern_vit_patch':
51 | self.vision_config = InternVisionPatchConfig(**vision_config)
52 | else:
53 | self.vision_config = InternVisionConfig(**vision_config)
54 | if llm_config['architectures'][0] == 'LlamaForCausalLM':
55 | self.llm_config = LlamaConfig(**llm_config)
56 | elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
57 | self.llm_config = InternLM2Config(**llm_config)
58 | elif llm_config['architectures'][0] == 'InternLM2VEForCausalLM':
59 | self.llm_config = InternLM2Config(**llm_config)
60 | elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
61 | self.llm_config = Qwen2Config(**llm_config)
62 | else:
63 | raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0]))
64 | self.use_backbone_lora = use_backbone_lora
65 | self.use_llm_lora = use_llm_lora
66 | self.pad2square = pad2square
67 | self.select_layer = select_layer
68 | self.force_image_size = force_image_size
69 | self.downsample_ratio = downsample_ratio
70 | self.template = template
71 | self.dynamic_image_size = dynamic_image_size
72 | self.use_thumbnail = use_thumbnail
73 | self.ps_version = ps_version # pixel shuffle version
74 | self.min_dynamic_patch = min_dynamic_patch
75 | self.max_dynamic_patch = max_dynamic_patch
76 |
77 | logger.info(f'vision_select_layer: {self.select_layer}')
78 | logger.info(f'ps_version: {self.ps_version}')
79 | logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
80 | logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
81 |
82 | def to_dict(self):
83 | """
84 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
85 |
86 | Returns:
87 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
88 | """
89 | output = copy.deepcopy(self.__dict__)
90 | output['vision_config'] = self.vision_config.to_dict()
91 | output['llm_config'] = self.llm_config.to_dict()
92 | output['model_type'] = self.__class__.model_type
93 | output['use_backbone_lora'] = self.use_backbone_lora
94 | output['use_llm_lora'] = self.use_llm_lora
95 | output['pad2square'] = self.pad2square
96 | output['select_layer'] = self.select_layer
97 | output['force_image_size'] = self.force_image_size
98 | output['downsample_ratio'] = self.downsample_ratio
99 | output['template'] = self.template
100 | output['dynamic_image_size'] = self.dynamic_image_size
101 | output['use_thumbnail'] = self.use_thumbnail
102 | output['ps_version'] = self.ps_version
103 | output['min_dynamic_patch'] = self.min_dynamic_patch
104 | output['max_dynamic_patch'] = self.max_dynamic_patch
105 |
106 | return output
107 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/flash_attention.py:
--------------------------------------------------------------------------------
1 | # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
2 | import torch
3 | import torch.nn as nn
4 | from einops import rearrange
5 |
6 | try: # v1
7 | from flash_attn.flash_attn_interface import \
8 | flash_attn_unpadded_qkvpacked_func
9 | except: # v2
10 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
11 |
12 | from flash_attn.bert_padding import pad_input, unpad_input
13 |
14 |
15 | class FlashAttention(nn.Module):
16 | """Implement the scaled dot product attention with softmax.
17 | Arguments
18 | ---------
19 | softmax_scale: The temperature to use for the softmax attention.
20 | (default: 1/sqrt(d_keys) where d_keys is computed at
21 | runtime)
22 | attention_dropout: The dropout rate to apply to the attention
23 | (default: 0.0)
24 | """
25 |
26 | def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
27 | super().__init__()
28 | self.softmax_scale = softmax_scale
29 | self.dropout_p = attention_dropout
30 |
31 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
32 | max_s=None, need_weights=False):
33 | """Implements the multihead softmax attention.
34 | Arguments
35 | ---------
36 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
37 | if unpadded: (nnz, 3, h, d)
38 | key_padding_mask: a bool tensor of shape (B, S)
39 | """
40 | assert not need_weights
41 | assert qkv.dtype in [torch.float16, torch.bfloat16]
42 | assert qkv.is_cuda
43 |
44 | if cu_seqlens is None:
45 | batch_size = qkv.shape[0]
46 | seqlen = qkv.shape[1]
47 | if key_padding_mask is None:
48 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
49 | max_s = seqlen
50 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
51 | device=qkv.device)
52 | output = flash_attn_unpadded_qkvpacked_func(
53 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
54 | softmax_scale=self.softmax_scale, causal=causal
55 | )
56 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
57 | else:
58 | nheads = qkv.shape[-2]
59 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
60 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
61 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
62 | output_unpad = flash_attn_unpadded_qkvpacked_func(
63 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
64 | softmax_scale=self.softmax_scale, causal=causal
65 | )
66 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
67 | indices, batch_size, seqlen),
68 | 'b s (h d) -> b s h d', h=nheads)
69 | else:
70 | assert max_s is not None
71 | output = flash_attn_unpadded_qkvpacked_func(
72 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
73 | softmax_scale=self.softmax_scale, causal=causal
74 | )
75 |
76 | return output, None
77 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/modeling_intern_vit.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2025 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | from typing import Optional, Tuple, Union
7 |
8 | import torch
9 | import torch.nn.functional as F
10 | import torch.utils.checkpoint
11 | from einops import rearrange
12 | from timm.models.layers import DropPath
13 | from torch import nn
14 | from transformers.activations import ACT2FN
15 | from transformers.modeling_outputs import (BaseModelOutput,
16 | BaseModelOutputWithPooling)
17 | from transformers.modeling_utils import PreTrainedModel
18 | from transformers.utils import logging
19 |
20 | from .configuration_intern_vit import InternVisionConfig, InternVisionPatchConfig
21 |
22 | try:
23 | from .flash_attention import FlashAttention
24 | has_flash_attn = True
25 | except:
26 | print('FlashAttention is not installed.')
27 | has_flash_attn = False
28 |
29 |
30 | logger = logging.get_logger(__name__)
31 |
32 |
33 | class InternRMSNorm(nn.Module):
34 | def __init__(self, hidden_size, eps=1e-6):
35 | super().__init__()
36 | self.weight = nn.Parameter(torch.ones(hidden_size))
37 | self.variance_epsilon = eps
38 |
39 | def forward(self, hidden_states):
40 | input_dtype = hidden_states.dtype
41 | hidden_states = hidden_states.to(torch.float32)
42 | variance = hidden_states.pow(2).mean(-1, keepdim=True)
43 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44 | return self.weight * hidden_states.to(input_dtype)
45 |
46 |
47 | try:
48 | from apex.normalization import FusedRMSNorm
49 |
50 | InternRMSNorm = FusedRMSNorm # noqa
51 |
52 | logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
53 | except ImportError:
54 | # using the normal InternRMSNorm
55 | pass
56 | except Exception:
57 | logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
58 | pass
59 |
60 |
61 | NORM2FN = {
62 | 'rms_norm': InternRMSNorm,
63 | 'layer_norm': nn.LayerNorm,
64 | }
65 |
66 |
67 | class InternVisionEmbeddings(nn.Module):
68 | def __init__(self, config: InternVisionConfig):
69 | super().__init__()
70 | self.config = config
71 | self.embed_dim = config.hidden_size
72 | self.image_size = config.image_size
73 | self.patch_size = config.patch_size
74 |
75 | self.class_embedding = nn.Parameter(
76 | torch.randn(1, 1, self.embed_dim),
77 | )
78 |
79 | self.patch_embedding = nn.Conv2d(
80 | in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
81 | )
82 |
83 | self.num_patches = (self.image_size // self.patch_size) ** 2
84 | self.num_positions = self.num_patches + 1
85 |
86 | self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
87 |
88 | def _get_pos_embed(self, pos_embed, H, W):
89 | target_dtype = pos_embed.dtype
90 | pos_embed = pos_embed.float().reshape(
91 | 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
92 | pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\
93 | reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype)
94 | return pos_embed
95 |
96 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
97 | target_dtype = self.patch_embedding.weight.dtype
98 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height]
99 | batch_size, _, height, width = patch_embeds.shape
100 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
101 | class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
102 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
103 | position_embedding = torch.cat([
104 | self.position_embedding[:, :1, :],
105 | self._get_pos_embed(self.position_embedding[:, 1:, :], height, width)
106 | ], dim=1)
107 | embeddings = embeddings + position_embedding.to(target_dtype)
108 | return embeddings
109 |
110 |
111 | class InternAttention(nn.Module):
112 | """Multi-headed attention from 'Attention Is All You Need' paper"""
113 |
114 | def __init__(self, config: InternVisionConfig):
115 | super().__init__()
116 | self.config = config
117 | self.embed_dim = config.hidden_size
118 | self.num_heads = config.num_attention_heads
119 | self.use_flash_attn = config.use_flash_attn and has_flash_attn
120 | if config.use_flash_attn and not has_flash_attn:
121 | print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
122 | self.head_dim = self.embed_dim // self.num_heads
123 | if self.head_dim * self.num_heads != self.embed_dim:
124 | raise ValueError(
125 | f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
126 | f' {self.num_heads}).'
127 | )
128 |
129 | self.scale = self.head_dim ** -0.5
130 | self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
131 | self.attn_drop = nn.Dropout(config.attention_dropout)
132 | self.proj_drop = nn.Dropout(config.dropout)
133 |
134 | self.qk_normalization = config.qk_normalization
135 |
136 | if self.qk_normalization:
137 | self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
138 | self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
139 |
140 | if self.use_flash_attn:
141 | self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
142 | self.proj = nn.Linear(self.embed_dim, self.embed_dim)
143 |
144 | def _naive_attn(self, x):
145 | B, N, C = x.shape
146 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
147 | q, k, v = qkv.unbind(0)
148 |
149 | if self.qk_normalization:
150 | B_, H_, N_, D_ = q.shape
151 | q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
152 | k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
153 |
154 | attn = ((q * self.scale) @ k.transpose(-2, -1))
155 | attn = attn.softmax(dim=-1)
156 | attn = self.attn_drop(attn)
157 |
158 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
159 | x = self.proj(x)
160 | x = self.proj_drop(x)
161 | return x
162 |
163 | def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
164 | qkv = self.qkv(x)
165 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
166 |
167 | if self.qk_normalization:
168 | q, k, v = qkv.unbind(2)
169 | q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
170 | k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
171 | qkv = torch.stack([q, k, v], dim=2)
172 |
173 | context, _ = self.inner_attn(
174 | qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
175 | )
176 | outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
177 | outs = self.proj_drop(outs)
178 | return outs
179 |
180 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
181 | x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
182 | return x
183 |
184 |
185 | class InternMLP(nn.Module):
186 | def __init__(self, config: InternVisionConfig):
187 | super().__init__()
188 | self.config = config
189 | self.act = ACT2FN[config.hidden_act]
190 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
191 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
192 |
193 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
194 | hidden_states = self.fc1(hidden_states)
195 | hidden_states = self.act(hidden_states)
196 | hidden_states = self.fc2(hidden_states)
197 | return hidden_states
198 |
199 |
200 | class InternVisionEncoderLayer(nn.Module):
201 | def __init__(self, config: InternVisionConfig, drop_path_rate: float):
202 | super().__init__()
203 | self.embed_dim = config.hidden_size
204 | self.intermediate_size = config.intermediate_size
205 | self.norm_type = config.norm_type
206 |
207 | self.attn = InternAttention(config)
208 | self.mlp = InternMLP(config)
209 | self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
210 | self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
211 |
212 | self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
213 | self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
214 | self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
215 | self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
216 |
217 | def forward(
218 | self,
219 | hidden_states: torch.Tensor,
220 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
221 | """
222 | Args:
223 | hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
224 | """
225 | hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
226 |
227 | hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
228 |
229 | return hidden_states
230 |
231 |
232 | class InternVisionEncoder(nn.Module):
233 | """
234 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
235 | [`InternEncoderLayer`].
236 |
237 | Args:
238 | config (`InternConfig`):
239 | The corresponding vision configuration for the `InternEncoder`.
240 | """
241 |
242 | def __init__(self, config: InternVisionConfig):
243 | super().__init__()
244 | self.config = config
245 | # stochastic depth decay rule
246 | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
247 | self.layers = nn.ModuleList([
248 | InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
249 | self.gradient_checkpointing = True
250 |
251 | def forward(
252 | self,
253 | inputs_embeds,
254 | output_hidden_states: Optional[bool] = None,
255 | return_dict: Optional[bool] = None,
256 | ) -> Union[Tuple, BaseModelOutput]:
257 | r"""
258 | Args:
259 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
260 | Embedded representation of the inputs. Should be float, not int tokens.
261 | output_hidden_states (`bool`, *optional*):
262 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
263 | for more detail.
264 | return_dict (`bool`, *optional*):
265 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
266 | """
267 | output_hidden_states = (
268 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
269 | )
270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
271 |
272 | encoder_states = () if output_hidden_states else None
273 | hidden_states = inputs_embeds
274 | for idx, encoder_layer in enumerate(self.layers):
275 | if output_hidden_states:
276 | encoder_states = encoder_states + (hidden_states,)
277 | if self.gradient_checkpointing and self.training:
278 | layer_outputs = torch.utils.checkpoint.checkpoint(
279 | encoder_layer,
280 | hidden_states)
281 | else:
282 | layer_outputs = encoder_layer(
283 | hidden_states,
284 | )
285 | hidden_states = layer_outputs
286 |
287 | if output_hidden_states:
288 | encoder_states = encoder_states + (hidden_states,)
289 |
290 | if not return_dict:
291 | return tuple(v for v in [hidden_states, encoder_states] if v is not None)
292 | return BaseModelOutput(
293 | last_hidden_state=hidden_states, hidden_states=encoder_states
294 | )
295 |
296 |
297 | class InternVisionModel(PreTrainedModel):
298 | main_input_name = 'pixel_values'
299 | config_class = InternVisionConfig
300 | _no_split_modules = ['InternVisionEncoderLayer']
301 |
302 | def __init__(self, config: InternVisionConfig):
303 | super().__init__(config)
304 | self.config = config
305 |
306 | self.embeddings = InternVisionEmbeddings(config)
307 | self.encoder = InternVisionEncoder(config)
308 |
309 | def resize_pos_embeddings(self, old_size, new_size, patch_size):
310 | pos_emb = self.embeddings.position_embedding
311 | _, num_positions, embed_dim = pos_emb.shape
312 | cls_emb = pos_emb[:, :1, :]
313 | pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
314 | pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
315 | pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
316 | pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
317 | self.embeddings.position_embedding = nn.Parameter(pos_emb)
318 | self.embeddings.image_size = new_size
319 | logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
320 |
321 | def get_input_embeddings(self):
322 | return self.embeddings
323 |
324 | def forward(
325 | self,
326 | pixel_values: Optional[torch.FloatTensor] = None,
327 | output_hidden_states: Optional[bool] = None,
328 | return_dict: Optional[bool] = None,
329 | pixel_embeds: Optional[torch.FloatTensor] = None,
330 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
331 | output_hidden_states = (
332 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
333 | )
334 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
335 |
336 | if pixel_values is None and pixel_embeds is None:
337 | raise ValueError('You have to specify pixel_values or pixel_embeds')
338 |
339 | if pixel_embeds is not None:
340 | hidden_states = pixel_embeds
341 | else:
342 | if len(pixel_values.shape) == 4:
343 | hidden_states = self.embeddings(pixel_values)
344 | else:
345 | raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
346 | encoder_outputs = self.encoder(
347 | inputs_embeds=hidden_states,
348 | output_hidden_states=output_hidden_states,
349 | return_dict=return_dict,
350 | )
351 | last_hidden_state = encoder_outputs.last_hidden_state
352 | pooled_output = last_hidden_state[:, 0, :]
353 |
354 | if not return_dict:
355 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
356 |
357 | return BaseModelOutputWithPooling(
358 | last_hidden_state=last_hidden_state,
359 | pooler_output=pooled_output,
360 | hidden_states=encoder_outputs.hidden_states,
361 | attentions=encoder_outputs.attentions,
362 | )
363 |
364 |
365 | class InternVisionPatchModel(PreTrainedModel):
366 | main_input_name = 'pixel_values'
367 | config_class = InternVisionPatchConfig
368 | _no_split_modules = ['InternVisionEncoderLayer']
369 |
370 | def __init__(self, config: InternVisionPatchConfig):
371 | super().__init__(config)
372 | self.config = config
373 | self.embeddings = InternVisionEmbeddings(config)
374 | def resize_pos_embeddings(self, old_size, new_size, patch_size):
375 | pos_emb = self.embeddings.position_embedding
376 | _, num_positions, embed_dim = pos_emb.shape
377 | cls_emb = pos_emb[:, :1, :]
378 | pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
379 | pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
380 | pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
381 | pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
382 | self.embeddings.position_embedding = nn.Parameter(pos_emb)
383 | self.embeddings.image_size = new_size
384 | logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
385 |
386 | def get_input_embeddings(self):
387 | return self.embeddings
388 |
389 |
390 | def forward(
391 | self,
392 | pixel_values: Optional[torch.FloatTensor] = None,
393 | output_hidden_states: Optional[bool] = None,
394 | return_dict: Optional[bool] = None,
395 | pixel_embeds: Optional[torch.FloatTensor] = None,
396 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
397 |
398 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
399 |
400 | if pixel_values is None:
401 | raise ValueError('You have to specify pixel_values')
402 |
403 |
404 | if len(pixel_values.shape) == 4:
405 | hidden_states = self.embeddings(pixel_values)[:,1:]
406 | else:
407 | raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
408 |
409 |
410 | if not return_dict:
411 | return (hidden_states, None,None)
412 |
413 | return BaseModelOutputWithPooling(
414 | last_hidden_state=hidden_states,
415 | pooler_output=None,
416 | hidden_states=None,
417 | attentions=None,
418 | )
419 |
420 |
421 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/modeling_internvl_chat.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2025 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 | import warnings
7 | from typing import Any, List, Optional, Tuple, Union
8 |
9 | import torch.distributed as dist
10 | import torch.utils.checkpoint
11 | from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
12 | from internvl.model.internlm2.modeling_internlm2_ve import InternLM2VEForCausalLM
13 | from peft import LoraConfig, get_peft_model
14 | from torch import nn
15 | from torch.nn import CrossEntropyLoss
16 | from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM,
17 | LlamaTokenizer, Qwen2ForCausalLM)
18 | from transformers.modeling_outputs import CausalLMOutputWithPast
19 | from transformers.modeling_utils import PreTrainedModel
20 | from transformers.utils import ModelOutput, logging
21 |
22 | from .configuration_internvl_chat import InternVLChatConfig
23 | from .modeling_intern_vit import InternVisionModel,InternVisionPatchModel
24 | from dataclasses import dataclass
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | @dataclass
29 | class CausalLMOutputWithVisualMask(ModelOutput):
30 | """
31 | Base class for causal language model (or autoregressive) outputs.
32 |
33 | Args:
34 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
35 | Language modeling loss (for next-token prediction).
36 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
37 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
38 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
39 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
40 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
41 |
42 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
43 | `past_key_values` input) to speed up sequential decoding.
44 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
45 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
46 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
47 |
48 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
49 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
50 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
51 | sequence_length)`.
52 |
53 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
54 | heads.
55 | """
56 |
57 | loss: Optional[torch.FloatTensor] = None
58 | logits: torch.FloatTensor = None
59 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
60 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
61 | attentions: Optional[Tuple[torch.FloatTensor]] = None
62 | visual_token_mask : Optional[torch.FloatTensor] = None
63 |
64 | class InternVLChatModel(PreTrainedModel):
65 | config_class = InternVLChatConfig
66 | main_input_name = 'pixel_values'
67 | _no_split_modules = ['InternVision', 'LlamaDecoderLayer', 'InternLM2DecoderLayer',
68 | 'Phi3DecoderLayer', 'Qwen2DecoderLayer']
69 |
70 | def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None):
71 | super().__init__(config)
72 |
73 | image_size = config.force_image_size or config.vision_config.image_size
74 | patch_size = config.vision_config.patch_size
75 | self.patch_size = patch_size
76 | self.select_layer = config.select_layer
77 | self.template = config.template
78 | self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))
79 | self.downsample_ratio = config.downsample_ratio
80 | self.ps_version = config.ps_version
81 | self.use_visual_token_mask=False
82 |
83 | logger.info(f'num_image_token: {self.num_image_token}')
84 | logger.info(f'ps_version: {self.ps_version}')
85 | if vision_model is not None:
86 | self.vision_model = vision_model
87 | elif 'intern_vit_6b' in config.vision_config.model_type:
88 | self.vision_model = InternVisionModel(config.vision_config)
89 | elif 'intern_vit_patch' in config.vision_config.model_type:
90 | self.vision_model = InternVisionPatchModel(config.vision_config)
91 | else:
92 | assert NotImplementedError
93 |
94 |
95 | if language_model is not None:
96 | self.language_model = language_model
97 | else:
98 | if config.llm_config.architectures[0] == 'LlamaForCausalLM':
99 | self.language_model = LlamaForCausalLM(config.llm_config)
100 | elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM':
101 | self.language_model = InternLM2ForCausalLM(config.llm_config)
102 | elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM':
103 | self.language_model = Qwen2ForCausalLM(config.llm_config)
104 | elif config.llm_config.architectures[0]=='InternLM2VEForCausalLM':
105 | self.language_model=InternLM2VEForCausalLM(config.llm_config)
106 | self.use_visual_token_mask=True
107 | else:
108 | raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.')
109 |
110 | self.language_arch=config.llm_config.architectures[0]
111 | vit_hidden_size = config.vision_config.hidden_size
112 | llm_hidden_size = config.llm_config.hidden_size
113 |
114 | self.mlp1 = nn.Sequential(
115 | nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
116 | nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
117 | nn.GELU(),
118 | nn.Linear(llm_hidden_size, llm_hidden_size)
119 | )
120 |
121 | self.img_context_token_id = None
122 | self.neftune_alpha = None
123 | self.num_samples = 0
124 |
125 | if config.use_backbone_lora:
126 | self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora)
127 |
128 | if config.use_llm_lora:
129 | self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora)
130 |
131 | def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
132 | lora_config = LoraConfig(
133 | r=r,
134 | target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'],
135 | lora_alpha=lora_alpha,
136 | lora_dropout=lora_dropout,
137 | )
138 | self.vision_model = get_peft_model(self.vision_model, lora_config)
139 | self.vision_model.print_trainable_parameters()
140 |
141 | def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05):
142 | lora_config = LoraConfig(
143 | r=r,
144 | target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj',
145 | 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'],
146 | lora_alpha=lora_alpha,
147 | lora_dropout=lora_dropout,
148 | task_type='CAUSAL_LM'
149 | )
150 | self.language_model = get_peft_model(self.language_model, lora_config)
151 | self.language_model.enable_input_require_grads()
152 | self.language_model.print_trainable_parameters()
153 |
154 | def forward(
155 | self,
156 | pixel_values: torch.FloatTensor,
157 | input_ids: torch.LongTensor = None,
158 | attention_mask: Optional[torch.Tensor] = None,
159 | position_ids: Optional[torch.LongTensor] = None,
160 | image_flags: Optional[torch.LongTensor] = None,
161 | past_key_values: Optional[List[torch.FloatTensor]] = None,
162 | labels: Optional[torch.LongTensor] = None,
163 | use_cache: Optional[bool] = None,
164 | output_attentions: Optional[bool] = None,
165 | output_hidden_states: Optional[bool] = None,
166 | return_dict: Optional[bool] = None,
167 | statistics: Optional[torch.LongTensor] = None,
168 | loss_weight: Optional[List] = None,
169 | loss_reduction_all_gather: Optional[bool] = False,
170 | ) -> Union[Tuple, CausalLMOutputWithVisualMask]:
171 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
172 |
173 | image_flags = image_flags.squeeze(-1)
174 | input_embeds = self.language_model.get_input_embeddings()(input_ids).clone()
175 |
176 | vit_embeds = self.extract_feature(pixel_values).to(input_embeds.dtype)
177 | vit_embeds = vit_embeds[image_flags == 1]
178 | vit_batch_size = pixel_values.shape[0]
179 |
180 | B, N, C = input_embeds.shape
181 | input_embeds = input_embeds.reshape(B * N, C)
182 |
183 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
184 | print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}')
185 | if statistics is not None:
186 | num_samples, num_padding_tokens, num_padding_images = statistics.tolist()
187 | self.num_samples += num_samples
188 | print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}')
189 |
190 | input_ids = input_ids.reshape(B * N)
191 | selected = (input_ids == self.img_context_token_id)
192 |
193 | try:
194 | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C)
195 | ignore_flag = False
196 | except Exception as e:
197 | vit_embeds = vit_embeds.reshape(-1, C)
198 | print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, '
199 | f'vit_embeds.shape={vit_embeds.shape}')
200 | n_token = selected.sum()
201 | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token]
202 | ignore_flag = True
203 |
204 | input_embeds = input_embeds.reshape(B, N, C)
205 |
206 | if self.language_arch=='InternLM2VEForCausalLM':
207 | visual_token_mask = selected.reshape(B,N,1).to(input_embeds.dtype)
208 | outputs = self.language_model(
209 | inputs_embeds=input_embeds,
210 | attention_mask=attention_mask,
211 | position_ids=position_ids,
212 | past_key_values=past_key_values,
213 | use_cache=use_cache,
214 | output_attentions=output_attentions,
215 | output_hidden_states=output_hidden_states,
216 | return_dict=return_dict,
217 | visual_token_mask=visual_token_mask
218 | )
219 | else:
220 | outputs = self.language_model(
221 | inputs_embeds=input_embeds,
222 | attention_mask=attention_mask,
223 | position_ids=position_ids,
224 | past_key_values=past_key_values,
225 | use_cache=use_cache,
226 | output_attentions=output_attentions,
227 | output_hidden_states=output_hidden_states,
228 | return_dict=return_dict
229 | )
230 | logits = outputs.logits
231 |
232 | loss = None
233 | if labels is not None and loss_weight is not None:
234 | loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device)
235 | # Shift so that tokens < n predict n
236 | shift_logits = logits[..., :-1, :].contiguous()
237 | shift_labels = labels[..., 1:].contiguous()
238 | shift_weights = loss_weight[..., 1:].contiguous()
239 | # Flatten the tokens
240 | loss_fct = CrossEntropyLoss(reduction='none')
241 | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
242 | shift_labels = shift_labels.view(-1)
243 | shift_weights = shift_weights.view(-1)
244 | # Enable model parallelism
245 | shift_labels = shift_labels.to(shift_logits.device)
246 | shift_weights = shift_weights.to(shift_logits.device)
247 | loss = loss_fct(shift_logits, shift_labels)
248 |
249 | shift_weights_sum = shift_weights.sum()
250 | if loss_reduction_all_gather:
251 | dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG)
252 |
253 | loss = loss * shift_weights
254 | loss = loss.sum() / shift_weights_sum
255 | if ignore_flag:
256 | loss = loss * 0.0
257 | elif labels is not None:
258 | # Shift so that tokens < n predict n
259 | shift_logits = logits[..., :-1, :].contiguous()
260 | shift_labels = labels[..., 1:].contiguous()
261 | # Flatten the tokens
262 | loss_fct = CrossEntropyLoss()
263 | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
264 | shift_labels = shift_labels.view(-1)
265 | # Enable model parallelism
266 | shift_labels = shift_labels.to(shift_logits.device)
267 | loss = loss_fct(shift_logits, shift_labels)
268 | if ignore_flag:
269 | loss = loss * 0.0
270 |
271 | if not return_dict:
272 | output = (logits,) + outputs[1:]
273 | return (loss,) + output if loss is not None else output
274 |
275 | return CausalLMOutputWithVisualMask(
276 | loss=loss,
277 | logits=logits,
278 | past_key_values=outputs.past_key_values,
279 | hidden_states=outputs.hidden_states,
280 | attentions=outputs.attentions,
281 | visual_token_mask=selected.reshape(B,N,1).to(input_embeds.dtype)
282 | )
283 |
284 | def pixel_shuffle(self, x, scale_factor=0.5):
285 | n, w, h, c = x.size()
286 | # N, W, H, C --> N, W, H * scale, C // scale
287 | x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
288 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
289 | x = x.permute(0, 2, 1, 3).contiguous()
290 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
291 | x = x.view(n, int(h * scale_factor), int(w * scale_factor),
292 | int(c / (scale_factor * scale_factor)))
293 | if self.ps_version == 'v1':
294 | warnings.warn("In ps_version 'v1', the height and width have not been swapped back, "
295 | 'which results in a transposed image.')
296 | else:
297 | x = x.permute(0, 2, 1, 3).contiguous()
298 | return x
299 |
300 | def noised_embed(self, vit_embeds, noise_alpha=5):
301 | dims = torch.tensor(vit_embeds.size(1) * vit_embeds.size(2))
302 | mag_norm = noise_alpha / torch.sqrt(dims)
303 | noise = torch.zeros_like(vit_embeds).uniform_(-mag_norm, mag_norm)
304 | return vit_embeds + noise
305 |
306 | def extract_feature(self, pixel_values):
307 | if self.select_layer == -1:
308 | vit_embeds = self.vision_model(
309 | pixel_values=pixel_values,
310 | output_hidden_states=False,
311 | return_dict=True).last_hidden_state
312 | else:
313 | vit_embeds = self.vision_model(
314 | pixel_values=pixel_values,
315 | output_hidden_states=True,
316 | return_dict=True).hidden_states[self.select_layer]
317 | if int(vit_embeds.shape[1] ** 0.5)**2 != vit_embeds.shape[1]:
318 | vit_embeds = vit_embeds[:, 1:, :]
319 |
320 | if self.training and self.neftune_alpha is not None:
321 | vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha)
322 |
323 | h = w = int(vit_embeds.shape[1] ** 0.5)
324 | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
325 | vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
326 | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
327 | vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device)
328 | return vit_embeds
329 |
330 | def batch_chat(self, tokenizer, pixel_values, image_counts, questions, generation_config, history=None,
331 | return_history=False, IMG_START_TOKEN='
', IMG_END_TOKEN='',
332 | IMG_CONTEXT_TOKEN=''):
333 | if history is not None or return_history:
334 | print('Now multi-turn chat is not supported in batch_chat.')
335 | raise NotImplementedError
336 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
337 | self.img_context_token_id = img_context_token_id
338 |
339 | from internvl.conversation import get_conv_template
340 |
341 | queries = []
342 | image_bs = pixel_values.shape[0]
343 |
344 | for idx, image_count in enumerate(image_counts):
345 | image_token = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_count + IMG_END_TOKEN
346 | question = image_token + '\n' + questions[idx]
347 | template = get_conv_template(self.template)
348 | template.append_message(template.roles[0], question)
349 | template.append_message(template.roles[1], None)
350 | query = template.get_prompt()
351 | queries.append(query)
352 | tokenizer.padding_side = 'left'
353 | model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
354 | input_ids = model_inputs['input_ids'].cuda()
355 | attention_mask = model_inputs['attention_mask'].cuda()
356 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
357 | generation_config['eos_token_id'] = eos_token_id
358 |
359 | generation_output = self.generate(
360 | pixel_values=pixel_values,
361 | input_ids=input_ids,
362 | attention_mask=attention_mask,
363 | **generation_config
364 | )
365 | responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True)
366 | responses = [response.split(template.sep)[0].strip() for response in responses]
367 | return responses
368 |
369 | def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False,
370 | IMG_START_TOKEN='
', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN=''):
371 |
372 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
373 | self.img_context_token_id = img_context_token_id
374 |
375 | from internvl.conversation import get_conv_template
376 |
377 | template = get_conv_template(self.template)
378 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
379 |
380 | image_bs = pixel_values.shape[0]
381 | print(f'dynamic ViT batch size: {image_bs}')
382 | if history is None:
383 | history = []
384 | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_bs + IMG_END_TOKEN
385 | question = image_tokens + '\n' + question
386 | else:
387 | for (old_question, old_answer) in history:
388 | template.append_message(template.roles[0], old_question)
389 | template.append_message(template.roles[1], old_answer)
390 | template.append_message(template.roles[0], question)
391 | template.append_message(template.roles[1], None)
392 | query = template.get_prompt()
393 | model_inputs = tokenizer(query, return_tensors='pt')
394 | input_ids = model_inputs['input_ids'].cuda()
395 | attention_mask = model_inputs['attention_mask'].cuda()
396 | generation_config['eos_token_id'] = eos_token_id
397 | generation_output = self.generate(
398 | pixel_values=pixel_values,
399 | input_ids=input_ids,
400 | attention_mask=attention_mask,
401 | **generation_config
402 | )
403 | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
404 | response = response.split(template.sep)[0].strip()
405 | history.append((question, response))
406 | if return_history:
407 | return response, history
408 | else:
409 | query_to_print = query.replace(image_tokens, '')
410 | print(query_to_print, response)
411 | return response
412 | return response
413 |
414 | def multi_image_chat(self, tokenizer, pixel_values, image_counts, question, generation_config, history=None,
415 | return_history=False, IMG_START_TOKEN='
', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN=''):
416 |
417 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
418 | self.img_context_token_id = img_context_token_id
419 |
420 | from internvl.conversation import get_conv_template
421 |
422 | template = get_conv_template(self.template)
423 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep)
424 |
425 | if history is None:
426 | history = []
427 | image_tokens = ''
428 | image_bs = pixel_values.shape[0]
429 | print(f'dynamic ViT batch size: {image_bs}, image_counts: {image_counts}')
430 | for idx, image_count in enumerate(image_counts):
431 | image_tokens += f' (图{idx+1}):' + IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_count + IMG_END_TOKEN
432 | question = image_tokens + '\n' + question
433 | else:
434 | for (old_question, old_answer) in history:
435 | template.append_message(template.roles[0], old_question)
436 | template.append_message(template.roles[1], old_answer)
437 | template.append_message(template.roles[0], question)
438 | template.append_message(template.roles[1], None)
439 | query = template.get_prompt()
440 | model_inputs = tokenizer(query, return_tensors='pt')
441 | input_ids = model_inputs['input_ids'].cuda()
442 | attention_mask = model_inputs['attention_mask'].cuda()
443 | generation_config['eos_token_id'] = eos_token_id
444 |
445 | generation_output = self.generate(
446 | pixel_values=pixel_values,
447 | input_ids=input_ids,
448 | attention_mask=attention_mask,
449 | **generation_config
450 | )
451 | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
452 | response = response.split(template.sep)[0].strip()
453 | history.append((question, response))
454 | if return_history:
455 | return response, history
456 | else:
457 | query_to_print = query.replace(image_tokens, '')
458 | print(query_to_print, response)
459 | return response
460 | return response
461 |
462 | @torch.no_grad()
463 | def generate(
464 | self,
465 | pixel_values: Optional[torch.FloatTensor] = None,
466 | input_ids: Optional[torch.FloatTensor] = None,
467 | attention_mask: Optional[torch.LongTensor] = None,
468 | visual_features: Optional[torch.FloatTensor] = None,
469 | generation_config: Optional[GenerationConfig] = None,
470 | output_hidden_states: Optional[bool] = None,
471 | return_dict: Optional[bool] = None,
472 | **generate_kwargs,
473 | ) -> torch.LongTensor:
474 |
475 | assert self.img_context_token_id is not None
476 | if pixel_values is not None:
477 | if visual_features is not None:
478 | vit_embeds = visual_features
479 | else:
480 | vit_embeds = self.extract_feature(pixel_values)
481 |
482 | input_embeds = self.language_model.get_input_embeddings()(input_ids)
483 | B, N, C = input_embeds.shape
484 | input_embeds = input_embeds.reshape(B * N, C)
485 |
486 | input_ids = input_ids.reshape(B * N)
487 | selected = (input_ids == self.img_context_token_id)
488 | assert selected.sum() != 0
489 | input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
490 |
491 | input_embeds = input_embeds.reshape(B, N, C)
492 | visual_token_mask = selected.reshape(B,N,1).to(input_embeds.dtype)
493 | else:
494 | input_embeds = self.language_model.get_input_embeddings()(input_ids)
495 | visual_token_mask = torch.zero_like(input_ids).reshape(B,N,1)
496 |
497 | if self.use_visual_token_mask:
498 | outputs = self.language_model.generate(
499 | inputs_embeds=input_embeds,
500 | attention_mask=attention_mask,
501 | generation_config=generation_config,
502 | output_hidden_states=output_hidden_states,
503 | return_dict=return_dict,
504 | use_cache=True,
505 | visual_token_mask=visual_token_mask,
506 | **generate_kwargs,
507 | )
508 | else:
509 | outputs = self.language_model.generate(
510 | inputs_embeds=input_embeds,
511 | attention_mask=attention_mask,
512 | generation_config=generation_config,
513 | output_hidden_states=output_hidden_states,
514 | return_dict=return_dict,
515 | use_cache=True,
516 | **generate_kwargs,
517 | )
518 | return outputs
519 |
--------------------------------------------------------------------------------
/internvl/patch/__init__.py:
--------------------------------------------------------------------------------
1 | from .internlm2_packed_training_patch import replace_internlm2_attention_class
2 | from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn
3 | from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
4 | from .llama_rmsnorm_monkey_patch import \
5 | replace_llama_rmsnorm_with_fused_rmsnorm
6 | from .pad_data_collator import concat_pad_data_collator, pad_data_collator
7 | from .qwen2_packed_training_patch import replace_qwen2_attention_class
8 | from .train_dataloader_patch import replace_train_dataloader, replace_unwapper_train_dataloader
9 | from .train_sampler_patch import replace_train_sampler, replace_sequence_train_sampler
10 |
11 | __all__ = ['replace_llama_attn_with_flash_attn',
12 | 'replace_llama_rmsnorm_with_fused_rmsnorm',
13 | 'replace_llama2_attn_with_flash_attn',
14 | 'replace_train_sampler',
15 | 'replace_sequence_train_sampler',
16 | 'replace_train_dataloader',
17 | 'replace_unwapper_train_dataloader',
18 | 'replace_internlm2_attention_class',
19 | 'replace_qwen2_attention_class',
20 | 'pad_data_collator',
21 | 'concat_pad_data_collator']
22 |
--------------------------------------------------------------------------------
/internvl/patch/internlm2_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
3 | from internvl.model.internlm2.modeling_internlm2 import (
4 | INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2,
5 | apply_rotary_pos_emb)
6 |
7 |
8 | # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2
9 | class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2):
10 |
11 | def _flash_attention_forward(
12 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
13 | ):
14 | """
15 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
16 | first unpad the input, then computes the attention scores and pad the final attention scores.
17 |
18 | Args:
19 | query_states (`torch.Tensor`):
20 | Input query states to be passed to Flash Attention API
21 | key_states (`torch.Tensor`):
22 | Input key states to be passed to Flash Attention API
23 | value_states (`torch.Tensor`):
24 | Input value states to be passed to Flash Attention API
25 | attention_mask (`torch.Tensor`):
26 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
27 | of the sequences in the batch.
28 | dropout (`int`, *optional*):
29 | Attention dropout
30 | softmax_scale (`float`, *optional*):
31 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
32 | """
33 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
34 | query_states = query_states.squeeze(0)
35 | key_states = key_states.squeeze(0)
36 | value_states = value_states.squeeze(0)
37 | cu_seqlens = attention_mask.squeeze(0)
38 |
39 | with torch.no_grad():
40 | max_seqlen = max([
41 | cu_seqlens[idx+1] - cu_seqlens[idx]
42 | for idx in range(cu_seqlens.size(0) - 1)
43 | ]).item()
44 |
45 | # Contains at least one padding token in the sequence
46 | causal = self.is_causal and query_length != 1
47 | attn_output = flash_attn_varlen_func(
48 | q=query_states,
49 | k=key_states,
50 | v=value_states,
51 | cu_seqlens_q=cu_seqlens,
52 | cu_seqlens_k=cu_seqlens,
53 | max_seqlen_q=max_seqlen,
54 | max_seqlen_k=max_seqlen,
55 | dropout_p=dropout,
56 | softmax_scale=softmax_scale,
57 | causal=causal,
58 | )
59 |
60 | query_states = query_states.unsqueeze(0)
61 | key_states = key_states.unsqueeze(0)
62 | value_states = value_states.unsqueeze(0)
63 | return attn_output
64 |
65 |
66 | def replace_internlm2_attention_class():
67 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining
68 | print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!')
69 |
--------------------------------------------------------------------------------
/internvl/patch/llama2_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is copied from: https://github.com/lm-sys/FastChat
3 | """
4 | import warnings
5 | from typing import Optional, Tuple
6 |
7 | import torch
8 | from flash_attn import __version__ as flash_attn_version
9 | from flash_attn.bert_padding import pad_input, unpad_input
10 | from flash_attn.flash_attn_interface import (flash_attn_func,
11 | flash_attn_varlen_kvpacked_func)
12 | from transformers.models.llama.modeling_llama import (LlamaAttention,
13 | LlamaModel, rotate_half)
14 |
15 |
16 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
17 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
18 | gather_indices = gather_indices.repeat(
19 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
20 | )
21 | bsz = gather_indices.shape[0]
22 | cos, sin = (
23 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
24 | for x in cos_sin
25 | )
26 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
27 | return q, k
28 |
29 |
30 | def forward(
31 | self,
32 | hidden_states: torch.Tensor,
33 | attention_mask: Optional[torch.Tensor] = None,
34 | position_ids: Optional[torch.Tensor] = None,
35 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
36 | output_attentions: bool = False,
37 | use_cache: bool = False,
38 | padding_mask: Optional[torch.Tensor] = None,
39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40 | if output_attentions:
41 | warnings.warn(
42 | 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.'
43 | )
44 |
45 | bsz, q_len, _ = hidden_states.size()
46 | kv_heads = getattr(self, 'num_key_value_heads', self.num_heads)
47 |
48 | q, k, v = (
49 | op(hidden_states).view(bsz, q_len, nh, self.head_dim)
50 | for op, nh in (
51 | (self.q_proj, self.num_heads),
52 | (self.k_proj, kv_heads),
53 | (self.v_proj, kv_heads),
54 | )
55 | )
56 | # shape: (b, s, num_heads, head_dim)
57 |
58 | kv_seq_len = k.shape[1]
59 | past_kv_len = 0
60 | if past_key_value is not None:
61 | past_kv_len = past_key_value[0].shape[2]
62 | kv_seq_len += past_kv_len
63 |
64 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
65 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
66 |
67 | if past_key_value is not None:
68 | assert (
69 | flash_attn_version >= '2.1.0'
70 | ), 'past_key_value support requires flash-attn >= 2.1.0'
71 | # reuse k, v
72 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
73 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
74 |
75 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
76 |
77 | if attention_mask is None:
78 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
79 | bsz, q_len, -1
80 | )
81 | else:
82 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
83 | # We can skip concat and call unpad twice but seems better to call unpad only once.
84 | kv, _, cu_k_lens, max_k = unpad_input(
85 | torch.stack((k, v), dim=2), attention_mask
86 | )
87 | output_unpad = flash_attn_varlen_kvpacked_func(
88 | q,
89 | kv,
90 | cu_q_lens,
91 | cu_k_lens,
92 | max_s,
93 | max_k,
94 | 0.0,
95 | softmax_scale=None,
96 | causal=True,
97 | )
98 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
99 | output = pad_input(output_unpad, indices, bsz, q_len)
100 |
101 | return self.o_proj(output), None, past_key_value
102 |
103 |
104 | # Disable the transformation of the attention mask in LlamaModel as flash attention
105 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
106 | def _prepare_decoder_attention_mask(
107 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
108 | ):
109 | # [bsz, seq_len]
110 | if past_key_values_length > 0 and attention_mask is not None:
111 | attention_mask = torch.cat(
112 | (
113 | torch.full(
114 | (input_shape[0], past_key_values_length),
115 | True,
116 | dtype=attention_mask.dtype,
117 | device=attention_mask.device,
118 | ),
119 | attention_mask,
120 | ),
121 | dim=-1,
122 | )
123 |
124 | if attention_mask is not None and torch.all(attention_mask):
125 | return None # This uses the faster call when training with full samples
126 |
127 | return attention_mask
128 |
129 |
130 | def replace_llama2_attn_with_flash_attn():
131 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
132 | if cuda_major < 8:
133 | warnings.warn(
134 | 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.'
135 | 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593'
136 | )
137 |
138 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
139 | LlamaAttention.forward = forward
140 |
141 |
142 | def test():
143 | from fastchat.train.llama_flash_attn_monkey_patch import \
144 | forward as fastchat_forward
145 | from transformers.models.llama.configuration_llama import LlamaConfig
146 |
147 | config = LlamaConfig(
148 | hidden_size=1024,
149 | intermediate_size=128,
150 | num_hidden_layers=1,
151 | num_attention_heads=8,
152 | max_position_embeddings=16,
153 | )
154 | device = torch.device('cuda')
155 | model = LlamaModel(config)
156 | attn = LlamaAttention(config).to(device).half()
157 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
158 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
159 | -1, seqlen
160 | )
161 |
162 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
163 | for i in range(4):
164 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
165 | if i:
166 | mask[0, -i:] = False
167 | mask[1, :i] = False
168 |
169 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
170 | ref, _, _ = attn.forward(
171 | hidden, attention_mask=lmask, position_ids=position_ids
172 | )
173 |
174 | fast, _, _ = fastchat_forward(
175 | attn, hidden, attention_mask=mask, position_ids=position_ids
176 | )
177 |
178 | lmask = _prepare_decoder_attention_mask(
179 | model, mask, hidden.shape[:2], hidden, 0
180 | )
181 | test, _, _ = forward(
182 | attn, hidden, attention_mask=lmask, position_ids=position_ids
183 | )
184 |
185 | print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}')
186 | print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}')
187 | print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}')
188 | print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}')
189 | print(f'allclose(fast, test) = {torch.allclose(fast, test)}')
190 |
191 | with torch.no_grad():
192 | # Also check that past_kv is handled properly
193 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
194 | part_len = seqlen // 4
195 | assert part_len * 4 == seqlen
196 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
197 | mask[0, -2:] = False
198 | lmask = _prepare_decoder_attention_mask(
199 | model, mask, hidden.shape[:2], hidden, 0
200 | )
201 | oneshot, _, _ = forward(
202 | attn, hidden, attention_mask=lmask, position_ids=position_ids
203 | )
204 | parts = []
205 | past_kv, past_kv_len = None, 0
206 | for i in range(4):
207 | start = part_len * i
208 | end = start + part_len
209 | hidden_part = hidden[:, start:end, ...]
210 | lmask = _prepare_decoder_attention_mask(
211 | model,
212 | mask[:, start:end],
213 | hidden_part.shape[:2],
214 | hidden_part,
215 | past_kv_len,
216 | )
217 | part, _, past_kv = forward(
218 | attn,
219 | hidden_part.clone(),
220 | attention_mask=lmask,
221 | position_ids=position_ids[:, start:end],
222 | past_key_value=past_kv,
223 | use_cache=True,
224 | )
225 | parts.append(part)
226 | past_kv_len = past_kv[0].shape[2]
227 |
228 | print(
229 | f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}'
230 | )
231 | print(
232 | f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}'
233 | )
234 |
235 |
236 | if __name__ == '__main__':
237 | test()
238 |
--------------------------------------------------------------------------------
/internvl/patch/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import transformers
7 | from torch import nn
8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9 |
10 |
11 | def forward(
12 | self,
13 | hidden_states: torch.Tensor,
14 | attention_mask: Optional[torch.Tensor] = None,
15 | position_ids: Optional[torch.Tensor] = None,
16 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
17 | output_attentions: bool = False,
18 | use_cache: bool = False,
19 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
20 | """Input shape: Batch x Time x Channel
21 |
22 | attention_mask: [bsz, q_len]
23 | """
24 | from einops import rearrange
25 | try: # v1
26 | from flash_attn.flash_attn_interface import \
27 | flash_attn_unpadded_qkvpacked_func
28 | except: # v2
29 | from flash_attn.flash_attn_interface import \
30 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
31 | from flash_attn.bert_padding import pad_input, unpad_input
32 |
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 | # [bsz, q_len, nh, hd]
51 | # [bsz, nh, q_len, hd]
52 |
53 | kv_seq_len = key_states.shape[-2]
54 | assert past_key_value is None, 'past_key_value is not supported'
55 |
56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
57 | query_states, key_states = apply_rotary_pos_emb(
58 | query_states, key_states, cos, sin, position_ids
59 | )
60 | # [bsz, nh, t, hd]
61 | assert not output_attentions, 'output_attentions is not supported'
62 | assert not use_cache, 'use_cache is not supported'
63 |
64 | # Flash attention codes from
65 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
66 |
67 | # transform the data into the format required by flash attention
68 | qkv = torch.stack(
69 | [query_states, key_states, value_states], dim=2
70 | ) # [bsz, nh, 3, q_len, hd]
71 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
72 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
73 | # the attention_mask should be the same as the key_padding_mask
74 | key_padding_mask = attention_mask
75 |
76 | if key_padding_mask is None:
77 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
78 | max_s = q_len
79 | cu_q_lens = torch.arange(
80 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
81 | )
82 | output = flash_attn_unpadded_qkvpacked_func(
83 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
84 | )
85 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
86 | else:
87 | nheads = qkv.shape[-2]
88 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
89 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
90 | x_unpad = rearrange(
91 | x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads
92 | )
93 | output_unpad = flash_attn_unpadded_qkvpacked_func(
94 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
95 | )
96 | output = rearrange(
97 | pad_input(
98 | rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len
99 | ),
100 | 'b s (h d) -> b s h d',
101 | h=nheads,
102 | )
103 | return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None
104 |
105 |
106 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
107 | # requires the attention mask to be the same as the key_padding_mask
108 | def _prepare_decoder_attention_mask(
109 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
110 | ):
111 | # [bsz, seq_len]
112 | return attention_mask
113 |
114 |
115 | def forward_2(
116 | self,
117 | hidden_states: torch.Tensor,
118 | attention_mask: Optional[torch.Tensor] = None,
119 | position_ids: Optional[torch.LongTensor] = None,
120 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
121 | output_attentions: bool = False,
122 | use_cache: bool = False,
123 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
124 | bsz, q_len, _ = hidden_states.size()
125 |
126 | query_states = (
127 | self.q_proj(hidden_states)
128 | .view(bsz, q_len, self.num_heads, self.head_dim)
129 | .transpose(1, 2)
130 | )
131 | key_states = (
132 | self.k_proj(hidden_states)
133 | .view(bsz, q_len, self.num_heads, self.head_dim)
134 | .transpose(1, 2)
135 | )
136 | value_states = (
137 | self.v_proj(hidden_states)
138 | .view(bsz, q_len, self.num_heads, self.head_dim)
139 | .transpose(1, 2)
140 | )
141 |
142 | kv_seq_len = key_states.shape[-2]
143 | if past_key_value is not None:
144 | kv_seq_len += past_key_value[0].shape[-2]
145 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
146 | query_states, key_states = apply_rotary_pos_emb(
147 | query_states, key_states, cos, sin, position_ids
148 | )
149 |
150 | assert not output_attentions, 'output_attentions is not supported'
151 | assert not use_cache, 'use_cache is not supported'
152 | assert past_key_value is None, 'past_key_value is not supported'
153 |
154 | if past_key_value is not None:
155 | # reuse k, v, self_attention
156 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
157 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
158 |
159 | past_key_value = (key_states, value_states) if use_cache else None
160 | if self.training:
161 | attn_output = F.scaled_dot_product_attention(
162 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True
163 | )
164 | attn_weights = None
165 | else:
166 | attn_weights = torch.matmul(
167 | query_states, key_states.transpose(2, 3)
168 | ) / math.sqrt(self.head_dim)
169 |
170 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
171 | raise ValueError(
172 | f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
173 | f' {attn_weights.size()}'
174 | )
175 |
176 | if attention_mask is not None:
177 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
178 | raise ValueError(
179 | f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
180 | )
181 | attn_weights = attn_weights + attention_mask
182 | attn_weights = torch.max(
183 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
184 | )
185 |
186 | # upcast attention to fp32
187 | attn_weights = nn.functional.softmax(
188 | attn_weights, dim=-1, dtype=torch.float32
189 | ).to(query_states.dtype)
190 | attn_output = torch.matmul(attn_weights, value_states)
191 |
192 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
193 | raise ValueError(
194 | f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
195 | f' {attn_output.size()}'
196 | )
197 |
198 | attn_output = attn_output.transpose(1, 2)
199 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
200 |
201 | attn_output = self.o_proj(attn_output)
202 |
203 | if not output_attentions:
204 | attn_weights = None
205 |
206 | return attn_output, attn_weights, past_key_value
207 |
208 |
209 | def replace_llama_attn_with_flash_attn():
210 | if hasattr(F, 'scaled_dot_product_attention'):
211 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
212 | else:
213 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
214 | _prepare_decoder_attention_mask
215 | )
216 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
217 |
--------------------------------------------------------------------------------
/internvl/patch/llama_rmsnorm_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import transformers
2 |
3 |
4 | def replace_llama_rmsnorm_with_fused_rmsnorm():
5 | try:
6 | from functools import partial
7 |
8 | from apex.normalization import FusedRMSNorm
9 | LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
10 | transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
11 | print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
12 | except ImportError:
13 | # using the normal LlamaRMSNorm
14 | pass
15 | except Exception:
16 | print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
17 | pass
18 |
--------------------------------------------------------------------------------
/internvl/patch/pad_data_collator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | IGNORE_INDEX = -100
5 |
6 |
7 | def pad_data_collator(features, pad_id=0):
8 |
9 | first = features[0]
10 | batch = {}
11 |
12 | batch_lens = [feat['input_ids'].shape for feat in features]
13 | max_item_length = max(batch_lens)[0]
14 | for idx in range(len(features)):
15 | feat = features[idx]
16 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
17 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
18 | feat['input_ids'] = temp_input_ids
19 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
20 | temp_labels[:feat['labels'].shape[0]] = feat['labels']
21 | feat['labels'] = temp_labels
22 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
23 |
24 | # Special handling for labels.
25 | # Ensure that tensor is created with the correct type
26 | # (it should be automatically the case, but let's make sure of it.)
27 | if 'label' in first and first['label'] is not None:
28 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
29 | dtype = torch.long if isinstance(label, int) else torch.float
30 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
31 | elif 'label_ids' in first and first['label_ids'] is not None:
32 | if isinstance(first['label_ids'], torch.Tensor):
33 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
34 | else:
35 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
36 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
37 |
38 | # Handling of all other possible keys.
39 | # Again, we will use the first element to figure out which key/values are not None for this model.
40 | for k, v in first.items():
41 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
42 | if isinstance(v, torch.Tensor):
43 | batch[k] = torch.stack([f[k] for f in features])
44 | elif isinstance(v, np.ndarray):
45 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
46 | else:
47 | batch[k] = torch.tensor([f[k] for f in features])
48 | return batch
49 |
50 |
51 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
52 |
53 | first = features[0]
54 | batch = {}
55 |
56 | batch_lens = [feat['input_ids'].shape for feat in features]
57 | max_item_length = max_item_length or max(batch_lens)[0]
58 | for idx in range(len(features)):
59 | feat = features[idx]
60 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
61 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
62 | feat['input_ids'] = temp_input_ids
63 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
64 | temp_labels[:feat['labels'].shape[0]] = feat['labels']
65 | feat['labels'] = temp_labels
66 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
67 |
68 | if 'position_ids' in feat:
69 | temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
70 | temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids']
71 | feat['position_ids'] = temp_position_ids
72 |
73 | if 'loss_weight' in feat:
74 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
75 | temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight']
76 | feat['loss_weight'] = temp_loss_weight
77 |
78 | # Special handling for labels.
79 | # Ensure that tensor is created with the correct type
80 | # (it should be automatically the case, but let's make sure of it.)
81 | if 'label' in first and first['label'] is not None:
82 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
83 | dtype = torch.long if isinstance(label, int) else torch.float
84 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
85 | elif 'label_ids' in first and first['label_ids'] is not None:
86 | if isinstance(first['label_ids'], torch.Tensor):
87 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
88 | else:
89 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
90 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
91 |
92 | # Handling of all other possible keys.
93 | # Again, we will use the first element to figure out which key/values are not None for this model.
94 | for k, v in first.items():
95 | if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
96 | v is not None and not isinstance(v, str):
97 | if isinstance(v, torch.Tensor):
98 | batch[k] = torch.stack([f[k] for f in features])
99 | elif isinstance(v, np.ndarray):
100 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
101 | else:
102 | batch[k] = torch.tensor([f[k] for f in features])
103 | if k in ('pixel_values', 'image_flags'):
104 | if isinstance(v, torch.Tensor):
105 | batch[k] = torch.concat([f[k] for f in features])
106 | elif isinstance(v, np.ndarray):
107 | batch[k] = torch.concat(np.stack([f[k] for f in features]))
108 | else:
109 | batch[k] = torch.concat([f[k] for f in features])
110 | return batch
111 |
--------------------------------------------------------------------------------
/internvl/patch/qwen2_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
3 | from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES,
4 | Qwen2FlashAttention2)
5 |
6 |
7 | # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2
8 | class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2):
9 |
10 | def _flash_attention_forward(
11 | self,
12 | query_states,
13 | key_states,
14 | value_states,
15 | attention_mask,
16 | query_length,
17 | dropout=0.0,
18 | softmax_scale=None,
19 | use_sliding_windows=False,
20 | ):
21 | """
22 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
23 | first unpad the input, then computes the attention scores and pad the final attention scores.
24 |
25 | Args:
26 | query_states (`torch.Tensor`):
27 | Input query states to be passed to Flash Attention API
28 | key_states (`torch.Tensor`):
29 | Input key states to be passed to Flash Attention API
30 | value_states (`torch.Tensor`):
31 | Input value states to be passed to Flash Attention API
32 | attention_mask (`torch.Tensor`):
33 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
34 | position of padding tokens and 1 for the position of non-padding tokens.
35 | dropout (`int`, *optional*):
36 | Attention dropout
37 | softmax_scale (`float`, *optional*):
38 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
39 | use_sliding_windows (`bool`, *optional*):
40 | Whether to activate sliding window attention.
41 | """
42 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
43 | query_states = query_states.squeeze(0)
44 | key_states = key_states.squeeze(0)
45 | value_states = value_states.squeeze(0)
46 | cu_seqlens = attention_mask.squeeze(0)
47 |
48 | with torch.no_grad():
49 | max_seqlen = max([
50 | cu_seqlens[idx+1] - cu_seqlens[idx]
51 | for idx in range(cu_seqlens.size(0) - 1)
52 | ]).item()
53 |
54 | if not self._flash_attn_uses_top_left_mask:
55 | causal = self.is_causal
56 | else:
57 | causal = self.is_causal and query_length != 1
58 |
59 | # Decide whether to use SWA or not by layer index.
60 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
61 | use_sliding_windows = False
62 |
63 | if not use_sliding_windows:
64 | attn_output = flash_attn_varlen_func(
65 | q=query_states,
66 | k=key_states,
67 | v=value_states,
68 | cu_seqlens_q=cu_seqlens,
69 | cu_seqlens_k=cu_seqlens,
70 | max_seqlen_q=max_seqlen,
71 | max_seqlen_k=max_seqlen,
72 | dropout_p=dropout,
73 | softmax_scale=softmax_scale,
74 | causal=causal,
75 | )
76 | else:
77 | attn_output = flash_attn_varlen_func(
78 | q=query_states,
79 | k=key_states,
80 | v=value_states,
81 | cu_seqlens_q=cu_seqlens,
82 | cu_seqlens_k=cu_seqlens,
83 | max_seqlen_q=max_seqlen,
84 | max_seqlen_k=max_seqlen,
85 | dropout_p=dropout,
86 | softmax_scale=softmax_scale,
87 | causal=causal,
88 | window_size=(self.config.sliding_window, self.config.sliding_window),
89 | )
90 |
91 | query_states = query_states.unsqueeze(0)
92 | key_states = key_states.unsqueeze(0)
93 | value_states = value_states.unsqueeze(0)
94 | return attn_output
95 |
96 |
97 | def replace_qwen2_attention_class():
98 | QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining
99 | print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!')
100 |
--------------------------------------------------------------------------------
/internvl/patch/train_dataloader_patch.py:
--------------------------------------------------------------------------------
1 | import datasets
2 | import torch
3 | import transformers
4 | from torch.utils.data import DataLoader
5 | from transformers.trainer import is_datasets_available, seed_worker
6 |
7 |
8 | def get_train_dataloader(self) -> DataLoader:
9 | """
10 | Returns the training [`~torch.utils.data.DataLoader`].
11 |
12 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
13 | training if necessary) otherwise.
14 |
15 | Subclass and override this method if you want to inject some custom behavior.
16 | """
17 | if self.train_dataset is None:
18 | raise ValueError('Trainer: training requires a train_dataset.')
19 |
20 | train_dataset = self.train_dataset
21 | data_collator = self.data_collator
22 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
23 | train_dataset = self._remove_unused_columns(train_dataset, description='training')
24 | else:
25 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
26 |
27 | dataloader_params = {
28 | 'batch_size': self._train_batch_size,
29 | 'collate_fn': data_collator,
30 | 'num_workers': self.args.dataloader_num_workers,
31 | 'pin_memory': self.args.dataloader_pin_memory,
32 | 'persistent_workers': self.args.dataloader_persistent_workers,
33 | }
34 |
35 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
36 | dataloader_params['sampler'] = self._get_train_sampler()
37 | dataloader_params['drop_last'] = self.args.dataloader_drop_last
38 | dataloader_params['worker_init_fn'] = seed_worker
39 |
40 | if self.args.use_packed_ds:
41 | return DataLoader(train_dataset, **dataloader_params)
42 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
43 |
44 |
45 |
46 | def get_unwapper_train_dataloader(self) -> DataLoader:
47 | """
48 | Returns the training [`~torch.utils.data.DataLoader`].
49 |
50 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
51 | training if necessary) otherwise.
52 |
53 | Subclass and override this method if you want to inject some custom behavior.
54 | """
55 | if self.train_dataset is None:
56 | raise ValueError('Trainer: training requires a train_dataset.')
57 |
58 | train_dataset = self.train_dataset
59 | data_collator = self.data_collator
60 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
61 | train_dataset = self._remove_unused_columns(train_dataset, description='training')
62 | else:
63 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training')
64 |
65 | dataloader_params = {
66 | 'batch_size': self._train_batch_size,
67 | 'collate_fn': data_collator,
68 | 'num_workers': self.args.dataloader_num_workers,
69 | 'pin_memory': self.args.dataloader_pin_memory,
70 | 'persistent_workers': self.args.dataloader_persistent_workers,
71 | }
72 |
73 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
74 | dataloader_params['sampler'] = self._get_train_sampler()
75 | dataloader_params['drop_last'] = self.args.dataloader_drop_last
76 | dataloader_params['worker_init_fn'] = seed_worker
77 |
78 | return DataLoader(train_dataset, **dataloader_params)
79 |
80 |
81 | def replace_train_dataloader():
82 | transformers.Trainer.get_train_dataloader = get_train_dataloader
83 | print('Replace train dataloader!!')
84 |
85 |
86 | def replace_unwapper_train_dataloader():
87 | transformers.Trainer.get_train_dataloader = get_unwapper_train_dataloader
88 | print('Replace train dataloader!!')
--------------------------------------------------------------------------------
/internvl/patch/train_sampler_patch.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | import transformers
5 | from torch.utils.data import Dataset, Sampler,SequentialSampler
6 | from transformers.tokenization_utils_base import BatchEncoding
7 | from transformers.trainer import (LengthGroupedSampler, RandomSampler,
8 | has_length)
9 | from transformers.trainer_pt_utils import logger
10 |
11 |
12 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
13 | def split_to_even_chunks(indices, lengths, num_chunks):
14 | """
15 | Split a list of indices into `chunks` chunks of roughly equal lengths.
16 | """
17 |
18 | if len(indices) % num_chunks != 0:
19 | return [indices[i::num_chunks] for i in range(num_chunks)]
20 |
21 | num_indices_per_chunk = len(indices) // num_chunks
22 |
23 | chunks = [[] for _ in range(num_chunks)]
24 | chunks_lengths = [0 for _ in range(num_chunks)]
25 | for index in indices:
26 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
27 | chunks[shortest_chunk].append(index)
28 | chunks_lengths[shortest_chunk] += lengths[index]
29 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
30 | chunks_lengths[shortest_chunk] = float('inf')
31 |
32 | return chunks
33 |
34 |
35 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
36 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
37 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
38 | indices = torch.randperm(len(lengths), generator=generator)
39 | megabatch_size = world_size * batch_size
40 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
41 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
42 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
43 |
44 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
45 |
46 |
47 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
48 | class LengthGroupedSampler(Sampler):
49 | r"""
50 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
51 | keeping a bit of randomness.
52 | """
53 |
54 | def __init__(
55 | self,
56 | batch_size: int,
57 | world_size: int,
58 | dataset: Optional[Dataset] = None,
59 | lengths: Optional[List[int]] = None,
60 | model_input_name: Optional[str] = None,
61 | generator=None,
62 | ):
63 | if dataset is None and lengths is None:
64 | raise ValueError('One of dataset and lengths must be provided.')
65 |
66 | self.batch_size = batch_size
67 | if lengths is None:
68 | model_input_name = model_input_name if model_input_name is not None else 'input_ids'
69 | if (
70 | not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
71 | or model_input_name not in dataset[0]
72 | ):
73 | raise ValueError(
74 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an '
75 | f"'{model_input_name}' key."
76 | )
77 | lengths = [len(feature[model_input_name]) for feature in dataset]
78 | elif isinstance(lengths, torch.Tensor):
79 | logger.info(
80 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
81 | )
82 | lengths = lengths.tolist()
83 | self.world_size = world_size
84 | self.lengths = lengths
85 | self.generator = generator
86 |
87 | def __len__(self):
88 | return len(self.lengths)
89 |
90 | def __iter__(self):
91 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
92 | return iter(indices)
93 |
94 |
95 | # patch trainer
96 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
97 | if self.train_dataset is None or not has_length(self.train_dataset):
98 | return None
99 | # Build the sampler.
100 | if self.args.group_by_length:
101 | lengths = []
102 | for dataset in self.train_dataset.datasets:
103 | lengths = lengths + dataset.length
104 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
105 | return LengthGroupedSampler(
106 | self.args.train_batch_size,
107 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
108 | # self.args.train_batch_size * self.args.gradient_accumulation_steps,
109 | dataset=self.train_dataset,
110 | lengths=lengths,
111 | model_input_name=model_input_name,
112 | )
113 | else:
114 | return RandomSampler(self.train_dataset)
115 |
116 | # patch trainer
117 | def _get_sequence_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
118 | if self.train_dataset is None or not has_length(self.train_dataset):
119 | return None
120 |
121 | # Build the sampler.
122 | print("Using SequentialSampler")
123 | return SequentialSampler(self.train_dataset)
124 |
125 |
126 | def replace_train_sampler():
127 | transformers.Trainer._get_train_sampler = _get_train_sampler
128 | print('Replace train sampler!!')
129 |
130 | def replace_sequence_train_sampler():
131 | transformers.Trainer._get_train_sampler = _get_sequence_train_sampler
132 | print('Replace train sampler!!')
--------------------------------------------------------------------------------
/internvl/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/internvl/serve/__init__.py
--------------------------------------------------------------------------------
/internvl/serve/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = '.'
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ''
10 | DEFAULT_IMAGE_PATCH_TOKEN = ''
11 | DEFAULT_IM_START_TOKEN = '
'
12 | DEFAULT_IM_END_TOKEN = ''
13 | IMAGE_PLACEHOLDER = ''
14 |
--------------------------------------------------------------------------------
/internvl/serve/mm_utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from io import BytesIO
3 |
4 | import torch
5 | from PIL import Image
6 | from transformers import StoppingCriteria
7 |
8 | from .constants import IMAGE_TOKEN_INDEX
9 |
10 |
11 | def load_image_from_base64(image):
12 | return Image.open(BytesIO(base64.b64decode(image)))
13 |
14 |
15 | def expand2square(pil_img, background_color):
16 | width, height = pil_img.size
17 | if width == height:
18 | return pil_img
19 | elif width > height:
20 | result = Image.new(pil_img.mode, (width, width), background_color)
21 | result.paste(pil_img, (0, (width - height) // 2))
22 | return result
23 | else:
24 | result = Image.new(pil_img.mode, (height, height), background_color)
25 | result.paste(pil_img, ((height - width) // 2, 0))
26 | return result
27 |
28 |
29 | def process_images(images, image_processor, model_cfg):
30 | image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None)
31 | new_images = []
32 | if image_aspect_ratio == 'pad':
33 | for image in images:
34 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
35 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
36 | new_images.append(image)
37 | else:
38 | return image_processor(images, return_tensors='pt')['pixel_values']
39 | if all(x.shape == new_images[0].shape for x in new_images):
40 | new_images = torch.stack(new_images, dim=0)
41 | return new_images
42 |
43 |
44 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX,
45 | num_image_tokens=None, return_tensors=None):
46 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
47 |
48 | def insert_separator(X, sep):
49 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
50 |
51 | input_ids = []
52 | offset = 0
53 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
54 | offset = 1
55 | input_ids.append(prompt_chunks[0][0])
56 |
57 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)):
58 | input_ids.extend(x[offset:])
59 |
60 | if return_tensors is not None:
61 | if return_tensors == 'pt':
62 | return torch.tensor(input_ids, dtype=torch.long)
63 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
64 | return input_ids
65 |
66 |
67 | def get_model_name_from_path(model_path):
68 | model_path = model_path.strip('/')
69 | model_paths = model_path.split('/')
70 | if model_paths[-1].startswith('checkpoint-'):
71 | return model_paths[-2] + '_' + model_paths[-1]
72 | else:
73 | return model_paths[-1]
74 |
75 |
76 | class KeywordsStoppingCriteria(StoppingCriteria):
77 | def __init__(self, keywords, tokenizer, input_ids):
78 | self.keywords = keywords
79 | self.keyword_ids = []
80 | self.max_keyword_len = 0
81 | for keyword in keywords:
82 | cur_keyword_ids = tokenizer(keyword).input_ids
83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
84 | cur_keyword_ids = cur_keyword_ids[1:]
85 | if len(cur_keyword_ids) > self.max_keyword_len:
86 | self.max_keyword_len = len(cur_keyword_ids)
87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
88 | self.tokenizer = tokenizer
89 | self.start_len = input_ids.shape[1]
90 |
91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
92 | assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)'
93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
95 | for keyword_id in self.keyword_ids:
96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
97 | return True
98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
99 | for keyword in self.keywords:
100 | if keyword in outputs:
101 | return True
102 | return False
103 |
--------------------------------------------------------------------------------
/internvl/serve/model_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker executes the model.
3 | """
4 | import argparse
5 | import asyncio
6 | import json
7 | import threading
8 | import time
9 | import uuid
10 | from functools import partial
11 | from threading import Thread
12 |
13 | import requests
14 | import torch
15 | import uvicorn
16 | from fastapi import BackgroundTasks, FastAPI, Request
17 | from fastapi.responses import StreamingResponse
18 | from internvl.train.dataset import dynamic_preprocess
19 | from transformers import (AutoTokenizer, CLIPImageProcessor,
20 | TextIteratorStreamer)
21 |
22 | from ..model.internvl_chat import InternVLChatModel
23 | from .constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
24 | DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
25 | IMAGE_TOKEN_INDEX, WORKER_HEART_BEAT_INTERVAL)
26 | from .mm_utils import (KeywordsStoppingCriteria, load_image_from_base64,
27 | process_images, tokenizer_image_token)
28 | from .utils import build_logger, pretty_print_semaphore, server_error_msg
29 |
30 | GB = 1 << 30
31 |
32 | worker_id = str(uuid.uuid4())[:6]
33 | logger = build_logger('model_worker', f'model_worker_{worker_id}.log')
34 | global_counter = 0
35 |
36 | model_semaphore = None
37 |
38 |
39 | def heart_beat_worker(controller):
40 |
41 | while True:
42 | time.sleep(WORKER_HEART_BEAT_INTERVAL)
43 | controller.send_heart_beat()
44 |
45 |
46 | class ModelWorker:
47 | def __init__(self, controller_addr, worker_addr,
48 | worker_id, no_register,
49 | model_path, model_base, model_name,
50 | load_8bit, load_4bit, device):
51 | self.controller_addr = controller_addr
52 | self.worker_addr = worker_addr
53 | self.worker_id = worker_id
54 | if model_path.endswith('/'):
55 | model_path = model_path[:-1]
56 | if model_name is None:
57 | model_paths = model_path.split('/')
58 | if model_paths[-1].startswith('checkpoint-'):
59 | self.model_name = model_paths[-2] + '_' + model_paths[-1]
60 | else:
61 | self.model_name = model_paths[-1]
62 | else:
63 | self.model_name = model_name
64 |
65 | logger.info(f'Loading the model {self.model_name} on worker {worker_id} ...')
66 |
67 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False)
68 | if device == 'auto':
69 | import os
70 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
71 | # This can make distributed deployment work properly, wonder why
72 | self.model = InternVLChatModel.from_pretrained(
73 | model_path, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map='auto').eval()
74 | else:
75 | self.model = InternVLChatModel.from_pretrained(
76 | model_path, load_in_8bit=load_8bit, torch_dtype=torch.float16).eval()
77 | if not load_8bit and not device == 'auto':
78 | self.model = self.model.cuda()
79 | self.image_size = self.model.config.force_image_size
80 | self.image_processor = CLIPImageProcessor(
81 | crop_size=self.image_size, do_center_crop=True, do_normalize=True, do_resize=True,
82 | image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], size=self.image_size
83 | )
84 | self.context_len = 12800
85 | self.is_multimodal = True
86 |
87 | if not no_register:
88 | self.register_to_controller()
89 | self.heart_beat_thread = threading.Thread(
90 | target=heart_beat_worker, args=(self,))
91 | self.heart_beat_thread.start()
92 |
93 | def register_to_controller(self):
94 | logger.info('Register to controller')
95 |
96 | url = self.controller_addr + '/register_worker'
97 | data = {
98 | 'worker_name': self.worker_addr,
99 | 'check_heart_beat': True,
100 | 'worker_status': self.get_status()
101 | }
102 | r = requests.post(url, json=data)
103 | assert r.status_code == 200
104 |
105 | def send_heart_beat(self):
106 | logger.info(f'Send heart beat. Models: {[self.model_name]}. '
107 | f'Semaphore: {pretty_print_semaphore(model_semaphore)}. '
108 | f'global_counter: {global_counter}')
109 |
110 | url = self.controller_addr + '/receive_heart_beat'
111 |
112 | while True:
113 | try:
114 | ret = requests.post(url, json={
115 | 'worker_name': self.worker_addr,
116 | 'queue_length': self.get_queue_length()}, timeout=5)
117 | exist = ret.json()['exist']
118 | break
119 | except requests.exceptions.RequestException as e:
120 | logger.error(f'heart beat error: {e}')
121 | time.sleep(5)
122 |
123 | if not exist:
124 | self.register_to_controller()
125 |
126 | def get_queue_length(self):
127 | if model_semaphore is None:
128 | return 0
129 | else:
130 | return args.limit_model_concurrency - model_semaphore._value + (len(
131 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
132 |
133 | def get_status(self):
134 | return {
135 | 'model_names': [self.model_name],
136 | 'speed': 1,
137 | 'queue_length': self.get_queue_length(),
138 | }
139 |
140 | @torch.inference_mode()
141 | def generate_stream(self, params):
142 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
143 |
144 | prompt = params['prompt']
145 | max_input_tiles = params['max_input_tiles']
146 | logger.info(f'max_input_tiles: {max_input_tiles}')
147 | ori_prompt = prompt
148 | images = params.get('images', None)
149 | num_image_tokens = 0
150 | if images is not None and len(images) > 0 and self.is_multimodal:
151 | if len(images) > 0:
152 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
153 | raise ValueError('Number of images does not match number of tokens in prompt')
154 | logger.info(f'dynamic_image_size: {model.config.dynamic_image_size}')
155 | logger.info(f'use_thumbnail: {model.config.use_thumbnail}')
156 | images = [load_image_from_base64(image) for image in images]
157 | if model.config.dynamic_image_size:
158 | images = dynamic_preprocess(
159 | images[0], image_size=self.image_size, max_num=max_input_tiles,
160 | use_thumbnail=model.config.use_thumbnail)
161 | images = [item.resize((self.image_size, self.image_size)) for item in images]
162 | logger.info(f'Resize images to {self.image_size}x{self.image_size}')
163 | images = process_images(images, image_processor, model.config)
164 |
165 | if type(images) is list:
166 | images = [image.to(self.model.device, dtype=torch.float16) for image in images]
167 | else:
168 | images = images.to(self.model.device, dtype=torch.float16)
169 | # images = torch.concat(images)
170 | logger.info(f'Split images to {images.shape}')
171 |
172 | replace_token = DEFAULT_IMAGE_TOKEN
173 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
174 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
175 | logger.info(prompt)
176 | num_image_tokens = model.num_image_token * images.size(0)
177 | model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN)
178 | else:
179 | images = None
180 | image_args = {'pixel_values': images}
181 | else:
182 | images = None
183 | image_args = {}
184 |
185 | temperature = float(params.get('temperature', 1.0))
186 | top_p = float(params.get('top_p', 1.0))
187 | max_context_length = getattr(model.config, 'max_position_embeddings', 16384)
188 | max_new_tokens = int(params.get('max_new_tokens', 1024))
189 | stop_str = params.get('stop', None)
190 | do_sample = True if temperature > 0.001 else False
191 | logger.info(f'num_image_tokens: {num_image_tokens}')
192 | logger.info(f'stop_str: {stop_str}')
193 | eos_token_id = tokenizer.convert_tokens_to_ids(stop_str)
194 |
195 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, num_image_tokens, return_tensors='pt').unsqueeze(0).cuda()
196 | input_ids[input_ids==IMAGE_TOKEN_INDEX] = model.img_context_token_id
197 |
198 | keywords = [stop_str]
199 | # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
200 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
201 |
202 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1])
203 | logger.info(f'max_new_tokens: {max_new_tokens}')
204 | if max_new_tokens < 1:
205 | yield json.dumps({'text': ori_prompt + 'Exceeds max token length. Please start a new conversation, thanks.', 'error_code': 0}).encode() + b'\0'
206 | return
207 |
208 | thread = Thread(target=model.generate, kwargs=dict(
209 | input_ids=input_ids,
210 | do_sample=do_sample,
211 | temperature=temperature,
212 | repetition_penalty=1.0,
213 | top_p=top_p,
214 | max_new_tokens=max_new_tokens,
215 | streamer=streamer,
216 | eos_token_id=eos_token_id,
217 | **image_args
218 | ))
219 | thread.start()
220 |
221 | generated_text = ori_prompt
222 | for new_text in streamer:
223 | generated_text += new_text
224 | if generated_text.endswith(stop_str):
225 | generated_text = generated_text[:-len(stop_str)]
226 | yield json.dumps({'text': generated_text, 'error_code': 0}).encode() + b'\0'
227 |
228 | def generate_stream_gate(self, params):
229 | try:
230 | for x in self.generate_stream(params):
231 | yield x
232 | except ValueError as e:
233 | print('Caught ValueError:', e)
234 | ret = {
235 | 'text': server_error_msg,
236 | 'error_code': 1,
237 | }
238 | yield json.dumps(ret).encode() + b'\0'
239 | except torch.cuda.CudaError as e:
240 | print('Caught torch.cuda.CudaError:', e)
241 | ret = {
242 | 'text': server_error_msg,
243 | 'error_code': 1,
244 | }
245 | yield json.dumps(ret).encode() + b'\0'
246 | except Exception as e:
247 | print('Caught Unknown Error', e)
248 | ret = {
249 | 'text': server_error_msg,
250 | 'error_code': 1,
251 | }
252 | yield json.dumps(ret).encode() + b'\0'
253 |
254 |
255 | app = FastAPI()
256 |
257 |
258 | def release_model_semaphore(fn=None):
259 | model_semaphore.release()
260 | if fn is not None:
261 | fn()
262 |
263 |
264 | @app.post('/worker_generate_stream')
265 | async def generate_stream(request: Request):
266 | global model_semaphore, global_counter
267 | global_counter += 1
268 | params = await request.json()
269 |
270 | if model_semaphore is None:
271 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
272 | await model_semaphore.acquire()
273 | worker.send_heart_beat()
274 | generator = worker.generate_stream_gate(params)
275 | background_tasks = BackgroundTasks()
276 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
277 | return StreamingResponse(generator, background=background_tasks)
278 |
279 |
280 | @app.post('/worker_get_status')
281 | async def get_status(request: Request):
282 | return worker.get_status()
283 |
284 |
285 | if __name__ == '__main__':
286 | parser = argparse.ArgumentParser()
287 | parser.add_argument('--host', type=str, default='localhost')
288 | parser.add_argument('--port', type=int, default=21002)
289 | parser.add_argument('--worker-address', type=str,
290 | default='http://localhost:21002')
291 | parser.add_argument('--controller-address', type=str,
292 | default='http://localhost:21001')
293 | parser.add_argument('--model-path', type=str, default='facebook/opt-350m')
294 | parser.add_argument('--model-base', type=str, default=None)
295 | parser.add_argument('--model-name', type=str)
296 | parser.add_argument('--device', type=str, default='cuda')
297 | parser.add_argument('--multi-modal', action='store_true', help='Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.')
298 | parser.add_argument('--limit-model-concurrency', type=int, default=5)
299 | parser.add_argument('--stream-interval', type=int, default=1)
300 | parser.add_argument('--no-register', action='store_true')
301 | parser.add_argument('--load-8bit', action='store_true')
302 | parser.add_argument('--load-4bit', action='store_true')
303 | args = parser.parse_args()
304 | logger.info(f'args: {args}')
305 |
306 | if args.multi_modal:
307 | logger.warning('Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.')
308 |
309 | worker = ModelWorker(args.controller_address,
310 | args.worker_address,
311 | worker_id,
312 | args.no_register,
313 | args.model_path,
314 | args.model_base,
315 | args.model_name,
316 | args.load_8bit,
317 | args.load_4bit,
318 | args.device)
319 | uvicorn.run(app, host=args.host, port=args.port, log_level='info')
320 |
--------------------------------------------------------------------------------
/internvl/serve/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from .constants import LOGDIR
10 |
11 | server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**'
12 | moderation_msg = 'YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.'
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
22 | datefmt='%Y-%m-%d %H:%M:%S',
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger('stdout')
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger('stderr')
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
99 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = 'https://api.openai.com/v1/moderations'
107 | headers = {'Content-Type': 'application/json',
108 | 'Authorization': 'Bearer ' + os.environ['OPENAI_API_KEY']}
109 | text = text.replace('\n', '')
110 | data = '{' + '"input": ' + f'"{text}"' + '}'
111 | data = data.encode('utf-8')
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()['results'][0]['flagged']
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return 'None'
126 | return f'Semaphore(value={semaphore._value}, locked={semaphore.locked()})'
127 |
--------------------------------------------------------------------------------
/internvl/train/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/internvl/train/__init__.py
--------------------------------------------------------------------------------
/internvl/train/constants.py:
--------------------------------------------------------------------------------
1 | IMG_CONTEXT_TOKEN = ''
2 | IMG_START_TOKEN = '
'
3 | IMG_END_TOKEN = ''
4 | QUAD_START_TOKEN = ''
5 | QUAD_END_TOKEN = ''
6 | REF_START_TOKEN = '['
7 | REF_END_TOKEN = ']'
8 | BOX_START_TOKEN = ''
9 | BOX_END_TOKEN = ''
10 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
11 | IMAGENET_STD = (0.229, 0.224, 0.225)
12 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
13 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
14 | SIGLIP_MEAN = (0.5, 0.5, 0.5)
15 | SIGLIP_STD = (0.5, 0.5, 0.5)
16 |
--------------------------------------------------------------------------------
/internvl/train/dataset.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | from transformers.trainer_pt_utils import LabelSmoother
4 |
5 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
6 | from typing import Dict
7 |
8 | import torch
9 | import torchvision.transforms as T
10 | import transformers
11 | from internvl.conversation import get_conv_template
12 | from PIL import Image
13 | from torch.utils.data import ConcatDataset, WeightedRandomSampler
14 | from torchvision.transforms.functional import InterpolationMode
15 |
16 | from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
17 | IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
18 | SIGLIP_MEAN, SIGLIP_STD)
19 |
20 | import sys
21 |
22 |
23 | class WeightedConcatDataset(ConcatDataset):
24 | def __init__(self, datasets, weights):
25 | super().__init__(datasets)
26 | self.weights = torch.DoubleTensor(weights)
27 | self.total_size = sum(len(d) for d in datasets)
28 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
29 |
30 | def __iter__(self):
31 | return iter(self.sampler)
32 |
33 | def __len__(self):
34 | return self.total_size
35 |
36 | def expand2square(pil_img, background_color):
37 | width, height = pil_img.size
38 | if width == height:
39 | return pil_img
40 | elif width > height:
41 | result = Image.new(pil_img.mode, (width, width), background_color)
42 | result.paste(pil_img, (0, (width - height) // 2))
43 | return result
44 | else:
45 | result = Image.new(pil_img.mode, (height, height), background_color)
46 | result.paste(pil_img, ((height - width) // 2, 0))
47 | return result
48 |
49 |
50 | def simulate_jpeg_degradation(quality):
51 | def jpeg_degrade(img):
52 | with io.BytesIO() as output:
53 | img.convert('RGB').save(output, format='JPEG', quality=quality)
54 | output.seek(0)
55 | img_jpeg = Image.open(output).copy()
56 | return img_jpeg
57 | return jpeg_degrade
58 |
59 |
60 | # Define the JPEG compression quality range, pre-create all JPEG compression functions
61 | qualities = list(range(75, 101))
62 | jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
63 |
64 |
65 | def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
66 | if normalize_type == 'imagenet':
67 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
68 | elif normalize_type == 'clip':
69 | MEAN, STD = CLIP_MEAN, CLIP_STD
70 | elif normalize_type == 'siglip':
71 | MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
72 | else:
73 | raise NotImplementedError
74 | if is_train: # use data augumentation
75 | transform = T.Compose([
76 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
77 | T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
78 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
79 | T.ToTensor(),
80 | T.Normalize(mean=MEAN, std=STD)
81 | ])
82 | else:
83 | if pad2square is False: # now we use this transform function by default
84 | transform = T.Compose([
85 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
86 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
87 | T.ToTensor(),
88 | T.Normalize(mean=MEAN, std=STD)
89 | ])
90 | else:
91 | transform = T.Compose([
92 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
93 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
94 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
95 | T.ToTensor(),
96 | T.Normalize(mean=MEAN, std=STD)
97 | ])
98 |
99 | return transform
100 |
101 |
102 | def preprocess(
103 | template_name,
104 | sources,
105 | tokenizer: transformers.PreTrainedTokenizer,
106 | num_image_token_list: list,
107 | text_only: bool = False,
108 | group_by_length: bool = False,
109 | use_packed_ds: bool = False,
110 | ds_name: str = None,
111 | num_image: int = 1
112 | ) -> Dict:
113 | conv = get_conv_template(template_name)
114 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
115 |
116 | # Apply prompt templates
117 | conversations = []
118 | for i, source in enumerate(sources):
119 | if roles[source[0]['from']] != conv.roles[0]:
120 | # Skip the first one if it is not from human
121 | source = source[1:]
122 |
123 | conv.messages = []
124 | for j, sentence in enumerate(source):
125 | role = roles[sentence['from']]
126 | assert role == conv.roles[j % 2], f'{i}'
127 | conv.append_message(role, sentence['value'])
128 | conversations.append(conv.get_prompt())
129 |
130 | if not text_only:
131 | new_conversations = []
132 | for conversation in conversations:
133 | for i in range(num_image):
134 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
135 | conversation = conversation.replace('', image_tokens, 1)
136 | new_conversations.append(conversation)
137 | conversations = new_conversations
138 |
139 | # Tokenize conversations
140 | input_ids = tokenizer(
141 | conversations,
142 | return_tensors='pt',
143 | padding=False if group_by_length or use_packed_ds else 'max_length',
144 | max_length=tokenizer.model_max_length,
145 | truncation=True,
146 | ).input_ids
147 | targets = input_ids.clone()
148 |
149 | # Mask targets. Only compute loss on the assistant outputs.
150 | sep = conv.sep + conv.roles[1] + ': '
151 | for conversation, target in zip(conversations, targets):
152 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
153 |
154 | turns = conversation.split(conv.sep2)
155 | cur_len = 1
156 | target[:cur_len] = IGNORE_TOKEN_ID
157 | for i, turn in enumerate(turns):
158 | if turn == '':
159 | break
160 | turn_len = len(tokenizer(turn).input_ids)
161 |
162 | parts = turn.split(sep)
163 | if len(parts) != 2:
164 | break
165 | parts[0] += sep
166 | # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
167 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
168 |
169 | if i != 0 and not tokenizer.legacy:
170 | # The legacy and non-legacy modes handle special tokens differently
171 | instruction_len -= 1
172 |
173 | # Ignore the user instructions
174 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
175 | cur_len += turn_len
176 |
177 | if i != 0 and not tokenizer.legacy:
178 | # The legacy and non-legacy modes handle special tokens differently
179 | cur_len -= 1
180 |
181 | target[cur_len:] = IGNORE_TOKEN_ID
182 |
183 | if False: # Inspect and check the correctness of masking
184 | z = target.clone()
185 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
186 | logger.info(tokenizer.decode(z))
187 | exit()
188 |
189 | if cur_len < tokenizer.model_max_length:
190 | if cur_len != total_len:
191 | target[:] = IGNORE_TOKEN_ID
192 | print(
193 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
194 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
195 | )
196 | sys.stdout.flush()
197 |
198 | return dict(
199 | input_ids=input_ids,
200 | labels=targets,
201 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
202 | )
203 |
204 |
205 | def preprocess_mpt(
206 | template_name,
207 | sources,
208 | tokenizer: transformers.PreTrainedTokenizer,
209 | num_image_token_list: list,
210 | text_only: bool = False,
211 | group_by_length: bool = False,
212 | use_packed_ds: bool = False,
213 | ds_name: str = None,
214 | num_image: int = 1
215 | ) -> Dict:
216 | conv = get_conv_template(template_name)
217 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
218 |
219 | # Apply prompt templates
220 | conversations = []
221 | for i, source in enumerate(sources):
222 | if roles[source[0]['from']] != conv.roles[0]:
223 | # Skip the first one if it is not from human
224 | source = source[1:]
225 |
226 | conv.messages = []
227 | for j, sentence in enumerate(source):
228 | role = roles[sentence['from']]
229 | assert role == conv.roles[j % 2], f'{i}'
230 | conv.append_message(role, sentence['value'])
231 | conversations.append(conv.get_prompt())
232 |
233 | if not text_only:
234 | new_conversations = []
235 | for conversation in conversations:
236 | for i in range(num_image):
237 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
238 | conversation = conversation.replace('', image_tokens, 1)
239 | new_conversations.append(conversation)
240 | conversations = new_conversations
241 |
242 | # Tokenize conversations
243 | input_ids = tokenizer(
244 | conversations,
245 | return_tensors='pt',
246 | padding=False if group_by_length or use_packed_ds else 'max_length',
247 | max_length=tokenizer.model_max_length,
248 | truncation=True,
249 | ).input_ids
250 | targets = input_ids.clone()
251 |
252 | # Mask targets. Only compute loss on the assistant outputs.
253 | sep = conv.sep + conv.roles[1]
254 | for conversation, target in zip(conversations, targets):
255 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
256 |
257 | turns = conversation.split(conv.sep)
258 | re_turns = [conv.sep.join(turns[:3])]
259 | for conv_idx in range(3, len(turns), 2):
260 | re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2]))
261 | cur_len = 0
262 | target[:cur_len] = IGNORE_TOKEN_ID
263 | for i, turn in enumerate(re_turns):
264 | if turn == '':
265 | break
266 | turn_len = len(tokenizer(turn).input_ids) + 1
267 |
268 | parts = turn.split(sep)
269 | if len(parts) != 2:
270 | break
271 | parts[0] += sep
272 | instruction_len = len(tokenizer(parts[0]).input_ids)
273 |
274 | # Ignore the user instructions
275 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
276 |
277 | cur_len += turn_len
278 |
279 | target[cur_len:] = IGNORE_TOKEN_ID
280 |
281 | if cur_len < tokenizer.model_max_length:
282 | if cur_len != total_len:
283 | target[:] = IGNORE_TOKEN_ID
284 | print(
285 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
286 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
287 | )
288 | sys.stdout.flush()
289 |
290 | return dict(
291 | input_ids=input_ids,
292 | labels=targets,
293 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
294 | )
295 |
296 |
297 | def preprocess_phi3(
298 | template_name,
299 | sources,
300 | tokenizer: transformers.PreTrainedTokenizer,
301 | num_image_token_list: list,
302 | text_only: bool = False,
303 | group_by_length: bool = False,
304 | use_packed_ds: bool = False,
305 | ds_name: str = None,
306 | num_image: int = 1
307 | ) -> Dict:
308 | conv = get_conv_template(template_name)
309 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
310 |
311 | # Apply prompt templates
312 | conversations = []
313 | for i, source in enumerate(sources):
314 | if roles[source[0]['from']] != conv.roles[0]:
315 | # Skip the first one if it is not from human
316 | source = source[1:]
317 |
318 | conv.messages = []
319 | for j, sentence in enumerate(source):
320 | role = roles[sentence['from']]
321 | assert role == conv.roles[j % 2], f'{i}'
322 | conv.append_message(role, sentence['value'])
323 | conversations.append(conv.get_prompt())
324 |
325 | if not text_only:
326 | new_conversations = []
327 | for conversation in conversations:
328 | for i in range(num_image):
329 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
330 | conversation = conversation.replace('', image_tokens, 1)
331 | new_conversations.append(conversation)
332 | conversations = new_conversations
333 |
334 | # Tokenize conversations
335 | tokenizer.padding_side = 'right'
336 | input_ids = tokenizer(
337 | conversations,
338 | return_tensors='pt',
339 | padding=False if group_by_length or use_packed_ds else 'max_length',
340 | max_length=tokenizer.model_max_length,
341 | truncation=True,
342 | ).input_ids
343 | targets = input_ids.clone()
344 |
345 | # Mask targets. Only compute loss on the assistant outputs.
346 | sep = conv.sep + conv.roles[1]
347 | for conversation, target in zip(conversations, targets):
348 | total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
349 |
350 | turns = conversation.split(conv.sep)
351 | re_turns = [conv.sep.join(turns[:3])]
352 | for conv_idx in range(3, len(turns), 2):
353 | re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2]))
354 | cur_len = 1
355 | target[:cur_len] = IGNORE_TOKEN_ID
356 | endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
357 | target[target == endoftext_id] = IGNORE_TOKEN_ID
358 |
359 | for i, turn in enumerate(re_turns):
360 | if turn == '':
361 | break
362 | if i == 0:
363 | turn_len = len(tokenizer(turn).input_ids)
364 | else:
365 | turn_len = len(tokenizer(turn).input_ids) - 1
366 | parts = turn.split(sep)
367 | if len(parts) != 2:
368 | break
369 | parts[0] += sep
370 |
371 | if i == 0:
372 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1
373 | else:
374 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2
375 |
376 | # Ignore the user instructions
377 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
378 |
379 | cur_len += turn_len
380 |
381 | target[cur_len:] = IGNORE_TOKEN_ID
382 |
383 | if False: # Inspect and check the correctness of masking
384 | z = target.clone()
385 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
386 | print(repr(tokenizer.decode(z)))
387 |
388 | if cur_len < tokenizer.model_max_length:
389 | if cur_len != total_len:
390 | target[:] = IGNORE_TOKEN_ID
391 | print(
392 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
393 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
394 | )
395 | sys.stdout.flush()
396 |
397 | return dict(
398 | input_ids=input_ids,
399 | labels=targets,
400 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
401 | )
402 |
403 |
404 | def preprocess_internlm(
405 | template_name,
406 | sources,
407 | tokenizer: transformers.PreTrainedTokenizer,
408 | num_image_token_list: list,
409 | text_only: bool = False,
410 | group_by_length: bool = False,
411 | use_packed_ds: bool = False,
412 | ds_name: str = None,
413 | num_image: int = 1
414 | ) -> Dict:
415 | conv = get_conv_template(template_name)
416 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
417 |
418 | # Apply prompt templates
419 | conversations = []
420 | for i, source in enumerate(sources):
421 | if roles[source[0]['from']] != conv.roles[0]:
422 | # Skip the first one if it is not from human
423 | source = source[1:]
424 |
425 | conv.messages = []
426 | for j, sentence in enumerate(source):
427 | role = roles[sentence['from']]
428 | assert role == conv.roles[j % 2], f'{i}'
429 | sentence['value'] = sentence['value'].strip()
430 | conv.append_message(role, sentence['value'])
431 | conversations.append(conv.get_prompt())
432 |
433 | if not text_only:
434 | new_conversations = []
435 | for conversation in conversations:
436 | for i in range(num_image):
437 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
438 | conversation = conversation.replace('', image_tokens, 1)
439 | new_conversations.append(conversation)
440 | conversations = new_conversations
441 |
442 | # Tokenize conversations
443 | input_ids = tokenizer(
444 | conversations,
445 | return_tensors='pt',
446 | padding=False if group_by_length or use_packed_ds else 'max_length',
447 | max_length=tokenizer.model_max_length,
448 | truncation=True,
449 | ).input_ids
450 | targets = input_ids.clone()
451 |
452 | for conversation, target in zip(conversations, targets):
453 | total_len = int(target.ne(tokenizer.pad_token_id).sum())
454 | cur_len = 1
455 | target[:cur_len] = IGNORE_TOKEN_ID
456 | parts = conversation.split(conv.roles[1])
457 | info = parts[0] + conv.roles[1]
458 | temp_len = len(tokenizer(info).input_ids) - 1
459 | target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
460 | cur_len = cur_len + temp_len
461 |
462 | for index in range(1, len(parts) - 1):
463 | info = parts[index]
464 | part1, part2 = info.split(conv.roles[0])
465 | temp_len = len(tokenizer(part1).input_ids) - 1
466 | cur_len = cur_len + temp_len
467 | part = conv.roles[0] + part2 + conv.roles[1]
468 | temp_len = len(tokenizer(part).input_ids) - 1
469 | target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
470 | cur_len = cur_len + temp_len
471 | last_info = parts[-1]
472 | temp_len = len(tokenizer(last_info).input_ids) - 1
473 | cur_len = cur_len + temp_len
474 |
475 | target[cur_len:] = IGNORE_TOKEN_ID
476 | if False: # Inspect and check the correctness of masking
477 | z = target.clone()
478 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
479 | print(repr(tokenizer.decode(z)))
480 |
481 | if cur_len < tokenizer.model_max_length:
482 | if cur_len != total_len:
483 | target[:] = IGNORE_TOKEN_ID
484 | print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
485 | sys.stdout.flush()
486 |
487 | return dict(
488 | input_ids=input_ids,
489 | labels=targets,
490 | attention_mask=input_ids.ne(tokenizer.pad_token_id),
491 | )
492 |
493 |
494 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
495 | best_ratio_diff = float('inf')
496 | best_ratio = (1, 1)
497 | area = width * height
498 | for ratio in target_ratios:
499 | target_aspect_ratio = ratio[0] / ratio[1]
500 | ratio_diff = abs(aspect_ratio - target_aspect_ratio)
501 | if ratio_diff < best_ratio_diff:
502 | best_ratio_diff = ratio_diff
503 | best_ratio = ratio
504 | elif ratio_diff == best_ratio_diff:
505 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
506 | best_ratio = ratio
507 | return best_ratio
508 |
509 |
510 | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
511 | orig_width, orig_height = image.size
512 | aspect_ratio = orig_width / orig_height
513 |
514 | # calculate the existing image aspect ratio
515 | target_ratios = set(
516 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
517 | i * j <= max_num and i * j >= min_num)
518 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
519 |
520 | # find the closest aspect ratio to the target
521 | target_aspect_ratio = find_closest_aspect_ratio(
522 | aspect_ratio, target_ratios, orig_width, orig_height, image_size)
523 |
524 | # calculate the target width and height
525 | target_width = image_size * target_aspect_ratio[0]
526 | target_height = image_size * target_aspect_ratio[1]
527 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
528 |
529 | # resize the image
530 | resized_img = image.resize((target_width, target_height))
531 | processed_images = []
532 | for i in range(blocks):
533 | box = (
534 | (i % (target_width // image_size)) * image_size,
535 | (i // (target_width // image_size)) * image_size,
536 | ((i % (target_width // image_size)) + 1) * image_size,
537 | ((i // (target_width // image_size)) + 1) * image_size
538 | )
539 | # split the image
540 | split_img = resized_img.crop(box)
541 | processed_images.append(split_img)
542 | assert len(processed_images) == blocks
543 | if use_thumbnail and len(processed_images) != 1:
544 | thumbnail_img = image.resize((image_size, image_size))
545 | processed_images.append(thumbnail_img)
546 | return processed_images
547 |
--------------------------------------------------------------------------------
/internvl/train/trainer_monkey_patch.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import torch
5 | import torch.nn as nn
6 | import transformers
7 | from transformers import Trainer, logging
8 | from transformers.trainer import is_sagemaker_mp_enabled
9 |
10 | logger = logging.get_logger(__name__)
11 |
12 |
13 | def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
14 | if var_name.startswith('internvl.'):
15 | var_name = var_name[len('internvl.'):]
16 | if var_name in ('query_tokens', 'logit_scale',):
17 | return 0
18 | if var_name.startswith('clip_projector.'):
19 | return vit_num_max_layer
20 | if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
21 | var_name == 'text_projection':
22 | return llama_num_max_layer
23 | if var_name.startswith('vision_model.'):
24 | if 'embeddings.' in var_name:
25 | return 0
26 | if 'layers.' in var_name:
27 | var_name = var_name.split('layers.')[-1]
28 | layer_id = int(var_name.split('.')[0])
29 | return layer_id + 1
30 | if var_name.startswith('qllama.'):
31 | if 'embed_tokens' in var_name:
32 | return 0
33 | if 'layers.' in var_name:
34 | var_name = var_name.split('layers.')[-1]
35 | layer_id = int(var_name.split('.')[0])
36 | return layer_id + 1
37 | else:
38 | return llama_num_max_layer
39 | return 0
40 |
41 |
42 | def param_classification(name):
43 | if name.startswith('internvl.'):
44 | name = name[len('internvl.'):]
45 | if name in ['query_tokens', 'text_projection', 'logit_scale']:
46 | return 'qllama'
47 | elif name.startswith('vision_model.'):
48 | return 'vit'
49 | elif name.startswith('qllama.'):
50 | return 'qllama'
51 | elif name.startswith('clip_projector.'):
52 | return 'vit'
53 | elif name.startswith('clip_projector2.'):
54 | return 'qllama'
55 | elif name.startswith('itm_head.'):
56 | return 'qllama'
57 | else:
58 | return 'other'
59 |
60 |
61 | def create_optimizer(self):
62 | """
63 | Setup the optimizer.
64 |
65 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
66 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
67 | """
68 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
69 |
70 | parameter_groups = {}
71 | try: # for stage2 model
72 | vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
73 | qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
74 | except: # for stage3 model
75 | vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
76 | qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
77 | print('vit_num_layers:', vit_num_layers)
78 | print('qllama_num_layers:', qllama_num_layers)
79 |
80 | vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
81 | qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
82 | qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
83 | print('vit_layer_decay_rate:', vit_layer_decay_rate)
84 | print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
85 | print('qllama_lr_scale:', qllama_lr_scale)
86 |
87 | for name, param in opt_model.named_parameters():
88 | if not param.requires_grad:
89 | continue # frozen weights
90 | if len(param.shape) == 1 or name.endswith('.bias'):
91 | group_name = 'no_decay'
92 | this_weight_decay = 0.
93 | else:
94 | group_name = 'decay'
95 | this_weight_decay = self.args.weight_decay
96 |
97 | cls = param_classification(name)
98 | layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
99 | group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
100 | if group_name not in parameter_groups:
101 | if cls == 'vit':
102 | scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
103 | elif cls == 'qllama':
104 | scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
105 | scale = scale * qllama_lr_scale
106 | else:
107 | scale = 1.0
108 | scale = min(1.0, scale)
109 | parameter_groups[group_name] = {
110 | 'weight_decay': this_weight_decay,
111 | 'params': [],
112 | 'param_names': [],
113 | 'lr_scale': scale,
114 | 'group_name': group_name,
115 | 'lr': scale * self.args.learning_rate,
116 | }
117 | parameter_groups[group_name]['params'].append(param)
118 | parameter_groups[group_name]['param_names'].append(name)
119 |
120 | rank = torch.distributed.get_rank()
121 | if rank == 0:
122 | to_display = {}
123 | for key in parameter_groups:
124 | to_display[key] = {
125 | 'param_names': parameter_groups[key]['param_names'],
126 | 'lr_scale': parameter_groups[key]['lr_scale'],
127 | 'lr': parameter_groups[key]['lr'],
128 | 'weight_decay': parameter_groups[key]['weight_decay'],
129 | }
130 | print('Param groups = %s' % json.dumps(to_display, indent=2))
131 |
132 | optimizer_grouped_parameters = list(parameter_groups.values())
133 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
134 |
135 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
136 | if optimizer_cls.__name__ == 'Adam8bit':
137 | import bitsandbytes
138 |
139 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
140 |
141 | skipped = 0
142 | for module in opt_model.modules():
143 | if isinstance(module, nn.Embedding):
144 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
145 | logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
146 | manager.register_module_override(module, 'weight', {'optim_bits': 32})
147 | logger.debug(f'bitsandbytes: will optimize {module} in fp32')
148 | logger.info(f'skipped: {skipped / 2 ** 20}M params')
149 |
150 | if is_sagemaker_mp_enabled():
151 | import smdistributed.modelparallel.torch as smp
152 | self.optimizer = smp.DistributedOptimizer(self.optimizer)
153 |
154 | return self.optimizer
155 |
156 |
157 | def replace_create_optimizer():
158 | print('Replace original create_optimizer with custom create_optimizer')
159 | transformers.Trainer.create_optimizer = create_optimizer
160 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate<1
2 | bitsandbytes==0.42.0
3 | datasets
4 | decord
5 | deepspeed==0.13.5
6 | einops==0.6.1
7 | einops-exts==0.0.4
8 | huggingface_hub
9 | imageio
10 | numpy==1.26.4
11 | opencv-python
12 | orjson
13 | peft==0.10.0
14 | pycocoevalcap
15 | pyyaml
16 | scikit-learn>=1.2.2
17 | scipy
18 | sentencepiece==0.1.99
19 | shortuuid
20 | tensorboardX
21 | termcolor
22 | timm==0.9.12
23 | tokenizers==0.15.1
24 | torch==2.1.0
25 | torchvision==0.16.0
26 | tqdm
27 | transformers==4.37.2
28 | yacs
29 |
--------------------------------------------------------------------------------
/shell/data_llava_finetune.json:
--------------------------------------------------------------------------------
1 | {
2 | "llava_665k": {
3 | "root": "playground/data/",
4 | "annotation": "playground/llava_v1_5_mix665k.jsonl",
5 | "data_augment": false,
6 | "repeat_time": 1,
7 | "length": 665298
8 | }
9 | }
10 |
--------------------------------------------------------------------------------
/shell/mono_internvl_finetune_llava_slurm.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | PARTITION=${PARTITION:-"Your partition"}
4 | GPUS=${GPUS:-8}
5 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
6 | QUOTA_TYPE=${QUOTA_TYPE:-"spot"}
7 | NODES=$((GPUS / GPUS_PER_NODE))
8 | CPUS_PER_TASK=${CPUS_PER_TASK:-12}
9 | SRUN_ARGS=${SRUN_ARGS:-""}
10 | BATCH_SIZE=${BATCH_SIZE:-128}
11 | PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-4}
12 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))
13 |
14 | MODEL=${MODEL:-"Path to your model"}
15 | OUTPUT_DIR=${OUTPUT_DIR:-"Path to your output directory"}
16 | mkdir -p "$OUTPUT_DIR"
17 |
18 | export PYTHONPATH="${PYTHONPATH}:$(pwd)"
19 |
20 | srun -p ${PARTITION} \
21 | --gres=gpu:${GPUS_PER_NODE} \
22 | --nodes=${NODES} \
23 | --ntasks=${GPUS} \
24 | --ntasks-per-node=${GPUS_PER_NODE} \
25 | --cpus-per-task=${CPUS_PER_TASK} \
26 | --kill-on-bad-exit=1 \
27 | --quotatype=${QUOTA_TYPE} \
28 | ${SRUN_ARGS} \
29 | python -u internvl/train/internvl_chat_finetune.py \
30 | --model_name_or_path ${MODEL} \
31 | --vision_type patch \
32 | --conv_style "internlm2-chat" \
33 | --output_dir ${OUTPUT_DIR} \
34 | --meta_path "./shell/data_llava_finetune.json" \
35 | --overwrite_output_dir True \
36 | --force_image_size 448 \
37 | --max_dynamic_patch 6 \
38 | --down_sample_ratio 0.5 \
39 | --drop_path_rate 0.1 \
40 | --pad2square False \
41 | --freeze_llm False \
42 | --freeze_mlp False \
43 | --freeze_backbone False \
44 | --unfreeze_ve True \
45 | --vision_select_layer -1 \
46 | --use_data_resampling False \
47 | --dataloader_num_workers 4 \
48 | --bf16 True \
49 | --num_train_epochs 1 \
50 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
51 | --gradient_accumulation_steps ${GRADIENT_ACC} \
52 | --evaluation_strategy "no" \
53 | --save_strategy "steps" \
54 | --save_steps 3000 \
55 | --save_total_limit 3 \
56 | --learning_rate 4e-5 \
57 | --weight_decay 0.01 \
58 | --warmup_ratio 0.03 \
59 | --lr_scheduler_type "cosine" \
60 | --logging_steps 1 \
61 | --max_seq_length 2048 \
62 | --do_train True \
63 | --grad_checkpoint True \
64 | --group_by_length True \
65 | --dynamic_image_size True \
66 | --use_thumbnail True \
67 | --ps_version 'v2' \
68 | --deepspeed "./shell/zero_stage1_config.json" \
69 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
70 |
--------------------------------------------------------------------------------
/shell/mono_internvl_finetune_llava_torchrun.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | GPUS=8
4 | PER_DEVICE_BATCH_SIZE=4
5 | BATCH_SIZE=128
6 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))
7 |
8 | MODEL=${MODEL:-"Path to your model"}
9 | OUTPUT_DIR=${OUTPUT_DIR:-"Path to your output directory"}
10 | mkdir -p "$OUTPUT_DIR"
11 |
12 | export PYTHONPATH="${PYTHONPATH}:$(pwd)"
13 |
14 | torchrun --nproc_per_node=$GPUS --master_port=29501 \
15 | internvl/train/internvl_chat_finetune.py \
16 | --model_name_or_path ${MODEL} \
17 | --vision_type patch \
18 | --conv_style "internlm2-chat" \
19 | --output_dir ${OUTPUT_DIR} \
20 | --meta_path "./shell/data_llava_finetune.json" \
21 | --overwrite_output_dir True \
22 | --force_image_size 448 \
23 | --max_dynamic_patch 6 \
24 | --down_sample_ratio 0.5 \
25 | --drop_path_rate 0.1 \
26 | --pad2square False \
27 | --freeze_llm False \
28 | --freeze_mlp False \
29 | --freeze_backbone False \
30 | --unfreeze_ve True \
31 | --vision_select_layer -1 \
32 | --use_data_resampling False \
33 | --dataloader_num_workers 4 \
34 | --bf16 True \
35 | --num_train_epochs 1 \
36 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
37 | --gradient_accumulation_steps ${GRADIENT_ACC} \
38 | --evaluation_strategy "no" \
39 | --save_strategy "steps" \
40 | --save_steps 3000 \
41 | --save_total_limit 3 \
42 | --learning_rate 4e-5 \
43 | --weight_decay 0.01 \
44 | --warmup_ratio 0.03 \
45 | --lr_scheduler_type "cosine" \
46 | --logging_steps 1 \
47 | --max_seq_length 2048 \
48 | --do_train True \
49 | --grad_checkpoint True \
50 | --group_by_length True \
51 | --dynamic_image_size True \
52 | --use_thumbnail True \
53 | --ps_version 'v2' \
54 | --deepspeed "./shell/zero_stage1_config.json" \
55 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
56 |
--------------------------------------------------------------------------------
/shell/zero_stage1_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e9,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e9,
9 | "contiguous_gradients": true
10 | },
11 | "fp16": {
12 | "enabled": "auto",
13 | "auto_cast": true,
14 | "loss_scale": 0,
15 | "initial_scale_power": 32,
16 | "loss_scale_window": 1000,
17 | "hysteresis": 2,
18 | "min_loss_scale": 1
19 | },
20 | "bf16": {
21 | "enabled": "auto"
22 | },
23 | "optimizer": {
24 | "type": "AdamW",
25 | "params": {
26 | "lr": "auto",
27 | "betas": [
28 | 0.9,
29 | 0.999
30 | ],
31 | "eps": 1e-8,
32 | "weight_decay": "auto"
33 | }
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 2000,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": true
41 | }
42 |
--------------------------------------------------------------------------------