├── .gitignore ├── LICENSE ├── README.md ├── examples └── image1.jpg ├── images ├── fig1.jpg └── fig2.jpg ├── internvl ├── conversation.py ├── dist_utils.py ├── model │ ├── internlm2 │ │ ├── configuration_internlm2.py │ │ ├── modeling_internlm2.py │ │ ├── modeling_internlm2_ve.py │ │ ├── tokenization_internlm2.py │ │ └── tokenization_internlm2_fast.py │ └── internvl_chat │ │ ├── __init__.py │ │ ├── configuration_intern_vit.py │ │ ├── configuration_internvl_chat.py │ │ ├── flash_attention.py │ │ ├── modeling_intern_vit.py │ │ └── modeling_internvl_chat.py ├── patch │ ├── __init__.py │ ├── internlm2_packed_training_patch.py │ ├── llama2_flash_attn_monkey_patch.py │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_rmsnorm_monkey_patch.py │ ├── pad_data_collator.py │ ├── qwen2_packed_training_patch.py │ ├── train_dataloader_patch.py │ └── train_sampler_patch.py ├── serve │ ├── __init__.py │ ├── constants.py │ ├── mm_utils.py │ ├── model_worker.py │ └── utils.py └── train │ ├── __init__.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_packed.py │ ├── internvl_chat_finetune.py │ └── trainer_monkey_patch.py ├── requirements.txt └── shell ├── data_llava_finetune.json ├── mono_internvl_finetune_llava_slurm.sh ├── mono_internvl_finetune_llava_torchrun.sh └── zero_stage1_config.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .idea/ 163 | 164 | .DS_Store 165 | 166 | 167 | /playground/data/ 168 | /playground/*.jsonl -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 OpenGVLab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mono-InternVL: Pushing the Boundaries of Monolithic Multimodal Large Language Models with Endogenous Visual Pre-training 2 | 3 | [[📜 Paper]](https://arxiv.org/abs/2410.08202) [[⭐️Project Page]](https://internvl.github.io/blog/2024-10-10-Mono-InternVL/) [[🤗 Model]](https://huggingface.co/collections/OpenGVLab/mono-internvl-6707cb402afb22f1e29f4d2b) [[📝 Chinese Post]](https://mp.weixin.qq.com/s/FmjG0Gp5ow7mm2Vzd9ppPg) 4 | 5 | 6 | 7 |

8 | radar chart 9 |
10 |
11 | architecture 12 |

13 | 14 | ## 📰 News 15 | - **2025.3**: We release the SFT code on LLaVA-v1.5-mix665k dataset. We also release the [258M synthetic data](https://huggingface.co/datasets/OpenGVLab/Mono-InternVL-2B-Synthetic-Data) used in S1.2 to boost future research. 16 | - **2025.2**: 🎉🎉 Mono-InternVL is accepted by **CVPR 2025**. Also check out our [**SynerGen-VL**](https://huggingface.co/papers/2412.09604) (CVPR 2025) that extends the monolithic structure to unified image generation and multimodal understanding, which will be open-sourced soon. 17 | - **2024.11**: Mono-InternVL is supported by [lmdeploy](https://github.com/InternLM/lmdeploy/pull/2727). 18 | - **2024.11**: Mono-InternVL is supported by [vllm](https://github.com/vllm-project/vllm/pull/9528). 19 | 20 | 21 | ## ⭐️ Introduction 22 | 23 | We release Mono-InternVL, a **monolithic** multimodal large language model (MLLM) that integrates visual encoding and textual decoding into a single LLM. In Mono-InternVL, a set of visual experts is embedded into the pre-trained LLM via a **mixture-of-experts (MoE) mechanism**. By freezing the LLM, Mono-InternVL ensures that visual capabilities are optimized without compromising the pre-trained language knowledge. Based on this structure, an innovative **Endogenous Visual Pretraining (EViP)** is introduced to realize coarse-to-fine visual learning. 24 | 25 | 26 | Mono-InternVL achieves superior performance compared to state-of-the-art MLLM Mini-InternVL-2B-1.5 and significantly outperforms other monolithic MLLMs, as shown in the [radar chart](#radar) above. Meanwhile, it achieves better deployment efficiency, with first token latency reduced by up to 67%. 27 | 28 | 29 | For more details, please refer to our [paper](https://arxiv.org/abs/2410.08202). 30 | 31 | 32 | ## 📊 Performance 33 | | Benchmark | Chameleon-7B | EVE-7B (HD) | Emu3 | Mini-InternVL-2B-1-5 | Mono-InternVL-2B | 34 | | :--------------------------: | :----------: | :---------: | :--------: | :------------------: | :--------------: | 35 | | Type | Monolithic | Monolithic | Monolithic | Modular | Monolithic | 36 | | #Activated Params | 7B | 7B | 8B | 2.2B | 1.8B | 37 | | | | | | | | 38 | | MMVet | 8.3 | 25.7 | 37.2 | 39.3 | 40.1 | 39 | | MMMUval | 25.4 | 32.6 | 31.6 | 34.6 | 33.7 | 40 | | MMEsum | 170 | 1628 | — | 1902 | 1875 | 41 | | MMBench-ENtest | 31.1 | 52.3 | 58.5 | 70.9 | 65.5 | 42 | | MathVistatestmini | 22.3 | 34.2 | — | 41.1 | 45.7 | 43 | | SEED-Image | 30.6 | 64.6 | 68.2 | 69.8 | 67.4 | 44 | | OCRBench | 7 | 398 | 687 | 654 | 767 | 45 | | Hallusion-Bench | 17.1 | 26.4 | — | 37.5 | 34.8 | 46 | | CCBenchdev | 3.5 | 16.3 | — | 63.5 | 66.3 | 47 | | Avgmultimodal | 16.1 | 38.9 | — | 54.4 | 55.2 | 48 | | | | | | | | 49 | | TextVQAval | 4.8 | 56.8 | 64.7 | 70.5 | 72.6 | 50 | | SQA-Itest | 47.2 | 64.9 | 89.2 | 84.9 | 93.6 | 51 | | GQAtest | — | 62.6 | 60.3 | 61.6 | 59.5 | 52 | | DocVQAtest | 1.5 | 53.0 | 76.3 | 85.0 | 80.0 | 53 | | AI2Dtest | 46.0 | 61.0 | 70.0 | 69.8 | 68.6 | 54 | | ChartQAtest | 2.9 | 59.1 | 68.6 | 74.8 | 73.7 | 55 | | InfoVQAtest | 5.0 | 25.0 | 43.8 | 55.4 | 43.0 | 56 | | AvgVQA | 17.9 | 54.6 | 67.6 | 71.7 | 70.1 | 57 | 58 | > * Sources of the results include the original papers, our evaluation with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit), and [OpenCompass](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME). 59 | > * Average scores are computed by normalizing each metric to a range between 0 and 100. 60 | > * Please note that evaluating the same model using different testing toolkits can result in slight differences, which is normal. Updates to code versions and variations in environment and hardware can also cause minor discrepancies in results. 61 | 62 | 63 | ## 🚀 Inference 64 | 65 | We provide an example code to run Mono-InternVL-2B inference using `transformers`. 66 | 67 | > Please use transformers==4.37.2 to ensure the model works normally. 68 | 69 |
70 | Inference with Transformers (click to expand) 71 | 72 | ```python 73 | import numpy as np 74 | import torch 75 | import torchvision.transforms as T 76 | from decord import VideoReader, cpu 77 | from PIL import Image 78 | from torchvision.transforms.functional import InterpolationMode 79 | from transformers import AutoModel, AutoTokenizer 80 | 81 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 82 | IMAGENET_STD = (0.229, 0.224, 0.225) 83 | 84 | def build_transform(input_size): 85 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 86 | transform = T.Compose([ 87 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 88 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 89 | T.ToTensor(), 90 | T.Normalize(mean=MEAN, std=STD) 91 | ]) 92 | return transform 93 | 94 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 95 | best_ratio_diff = float('inf') 96 | best_ratio = (1, 1) 97 | area = width * height 98 | for ratio in target_ratios: 99 | target_aspect_ratio = ratio[0] / ratio[1] 100 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 101 | if ratio_diff < best_ratio_diff: 102 | best_ratio_diff = ratio_diff 103 | best_ratio = ratio 104 | elif ratio_diff == best_ratio_diff: 105 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 106 | best_ratio = ratio 107 | return best_ratio 108 | 109 | def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): 110 | orig_width, orig_height = image.size 111 | aspect_ratio = orig_width / orig_height 112 | 113 | # calculate the existing image aspect ratio 114 | target_ratios = set( 115 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 116 | i * j <= max_num and i * j >= min_num) 117 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 118 | 119 | # find the closest aspect ratio to the target 120 | target_aspect_ratio = find_closest_aspect_ratio( 121 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 122 | 123 | # calculate the target width and height 124 | target_width = image_size * target_aspect_ratio[0] 125 | target_height = image_size * target_aspect_ratio[1] 126 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 127 | 128 | # resize the image 129 | resized_img = image.resize((target_width, target_height)) 130 | processed_images = [] 131 | for i in range(blocks): 132 | box = ( 133 | (i % (target_width // image_size)) * image_size, 134 | (i // (target_width // image_size)) * image_size, 135 | ((i % (target_width // image_size)) + 1) * image_size, 136 | ((i // (target_width // image_size)) + 1) * image_size 137 | ) 138 | # split the image 139 | split_img = resized_img.crop(box) 140 | processed_images.append(split_img) 141 | assert len(processed_images) == blocks 142 | if use_thumbnail and len(processed_images) != 1: 143 | thumbnail_img = image.resize((image_size, image_size)) 144 | processed_images.append(thumbnail_img) 145 | return processed_images 146 | 147 | def load_image(image_file, input_size=448, max_num=12): 148 | image = Image.open(image_file).convert('RGB') 149 | transform = build_transform(input_size=input_size) 150 | images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) 151 | pixel_values = [transform(image) for image in images] 152 | pixel_values = torch.stack(pixel_values) 153 | return pixel_values 154 | 155 | 156 | path = 'OpenGVLab/Mono-InternVL-2B' 157 | model = AutoModel.from_pretrained( 158 | path, 159 | torch_dtype=torch.bfloat16, 160 | low_cpu_mem_usage=True, 161 | trust_remote_code=True).eval().cuda() 162 | tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False) 163 | 164 | # set the max number of tiles in `max_num` 165 | pixel_values = load_image('./examples/image1.jpg', max_num=12).to(torch.bfloat16).cuda() 166 | generation_config = dict(max_new_tokens=1024, do_sample=True) 167 | 168 | # pure-text conversation 169 | question = 'Hello, who are you?' 170 | response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True) 171 | print(f'User: {question}\nAssistant: {response}') 172 | 173 | question = 'Can you tell me a story?' 174 | response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True) 175 | print(f'User: {question}\nAssistant: {response}') 176 | 177 | # single-image single-round conversation 178 | question = '\nPlease describe the image shortly.' 179 | response = model.chat(tokenizer, pixel_values, question, generation_config) 180 | print(f'User: {question}\nAssistant: {response}') 181 | 182 | # single-image multi-round conversation 183 | question = '\nPlease describe the image in detail.' 184 | response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True) 185 | print(f'User: {question}\nAssistant: {response}') 186 | 187 | question = 'Please write a poem according to the image.' 188 | response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True) 189 | print(f'User: {question}\nAssistant: {response}') 190 | ``` 191 | 192 |
193 | 194 | 195 |
196 | Inference with LMDeploy 197 | 198 | Please install lmdeploy>=0.6.3 for Mono-InternVL support. 199 | 200 | ```python 201 | from lmdeploy import pipeline 202 | from lmdeploy.vl import load_image 203 | 204 | image = load_image('./examples/image1.jpg') 205 | pipe = pipeline('OpenGVLab/Mono-InternVL-2B') 206 | response = pipe(('Please describe the image shortly.', image)) 207 | print(response.text) 208 | ``` 209 |
210 | 211 | ## 🔥 Supervised Finetuning 212 | 213 | Currently we provide the supervised finetuning (S2 instruction tuning) code on the LLaVA-v1.5-mix665k dataset. For details on the dataset, please refer to [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA). 214 | 215 |
216 | Installation 217 | 218 | - Clone this repository: 219 | 220 | ```bash 221 | git clone https://github.com/OpenGVLab/Mono-InternVL.git 222 | ``` 223 | 224 | - Create a conda virtual environment and activate it: 225 | 226 | ```bash 227 | conda create -n monointernvl python=3.9 -y 228 | conda activate monointernvl 229 | ``` 230 | 231 | - Install dependencies using `requirements.txt`: 232 | 233 | ```bash 234 | pip install -r requirements.txt 235 | ``` 236 | 237 | - Additional: Install `flash-attn==2.5.6`: 238 | 239 | ```bash 240 | pip install flash-attn==2.5.6 --no-build-isolation 241 | ``` 242 | 243 | Alternatively you can compile from source: 244 | 245 | ```bash 246 | git clone https://github.com/Dao-AILab/flash-attention.git 247 | cd flash-attention 248 | git checkout v2.5.6 249 | python setup.py install 250 | ``` 251 |
252 | 253 |
254 | Dataset Preparation 255 | 256 | #### LLaVA-v1.5-mix665k Dataset 257 | 258 | 1. Download the instruction tuning data: 259 | ```sh 260 | mkdir playground 261 | wget https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/resolve/main/llava_v1_5_mix665k.json -P playground/ 262 | ``` 263 | 264 | 2. Download image datasets: 265 | 266 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip) 267 | - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) 268 | - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing) 269 | - TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) 270 | - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip) 271 | 272 | 3. Organize data as follows: 273 | 274 | ```none 275 | playground/ 276 | ├── data/ 277 | │ ├── coco/train2017/ 278 | │ ├── gqa/images/ 279 | │ ├── ocr_vqa/images/ 280 | │ ├── textvqa/train_images/ 281 | │ └── vg/ 282 | │ ├── VG_100K/ 283 | │ └── VG_100K_2/ 284 | └── llava_v1_5_mix665k.json 285 | ``` 286 | 287 | #### Custom Dataset 288 | 289 | For custom dataset, format your data in to a JSONL file, where each entry is a dictionary organized in the following format (similar to `llava_v1_5_mix665k.json`): 290 | 291 | ```python 292 | { 293 | "id": "000000120375", 294 | "image": "coco/train2017/000000120375.jpg", 295 | "conversations": [ 296 | { 297 | "from": "human", 298 | "value": "\nWhat type of vehicle is driving down the street in the image?" 299 | }, 300 | { 301 | "from": "gpt", 302 | "value": "A red sports utility vehicle (SUV) is driving down the street in the image." 303 | }, 304 | { 305 | "from": "human", 306 | "value": "Is the street crowded with people?" 307 | }, 308 | { 309 | "from": "gpt", 310 | "value": "Yes, the street is filled with a considerable number of people, which indicates that the area is busy." 311 | } 312 | # (more turns ...) 313 | ] 314 | } 315 | ``` 316 | 317 | Then modify the metadata file `shell/data_llava_finetune.json`: 318 | 319 | ```python 320 | { 321 | "name of your dataset": { 322 | "root": "playground/data/", # combination of "root" and "image" in the JSONL gives the complete image path 323 | "annotation": "path to your JSONL", 324 | "data_augment": false, 325 | "repeat_time": 1, 326 | "length": 12345 # change to the actual number of samples in your dataset 327 | } 328 | } 329 | ``` 330 | 331 |
332 | 333 |
334 | Model Preparation 335 | 336 | We provide pretrained models of different stages (S1.1 concept learning, S1.2 semantic learning, S1.3 alignment learning). 337 | Choose from the following models and download the weights to `workdirs/` folder. 338 | 339 | 340 | | model name | download | size | 341 | | ----------------------- | ---------------------------------------------------------------------- |:------:| 342 | | Mono-InternVL-2B-S1-1 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-1) | 6.2 GB | 343 | | Mono-InternVL-2B-S1-2 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-2) | 6.2 GB | 344 | | Mono-InternVL-2B-S1-3 | 🤗 [HF link](https://huggingface.co/OpenGVLab/Mono-InternVL-2B-S1-3) | 6.2 GB | 345 | 346 | 347 | ```sh 348 | mkdir workdirs 349 | cd workdirs/ 350 | # pip install -U huggingface_hub 351 | huggingface-cli download --resume-download --local-dir-use-symlinks False OpenGVLab/Mono-InternVL-2B-S1-1 --local-dir Mono-InternVL-2B-S1-1 352 | ``` 353 | 354 | The directory structure is: 355 | 356 | ```sh 357 | workdirs/ 358 | ├── Mono-InternVL-2B-S1-1/ 359 | ├── Mono-InternVL-2B-S1-2/ 360 | └── Mono-InternVL-2B-S1-3/ 361 | ``` 362 |
363 | 364 |
365 | Training 366 | 367 | Finetuning takes around 12 hours on 8x A100 (80G) GPUs. 368 | 369 | #### Single Node Multi-GPU 370 | ```sh 371 | MODEL="./workdirs/Mono-InternVL-2B-S1-3" OUTPUT_DIR="./workdirs/mono_internvl_llava_sft" sh shell/mono_internvl_finetune_llava_torchrun.sh 372 | ``` 373 | 374 | #### Slurm Cluster 375 | ```sh 376 | PARTITION="your partition" MODEL="./workdirs/Mono-InternVL-2B-S1-3" OUTPUT_DIR="./workdirs/mono_internvl_llava_sft" sh shell/mono_internvl_finetune_llava_slurm.sh 377 | ``` 378 | 379 |
380 | 381 | 382 | ## 🎫 License 383 | 384 | This project is released under the [MIT License](LICENSE). 385 | 386 | ## 🖊️ Citation 387 | 388 | If you find this work helpful in your research, please consider giving this repo a star ⭐ and citing our paper: 389 | 390 | ```bibtex 391 | @article{luo2024mono, 392 | title={Mono-InternVL: Pushing the Boundaries of Monolithic Multimodal Large Language Models with Endogenous Visual Pre-training}, 393 | author={Luo, Gen and Yang, Xue and Dou, Wenhan and Wang, Zhaokai and Liu, Jiawen and Dai, Jifeng and Qiao, Yu and Zhu, Xizhou}, 394 | journal={arXiv preprint arXiv:2410.08202}, 395 | year={2024} 396 | } 397 | ``` 398 | -------------------------------------------------------------------------------- /examples/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/examples/image1.jpg -------------------------------------------------------------------------------- /images/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/images/fig1.jpg -------------------------------------------------------------------------------- /images/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/images/fig2.jpg -------------------------------------------------------------------------------- /internvl/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import subprocess 4 | from datetime import timedelta 5 | 6 | import deepspeed 7 | import torch 8 | import torch.multiprocessing as mp 9 | from torch import distributed as dist 10 | 11 | timeout = timedelta(minutes=60) 12 | 13 | 14 | def _find_free_port(): 15 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 17 | # Binding to port 0 will cause the OS to find an available port for us 18 | sock.bind(('', 0)) 19 | port = sock.getsockname()[1] 20 | sock.close() 21 | # NOTE: there is still a chance the port could be taken by other processes. 22 | return port 23 | 24 | 25 | def _is_free_port(port): 26 | ips = socket.gethostbyname_ex(socket.gethostname())[-1] 27 | ips.append('localhost') 28 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 29 | return all(s.connect_ex((ip, port)) != 0 for ip in ips) 30 | 31 | 32 | def init_dist(launcher, backend='nccl', **kwargs): 33 | if mp.get_start_method(allow_none=True) is None: 34 | mp.set_start_method('spawn') 35 | if launcher == 'pytorch': 36 | _init_dist_pytorch(backend, **kwargs) 37 | elif launcher == 'mpi': 38 | _init_dist_mpi(backend, **kwargs) 39 | elif launcher == 'slurm': 40 | _init_dist_slurm(backend, **kwargs) 41 | else: 42 | raise ValueError(f'Invalid launcher type: {launcher}') 43 | 44 | 45 | def _init_dist_pytorch(backend, **kwargs): 46 | 47 | rank = int(os.environ['RANK']) 48 | num_gpus = torch.cuda.device_count() 49 | torch.cuda.set_device(rank % num_gpus) 50 | # dist.init_process_group(backend=backend, **kwargs) 51 | deepspeed.init_distributed(dist_backend=backend) 52 | 53 | 54 | def _init_dist_mpi(backend, **kwargs): 55 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 56 | torch.cuda.set_device(local_rank) 57 | if 'MASTER_PORT' not in os.environ: 58 | # 29500 is torch.distributed default port 59 | os.environ['MASTER_PORT'] = '29500' 60 | if 'MASTER_ADDR' not in os.environ: 61 | raise KeyError('The environment variable MASTER_ADDR is not set') 62 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] 63 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] 64 | dist.init_process_group(backend=backend, **kwargs) 65 | 66 | 67 | def _init_dist_slurm(backend, port=None): 68 | """Initialize slurm distributed training environment. 69 | 70 | If argument ``port`` is not specified, then the master port will be system 71 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 72 | environment variable, then a default port ``29500`` will be used. 73 | 74 | Args: 75 | backend (str): Backend of torch.distributed. 76 | port (int, optional): Master port. Defaults to None. 77 | """ 78 | proc_id = int(os.environ['SLURM_PROCID']) 79 | ntasks = int(os.environ['SLURM_NTASKS']) 80 | node_list = os.environ['SLURM_NODELIST'] 81 | num_gpus = torch.cuda.device_count() 82 | torch.cuda.set_device(proc_id % num_gpus) 83 | addr = subprocess.getoutput( 84 | f'scontrol show hostname {node_list} | head -n1') 85 | # specify master port 86 | if port is not None: 87 | os.environ['MASTER_PORT'] = str(port) 88 | elif 'MASTER_PORT' in os.environ: 89 | pass # use MASTER_PORT in the environment variable 90 | else: 91 | # if torch.distributed default port(29500) is available 92 | # then use it, else find a free port 93 | if _is_free_port(29500): 94 | os.environ['MASTER_PORT'] = '29500' 95 | else: 96 | os.environ['MASTER_PORT'] = str(_find_free_port()) 97 | # use MASTER_ADDR in the environment variable if it already exists 98 | if 'MASTER_ADDR' not in os.environ: 99 | os.environ['MASTER_ADDR'] = addr 100 | os.environ['WORLD_SIZE'] = str(ntasks) 101 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 102 | os.environ['RANK'] = str(proc_id) 103 | # dist.init_process_group(backend=backend, timeout=timeout) 104 | deepspeed.init_distributed(dist_backend=backend) 105 | -------------------------------------------------------------------------------- /internvl/model/internlm2/configuration_internlm2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ InternLM2 model configuration""" 17 | 18 | from transformers.configuration_utils import PretrainedConfig 19 | from transformers.utils import logging 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 24 | 25 | 26 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig 27 | class InternLM2Config(PretrainedConfig): 28 | r""" 29 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate 30 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a 31 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. 32 | 33 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 34 | documentation from [`PretrainedConfig`] for more information. 35 | 36 | 37 | Args: 38 | vocab_size (`int`, *optional*, defaults to 32000): 39 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the 40 | `inputs_ids` passed when calling [`InternLM2Model`] 41 | hidden_size (`int`, *optional*, defaults to 4096): 42 | Dimension of the hidden representations. 43 | intermediate_size (`int`, *optional*, defaults to 11008): 44 | Dimension of the MLP representations. 45 | num_hidden_layers (`int`, *optional*, defaults to 32): 46 | Number of hidden layers in the Transformer encoder. 47 | num_attention_heads (`int`, *optional*, defaults to 32): 48 | Number of attention heads for each attention layer in the Transformer encoder. 49 | num_key_value_heads (`int`, *optional*): 50 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 51 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 52 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 53 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 54 | by meanpooling all the original heads within that group. For more details checkout [this 55 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 56 | `num_attention_heads`. 57 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 58 | The non-linear activation function (function or string) in the decoder. 59 | max_position_embeddings (`int`, *optional*, defaults to 2048): 60 | The maximum sequence length that this model might ever be used with. Typically set this to something large 61 | just in case (e.g., 512 or 1024 or 2048). 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 65 | The epsilon used by the rms normalization layers. 66 | use_cache (`bool`, *optional*, defaults to `True`): 67 | Whether or not the model should return the last key/values attentions (not used by all models). Only 68 | relevant if `config.is_decoder=True`. 69 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 70 | Whether to tie weight embeddings 71 | Example: 72 | 73 | """ 74 | model_type = 'internlm2' 75 | _auto_class = 'AutoConfig' 76 | 77 | def __init__( # pylint: disable=W0102 78 | self, 79 | vocab_size=103168, 80 | hidden_size=4096, 81 | intermediate_size=11008, 82 | num_hidden_layers=32, 83 | num_attention_heads=32, 84 | num_key_value_heads=None, 85 | hidden_act='silu', 86 | max_position_embeddings=2048, 87 | initializer_range=0.02, 88 | rms_norm_eps=1e-6, 89 | use_cache=True, 90 | pad_token_id=0, 91 | bos_token_id=1, 92 | eos_token_id=2, 93 | tie_word_embeddings=False, 94 | bias=True, 95 | rope_theta=10000, 96 | rope_scaling=None, 97 | attn_implementation='eager', 98 | **kwargs, 99 | ): 100 | self.vocab_size = vocab_size 101 | self.max_position_embeddings = max_position_embeddings 102 | self.hidden_size = hidden_size 103 | self.intermediate_size = intermediate_size 104 | self.num_hidden_layers = num_hidden_layers 105 | self.num_attention_heads = num_attention_heads 106 | self.bias = bias 107 | 108 | if num_key_value_heads is None: 109 | num_key_value_heads = num_attention_heads 110 | self.num_key_value_heads = num_key_value_heads 111 | 112 | self.hidden_act = hidden_act 113 | self.initializer_range = initializer_range 114 | self.rms_norm_eps = rms_norm_eps 115 | self.use_cache = use_cache 116 | self.rope_theta = rope_theta 117 | self.rope_scaling = rope_scaling 118 | self._rope_scaling_validation() 119 | 120 | 121 | self.attn_implementation = attn_implementation 122 | if self.attn_implementation is None: 123 | self.attn_implementation = 'eager' 124 | super().__init__( 125 | pad_token_id=pad_token_id, 126 | bos_token_id=bos_token_id, 127 | eos_token_id=eos_token_id, 128 | tie_word_embeddings=tie_word_embeddings, 129 | **kwargs, 130 | ) 131 | 132 | def _rope_scaling_validation(self): 133 | """ 134 | Validate the `rope_scaling` configuration. 135 | """ 136 | if self.rope_scaling is None: 137 | return 138 | 139 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 140 | raise ValueError( 141 | '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' 142 | f'got {self.rope_scaling}' 143 | ) 144 | rope_scaling_type = self.rope_scaling.get('type', None) 145 | rope_scaling_factor = self.rope_scaling.get('factor', None) 146 | if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']: 147 | raise ValueError( 148 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 149 | ) 150 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0: 151 | raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}") 152 | -------------------------------------------------------------------------------- /internvl/model/internlm2/tokenization_internlm2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tokenization classes for InternLM.""" 18 | import os 19 | from shutil import copyfile 20 | from typing import Any, Dict, List, Optional, Tuple 21 | 22 | import sentencepiece as spm 23 | from transformers.tokenization_utils import PreTrainedTokenizer 24 | from transformers.utils import logging 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} 29 | 30 | PRETRAINED_VOCAB_FILES_MAP = {} 31 | 32 | 33 | # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer 34 | class InternLM2Tokenizer(PreTrainedTokenizer): 35 | """ 36 | Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. 37 | 38 | Args: 39 | vocab_file (`str`): 40 | Path to the vocabulary file. 41 | """ 42 | 43 | vocab_files_names = VOCAB_FILES_NAMES 44 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 45 | model_input_names = ['input_ids', 'attention_mask'] 46 | _auto_class = 'AutoTokenizer' 47 | 48 | def __init__( 49 | self, 50 | vocab_file, 51 | unk_token='', 52 | bos_token='', 53 | eos_token='', 54 | pad_token='', 55 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 56 | add_bos_token=True, 57 | add_eos_token=False, 58 | decode_with_prefix_space=False, 59 | clean_up_tokenization_spaces=False, 60 | **kwargs, 61 | ): 62 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 63 | self.vocab_file = vocab_file 64 | self.add_bos_token = add_bos_token 65 | self.add_eos_token = add_eos_token 66 | self.decode_with_prefix_space = decode_with_prefix_space 67 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 68 | self.sp_model.Load(vocab_file) 69 | self._no_prefix_space_tokens = None 70 | super().__init__( 71 | bos_token=bos_token, 72 | eos_token=eos_token, 73 | unk_token=unk_token, 74 | pad_token=pad_token, 75 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 76 | **kwargs, 77 | ) 78 | 79 | @property 80 | def no_prefix_space_tokens(self): 81 | if self._no_prefix_space_tokens is None: 82 | vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) 83 | self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith('▁')} 84 | return self._no_prefix_space_tokens 85 | 86 | @property 87 | def vocab_size(self): 88 | """Returns vocab size""" 89 | return self.sp_model.get_piece_size() 90 | 91 | @property 92 | def bos_token_id(self) -> Optional[int]: 93 | return self.sp_model.bos_id() 94 | 95 | @property 96 | def eos_token_id(self) -> Optional[int]: 97 | return self.sp_model.eos_id() 98 | 99 | def get_vocab(self): 100 | """Returns vocab as a dict""" 101 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 102 | vocab.update(self.added_tokens_encoder) 103 | return vocab 104 | 105 | def _tokenize(self, text): 106 | """Returns a tokenized string.""" 107 | return self.sp_model.encode(text, out_type=str) 108 | 109 | def _convert_token_to_id(self, token): 110 | """Converts a token (str) in an id using the vocab.""" 111 | return self.sp_model.piece_to_id(token) 112 | 113 | def _convert_id_to_token(self, index): 114 | """Converts an index (integer) in a token (str) using the vocab.""" 115 | token = self.sp_model.IdToPiece(index) 116 | return token 117 | 118 | def _maybe_add_prefix_space(self, tokens, decoded): 119 | if tokens and tokens[0] not in self.no_prefix_space_tokens: 120 | return ' ' + decoded 121 | else: 122 | return decoded 123 | 124 | def convert_tokens_to_string(self, tokens): 125 | """Converts a sequence of tokens (string) in a single string.""" 126 | current_sub_tokens = [] 127 | out_string = '' 128 | prev_is_special = False 129 | for token in tokens: 130 | # make sure that special tokens are not decoded using sentencepiece model 131 | if token in self.all_special_tokens: 132 | if not prev_is_special: 133 | out_string += ' ' 134 | out_string += self.sp_model.decode(current_sub_tokens) + token 135 | prev_is_special = True 136 | current_sub_tokens = [] 137 | else: 138 | current_sub_tokens.append(token) 139 | prev_is_special = False 140 | out_string += self.sp_model.decode(current_sub_tokens) 141 | out_string = self.clean_up_tokenization(out_string) 142 | out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) 143 | return out_string[1:] 144 | 145 | def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: 146 | """ 147 | Save the vocabulary and special tokens file to a directory. 148 | 149 | Args: 150 | save_directory (`str`): 151 | The directory in which to save the vocabulary. 152 | 153 | Returns: 154 | `Tuple(str)`: Paths to the files saved. 155 | """ 156 | if not os.path.isdir(save_directory): 157 | logger.error(f'Vocabulary path ({save_directory}) should be a directory') 158 | return 159 | out_vocab_file = os.path.join( 160 | save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] 161 | ) 162 | 163 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): 164 | copyfile(self.vocab_file, out_vocab_file) 165 | elif not os.path.isfile(self.vocab_file): 166 | with open(out_vocab_file, 'wb') as fi: 167 | content_spiece_model = self.sp_model.serialized_model_proto() 168 | fi.write(content_spiece_model) 169 | 170 | return (out_vocab_file,) 171 | 172 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 173 | if self.add_bos_token: 174 | bos_token_ids = [self.bos_token_id] 175 | else: 176 | bos_token_ids = [] 177 | 178 | output = bos_token_ids + token_ids_0 179 | 180 | if token_ids_1 is not None: 181 | output = output + token_ids_1 182 | 183 | if self.add_eos_token: 184 | output = output + [self.eos_token_id] 185 | 186 | return output 187 | 188 | def get_special_tokens_mask( 189 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 190 | ) -> List[int]: 191 | """ 192 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 193 | special tokens using the tokenizer `prepare_for_model` method. 194 | 195 | Args: 196 | token_ids_0 (`List[int]`): 197 | List of IDs. 198 | token_ids_1 (`List[int]`, *optional*): 199 | Optional second list of IDs for sequence pairs. 200 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 201 | Whether or not the token list is already formatted with special tokens for the model. 202 | 203 | Returns: 204 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 205 | """ 206 | if already_has_special_tokens: 207 | return super().get_special_tokens_mask( 208 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 209 | ) 210 | 211 | if token_ids_1 is None: 212 | return [1] + ([0] * len(token_ids_0)) + [1] 213 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 214 | 215 | def create_token_type_ids_from_sequences( 216 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 217 | ) -> List[int]: 218 | """ 219 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make 220 | use of token type ids, therefore a list of zeros is returned. 221 | 222 | Args: 223 | token_ids_0 (`List[int]`): 224 | List of IDs. 225 | token_ids_1 (`List[int]`, *optional*): 226 | Optional second list of IDs for sequence pairs. 227 | 228 | Returns: 229 | `List[int]`: List of zeros. 230 | """ 231 | eos = [self.eos_token_id] 232 | 233 | if token_ids_1 is None: 234 | return len(token_ids_0 + eos) * [0] 235 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0] 236 | -------------------------------------------------------------------------------- /internvl/model/internlm2/tokenization_internlm2_fast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tokenization Fast class for InternLM.""" 18 | import os 19 | from shutil import copyfile 20 | from typing import Any, Dict, Optional, Tuple 21 | 22 | from tokenizers import Tokenizer, decoders, normalizers, processors 23 | from tokenizers.models import BPE 24 | from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS, 25 | SentencePieceExtractor, 26 | SpmConverter) 27 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 28 | from transformers.utils import logging 29 | 30 | from .tokenization_internlm2 import InternLM2Tokenizer 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} 35 | 36 | 37 | # Modified from transformers.convert_slow_tokenizer.LlamaConverter 38 | class InternLM2Converter(SpmConverter): 39 | handle_byte_fallback = True 40 | 41 | def vocab(self, proto): 42 | vocab = [ 43 | ('', 0.0), 44 | ('', 0.0), 45 | ('', 0.0), 46 | ] 47 | vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] 48 | return vocab 49 | 50 | def unk_id(self, proto): 51 | unk_id = 0 52 | return unk_id 53 | 54 | def decoder(self, replacement, add_prefix_space): 55 | return decoders.Sequence( 56 | [ 57 | decoders.Replace('▁', ' '), 58 | decoders.ByteFallback(), 59 | decoders.Fuse(), 60 | decoders.Strip(content=' ', left=1), 61 | ] 62 | ) 63 | 64 | def tokenizer(self, proto): 65 | model_type = proto.trainer_spec.model_type 66 | vocab_scores = self.vocab(proto) 67 | # special tokens 68 | added_tokens = self.original_tokenizer.added_tokens_decoder 69 | for i in range(len(vocab_scores)): 70 | piece, score = vocab_scores[i] 71 | if i in added_tokens: 72 | vocab_scores[i] = (added_tokens[i].content, score) 73 | if model_type == 1: 74 | raise RuntimeError('InternLM2 is supposed to be a BPE model!') 75 | 76 | elif model_type == 2: 77 | _, merges = SentencePieceExtractor(self.original_tokenizer.vocab_file).extract(vocab_scores) 78 | bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} 79 | tokenizer = Tokenizer( 80 | BPE(bpe_vocab, merges, unk_token=proto.trainer_spec.unk_piece, fuse_unk=True, byte_fallback=True) 81 | ) 82 | tokenizer.add_special_tokens( 83 | [ added_token for index, added_token in added_tokens.items()] 84 | ) 85 | else: 86 | raise Exception( 87 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" 88 | ) 89 | 90 | return tokenizer 91 | 92 | def normalizer(self, proto): 93 | normalizers_list = [] 94 | if proto.normalizer_spec.add_dummy_prefix: 95 | normalizers_list.append(normalizers.Prepend(prepend='▁')) 96 | normalizers_list.append(normalizers.Replace(pattern=' ', content='▁')) 97 | return normalizers.Sequence(normalizers_list) 98 | 99 | def pre_tokenizer(self, replacement, add_prefix_space): 100 | return None 101 | 102 | 103 | SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter 104 | 105 | 106 | # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast 107 | class InternLM2TokenizerFast(PreTrainedTokenizerFast): 108 | vocab_files_names = VOCAB_FILES_NAMES 109 | slow_tokenizer_class = InternLM2Tokenizer 110 | padding_side = 'left' 111 | model_input_names = ['input_ids', 'attention_mask'] 112 | _auto_class = 'AutoTokenizer' 113 | 114 | def __init__( 115 | self, 116 | vocab_file, 117 | unk_token='', 118 | bos_token='', 119 | eos_token='', 120 | pad_token='', 121 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 122 | add_bos_token=True, 123 | add_eos_token=False, 124 | decode_with_prefix_space=False, 125 | clean_up_tokenization_spaces=False, 126 | **kwargs, 127 | ): 128 | super().__init__( 129 | vocab_file=vocab_file, 130 | unk_token=unk_token, 131 | bos_token=bos_token, 132 | eos_token=eos_token, 133 | pad_token=pad_token, 134 | sp_model_kwargs=sp_model_kwargs, 135 | add_bos_token=add_bos_token, 136 | add_eos_token=add_eos_token, 137 | decode_with_prefix_space=decode_with_prefix_space, 138 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 139 | **kwargs, 140 | ) 141 | self._add_bos_token = add_bos_token 142 | self._add_eos_token = add_eos_token 143 | self.update_post_processor() 144 | self.vocab_file = vocab_file 145 | 146 | @property 147 | def can_save_slow_tokenizer(self) -> bool: 148 | return os.path.isfile(self.vocab_file) if self.vocab_file else False 149 | 150 | def update_post_processor(self): 151 | """ 152 | Updates the underlying post processor with the current `bos_token` and `eos_token`. 153 | """ 154 | bos = self.bos_token 155 | bos_token_id = self.bos_token_id 156 | if bos is None and self.add_bos_token: 157 | raise ValueError('add_bos_token = True but bos_token = None') 158 | 159 | eos = self.eos_token 160 | eos_token_id = self.eos_token_id 161 | if eos is None and self.add_eos_token: 162 | raise ValueError('add_eos_token = True but eos_token = None') 163 | 164 | single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" 165 | pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" 166 | 167 | special_tokens = [] 168 | if self.add_bos_token: 169 | special_tokens.append((bos, bos_token_id)) 170 | if self.add_eos_token: 171 | special_tokens.append((eos, eos_token_id)) 172 | self._tokenizer.post_processor = processors.TemplateProcessing( 173 | single=single, pair=pair, special_tokens=special_tokens 174 | ) 175 | 176 | @property 177 | def add_eos_token(self): 178 | return self._add_eos_token 179 | 180 | @property 181 | def add_bos_token(self): 182 | return self._add_bos_token 183 | 184 | @add_eos_token.setter 185 | def add_eos_token(self, value): 186 | self._add_eos_token = value 187 | self.update_post_processor() 188 | 189 | @add_bos_token.setter 190 | def add_bos_token(self, value): 191 | self._add_bos_token = value 192 | self.update_post_processor() 193 | 194 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 195 | if not self.can_save_slow_tokenizer: 196 | raise ValueError( 197 | 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' 198 | 'tokenizer.' 199 | ) 200 | 201 | if not os.path.isdir(save_directory): 202 | logger.error(f'Vocabulary path ({save_directory}) should be a directory') 203 | return 204 | out_vocab_file = os.path.join( 205 | save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file'] 206 | ) 207 | 208 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 209 | copyfile(self.vocab_file, out_vocab_file) 210 | 211 | return (out_vocab_file,) 212 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2025 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .configuration_intern_vit import InternVisionConfig,InternVisionPatchConfig 8 | from .configuration_internvl_chat import InternVLChatConfig 9 | from .modeling_intern_vit import InternVisionModel,InternVisionPatchModel 10 | from .modeling_internvl_chat import InternVLChatModel 11 | 12 | __all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVisionPatchModel', 13 | 'InternVLChatConfig', 'InternVisionPatchConfig','InternVLChatModel'] 14 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/configuration_intern_vit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2025 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | import os 7 | from typing import Union 8 | 9 | from transformers.configuration_utils import PretrainedConfig 10 | from transformers.utils import logging 11 | 12 | logger = logging.get_logger(__name__) 13 | 14 | 15 | class InternVisionConfig(PretrainedConfig): 16 | r""" 17 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to 18 | instantiate a vision encoder according to the specified arguments, defining the model architecture. 19 | 20 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 21 | documentation from [`PretrainedConfig`] for more information. 22 | 23 | Args: 24 | num_channels (`int`, *optional*, defaults to 3): 25 | Number of color channels in the input images (e.g., 3 for RGB). 26 | patch_size (`int`, *optional*, defaults to 14): 27 | The size (resolution) of each patch. 28 | image_size (`int`, *optional*, defaults to 224): 29 | The size (resolution) of each image. 30 | qkv_bias (`bool`, *optional*, defaults to `False`): 31 | Whether to add a bias to the queries and values in the self-attention layers. 32 | hidden_size (`int`, *optional*, defaults to 3200): 33 | Dimensionality of the encoder layers and the pooler layer. 34 | num_attention_heads (`int`, *optional*, defaults to 25): 35 | Number of attention heads for each attention layer in the Transformer encoder. 36 | intermediate_size (`int`, *optional*, defaults to 12800): 37 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 38 | qk_normalization (`bool`, *optional*, defaults to `True`): 39 | Whether to normalize the queries and keys in the self-attention layers. 40 | num_hidden_layers (`int`, *optional*, defaults to 48): 41 | Number of hidden layers in the Transformer encoder. 42 | use_flash_attn (`bool`, *optional*, defaults to `True`): 43 | Whether to use flash attention mechanism. 44 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 45 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 46 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. 47 | layer_norm_eps (`float`, *optional*, defaults to 1e-6): 48 | The epsilon used by the layer normalization layers. 49 | dropout (`float`, *optional*, defaults to 0.0): 50 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 51 | drop_path_rate (`float`, *optional*, defaults to 0.0): 52 | Dropout rate for stochastic depth. 53 | attention_dropout (`float`, *optional*, defaults to 0.0): 54 | The dropout ratio for the attention probabilities. 55 | initializer_range (`float`, *optional*, defaults to 0.02): 56 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 57 | initializer_factor (`float`, *optional*, defaults to 0.1): 58 | A factor for layer scale. 59 | """ 60 | 61 | model_type = 'intern_vit_6b' 62 | 63 | def __init__( 64 | self, 65 | num_channels=3, 66 | patch_size=14, 67 | image_size=224, 68 | qkv_bias=False, 69 | hidden_size=3200, 70 | num_attention_heads=25, 71 | intermediate_size=12800, 72 | qk_normalization=True, 73 | num_hidden_layers=48, 74 | use_flash_attn=True, 75 | hidden_act='gelu', 76 | norm_type='rms_norm', 77 | layer_norm_eps=1e-6, 78 | dropout=0.0, 79 | drop_path_rate=0.0, 80 | attention_dropout=0.0, 81 | initializer_range=0.02, 82 | initializer_factor=0.1, 83 | **kwargs, 84 | ): 85 | super().__init__(**kwargs) 86 | 87 | self.hidden_size = hidden_size 88 | self.intermediate_size = intermediate_size 89 | self.dropout = dropout 90 | self.drop_path_rate = drop_path_rate 91 | self.num_hidden_layers = num_hidden_layers 92 | self.num_attention_heads = num_attention_heads 93 | self.num_channels = num_channels 94 | self.patch_size = patch_size 95 | self.image_size = image_size 96 | self.initializer_range = initializer_range 97 | self.initializer_factor = initializer_factor 98 | self.attention_dropout = attention_dropout 99 | self.layer_norm_eps = layer_norm_eps 100 | self.hidden_act = hidden_act 101 | self.norm_type = norm_type 102 | self.qkv_bias = qkv_bias 103 | self.qk_normalization = qk_normalization 104 | self.use_flash_attn = use_flash_attn 105 | 106 | @classmethod 107 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': 108 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 109 | 110 | if 'vision_config' in config_dict: 111 | config_dict = config_dict['vision_config'] 112 | 113 | if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: 114 | logger.warning( 115 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 116 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' 117 | ) 118 | 119 | 120 | return cls.from_dict(config_dict, **kwargs) 121 | 122 | 123 | 124 | 125 | class InternVisionPatchConfig(PretrainedConfig): 126 | r""" 127 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to 128 | instantiate a vision encoder according to the specified arguments, defining the model architecture. 129 | 130 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 131 | documentation from [`PretrainedConfig`] for more information. 132 | 133 | Args: 134 | num_channels (`int`, *optional*, defaults to 3): 135 | Number of color channels in the input images (e.g., 3 for RGB). 136 | patch_size (`int`, *optional*, defaults to 14): 137 | The size (resolution) of each patch. 138 | image_size (`int`, *optional*, defaults to 224): 139 | The size (resolution) of each image. 140 | qkv_bias (`bool`, *optional*, defaults to `False`): 141 | Whether to add a bias to the queries and values in the self-attention layers. 142 | hidden_size (`int`, *optional*, defaults to 3200): 143 | Dimensionality of the encoder layers and the pooler layer. 144 | num_attention_heads (`int`, *optional*, defaults to 25): 145 | Number of attention heads for each attention layer in the Transformer encoder. 146 | intermediate_size (`int`, *optional*, defaults to 12800): 147 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 148 | qk_normalization (`bool`, *optional*, defaults to `True`): 149 | Whether to normalize the queries and keys in the self-attention layers. 150 | num_hidden_layers (`int`, *optional*, defaults to 48): 151 | Number of hidden layers in the Transformer encoder. 152 | use_flash_attn (`bool`, *optional*, defaults to `True`): 153 | Whether to use flash attention mechanism. 154 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 155 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 156 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. 157 | layer_norm_eps (`float`, *optional*, defaults to 1e-6): 158 | The epsilon used by the layer normalization layers. 159 | dropout (`float`, *optional*, defaults to 0.0): 160 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 161 | drop_path_rate (`float`, *optional*, defaults to 0.0): 162 | Dropout rate for stochastic depth. 163 | attention_dropout (`float`, *optional*, defaults to 0.0): 164 | The dropout ratio for the attention probabilities. 165 | initializer_range (`float`, *optional*, defaults to 0.02): 166 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 167 | initializer_factor (`float`, *optional*, defaults to 0.1): 168 | A factor for layer scale. 169 | """ 170 | 171 | model_type = 'intern_vit_patch' 172 | 173 | def __init__( 174 | self, 175 | num_channels=3, 176 | patch_size=14, 177 | image_size=224, 178 | qkv_bias=False, 179 | hidden_size=3200, 180 | num_attention_heads=25, 181 | intermediate_size=12800, 182 | qk_normalization=True, 183 | num_hidden_layers=48, 184 | use_flash_attn=True, 185 | hidden_act='gelu', 186 | norm_type='rms_norm', 187 | layer_norm_eps=1e-6, 188 | dropout=0.0, 189 | drop_path_rate=0.0, 190 | attention_dropout=0.0, 191 | initializer_range=0.02, 192 | initializer_factor=0.1, 193 | **kwargs, 194 | ): 195 | super().__init__(**kwargs) 196 | 197 | self.hidden_size = hidden_size 198 | self.intermediate_size = intermediate_size 199 | self.dropout = dropout 200 | self.drop_path_rate = drop_path_rate 201 | self.num_hidden_layers = num_hidden_layers 202 | self.num_attention_heads = num_attention_heads 203 | self.num_channels = num_channels 204 | self.patch_size = patch_size 205 | self.image_size = image_size 206 | self.initializer_range = initializer_range 207 | self.initializer_factor = initializer_factor 208 | self.attention_dropout = attention_dropout 209 | self.layer_norm_eps = layer_norm_eps 210 | self.hidden_act = hidden_act 211 | self.norm_type = norm_type 212 | self.qkv_bias = qkv_bias 213 | self.qk_normalization = qk_normalization 214 | self.use_flash_attn = use_flash_attn 215 | 216 | @classmethod 217 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig': 218 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 219 | 220 | if 'vision_config' in config_dict: 221 | config_dict = config_dict['vision_config'] 222 | 223 | if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type: 224 | logger.warning( 225 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 226 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' 227 | ) 228 | 229 | 230 | return cls.from_dict(config_dict, **kwargs) 231 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/configuration_internvl_chat.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2025 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import copy 8 | 9 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config 10 | from transformers import AutoConfig, LlamaConfig, Qwen2Config 11 | from transformers.configuration_utils import PretrainedConfig 12 | from transformers.utils import logging 13 | 14 | from .configuration_intern_vit import InternVisionConfig, InternVisionPatchConfig 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | 19 | class InternVLChatConfig(PretrainedConfig): 20 | model_type = 'internvl_chat' 21 | is_composition = True 22 | 23 | def __init__( 24 | self, 25 | vision_config=None, 26 | llm_config=None, 27 | use_backbone_lora=0, 28 | use_llm_lora=0, 29 | pad2square=False, 30 | select_layer=-4, 31 | force_image_size=None, 32 | downsample_ratio=0.5, 33 | template=None, 34 | dynamic_image_size=False, 35 | use_thumbnail=False, 36 | ps_version='v1', 37 | min_dynamic_patch=1, 38 | max_dynamic_patch=6, 39 | **kwargs): 40 | super().__init__(**kwargs) 41 | 42 | if vision_config is None: 43 | vision_config = {} 44 | logger.info('vision_config is None. Initializing the InternVisionConfig with default values.') 45 | 46 | if llm_config is None: 47 | llm_config = {} 48 | logger.info('llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).') 49 | 50 | if vision_config and vision_config['model_type']=='intern_vit_patch': 51 | self.vision_config = InternVisionPatchConfig(**vision_config) 52 | else: 53 | self.vision_config = InternVisionConfig(**vision_config) 54 | if llm_config['architectures'][0] == 'LlamaForCausalLM': 55 | self.llm_config = LlamaConfig(**llm_config) 56 | elif llm_config['architectures'][0] == 'InternLM2ForCausalLM': 57 | self.llm_config = InternLM2Config(**llm_config) 58 | elif llm_config['architectures'][0] == 'InternLM2VEForCausalLM': 59 | self.llm_config = InternLM2Config(**llm_config) 60 | elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': 61 | self.llm_config = Qwen2Config(**llm_config) 62 | else: 63 | raise ValueError('Unsupported architecture: {}'.format(llm_config['architectures'][0])) 64 | self.use_backbone_lora = use_backbone_lora 65 | self.use_llm_lora = use_llm_lora 66 | self.pad2square = pad2square 67 | self.select_layer = select_layer 68 | self.force_image_size = force_image_size 69 | self.downsample_ratio = downsample_ratio 70 | self.template = template 71 | self.dynamic_image_size = dynamic_image_size 72 | self.use_thumbnail = use_thumbnail 73 | self.ps_version = ps_version # pixel shuffle version 74 | self.min_dynamic_patch = min_dynamic_patch 75 | self.max_dynamic_patch = max_dynamic_patch 76 | 77 | logger.info(f'vision_select_layer: {self.select_layer}') 78 | logger.info(f'ps_version: {self.ps_version}') 79 | logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') 80 | logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') 81 | 82 | def to_dict(self): 83 | """ 84 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. 85 | 86 | Returns: 87 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 88 | """ 89 | output = copy.deepcopy(self.__dict__) 90 | output['vision_config'] = self.vision_config.to_dict() 91 | output['llm_config'] = self.llm_config.to_dict() 92 | output['model_type'] = self.__class__.model_type 93 | output['use_backbone_lora'] = self.use_backbone_lora 94 | output['use_llm_lora'] = self.use_llm_lora 95 | output['pad2square'] = self.pad2square 96 | output['select_layer'] = self.select_layer 97 | output['force_image_size'] = self.force_image_size 98 | output['downsample_ratio'] = self.downsample_ratio 99 | output['template'] = self.template 100 | output['dynamic_image_size'] = self.dynamic_image_size 101 | output['use_thumbnail'] = self.use_thumbnail 102 | output['ps_version'] = self.ps_version 103 | output['min_dynamic_patch'] = self.min_dynamic_patch 104 | output['max_dynamic_patch'] = self.max_dynamic_patch 105 | 106 | return output 107 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/flash_attention.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py 2 | import torch 3 | import torch.nn as nn 4 | from einops import rearrange 5 | 6 | try: # v1 7 | from flash_attn.flash_attn_interface import \ 8 | flash_attn_unpadded_qkvpacked_func 9 | except: # v2 10 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 11 | 12 | from flash_attn.bert_padding import pad_input, unpad_input 13 | 14 | 15 | class FlashAttention(nn.Module): 16 | """Implement the scaled dot product attention with softmax. 17 | Arguments 18 | --------- 19 | softmax_scale: The temperature to use for the softmax attention. 20 | (default: 1/sqrt(d_keys) where d_keys is computed at 21 | runtime) 22 | attention_dropout: The dropout rate to apply to the attention 23 | (default: 0.0) 24 | """ 25 | 26 | def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): 27 | super().__init__() 28 | self.softmax_scale = softmax_scale 29 | self.dropout_p = attention_dropout 30 | 31 | def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None, 32 | max_s=None, need_weights=False): 33 | """Implements the multihead softmax attention. 34 | Arguments 35 | --------- 36 | qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None 37 | if unpadded: (nnz, 3, h, d) 38 | key_padding_mask: a bool tensor of shape (B, S) 39 | """ 40 | assert not need_weights 41 | assert qkv.dtype in [torch.float16, torch.bfloat16] 42 | assert qkv.is_cuda 43 | 44 | if cu_seqlens is None: 45 | batch_size = qkv.shape[0] 46 | seqlen = qkv.shape[1] 47 | if key_padding_mask is None: 48 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 49 | max_s = seqlen 50 | cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, 51 | device=qkv.device) 52 | output = flash_attn_unpadded_qkvpacked_func( 53 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 54 | softmax_scale=self.softmax_scale, causal=causal 55 | ) 56 | output = rearrange(output, '(b s) ... -> b s ...', b=batch_size) 57 | else: 58 | nheads = qkv.shape[-2] 59 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 60 | x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask) 61 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 62 | output_unpad = flash_attn_unpadded_qkvpacked_func( 63 | x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 64 | softmax_scale=self.softmax_scale, causal=causal 65 | ) 66 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 67 | indices, batch_size, seqlen), 68 | 'b s (h d) -> b s h d', h=nheads) 69 | else: 70 | assert max_s is not None 71 | output = flash_attn_unpadded_qkvpacked_func( 72 | qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0, 73 | softmax_scale=self.softmax_scale, causal=causal 74 | ) 75 | 76 | return output, None 77 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/modeling_intern_vit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2025 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | from typing import Optional, Tuple, Union 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from einops import rearrange 12 | from timm.models.layers import DropPath 13 | from torch import nn 14 | from transformers.activations import ACT2FN 15 | from transformers.modeling_outputs import (BaseModelOutput, 16 | BaseModelOutputWithPooling) 17 | from transformers.modeling_utils import PreTrainedModel 18 | from transformers.utils import logging 19 | 20 | from .configuration_intern_vit import InternVisionConfig, InternVisionPatchConfig 21 | 22 | try: 23 | from .flash_attention import FlashAttention 24 | has_flash_attn = True 25 | except: 26 | print('FlashAttention is not installed.') 27 | has_flash_attn = False 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | 33 | class InternRMSNorm(nn.Module): 34 | def __init__(self, hidden_size, eps=1e-6): 35 | super().__init__() 36 | self.weight = nn.Parameter(torch.ones(hidden_size)) 37 | self.variance_epsilon = eps 38 | 39 | def forward(self, hidden_states): 40 | input_dtype = hidden_states.dtype 41 | hidden_states = hidden_states.to(torch.float32) 42 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 43 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 44 | return self.weight * hidden_states.to(input_dtype) 45 | 46 | 47 | try: 48 | from apex.normalization import FusedRMSNorm 49 | 50 | InternRMSNorm = FusedRMSNorm # noqa 51 | 52 | logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm') 53 | except ImportError: 54 | # using the normal InternRMSNorm 55 | pass 56 | except Exception: 57 | logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm') 58 | pass 59 | 60 | 61 | NORM2FN = { 62 | 'rms_norm': InternRMSNorm, 63 | 'layer_norm': nn.LayerNorm, 64 | } 65 | 66 | 67 | class InternVisionEmbeddings(nn.Module): 68 | def __init__(self, config: InternVisionConfig): 69 | super().__init__() 70 | self.config = config 71 | self.embed_dim = config.hidden_size 72 | self.image_size = config.image_size 73 | self.patch_size = config.patch_size 74 | 75 | self.class_embedding = nn.Parameter( 76 | torch.randn(1, 1, self.embed_dim), 77 | ) 78 | 79 | self.patch_embedding = nn.Conv2d( 80 | in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size 81 | ) 82 | 83 | self.num_patches = (self.image_size // self.patch_size) ** 2 84 | self.num_positions = self.num_patches + 1 85 | 86 | self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim)) 87 | 88 | def _get_pos_embed(self, pos_embed, H, W): 89 | target_dtype = pos_embed.dtype 90 | pos_embed = pos_embed.float().reshape( 91 | 1, self.image_size // self.patch_size, self.image_size // self.patch_size, -1).permute(0, 3, 1, 2) 92 | pos_embed = F.interpolate(pos_embed, size=(H, W), mode='bicubic', align_corners=False).\ 93 | reshape(1, -1, H * W).permute(0, 2, 1).to(target_dtype) 94 | return pos_embed 95 | 96 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 97 | target_dtype = self.patch_embedding.weight.dtype 98 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, channel, width, height] 99 | batch_size, _, height, width = patch_embeds.shape 100 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 101 | class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype) 102 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 103 | position_embedding = torch.cat([ 104 | self.position_embedding[:, :1, :], 105 | self._get_pos_embed(self.position_embedding[:, 1:, :], height, width) 106 | ], dim=1) 107 | embeddings = embeddings + position_embedding.to(target_dtype) 108 | return embeddings 109 | 110 | 111 | class InternAttention(nn.Module): 112 | """Multi-headed attention from 'Attention Is All You Need' paper""" 113 | 114 | def __init__(self, config: InternVisionConfig): 115 | super().__init__() 116 | self.config = config 117 | self.embed_dim = config.hidden_size 118 | self.num_heads = config.num_attention_heads 119 | self.use_flash_attn = config.use_flash_attn and has_flash_attn 120 | if config.use_flash_attn and not has_flash_attn: 121 | print('Warning: Flash Attention is not available, use_flash_attn is set to False.') 122 | self.head_dim = self.embed_dim // self.num_heads 123 | if self.head_dim * self.num_heads != self.embed_dim: 124 | raise ValueError( 125 | f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:' 126 | f' {self.num_heads}).' 127 | ) 128 | 129 | self.scale = self.head_dim ** -0.5 130 | self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias) 131 | self.attn_drop = nn.Dropout(config.attention_dropout) 132 | self.proj_drop = nn.Dropout(config.dropout) 133 | 134 | self.qk_normalization = config.qk_normalization 135 | 136 | if self.qk_normalization: 137 | self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) 138 | self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) 139 | 140 | if self.use_flash_attn: 141 | self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout) 142 | self.proj = nn.Linear(self.embed_dim, self.embed_dim) 143 | 144 | def _naive_attn(self, x): 145 | B, N, C = x.shape 146 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 147 | q, k, v = qkv.unbind(0) 148 | 149 | if self.qk_normalization: 150 | B_, H_, N_, D_ = q.shape 151 | q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) 152 | k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2) 153 | 154 | attn = ((q * self.scale) @ k.transpose(-2, -1)) 155 | attn = attn.softmax(dim=-1) 156 | attn = self.attn_drop(attn) 157 | 158 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 159 | x = self.proj(x) 160 | x = self.proj_drop(x) 161 | return x 162 | 163 | def _flash_attn(self, x, key_padding_mask=None, need_weights=False): 164 | qkv = self.qkv(x) 165 | qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads) 166 | 167 | if self.qk_normalization: 168 | q, k, v = qkv.unbind(2) 169 | q = self.q_norm(q.flatten(-2, -1)).view(q.shape) 170 | k = self.k_norm(k.flatten(-2, -1)).view(k.shape) 171 | qkv = torch.stack([q, k, v], dim=2) 172 | 173 | context, _ = self.inner_attn( 174 | qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False 175 | ) 176 | outs = self.proj(rearrange(context, 'b s h d -> b s (h d)')) 177 | outs = self.proj_drop(outs) 178 | return outs 179 | 180 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 181 | x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states) 182 | return x 183 | 184 | 185 | class InternMLP(nn.Module): 186 | def __init__(self, config: InternVisionConfig): 187 | super().__init__() 188 | self.config = config 189 | self.act = ACT2FN[config.hidden_act] 190 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 191 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 192 | 193 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 194 | hidden_states = self.fc1(hidden_states) 195 | hidden_states = self.act(hidden_states) 196 | hidden_states = self.fc2(hidden_states) 197 | return hidden_states 198 | 199 | 200 | class InternVisionEncoderLayer(nn.Module): 201 | def __init__(self, config: InternVisionConfig, drop_path_rate: float): 202 | super().__init__() 203 | self.embed_dim = config.hidden_size 204 | self.intermediate_size = config.intermediate_size 205 | self.norm_type = config.norm_type 206 | 207 | self.attn = InternAttention(config) 208 | self.mlp = InternMLP(config) 209 | self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) 210 | self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) 211 | 212 | self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) 213 | self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim)) 214 | self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 215 | self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() 216 | 217 | def forward( 218 | self, 219 | hidden_states: torch.Tensor, 220 | ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]: 221 | """ 222 | Args: 223 | hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)` 224 | """ 225 | hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1) 226 | 227 | hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2) 228 | 229 | return hidden_states 230 | 231 | 232 | class InternVisionEncoder(nn.Module): 233 | """ 234 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a 235 | [`InternEncoderLayer`]. 236 | 237 | Args: 238 | config (`InternConfig`): 239 | The corresponding vision configuration for the `InternEncoder`. 240 | """ 241 | 242 | def __init__(self, config: InternVisionConfig): 243 | super().__init__() 244 | self.config = config 245 | # stochastic depth decay rule 246 | dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)] 247 | self.layers = nn.ModuleList([ 248 | InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)]) 249 | self.gradient_checkpointing = True 250 | 251 | def forward( 252 | self, 253 | inputs_embeds, 254 | output_hidden_states: Optional[bool] = None, 255 | return_dict: Optional[bool] = None, 256 | ) -> Union[Tuple, BaseModelOutput]: 257 | r""" 258 | Args: 259 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 260 | Embedded representation of the inputs. Should be float, not int tokens. 261 | output_hidden_states (`bool`, *optional*): 262 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 263 | for more detail. 264 | return_dict (`bool`, *optional*): 265 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 266 | """ 267 | output_hidden_states = ( 268 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 269 | ) 270 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 271 | 272 | encoder_states = () if output_hidden_states else None 273 | hidden_states = inputs_embeds 274 | for idx, encoder_layer in enumerate(self.layers): 275 | if output_hidden_states: 276 | encoder_states = encoder_states + (hidden_states,) 277 | if self.gradient_checkpointing and self.training: 278 | layer_outputs = torch.utils.checkpoint.checkpoint( 279 | encoder_layer, 280 | hidden_states) 281 | else: 282 | layer_outputs = encoder_layer( 283 | hidden_states, 284 | ) 285 | hidden_states = layer_outputs 286 | 287 | if output_hidden_states: 288 | encoder_states = encoder_states + (hidden_states,) 289 | 290 | if not return_dict: 291 | return tuple(v for v in [hidden_states, encoder_states] if v is not None) 292 | return BaseModelOutput( 293 | last_hidden_state=hidden_states, hidden_states=encoder_states 294 | ) 295 | 296 | 297 | class InternVisionModel(PreTrainedModel): 298 | main_input_name = 'pixel_values' 299 | config_class = InternVisionConfig 300 | _no_split_modules = ['InternVisionEncoderLayer'] 301 | 302 | def __init__(self, config: InternVisionConfig): 303 | super().__init__(config) 304 | self.config = config 305 | 306 | self.embeddings = InternVisionEmbeddings(config) 307 | self.encoder = InternVisionEncoder(config) 308 | 309 | def resize_pos_embeddings(self, old_size, new_size, patch_size): 310 | pos_emb = self.embeddings.position_embedding 311 | _, num_positions, embed_dim = pos_emb.shape 312 | cls_emb = pos_emb[:, :1, :] 313 | pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) 314 | pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) 315 | pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) 316 | pos_emb = torch.cat([cls_emb, pos_emb], dim=1) 317 | self.embeddings.position_embedding = nn.Parameter(pos_emb) 318 | self.embeddings.image_size = new_size 319 | logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) 320 | 321 | def get_input_embeddings(self): 322 | return self.embeddings 323 | 324 | def forward( 325 | self, 326 | pixel_values: Optional[torch.FloatTensor] = None, 327 | output_hidden_states: Optional[bool] = None, 328 | return_dict: Optional[bool] = None, 329 | pixel_embeds: Optional[torch.FloatTensor] = None, 330 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 331 | output_hidden_states = ( 332 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 333 | ) 334 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 335 | 336 | if pixel_values is None and pixel_embeds is None: 337 | raise ValueError('You have to specify pixel_values or pixel_embeds') 338 | 339 | if pixel_embeds is not None: 340 | hidden_states = pixel_embeds 341 | else: 342 | if len(pixel_values.shape) == 4: 343 | hidden_states = self.embeddings(pixel_values) 344 | else: 345 | raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') 346 | encoder_outputs = self.encoder( 347 | inputs_embeds=hidden_states, 348 | output_hidden_states=output_hidden_states, 349 | return_dict=return_dict, 350 | ) 351 | last_hidden_state = encoder_outputs.last_hidden_state 352 | pooled_output = last_hidden_state[:, 0, :] 353 | 354 | if not return_dict: 355 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 356 | 357 | return BaseModelOutputWithPooling( 358 | last_hidden_state=last_hidden_state, 359 | pooler_output=pooled_output, 360 | hidden_states=encoder_outputs.hidden_states, 361 | attentions=encoder_outputs.attentions, 362 | ) 363 | 364 | 365 | class InternVisionPatchModel(PreTrainedModel): 366 | main_input_name = 'pixel_values' 367 | config_class = InternVisionPatchConfig 368 | _no_split_modules = ['InternVisionEncoderLayer'] 369 | 370 | def __init__(self, config: InternVisionPatchConfig): 371 | super().__init__(config) 372 | self.config = config 373 | self.embeddings = InternVisionEmbeddings(config) 374 | def resize_pos_embeddings(self, old_size, new_size, patch_size): 375 | pos_emb = self.embeddings.position_embedding 376 | _, num_positions, embed_dim = pos_emb.shape 377 | cls_emb = pos_emb[:, :1, :] 378 | pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2) 379 | pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False) 380 | pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1) 381 | pos_emb = torch.cat([cls_emb, pos_emb], dim=1) 382 | self.embeddings.position_embedding = nn.Parameter(pos_emb) 383 | self.embeddings.image_size = new_size 384 | logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size)) 385 | 386 | def get_input_embeddings(self): 387 | return self.embeddings 388 | 389 | 390 | def forward( 391 | self, 392 | pixel_values: Optional[torch.FloatTensor] = None, 393 | output_hidden_states: Optional[bool] = None, 394 | return_dict: Optional[bool] = None, 395 | pixel_embeds: Optional[torch.FloatTensor] = None, 396 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 397 | 398 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 399 | 400 | if pixel_values is None: 401 | raise ValueError('You have to specify pixel_values') 402 | 403 | 404 | if len(pixel_values.shape) == 4: 405 | hidden_states = self.embeddings(pixel_values)[:,1:] 406 | else: 407 | raise ValueError(f'wrong pixel_values size: {pixel_values.shape}') 408 | 409 | 410 | if not return_dict: 411 | return (hidden_states, None,None) 412 | 413 | return BaseModelOutputWithPooling( 414 | last_hidden_state=hidden_states, 415 | pooler_output=None, 416 | hidden_states=None, 417 | attentions=None, 418 | ) 419 | 420 | 421 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/modeling_internvl_chat.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2025 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | import warnings 7 | from typing import Any, List, Optional, Tuple, Union 8 | 9 | import torch.distributed as dist 10 | import torch.utils.checkpoint 11 | from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM 12 | from internvl.model.internlm2.modeling_internlm2_ve import InternLM2VEForCausalLM 13 | from peft import LoraConfig, get_peft_model 14 | from torch import nn 15 | from torch.nn import CrossEntropyLoss 16 | from transformers import (AutoModel, GenerationConfig, LlamaForCausalLM, 17 | LlamaTokenizer, Qwen2ForCausalLM) 18 | from transformers.modeling_outputs import CausalLMOutputWithPast 19 | from transformers.modeling_utils import PreTrainedModel 20 | from transformers.utils import ModelOutput, logging 21 | 22 | from .configuration_internvl_chat import InternVLChatConfig 23 | from .modeling_intern_vit import InternVisionModel,InternVisionPatchModel 24 | from dataclasses import dataclass 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | @dataclass 29 | class CausalLMOutputWithVisualMask(ModelOutput): 30 | """ 31 | Base class for causal language model (or autoregressive) outputs. 32 | 33 | Args: 34 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): 35 | Language modeling loss (for next-token prediction). 36 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 37 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). 38 | past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): 39 | Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape 40 | `(batch_size, num_heads, sequence_length, embed_size_per_head)`) 41 | 42 | Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see 43 | `past_key_values` input) to speed up sequential decoding. 44 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 45 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 46 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 47 | 48 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 49 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 50 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 51 | sequence_length)`. 52 | 53 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 54 | heads. 55 | """ 56 | 57 | loss: Optional[torch.FloatTensor] = None 58 | logits: torch.FloatTensor = None 59 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 60 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 61 | attentions: Optional[Tuple[torch.FloatTensor]] = None 62 | visual_token_mask : Optional[torch.FloatTensor] = None 63 | 64 | class InternVLChatModel(PreTrainedModel): 65 | config_class = InternVLChatConfig 66 | main_input_name = 'pixel_values' 67 | _no_split_modules = ['InternVision', 'LlamaDecoderLayer', 'InternLM2DecoderLayer', 68 | 'Phi3DecoderLayer', 'Qwen2DecoderLayer'] 69 | 70 | def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None): 71 | super().__init__(config) 72 | 73 | image_size = config.force_image_size or config.vision_config.image_size 74 | patch_size = config.vision_config.patch_size 75 | self.patch_size = patch_size 76 | self.select_layer = config.select_layer 77 | self.template = config.template 78 | self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2)) 79 | self.downsample_ratio = config.downsample_ratio 80 | self.ps_version = config.ps_version 81 | self.use_visual_token_mask=False 82 | 83 | logger.info(f'num_image_token: {self.num_image_token}') 84 | logger.info(f'ps_version: {self.ps_version}') 85 | if vision_model is not None: 86 | self.vision_model = vision_model 87 | elif 'intern_vit_6b' in config.vision_config.model_type: 88 | self.vision_model = InternVisionModel(config.vision_config) 89 | elif 'intern_vit_patch' in config.vision_config.model_type: 90 | self.vision_model = InternVisionPatchModel(config.vision_config) 91 | else: 92 | assert NotImplementedError 93 | 94 | 95 | if language_model is not None: 96 | self.language_model = language_model 97 | else: 98 | if config.llm_config.architectures[0] == 'LlamaForCausalLM': 99 | self.language_model = LlamaForCausalLM(config.llm_config) 100 | elif config.llm_config.architectures[0] == 'InternLM2ForCausalLM': 101 | self.language_model = InternLM2ForCausalLM(config.llm_config) 102 | elif config.llm_config.architectures[0] == 'Qwen2ForCausalLM': 103 | self.language_model = Qwen2ForCausalLM(config.llm_config) 104 | elif config.llm_config.architectures[0]=='InternLM2VEForCausalLM': 105 | self.language_model=InternLM2VEForCausalLM(config.llm_config) 106 | self.use_visual_token_mask=True 107 | else: 108 | raise NotImplementedError(f'{config.llm_config.architectures[0]} is not implemented.') 109 | 110 | self.language_arch=config.llm_config.architectures[0] 111 | vit_hidden_size = config.vision_config.hidden_size 112 | llm_hidden_size = config.llm_config.hidden_size 113 | 114 | self.mlp1 = nn.Sequential( 115 | nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), 116 | nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size), 117 | nn.GELU(), 118 | nn.Linear(llm_hidden_size, llm_hidden_size) 119 | ) 120 | 121 | self.img_context_token_id = None 122 | self.neftune_alpha = None 123 | self.num_samples = 0 124 | 125 | if config.use_backbone_lora: 126 | self.wrap_backbone_lora(r=config.use_backbone_lora, lora_alpha=2 * config.use_backbone_lora) 127 | 128 | if config.use_llm_lora: 129 | self.wrap_llm_lora(r=config.use_llm_lora, lora_alpha=2 * config.use_llm_lora) 130 | 131 | def wrap_backbone_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): 132 | lora_config = LoraConfig( 133 | r=r, 134 | target_modules=['attn.qkv', 'attn.proj', 'mlp.fc1', 'mlp.fc2'], 135 | lora_alpha=lora_alpha, 136 | lora_dropout=lora_dropout, 137 | ) 138 | self.vision_model = get_peft_model(self.vision_model, lora_config) 139 | self.vision_model.print_trainable_parameters() 140 | 141 | def wrap_llm_lora(self, r=128, lora_alpha=256, lora_dropout=0.05): 142 | lora_config = LoraConfig( 143 | r=r, 144 | target_modules=['self_attn.q_proj', 'self_attn.k_proj', 'self_attn.v_proj', 'self_attn.o_proj', 145 | 'mlp.gate_proj', 'mlp.down_proj', 'mlp.up_proj'], 146 | lora_alpha=lora_alpha, 147 | lora_dropout=lora_dropout, 148 | task_type='CAUSAL_LM' 149 | ) 150 | self.language_model = get_peft_model(self.language_model, lora_config) 151 | self.language_model.enable_input_require_grads() 152 | self.language_model.print_trainable_parameters() 153 | 154 | def forward( 155 | self, 156 | pixel_values: torch.FloatTensor, 157 | input_ids: torch.LongTensor = None, 158 | attention_mask: Optional[torch.Tensor] = None, 159 | position_ids: Optional[torch.LongTensor] = None, 160 | image_flags: Optional[torch.LongTensor] = None, 161 | past_key_values: Optional[List[torch.FloatTensor]] = None, 162 | labels: Optional[torch.LongTensor] = None, 163 | use_cache: Optional[bool] = None, 164 | output_attentions: Optional[bool] = None, 165 | output_hidden_states: Optional[bool] = None, 166 | return_dict: Optional[bool] = None, 167 | statistics: Optional[torch.LongTensor] = None, 168 | loss_weight: Optional[List] = None, 169 | loss_reduction_all_gather: Optional[bool] = False, 170 | ) -> Union[Tuple, CausalLMOutputWithVisualMask]: 171 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 172 | 173 | image_flags = image_flags.squeeze(-1) 174 | input_embeds = self.language_model.get_input_embeddings()(input_ids).clone() 175 | 176 | vit_embeds = self.extract_feature(pixel_values).to(input_embeds.dtype) 177 | vit_embeds = vit_embeds[image_flags == 1] 178 | vit_batch_size = pixel_values.shape[0] 179 | 180 | B, N, C = input_embeds.shape 181 | input_embeds = input_embeds.reshape(B * N, C) 182 | 183 | if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: 184 | print(f'dynamic ViT batch size: {vit_batch_size}, images per sample: {vit_batch_size / B}, dynamic token length: {N}') 185 | if statistics is not None: 186 | num_samples, num_padding_tokens, num_padding_images = statistics.tolist() 187 | self.num_samples += num_samples 188 | print(f'total_samples={self.num_samples}, {num_samples=}, {num_padding_tokens=}, {num_padding_images=}') 189 | 190 | input_ids = input_ids.reshape(B * N) 191 | selected = (input_ids == self.img_context_token_id) 192 | 193 | try: 194 | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds.reshape(-1, C) 195 | ignore_flag = False 196 | except Exception as e: 197 | vit_embeds = vit_embeds.reshape(-1, C) 198 | print(f'warning: {e}, input_embeds[selected].shape={input_embeds[selected].shape}, ' 199 | f'vit_embeds.shape={vit_embeds.shape}') 200 | n_token = selected.sum() 201 | input_embeds[selected] = input_embeds[selected] * 0.0 + vit_embeds[:n_token] 202 | ignore_flag = True 203 | 204 | input_embeds = input_embeds.reshape(B, N, C) 205 | 206 | if self.language_arch=='InternLM2VEForCausalLM': 207 | visual_token_mask = selected.reshape(B,N,1).to(input_embeds.dtype) 208 | outputs = self.language_model( 209 | inputs_embeds=input_embeds, 210 | attention_mask=attention_mask, 211 | position_ids=position_ids, 212 | past_key_values=past_key_values, 213 | use_cache=use_cache, 214 | output_attentions=output_attentions, 215 | output_hidden_states=output_hidden_states, 216 | return_dict=return_dict, 217 | visual_token_mask=visual_token_mask 218 | ) 219 | else: 220 | outputs = self.language_model( 221 | inputs_embeds=input_embeds, 222 | attention_mask=attention_mask, 223 | position_ids=position_ids, 224 | past_key_values=past_key_values, 225 | use_cache=use_cache, 226 | output_attentions=output_attentions, 227 | output_hidden_states=output_hidden_states, 228 | return_dict=return_dict 229 | ) 230 | logits = outputs.logits 231 | 232 | loss = None 233 | if labels is not None and loss_weight is not None: 234 | loss_weight = torch.tensor(loss_weight, dtype=torch.float32, device=labels.device) 235 | # Shift so that tokens < n predict n 236 | shift_logits = logits[..., :-1, :].contiguous() 237 | shift_labels = labels[..., 1:].contiguous() 238 | shift_weights = loss_weight[..., 1:].contiguous() 239 | # Flatten the tokens 240 | loss_fct = CrossEntropyLoss(reduction='none') 241 | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) 242 | shift_labels = shift_labels.view(-1) 243 | shift_weights = shift_weights.view(-1) 244 | # Enable model parallelism 245 | shift_labels = shift_labels.to(shift_logits.device) 246 | shift_weights = shift_weights.to(shift_logits.device) 247 | loss = loss_fct(shift_logits, shift_labels) 248 | 249 | shift_weights_sum = shift_weights.sum() 250 | if loss_reduction_all_gather: 251 | dist.all_reduce(shift_weights_sum, op=dist.ReduceOp.AVG) 252 | 253 | loss = loss * shift_weights 254 | loss = loss.sum() / shift_weights_sum 255 | if ignore_flag: 256 | loss = loss * 0.0 257 | elif labels is not None: 258 | # Shift so that tokens < n predict n 259 | shift_logits = logits[..., :-1, :].contiguous() 260 | shift_labels = labels[..., 1:].contiguous() 261 | # Flatten the tokens 262 | loss_fct = CrossEntropyLoss() 263 | shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size) 264 | shift_labels = shift_labels.view(-1) 265 | # Enable model parallelism 266 | shift_labels = shift_labels.to(shift_logits.device) 267 | loss = loss_fct(shift_logits, shift_labels) 268 | if ignore_flag: 269 | loss = loss * 0.0 270 | 271 | if not return_dict: 272 | output = (logits,) + outputs[1:] 273 | return (loss,) + output if loss is not None else output 274 | 275 | return CausalLMOutputWithVisualMask( 276 | loss=loss, 277 | logits=logits, 278 | past_key_values=outputs.past_key_values, 279 | hidden_states=outputs.hidden_states, 280 | attentions=outputs.attentions, 281 | visual_token_mask=selected.reshape(B,N,1).to(input_embeds.dtype) 282 | ) 283 | 284 | def pixel_shuffle(self, x, scale_factor=0.5): 285 | n, w, h, c = x.size() 286 | # N, W, H, C --> N, W, H * scale, C // scale 287 | x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) 288 | # N, W, H * scale, C // scale --> N, H * scale, W, C // scale 289 | x = x.permute(0, 2, 1, 3).contiguous() 290 | # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) 291 | x = x.view(n, int(h * scale_factor), int(w * scale_factor), 292 | int(c / (scale_factor * scale_factor))) 293 | if self.ps_version == 'v1': 294 | warnings.warn("In ps_version 'v1', the height and width have not been swapped back, " 295 | 'which results in a transposed image.') 296 | else: 297 | x = x.permute(0, 2, 1, 3).contiguous() 298 | return x 299 | 300 | def noised_embed(self, vit_embeds, noise_alpha=5): 301 | dims = torch.tensor(vit_embeds.size(1) * vit_embeds.size(2)) 302 | mag_norm = noise_alpha / torch.sqrt(dims) 303 | noise = torch.zeros_like(vit_embeds).uniform_(-mag_norm, mag_norm) 304 | return vit_embeds + noise 305 | 306 | def extract_feature(self, pixel_values): 307 | if self.select_layer == -1: 308 | vit_embeds = self.vision_model( 309 | pixel_values=pixel_values, 310 | output_hidden_states=False, 311 | return_dict=True).last_hidden_state 312 | else: 313 | vit_embeds = self.vision_model( 314 | pixel_values=pixel_values, 315 | output_hidden_states=True, 316 | return_dict=True).hidden_states[self.select_layer] 317 | if int(vit_embeds.shape[1] ** 0.5)**2 != vit_embeds.shape[1]: 318 | vit_embeds = vit_embeds[:, 1:, :] 319 | 320 | if self.training and self.neftune_alpha is not None: 321 | vit_embeds = self.noised_embed(vit_embeds, self.neftune_alpha) 322 | 323 | h = w = int(vit_embeds.shape[1] ** 0.5) 324 | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) 325 | vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) 326 | vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) 327 | vit_embeds = self.mlp1(vit_embeds)#.to(pixel_values.device) 328 | return vit_embeds 329 | 330 | def batch_chat(self, tokenizer, pixel_values, image_counts, questions, generation_config, history=None, 331 | return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', 332 | IMG_CONTEXT_TOKEN=''): 333 | if history is not None or return_history: 334 | print('Now multi-turn chat is not supported in batch_chat.') 335 | raise NotImplementedError 336 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) 337 | self.img_context_token_id = img_context_token_id 338 | 339 | from internvl.conversation import get_conv_template 340 | 341 | queries = [] 342 | image_bs = pixel_values.shape[0] 343 | 344 | for idx, image_count in enumerate(image_counts): 345 | image_token = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_count + IMG_END_TOKEN 346 | question = image_token + '\n' + questions[idx] 347 | template = get_conv_template(self.template) 348 | template.append_message(template.roles[0], question) 349 | template.append_message(template.roles[1], None) 350 | query = template.get_prompt() 351 | queries.append(query) 352 | tokenizer.padding_side = 'left' 353 | model_inputs = tokenizer(queries, return_tensors='pt', padding=True) 354 | input_ids = model_inputs['input_ids'].cuda() 355 | attention_mask = model_inputs['attention_mask'].cuda() 356 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) 357 | generation_config['eos_token_id'] = eos_token_id 358 | 359 | generation_output = self.generate( 360 | pixel_values=pixel_values, 361 | input_ids=input_ids, 362 | attention_mask=attention_mask, 363 | **generation_config 364 | ) 365 | responses = tokenizer.batch_decode(generation_output, skip_special_tokens=True) 366 | responses = [response.split(template.sep)[0].strip() for response in responses] 367 | return responses 368 | 369 | def chat(self, tokenizer, pixel_values, question, generation_config, history=None, return_history=False, 370 | IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN=''): 371 | 372 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) 373 | self.img_context_token_id = img_context_token_id 374 | 375 | from internvl.conversation import get_conv_template 376 | 377 | template = get_conv_template(self.template) 378 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) 379 | 380 | image_bs = pixel_values.shape[0] 381 | print(f'dynamic ViT batch size: {image_bs}') 382 | if history is None: 383 | history = [] 384 | image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_bs + IMG_END_TOKEN 385 | question = image_tokens + '\n' + question 386 | else: 387 | for (old_question, old_answer) in history: 388 | template.append_message(template.roles[0], old_question) 389 | template.append_message(template.roles[1], old_answer) 390 | template.append_message(template.roles[0], question) 391 | template.append_message(template.roles[1], None) 392 | query = template.get_prompt() 393 | model_inputs = tokenizer(query, return_tensors='pt') 394 | input_ids = model_inputs['input_ids'].cuda() 395 | attention_mask = model_inputs['attention_mask'].cuda() 396 | generation_config['eos_token_id'] = eos_token_id 397 | generation_output = self.generate( 398 | pixel_values=pixel_values, 399 | input_ids=input_ids, 400 | attention_mask=attention_mask, 401 | **generation_config 402 | ) 403 | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] 404 | response = response.split(template.sep)[0].strip() 405 | history.append((question, response)) 406 | if return_history: 407 | return response, history 408 | else: 409 | query_to_print = query.replace(image_tokens, '') 410 | print(query_to_print, response) 411 | return response 412 | return response 413 | 414 | def multi_image_chat(self, tokenizer, pixel_values, image_counts, question, generation_config, history=None, 415 | return_history=False, IMG_START_TOKEN='', IMG_END_TOKEN='', IMG_CONTEXT_TOKEN=''): 416 | 417 | img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) 418 | self.img_context_token_id = img_context_token_id 419 | 420 | from internvl.conversation import get_conv_template 421 | 422 | template = get_conv_template(self.template) 423 | eos_token_id = tokenizer.convert_tokens_to_ids(template.sep) 424 | 425 | if history is None: 426 | history = [] 427 | image_tokens = '' 428 | image_bs = pixel_values.shape[0] 429 | print(f'dynamic ViT batch size: {image_bs}, image_counts: {image_counts}') 430 | for idx, image_count in enumerate(image_counts): 431 | image_tokens += f' (图{idx+1}):' + IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * image_count + IMG_END_TOKEN 432 | question = image_tokens + '\n' + question 433 | else: 434 | for (old_question, old_answer) in history: 435 | template.append_message(template.roles[0], old_question) 436 | template.append_message(template.roles[1], old_answer) 437 | template.append_message(template.roles[0], question) 438 | template.append_message(template.roles[1], None) 439 | query = template.get_prompt() 440 | model_inputs = tokenizer(query, return_tensors='pt') 441 | input_ids = model_inputs['input_ids'].cuda() 442 | attention_mask = model_inputs['attention_mask'].cuda() 443 | generation_config['eos_token_id'] = eos_token_id 444 | 445 | generation_output = self.generate( 446 | pixel_values=pixel_values, 447 | input_ids=input_ids, 448 | attention_mask=attention_mask, 449 | **generation_config 450 | ) 451 | response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] 452 | response = response.split(template.sep)[0].strip() 453 | history.append((question, response)) 454 | if return_history: 455 | return response, history 456 | else: 457 | query_to_print = query.replace(image_tokens, '') 458 | print(query_to_print, response) 459 | return response 460 | return response 461 | 462 | @torch.no_grad() 463 | def generate( 464 | self, 465 | pixel_values: Optional[torch.FloatTensor] = None, 466 | input_ids: Optional[torch.FloatTensor] = None, 467 | attention_mask: Optional[torch.LongTensor] = None, 468 | visual_features: Optional[torch.FloatTensor] = None, 469 | generation_config: Optional[GenerationConfig] = None, 470 | output_hidden_states: Optional[bool] = None, 471 | return_dict: Optional[bool] = None, 472 | **generate_kwargs, 473 | ) -> torch.LongTensor: 474 | 475 | assert self.img_context_token_id is not None 476 | if pixel_values is not None: 477 | if visual_features is not None: 478 | vit_embeds = visual_features 479 | else: 480 | vit_embeds = self.extract_feature(pixel_values) 481 | 482 | input_embeds = self.language_model.get_input_embeddings()(input_ids) 483 | B, N, C = input_embeds.shape 484 | input_embeds = input_embeds.reshape(B * N, C) 485 | 486 | input_ids = input_ids.reshape(B * N) 487 | selected = (input_ids == self.img_context_token_id) 488 | assert selected.sum() != 0 489 | input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device) 490 | 491 | input_embeds = input_embeds.reshape(B, N, C) 492 | visual_token_mask = selected.reshape(B,N,1).to(input_embeds.dtype) 493 | else: 494 | input_embeds = self.language_model.get_input_embeddings()(input_ids) 495 | visual_token_mask = torch.zero_like(input_ids).reshape(B,N,1) 496 | 497 | if self.use_visual_token_mask: 498 | outputs = self.language_model.generate( 499 | inputs_embeds=input_embeds, 500 | attention_mask=attention_mask, 501 | generation_config=generation_config, 502 | output_hidden_states=output_hidden_states, 503 | return_dict=return_dict, 504 | use_cache=True, 505 | visual_token_mask=visual_token_mask, 506 | **generate_kwargs, 507 | ) 508 | else: 509 | outputs = self.language_model.generate( 510 | inputs_embeds=input_embeds, 511 | attention_mask=attention_mask, 512 | generation_config=generation_config, 513 | output_hidden_states=output_hidden_states, 514 | return_dict=return_dict, 515 | use_cache=True, 516 | **generate_kwargs, 517 | ) 518 | return outputs 519 | -------------------------------------------------------------------------------- /internvl/patch/__init__.py: -------------------------------------------------------------------------------- 1 | from .internlm2_packed_training_patch import replace_internlm2_attention_class 2 | from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn 3 | from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 4 | from .llama_rmsnorm_monkey_patch import \ 5 | replace_llama_rmsnorm_with_fused_rmsnorm 6 | from .pad_data_collator import concat_pad_data_collator, pad_data_collator 7 | from .qwen2_packed_training_patch import replace_qwen2_attention_class 8 | from .train_dataloader_patch import replace_train_dataloader, replace_unwapper_train_dataloader 9 | from .train_sampler_patch import replace_train_sampler, replace_sequence_train_sampler 10 | 11 | __all__ = ['replace_llama_attn_with_flash_attn', 12 | 'replace_llama_rmsnorm_with_fused_rmsnorm', 13 | 'replace_llama2_attn_with_flash_attn', 14 | 'replace_train_sampler', 15 | 'replace_sequence_train_sampler', 16 | 'replace_train_dataloader', 17 | 'replace_unwapper_train_dataloader', 18 | 'replace_internlm2_attention_class', 19 | 'replace_qwen2_attention_class', 20 | 'pad_data_collator', 21 | 'concat_pad_data_collator'] 22 | -------------------------------------------------------------------------------- /internvl/patch/internlm2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 3 | from internvl.model.internlm2.modeling_internlm2 import ( 4 | INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2, 5 | apply_rotary_pos_emb) 6 | 7 | 8 | # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2 9 | class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2): 10 | 11 | def _flash_attention_forward( 12 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None 13 | ): 14 | """ 15 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 16 | first unpad the input, then computes the attention scores and pad the final attention scores. 17 | 18 | Args: 19 | query_states (`torch.Tensor`): 20 | Input query states to be passed to Flash Attention API 21 | key_states (`torch.Tensor`): 22 | Input key states to be passed to Flash Attention API 23 | value_states (`torch.Tensor`): 24 | Input value states to be passed to Flash Attention API 25 | attention_mask (`torch.Tensor`): 26 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 27 | of the sequences in the batch. 28 | dropout (`int`, *optional*): 29 | Attention dropout 30 | softmax_scale (`float`, *optional*): 31 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 32 | """ 33 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 34 | query_states = query_states.squeeze(0) 35 | key_states = key_states.squeeze(0) 36 | value_states = value_states.squeeze(0) 37 | cu_seqlens = attention_mask.squeeze(0) 38 | 39 | with torch.no_grad(): 40 | max_seqlen = max([ 41 | cu_seqlens[idx+1] - cu_seqlens[idx] 42 | for idx in range(cu_seqlens.size(0) - 1) 43 | ]).item() 44 | 45 | # Contains at least one padding token in the sequence 46 | causal = self.is_causal and query_length != 1 47 | attn_output = flash_attn_varlen_func( 48 | q=query_states, 49 | k=key_states, 50 | v=value_states, 51 | cu_seqlens_q=cu_seqlens, 52 | cu_seqlens_k=cu_seqlens, 53 | max_seqlen_q=max_seqlen, 54 | max_seqlen_k=max_seqlen, 55 | dropout_p=dropout, 56 | softmax_scale=softmax_scale, 57 | causal=causal, 58 | ) 59 | 60 | query_states = query_states.unsqueeze(0) 61 | key_states = key_states.unsqueeze(0) 62 | value_states = value_states.unsqueeze(0) 63 | return attn_output 64 | 65 | 66 | def replace_internlm2_attention_class(): 67 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = InternLM2FlashAttention2ForPackedTraining 68 | print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!') 69 | -------------------------------------------------------------------------------- /internvl/patch/llama2_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is copied from: https://github.com/lm-sys/FastChat 3 | """ 4 | import warnings 5 | from typing import Optional, Tuple 6 | 7 | import torch 8 | from flash_attn import __version__ as flash_attn_version 9 | from flash_attn.bert_padding import pad_input, unpad_input 10 | from flash_attn.flash_attn_interface import (flash_attn_func, 11 | flash_attn_varlen_kvpacked_func) 12 | from transformers.models.llama.modeling_llama import (LlamaAttention, 13 | LlamaModel, rotate_half) 14 | 15 | 16 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids): 17 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] 18 | gather_indices = gather_indices.repeat( 19 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] 20 | ) 21 | bsz = gather_indices.shape[0] 22 | cos, sin = ( 23 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) 24 | for x in cos_sin 25 | ) 26 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) 27 | return q, k 28 | 29 | 30 | def forward( 31 | self, 32 | hidden_states: torch.Tensor, 33 | attention_mask: Optional[torch.Tensor] = None, 34 | position_ids: Optional[torch.Tensor] = None, 35 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 36 | output_attentions: bool = False, 37 | use_cache: bool = False, 38 | padding_mask: Optional[torch.Tensor] = None, 39 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 40 | if output_attentions: 41 | warnings.warn( 42 | 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.' 43 | ) 44 | 45 | bsz, q_len, _ = hidden_states.size() 46 | kv_heads = getattr(self, 'num_key_value_heads', self.num_heads) 47 | 48 | q, k, v = ( 49 | op(hidden_states).view(bsz, q_len, nh, self.head_dim) 50 | for op, nh in ( 51 | (self.q_proj, self.num_heads), 52 | (self.k_proj, kv_heads), 53 | (self.v_proj, kv_heads), 54 | ) 55 | ) 56 | # shape: (b, s, num_heads, head_dim) 57 | 58 | kv_seq_len = k.shape[1] 59 | past_kv_len = 0 60 | if past_key_value is not None: 61 | past_kv_len = past_key_value[0].shape[2] 62 | kv_seq_len += past_kv_len 63 | 64 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) 65 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) 66 | 67 | if past_key_value is not None: 68 | assert ( 69 | flash_attn_version >= '2.1.0' 70 | ), 'past_key_value support requires flash-attn >= 2.1.0' 71 | # reuse k, v 72 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) 73 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) 74 | 75 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None 76 | 77 | if attention_mask is None: 78 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( 79 | bsz, q_len, -1 80 | ) 81 | else: 82 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) 83 | # We can skip concat and call unpad twice but seems better to call unpad only once. 84 | kv, _, cu_k_lens, max_k = unpad_input( 85 | torch.stack((k, v), dim=2), attention_mask 86 | ) 87 | output_unpad = flash_attn_varlen_kvpacked_func( 88 | q, 89 | kv, 90 | cu_q_lens, 91 | cu_k_lens, 92 | max_s, 93 | max_k, 94 | 0.0, 95 | softmax_scale=None, 96 | causal=True, 97 | ) 98 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 99 | output = pad_input(output_unpad, indices, bsz, q_len) 100 | 101 | return self.o_proj(output), None, past_key_value 102 | 103 | 104 | # Disable the transformation of the attention mask in LlamaModel as flash attention 105 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward. 106 | def _prepare_decoder_attention_mask( 107 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 108 | ): 109 | # [bsz, seq_len] 110 | if past_key_values_length > 0 and attention_mask is not None: 111 | attention_mask = torch.cat( 112 | ( 113 | torch.full( 114 | (input_shape[0], past_key_values_length), 115 | True, 116 | dtype=attention_mask.dtype, 117 | device=attention_mask.device, 118 | ), 119 | attention_mask, 120 | ), 121 | dim=-1, 122 | ) 123 | 124 | if attention_mask is not None and torch.all(attention_mask): 125 | return None # This uses the faster call when training with full samples 126 | 127 | return attention_mask 128 | 129 | 130 | def replace_llama2_attn_with_flash_attn(): 131 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 132 | if cuda_major < 8: 133 | warnings.warn( 134 | 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' 135 | 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593' 136 | ) 137 | 138 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 139 | LlamaAttention.forward = forward 140 | 141 | 142 | def test(): 143 | from fastchat.train.llama_flash_attn_monkey_patch import \ 144 | forward as fastchat_forward 145 | from transformers.models.llama.configuration_llama import LlamaConfig 146 | 147 | config = LlamaConfig( 148 | hidden_size=1024, 149 | intermediate_size=128, 150 | num_hidden_layers=1, 151 | num_attention_heads=8, 152 | max_position_embeddings=16, 153 | ) 154 | device = torch.device('cuda') 155 | model = LlamaModel(config) 156 | attn = LlamaAttention(config).to(device).half() 157 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings 158 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( 159 | -1, seqlen 160 | ) 161 | 162 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 163 | for i in range(4): 164 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 165 | if i: 166 | mask[0, -i:] = False 167 | mask[1, :i] = False 168 | 169 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) 170 | ref, _, _ = attn.forward( 171 | hidden, attention_mask=lmask, position_ids=position_ids 172 | ) 173 | 174 | fast, _, _ = fastchat_forward( 175 | attn, hidden, attention_mask=mask, position_ids=position_ids 176 | ) 177 | 178 | lmask = _prepare_decoder_attention_mask( 179 | model, mask, hidden.shape[:2], hidden, 0 180 | ) 181 | test, _, _ = forward( 182 | attn, hidden, attention_mask=lmask, position_ids=position_ids 183 | ) 184 | 185 | print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}') 186 | print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}') 187 | print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}') 188 | print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}') 189 | print(f'allclose(fast, test) = {torch.allclose(fast, test)}') 190 | 191 | with torch.no_grad(): 192 | # Also check that past_kv is handled properly 193 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 194 | part_len = seqlen // 4 195 | assert part_len * 4 == seqlen 196 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 197 | mask[0, -2:] = False 198 | lmask = _prepare_decoder_attention_mask( 199 | model, mask, hidden.shape[:2], hidden, 0 200 | ) 201 | oneshot, _, _ = forward( 202 | attn, hidden, attention_mask=lmask, position_ids=position_ids 203 | ) 204 | parts = [] 205 | past_kv, past_kv_len = None, 0 206 | for i in range(4): 207 | start = part_len * i 208 | end = start + part_len 209 | hidden_part = hidden[:, start:end, ...] 210 | lmask = _prepare_decoder_attention_mask( 211 | model, 212 | mask[:, start:end], 213 | hidden_part.shape[:2], 214 | hidden_part, 215 | past_kv_len, 216 | ) 217 | part, _, past_kv = forward( 218 | attn, 219 | hidden_part.clone(), 220 | attention_mask=lmask, 221 | position_ids=position_ids[:, start:end], 222 | past_key_value=past_kv, 223 | use_cache=True, 224 | ) 225 | parts.append(part) 226 | past_kv_len = past_kv[0].shape[2] 227 | 228 | print( 229 | f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}' 230 | ) 231 | print( 232 | f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}' 233 | ) 234 | 235 | 236 | if __name__ == '__main__': 237 | test() 238 | -------------------------------------------------------------------------------- /internvl/patch/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import transformers 7 | from torch import nn 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | 11 | def forward( 12 | self, 13 | hidden_states: torch.Tensor, 14 | attention_mask: Optional[torch.Tensor] = None, 15 | position_ids: Optional[torch.Tensor] = None, 16 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 17 | output_attentions: bool = False, 18 | use_cache: bool = False, 19 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 20 | """Input shape: Batch x Time x Channel 21 | 22 | attention_mask: [bsz, q_len] 23 | """ 24 | from einops import rearrange 25 | try: # v1 26 | from flash_attn.flash_attn_interface import \ 27 | flash_attn_unpadded_qkvpacked_func 28 | except: # v2 29 | from flash_attn.flash_attn_interface import \ 30 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 31 | from flash_attn.bert_padding import pad_input, unpad_input 32 | 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | # [bsz, q_len, nh, hd] 51 | # [bsz, nh, q_len, hd] 52 | 53 | kv_seq_len = key_states.shape[-2] 54 | assert past_key_value is None, 'past_key_value is not supported' 55 | 56 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 57 | query_states, key_states = apply_rotary_pos_emb( 58 | query_states, key_states, cos, sin, position_ids 59 | ) 60 | # [bsz, nh, t, hd] 61 | assert not output_attentions, 'output_attentions is not supported' 62 | assert not use_cache, 'use_cache is not supported' 63 | 64 | # Flash attention codes from 65 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 66 | 67 | # transform the data into the format required by flash attention 68 | qkv = torch.stack( 69 | [query_states, key_states, value_states], dim=2 70 | ) # [bsz, nh, 3, q_len, hd] 71 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 72 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 73 | # the attention_mask should be the same as the key_padding_mask 74 | key_padding_mask = attention_mask 75 | 76 | if key_padding_mask is None: 77 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 78 | max_s = q_len 79 | cu_q_lens = torch.arange( 80 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 81 | ) 82 | output = flash_attn_unpadded_qkvpacked_func( 83 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 84 | ) 85 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 86 | else: 87 | nheads = qkv.shape[-2] 88 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 89 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 90 | x_unpad = rearrange( 91 | x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads 92 | ) 93 | output_unpad = flash_attn_unpadded_qkvpacked_func( 94 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 95 | ) 96 | output = rearrange( 97 | pad_input( 98 | rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len 99 | ), 100 | 'b s (h d) -> b s h d', 101 | h=nheads, 102 | ) 103 | return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None 104 | 105 | 106 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 107 | # requires the attention mask to be the same as the key_padding_mask 108 | def _prepare_decoder_attention_mask( 109 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 110 | ): 111 | # [bsz, seq_len] 112 | return attention_mask 113 | 114 | 115 | def forward_2( 116 | self, 117 | hidden_states: torch.Tensor, 118 | attention_mask: Optional[torch.Tensor] = None, 119 | position_ids: Optional[torch.LongTensor] = None, 120 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 121 | output_attentions: bool = False, 122 | use_cache: bool = False, 123 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 124 | bsz, q_len, _ = hidden_states.size() 125 | 126 | query_states = ( 127 | self.q_proj(hidden_states) 128 | .view(bsz, q_len, self.num_heads, self.head_dim) 129 | .transpose(1, 2) 130 | ) 131 | key_states = ( 132 | self.k_proj(hidden_states) 133 | .view(bsz, q_len, self.num_heads, self.head_dim) 134 | .transpose(1, 2) 135 | ) 136 | value_states = ( 137 | self.v_proj(hidden_states) 138 | .view(bsz, q_len, self.num_heads, self.head_dim) 139 | .transpose(1, 2) 140 | ) 141 | 142 | kv_seq_len = key_states.shape[-2] 143 | if past_key_value is not None: 144 | kv_seq_len += past_key_value[0].shape[-2] 145 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 146 | query_states, key_states = apply_rotary_pos_emb( 147 | query_states, key_states, cos, sin, position_ids 148 | ) 149 | 150 | assert not output_attentions, 'output_attentions is not supported' 151 | assert not use_cache, 'use_cache is not supported' 152 | assert past_key_value is None, 'past_key_value is not supported' 153 | 154 | if past_key_value is not None: 155 | # reuse k, v, self_attention 156 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 157 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 158 | 159 | past_key_value = (key_states, value_states) if use_cache else None 160 | if self.training: 161 | attn_output = F.scaled_dot_product_attention( 162 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True 163 | ) 164 | attn_weights = None 165 | else: 166 | attn_weights = torch.matmul( 167 | query_states, key_states.transpose(2, 3) 168 | ) / math.sqrt(self.head_dim) 169 | 170 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 171 | raise ValueError( 172 | f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is' 173 | f' {attn_weights.size()}' 174 | ) 175 | 176 | if attention_mask is not None: 177 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 178 | raise ValueError( 179 | f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' 180 | ) 181 | attn_weights = attn_weights + attention_mask 182 | attn_weights = torch.max( 183 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 184 | ) 185 | 186 | # upcast attention to fp32 187 | attn_weights = nn.functional.softmax( 188 | attn_weights, dim=-1, dtype=torch.float32 189 | ).to(query_states.dtype) 190 | attn_output = torch.matmul(attn_weights, value_states) 191 | 192 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 193 | raise ValueError( 194 | f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' 195 | f' {attn_output.size()}' 196 | ) 197 | 198 | attn_output = attn_output.transpose(1, 2) 199 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 200 | 201 | attn_output = self.o_proj(attn_output) 202 | 203 | if not output_attentions: 204 | attn_weights = None 205 | 206 | return attn_output, attn_weights, past_key_value 207 | 208 | 209 | def replace_llama_attn_with_flash_attn(): 210 | if hasattr(F, 'scaled_dot_product_attention'): 211 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2 212 | else: 213 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 214 | _prepare_decoder_attention_mask 215 | ) 216 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 217 | -------------------------------------------------------------------------------- /internvl/patch/llama_rmsnorm_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | 4 | def replace_llama_rmsnorm_with_fused_rmsnorm(): 5 | try: 6 | from functools import partial 7 | 8 | from apex.normalization import FusedRMSNorm 9 | LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa 10 | transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm 11 | print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm') 12 | except ImportError: 13 | # using the normal LlamaRMSNorm 14 | pass 15 | except Exception: 16 | print('discovered apex but it failed to load, falling back to LlamaRMSNorm') 17 | pass 18 | -------------------------------------------------------------------------------- /internvl/patch/pad_data_collator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | IGNORE_INDEX = -100 5 | 6 | 7 | def pad_data_collator(features, pad_id=0): 8 | 9 | first = features[0] 10 | batch = {} 11 | 12 | batch_lens = [feat['input_ids'].shape for feat in features] 13 | max_item_length = max(batch_lens)[0] 14 | for idx in range(len(features)): 15 | feat = features[idx] 16 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 17 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 18 | feat['input_ids'] = temp_input_ids 19 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 20 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 21 | feat['labels'] = temp_labels 22 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 23 | 24 | # Special handling for labels. 25 | # Ensure that tensor is created with the correct type 26 | # (it should be automatically the case, but let's make sure of it.) 27 | if 'label' in first and first['label'] is not None: 28 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 29 | dtype = torch.long if isinstance(label, int) else torch.float 30 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 31 | elif 'label_ids' in first and first['label_ids'] is not None: 32 | if isinstance(first['label_ids'], torch.Tensor): 33 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 34 | else: 35 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 36 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 37 | 38 | # Handling of all other possible keys. 39 | # Again, we will use the first element to figure out which key/values are not None for this model. 40 | for k, v in first.items(): 41 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): 42 | if isinstance(v, torch.Tensor): 43 | batch[k] = torch.stack([f[k] for f in features]) 44 | elif isinstance(v, np.ndarray): 45 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 46 | else: 47 | batch[k] = torch.tensor([f[k] for f in features]) 48 | return batch 49 | 50 | 51 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0): 52 | 53 | first = features[0] 54 | batch = {} 55 | 56 | batch_lens = [feat['input_ids'].shape for feat in features] 57 | max_item_length = max_item_length or max(batch_lens)[0] 58 | for idx in range(len(features)): 59 | feat = features[idx] 60 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 61 | temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids'] 62 | feat['input_ids'] = temp_input_ids 63 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 64 | temp_labels[:feat['labels'].shape[0]] = feat['labels'] 65 | feat['labels'] = temp_labels 66 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 67 | 68 | if 'position_ids' in feat: 69 | temp_position_ids = torch.LongTensor([pad_id] * max_item_length) 70 | temp_position_ids[:feat['position_ids'].shape[0]] = feat['position_ids'] 71 | feat['position_ids'] = temp_position_ids 72 | 73 | if 'loss_weight' in feat: 74 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length) 75 | temp_loss_weight[:feat['loss_weight'].shape[0]] = feat['loss_weight'] 76 | feat['loss_weight'] = temp_loss_weight 77 | 78 | # Special handling for labels. 79 | # Ensure that tensor is created with the correct type 80 | # (it should be automatically the case, but let's make sure of it.) 81 | if 'label' in first and first['label'] is not None: 82 | label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label'] 83 | dtype = torch.long if isinstance(label, int) else torch.float 84 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 85 | elif 'label_ids' in first and first['label_ids'] is not None: 86 | if isinstance(first['label_ids'], torch.Tensor): 87 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 88 | else: 89 | dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float 90 | batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype) 91 | 92 | # Handling of all other possible keys. 93 | # Again, we will use the first element to figure out which key/values are not None for this model. 94 | for k, v in first.items(): 95 | if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \ 96 | v is not None and not isinstance(v, str): 97 | if isinstance(v, torch.Tensor): 98 | batch[k] = torch.stack([f[k] for f in features]) 99 | elif isinstance(v, np.ndarray): 100 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 101 | else: 102 | batch[k] = torch.tensor([f[k] for f in features]) 103 | if k in ('pixel_values', 'image_flags'): 104 | if isinstance(v, torch.Tensor): 105 | batch[k] = torch.concat([f[k] for f in features]) 106 | elif isinstance(v, np.ndarray): 107 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 108 | else: 109 | batch[k] = torch.concat([f[k] for f in features]) 110 | return batch 111 | -------------------------------------------------------------------------------- /internvl/patch/qwen2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 3 | from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES, 4 | Qwen2FlashAttention2) 5 | 6 | 7 | # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 8 | class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2): 9 | 10 | def _flash_attention_forward( 11 | self, 12 | query_states, 13 | key_states, 14 | value_states, 15 | attention_mask, 16 | query_length, 17 | dropout=0.0, 18 | softmax_scale=None, 19 | use_sliding_windows=False, 20 | ): 21 | """ 22 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 23 | first unpad the input, then computes the attention scores and pad the final attention scores. 24 | 25 | Args: 26 | query_states (`torch.Tensor`): 27 | Input query states to be passed to Flash Attention API 28 | key_states (`torch.Tensor`): 29 | Input key states to be passed to Flash Attention API 30 | value_states (`torch.Tensor`): 31 | Input value states to be passed to Flash Attention API 32 | attention_mask (`torch.Tensor`): 33 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 34 | position of padding tokens and 1 for the position of non-padding tokens. 35 | dropout (`int`, *optional*): 36 | Attention dropout 37 | softmax_scale (`float`, *optional*): 38 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 39 | use_sliding_windows (`bool`, *optional*): 40 | Whether to activate sliding window attention. 41 | """ 42 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 43 | query_states = query_states.squeeze(0) 44 | key_states = key_states.squeeze(0) 45 | value_states = value_states.squeeze(0) 46 | cu_seqlens = attention_mask.squeeze(0) 47 | 48 | with torch.no_grad(): 49 | max_seqlen = max([ 50 | cu_seqlens[idx+1] - cu_seqlens[idx] 51 | for idx in range(cu_seqlens.size(0) - 1) 52 | ]).item() 53 | 54 | if not self._flash_attn_uses_top_left_mask: 55 | causal = self.is_causal 56 | else: 57 | causal = self.is_causal and query_length != 1 58 | 59 | # Decide whether to use SWA or not by layer index. 60 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 61 | use_sliding_windows = False 62 | 63 | if not use_sliding_windows: 64 | attn_output = flash_attn_varlen_func( 65 | q=query_states, 66 | k=key_states, 67 | v=value_states, 68 | cu_seqlens_q=cu_seqlens, 69 | cu_seqlens_k=cu_seqlens, 70 | max_seqlen_q=max_seqlen, 71 | max_seqlen_k=max_seqlen, 72 | dropout_p=dropout, 73 | softmax_scale=softmax_scale, 74 | causal=causal, 75 | ) 76 | else: 77 | attn_output = flash_attn_varlen_func( 78 | q=query_states, 79 | k=key_states, 80 | v=value_states, 81 | cu_seqlens_q=cu_seqlens, 82 | cu_seqlens_k=cu_seqlens, 83 | max_seqlen_q=max_seqlen, 84 | max_seqlen_k=max_seqlen, 85 | dropout_p=dropout, 86 | softmax_scale=softmax_scale, 87 | causal=causal, 88 | window_size=(self.config.sliding_window, self.config.sliding_window), 89 | ) 90 | 91 | query_states = query_states.unsqueeze(0) 92 | key_states = key_states.unsqueeze(0) 93 | value_states = value_states.unsqueeze(0) 94 | return attn_output 95 | 96 | 97 | def replace_qwen2_attention_class(): 98 | QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining 99 | print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!') 100 | -------------------------------------------------------------------------------- /internvl/patch/train_dataloader_patch.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import torch 3 | import transformers 4 | from torch.utils.data import DataLoader 5 | from transformers.trainer import is_datasets_available, seed_worker 6 | 7 | 8 | def get_train_dataloader(self) -> DataLoader: 9 | """ 10 | Returns the training [`~torch.utils.data.DataLoader`]. 11 | 12 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 13 | training if necessary) otherwise. 14 | 15 | Subclass and override this method if you want to inject some custom behavior. 16 | """ 17 | if self.train_dataset is None: 18 | raise ValueError('Trainer: training requires a train_dataset.') 19 | 20 | train_dataset = self.train_dataset 21 | data_collator = self.data_collator 22 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 23 | train_dataset = self._remove_unused_columns(train_dataset, description='training') 24 | else: 25 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training') 26 | 27 | dataloader_params = { 28 | 'batch_size': self._train_batch_size, 29 | 'collate_fn': data_collator, 30 | 'num_workers': self.args.dataloader_num_workers, 31 | 'pin_memory': self.args.dataloader_pin_memory, 32 | 'persistent_workers': self.args.dataloader_persistent_workers, 33 | } 34 | 35 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 36 | dataloader_params['sampler'] = self._get_train_sampler() 37 | dataloader_params['drop_last'] = self.args.dataloader_drop_last 38 | dataloader_params['worker_init_fn'] = seed_worker 39 | 40 | if self.args.use_packed_ds: 41 | return DataLoader(train_dataset, **dataloader_params) 42 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 43 | 44 | 45 | 46 | def get_unwapper_train_dataloader(self) -> DataLoader: 47 | """ 48 | Returns the training [`~torch.utils.data.DataLoader`]. 49 | 50 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 51 | training if necessary) otherwise. 52 | 53 | Subclass and override this method if you want to inject some custom behavior. 54 | """ 55 | if self.train_dataset is None: 56 | raise ValueError('Trainer: training requires a train_dataset.') 57 | 58 | train_dataset = self.train_dataset 59 | data_collator = self.data_collator 60 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 61 | train_dataset = self._remove_unused_columns(train_dataset, description='training') 62 | else: 63 | data_collator = self._get_collator_with_removed_columns(data_collator, description='training') 64 | 65 | dataloader_params = { 66 | 'batch_size': self._train_batch_size, 67 | 'collate_fn': data_collator, 68 | 'num_workers': self.args.dataloader_num_workers, 69 | 'pin_memory': self.args.dataloader_pin_memory, 70 | 'persistent_workers': self.args.dataloader_persistent_workers, 71 | } 72 | 73 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 74 | dataloader_params['sampler'] = self._get_train_sampler() 75 | dataloader_params['drop_last'] = self.args.dataloader_drop_last 76 | dataloader_params['worker_init_fn'] = seed_worker 77 | 78 | return DataLoader(train_dataset, **dataloader_params) 79 | 80 | 81 | def replace_train_dataloader(): 82 | transformers.Trainer.get_train_dataloader = get_train_dataloader 83 | print('Replace train dataloader!!') 84 | 85 | 86 | def replace_unwapper_train_dataloader(): 87 | transformers.Trainer.get_train_dataloader = get_unwapper_train_dataloader 88 | print('Replace train dataloader!!') -------------------------------------------------------------------------------- /internvl/patch/train_sampler_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | import transformers 5 | from torch.utils.data import Dataset, Sampler,SequentialSampler 6 | from transformers.tokenization_utils_base import BatchEncoding 7 | from transformers.trainer import (LengthGroupedSampler, RandomSampler, 8 | has_length) 9 | from transformers.trainer_pt_utils import logger 10 | 11 | 12 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38 13 | def split_to_even_chunks(indices, lengths, num_chunks): 14 | """ 15 | Split a list of indices into `chunks` chunks of roughly equal lengths. 16 | """ 17 | 18 | if len(indices) % num_chunks != 0: 19 | return [indices[i::num_chunks] for i in range(num_chunks)] 20 | 21 | num_indices_per_chunk = len(indices) // num_chunks 22 | 23 | chunks = [[] for _ in range(num_chunks)] 24 | chunks_lengths = [0 for _ in range(num_chunks)] 25 | for index in indices: 26 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 27 | chunks[shortest_chunk].append(index) 28 | chunks_lengths[shortest_chunk] += lengths[index] 29 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 30 | chunks_lengths[shortest_chunk] = float('inf') 31 | 32 | return chunks 33 | 34 | 35 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88 36 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 37 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 38 | indices = torch.randperm(len(lengths), generator=generator) 39 | megabatch_size = world_size * batch_size 40 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 41 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 42 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 43 | 44 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 45 | 46 | 47 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99 48 | class LengthGroupedSampler(Sampler): 49 | r""" 50 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 51 | keeping a bit of randomness. 52 | """ 53 | 54 | def __init__( 55 | self, 56 | batch_size: int, 57 | world_size: int, 58 | dataset: Optional[Dataset] = None, 59 | lengths: Optional[List[int]] = None, 60 | model_input_name: Optional[str] = None, 61 | generator=None, 62 | ): 63 | if dataset is None and lengths is None: 64 | raise ValueError('One of dataset and lengths must be provided.') 65 | 66 | self.batch_size = batch_size 67 | if lengths is None: 68 | model_input_name = model_input_name if model_input_name is not None else 'input_ids' 69 | if ( 70 | not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding)) 71 | or model_input_name not in dataset[0] 72 | ): 73 | raise ValueError( 74 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an ' 75 | f"'{model_input_name}' key." 76 | ) 77 | lengths = [len(feature[model_input_name]) for feature in dataset] 78 | elif isinstance(lengths, torch.Tensor): 79 | logger.info( 80 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' 81 | ) 82 | lengths = lengths.tolist() 83 | self.world_size = world_size 84 | self.lengths = lengths 85 | self.generator = generator 86 | 87 | def __len__(self): 88 | return len(self.lengths) 89 | 90 | def __iter__(self): 91 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 92 | return iter(indices) 93 | 94 | 95 | # patch trainer 96 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 97 | if self.train_dataset is None or not has_length(self.train_dataset): 98 | return None 99 | # Build the sampler. 100 | if self.args.group_by_length: 101 | lengths = [] 102 | for dataset in self.train_dataset.datasets: 103 | lengths = lengths + dataset.length 104 | model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 105 | return LengthGroupedSampler( 106 | self.args.train_batch_size, 107 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 108 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, 109 | dataset=self.train_dataset, 110 | lengths=lengths, 111 | model_input_name=model_input_name, 112 | ) 113 | else: 114 | return RandomSampler(self.train_dataset) 115 | 116 | # patch trainer 117 | def _get_sequence_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 118 | if self.train_dataset is None or not has_length(self.train_dataset): 119 | return None 120 | 121 | # Build the sampler. 122 | print("Using SequentialSampler") 123 | return SequentialSampler(self.train_dataset) 124 | 125 | 126 | def replace_train_sampler(): 127 | transformers.Trainer._get_train_sampler = _get_train_sampler 128 | print('Replace train sampler!!') 129 | 130 | def replace_sequence_train_sampler(): 131 | transformers.Trainer._get_train_sampler = _get_sequence_train_sampler 132 | print('Replace train sampler!!') -------------------------------------------------------------------------------- /internvl/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/internvl/serve/__init__.py -------------------------------------------------------------------------------- /internvl/serve/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = '.' 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = '' 10 | DEFAULT_IMAGE_PATCH_TOKEN = '' 11 | DEFAULT_IM_START_TOKEN = '' 12 | DEFAULT_IM_END_TOKEN = '' 13 | IMAGE_PLACEHOLDER = '' 14 | -------------------------------------------------------------------------------- /internvl/serve/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import torch 5 | from PIL import Image 6 | from transformers import StoppingCriteria 7 | 8 | from .constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def expand2square(pil_img, background_color): 16 | width, height = pil_img.size 17 | if width == height: 18 | return pil_img 19 | elif width > height: 20 | result = Image.new(pil_img.mode, (width, width), background_color) 21 | result.paste(pil_img, (0, (width - height) // 2)) 22 | return result 23 | else: 24 | result = Image.new(pil_img.mode, (height, height), background_color) 25 | result.paste(pil_img, ((height - width) // 2, 0)) 26 | return result 27 | 28 | 29 | def process_images(images, image_processor, model_cfg): 30 | image_aspect_ratio = getattr(model_cfg, 'image_aspect_ratio', None) 31 | new_images = [] 32 | if image_aspect_ratio == 'pad': 33 | for image in images: 34 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 35 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 36 | new_images.append(image) 37 | else: 38 | return image_processor(images, return_tensors='pt')['pixel_values'] 39 | if all(x.shape == new_images[0].shape for x in new_images): 40 | new_images = torch.stack(new_images, dim=0) 41 | return new_images 42 | 43 | 44 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, 45 | num_image_tokens=None, return_tensors=None): 46 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 47 | 48 | def insert_separator(X, sep): 49 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 50 | 51 | input_ids = [] 52 | offset = 0 53 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 54 | offset = 1 55 | input_ids.append(prompt_chunks[0][0]) 56 | 57 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + num_image_tokens)): 58 | input_ids.extend(x[offset:]) 59 | 60 | if return_tensors is not None: 61 | if return_tensors == 'pt': 62 | return torch.tensor(input_ids, dtype=torch.long) 63 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 64 | return input_ids 65 | 66 | 67 | def get_model_name_from_path(model_path): 68 | model_path = model_path.strip('/') 69 | model_paths = model_path.split('/') 70 | if model_paths[-1].startswith('checkpoint-'): 71 | return model_paths[-2] + '_' + model_paths[-1] 72 | else: 73 | return model_paths[-1] 74 | 75 | 76 | class KeywordsStoppingCriteria(StoppingCriteria): 77 | def __init__(self, keywords, tokenizer, input_ids): 78 | self.keywords = keywords 79 | self.keyword_ids = [] 80 | self.max_keyword_len = 0 81 | for keyword in keywords: 82 | cur_keyword_ids = tokenizer(keyword).input_ids 83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 84 | cur_keyword_ids = cur_keyword_ids[1:] 85 | if len(cur_keyword_ids) > self.max_keyword_len: 86 | self.max_keyword_len = len(cur_keyword_ids) 87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 88 | self.tokenizer = tokenizer 89 | self.start_len = input_ids.shape[1] 90 | 91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 92 | assert output_ids.shape[0] == 1, 'Only support batch size 1 (yet)' 93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 95 | for keyword_id in self.keyword_ids: 96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 97 | return True 98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 99 | for keyword in self.keywords: 100 | if keyword in outputs: 101 | return True 102 | return False 103 | -------------------------------------------------------------------------------- /internvl/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import threading 8 | import time 9 | import uuid 10 | from functools import partial 11 | from threading import Thread 12 | 13 | import requests 14 | import torch 15 | import uvicorn 16 | from fastapi import BackgroundTasks, FastAPI, Request 17 | from fastapi.responses import StreamingResponse 18 | from internvl.train.dataset import dynamic_preprocess 19 | from transformers import (AutoTokenizer, CLIPImageProcessor, 20 | TextIteratorStreamer) 21 | 22 | from ..model.internvl_chat import InternVLChatModel 23 | from .constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, 24 | DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, 25 | IMAGE_TOKEN_INDEX, WORKER_HEART_BEAT_INTERVAL) 26 | from .mm_utils import (KeywordsStoppingCriteria, load_image_from_base64, 27 | process_images, tokenizer_image_token) 28 | from .utils import build_logger, pretty_print_semaphore, server_error_msg 29 | 30 | GB = 1 << 30 31 | 32 | worker_id = str(uuid.uuid4())[:6] 33 | logger = build_logger('model_worker', f'model_worker_{worker_id}.log') 34 | global_counter = 0 35 | 36 | model_semaphore = None 37 | 38 | 39 | def heart_beat_worker(controller): 40 | 41 | while True: 42 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 43 | controller.send_heart_beat() 44 | 45 | 46 | class ModelWorker: 47 | def __init__(self, controller_addr, worker_addr, 48 | worker_id, no_register, 49 | model_path, model_base, model_name, 50 | load_8bit, load_4bit, device): 51 | self.controller_addr = controller_addr 52 | self.worker_addr = worker_addr 53 | self.worker_id = worker_id 54 | if model_path.endswith('/'): 55 | model_path = model_path[:-1] 56 | if model_name is None: 57 | model_paths = model_path.split('/') 58 | if model_paths[-1].startswith('checkpoint-'): 59 | self.model_name = model_paths[-2] + '_' + model_paths[-1] 60 | else: 61 | self.model_name = model_paths[-1] 62 | else: 63 | self.model_name = model_name 64 | 65 | logger.info(f'Loading the model {self.model_name} on worker {worker_id} ...') 66 | 67 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=False) 68 | if device == 'auto': 69 | import os 70 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 71 | # This can make distributed deployment work properly, wonder why 72 | self.model = InternVLChatModel.from_pretrained( 73 | model_path, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map='auto').eval() 74 | else: 75 | self.model = InternVLChatModel.from_pretrained( 76 | model_path, load_in_8bit=load_8bit, torch_dtype=torch.float16).eval() 77 | if not load_8bit and not device == 'auto': 78 | self.model = self.model.cuda() 79 | self.image_size = self.model.config.force_image_size 80 | self.image_processor = CLIPImageProcessor( 81 | crop_size=self.image_size, do_center_crop=True, do_normalize=True, do_resize=True, 82 | image_mean=[0.485, 0.456, 0.406], image_std=[0.229, 0.224, 0.225], size=self.image_size 83 | ) 84 | self.context_len = 12800 85 | self.is_multimodal = True 86 | 87 | if not no_register: 88 | self.register_to_controller() 89 | self.heart_beat_thread = threading.Thread( 90 | target=heart_beat_worker, args=(self,)) 91 | self.heart_beat_thread.start() 92 | 93 | def register_to_controller(self): 94 | logger.info('Register to controller') 95 | 96 | url = self.controller_addr + '/register_worker' 97 | data = { 98 | 'worker_name': self.worker_addr, 99 | 'check_heart_beat': True, 100 | 'worker_status': self.get_status() 101 | } 102 | r = requests.post(url, json=data) 103 | assert r.status_code == 200 104 | 105 | def send_heart_beat(self): 106 | logger.info(f'Send heart beat. Models: {[self.model_name]}. ' 107 | f'Semaphore: {pretty_print_semaphore(model_semaphore)}. ' 108 | f'global_counter: {global_counter}') 109 | 110 | url = self.controller_addr + '/receive_heart_beat' 111 | 112 | while True: 113 | try: 114 | ret = requests.post(url, json={ 115 | 'worker_name': self.worker_addr, 116 | 'queue_length': self.get_queue_length()}, timeout=5) 117 | exist = ret.json()['exist'] 118 | break 119 | except requests.exceptions.RequestException as e: 120 | logger.error(f'heart beat error: {e}') 121 | time.sleep(5) 122 | 123 | if not exist: 124 | self.register_to_controller() 125 | 126 | def get_queue_length(self): 127 | if model_semaphore is None: 128 | return 0 129 | else: 130 | return args.limit_model_concurrency - model_semaphore._value + (len( 131 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 132 | 133 | def get_status(self): 134 | return { 135 | 'model_names': [self.model_name], 136 | 'speed': 1, 137 | 'queue_length': self.get_queue_length(), 138 | } 139 | 140 | @torch.inference_mode() 141 | def generate_stream(self, params): 142 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 143 | 144 | prompt = params['prompt'] 145 | max_input_tiles = params['max_input_tiles'] 146 | logger.info(f'max_input_tiles: {max_input_tiles}') 147 | ori_prompt = prompt 148 | images = params.get('images', None) 149 | num_image_tokens = 0 150 | if images is not None and len(images) > 0 and self.is_multimodal: 151 | if len(images) > 0: 152 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 153 | raise ValueError('Number of images does not match number of tokens in prompt') 154 | logger.info(f'dynamic_image_size: {model.config.dynamic_image_size}') 155 | logger.info(f'use_thumbnail: {model.config.use_thumbnail}') 156 | images = [load_image_from_base64(image) for image in images] 157 | if model.config.dynamic_image_size: 158 | images = dynamic_preprocess( 159 | images[0], image_size=self.image_size, max_num=max_input_tiles, 160 | use_thumbnail=model.config.use_thumbnail) 161 | images = [item.resize((self.image_size, self.image_size)) for item in images] 162 | logger.info(f'Resize images to {self.image_size}x{self.image_size}') 163 | images = process_images(images, image_processor, model.config) 164 | 165 | if type(images) is list: 166 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 167 | else: 168 | images = images.to(self.model.device, dtype=torch.float16) 169 | # images = torch.concat(images) 170 | logger.info(f'Split images to {images.shape}') 171 | 172 | replace_token = DEFAULT_IMAGE_TOKEN 173 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 174 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 175 | logger.info(prompt) 176 | num_image_tokens = model.num_image_token * images.size(0) 177 | model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN) 178 | else: 179 | images = None 180 | image_args = {'pixel_values': images} 181 | else: 182 | images = None 183 | image_args = {} 184 | 185 | temperature = float(params.get('temperature', 1.0)) 186 | top_p = float(params.get('top_p', 1.0)) 187 | max_context_length = getattr(model.config, 'max_position_embeddings', 16384) 188 | max_new_tokens = int(params.get('max_new_tokens', 1024)) 189 | stop_str = params.get('stop', None) 190 | do_sample = True if temperature > 0.001 else False 191 | logger.info(f'num_image_tokens: {num_image_tokens}') 192 | logger.info(f'stop_str: {stop_str}') 193 | eos_token_id = tokenizer.convert_tokens_to_ids(stop_str) 194 | 195 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, num_image_tokens, return_tensors='pt').unsqueeze(0).cuda() 196 | input_ids[input_ids==IMAGE_TOKEN_INDEX] = model.img_context_token_id 197 | 198 | keywords = [stop_str] 199 | # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 200 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 201 | 202 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1]) 203 | logger.info(f'max_new_tokens: {max_new_tokens}') 204 | if max_new_tokens < 1: 205 | yield json.dumps({'text': ori_prompt + 'Exceeds max token length. Please start a new conversation, thanks.', 'error_code': 0}).encode() + b'\0' 206 | return 207 | 208 | thread = Thread(target=model.generate, kwargs=dict( 209 | input_ids=input_ids, 210 | do_sample=do_sample, 211 | temperature=temperature, 212 | repetition_penalty=1.0, 213 | top_p=top_p, 214 | max_new_tokens=max_new_tokens, 215 | streamer=streamer, 216 | eos_token_id=eos_token_id, 217 | **image_args 218 | )) 219 | thread.start() 220 | 221 | generated_text = ori_prompt 222 | for new_text in streamer: 223 | generated_text += new_text 224 | if generated_text.endswith(stop_str): 225 | generated_text = generated_text[:-len(stop_str)] 226 | yield json.dumps({'text': generated_text, 'error_code': 0}).encode() + b'\0' 227 | 228 | def generate_stream_gate(self, params): 229 | try: 230 | for x in self.generate_stream(params): 231 | yield x 232 | except ValueError as e: 233 | print('Caught ValueError:', e) 234 | ret = { 235 | 'text': server_error_msg, 236 | 'error_code': 1, 237 | } 238 | yield json.dumps(ret).encode() + b'\0' 239 | except torch.cuda.CudaError as e: 240 | print('Caught torch.cuda.CudaError:', e) 241 | ret = { 242 | 'text': server_error_msg, 243 | 'error_code': 1, 244 | } 245 | yield json.dumps(ret).encode() + b'\0' 246 | except Exception as e: 247 | print('Caught Unknown Error', e) 248 | ret = { 249 | 'text': server_error_msg, 250 | 'error_code': 1, 251 | } 252 | yield json.dumps(ret).encode() + b'\0' 253 | 254 | 255 | app = FastAPI() 256 | 257 | 258 | def release_model_semaphore(fn=None): 259 | model_semaphore.release() 260 | if fn is not None: 261 | fn() 262 | 263 | 264 | @app.post('/worker_generate_stream') 265 | async def generate_stream(request: Request): 266 | global model_semaphore, global_counter 267 | global_counter += 1 268 | params = await request.json() 269 | 270 | if model_semaphore is None: 271 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 272 | await model_semaphore.acquire() 273 | worker.send_heart_beat() 274 | generator = worker.generate_stream_gate(params) 275 | background_tasks = BackgroundTasks() 276 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 277 | return StreamingResponse(generator, background=background_tasks) 278 | 279 | 280 | @app.post('/worker_get_status') 281 | async def get_status(request: Request): 282 | return worker.get_status() 283 | 284 | 285 | if __name__ == '__main__': 286 | parser = argparse.ArgumentParser() 287 | parser.add_argument('--host', type=str, default='localhost') 288 | parser.add_argument('--port', type=int, default=21002) 289 | parser.add_argument('--worker-address', type=str, 290 | default='http://localhost:21002') 291 | parser.add_argument('--controller-address', type=str, 292 | default='http://localhost:21001') 293 | parser.add_argument('--model-path', type=str, default='facebook/opt-350m') 294 | parser.add_argument('--model-base', type=str, default=None) 295 | parser.add_argument('--model-name', type=str) 296 | parser.add_argument('--device', type=str, default='cuda') 297 | parser.add_argument('--multi-modal', action='store_true', help='Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.') 298 | parser.add_argument('--limit-model-concurrency', type=int, default=5) 299 | parser.add_argument('--stream-interval', type=int, default=1) 300 | parser.add_argument('--no-register', action='store_true') 301 | parser.add_argument('--load-8bit', action='store_true') 302 | parser.add_argument('--load-4bit', action='store_true') 303 | args = parser.parse_args() 304 | logger.info(f'args: {args}') 305 | 306 | if args.multi_modal: 307 | logger.warning('Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.') 308 | 309 | worker = ModelWorker(args.controller_address, 310 | args.worker_address, 311 | worker_id, 312 | args.no_register, 313 | args.model_path, 314 | args.model_base, 315 | args.model_name, 316 | args.load_8bit, 317 | args.load_4bit, 318 | args.device) 319 | uvicorn.run(app, host=args.host, port=args.port, log_level='info') 320 | -------------------------------------------------------------------------------- /internvl/serve/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from .constants import LOGDIR 10 | 11 | server_error_msg = '**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**' 12 | moderation_msg = 'YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN.' 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt='%(asctime)s | %(levelname)s | %(name)s | %(message)s', 22 | datefmt='%Y-%m-%d %H:%M:%S', 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger('stdout') 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger('stderr') 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) 99 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = 'https://api.openai.com/v1/moderations' 107 | headers = {'Content-Type': 'application/json', 108 | 'Authorization': 'Bearer ' + os.environ['OPENAI_API_KEY']} 109 | text = text.replace('\n', '') 110 | data = '{' + '"input": ' + f'"{text}"' + '}' 111 | data = data.encode('utf-8') 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()['results'][0]['flagged'] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return 'None' 126 | return f'Semaphore(value={semaphore._value}, locked={semaphore.locked()})' 127 | -------------------------------------------------------------------------------- /internvl/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/Mono-InternVL/fcd0381544fc83ac544bade229a70d7a05aa4614/internvl/train/__init__.py -------------------------------------------------------------------------------- /internvl/train/constants.py: -------------------------------------------------------------------------------- 1 | IMG_CONTEXT_TOKEN = '' 2 | IMG_START_TOKEN = '' 3 | IMG_END_TOKEN = '' 4 | QUAD_START_TOKEN = '' 5 | QUAD_END_TOKEN = '' 6 | REF_START_TOKEN = '' 7 | REF_END_TOKEN = '' 8 | BOX_START_TOKEN = '' 9 | BOX_END_TOKEN = '' 10 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 11 | IMAGENET_STD = (0.229, 0.224, 0.225) 12 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) 13 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711) 14 | SIGLIP_MEAN = (0.5, 0.5, 0.5) 15 | SIGLIP_STD = (0.5, 0.5, 0.5) 16 | -------------------------------------------------------------------------------- /internvl/train/dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | from transformers.trainer_pt_utils import LabelSmoother 4 | 5 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 6 | from typing import Dict 7 | 8 | import torch 9 | import torchvision.transforms as T 10 | import transformers 11 | from internvl.conversation import get_conv_template 12 | from PIL import Image 13 | from torch.utils.data import ConcatDataset, WeightedRandomSampler 14 | from torchvision.transforms.functional import InterpolationMode 15 | 16 | from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD, 17 | IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN, 18 | SIGLIP_MEAN, SIGLIP_STD) 19 | 20 | import sys 21 | 22 | 23 | class WeightedConcatDataset(ConcatDataset): 24 | def __init__(self, datasets, weights): 25 | super().__init__(datasets) 26 | self.weights = torch.DoubleTensor(weights) 27 | self.total_size = sum(len(d) for d in datasets) 28 | self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True) 29 | 30 | def __iter__(self): 31 | return iter(self.sampler) 32 | 33 | def __len__(self): 34 | return self.total_size 35 | 36 | def expand2square(pil_img, background_color): 37 | width, height = pil_img.size 38 | if width == height: 39 | return pil_img 40 | elif width > height: 41 | result = Image.new(pil_img.mode, (width, width), background_color) 42 | result.paste(pil_img, (0, (width - height) // 2)) 43 | return result 44 | else: 45 | result = Image.new(pil_img.mode, (height, height), background_color) 46 | result.paste(pil_img, ((height - width) // 2, 0)) 47 | return result 48 | 49 | 50 | def simulate_jpeg_degradation(quality): 51 | def jpeg_degrade(img): 52 | with io.BytesIO() as output: 53 | img.convert('RGB').save(output, format='JPEG', quality=quality) 54 | output.seek(0) 55 | img_jpeg = Image.open(output).copy() 56 | return img_jpeg 57 | return jpeg_degrade 58 | 59 | 60 | # Define the JPEG compression quality range, pre-create all JPEG compression functions 61 | qualities = list(range(75, 101)) 62 | jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities} 63 | 64 | 65 | def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'): 66 | if normalize_type == 'imagenet': 67 | MEAN, STD = IMAGENET_MEAN, IMAGENET_STD 68 | elif normalize_type == 'clip': 69 | MEAN, STD = CLIP_MEAN, CLIP_STD 70 | elif normalize_type == 'siglip': 71 | MEAN, STD = SIGLIP_MEAN, SIGLIP_STD 72 | else: 73 | raise NotImplementedError 74 | if is_train: # use data augumentation 75 | transform = T.Compose([ 76 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 77 | T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]), 78 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 79 | T.ToTensor(), 80 | T.Normalize(mean=MEAN, std=STD) 81 | ]) 82 | else: 83 | if pad2square is False: # now we use this transform function by default 84 | transform = T.Compose([ 85 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 86 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 87 | T.ToTensor(), 88 | T.Normalize(mean=MEAN, std=STD) 89 | ]) 90 | else: 91 | transform = T.Compose([ 92 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 93 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))), 94 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 95 | T.ToTensor(), 96 | T.Normalize(mean=MEAN, std=STD) 97 | ]) 98 | 99 | return transform 100 | 101 | 102 | def preprocess( 103 | template_name, 104 | sources, 105 | tokenizer: transformers.PreTrainedTokenizer, 106 | num_image_token_list: list, 107 | text_only: bool = False, 108 | group_by_length: bool = False, 109 | use_packed_ds: bool = False, 110 | ds_name: str = None, 111 | num_image: int = 1 112 | ) -> Dict: 113 | conv = get_conv_template(template_name) 114 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} 115 | 116 | # Apply prompt templates 117 | conversations = [] 118 | for i, source in enumerate(sources): 119 | if roles[source[0]['from']] != conv.roles[0]: 120 | # Skip the first one if it is not from human 121 | source = source[1:] 122 | 123 | conv.messages = [] 124 | for j, sentence in enumerate(source): 125 | role = roles[sentence['from']] 126 | assert role == conv.roles[j % 2], f'{i}' 127 | conv.append_message(role, sentence['value']) 128 | conversations.append(conv.get_prompt()) 129 | 130 | if not text_only: 131 | new_conversations = [] 132 | for conversation in conversations: 133 | for i in range(num_image): 134 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' 135 | conversation = conversation.replace('', image_tokens, 1) 136 | new_conversations.append(conversation) 137 | conversations = new_conversations 138 | 139 | # Tokenize conversations 140 | input_ids = tokenizer( 141 | conversations, 142 | return_tensors='pt', 143 | padding=False if group_by_length or use_packed_ds else 'max_length', 144 | max_length=tokenizer.model_max_length, 145 | truncation=True, 146 | ).input_ids 147 | targets = input_ids.clone() 148 | 149 | # Mask targets. Only compute loss on the assistant outputs. 150 | sep = conv.sep + conv.roles[1] + ': ' 151 | for conversation, target in zip(conversations, targets): 152 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 153 | 154 | turns = conversation.split(conv.sep2) 155 | cur_len = 1 156 | target[:cur_len] = IGNORE_TOKEN_ID 157 | for i, turn in enumerate(turns): 158 | if turn == '': 159 | break 160 | turn_len = len(tokenizer(turn).input_ids) 161 | 162 | parts = turn.split(sep) 163 | if len(parts) != 2: 164 | break 165 | parts[0] += sep 166 | # "-2" is hardcoded for the Llama tokenizer to make the offset correct. 167 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 168 | 169 | if i != 0 and not tokenizer.legacy: 170 | # The legacy and non-legacy modes handle special tokens differently 171 | instruction_len -= 1 172 | 173 | # Ignore the user instructions 174 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID 175 | cur_len += turn_len 176 | 177 | if i != 0 and not tokenizer.legacy: 178 | # The legacy and non-legacy modes handle special tokens differently 179 | cur_len -= 1 180 | 181 | target[cur_len:] = IGNORE_TOKEN_ID 182 | 183 | if False: # Inspect and check the correctness of masking 184 | z = target.clone() 185 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 186 | logger.info(tokenizer.decode(z)) 187 | exit() 188 | 189 | if cur_len < tokenizer.model_max_length: 190 | if cur_len != total_len: 191 | target[:] = IGNORE_TOKEN_ID 192 | print( 193 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' 194 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' 195 | ) 196 | sys.stdout.flush() 197 | 198 | return dict( 199 | input_ids=input_ids, 200 | labels=targets, 201 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 202 | ) 203 | 204 | 205 | def preprocess_mpt( 206 | template_name, 207 | sources, 208 | tokenizer: transformers.PreTrainedTokenizer, 209 | num_image_token_list: list, 210 | text_only: bool = False, 211 | group_by_length: bool = False, 212 | use_packed_ds: bool = False, 213 | ds_name: str = None, 214 | num_image: int = 1 215 | ) -> Dict: 216 | conv = get_conv_template(template_name) 217 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} 218 | 219 | # Apply prompt templates 220 | conversations = [] 221 | for i, source in enumerate(sources): 222 | if roles[source[0]['from']] != conv.roles[0]: 223 | # Skip the first one if it is not from human 224 | source = source[1:] 225 | 226 | conv.messages = [] 227 | for j, sentence in enumerate(source): 228 | role = roles[sentence['from']] 229 | assert role == conv.roles[j % 2], f'{i}' 230 | conv.append_message(role, sentence['value']) 231 | conversations.append(conv.get_prompt()) 232 | 233 | if not text_only: 234 | new_conversations = [] 235 | for conversation in conversations: 236 | for i in range(num_image): 237 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' 238 | conversation = conversation.replace('', image_tokens, 1) 239 | new_conversations.append(conversation) 240 | conversations = new_conversations 241 | 242 | # Tokenize conversations 243 | input_ids = tokenizer( 244 | conversations, 245 | return_tensors='pt', 246 | padding=False if group_by_length or use_packed_ds else 'max_length', 247 | max_length=tokenizer.model_max_length, 248 | truncation=True, 249 | ).input_ids 250 | targets = input_ids.clone() 251 | 252 | # Mask targets. Only compute loss on the assistant outputs. 253 | sep = conv.sep + conv.roles[1] 254 | for conversation, target in zip(conversations, targets): 255 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 256 | 257 | turns = conversation.split(conv.sep) 258 | re_turns = [conv.sep.join(turns[:3])] 259 | for conv_idx in range(3, len(turns), 2): 260 | re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) 261 | cur_len = 0 262 | target[:cur_len] = IGNORE_TOKEN_ID 263 | for i, turn in enumerate(re_turns): 264 | if turn == '': 265 | break 266 | turn_len = len(tokenizer(turn).input_ids) + 1 267 | 268 | parts = turn.split(sep) 269 | if len(parts) != 2: 270 | break 271 | parts[0] += sep 272 | instruction_len = len(tokenizer(parts[0]).input_ids) 273 | 274 | # Ignore the user instructions 275 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID 276 | 277 | cur_len += turn_len 278 | 279 | target[cur_len:] = IGNORE_TOKEN_ID 280 | 281 | if cur_len < tokenizer.model_max_length: 282 | if cur_len != total_len: 283 | target[:] = IGNORE_TOKEN_ID 284 | print( 285 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' 286 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' 287 | ) 288 | sys.stdout.flush() 289 | 290 | return dict( 291 | input_ids=input_ids, 292 | labels=targets, 293 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 294 | ) 295 | 296 | 297 | def preprocess_phi3( 298 | template_name, 299 | sources, 300 | tokenizer: transformers.PreTrainedTokenizer, 301 | num_image_token_list: list, 302 | text_only: bool = False, 303 | group_by_length: bool = False, 304 | use_packed_ds: bool = False, 305 | ds_name: str = None, 306 | num_image: int = 1 307 | ) -> Dict: 308 | conv = get_conv_template(template_name) 309 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} 310 | 311 | # Apply prompt templates 312 | conversations = [] 313 | for i, source in enumerate(sources): 314 | if roles[source[0]['from']] != conv.roles[0]: 315 | # Skip the first one if it is not from human 316 | source = source[1:] 317 | 318 | conv.messages = [] 319 | for j, sentence in enumerate(source): 320 | role = roles[sentence['from']] 321 | assert role == conv.roles[j % 2], f'{i}' 322 | conv.append_message(role, sentence['value']) 323 | conversations.append(conv.get_prompt()) 324 | 325 | if not text_only: 326 | new_conversations = [] 327 | for conversation in conversations: 328 | for i in range(num_image): 329 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' 330 | conversation = conversation.replace('', image_tokens, 1) 331 | new_conversations.append(conversation) 332 | conversations = new_conversations 333 | 334 | # Tokenize conversations 335 | tokenizer.padding_side = 'right' 336 | input_ids = tokenizer( 337 | conversations, 338 | return_tensors='pt', 339 | padding=False if group_by_length or use_packed_ds else 'max_length', 340 | max_length=tokenizer.model_max_length, 341 | truncation=True, 342 | ).input_ids 343 | targets = input_ids.clone() 344 | 345 | # Mask targets. Only compute loss on the assistant outputs. 346 | sep = conv.sep + conv.roles[1] 347 | for conversation, target in zip(conversations, targets): 348 | total_len = int(target.ne(int(tokenizer.pad_token_id)).sum()) 349 | 350 | turns = conversation.split(conv.sep) 351 | re_turns = [conv.sep.join(turns[:3])] 352 | for conv_idx in range(3, len(turns), 2): 353 | re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) 354 | cur_len = 1 355 | target[:cur_len] = IGNORE_TOKEN_ID 356 | endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>') 357 | target[target == endoftext_id] = IGNORE_TOKEN_ID 358 | 359 | for i, turn in enumerate(re_turns): 360 | if turn == '': 361 | break 362 | if i == 0: 363 | turn_len = len(tokenizer(turn).input_ids) 364 | else: 365 | turn_len = len(tokenizer(turn).input_ids) - 1 366 | parts = turn.split(sep) 367 | if len(parts) != 2: 368 | break 369 | parts[0] += sep 370 | 371 | if i == 0: 372 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 373 | else: 374 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 375 | 376 | # Ignore the user instructions 377 | target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID 378 | 379 | cur_len += turn_len 380 | 381 | target[cur_len:] = IGNORE_TOKEN_ID 382 | 383 | if False: # Inspect and check the correctness of masking 384 | z = target.clone() 385 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 386 | print(repr(tokenizer.decode(z))) 387 | 388 | if cur_len < tokenizer.model_max_length: 389 | if cur_len != total_len: 390 | target[:] = IGNORE_TOKEN_ID 391 | print( 392 | f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.' 393 | f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.' 394 | ) 395 | sys.stdout.flush() 396 | 397 | return dict( 398 | input_ids=input_ids, 399 | labels=targets, 400 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 401 | ) 402 | 403 | 404 | def preprocess_internlm( 405 | template_name, 406 | sources, 407 | tokenizer: transformers.PreTrainedTokenizer, 408 | num_image_token_list: list, 409 | text_only: bool = False, 410 | group_by_length: bool = False, 411 | use_packed_ds: bool = False, 412 | ds_name: str = None, 413 | num_image: int = 1 414 | ) -> Dict: 415 | conv = get_conv_template(template_name) 416 | roles = {'human': conv.roles[0], 'gpt': conv.roles[1]} 417 | 418 | # Apply prompt templates 419 | conversations = [] 420 | for i, source in enumerate(sources): 421 | if roles[source[0]['from']] != conv.roles[0]: 422 | # Skip the first one if it is not from human 423 | source = source[1:] 424 | 425 | conv.messages = [] 426 | for j, sentence in enumerate(source): 427 | role = roles[sentence['from']] 428 | assert role == conv.roles[j % 2], f'{i}' 429 | sentence['value'] = sentence['value'].strip() 430 | conv.append_message(role, sentence['value']) 431 | conversations.append(conv.get_prompt()) 432 | 433 | if not text_only: 434 | new_conversations = [] 435 | for conversation in conversations: 436 | for i in range(num_image): 437 | image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}' 438 | conversation = conversation.replace('', image_tokens, 1) 439 | new_conversations.append(conversation) 440 | conversations = new_conversations 441 | 442 | # Tokenize conversations 443 | input_ids = tokenizer( 444 | conversations, 445 | return_tensors='pt', 446 | padding=False if group_by_length or use_packed_ds else 'max_length', 447 | max_length=tokenizer.model_max_length, 448 | truncation=True, 449 | ).input_ids 450 | targets = input_ids.clone() 451 | 452 | for conversation, target in zip(conversations, targets): 453 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 454 | cur_len = 1 455 | target[:cur_len] = IGNORE_TOKEN_ID 456 | parts = conversation.split(conv.roles[1]) 457 | info = parts[0] + conv.roles[1] 458 | temp_len = len(tokenizer(info).input_ids) - 1 459 | target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID 460 | cur_len = cur_len + temp_len 461 | 462 | for index in range(1, len(parts) - 1): 463 | info = parts[index] 464 | part1, part2 = info.split(conv.roles[0]) 465 | temp_len = len(tokenizer(part1).input_ids) - 1 466 | cur_len = cur_len + temp_len 467 | part = conv.roles[0] + part2 + conv.roles[1] 468 | temp_len = len(tokenizer(part).input_ids) - 1 469 | target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID 470 | cur_len = cur_len + temp_len 471 | last_info = parts[-1] 472 | temp_len = len(tokenizer(last_info).input_ids) - 1 473 | cur_len = cur_len + temp_len 474 | 475 | target[cur_len:] = IGNORE_TOKEN_ID 476 | if False: # Inspect and check the correctness of masking 477 | z = target.clone() 478 | z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z) 479 | print(repr(tokenizer.decode(z))) 480 | 481 | if cur_len < tokenizer.model_max_length: 482 | if cur_len != total_len: 483 | target[:] = IGNORE_TOKEN_ID 484 | print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.') 485 | sys.stdout.flush() 486 | 487 | return dict( 488 | input_ids=input_ids, 489 | labels=targets, 490 | attention_mask=input_ids.ne(tokenizer.pad_token_id), 491 | ) 492 | 493 | 494 | def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size): 495 | best_ratio_diff = float('inf') 496 | best_ratio = (1, 1) 497 | area = width * height 498 | for ratio in target_ratios: 499 | target_aspect_ratio = ratio[0] / ratio[1] 500 | ratio_diff = abs(aspect_ratio - target_aspect_ratio) 501 | if ratio_diff < best_ratio_diff: 502 | best_ratio_diff = ratio_diff 503 | best_ratio = ratio 504 | elif ratio_diff == best_ratio_diff: 505 | if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: 506 | best_ratio = ratio 507 | return best_ratio 508 | 509 | 510 | def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False): 511 | orig_width, orig_height = image.size 512 | aspect_ratio = orig_width / orig_height 513 | 514 | # calculate the existing image aspect ratio 515 | target_ratios = set( 516 | (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if 517 | i * j <= max_num and i * j >= min_num) 518 | target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) 519 | 520 | # find the closest aspect ratio to the target 521 | target_aspect_ratio = find_closest_aspect_ratio( 522 | aspect_ratio, target_ratios, orig_width, orig_height, image_size) 523 | 524 | # calculate the target width and height 525 | target_width = image_size * target_aspect_ratio[0] 526 | target_height = image_size * target_aspect_ratio[1] 527 | blocks = target_aspect_ratio[0] * target_aspect_ratio[1] 528 | 529 | # resize the image 530 | resized_img = image.resize((target_width, target_height)) 531 | processed_images = [] 532 | for i in range(blocks): 533 | box = ( 534 | (i % (target_width // image_size)) * image_size, 535 | (i // (target_width // image_size)) * image_size, 536 | ((i % (target_width // image_size)) + 1) * image_size, 537 | ((i // (target_width // image_size)) + 1) * image_size 538 | ) 539 | # split the image 540 | split_img = resized_img.crop(box) 541 | processed_images.append(split_img) 542 | assert len(processed_images) == blocks 543 | if use_thumbnail and len(processed_images) != 1: 544 | thumbnail_img = image.resize((image_size, image_size)) 545 | processed_images.append(thumbnail_img) 546 | return processed_images 547 | -------------------------------------------------------------------------------- /internvl/train/trainer_monkey_patch.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | from transformers import Trainer, logging 8 | from transformers.trainer import is_sagemaker_mp_enabled 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer): 14 | if var_name.startswith('internvl.'): 15 | var_name = var_name[len('internvl.'):] 16 | if var_name in ('query_tokens', 'logit_scale',): 17 | return 0 18 | if var_name.startswith('clip_projector.'): 19 | return vit_num_max_layer 20 | if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \ 21 | var_name == 'text_projection': 22 | return llama_num_max_layer 23 | if var_name.startswith('vision_model.'): 24 | if 'embeddings.' in var_name: 25 | return 0 26 | if 'layers.' in var_name: 27 | var_name = var_name.split('layers.')[-1] 28 | layer_id = int(var_name.split('.')[0]) 29 | return layer_id + 1 30 | if var_name.startswith('qllama.'): 31 | if 'embed_tokens' in var_name: 32 | return 0 33 | if 'layers.' in var_name: 34 | var_name = var_name.split('layers.')[-1] 35 | layer_id = int(var_name.split('.')[0]) 36 | return layer_id + 1 37 | else: 38 | return llama_num_max_layer 39 | return 0 40 | 41 | 42 | def param_classification(name): 43 | if name.startswith('internvl.'): 44 | name = name[len('internvl.'):] 45 | if name in ['query_tokens', 'text_projection', 'logit_scale']: 46 | return 'qllama' 47 | elif name.startswith('vision_model.'): 48 | return 'vit' 49 | elif name.startswith('qllama.'): 50 | return 'qllama' 51 | elif name.startswith('clip_projector.'): 52 | return 'vit' 53 | elif name.startswith('clip_projector2.'): 54 | return 'qllama' 55 | elif name.startswith('itm_head.'): 56 | return 'qllama' 57 | else: 58 | return 'other' 59 | 60 | 61 | def create_optimizer(self): 62 | """ 63 | Setup the optimizer. 64 | 65 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 66 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 67 | """ 68 | opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model 69 | 70 | parameter_groups = {} 71 | try: # for stage2 model 72 | vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2 73 | qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2 74 | except: # for stage3 model 75 | vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2 76 | qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2 77 | print('vit_num_layers:', vit_num_layers) 78 | print('qllama_num_layers:', qllama_num_layers) 79 | 80 | vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0)) 81 | qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0)) 82 | qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0)) 83 | print('vit_layer_decay_rate:', vit_layer_decay_rate) 84 | print('qllama_layer_decay_rate:', qllama_layer_decay_rate) 85 | print('qllama_lr_scale:', qllama_lr_scale) 86 | 87 | for name, param in opt_model.named_parameters(): 88 | if not param.requires_grad: 89 | continue # frozen weights 90 | if len(param.shape) == 1 or name.endswith('.bias'): 91 | group_name = 'no_decay' 92 | this_weight_decay = 0. 93 | else: 94 | group_name = 'decay' 95 | this_weight_decay = self.args.weight_decay 96 | 97 | cls = param_classification(name) 98 | layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers) 99 | group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name) 100 | if group_name not in parameter_groups: 101 | if cls == 'vit': 102 | scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1) 103 | elif cls == 'qllama': 104 | scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1) 105 | scale = scale * qllama_lr_scale 106 | else: 107 | scale = 1.0 108 | scale = min(1.0, scale) 109 | parameter_groups[group_name] = { 110 | 'weight_decay': this_weight_decay, 111 | 'params': [], 112 | 'param_names': [], 113 | 'lr_scale': scale, 114 | 'group_name': group_name, 115 | 'lr': scale * self.args.learning_rate, 116 | } 117 | parameter_groups[group_name]['params'].append(param) 118 | parameter_groups[group_name]['param_names'].append(name) 119 | 120 | rank = torch.distributed.get_rank() 121 | if rank == 0: 122 | to_display = {} 123 | for key in parameter_groups: 124 | to_display[key] = { 125 | 'param_names': parameter_groups[key]['param_names'], 126 | 'lr_scale': parameter_groups[key]['lr_scale'], 127 | 'lr': parameter_groups[key]['lr'], 128 | 'weight_decay': parameter_groups[key]['weight_decay'], 129 | } 130 | print('Param groups = %s' % json.dumps(to_display, indent=2)) 131 | 132 | optimizer_grouped_parameters = list(parameter_groups.values()) 133 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 134 | 135 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 136 | if optimizer_cls.__name__ == 'Adam8bit': 137 | import bitsandbytes 138 | 139 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 140 | 141 | skipped = 0 142 | for module in opt_model.modules(): 143 | if isinstance(module, nn.Embedding): 144 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 145 | logger.info(f'skipped {module}: {skipped / 2 ** 20}M params') 146 | manager.register_module_override(module, 'weight', {'optim_bits': 32}) 147 | logger.debug(f'bitsandbytes: will optimize {module} in fp32') 148 | logger.info(f'skipped: {skipped / 2 ** 20}M params') 149 | 150 | if is_sagemaker_mp_enabled(): 151 | import smdistributed.modelparallel.torch as smp 152 | self.optimizer = smp.DistributedOptimizer(self.optimizer) 153 | 154 | return self.optimizer 155 | 156 | 157 | def replace_create_optimizer(): 158 | print('Replace original create_optimizer with custom create_optimizer') 159 | transformers.Trainer.create_optimizer = create_optimizer 160 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate<1 2 | bitsandbytes==0.42.0 3 | datasets 4 | decord 5 | deepspeed==0.13.5 6 | einops==0.6.1 7 | einops-exts==0.0.4 8 | huggingface_hub 9 | imageio 10 | numpy==1.26.4 11 | opencv-python 12 | orjson 13 | peft==0.10.0 14 | pycocoevalcap 15 | pyyaml 16 | scikit-learn>=1.2.2 17 | scipy 18 | sentencepiece==0.1.99 19 | shortuuid 20 | tensorboardX 21 | termcolor 22 | timm==0.9.12 23 | tokenizers==0.15.1 24 | torch==2.1.0 25 | torchvision==0.16.0 26 | tqdm 27 | transformers==4.37.2 28 | yacs 29 | -------------------------------------------------------------------------------- /shell/data_llava_finetune.json: -------------------------------------------------------------------------------- 1 | { 2 | "llava_665k": { 3 | "root": "playground/data/", 4 | "annotation": "playground/llava_v1_5_mix665k.jsonl", 5 | "data_augment": false, 6 | "repeat_time": 1, 7 | "length": 665298 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /shell/mono_internvl_finetune_llava_slurm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | PARTITION=${PARTITION:-"Your partition"} 4 | GPUS=${GPUS:-8} 5 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 6 | QUOTA_TYPE=${QUOTA_TYPE:-"spot"} 7 | NODES=$((GPUS / GPUS_PER_NODE)) 8 | CPUS_PER_TASK=${CPUS_PER_TASK:-12} 9 | SRUN_ARGS=${SRUN_ARGS:-""} 10 | BATCH_SIZE=${BATCH_SIZE:-128} 11 | PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-4} 12 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS)) 13 | 14 | MODEL=${MODEL:-"Path to your model"} 15 | OUTPUT_DIR=${OUTPUT_DIR:-"Path to your output directory"} 16 | mkdir -p "$OUTPUT_DIR" 17 | 18 | export PYTHONPATH="${PYTHONPATH}:$(pwd)" 19 | 20 | srun -p ${PARTITION} \ 21 | --gres=gpu:${GPUS_PER_NODE} \ 22 | --nodes=${NODES} \ 23 | --ntasks=${GPUS} \ 24 | --ntasks-per-node=${GPUS_PER_NODE} \ 25 | --cpus-per-task=${CPUS_PER_TASK} \ 26 | --kill-on-bad-exit=1 \ 27 | --quotatype=${QUOTA_TYPE} \ 28 | ${SRUN_ARGS} \ 29 | python -u internvl/train/internvl_chat_finetune.py \ 30 | --model_name_or_path ${MODEL} \ 31 | --vision_type patch \ 32 | --conv_style "internlm2-chat" \ 33 | --output_dir ${OUTPUT_DIR} \ 34 | --meta_path "./shell/data_llava_finetune.json" \ 35 | --overwrite_output_dir True \ 36 | --force_image_size 448 \ 37 | --max_dynamic_patch 6 \ 38 | --down_sample_ratio 0.5 \ 39 | --drop_path_rate 0.1 \ 40 | --pad2square False \ 41 | --freeze_llm False \ 42 | --freeze_mlp False \ 43 | --freeze_backbone False \ 44 | --unfreeze_ve True \ 45 | --vision_select_layer -1 \ 46 | --use_data_resampling False \ 47 | --dataloader_num_workers 4 \ 48 | --bf16 True \ 49 | --num_train_epochs 1 \ 50 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ 51 | --gradient_accumulation_steps ${GRADIENT_ACC} \ 52 | --evaluation_strategy "no" \ 53 | --save_strategy "steps" \ 54 | --save_steps 3000 \ 55 | --save_total_limit 3 \ 56 | --learning_rate 4e-5 \ 57 | --weight_decay 0.01 \ 58 | --warmup_ratio 0.03 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 1 \ 61 | --max_seq_length 2048 \ 62 | --do_train True \ 63 | --grad_checkpoint True \ 64 | --group_by_length True \ 65 | --dynamic_image_size True \ 66 | --use_thumbnail True \ 67 | --ps_version 'v2' \ 68 | --deepspeed "./shell/zero_stage1_config.json" \ 69 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" 70 | -------------------------------------------------------------------------------- /shell/mono_internvl_finetune_llava_torchrun.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUS=8 4 | PER_DEVICE_BATCH_SIZE=4 5 | BATCH_SIZE=128 6 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS)) 7 | 8 | MODEL=${MODEL:-"Path to your model"} 9 | OUTPUT_DIR=${OUTPUT_DIR:-"Path to your output directory"} 10 | mkdir -p "$OUTPUT_DIR" 11 | 12 | export PYTHONPATH="${PYTHONPATH}:$(pwd)" 13 | 14 | torchrun --nproc_per_node=$GPUS --master_port=29501 \ 15 | internvl/train/internvl_chat_finetune.py \ 16 | --model_name_or_path ${MODEL} \ 17 | --vision_type patch \ 18 | --conv_style "internlm2-chat" \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --meta_path "./shell/data_llava_finetune.json" \ 21 | --overwrite_output_dir True \ 22 | --force_image_size 448 \ 23 | --max_dynamic_patch 6 \ 24 | --down_sample_ratio 0.5 \ 25 | --drop_path_rate 0.1 \ 26 | --pad2square False \ 27 | --freeze_llm False \ 28 | --freeze_mlp False \ 29 | --freeze_backbone False \ 30 | --unfreeze_ve True \ 31 | --vision_select_layer -1 \ 32 | --use_data_resampling False \ 33 | --dataloader_num_workers 4 \ 34 | --bf16 True \ 35 | --num_train_epochs 1 \ 36 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ 37 | --gradient_accumulation_steps ${GRADIENT_ACC} \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "steps" \ 40 | --save_steps 3000 \ 41 | --save_total_limit 3 \ 42 | --learning_rate 4e-5 \ 43 | --weight_decay 0.01 \ 44 | --warmup_ratio 0.03 \ 45 | --lr_scheduler_type "cosine" \ 46 | --logging_steps 1 \ 47 | --max_seq_length 2048 \ 48 | --do_train True \ 49 | --grad_checkpoint True \ 50 | --group_by_length True \ 51 | --dynamic_image_size True \ 52 | --use_thumbnail True \ 53 | --ps_version 'v2' \ 54 | --deepspeed "./shell/zero_stage1_config.json" \ 55 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" 56 | -------------------------------------------------------------------------------- /shell/zero_stage1_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e9, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e9, 9 | "contiguous_gradients": true 10 | }, 11 | "fp16": { 12 | "enabled": "auto", 13 | "auto_cast": true, 14 | "loss_scale": 0, 15 | "initial_scale_power": 32, 16 | "loss_scale_window": 1000, 17 | "hysteresis": 2, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "optimizer": { 24 | "type": "AdamW", 25 | "params": { 26 | "lr": "auto", 27 | "betas": [ 28 | 0.9, 29 | 0.999 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": "auto" 33 | } 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": true 41 | } 42 | --------------------------------------------------------------------------------