├── .flake8 ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── data_pipeline ├── llm_utils.py ├── mllm_as_a_judge.py ├── omegaprm.py ├── prm_data_format.py ├── process_json.py ├── run_data_pipeline.py ├── run_data_pipeline.sh └── traverse.py ├── docs ├── case_study.png ├── logo.png ├── performance.png └── wechat_qr.png ├── eval └── prm │ ├── evaluate_k12_prm.py │ ├── evaluate_mathverse_prm.py │ ├── evaluate_mathvision_prm.py │ ├── evaluate_mathvista_prm.py │ ├── evaluate_olympiadbench_prm.py │ └── extract_calculate.py ├── evaluate.sh ├── internvl ├── conversation.py ├── dist_utils.py ├── model │ ├── __init__.py │ ├── internlm2 │ │ ├── configuration_internlm2.py │ │ ├── modeling_internlm2.py │ │ ├── tokenization_internlm2.py │ │ └── tokenization_internlm2_fast.py │ ├── internvl_chat │ │ ├── __init__.py │ │ ├── configuration_intern_vit.py │ │ ├── configuration_internvl_chat.py │ │ ├── modeling_intern_vit.py │ │ └── modeling_internvl_chat.py │ └── phi3 │ │ ├── configuration_phi3.py │ │ └── modeling_phi3.py ├── patch │ ├── __init__.py │ ├── internlm2_packed_training_patch.py │ ├── internvit_liger_monkey_patch.py │ ├── llama2_flash_attn_monkey_patch.py │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_packed_training_patch.py │ ├── llama_rmsnorm_monkey_patch.py │ ├── pad_data_collator.py │ ├── phi3_packed_training_patch.py │ ├── qwen2_packed_training_patch.py │ ├── train_dataloader_patch.py │ └── train_sampler_patch.py └── train │ ├── __init__.py │ ├── constants.py │ ├── dataset.py │ ├── dataset_packed.py │ └── internvl_chat_finetune.py ├── requirements.txt ├── requirements ├── classification.txt ├── clip_benchmark.txt ├── internvl_chat.txt ├── segmentation.txt └── streamlit_demo.txt ├── shell └── internvl2.5 │ └── 2nd_finetune │ └── internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh ├── zero_stage1_config.json ├── zero_stage2_config.json ├── zero_stage3_config.json ├── zero_stage3_config_100b.json ├── zero_stage3_config_100b_1e8.json ├── zero_stage3_config_34b.json └── zero_stage3_config_70b.json /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, F403, C901, W504, W605, E251, E122, E126, E127, E722, W503, E128, E741, E731, E701 3 | select = E1, E3, E502, E7, E9, W1, W5, W6 4 | max-line-length = 180 5 | exclude=*.egg/*,build,dist,detection/configs/* 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .idea/ 163 | 164 | .DS_Store 165 | data_process/ 166 | internvl_chat/work_dirs/ 167 | internvl_chat/unittest/ 168 | internvl_chat/data/ 169 | Husky2/* 170 | data_process/ 171 | *distillation* 172 | 173 | batchscript-* 174 | results/ 175 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | line-length = 180 3 | multi_line_output = 0 4 | extra_standard_library = setuptools 5 | known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml 6 | no_lines_before = STDLIB,LOCALFOLDER 7 | default_section = THIRDPARTY 8 | 9 | [yapf] 10 | BASED_ON_STYLE = pep8 11 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 12 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 13 | 14 | [codespell] 15 | skip = *.ipynb 16 | quiet-level = 3 17 | ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood 18 | © 2022 GitHub, Inc. 19 | Terms 20 | Privacy 21 | Security 22 | Status 23 | Docs 24 | Contact GitHub 25 | Pricing 26 | API 27 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/flake8 3 | rev: 5.0.4 4 | hooks: 5 | - id: flake8 6 | - repo: https://github.com/PyCQA/isort 7 | rev: 5.11.5 8 | hooks: 9 | - id: isort 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.3.0 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: check-yaml 15 | - id: end-of-file-fixer 16 | - id: requirements-txt-fixer 17 | - id: double-quote-string-fixer 18 | - id: check-merge-conflict 19 | - id: fix-encoding-pragma 20 | args: ["--remove"] 21 | - id: mixed-line-ending 22 | args: ["--fix=lf"] 23 | - repo: https://github.com/executablebooks/mdformat 24 | rev: 0.7.9 25 | hooks: 26 | - id: mdformat 27 | args: ["--number"] 28 | additional_dependencies: 29 | - mdformat-openmmlab 30 | - mdformat_frontmatter 31 | - linkify-it-py 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to MM-Eureka 2 | 3 | After cloning the repository, please install pre-commit hooks with: 4 | 5 | ``` 6 | pip install pre-commit 7 | pre-commit install 8 | ``` 9 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 ModalMinds Team 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | MM-PRM logo 3 |
4 | 5 |
6 | 7 | # MM-PRM 8 | 9 |
10 | 11 |
12 |

13 | 📖Paper | 14 | 📊MM-K12 | 15 | 🤗MM-PRM 16 |

17 |
18 | 19 |
20 |
21 |

MM-PRM: Enhancing Multimodal Mathematical Reasoning with Scalable Step-Level Supervision

22 |

23 |
24 | 25 | ## 🎯Overview 26 | 27 | While Multimodal Large Language Models (MLLMs) have achieved impressive progress in vision-language understanding, they still struggle with complex multi-step reasoning, often producing logically inconsistent or partially correct solutions. A key limitation lies in the lack of fine-grained supervision over intermediate reasoning steps. To address this, we propose **MM-PRM**, a process reward model trained within a fully automated, scalable framework. We first build **MM-Policy**, a strong multimodal model trained on diverse mathematical reasoning data. Then, we construct **MM-K12**, a curated dataset of 10,000 multimodal math problems with verifiable answers, which serves as seed data. Leveraging a Monte Carlo Tree Search (MCTS)-based pipeline, we generate over 700k step-level annotations without human labeling. The resulting PRM is used to score candidate reasoning paths in the Best-of-N inference setup and achieves significant improvements across both in-domain (MM-K12 test set) and out-of-domain (OlympiadBench, MathVista, etc.) benchmarks. Further analysis confirms the effectiveness of soft labels, smaller learning rates, and path diversity in optimizing PRM performance. MM-PRM demonstrates that process supervision is a powerful tool for enhancing the logical robustness of multimodal reasoning systems. We release all our codes and data at [MM-PRM](https://github.com/ModalMinds/MM-PRM). 28 | 29 | ## 🗞️ News 30 | 31 | - **\[2025/05/19\]** We released `MM-PRM`. 32 | - 📖 Paper: [MM-PRM-Paper](https://arxiv.org/abs/2505.13427) 33 | - 📊 Data: [MM-K12](https://huggingface.co/datasets/Cierra0506/MM-K12) 34 | - 🤗 Model: [MM-PRM](https://huggingface.co/Cierra0506/MM-PRM) 35 | 36 | ## 📊 MM-K12 Dataset 37 | 38 | We released **MM-K12** dataset at [MM-K12](https://huggingface.co/datasets/Cierra0506/MM-K12). 39 | 40 | ## 🤖 Models 41 | 42 |
43 | Case Study 44 |
45 | 46 | *Figure 1 | Qualitative example of MM-PRM accurately identifying error steps in multimodal reasoning process.* 47 | 48 |
49 | Performance 50 |
51 | 52 | *Figure 2 | Performance improvements across various benchmarks when applying the MM-PRM to different models.* 53 | 54 | - 🤗 [MM-PRM](https://huggingface.co/Cierra0506/MM-PRM) 55 | 56 | ## 🏁 Getting Started 57 | 58 | ### 📦 Installation 59 | 60 | ```shell 61 | git clone https://github.com/ModalMinds/MM-PRM.git 62 | cd MM-PRM 63 | pip install -r requirements.txt 64 | 65 | # install flash-attn==2.3.6: 66 | 67 | pip install flash-attn==2.3.6 --no-build-isolation 68 | 69 | # Alternatively you can compile from source: 70 | 71 | git clone https://github.com/Dao-AILab/flash-attention.git 72 | cd flash-attention 73 | git checkout v2.3.6 74 | python setup.py install 75 | ``` 76 | 77 | ### 📂 Data Pipeline 78 | 79 | 1. **Seed dataset preparation** 80 | 81 | To begin, prepare a seed dataset consisting of verifiable problems. Each example should be formatted as a JSON object containing the following fields: 82 | 83 | ```json 84 | [ 85 | { 86 | "id": "unique identifier for the problem", 87 | "question": "problem statement", 88 | "correct_answer": "ground-truth final answer for evaluation and verification", 89 | "image_path": "/path/to/image.png" 90 | }, 91 | ... 92 | ] 93 | ``` 94 | 95 | This dataset will be used as input to the data pipeline to generate annotated solution trees with step-wise correctness labels. 96 | 97 | To enable parallel data generation, you need to split the seed dataset into smaller chunks. 98 | 99 | ```shell 100 | cd data_pipeline 101 | python process_json.py 102 | ``` 103 | 104 | 2. **API endpoint setup (Optional)** 105 | 106 | The data generation process requires an API endpoint to automatically verify whether the final answer in a rollout is correct. You can deploy a model (e.g., Qwen2.5) locally to act as the answer judge. 107 | 108 | We recommend using [vLLM](https://docs.vllm.ai/) to deploy a local model. 109 | 110 | 3. **Run data pipeline** 111 | 112 | Once you have all set, you can run the data pipeline to generate step-level supervision data. 113 | 114 | Before running, ensure that all necessary parameters are correctly set in the script or passed through the environment. 115 | 116 | ```shell 117 | sh run_data_pipeline.sh 118 | ``` 119 | 120 | 4. **Sampling Training Data from annotation trees** 121 | 122 | After generating annotated reasoning trees, you need to sample step-by-step solution paths from these trees to construct the training data for the Process Reward Model (PRM). This can be done using the script: 123 | 124 | ```shell 125 | python traverse.py 126 | ``` 127 | 128 | The next step is to convert this data into the format required for PRM training. Use the following script to perform the formatting: 129 | 130 | ```shell 131 | python prm_data_format.py 132 | ``` 133 | 134 | ### 🌐 Start PRM Training 135 | 136 | Create a JSON file in `internvl_chat/shell/data/` 137 | 138 | The format for the JSON file should be: 139 | 140 | ```json 141 | { 142 | "your-custom-prm_dataset": { 143 | "root": "/path/to/the/image/root", 144 | "annotation": "/path/to/the/jsonl/annotation", 145 | "data_augment": false, 146 | "repeat_time": 1, 147 | "length": "number of samples in the dataset" 148 | } 149 | } 150 | ``` 151 | 152 | Once the dataset configuration is in place, you can start training the PRM model with: 153 | 154 | ```shell 155 | GPUS=8 sh shell/internvl2.5/2nd_finetune/internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh 156 | ``` 157 | 158 | ### 📊 Evaluation 159 | 160 | We provide our **evaluation code** in the `eval/` directory. 161 | 162 | ## ⭐ Starchart 163 | 164 | [![Star History Chart](https://api.star-history.com/svg?repos=ModalMinds/MM-PRM&type=Date)](https://star-history.com/#ModalMinds/MM-PRM&Date) 165 | 166 | ## 🤝 Contribution 167 | 168 | If you want to contribute, please feel free to make a pull request or create an issue. 169 | 170 | Please refer to `CONTRIBUTING.md` before you dive in! 171 | 172 | ## 📬 Contact 173 | 174 | If you have any questions or would like to engage with our community, feel free to scan the QR code below to join our WeChat group. 175 | 176 |
177 | MM-PRM logo 178 |
179 | 180 | ## 🎓 Acknowledgements 181 | 182 | We acknowledge the outstanding open-source contributions from [OpenR](https://github.com/openreasoner/openr) and [vLLM](https://github.com/vllm-project/vllm). We also extend our gratitude to [InternVL](https://github.com/OpenGVLab/InternVL) for their open-source techniques and base models, which have enabled us to further our exploration. 183 | 184 | ## 📜 Citation 185 | 186 | ``` 187 | @article{du2025mmprm, 188 | title={MM-PRM: Enhancing Multimodal Mathematical Reasoning with Scalable Step-Level Supervision}, 189 | author={Lingxiao Du and Fanqing Meng and Zongkai Liu and Zhixiang Zhou and Ping Luo and Qiaosheng Zhang and Wenqi Shao}, 190 | year={2025}, 191 | journal={arXiv preprint arXiv:2505.13427}, 192 | } 193 | @article{meng2025mmeureka, 194 | title={MM-Eureka: Exploring the Frontiers of Multimodal Reasoning with Rule-based Reinforcement Learning}, 195 | author={Fanqing Meng and Lingxiao Du and Zongkai Liu and Zhixiang Zhou and Quanfeng Lu and Daocheng Fu and Tiancheng Han and Botian Shi and Wenhai Wang and Junjun He and Kaipeng Zhang and Ping Luo and Yu Qiao and Qiaosheng Zhang and Wenqi Shao}, 196 | year={2025}, 197 | journal={arXiv preprint arXiv:2503.07365}, 198 | } 199 | @article{liu2025cpgd, 200 | title={CPGD: Toward Stable Rule-based Reinforcement Learning for Language Models}, 201 | author={Zongkai Liu and Fanqing Meng and Lingxiao Du and Zhixiang Zhou and Chao Yu and Wenqi Shao and Qiaosheng Zhang}, 202 | year={2025}, 203 | journal={arXiv preprint arXiv:2505.12504}, 204 | } 205 | ``` 206 | -------------------------------------------------------------------------------- /data_pipeline/llm_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | 4 | from transformers import AutoTokenizer 5 | from vllm import LLM, SamplingParams 6 | from vllm.multimodal.utils import fetch_image 7 | 8 | logger = logging.getLogger('main') 9 | 10 | 11 | class LanguageModel: 12 | def __init__( 13 | self, 14 | model='OpenGVLab/InternVL2_5-8B', 15 | max_new_tokens=4096, 16 | temperature=1.0, 17 | top_k=50, 18 | top_p=0.9, 19 | ): 20 | self.model = model 21 | self.max_new_tokens = max_new_tokens 22 | self.temperature = temperature 23 | self.top_k = top_k 24 | self.top_p = top_p 25 | self.repetition_penalty = 1.05 26 | 27 | logger.info(f'Loading model {self.model}...') 28 | self.model = LLM( 29 | model=model, 30 | trust_remote_code=True, 31 | tensor_parallel_size=1, 32 | limit_mm_per_prompt={'image': 8}, 33 | ) 34 | self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) 35 | self.stop_tokens = ['<|im_end|>\n'.strip()] 36 | self.stop_token_ids = [ 37 | self.tokenizer.convert_tokens_to_ids(i) for i in self.stop_tokens 38 | ] 39 | self.special_tokens = self.tokenizer.all_special_tokens 40 | self.custom_tokens = ['', '', '', ''] 41 | self.special_tokens = [ 42 | token for token in self.special_tokens if token not in self.custom_tokens 43 | ] 44 | self.pattern1 = r'|'.join(map(re.escape, self.special_tokens)) 45 | self.pattern2 = r'(.*?)|(.*?)' 46 | logger.info('Model loaded successfully.') 47 | 48 | def generate_results(self, prompt, image_path=None, num_copies=16): 49 | if '' not in prompt: 50 | prompt = '\n' + prompt 51 | 52 | messages = [ 53 | { 54 | 'role': 'system', 55 | 'content': '你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。', 56 | }, 57 | {'role': 'user', 'content': prompt}, 58 | ] 59 | prompt = self.tokenizer.apply_chat_template( 60 | messages, tokenize=False, add_generation_prompt=True 61 | ) 62 | 63 | inputs = [] 64 | if image_path: 65 | image = fetch_image('file://' + image_path, allowed_local_media_path='/') 66 | for _ in range(num_copies): 67 | inputs.append( 68 | { 69 | 'prompt': prompt, 70 | 'multi_modal_data': {'image': image}, 71 | } 72 | ) 73 | else: 74 | for _ in range(num_copies): 75 | inputs.append({'prompt': prompt}) 76 | 77 | sampling_params = SamplingParams( 78 | temperature=self.temperature, 79 | max_tokens=self.max_new_tokens, 80 | top_p=self.top_p, 81 | top_k=self.top_k, 82 | repetition_penalty=self.repetition_penalty, 83 | stop_token_ids=self.stop_token_ids, 84 | skip_special_tokens=False, 85 | ) 86 | model_outputs = self.model.generate(inputs, sampling_params=sampling_params) 87 | batch_results = [] 88 | for model_output in model_outputs: 89 | response = re.sub(self.pattern1, '', model_output.outputs[0].text) 90 | matches = re.findall(self.pattern2, response, re.DOTALL) 91 | res = ( 92 | [match[0] if match[0] else match[1] for match in matches] 93 | if matches 94 | else [] 95 | ) 96 | res = list(map(str.strip, res)) 97 | batch_results.append(res) 98 | 99 | for result in batch_results: 100 | logger.debug(f'Prompt: {prompt}\nGenerated rollout: {result}') 101 | 102 | return batch_results 103 | -------------------------------------------------------------------------------- /data_pipeline/mllm_as_a_judge.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import logging 4 | import os 5 | from concurrent.futures import ThreadPoolExecutor 6 | 7 | import numpy as np 8 | import openai 9 | from tqdm import tqdm 10 | 11 | logging.basicConfig( 12 | level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' 13 | ) 14 | 15 | openai._utils._logs.logger.setLevel(logging.WARNING) 16 | openai._utils._logs.httpx_logger.setLevel(logging.WARNING) 17 | openai.base_url = 'http://127.0.0.1:10022/v1/' 18 | openai.api_key = 'FAKE_API_KEY' 19 | MODEL_NAME = 'Qwen2-VL-72B-Instruct' 20 | OUTPUT_DIR = 'mllm_as_a_judge_outputs' 21 | 22 | 23 | def traverse(node): 24 | def dfs(node, question, encoded_image, previous_solution, answer): 25 | partial_solution = ( 26 | '\n\n'.join(previous_solution) 27 | if previous_solution 28 | else 'No partial solution' 29 | ) 30 | following_steps = '\n\n'.join(node['partial_solution']) 31 | prompt = f"""I will provide a problem, its corresponding answer, a partial solution to the problem and some steps that continue from the partial solution. They will be formatted as follows: 32 | 33 | [Problem] 34 | 35 | ...(problem)... 36 | 37 | [Correct Answer] 38 | 39 | ...(problem's correct answer)... 40 | 41 | [Partial Solution] 42 | 43 | ...(partial solution)... 44 | 45 | [Following Steps] 46 | 47 | ...(some steps that continue from the partial solution)... 48 | 49 | Your task is to evaluate the Following Steps to determine whether they are logically and mathematically valid. If they are valid, respond with "Yes"; otherwise, respond with "No". 50 | 51 | * Respond with "Yes" or "No" only. 52 | 53 | ------------------------------------------------ 54 | 55 | The following is the information for you task: 56 | 57 | [Problem] 58 | 59 | {question} 60 | 61 | [Correct Answer] 62 | 63 | {answer} 64 | 65 | [Partial Solution] 66 | 67 | {partial_solution} 68 | 69 | [Following Steps] 70 | 71 | {following_steps} 72 | """ 73 | 74 | messages = [ 75 | { 76 | 'role': 'user', 77 | 'content': [ 78 | {'type': 'text', 'text': prompt}, 79 | { 80 | 'type': 'image_url', 81 | 'image_url': {'url': encoded_image}, 82 | }, 83 | ], 84 | } 85 | ] 86 | 87 | completion = openai.chat.completions.create( 88 | model=MODEL_NAME, 89 | messages=messages, 90 | temperature=0.0, 91 | max_tokens=1, 92 | top_p=0.95, 93 | logprobs=True, 94 | top_logprobs=5, 95 | ) 96 | 97 | top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs 98 | top_logprobs = {i.token: i.logprob for i in top_logprobs} 99 | 100 | logprob_yes = max(top_logprobs.get('YES', -100), top_logprobs.get('Yes', -100)) 101 | logprob_no = max(top_logprobs.get('NO', -100), top_logprobs.get('No', -100)) 102 | node['logprob_yes'] = float(np.exp(logprob_yes)) 103 | node['logprob_no'] = float(np.exp(logprob_no)) 104 | node['llm_as_a_judge'] = float( 105 | np.exp(logprob_yes) / (np.exp(logprob_yes) + np.exp(logprob_no)) 106 | ) 107 | for child in node['children']: 108 | dfs( 109 | child, 110 | question, 111 | encoded_image, 112 | previous_solution + node['partial_solution'], 113 | answer, 114 | ) 115 | 116 | try: 117 | with open(node['image_path'], 'rb') as f: 118 | encoded_image = ( 119 | f'data:image;base64,{base64.b64encode(f.read()).decode("utf-8")}' 120 | ) 121 | for child in node['children']: 122 | dfs(child, node['question'], encoded_image, [], node['answer']) 123 | 124 | with open(os.path.join(OUTPUT_DIR, f'{node["id"]}.json'), 'w') as f: 125 | json.dump(node, f, ensure_ascii=False, indent=4) 126 | logging.info(node['id']) 127 | except: 128 | pass 129 | 130 | 131 | if __name__ == '__main__': 132 | os.makedirs(OUTPUT_DIR, exist_ok=True) 133 | 134 | root_dir = '/path/to/outputs' 135 | roots = [] 136 | for file in os.listdir(root_dir): 137 | logging.info(file) 138 | with open(os.path.join(root_dir, file), 'r') as f: 139 | roots.append(json.load(f)) 140 | 141 | with ThreadPoolExecutor(max_workers=16) as executor: 142 | results = list(tqdm(executor.map(traverse, roots), total=len(roots))) 143 | -------------------------------------------------------------------------------- /data_pipeline/prm_data_format.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import jsonlines 4 | 5 | input_file = '/path/to/input/file.json' 6 | output_file = '/path/to/output/file.jsonl' 7 | 8 | with open(input_file, 'r') as f: 9 | data = json.load(f) 10 | res = [] 11 | 12 | for idx, item in enumerate(data): 13 | question = item['question'] 14 | process = item['process'] 15 | labels = item['labels'] 16 | image_path = item['image_path'] 17 | 18 | combined_value = f'Question: {question}\nProcess: {process}' 19 | 20 | conversations = [ 21 | {'from': 'human', 'value': combined_value}, 22 | {'from': 'gpt', 'value': labels}, 23 | ] 24 | 25 | new_item = {'id': idx, 'image': image_path, 'conversations': conversations} 26 | 27 | res.append(new_item) 28 | 29 | with jsonlines.open(output_file, 'w') as writer: 30 | writer.write_all(res) 31 | -------------------------------------------------------------------------------- /data_pipeline/process_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | 5 | 6 | def split_questions_uniformly(input_file, output_dir, num_splits): 7 | with open(input_file, 'r') as f: 8 | input_data = json.load(f) 9 | 10 | num_data = len(input_data) 11 | data_per_file = math.ceil(num_data / num_splits) 12 | 13 | os.makedirs(output_dir, exist_ok=True) 14 | 15 | for i in range(num_splits): 16 | start_idx = i * data_per_file 17 | end_idx = min(start_idx + data_per_file, num_data) 18 | data_subset = input_data[start_idx:end_idx] 19 | 20 | output_filepath = os.path.join(output_dir, f'questions_part_{i + 1}.json') 21 | print(f'Saving {len(data_subset)} questions to {output_filepath}') 22 | 23 | with open(output_filepath, 'w') as f_out: 24 | json.dump(data_subset, f_out, indent=4, ensure_ascii=False) 25 | 26 | 27 | if __name__ == '__main__': 28 | split_questions_uniformly('/path/to/input/file', 'split_dir', 16) 29 | -------------------------------------------------------------------------------- /data_pipeline/run_data_pipeline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import hashlib 3 | import json 4 | import logging 5 | import os 6 | 7 | from llm_utils import LanguageModel 8 | from omegaprm import OmegaPRM 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument('--input_file', type=str) 14 | parser.add_argument('--log_file', type=str, default='log.txt') 15 | parser.add_argument('--output_dir', type=str, default='output') 16 | 17 | parser.add_argument('--model', type=str, default='OpenGVLab/InternVL2_5-8B') 18 | parser.add_argument('--max_new_tokens', type=int, default=2048) 19 | parser.add_argument('--temperature', type=float, default=0.7) 20 | parser.add_argument('--top_k', type=int, default=30) 21 | parser.add_argument('--top_p', type=float, default=0.9) 22 | 23 | parser.add_argument('--c_puct', type=float, default=0.125) 24 | parser.add_argument('--alpha', type=float, default=0.5) 25 | parser.add_argument('--beta', type=float, default=0.9) 26 | parser.add_argument('--length_scale', type=int, default=500) 27 | parser.add_argument('--num_rollouts', type=int, default=16) 28 | parser.add_argument('--max_search_count', type=int, default=20) 29 | parser.add_argument('--rollout_budget', type=int, default=200) 30 | 31 | parser.add_argument('--api_endpoint', type=str) 32 | 33 | args = parser.parse_args() 34 | 35 | logging.basicConfig( 36 | level=logging.DEBUG, 37 | format='%(asctime)s [%(levelname)s] %(message)s', 38 | handlers=[logging.FileHandler(args.log_file), logging.StreamHandler()], 39 | ) 40 | logger = logging.getLogger('main') 41 | 42 | logger.info('Start OmegaPRM') 43 | logger.info(f'Using model: {args.model}') 44 | logger.info(f'Input file: {args.input_file}') 45 | logger.info(f'Output directory: {args.output_dir}') 46 | 47 | with open(args.input_file, 'r') as f: 48 | input_data = json.load(f) 49 | 50 | LLM = LanguageModel( 51 | model=args.model, 52 | max_new_tokens=args.max_new_tokens, 53 | temperature=args.temperature, 54 | top_k=args.top_k, 55 | top_p=args.top_p, 56 | ) 57 | 58 | for question in input_data: 59 | hash_value = hashlib.md5(json.dumps(question).encode()).hexdigest() 60 | if os.path.exists(os.path.join(args.output_dir, f'{hash_value}.json')): 61 | logger.info(f"Already processed: {question['question']}") 62 | continue 63 | try: 64 | omega_prm = OmegaPRM( 65 | LLM=LLM, 66 | question=question['question'], 67 | image_path=question['image_path'], 68 | correct_answer=question['correct_answer'], 69 | c_puct=args.c_puct, 70 | alpha=args.alpha, 71 | beta=args.beta, 72 | length_scale=args.length_scale, 73 | num_rollouts=args.num_rollouts, 74 | max_search_count=args.max_search_count, 75 | rollout_budget=args.rollout_budget, 76 | api_endpoint=args.api_endpoint, 77 | ) 78 | data = omega_prm.run() 79 | data['id'] = question['id'] 80 | 81 | filename = os.path.join(args.output_dir, f'{hash_value}.json') 82 | with open(filename, 'w') as f: 83 | json.dump(data, f, indent=4, ensure_ascii=False) 84 | logger.info(f'Processed: {hash_value}.json') 85 | except Exception as e: 86 | logger.error(f"Error processing {question['question']}: {e}") 87 | -------------------------------------------------------------------------------- /data_pipeline/run_data_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT_DIR="split_dir" 4 | LOG_DIR="logs" 5 | OUTPUT_DIR="outputs" 6 | MODEL="/path/to/model" 7 | 8 | MAX_NEW_TOKENS=4096 9 | TEMPERATURE=1.0 10 | TOP_K=50 11 | TOP_P=0.9 12 | 13 | C_PUCT=0.125 14 | ALPHA=0.5 15 | BETA=0.9 16 | LENGTH_SCALE=2000 17 | NUM_ROLLOUTS=16 18 | MAX_SEARCH_COUNT=200 19 | ROLLOUT_BUDGET=1000 20 | 21 | API_ENDPOINT="http://127.0.0.1:8000/v1/" 22 | 23 | 24 | mkdir -p $OUTPUT_DIR 25 | mkdir -p $LOG_DIR 26 | 27 | START=$((NODE_RANK * 8 + 1)) 28 | 29 | for i in {0..7} 30 | do 31 | j=$((START + i)) 32 | GPU_ID=$i 33 | INPUT_FILE="$SPLIT_DIR/questions_part_${j}.json" 34 | LOG_FILE="$LOG_DIR/part_${j}.log" 35 | 36 | CUDA_VISIBLE_DEVICES="${GPU_ID}" python run_data_pipeline.py \ 37 | --input_file $INPUT_FILE \ 38 | --log_file $LOG_FILE \ 39 | --output_dir $OUTPUT_DIR \ 40 | --model $MODEL \ 41 | --max_new_tokens $MAX_NEW_TOKENS \ 42 | --temperature $TEMPERATURE \ 43 | --top_k $TOP_K \ 44 | --top_p $TOP_P \ 45 | --c_puct $C_PUCT \ 46 | --alpha $ALPHA \ 47 | --beta $BETA \ 48 | --length_scale $LENGTH_SCALE \ 49 | --num_rollouts $NUM_ROLLOUTS \ 50 | --max_search_count $MAX_SEARCH_COUNT \ 51 | --rollout_budget $ROLLOUT_BUDGET \ 52 | --api_endpoint $API_ENDPOINT & 53 | done 54 | 55 | wait 56 | -------------------------------------------------------------------------------- /data_pipeline/traverse.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | 5 | SEP = '' 6 | 7 | THRESHOLD = 0.00 8 | 9 | 10 | def random_select_solutions(solutions, n): 11 | if len(solutions) <= n: 12 | return solutions 13 | return random.sample(solutions, n) 14 | 15 | 16 | def completion_too_short(string, word_count_thres=10): 17 | return len(string.split(' ')) <= word_count_thres 18 | 19 | 20 | def remove_redundant(solutions): 21 | unique_solutions = {} 22 | for solution in solutions: 23 | key = solution['process'] 24 | if key not in unique_solutions: 25 | unique_solutions[key] = solution 26 | 27 | return list(unique_solutions.values()) 28 | 29 | 30 | def traverse(root): 31 | question = root['question'] 32 | answer = root['answer'] 33 | image_path = root['image_path'] 34 | positive = [] 35 | negative = [] 36 | res = [] 37 | 38 | def dfs(node, solution_prefix, labels): 39 | partial_solution = list(map(str.strip, node['partial_solution'])) 40 | partial_solution = '\n\n'.join(partial_solution) 41 | solution_prefix = solution_prefix + partial_solution + SEP + '\n\n' 42 | labels = labels + [node['mc_value']] 43 | if node['mc_value'] <= THRESHOLD: 44 | if not completion_too_short(solution_prefix): 45 | negative.append( 46 | { 47 | 'question': question, 48 | 'answer': answer, 49 | 'image_path': image_path, 50 | 'process': solution_prefix.strip(), 51 | 'labels': labels, 52 | } 53 | ) 54 | return 55 | if node['children']: 56 | for child in node['children']: 57 | dfs(child, solution_prefix, labels) 58 | else: 59 | if not completion_too_short(solution_prefix): 60 | positive.append( 61 | { 62 | 'question': question, 63 | 'answer': answer, 64 | 'image_path': image_path, 65 | 'process': solution_prefix.strip(), 66 | 'labels': labels, 67 | } 68 | ) 69 | return 70 | 71 | if root['children']: 72 | for child in root['children']: 73 | dfs(child, '', []) 74 | 75 | negative = remove_redundant(negative) 76 | positive = remove_redundant(positive) 77 | 78 | res.extend(negative) 79 | 80 | if len(positive) < len(negative): 81 | res.extend(positive) 82 | else: 83 | res.extend(random.sample(positive, len(negative))) 84 | return res 85 | 86 | 87 | if __name__ == '__main__': 88 | root_dir = 'outputs' 89 | res = [] 90 | for file in os.listdir(root_dir): 91 | print(file) 92 | with open(os.path.join(root_dir, file), 'r') as f: 93 | data = json.load(f) 94 | if data['mc_value'] != 0.0 and data['mc_value'] != 1.0: 95 | res.extend(traverse(data)) 96 | 97 | print(len(res)) 98 | json.dump(res, open('output_file.json', 'w'), ensure_ascii=False, indent=4) 99 | -------------------------------------------------------------------------------- /docs/case_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/case_study.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/logo.png -------------------------------------------------------------------------------- /docs/performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/performance.png -------------------------------------------------------------------------------- /docs/wechat_qr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/wechat_qr.png -------------------------------------------------------------------------------- /eval/prm/evaluate_k12_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from internvl.model import load_model_and_tokenizer 13 | from internvl.train.dataset import build_transform, dynamic_preprocess 14 | 15 | ds_collections = { 16 | 'k12_prm': {'root': '/path/to/image/root', 'annotation': '/path/to/rollout/file'} 17 | } 18 | 19 | 20 | def collate_fn(batches): 21 | pixel_values = batches[0]['pixel_values'] 22 | prompts = batches[0]['prompts'] 23 | steps_lens = batches[0]['steps_lens'] 24 | data_items = batches[0]['data_item'] 25 | return pixel_values, prompts, steps_lens, data_items 26 | 27 | 28 | class K12PRMDataset(torch.utils.data.Dataset): 29 | 30 | def __init__( 31 | self, 32 | root, 33 | annotation, 34 | input_size=224, 35 | dynamic_image_size=False, 36 | use_thumbnail=False, 37 | max_num=6, 38 | ): 39 | self.root = root 40 | self.data = json.load(open(annotation)) 41 | self.input_size = input_size 42 | self.dynamic_image_size = dynamic_image_size 43 | self.use_thumbnail = use_thumbnail 44 | self.max_num = max_num 45 | self.transform = build_transform(is_train=False, input_size=input_size) 46 | 47 | def __len__(self): 48 | return len(self.data) 49 | 50 | def __getitem__(self, idx): 51 | data_item = self.data[idx] 52 | image = Image.open(os.path.join(self.root, data_item['image_path'])).convert( 53 | 'RGB' 54 | ) 55 | 56 | if self.dynamic_image_size: 57 | images = dynamic_preprocess( 58 | image, 59 | image_size=self.input_size, 60 | use_thumbnail=self.use_thumbnail, 61 | max_num=self.max_num, 62 | ) 63 | else: 64 | images = [image] 65 | pixel_values = [self.transform(image) for image in images] 66 | pixel_values = torch.stack(pixel_values) 67 | 68 | question = data_item['question'] 69 | 70 | prompts = [] 71 | steps_lens = [] 72 | for solution_split in data_item['solutions_splits']: 73 | solution = ''.join(solution_split) + '' 74 | prompt = f'Question: {question}\nProcess: {solution}' 75 | prompts.append(prompt) 76 | steps_lens.append(len(solution_split)) 77 | 78 | return { 79 | 'pixel_values': pixel_values, 80 | 'prompts': prompts, 81 | 'steps_lens': steps_lens, 82 | 'data_item': data_item, 83 | } 84 | 85 | 86 | class InferenceSampler(torch.utils.data.sampler.Sampler): 87 | 88 | def __init__(self, size): 89 | self._size = int(size) 90 | assert size > 0 91 | self._rank = torch.distributed.get_rank() 92 | self._world_size = torch.distributed.get_world_size() 93 | self._local_indices = self._get_local_indices( 94 | size, self._world_size, self._rank 95 | ) 96 | 97 | @staticmethod 98 | def _get_local_indices(total_size, world_size, rank): 99 | shard_size = total_size // world_size 100 | left = total_size % world_size 101 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 102 | 103 | begin = sum(shard_sizes[:rank]) 104 | end = min(sum(shard_sizes[: rank + 1]), total_size) 105 | return range(begin, end) 106 | 107 | def __iter__(self): 108 | yield from self._local_indices 109 | 110 | def __len__(self): 111 | return len(self._local_indices) 112 | 113 | 114 | def evaluate_chat_model(): 115 | random.seed(args.seed) 116 | 117 | for ds_name in args.datasets: 118 | dataset = K12PRMDataset( 119 | root=ds_collections[ds_name]['root'], 120 | annotation=ds_collections[ds_name]['annotation'], 121 | input_size=image_size, 122 | dynamic_image_size=args.dynamic, 123 | use_thumbnail=use_thumbnail, 124 | max_num=args.max_num, 125 | ) 126 | dataloader = torch.utils.data.DataLoader( 127 | dataset=dataset, 128 | sampler=InferenceSampler(len(dataset)), 129 | batch_size=1, 130 | num_workers=args.num_workers, 131 | pin_memory=True, 132 | drop_last=False, 133 | collate_fn=collate_fn, 134 | ) 135 | 136 | outputs = [] 137 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm( 138 | enumerate(dataloader) 139 | ): 140 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 141 | 142 | prm_scores_flattened = [] 143 | for i in range(0, len(prompts), args.mini_batch_size): 144 | curr_bs = min(args.mini_batch_size, len(prompts) - i) 145 | output = model.batch_prm( 146 | tokenizer=tokenizer, 147 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0), 148 | questions=prompts[i : i + curr_bs], 149 | num_patches_list=[pixel_values.shape[0]] * curr_bs, 150 | verbose=True, 151 | ) 152 | prm_scores_flattened.extend(output.tolist()) 153 | 154 | data_item['prm_scores'] = [] 155 | curr_len = 0 156 | for i in range(len(steps_lens)): 157 | data_item['prm_scores'].append( 158 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]] 159 | ) 160 | curr_len += steps_lens[i] 161 | 162 | for i in range(len(data_item['prm_scores'])): 163 | assert len(data_item['prm_scores'][i]) == steps_lens[i] 164 | 165 | print(f'Pred: {data_item["prm_scores"]}') 166 | outputs.append(data_item) 167 | 168 | if idx % 50 == 0: 169 | torch.distributed.barrier() 170 | 171 | torch.distributed.barrier() 172 | 173 | world_size = torch.distributed.get_world_size() 174 | merged_outputs = [None for _ in range(world_size)] 175 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 176 | 177 | merged_outputs = [json.loads(_) for _ in merged_outputs] 178 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 179 | 180 | if torch.distributed.get_rank() == 0: 181 | print(f'Evaluating {ds_name} ...') 182 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 183 | results_file = f'{ds_name}_{time_prefix}.json' 184 | output_path = os.path.join(args.out_dir, results_file) 185 | json.dump( 186 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False 187 | ) 188 | print('Results saved to {}'.format(output_path)) 189 | 190 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}' 191 | print(cmd) 192 | os.system(cmd) 193 | 194 | 195 | if __name__ == '__main__': 196 | parser = argparse.ArgumentParser() 197 | parser.add_argument('--checkpoint', type=str, default='') 198 | parser.add_argument('--datasets', type=str, default='k12_prm') 199 | parser.add_argument('--mini-batch-size', type=int, default=4) 200 | parser.add_argument('--num-workers', type=int, default=2) 201 | parser.add_argument('--out-dir', type=str, default='results') 202 | parser.add_argument('--seed', type=int, default=0) 203 | parser.add_argument('--dynamic', action='store_true', default=True) 204 | parser.add_argument('--max-num', type=int, default=6) 205 | parser.add_argument('--load-in-8bit', action='store_true') 206 | parser.add_argument('--load-in-4bit', action='store_true') 207 | parser.add_argument('--auto', action='store_true') 208 | args = parser.parse_args() 209 | 210 | if not os.path.exists(args.out_dir): 211 | os.makedirs(args.out_dir, exist_ok=True) 212 | 213 | args.datasets = args.datasets.split(',') 214 | print('datasets:', args.datasets) 215 | 216 | torch.distributed.init_process_group( 217 | backend='nccl', 218 | world_size=int(os.getenv('WORLD_SIZE', '1')), 219 | rank=int(os.getenv('RANK', '0')), 220 | ) 221 | 222 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 223 | 224 | model, tokenizer = load_model_and_tokenizer(args) 225 | 226 | image_size = model.config.force_image_size or model.config.vision_config.image_size 227 | use_thumbnail = model.config.use_thumbnail 228 | 229 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 230 | print(f'[test] total_params: {total_params}B') 231 | print(f'[test] image_size: {image_size}') 232 | print(f'[test] template: {model.config.template}') 233 | print(f'[test] dynamic_image_size: {args.dynamic}') 234 | print(f'[test] use_thumbnail: {use_thumbnail}') 235 | 236 | evaluate_chat_model() 237 | -------------------------------------------------------------------------------- /eval/prm/evaluate_mathverse_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from internvl.model import load_model_and_tokenizer 13 | from internvl.train.dataset import build_transform, dynamic_preprocess 14 | 15 | ds_collections = { 16 | 'mathverse_prm': { 17 | 'root': '/path/to/image/root', 18 | 'annotation': '/path/to/rollout/file', 19 | } 20 | } 21 | 22 | 23 | def collate_fn(batches): 24 | pixel_values = batches[0]['pixel_values'] 25 | prompts = batches[0]['prompts'] 26 | steps_lens = batches[0]['steps_lens'] 27 | data_items = batches[0]['data_item'] 28 | return pixel_values, prompts, steps_lens, data_items 29 | 30 | 31 | class MathVersePRMDataset(torch.utils.data.Dataset): 32 | 33 | def __init__( 34 | self, 35 | root, 36 | annotation, 37 | input_size=224, 38 | dynamic_image_size=False, 39 | use_thumbnail=False, 40 | max_num=6, 41 | ): 42 | self.root = root 43 | self.data = json.load(open(annotation)) 44 | self.input_size = input_size 45 | self.dynamic_image_size = dynamic_image_size 46 | self.use_thumbnail = use_thumbnail 47 | self.max_num = max_num 48 | self.transform = build_transform(is_train=False, input_size=input_size) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | data_item = self.data[idx] 55 | image = os.path.join(self.root, data_item['image']) 56 | 57 | image = Image.open(image).convert('RGB') 58 | if self.dynamic_image_size: 59 | images = dynamic_preprocess( 60 | image, 61 | image_size=self.input_size, 62 | use_thumbnail=self.use_thumbnail, 63 | max_num=self.max_num, 64 | ) 65 | else: 66 | images = [image] 67 | pixel_values = [self.transform(image) for image in images] 68 | pixel_values = torch.stack(pixel_values) 69 | 70 | question = data_item['query_cot'] 71 | 72 | prompts = [] 73 | steps_lens = [] 74 | for solution_split in data_item['solutions_splits']: 75 | solution = ''.join(solution_split) + '' 76 | prompt = f'Question: {question}\nProcess: {solution}' 77 | prompts.append(prompt) 78 | steps_lens.append(len(solution_split)) 79 | 80 | return { 81 | 'pixel_values': pixel_values, 82 | 'prompts': prompts, 83 | 'steps_lens': steps_lens, 84 | 'data_item': data_item, 85 | } 86 | 87 | 88 | class InferenceSampler(torch.utils.data.sampler.Sampler): 89 | 90 | def __init__(self, size): 91 | self._size = int(size) 92 | assert size > 0 93 | self._rank = torch.distributed.get_rank() 94 | self._world_size = torch.distributed.get_world_size() 95 | self._local_indices = self._get_local_indices( 96 | size, self._world_size, self._rank 97 | ) 98 | 99 | @staticmethod 100 | def _get_local_indices(total_size, world_size, rank): 101 | shard_size = total_size // world_size 102 | left = total_size % world_size 103 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 104 | 105 | begin = sum(shard_sizes[:rank]) 106 | end = min(sum(shard_sizes[: rank + 1]), total_size) 107 | return range(begin, end) 108 | 109 | def __iter__(self): 110 | yield from self._local_indices 111 | 112 | def __len__(self): 113 | return len(self._local_indices) 114 | 115 | 116 | def evaluate_chat_model(): 117 | random.seed(args.seed) 118 | 119 | for ds_name in args.datasets: 120 | dataset = MathVersePRMDataset( 121 | root=ds_collections[ds_name]['root'], 122 | annotation=ds_collections[ds_name]['annotation'], 123 | input_size=image_size, 124 | dynamic_image_size=args.dynamic, 125 | use_thumbnail=use_thumbnail, 126 | max_num=args.max_num, 127 | ) 128 | dataloader = torch.utils.data.DataLoader( 129 | dataset=dataset, 130 | sampler=InferenceSampler(len(dataset)), 131 | batch_size=1, 132 | num_workers=args.num_workers, 133 | pin_memory=True, 134 | drop_last=False, 135 | collate_fn=collate_fn, 136 | ) 137 | 138 | outputs = [] 139 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm( 140 | enumerate(dataloader) 141 | ): 142 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 143 | 144 | prm_scores_flattened = [] 145 | for i in range(0, len(prompts), args.mini_batch_size): 146 | curr_bs = min(args.mini_batch_size, len(prompts) - i) 147 | output = model.batch_prm( 148 | tokenizer=tokenizer, 149 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0), 150 | questions=prompts[i : i + curr_bs], 151 | num_patches_list=[pixel_values.shape[0]] * curr_bs, 152 | verbose=True, 153 | ) 154 | prm_scores_flattened.extend(output.tolist()) 155 | 156 | data_item['prm_scores'] = [] 157 | curr_len = 0 158 | for i in range(len(steps_lens)): 159 | data_item['prm_scores'].append( 160 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]] 161 | ) 162 | curr_len += steps_lens[i] 163 | 164 | for i in range(len(data_item['prm_scores'])): 165 | assert len(data_item['prm_scores'][i]) == steps_lens[i] 166 | 167 | print(f'Pred: {data_item["prm_scores"]}') 168 | outputs.append(data_item) 169 | 170 | if idx % 50 == 0: 171 | torch.distributed.barrier() 172 | 173 | torch.distributed.barrier() 174 | 175 | world_size = torch.distributed.get_world_size() 176 | merged_outputs = [None for _ in range(world_size)] 177 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 178 | 179 | merged_outputs = [json.loads(_) for _ in merged_outputs] 180 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 181 | 182 | if torch.distributed.get_rank() == 0: 183 | print(f'Evaluating {ds_name} ...') 184 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 185 | results_file = f'{ds_name}_{time_prefix}.json' 186 | output_path = os.path.join(args.out_dir, results_file) 187 | json.dump( 188 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False 189 | ) 190 | print('Results saved to {}'.format(output_path)) 191 | 192 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}' 193 | print(cmd) 194 | os.system(cmd) 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('--checkpoint', type=str, default='') 200 | parser.add_argument('--datasets', type=str, default='') 201 | parser.add_argument('--mini-batch-size', type=int, default=4) 202 | parser.add_argument('--num-workers', type=int, default=2) 203 | parser.add_argument('--out-dir', type=str, default='results') 204 | parser.add_argument('--seed', type=int, default=0) 205 | parser.add_argument('--dynamic', action='store_true', default=True) 206 | parser.add_argument('--max-num', type=int, default=6) 207 | parser.add_argument('--load-in-8bit', action='store_true') 208 | parser.add_argument('--load-in-4bit', action='store_true') 209 | parser.add_argument('--auto', action='store_true') 210 | args = parser.parse_args() 211 | 212 | if not os.path.exists(args.out_dir): 213 | os.makedirs(args.out_dir, exist_ok=True) 214 | 215 | args.datasets = args.datasets.split(',') 216 | print('datasets:', args.datasets) 217 | 218 | torch.distributed.init_process_group( 219 | backend='nccl', 220 | world_size=int(os.getenv('WORLD_SIZE', '1')), 221 | rank=int(os.getenv('RANK', '0')), 222 | ) 223 | 224 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 225 | 226 | model, tokenizer = load_model_and_tokenizer(args) 227 | 228 | image_size = model.config.force_image_size or model.config.vision_config.image_size 229 | use_thumbnail = model.config.use_thumbnail 230 | 231 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 232 | print(f'[test] total_params: {total_params}B') 233 | print(f'[test] image_size: {image_size}') 234 | print(f'[test] template: {model.config.template}') 235 | print(f'[test] dynamic_image_size: {args.dynamic}') 236 | print(f'[test] use_thumbnail: {use_thumbnail}') 237 | 238 | evaluate_chat_model() 239 | -------------------------------------------------------------------------------- /eval/prm/evaluate_mathvision_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from internvl.model import load_model_and_tokenizer 13 | from internvl.train.dataset import build_transform, dynamic_preprocess 14 | 15 | ds_collections = { 16 | 'mathvision_prm': { 17 | 'root': '/path/to/image/root', 18 | 'annotation': '/path/to/rollout/file', 19 | } 20 | } 21 | 22 | 23 | def collate_fn(batches): 24 | pixel_values = batches[0]['pixel_values'] 25 | prompts = batches[0]['prompts'] 26 | steps_lens = batches[0]['steps_lens'] 27 | data_items = batches[0]['data_item'] 28 | return pixel_values, prompts, steps_lens, data_items 29 | 30 | 31 | class MathVisionPRMDataset(torch.utils.data.Dataset): 32 | 33 | def __init__( 34 | self, 35 | root, 36 | annotation, 37 | input_size=224, 38 | dynamic_image_size=False, 39 | use_thumbnail=False, 40 | max_num=6, 41 | ): 42 | self.root = root 43 | self.data = json.load(open(annotation)) 44 | self.input_size = input_size 45 | self.dynamic_image_size = dynamic_image_size 46 | self.use_thumbnail = use_thumbnail 47 | self.max_num = max_num 48 | self.transform = build_transform(is_train=False, input_size=input_size) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | data_item = self.data[idx] 55 | image = os.path.join(self.root, os.path.basename(data_item['image'])) 56 | 57 | image = Image.open(image).convert('RGB') 58 | if self.dynamic_image_size: 59 | images = dynamic_preprocess( 60 | image, 61 | image_size=self.input_size, 62 | use_thumbnail=self.use_thumbnail, 63 | max_num=self.max_num, 64 | ) 65 | else: 66 | images = [image] 67 | pixel_values = [self.transform(image) for image in images] 68 | pixel_values = torch.stack(pixel_values) 69 | 70 | options = '' 71 | if len(data_item['options']) > 0: 72 | assert len(data_item['options']) == 5, data_item 73 | if ''.join(data_item['options']) != 'ABCDE': 74 | options = f"(A) {data_item['options'][0]}\n(B) {data_item['options'][1]}\n(C) {data_item['options'][2]}\n(D) {data_item['options'][3]}\n(E) {data_item['options'][4]}\n" 75 | question = f"{data_item['question']}\n{options}" 76 | 77 | prompts = [] 78 | steps_lens = [] 79 | for solution_split in data_item['solutions_splits']: 80 | solution = ''.join(solution_split) + '' 81 | prompt = f'Question: {question}\nProcess: {solution}' 82 | prompts.append(prompt) 83 | steps_lens.append(len(solution_split)) 84 | 85 | return { 86 | 'pixel_values': pixel_values, 87 | 'prompts': prompts, 88 | 'steps_lens': steps_lens, 89 | 'data_item': data_item, 90 | } 91 | 92 | 93 | class InferenceSampler(torch.utils.data.sampler.Sampler): 94 | 95 | def __init__(self, size): 96 | self._size = int(size) 97 | assert size > 0 98 | self._rank = torch.distributed.get_rank() 99 | self._world_size = torch.distributed.get_world_size() 100 | self._local_indices = self._get_local_indices( 101 | size, self._world_size, self._rank 102 | ) 103 | 104 | @staticmethod 105 | def _get_local_indices(total_size, world_size, rank): 106 | shard_size = total_size // world_size 107 | left = total_size % world_size 108 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 109 | 110 | begin = sum(shard_sizes[:rank]) 111 | end = min(sum(shard_sizes[: rank + 1]), total_size) 112 | return range(begin, end) 113 | 114 | def __iter__(self): 115 | yield from self._local_indices 116 | 117 | def __len__(self): 118 | return len(self._local_indices) 119 | 120 | 121 | def evaluate_chat_model(): 122 | random.seed(args.seed) 123 | 124 | for ds_name in args.datasets: 125 | dataset = MathVisionPRMDataset( 126 | root=ds_collections[ds_name]['root'], 127 | annotation=ds_collections[ds_name]['annotation'], 128 | input_size=image_size, 129 | dynamic_image_size=args.dynamic, 130 | use_thumbnail=use_thumbnail, 131 | max_num=args.max_num, 132 | ) 133 | dataloader = torch.utils.data.DataLoader( 134 | dataset=dataset, 135 | sampler=InferenceSampler(len(dataset)), 136 | batch_size=1, 137 | num_workers=args.num_workers, 138 | pin_memory=True, 139 | drop_last=False, 140 | collate_fn=collate_fn, 141 | ) 142 | 143 | outputs = [] 144 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm( 145 | enumerate(dataloader) 146 | ): 147 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 148 | 149 | prm_scores_flattened = [] 150 | for i in range(0, len(prompts), args.mini_batch_size): 151 | curr_bs = min(args.mini_batch_size, len(prompts) - i) 152 | output = model.batch_prm( 153 | tokenizer=tokenizer, 154 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0), 155 | questions=prompts[i : i + curr_bs], 156 | num_patches_list=[pixel_values.shape[0]] * curr_bs, 157 | verbose=True, 158 | ) 159 | prm_scores_flattened.extend(output.tolist()) 160 | 161 | data_item['prm_scores'] = [] 162 | curr_len = 0 163 | for i in range(len(steps_lens)): 164 | data_item['prm_scores'].append( 165 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]] 166 | ) 167 | curr_len += steps_lens[i] 168 | 169 | for i in range(len(data_item['prm_scores'])): 170 | assert len(data_item['prm_scores'][i]) == steps_lens[i] 171 | 172 | print(f'Pred: {data_item["prm_scores"]}') 173 | outputs.append(data_item) 174 | 175 | if idx % 50 == 0: 176 | torch.distributed.barrier() 177 | 178 | torch.distributed.barrier() 179 | 180 | world_size = torch.distributed.get_world_size() 181 | merged_outputs = [None for _ in range(world_size)] 182 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 183 | 184 | merged_outputs = [json.loads(_) for _ in merged_outputs] 185 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 186 | 187 | if torch.distributed.get_rank() == 0: 188 | print(f'Evaluating {ds_name} ...') 189 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 190 | results_file = f'{ds_name}_{time_prefix}.json' 191 | output_path = os.path.join(args.out_dir, results_file) 192 | json.dump( 193 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False 194 | ) 195 | print('Results saved to {}'.format(output_path)) 196 | 197 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}' 198 | print(cmd) 199 | os.system(cmd) 200 | 201 | 202 | if __name__ == '__main__': 203 | parser = argparse.ArgumentParser() 204 | parser.add_argument('--checkpoint', type=str, default='') 205 | parser.add_argument('--datasets', type=str, default='') 206 | parser.add_argument('--mini-batch-size', type=int, default=4) 207 | parser.add_argument('--num-workers', type=int, default=2) 208 | parser.add_argument('--out-dir', type=str, default='results') 209 | parser.add_argument('--seed', type=int, default=0) 210 | parser.add_argument('--dynamic', action='store_true', default=True) 211 | parser.add_argument('--max-num', type=int, default=6) 212 | parser.add_argument('--load-in-8bit', action='store_true') 213 | parser.add_argument('--load-in-4bit', action='store_true') 214 | parser.add_argument('--auto', action='store_true') 215 | args = parser.parse_args() 216 | 217 | if not os.path.exists(args.out_dir): 218 | os.makedirs(args.out_dir, exist_ok=True) 219 | 220 | args.datasets = args.datasets.split(',') 221 | print('datasets:', args.datasets) 222 | 223 | torch.distributed.init_process_group( 224 | backend='nccl', 225 | world_size=int(os.getenv('WORLD_SIZE', '1')), 226 | rank=int(os.getenv('RANK', '0')), 227 | ) 228 | 229 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 230 | 231 | model, tokenizer = load_model_and_tokenizer(args) 232 | 233 | image_size = model.config.force_image_size or model.config.vision_config.image_size 234 | use_thumbnail = model.config.use_thumbnail 235 | 236 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 237 | print(f'[test] total_params: {total_params}B') 238 | print(f'[test] image_size: {image_size}') 239 | print(f'[test] template: {model.config.template}') 240 | print(f'[test] dynamic_image_size: {args.dynamic}') 241 | print(f'[test] use_thumbnail: {use_thumbnail}') 242 | 243 | evaluate_chat_model() 244 | -------------------------------------------------------------------------------- /eval/prm/evaluate_mathvista_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from internvl.model import load_model_and_tokenizer 13 | from internvl.train.dataset import build_transform, dynamic_preprocess 14 | 15 | ds_collections = { 16 | 'mathvista_prm': { 17 | 'root': '/path/to/image/root', 18 | 'annotation': '/path/to/rollout/file', 19 | } 20 | } 21 | 22 | 23 | def collate_fn(batches): 24 | pixel_values = batches[0]['pixel_values'] 25 | prompts = batches[0]['prompts'] 26 | steps_lens = batches[0]['steps_lens'] 27 | data_items = batches[0]['data_item'] 28 | return pixel_values, prompts, steps_lens, data_items 29 | 30 | 31 | class MathVistaPRMDataset(torch.utils.data.Dataset): 32 | 33 | def __init__( 34 | self, 35 | root, 36 | annotation, 37 | input_size=224, 38 | dynamic_image_size=False, 39 | use_thumbnail=False, 40 | max_num=6, 41 | ): 42 | self.root = root 43 | self.data = json.load(open(annotation)) 44 | self.input_size = input_size 45 | self.dynamic_image_size = dynamic_image_size 46 | self.use_thumbnail = use_thumbnail 47 | self.max_num = max_num 48 | self.transform = build_transform(is_train=False, input_size=input_size) 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, idx): 54 | data_item = self.data[idx] 55 | image = os.path.join(self.root, os.path.basename(data_item['image'])) 56 | 57 | image = Image.open(image).convert('RGB') 58 | if self.dynamic_image_size: 59 | images = dynamic_preprocess( 60 | image, 61 | image_size=self.input_size, 62 | use_thumbnail=self.use_thumbnail, 63 | max_num=self.max_num, 64 | ) 65 | else: 66 | images = [image] 67 | pixel_values = [self.transform(image) for image in images] 68 | pixel_values = torch.stack(pixel_values) 69 | 70 | question = data_item['query'] 71 | 72 | prompts = [] 73 | steps_lens = [] 74 | for solution_split in data_item['solutions_splits']: 75 | solution = ''.join(solution_split) + '' 76 | prompt = f'Question: {question}\nProcess: {solution}' 77 | prompts.append(prompt) 78 | steps_lens.append(len(solution_split)) 79 | 80 | return { 81 | 'pixel_values': pixel_values, 82 | 'prompts': prompts, 83 | 'steps_lens': steps_lens, 84 | 'data_item': data_item, 85 | } 86 | 87 | 88 | class InferenceSampler(torch.utils.data.sampler.Sampler): 89 | 90 | def __init__(self, size): 91 | self._size = int(size) 92 | assert size > 0 93 | self._rank = torch.distributed.get_rank() 94 | self._world_size = torch.distributed.get_world_size() 95 | self._local_indices = self._get_local_indices( 96 | size, self._world_size, self._rank 97 | ) 98 | 99 | @staticmethod 100 | def _get_local_indices(total_size, world_size, rank): 101 | shard_size = total_size // world_size 102 | left = total_size % world_size 103 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 104 | 105 | begin = sum(shard_sizes[:rank]) 106 | end = min(sum(shard_sizes[: rank + 1]), total_size) 107 | return range(begin, end) 108 | 109 | def __iter__(self): 110 | yield from self._local_indices 111 | 112 | def __len__(self): 113 | return len(self._local_indices) 114 | 115 | 116 | def evaluate_chat_model(): 117 | random.seed(args.seed) 118 | 119 | for ds_name in args.datasets: 120 | dataset = MathVistaPRMDataset( 121 | root=ds_collections[ds_name]['root'], 122 | annotation=ds_collections[ds_name]['annotation'], 123 | input_size=image_size, 124 | dynamic_image_size=args.dynamic, 125 | use_thumbnail=use_thumbnail, 126 | max_num=args.max_num, 127 | ) 128 | dataloader = torch.utils.data.DataLoader( 129 | dataset=dataset, 130 | sampler=InferenceSampler(len(dataset)), 131 | batch_size=1, 132 | num_workers=args.num_workers, 133 | pin_memory=True, 134 | drop_last=False, 135 | collate_fn=collate_fn, 136 | ) 137 | 138 | outputs = [] 139 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm( 140 | enumerate(dataloader) 141 | ): 142 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 143 | 144 | prm_scores_flattened = [] 145 | for i in range(0, len(prompts), args.mini_batch_size): 146 | curr_bs = min(args.mini_batch_size, len(prompts) - i) 147 | output = model.batch_prm( 148 | tokenizer=tokenizer, 149 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0), 150 | questions=prompts[i : i + curr_bs], 151 | num_patches_list=[pixel_values.shape[0]] * curr_bs, 152 | verbose=True, 153 | ) 154 | prm_scores_flattened.extend(output.tolist()) 155 | 156 | data_item['prm_scores'] = [] 157 | curr_len = 0 158 | for i in range(len(steps_lens)): 159 | data_item['prm_scores'].append( 160 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]] 161 | ) 162 | curr_len += steps_lens[i] 163 | 164 | for i in range(len(data_item['prm_scores'])): 165 | assert len(data_item['prm_scores'][i]) == steps_lens[i] 166 | 167 | print(f'Pred: {data_item["prm_scores"]}') 168 | outputs.append(data_item) 169 | 170 | if idx % 50 == 0: 171 | torch.distributed.barrier() 172 | 173 | torch.distributed.barrier() 174 | 175 | world_size = torch.distributed.get_world_size() 176 | merged_outputs = [None for _ in range(world_size)] 177 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 178 | 179 | merged_outputs = [json.loads(_) for _ in merged_outputs] 180 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 181 | 182 | if torch.distributed.get_rank() == 0: 183 | print(f'Evaluating {ds_name} ...') 184 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 185 | results_file = f'{ds_name}_{time_prefix}.json' 186 | output_path = os.path.join(args.out_dir, results_file) 187 | json.dump( 188 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False 189 | ) 190 | print('Results saved to {}'.format(output_path)) 191 | 192 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}' 193 | print(cmd) 194 | os.system(cmd) 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('--checkpoint', type=str, default='') 200 | parser.add_argument('--datasets', type=str, default='') 201 | parser.add_argument('--mini-batch-size', type=int, default=4) 202 | parser.add_argument('--num-workers', type=int, default=2) 203 | parser.add_argument('--out-dir', type=str, default='results') 204 | parser.add_argument('--seed', type=int, default=0) 205 | parser.add_argument('--dynamic', action='store_true', default=True) 206 | parser.add_argument('--max-num', type=int, default=6) 207 | parser.add_argument('--load-in-8bit', action='store_true') 208 | parser.add_argument('--load-in-4bit', action='store_true') 209 | parser.add_argument('--auto', action='store_true') 210 | args = parser.parse_args() 211 | 212 | if not os.path.exists(args.out_dir): 213 | os.makedirs(args.out_dir, exist_ok=True) 214 | 215 | args.datasets = args.datasets.split(',') 216 | print('datasets:', args.datasets) 217 | 218 | torch.distributed.init_process_group( 219 | backend='nccl', 220 | world_size=int(os.getenv('WORLD_SIZE', '1')), 221 | rank=int(os.getenv('RANK', '0')), 222 | ) 223 | 224 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 225 | 226 | model, tokenizer = load_model_and_tokenizer(args) 227 | 228 | image_size = model.config.force_image_size or model.config.vision_config.image_size 229 | use_thumbnail = model.config.use_thumbnail 230 | 231 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 232 | print(f'[test] total_params: {total_params}B') 233 | print(f'[test] image_size: {image_size}') 234 | print(f'[test] template: {model.config.template}') 235 | print(f'[test] dynamic_image_size: {args.dynamic}') 236 | print(f'[test] use_thumbnail: {use_thumbnail}') 237 | 238 | evaluate_chat_model() 239 | -------------------------------------------------------------------------------- /eval/prm/evaluate_olympiadbench_prm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | from internvl.model import load_model_and_tokenizer 13 | from internvl.train.dataset import build_transform, dynamic_preprocess 14 | 15 | ds_collections = { 16 | 'olympiadbench_prm': { 17 | 'root': '/path/to/image/root', 18 | 'annotation': '/path/to/rollout/file', 19 | } 20 | } 21 | 22 | 23 | def collate_fn(batches): 24 | pixel_values = batches[0]['pixel_values'] 25 | prompts = batches[0]['prompts'] 26 | steps_lens = batches[0]['steps_lens'] 27 | num_image_patch = batches[0]['num_image_patch'] 28 | data_items = batches[0]['data_item'] 29 | return pixel_values, prompts, steps_lens, num_image_patch, data_items 30 | 31 | 32 | class OlympiadBenchPRMDataset(torch.utils.data.Dataset): 33 | 34 | def __init__( 35 | self, 36 | root, 37 | annotation, 38 | input_size=224, 39 | dynamic_image_size=False, 40 | use_thumbnail=False, 41 | max_num=6, 42 | ): 43 | self.root = root 44 | self.data = json.load(open(annotation)) 45 | self.input_size = input_size 46 | self.dynamic_image_size = dynamic_image_size 47 | self.use_thumbnail = use_thumbnail 48 | self.max_num = max_num 49 | self.transform = build_transform(is_train=False, input_size=input_size) 50 | 51 | def __len__(self): 52 | return len(self.data) 53 | 54 | def __getitem__(self, idx): 55 | data_item = self.data[idx] 56 | 57 | images, num_tiles = [], [] 58 | for i in range(1, 6): 59 | key = f'image_{i}' 60 | if data_item[key] is None: 61 | continue 62 | 63 | image = Image.open(os.path.join(self.root, data_item[key])).convert('RGB') 64 | 65 | if self.dynamic_image_size: 66 | image = dynamic_preprocess( 67 | image, 68 | image_size=self.input_size, 69 | use_thumbnail=self.use_thumbnail, 70 | max_num=self.max_num, 71 | ) 72 | images += image 73 | num_tiles.append(len(image)) 74 | else: 75 | images.append(image) 76 | num_tiles.append(1) 77 | pixel_values = [self.transform(image) for image in images] 78 | pixel_values = torch.stack(pixel_values) 79 | 80 | question = data_item['question'] 81 | 82 | prompts = [] 83 | steps_lens = [] 84 | for solution_split in data_item['solutions_splits']: 85 | solution = ''.join(solution_split) + '' 86 | prompt = f'Question: {question}\nProcess: {solution}' 87 | prompts.append(prompt) 88 | steps_lens.append(len(solution_split)) 89 | 90 | return { 91 | 'pixel_values': pixel_values, 92 | 'num_image_patch': num_tiles, 93 | 'prompts': prompts, 94 | 'steps_lens': steps_lens, 95 | 'data_item': data_item, 96 | } 97 | 98 | 99 | class InferenceSampler(torch.utils.data.sampler.Sampler): 100 | 101 | def __init__(self, size): 102 | self._size = int(size) 103 | assert size > 0 104 | self._rank = torch.distributed.get_rank() 105 | self._world_size = torch.distributed.get_world_size() 106 | self._local_indices = self._get_local_indices( 107 | size, self._world_size, self._rank 108 | ) 109 | 110 | @staticmethod 111 | def _get_local_indices(total_size, world_size, rank): 112 | shard_size = total_size // world_size 113 | left = total_size % world_size 114 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 115 | 116 | begin = sum(shard_sizes[:rank]) 117 | end = min(sum(shard_sizes[: rank + 1]), total_size) 118 | return range(begin, end) 119 | 120 | def __iter__(self): 121 | yield from self._local_indices 122 | 123 | def __len__(self): 124 | return len(self._local_indices) 125 | 126 | 127 | def evaluate_chat_model(): 128 | random.seed(args.seed) 129 | 130 | for ds_name in args.datasets: 131 | dataset = OlympiadBenchPRMDataset( 132 | root=ds_collections[ds_name]['root'], 133 | annotation=ds_collections[ds_name]['annotation'], 134 | input_size=image_size, 135 | dynamic_image_size=args.dynamic, 136 | use_thumbnail=use_thumbnail, 137 | max_num=args.max_num, 138 | ) 139 | dataloader = torch.utils.data.DataLoader( 140 | dataset=dataset, 141 | sampler=InferenceSampler(len(dataset)), 142 | batch_size=1, 143 | num_workers=args.num_workers, 144 | pin_memory=True, 145 | drop_last=False, 146 | collate_fn=collate_fn, 147 | ) 148 | 149 | outputs = [] 150 | for idx, ( 151 | pixel_values, 152 | prompts, 153 | steps_lens, 154 | num_image_patch, 155 | data_item, 156 | ) in tqdm(enumerate(dataloader)): 157 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 158 | 159 | prm_scores_flattened = [] 160 | for i in range(0, len(prompts), args.mini_batch_size): 161 | curr_bs = min(args.mini_batch_size, len(prompts) - i) 162 | output = model.prm( 163 | tokenizer=tokenizer, 164 | pixel_values=pixel_values, 165 | question=prompts[i : i + curr_bs][0], 166 | num_patches_list=num_image_patch, 167 | verbose=True, 168 | ) 169 | prm_scores_flattened.extend(output.tolist()) 170 | 171 | data_item['prm_scores'] = [] 172 | curr_len = 0 173 | for i in range(len(steps_lens)): 174 | data_item['prm_scores'].append( 175 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]] 176 | ) 177 | curr_len += steps_lens[i] 178 | 179 | for i in range(len(data_item['prm_scores'])): 180 | assert len(data_item['prm_scores'][i]) == steps_lens[i] 181 | 182 | print(f'Pred: {data_item["prm_scores"]}') 183 | outputs.append(data_item) 184 | 185 | if idx % 50 == 0: 186 | torch.distributed.barrier() 187 | 188 | torch.distributed.barrier() 189 | 190 | world_size = torch.distributed.get_world_size() 191 | merged_outputs = [None for _ in range(world_size)] 192 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 193 | 194 | merged_outputs = [json.loads(_) for _ in merged_outputs] 195 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 196 | 197 | if torch.distributed.get_rank() == 0: 198 | print(f'Evaluating {ds_name} ...') 199 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 200 | results_file = f'{ds_name}_{time_prefix}.json' 201 | output_path = os.path.join(args.out_dir, results_file) 202 | json.dump( 203 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False 204 | ) 205 | print('Results saved to {}'.format(output_path)) 206 | 207 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}' 208 | print(cmd) 209 | os.system(cmd) 210 | 211 | 212 | if __name__ == '__main__': 213 | parser = argparse.ArgumentParser() 214 | parser.add_argument('--checkpoint', type=str, default='') 215 | parser.add_argument('--datasets', type=str, default='') 216 | parser.add_argument('--mini-batch-size', type=int, default=1) 217 | parser.add_argument('--num-workers', type=int, default=2) 218 | parser.add_argument('--out-dir', type=str, default='results') 219 | parser.add_argument('--seed', type=int, default=0) 220 | parser.add_argument('--dynamic', action='store_true', default=True) 221 | parser.add_argument('--max-num', type=int, default=6) 222 | parser.add_argument('--load-in-8bit', action='store_true') 223 | parser.add_argument('--load-in-4bit', action='store_true') 224 | parser.add_argument('--auto', action='store_true') 225 | args = parser.parse_args() 226 | 227 | if not os.path.exists(args.out_dir): 228 | os.makedirs(args.out_dir, exist_ok=True) 229 | 230 | args.datasets = args.datasets.split(',') 231 | print('datasets:', args.datasets) 232 | 233 | torch.distributed.init_process_group( 234 | backend='nccl', 235 | world_size=int(os.getenv('WORLD_SIZE', '1')), 236 | rank=int(os.getenv('RANK', '0')), 237 | ) 238 | 239 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 240 | 241 | model, tokenizer = load_model_and_tokenizer(args) 242 | 243 | image_size = model.config.force_image_size or model.config.vision_config.image_size 244 | use_thumbnail = model.config.use_thumbnail 245 | 246 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 247 | print(f'[test] total_params: {total_params}B') 248 | print(f'[test] image_size: {image_size}') 249 | print(f'[test] template: {model.config.template}') 250 | print(f'[test] dynamic_image_size: {args.dynamic}') 251 | print(f'[test] use_thumbnail: {use_thumbnail}') 252 | 253 | evaluate_chat_model() 254 | -------------------------------------------------------------------------------- /eval/prm/extract_calculate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import math 4 | import os 5 | import random 6 | from functools import reduce 7 | 8 | 9 | def calculate_accuracy(data): 10 | results = {} 11 | 12 | cnts = [] 13 | for i in range(5000): 14 | cnt = 0 15 | for item in data: 16 | if random.choice(item['labels']) == 1: 17 | cnt += 1 18 | cnts.append(cnt) 19 | 20 | cnt = sum(cnts) / len(cnts) 21 | print(f'random {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 22 | results['random'] = { 23 | 'correct': cnt, 24 | 'total': len(data), 25 | 'accuracy': cnt / len(data), 26 | } 27 | 28 | cnt = 0 29 | for item in data: 30 | labels = random.sample(item['labels'], min(16, len(item['labels']))) 31 | if sum(labels) >= 1: 32 | cnt += 1 33 | print(f'pass@16 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 34 | results['pass@16'] = { 35 | 'correct': cnt, 36 | 'total': len(data), 37 | 'accuracy': cnt / len(data), 38 | } 39 | 40 | cnt = 0 41 | for item in data: 42 | labels = random.sample(item['labels'], min(8, len(item['labels']))) 43 | if sum(labels) >= 1: 44 | cnt += 1 45 | print(f'pass@8 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 46 | results['pass@8'] = { 47 | 'correct': cnt, 48 | 'total': len(data), 49 | 'accuracy': cnt / len(data), 50 | } 51 | 52 | cnt = 0 53 | for item in data: 54 | labels = random.sample(item['labels'], min(4, len(item['labels']))) 55 | if sum(labels) >= 1: 56 | cnt += 1 57 | print(f'pass@4 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 58 | results['pass@4'] = { 59 | 'correct': cnt, 60 | 'total': len(data), 61 | 'accuracy': cnt / len(data), 62 | } 63 | 64 | cnt = 0 65 | for item in data: 66 | labels = random.sample(item['labels'], min(2, len(item['labels']))) 67 | if sum(labels) >= 1: 68 | cnt += 1 69 | print(f'pass@2 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 70 | results['pass@2'] = { 71 | 'correct': cnt, 72 | 'total': len(data), 73 | 'accuracy': cnt / len(data), 74 | } 75 | 76 | cnt = 0 77 | for item in data: 78 | prm_score = list(map(lambda x: min(x) if x else 0, item['prm_scores'])) 79 | if item['labels'][prm_score.index(max(prm_score))] == 1: 80 | cnt += 1 81 | print(f'prm accuracy min {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 82 | results['min'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)} 83 | 84 | cnt = 0 85 | for item in data: 86 | prm_score = list(map(lambda x: x[-1] if x else 0, item['prm_scores'])) 87 | if item['labels'][prm_score.index(max(prm_score))] == 1: 88 | cnt += 1 89 | print(f'prm accuracy last {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 90 | results['last'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)} 91 | 92 | cnt = 0 93 | for item in data: 94 | prm_score = list( 95 | map(lambda x: reduce(lambda a, b: a * b, x) if x else 0, item['prm_scores']) 96 | ) 97 | if item['labels'][prm_score.index(max(prm_score))] == 1: 98 | cnt += 1 99 | print(f'prm accuracy product {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 100 | results['product'] = { 101 | 'correct': cnt, 102 | 'total': len(data), 103 | 'accuracy': cnt / len(data), 104 | } 105 | 106 | cnt = 0 107 | for item in data: 108 | prm_score = list(map(lambda x: sum(x) / len(x) if x else 0, item['prm_scores'])) 109 | if item['labels'][prm_score.index(max(prm_score))] == 1: 110 | cnt += 1 111 | print(f'prm accuracy average {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 112 | results['average'] = { 113 | 'correct': cnt, 114 | 'total': len(data), 115 | 'accuracy': cnt / len(data), 116 | } 117 | 118 | cnt = 0 119 | for item in data: 120 | prm_score = list( 121 | map( 122 | lambda x: ( 123 | sum(list(map(lambda a: math.log(a) if a != 0 else -9999, x))) 124 | if x 125 | else 0 126 | ), 127 | item['prm_scores'], 128 | ) 129 | ) 130 | if item['labels'][prm_score.index(max(prm_score))] == 1: 131 | cnt += 1 132 | print( 133 | f'prm accuracy sum_logprob {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%' 134 | ) 135 | results['sum_logprob'] = { 136 | 'correct': cnt, 137 | 'total': len(data), 138 | 'accuracy': cnt / len(data), 139 | } 140 | 141 | cnt = 0 142 | for item in data: 143 | prm_score = list(map(lambda x: max(x) if x else 0, item['prm_scores'])) 144 | if item['labels'][prm_score.index(max(prm_score))] == 1: 145 | cnt += 1 146 | print(f'prm accuracy max {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 147 | results['max'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)} 148 | 149 | cnt = 0 150 | for item in data: 151 | try: 152 | prm_score = list( 153 | map( 154 | lambda x: ( 155 | sum( 156 | list( 157 | map( 158 | lambda a: math.log(a / (1 - a)) if a != 1 else 9999, 159 | x, 160 | ) 161 | ) 162 | ) 163 | if x 164 | else 0 165 | ), 166 | item['prm_scores'], 167 | ) 168 | ) 169 | except Exception as e: 170 | print(e) 171 | print(item['prm_scores']) 172 | raise e 173 | if item['labels'][prm_score.index(max(prm_score))] == 1: 174 | cnt += 1 175 | print(f'prm accuracy sum_logit {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 176 | results['sum_logit'] = { 177 | 'correct': cnt, 178 | 'total': len(data), 179 | 'accuracy': cnt / len(data), 180 | } 181 | 182 | cnt = 0 183 | for item in data: 184 | prm_score = list( 185 | map( 186 | lambda x: ( 187 | sum(list(map(lambda a: (a / (1 - a)) if a != 1 else 9999, x))) 188 | / len(x) 189 | if x 190 | else 0 191 | ), 192 | item['prm_scores'], 193 | ) 194 | ) 195 | if item['labels'][prm_score.index(max(prm_score))] == 1: 196 | cnt += 1 197 | print(f'prm accuracy mean_odd {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%') 198 | results['mean_odd'] = { 199 | 'correct': cnt, 200 | 'total': len(data), 201 | 'accuracy': cnt / len(data), 202 | } 203 | 204 | return results 205 | 206 | 207 | if __name__ == '__main__': 208 | parser = argparse.ArgumentParser() 209 | parser.add_argument('--output_dir', type=str, default='./results') 210 | parser.add_argument('--output_file', type=str, default='') 211 | args = parser.parse_args() 212 | 213 | result_file = os.path.join(args.output_dir, args.output_file) 214 | 215 | print(f'Reading {result_file}...') 216 | results = calculate_accuracy(json.load(open(result_file))) 217 | 218 | print(f"Saving results to {result_file.replace('.json', f'_score.json')}...") 219 | json.dump( 220 | results, 221 | open(result_file.replace('.json', f'_score.json'), 'w'), 222 | indent=4, 223 | ensure_ascii=False, 224 | ) 225 | print(f'Results saved.') 226 | -------------------------------------------------------------------------------- /evaluate.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | CHECKPOINT=${1} 4 | DATASET=${2} 5 | CHECKPOINT="$(pwd)/${CHECKPOINT}" 6 | export PYTHONPATH="$(pwd):${PYTHONPATH}" 7 | echo "CHECKPOINT: ${CHECKPOINT}" 8 | 9 | MASTER_PORT=${MASTER_PORT:-63669} 10 | PORT=${PORT:-63665} 11 | GPUS=${GPUS:-8} 12 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 13 | NODES=$((GPUS / GPUS_PER_NODE)) 14 | export MASTER_PORT=${MASTER_PORT} 15 | export PORT=${PORT} 16 | 17 | # Save original arguments 18 | ARGS=("$@") 19 | 20 | # Parse options 21 | while [[ $# -gt 0 ]]; do 22 | case "$1" in 23 | --auto) 24 | GPUS=1 25 | shift 26 | ;; 27 | *) 28 | shift 29 | ;; 30 | esac 31 | done 32 | echo "GPUS: ${GPUS}" 33 | 34 | if [[ "${DATASET}" == *"k12"* ]]; then 35 | torchrun \ 36 | --nnodes=1 \ 37 | --node_rank=0 \ 38 | --master_addr=127.0.0.1 \ 39 | --nproc_per_node=${GPUS} \ 40 | --master_port=${MASTER_PORT} \ 41 | eval/prm/evaluate_k12_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}" 42 | fi 43 | 44 | if [[ "${DATASET}" == *"mathvista"* ]]; then 45 | torchrun \ 46 | --nnodes=1 \ 47 | --node_rank=0 \ 48 | --master_addr=127.0.0.1 \ 49 | --nproc_per_node=${GPUS} \ 50 | --master_port=${MASTER_PORT} \ 51 | eval/prm/evaluate_mathvista_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}" 52 | fi 53 | 54 | if [[ "${DATASET}" == *"mathverse"* ]]; then 55 | torchrun \ 56 | --nnodes=1 \ 57 | --node_rank=0 \ 58 | --master_addr=127.0.0.1 \ 59 | --nproc_per_node=${GPUS} \ 60 | --master_port=${MASTER_PORT} \ 61 | eval/prm/evaluate_mathverse_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}" 62 | fi 63 | 64 | if [[ "${DATASET}" == *"mathvision"* ]]; then 65 | torchrun \ 66 | --nnodes=1 \ 67 | --node_rank=0 \ 68 | --master_addr=127.0.0.1 \ 69 | --nproc_per_node=${GPUS} \ 70 | --master_port=${MASTER_PORT} \ 71 | eval/prm/evaluate_mathvision_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}" 72 | fi 73 | 74 | if [[ "${DATASET}" == *"olympiadbench"* ]]; then 75 | torchrun \ 76 | --nnodes=1 \ 77 | --node_rank=0 \ 78 | --master_addr=127.0.0.1 \ 79 | --nproc_per_node=${GPUS} \ 80 | --master_port=${MASTER_PORT} \ 81 | eval/prm/evaluate_olympiadbench_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}" 82 | fi 83 | -------------------------------------------------------------------------------- /internvl/dist_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import socket 3 | import subprocess 4 | from datetime import timedelta 5 | 6 | import deepspeed 7 | import torch 8 | import torch.multiprocessing as mp 9 | from torch import distributed as dist 10 | 11 | timeout = timedelta(minutes=60) 12 | 13 | 14 | def _find_free_port(): 15 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 17 | # Binding to port 0 will cause the OS to find an available port for us 18 | sock.bind(('', 0)) 19 | port = sock.getsockname()[1] 20 | sock.close() 21 | # NOTE: there is still a chance the port could be taken by other processes. 22 | return port 23 | 24 | 25 | def _is_free_port(port): 26 | ips = socket.gethostbyname_ex(socket.gethostname())[-1] 27 | ips.append('localhost') 28 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 29 | return all(s.connect_ex((ip, port)) != 0 for ip in ips) 30 | 31 | 32 | def init_dist(launcher, backend='nccl', **kwargs): 33 | if mp.get_start_method(allow_none=True) is None: 34 | mp.set_start_method('spawn') 35 | if launcher == 'pytorch': 36 | _init_dist_pytorch(backend, **kwargs) 37 | elif launcher == 'mpi': 38 | _init_dist_mpi(backend, **kwargs) 39 | elif launcher == 'slurm': 40 | _init_dist_slurm(backend, **kwargs) 41 | else: 42 | raise ValueError(f'Invalid launcher type: {launcher}') 43 | 44 | 45 | def _init_dist_pytorch(backend, **kwargs): 46 | # TODO: use local_rank instead of rank % num_gpus 47 | rank = int(os.environ['RANK']) 48 | num_gpus = torch.cuda.device_count() 49 | torch.cuda.set_device(rank % num_gpus) 50 | # dist.init_process_group(backend=backend, **kwargs) 51 | deepspeed.init_distributed(dist_backend=backend) 52 | 53 | 54 | def _init_dist_mpi(backend, **kwargs): 55 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 56 | torch.cuda.set_device(local_rank) 57 | if 'MASTER_PORT' not in os.environ: 58 | # 29500 is torch.distributed default port 59 | os.environ['MASTER_PORT'] = '29500' 60 | if 'MASTER_ADDR' not in os.environ: 61 | raise KeyError('The environment variable MASTER_ADDR is not set') 62 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] 63 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] 64 | dist.init_process_group(backend=backend, **kwargs) 65 | 66 | 67 | def _init_dist_slurm(backend, port=None): 68 | """Initialize slurm distributed training environment. 69 | 70 | If argument ``port`` is not specified, then the master port will be system 71 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 72 | environment variable, then a default port ``29500`` will be used. 73 | 74 | Args: 75 | backend (str): Backend of torch.distributed. 76 | port (int, optional): Master port. Defaults to None. 77 | """ 78 | proc_id = int(os.environ['SLURM_PROCID']) 79 | ntasks = int(os.environ['SLURM_NTASKS']) 80 | node_list = os.environ['SLURM_NODELIST'] 81 | num_gpus = torch.cuda.device_count() 82 | torch.cuda.set_device(proc_id % num_gpus) 83 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 84 | # specify master port 85 | if port is not None: 86 | os.environ['MASTER_PORT'] = str(port) 87 | elif 'MASTER_PORT' in os.environ: 88 | pass # use MASTER_PORT in the environment variable 89 | else: 90 | # if torch.distributed default port(29500) is available 91 | # then use it, else find a free port 92 | if _is_free_port(29500): 93 | os.environ['MASTER_PORT'] = '29500' 94 | else: 95 | os.environ['MASTER_PORT'] = str(_find_free_port()) 96 | # use MASTER_ADDR in the environment variable if it already exists 97 | if 'MASTER_ADDR' not in os.environ: 98 | os.environ['MASTER_ADDR'] = addr 99 | os.environ['WORLD_SIZE'] = str(ntasks) 100 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 101 | os.environ['RANK'] = str(proc_id) 102 | # dist.init_process_group(backend=backend, timeout=timeout) 103 | deepspeed.init_distributed(dist_backend=backend) 104 | -------------------------------------------------------------------------------- /internvl/model/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import math 8 | 9 | import torch 10 | from transformers import AutoTokenizer 11 | 12 | from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel 13 | 14 | 15 | def split_model(num_layers, vit_alpha=0.5): 16 | device_map = {} 17 | world_size = torch.cuda.device_count() 18 | # Since the first GPU will be used for ViT, treat it as half a GPU. 19 | num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha)) 20 | num_layers_per_gpu = [num_layers_per_gpu] * world_size 21 | num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha)) 22 | layer_cnt = 0 23 | for i, num_layer in enumerate(num_layers_per_gpu): 24 | for j in range(num_layer): 25 | device_map[f'language_model.model.layers.{layer_cnt}'] = i 26 | layer_cnt += 1 27 | device_map['vision_model'] = 0 28 | device_map['mlp1'] = 0 29 | device_map['language_model.model.tok_embeddings'] = 0 30 | device_map['language_model.model.embed_tokens'] = 0 31 | device_map['language_model.output'] = 0 32 | device_map['language_model.model.norm'] = 0 33 | device_map['language_model.lm_head'] = 0 34 | device_map[f'language_model.model.layers.{num_layers - 1}'] = 0 35 | device_map['language_model.model.rotary_emb'] = 0 36 | 37 | return device_map 38 | 39 | 40 | def load_model_and_tokenizer(args): 41 | if args.auto: 42 | config = InternVLChatConfig.from_pretrained(args.checkpoint) 43 | num_hidden_layers = config.llm_config.num_hidden_layers 44 | device_map = split_model(num_hidden_layers) 45 | kwargs = {'device_map': device_map} if args.auto else {} 46 | tokenizer = AutoTokenizer.from_pretrained( 47 | args.checkpoint, trust_remote_code=True, use_fast=False 48 | ) 49 | model = InternVLChatModel.from_pretrained( 50 | args.checkpoint, 51 | low_cpu_mem_usage=True, 52 | torch_dtype=torch.bfloat16, 53 | load_in_8bit=args.load_in_8bit, 54 | load_in_4bit=args.load_in_4bit, 55 | **kwargs, 56 | ).eval() 57 | if not args.load_in_8bit and not args.load_in_4bit and not args.auto: 58 | model = model.cuda() 59 | return model, tokenizer 60 | -------------------------------------------------------------------------------- /internvl/model/internlm2/configuration_internlm2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """InternLM2 model configuration""" 17 | 18 | from transformers.configuration_utils import PretrainedConfig 19 | from transformers.utils import logging 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 24 | 25 | 26 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig 27 | class InternLM2Config(PretrainedConfig): 28 | r""" 29 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate 30 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a 31 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B. 32 | 33 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 34 | documentation from [`PretrainedConfig`] for more information. 35 | 36 | 37 | Args: 38 | vocab_size (`int`, *optional*, defaults to 32000): 39 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the 40 | `inputs_ids` passed when calling [`InternLM2Model`] 41 | hidden_size (`int`, *optional*, defaults to 4096): 42 | Dimension of the hidden representations. 43 | intermediate_size (`int`, *optional*, defaults to 11008): 44 | Dimension of the MLP representations. 45 | num_hidden_layers (`int`, *optional*, defaults to 32): 46 | Number of hidden layers in the Transformer encoder. 47 | num_attention_heads (`int`, *optional*, defaults to 32): 48 | Number of attention heads for each attention layer in the Transformer encoder. 49 | num_key_value_heads (`int`, *optional*): 50 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 51 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 52 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 53 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 54 | by meanpooling all the original heads within that group. For more details checkout [this 55 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 56 | `num_attention_heads`. 57 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 58 | The non-linear activation function (function or string) in the decoder. 59 | max_position_embeddings (`int`, *optional*, defaults to 2048): 60 | The maximum sequence length that this model might ever be used with. Typically set this to something large 61 | just in case (e.g., 512 or 1024 or 2048). 62 | initializer_range (`float`, *optional*, defaults to 0.02): 63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 64 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 65 | The epsilon used by the rms normalization layers. 66 | use_cache (`bool`, *optional*, defaults to `True`): 67 | Whether or not the model should return the last key/values attentions (not used by all models). Only 68 | relevant if `config.is_decoder=True`. 69 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 70 | Whether to tie weight embeddings 71 | Example: 72 | 73 | """ 74 | 75 | model_type = 'internlm2' 76 | _auto_class = 'AutoConfig' 77 | 78 | def __init__( # pylint: disable=W0102 79 | self, 80 | vocab_size=103168, 81 | hidden_size=4096, 82 | intermediate_size=11008, 83 | num_hidden_layers=32, 84 | num_attention_heads=32, 85 | num_key_value_heads=None, 86 | hidden_act='silu', 87 | max_position_embeddings=2048, 88 | initializer_range=0.02, 89 | rms_norm_eps=1e-6, 90 | use_cache=True, 91 | pad_token_id=0, 92 | bos_token_id=1, 93 | eos_token_id=2, 94 | tie_word_embeddings=False, 95 | bias=True, 96 | rope_theta=10000, 97 | rope_scaling=None, 98 | attn_implementation='eager', 99 | **kwargs, 100 | ): 101 | self.vocab_size = vocab_size 102 | self.max_position_embeddings = max_position_embeddings 103 | self.hidden_size = hidden_size 104 | self.intermediate_size = intermediate_size 105 | self.num_hidden_layers = num_hidden_layers 106 | self.num_attention_heads = num_attention_heads 107 | self.bias = bias 108 | 109 | if num_key_value_heads is None: 110 | num_key_value_heads = num_attention_heads 111 | self.num_key_value_heads = num_key_value_heads 112 | 113 | self.hidden_act = hidden_act 114 | self.initializer_range = initializer_range 115 | self.rms_norm_eps = rms_norm_eps 116 | self.use_cache = use_cache 117 | self.rope_theta = rope_theta 118 | self.rope_scaling = rope_scaling 119 | self._rope_scaling_validation() 120 | 121 | self.attn_implementation = attn_implementation 122 | if self.attn_implementation is None: 123 | self.attn_implementation = 'eager' 124 | super().__init__( 125 | pad_token_id=pad_token_id, 126 | bos_token_id=bos_token_id, 127 | eos_token_id=eos_token_id, 128 | tie_word_embeddings=tie_word_embeddings, 129 | **kwargs, 130 | ) 131 | 132 | def _rope_scaling_validation(self): 133 | """ 134 | Validate the `rope_scaling` configuration. 135 | """ 136 | if self.rope_scaling is None: 137 | return 138 | 139 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 140 | raise ValueError( 141 | '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, ' 142 | f'got {self.rope_scaling}' 143 | ) 144 | rope_scaling_type = self.rope_scaling.get('type', None) 145 | rope_scaling_factor = self.rope_scaling.get('factor', None) 146 | if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']: 147 | raise ValueError( 148 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 149 | ) 150 | if ( 151 | rope_scaling_factor is None 152 | or not isinstance(rope_scaling_factor, float) 153 | or rope_scaling_factor < 1.0 154 | ): 155 | raise ValueError( 156 | f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}" 157 | ) 158 | -------------------------------------------------------------------------------- /internvl/model/internlm2/tokenization_internlm2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tokenization classes for InternLM.""" 18 | import os 19 | from shutil import copyfile 20 | from typing import Any, Dict, List, Optional, Tuple 21 | 22 | import sentencepiece as spm 23 | from transformers.tokenization_utils import PreTrainedTokenizer 24 | from transformers.utils import logging 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} 29 | 30 | PRETRAINED_VOCAB_FILES_MAP = {} 31 | 32 | 33 | # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer 34 | class InternLM2Tokenizer(PreTrainedTokenizer): 35 | """ 36 | Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding. 37 | 38 | Args: 39 | vocab_file (`str`): 40 | Path to the vocabulary file. 41 | """ 42 | 43 | vocab_files_names = VOCAB_FILES_NAMES 44 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 45 | model_input_names = ['input_ids', 'attention_mask'] 46 | _auto_class = 'AutoTokenizer' 47 | 48 | def __init__( 49 | self, 50 | vocab_file, 51 | unk_token='', 52 | bos_token='', 53 | eos_token='', 54 | pad_token='', 55 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 56 | add_bos_token=True, 57 | add_eos_token=False, 58 | decode_with_prefix_space=False, 59 | clean_up_tokenization_spaces=False, 60 | **kwargs, 61 | ): 62 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 63 | self.vocab_file = vocab_file 64 | self.add_bos_token = add_bos_token 65 | self.add_eos_token = add_eos_token 66 | self.decode_with_prefix_space = decode_with_prefix_space 67 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) 68 | self.sp_model.Load(vocab_file) 69 | self._no_prefix_space_tokens = None 70 | super().__init__( 71 | bos_token=bos_token, 72 | eos_token=eos_token, 73 | unk_token=unk_token, 74 | pad_token=pad_token, 75 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 76 | **kwargs, 77 | ) 78 | 79 | @property 80 | def no_prefix_space_tokens(self): 81 | if self._no_prefix_space_tokens is None: 82 | vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) 83 | self._no_prefix_space_tokens = { 84 | i for i, tok in enumerate(vocab) if not tok.startswith('▁') 85 | } 86 | return self._no_prefix_space_tokens 87 | 88 | @property 89 | def vocab_size(self): 90 | """Returns vocab size""" 91 | return self.sp_model.get_piece_size() 92 | 93 | @property 94 | def bos_token_id(self) -> Optional[int]: 95 | return self.sp_model.bos_id() 96 | 97 | @property 98 | def eos_token_id(self) -> Optional[int]: 99 | return self.sp_model.eos_id() 100 | 101 | def get_vocab(self): 102 | """Returns vocab as a dict""" 103 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} 104 | vocab.update(self.added_tokens_encoder) 105 | return vocab 106 | 107 | def _tokenize(self, text): 108 | """Returns a tokenized string.""" 109 | return self.sp_model.encode(text, out_type=str) 110 | 111 | def _convert_token_to_id(self, token): 112 | """Converts a token (str) in an id using the vocab.""" 113 | return self.sp_model.piece_to_id(token) 114 | 115 | def _convert_id_to_token(self, index): 116 | """Converts an index (integer) in a token (str) using the vocab.""" 117 | token = self.sp_model.IdToPiece(index) 118 | return token 119 | 120 | def _maybe_add_prefix_space(self, tokens, decoded): 121 | if tokens and tokens[0] not in self.no_prefix_space_tokens: 122 | return ' ' + decoded 123 | else: 124 | return decoded 125 | 126 | def convert_tokens_to_string(self, tokens): 127 | """Converts a sequence of tokens (string) in a single string.""" 128 | current_sub_tokens = [] 129 | out_string = '' 130 | prev_is_special = False 131 | for token in tokens: 132 | # make sure that special tokens are not decoded using sentencepiece model 133 | if token in self.all_special_tokens: 134 | if not prev_is_special: 135 | out_string += ' ' 136 | out_string += self.sp_model.decode(current_sub_tokens) + token 137 | prev_is_special = True 138 | current_sub_tokens = [] 139 | else: 140 | current_sub_tokens.append(token) 141 | prev_is_special = False 142 | out_string += self.sp_model.decode(current_sub_tokens) 143 | out_string = self.clean_up_tokenization(out_string) 144 | out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) 145 | return out_string[1:] 146 | 147 | def save_vocabulary( 148 | self, save_directory, filename_prefix: Optional[str] = None 149 | ) -> Tuple[str]: 150 | """ 151 | Save the vocabulary and special tokens file to a directory. 152 | 153 | Args: 154 | save_directory (`str`): 155 | The directory in which to save the vocabulary. 156 | 157 | Returns: 158 | `Tuple(str)`: Paths to the files saved. 159 | """ 160 | if not os.path.isdir(save_directory): 161 | logger.error(f'Vocabulary path ({save_directory}) should be a directory') 162 | return 163 | out_vocab_file = os.path.join( 164 | save_directory, 165 | (filename_prefix + '-' if filename_prefix else '') 166 | + VOCAB_FILES_NAMES['vocab_file'], 167 | ) 168 | 169 | if os.path.abspath(self.vocab_file) != os.path.abspath( 170 | out_vocab_file 171 | ) and os.path.isfile(self.vocab_file): 172 | copyfile(self.vocab_file, out_vocab_file) 173 | elif not os.path.isfile(self.vocab_file): 174 | with open(out_vocab_file, 'wb') as fi: 175 | content_spiece_model = self.sp_model.serialized_model_proto() 176 | fi.write(content_spiece_model) 177 | 178 | return (out_vocab_file,) 179 | 180 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): 181 | if self.add_bos_token: 182 | bos_token_ids = [self.bos_token_id] 183 | else: 184 | bos_token_ids = [] 185 | 186 | output = bos_token_ids + token_ids_0 187 | 188 | if token_ids_1 is not None: 189 | output = output + token_ids_1 190 | 191 | if self.add_eos_token: 192 | output = output + [self.eos_token_id] 193 | 194 | return output 195 | 196 | def get_special_tokens_mask( 197 | self, 198 | token_ids_0: List[int], 199 | token_ids_1: Optional[List[int]] = None, 200 | already_has_special_tokens: bool = False, 201 | ) -> List[int]: 202 | """ 203 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 204 | special tokens using the tokenizer `prepare_for_model` method. 205 | 206 | Args: 207 | token_ids_0 (`List[int]`): 208 | List of IDs. 209 | token_ids_1 (`List[int]`, *optional*): 210 | Optional second list of IDs for sequence pairs. 211 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 212 | Whether or not the token list is already formatted with special tokens for the model. 213 | 214 | Returns: 215 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 216 | """ 217 | if already_has_special_tokens: 218 | return super().get_special_tokens_mask( 219 | token_ids_0=token_ids_0, 220 | token_ids_1=token_ids_1, 221 | already_has_special_tokens=True, 222 | ) 223 | 224 | if token_ids_1 is None: 225 | return [1] + ([0] * len(token_ids_0)) + [1] 226 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 227 | 228 | def create_token_type_ids_from_sequences( 229 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 230 | ) -> List[int]: 231 | """ 232 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make 233 | use of token type ids, therefore a list of zeros is returned. 234 | 235 | Args: 236 | token_ids_0 (`List[int]`): 237 | List of IDs. 238 | token_ids_1 (`List[int]`, *optional*): 239 | Optional second list of IDs for sequence pairs. 240 | 241 | Returns: 242 | `List[int]`: List of zeros. 243 | """ 244 | eos = [self.eos_token_id] 245 | 246 | if token_ids_1 is None: 247 | return len(token_ids_0 + eos) * [0] 248 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0] 249 | -------------------------------------------------------------------------------- /internvl/model/internlm2/tokenization_internlm2_fast.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """Tokenization Fast class for InternLM.""" 18 | import os 19 | from shutil import copyfile 20 | from typing import Any, Dict, Optional, Tuple 21 | 22 | from tokenizers import Tokenizer, decoders, normalizers, processors 23 | from tokenizers.models import BPE 24 | from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS, 25 | SentencePieceExtractor, 26 | SpmConverter) 27 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast 28 | from transformers.utils import logging 29 | 30 | from .tokenization_internlm2 import InternLM2Tokenizer 31 | 32 | logger = logging.get_logger(__name__) 33 | 34 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'} 35 | 36 | 37 | # Modified from transformers.convert_slow_tokenizer.LlamaConverter 38 | class InternLM2Converter(SpmConverter): 39 | handle_byte_fallback = True 40 | 41 | def vocab(self, proto): 42 | vocab = [ 43 | ('', 0.0), 44 | ('', 0.0), 45 | ('', 0.0), 46 | ] 47 | vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]] 48 | return vocab 49 | 50 | def unk_id(self, proto): 51 | unk_id = 0 52 | return unk_id 53 | 54 | def decoder(self, replacement, add_prefix_space): 55 | return decoders.Sequence( 56 | [ 57 | decoders.Replace('▁', ' '), 58 | decoders.ByteFallback(), 59 | decoders.Fuse(), 60 | decoders.Strip(content=' ', left=1), 61 | ] 62 | ) 63 | 64 | def tokenizer(self, proto): 65 | model_type = proto.trainer_spec.model_type 66 | vocab_scores = self.vocab(proto) 67 | # special tokens 68 | added_tokens = self.original_tokenizer.added_tokens_decoder 69 | for i in range(len(vocab_scores)): 70 | piece, score = vocab_scores[i] 71 | if i in added_tokens: 72 | vocab_scores[i] = (added_tokens[i].content, score) 73 | if model_type == 1: 74 | raise RuntimeError('InternLM2 is supposed to be a BPE model!') 75 | 76 | elif model_type == 2: 77 | _, merges = SentencePieceExtractor( 78 | self.original_tokenizer.vocab_file 79 | ).extract(vocab_scores) 80 | bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)} 81 | tokenizer = Tokenizer( 82 | BPE( 83 | bpe_vocab, 84 | merges, 85 | unk_token=proto.trainer_spec.unk_piece, 86 | fuse_unk=True, 87 | byte_fallback=True, 88 | ) 89 | ) 90 | tokenizer.add_special_tokens( 91 | [added_token for index, added_token in added_tokens.items()] 92 | ) 93 | else: 94 | raise Exception( 95 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm" 96 | ) 97 | 98 | return tokenizer 99 | 100 | def normalizer(self, proto): 101 | normalizers_list = [] 102 | if proto.normalizer_spec.add_dummy_prefix: 103 | normalizers_list.append(normalizers.Prepend(prepend='▁')) 104 | normalizers_list.append(normalizers.Replace(pattern=' ', content='▁')) 105 | return normalizers.Sequence(normalizers_list) 106 | 107 | def pre_tokenizer(self, replacement, add_prefix_space): 108 | return None 109 | 110 | 111 | SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter 112 | 113 | 114 | # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast 115 | class InternLM2TokenizerFast(PreTrainedTokenizerFast): 116 | vocab_files_names = VOCAB_FILES_NAMES 117 | slow_tokenizer_class = InternLM2Tokenizer 118 | padding_side = 'left' 119 | model_input_names = ['input_ids', 'attention_mask'] 120 | _auto_class = 'AutoTokenizer' 121 | 122 | def __init__( 123 | self, 124 | vocab_file, 125 | unk_token='', 126 | bos_token='', 127 | eos_token='', 128 | pad_token='', 129 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 130 | add_bos_token=True, 131 | add_eos_token=False, 132 | decode_with_prefix_space=False, 133 | clean_up_tokenization_spaces=False, 134 | **kwargs, 135 | ): 136 | super().__init__( 137 | vocab_file=vocab_file, 138 | unk_token=unk_token, 139 | bos_token=bos_token, 140 | eos_token=eos_token, 141 | pad_token=pad_token, 142 | sp_model_kwargs=sp_model_kwargs, 143 | add_bos_token=add_bos_token, 144 | add_eos_token=add_eos_token, 145 | decode_with_prefix_space=decode_with_prefix_space, 146 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 147 | **kwargs, 148 | ) 149 | self._add_bos_token = add_bos_token 150 | self._add_eos_token = add_eos_token 151 | self.update_post_processor() 152 | self.vocab_file = vocab_file 153 | 154 | @property 155 | def can_save_slow_tokenizer(self) -> bool: 156 | return os.path.isfile(self.vocab_file) if self.vocab_file else False 157 | 158 | def update_post_processor(self): 159 | """ 160 | Updates the underlying post processor with the current `bos_token` and `eos_token`. 161 | """ 162 | bos = self.bos_token 163 | bos_token_id = self.bos_token_id 164 | if bos is None and self.add_bos_token: 165 | raise ValueError('add_bos_token = True but bos_token = None') 166 | 167 | eos = self.eos_token 168 | eos_token_id = self.eos_token_id 169 | if eos is None and self.add_eos_token: 170 | raise ValueError('add_eos_token = True but eos_token = None') 171 | 172 | single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" 173 | pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" 174 | 175 | special_tokens = [] 176 | if self.add_bos_token: 177 | special_tokens.append((bos, bos_token_id)) 178 | if self.add_eos_token: 179 | special_tokens.append((eos, eos_token_id)) 180 | self._tokenizer.post_processor = processors.TemplateProcessing( 181 | single=single, pair=pair, special_tokens=special_tokens 182 | ) 183 | 184 | @property 185 | def add_eos_token(self): 186 | return self._add_eos_token 187 | 188 | @property 189 | def add_bos_token(self): 190 | return self._add_bos_token 191 | 192 | @add_eos_token.setter 193 | def add_eos_token(self, value): 194 | self._add_eos_token = value 195 | self.update_post_processor() 196 | 197 | @add_bos_token.setter 198 | def add_bos_token(self, value): 199 | self._add_bos_token = value 200 | self.update_post_processor() 201 | 202 | def save_vocabulary( 203 | self, save_directory: str, filename_prefix: Optional[str] = None 204 | ) -> Tuple[str]: 205 | if not self.can_save_slow_tokenizer: 206 | raise ValueError( 207 | 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow ' 208 | 'tokenizer.' 209 | ) 210 | 211 | if not os.path.isdir(save_directory): 212 | logger.error(f'Vocabulary path ({save_directory}) should be a directory') 213 | return 214 | out_vocab_file = os.path.join( 215 | save_directory, 216 | (filename_prefix + '-' if filename_prefix else '') 217 | + VOCAB_FILES_NAMES['vocab_file'], 218 | ) 219 | 220 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): 221 | copyfile(self.vocab_file, out_vocab_file) 222 | 223 | return (out_vocab_file,) 224 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .configuration_intern_vit import InternVisionConfig 8 | from .configuration_internvl_chat import InternVLChatConfig 9 | from .modeling_intern_vit import InternVisionModel 10 | from .modeling_internvl_chat import InternVLChatModel 11 | 12 | __all__ = [ 13 | 'InternVisionConfig', 14 | 'InternVisionModel', 15 | 'InternVLChatConfig', 16 | 'InternVLChatModel', 17 | ] 18 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/configuration_intern_vit.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import os 8 | from typing import Union 9 | 10 | from transformers.configuration_utils import PretrainedConfig 11 | from transformers.utils import logging 12 | 13 | logger = logging.get_logger(__name__) 14 | 15 | 16 | class InternVisionConfig(PretrainedConfig): 17 | r""" 18 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to 19 | instantiate a vision encoder according to the specified arguments, defining the model architecture. 20 | 21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 22 | documentation from [`PretrainedConfig`] for more information. 23 | 24 | Args: 25 | num_channels (`int`, *optional*, defaults to 3): 26 | Number of color channels in the input images (e.g., 3 for RGB). 27 | patch_size (`int`, *optional*, defaults to 14): 28 | The size (resolution) of each patch. 29 | image_size (`int`, *optional*, defaults to 224): 30 | The size (resolution) of each image. 31 | qkv_bias (`bool`, *optional*, defaults to `False`): 32 | Whether to add a bias to the queries and values in the self-attention layers. 33 | hidden_size (`int`, *optional*, defaults to 3200): 34 | Dimensionality of the encoder layers and the pooler layer. 35 | num_attention_heads (`int`, *optional*, defaults to 25): 36 | Number of attention heads for each attention layer in the Transformer encoder. 37 | intermediate_size (`int`, *optional*, defaults to 12800): 38 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. 39 | qk_normalization (`bool`, *optional*, defaults to `True`): 40 | Whether to normalize the queries and keys in the self-attention layers. 41 | num_hidden_layers (`int`, *optional*, defaults to 48): 42 | Number of hidden layers in the Transformer encoder. 43 | use_flash_attn (`bool`, *optional*, defaults to `True`): 44 | Whether to use flash attention mechanism. 45 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 46 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, 47 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. 48 | layer_norm_eps (`float`, *optional*, defaults to 1e-6): 49 | The epsilon used by the layer normalization layers. 50 | dropout (`float`, *optional*, defaults to 0.0): 51 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 52 | drop_path_rate (`float`, *optional*, defaults to 0.0): 53 | Dropout rate for stochastic depth. 54 | attention_dropout (`float`, *optional*, defaults to 0.0): 55 | The dropout ratio for the attention probabilities. 56 | initializer_range (`float`, *optional*, defaults to 0.02): 57 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 58 | initializer_factor (`float`, *optional*, defaults to 0.1): 59 | A factor for layer scale. 60 | """ 61 | 62 | model_type = 'intern_vit_6b' 63 | 64 | def __init__( 65 | self, 66 | num_channels=3, 67 | patch_size=14, 68 | image_size=224, 69 | qkv_bias=False, 70 | hidden_size=3200, 71 | num_attention_heads=25, 72 | intermediate_size=12800, 73 | qk_normalization=True, 74 | num_hidden_layers=48, 75 | use_flash_attn=True, 76 | hidden_act='gelu', 77 | norm_type='rms_norm', 78 | layer_norm_eps=1e-6, 79 | dropout=0.0, 80 | drop_path_rate=0.0, 81 | attention_dropout=0.0, 82 | initializer_range=0.02, 83 | initializer_factor=0.1, 84 | **kwargs, 85 | ): 86 | super().__init__(**kwargs) 87 | 88 | self.hidden_size = hidden_size 89 | self.intermediate_size = intermediate_size 90 | self.dropout = dropout 91 | self.drop_path_rate = drop_path_rate 92 | self.num_hidden_layers = num_hidden_layers 93 | self.num_attention_heads = num_attention_heads 94 | self.num_channels = num_channels 95 | self.patch_size = patch_size 96 | self.image_size = image_size 97 | self.initializer_range = initializer_range 98 | self.initializer_factor = initializer_factor 99 | self.attention_dropout = attention_dropout 100 | self.layer_norm_eps = layer_norm_eps 101 | self.hidden_act = hidden_act 102 | self.norm_type = norm_type 103 | self.qkv_bias = qkv_bias 104 | self.qk_normalization = qk_normalization 105 | self.use_flash_attn = use_flash_attn 106 | 107 | @classmethod 108 | def from_pretrained( 109 | cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs 110 | ) -> 'PretrainedConfig': 111 | config_dict, kwargs = cls.get_config_dict( 112 | pretrained_model_name_or_path, **kwargs 113 | ) 114 | 115 | if 'vision_config' in config_dict: 116 | config_dict = config_dict['vision_config'] 117 | 118 | if ( 119 | 'model_type' in config_dict 120 | and hasattr(cls, 'model_type') 121 | and config_dict['model_type'] != cls.model_type 122 | ): 123 | logger.warning( 124 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 125 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.' 126 | ) 127 | 128 | return cls.from_dict(config_dict, **kwargs) 129 | -------------------------------------------------------------------------------- /internvl/model/internvl_chat/configuration_internvl_chat.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import copy 8 | 9 | from transformers import AutoConfig, LlamaConfig, Qwen2Config 10 | from transformers.configuration_utils import PretrainedConfig 11 | from transformers.utils import logging 12 | 13 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config 14 | from internvl.model.phi3.configuration_phi3 import Phi3Config 15 | 16 | from .configuration_intern_vit import InternVisionConfig 17 | 18 | logger = logging.get_logger(__name__) 19 | 20 | 21 | class InternVLChatConfig(PretrainedConfig): 22 | model_type = 'internvl_chat' 23 | is_composition = True 24 | 25 | def __init__( 26 | self, 27 | vision_config=None, 28 | llm_config=None, 29 | use_backbone_lora=0, 30 | use_llm_lora=0, 31 | pad2square=False, 32 | select_layer=-1, 33 | force_image_size=None, 34 | downsample_ratio=0.5, 35 | template=None, 36 | dynamic_image_size=False, 37 | use_thumbnail=False, 38 | ps_version='v1', 39 | min_dynamic_patch=1, 40 | max_dynamic_patch=6, 41 | **kwargs, 42 | ): 43 | super().__init__(**kwargs) 44 | 45 | if vision_config is None: 46 | vision_config = {'architectures': ['InternVisionModel']} 47 | logger.info( 48 | 'vision_config is None. Initializing the InternVisionConfig with default values.' 49 | ) 50 | 51 | if llm_config is None: 52 | # TODO: There might still be a bug in transformers version 4.44 and above. 53 | llm_config = {'architectures': ['']} 54 | logger.info( 55 | 'llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).' 56 | ) 57 | 58 | self.vision_config = InternVisionConfig(**vision_config) 59 | if llm_config['architectures'][0] == 'LlamaForCausalLM': 60 | self.llm_config = LlamaConfig(**llm_config) 61 | elif llm_config['architectures'][0] == 'InternLM2ForCausalLM': 62 | self.llm_config = InternLM2Config(**llm_config) 63 | elif llm_config['architectures'][0] == 'Phi3ForCausalLM': 64 | self.llm_config = Phi3Config(**llm_config) 65 | elif llm_config['architectures'][0] == 'Qwen2ForCausalLM': 66 | self.llm_config = Qwen2Config(**llm_config) 67 | else: 68 | raise ValueError( 69 | 'Unsupported architecture: {}'.format(llm_config['architectures'][0]) 70 | ) 71 | self.use_backbone_lora = use_backbone_lora 72 | self.use_llm_lora = use_llm_lora 73 | self.pad2square = pad2square 74 | self.select_layer = select_layer 75 | self.force_image_size = force_image_size 76 | self.downsample_ratio = downsample_ratio 77 | self.template = template 78 | self.dynamic_image_size = dynamic_image_size 79 | self.use_thumbnail = use_thumbnail 80 | self.ps_version = ps_version # pixel shuffle version 81 | self.min_dynamic_patch = min_dynamic_patch 82 | self.max_dynamic_patch = max_dynamic_patch 83 | 84 | self.hidden_size = self.llm_config.hidden_size 85 | # By default, we use tie_word_embeddings=False for models of all sizes. 86 | self.tie_word_embeddings = False 87 | self.llm_config.tie_word_embeddings = self.tie_word_embeddings 88 | 89 | logger.info(f'vision_select_layer: {self.select_layer}') 90 | logger.info(f'ps_version: {self.ps_version}') 91 | logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') 92 | logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') 93 | 94 | def to_dict(self): 95 | """ 96 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. 97 | 98 | Returns: 99 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, 100 | """ 101 | output = copy.deepcopy(self.__dict__) 102 | output['vision_config'] = self.vision_config.to_dict() 103 | output['llm_config'] = self.llm_config.to_dict() 104 | output['model_type'] = self.__class__.model_type 105 | output['use_backbone_lora'] = self.use_backbone_lora 106 | output['use_llm_lora'] = self.use_llm_lora 107 | output['select_layer'] = self.select_layer 108 | output['force_image_size'] = self.force_image_size 109 | output['downsample_ratio'] = self.downsample_ratio 110 | output['template'] = self.template 111 | output['dynamic_image_size'] = self.dynamic_image_size 112 | output['use_thumbnail'] = self.use_thumbnail 113 | output['ps_version'] = self.ps_version 114 | output['min_dynamic_patch'] = self.min_dynamic_patch 115 | output['max_dynamic_patch'] = self.max_dynamic_patch 116 | 117 | return output 118 | -------------------------------------------------------------------------------- /internvl/model/phi3/configuration_phi3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License atd 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Phi-3 model configuration""" 16 | 17 | 18 | from transformers.configuration_utils import PretrainedConfig 19 | from transformers.utils import logging 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json', 25 | 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json', 26 | } 27 | 28 | 29 | class Phi3Config(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 32 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 33 | defaults will yield a similar configuration to that of the 34 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). 35 | 36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 37 | documentation from [`PretrainedConfig`] for more information. 38 | 39 | Args: 40 | vocab_size (`int`, *optional*, defaults to 32064): 41 | Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the 42 | `inputs_ids` passed when calling [`Phi3Model`]. 43 | hidden_size (`int`, *optional*, defaults to 3072): 44 | Dimension of the hidden representations. 45 | intermediate_size (`int`, *optional*, defaults to 8192): 46 | Dimension of the MLP representations. 47 | num_hidden_layers (`int`, *optional*, defaults to 32): 48 | Number of hidden layers in the Transformer decoder. 49 | num_attention_heads (`int`, *optional*, defaults to 32): 50 | Number of attention heads for each attention layer in the Transformer decoder. 51 | num_key_value_heads (`int`, *optional*): 52 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 53 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 54 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 55 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 56 | by meanpooling all the original heads within that group. For more details checkout [this 57 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 58 | `num_attention_heads`. 59 | resid_pdrop (`float`, *optional*, defaults to 0.0): 60 | Dropout probability for mlp outputs. 61 | embd_pdrop (`int`, *optional*, defaults to 0.0): 62 | The dropout ratio for the embeddings. 63 | attention_dropout (`float`, *optional*, defaults to 0.0): 64 | The dropout ratio after computing the attention scores. 65 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 66 | The non-linear activation function (function or string) in the decoder. 67 | max_position_embeddings (`int`, *optional*, defaults to 4096): 68 | The maximum sequence length that this model might ever be used with. 69 | original_max_position_embeddings (`int`, *optional*, defaults to 4096): 70 | The maximum sequence length that this model was trained with. This is used to determine the size of the 71 | original RoPE embeddings when using long scaling. 72 | initializer_range (`float`, *optional*, defaults to 0.02): 73 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 74 | rms_norm_eps (`float`, *optional*, defaults to 1e-05): 75 | The epsilon value used for the RMSNorm. 76 | use_cache (`bool`, *optional*, defaults to `True`): 77 | Whether or not the model should return the last key/values attentions (not used by all models). Only 78 | relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether to tie weight embeddings 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | rope_scaling (`dict`, *optional*): 84 | The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must 85 | contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and 86 | the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size 87 | divided by the number of attention heads divided by 2. 88 | bos_token_id (`int`, *optional*, defaults to 1): 89 | The id of the "beginning-of-sequence" token. 90 | eos_token_id (`int`, *optional*, defaults to 32000): 91 | The id of the "end-of-sequence" token. 92 | pad_token_id (`int`, *optional*, defaults to 32000): 93 | The id of the padding token. 94 | sliding_window (`int`, *optional*): 95 | Sliding window attention window size. If `None`, no sliding window is applied. 96 | 97 | Example: 98 | 99 | ```python 100 | >>> from transformers import Phi3Model, Phi3Config 101 | 102 | >>> # Initializing a Phi-3 style configuration 103 | >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") 104 | 105 | >>> # Initializing a model from the configuration 106 | >>> model = Phi3Model(configuration) 107 | 108 | >>> # Accessing the model configuration 109 | >>> configuration = model.config 110 | ```""" 111 | 112 | model_type = 'phi3' 113 | keys_to_ignore_at_inference = ['past_key_values'] 114 | 115 | def __init__( 116 | self, 117 | vocab_size=32064, 118 | hidden_size=3072, 119 | intermediate_size=8192, 120 | num_hidden_layers=32, 121 | num_attention_heads=32, 122 | num_key_value_heads=None, 123 | resid_pdrop=0.0, 124 | embd_pdrop=0.0, 125 | attention_dropout=0.0, 126 | hidden_act='silu', 127 | max_position_embeddings=4096, 128 | original_max_position_embeddings=4096, 129 | initializer_range=0.02, 130 | rms_norm_eps=1e-5, 131 | use_cache=True, 132 | tie_word_embeddings=False, 133 | rope_theta=10000.0, 134 | rope_scaling=None, 135 | bos_token_id=1, 136 | eos_token_id=32000, 137 | pad_token_id=32000, 138 | sliding_window=None, 139 | **kwargs, 140 | ): 141 | self.vocab_size = vocab_size 142 | self.hidden_size = hidden_size 143 | self.intermediate_size = intermediate_size 144 | self.num_hidden_layers = num_hidden_layers 145 | self.num_attention_heads = num_attention_heads 146 | 147 | if num_key_value_heads is None: 148 | num_key_value_heads = num_attention_heads 149 | 150 | self.num_key_value_heads = num_key_value_heads 151 | self.resid_pdrop = resid_pdrop 152 | self.embd_pdrop = embd_pdrop 153 | self.attention_dropout = attention_dropout 154 | self.hidden_act = hidden_act 155 | self.max_position_embeddings = max_position_embeddings 156 | self.original_max_position_embeddings = original_max_position_embeddings 157 | self.initializer_range = initializer_range 158 | self.rms_norm_eps = rms_norm_eps 159 | self.use_cache = use_cache 160 | self.rope_theta = rope_theta 161 | self.rope_scaling = rope_scaling 162 | self._rope_scaling_validation() 163 | self.sliding_window = sliding_window 164 | 165 | super().__init__( 166 | bos_token_id=bos_token_id, 167 | eos_token_id=eos_token_id, 168 | pad_token_id=pad_token_id, 169 | tie_word_embeddings=tie_word_embeddings, 170 | **kwargs, 171 | ) 172 | 173 | def _rope_scaling_validation(self): 174 | """ 175 | Validate the `rope_scaling` configuration. 176 | """ 177 | if self.rope_scaling is None: 178 | return 179 | 180 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: 181 | raise ValueError( 182 | '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, ' 183 | f'got {self.rope_scaling}' 184 | ) 185 | rope_scaling_type = self.rope_scaling.get('type', None) 186 | rope_scaling_short_factor = self.rope_scaling.get('short_factor', None) 187 | rope_scaling_long_factor = self.rope_scaling.get('long_factor', None) 188 | if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']: 189 | raise ValueError( 190 | f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}" 191 | ) 192 | if not ( 193 | isinstance(rope_scaling_short_factor, list) 194 | and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) 195 | ): 196 | raise ValueError( 197 | f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" 198 | ) 199 | if ( 200 | not len(rope_scaling_short_factor) 201 | == self.hidden_size // self.num_attention_heads // 2 202 | ): 203 | raise ValueError( 204 | f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}" 205 | ) 206 | if not ( 207 | isinstance(rope_scaling_long_factor, list) 208 | and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) 209 | ): 210 | raise ValueError( 211 | f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" 212 | ) 213 | if ( 214 | not len(rope_scaling_long_factor) 215 | == self.hidden_size // self.num_attention_heads // 2 216 | ): 217 | raise ValueError( 218 | f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}" 219 | ) 220 | -------------------------------------------------------------------------------- /internvl/patch/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from .internlm2_packed_training_patch import replace_internlm2_attention_class 8 | from .internvit_liger_monkey_patch import apply_liger_kernel_to_internvit 9 | from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn 10 | from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 11 | from .llama_packed_training_patch import replace_llama_attention_class 12 | from .llama_rmsnorm_monkey_patch import \ 13 | replace_llama_rmsnorm_with_fused_rmsnorm 14 | from .pad_data_collator import (concat_pad_data_collator, 15 | dpo_concat_pad_data_collator, 16 | pad_data_collator) 17 | from .phi3_packed_training_patch import replace_phi3_attention_class 18 | from .qwen2_packed_training_patch import replace_qwen2_attention_class 19 | from .train_dataloader_patch import replace_train_dataloader 20 | from .train_sampler_patch import replace_train_sampler 21 | 22 | __all__ = [ 23 | 'replace_llama_attn_with_flash_attn', 24 | 'replace_llama_rmsnorm_with_fused_rmsnorm', 25 | 'replace_llama2_attn_with_flash_attn', 26 | 'replace_train_sampler', 27 | 'replace_train_dataloader', 28 | 'replace_internlm2_attention_class', 29 | 'replace_qwen2_attention_class', 30 | 'replace_phi3_attention_class', 31 | 'replace_llama_attention_class', 32 | 'pad_data_collator', 33 | 'dpo_concat_pad_data_collator', 34 | 'concat_pad_data_collator', 35 | 'apply_liger_kernel_to_internvit', 36 | ] 37 | -------------------------------------------------------------------------------- /internvl/patch/internlm2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 9 | 10 | from internvl.model.internlm2.modeling_internlm2 import ( 11 | INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2, 12 | apply_rotary_pos_emb) 13 | 14 | 15 | # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2 16 | class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2): 17 | 18 | def _flash_attention_forward( 19 | self, 20 | query_states, 21 | key_states, 22 | value_states, 23 | attention_mask, 24 | query_length, 25 | dropout=0.0, 26 | softmax_scale=None, 27 | ): 28 | """ 29 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 30 | first unpad the input, then computes the attention scores and pad the final attention scores. 31 | 32 | Args: 33 | query_states (`torch.Tensor`): 34 | Input query states to be passed to Flash Attention API 35 | key_states (`torch.Tensor`): 36 | Input key states to be passed to Flash Attention API 37 | value_states (`torch.Tensor`): 38 | Input value states to be passed to Flash Attention API 39 | attention_mask (`torch.Tensor`): 40 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths 41 | of the sequences in the batch. 42 | dropout (`int`, *optional*): 43 | Attention dropout 44 | softmax_scale (`float`, *optional*): 45 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 46 | """ 47 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 48 | query_states = query_states.squeeze(0) 49 | key_states = key_states.squeeze(0) 50 | value_states = value_states.squeeze(0) 51 | cu_seqlens = attention_mask.squeeze(0) 52 | 53 | with torch.no_grad(): 54 | max_seqlen = max( 55 | [ 56 | cu_seqlens[idx + 1] - cu_seqlens[idx] 57 | for idx in range(cu_seqlens.size(0) - 1) 58 | ] 59 | ).item() 60 | 61 | # Contains at least one padding token in the sequence 62 | causal = self.is_causal and query_length != 1 63 | attn_output = flash_attn_varlen_func( 64 | q=query_states, 65 | k=key_states, 66 | v=value_states, 67 | cu_seqlens_q=cu_seqlens, 68 | cu_seqlens_k=cu_seqlens, 69 | max_seqlen_q=max_seqlen, 70 | max_seqlen_k=max_seqlen, 71 | dropout_p=dropout, 72 | softmax_scale=softmax_scale, 73 | causal=causal, 74 | ) 75 | 76 | query_states = query_states.unsqueeze(0) 77 | key_states = key_states.unsqueeze(0) 78 | value_states = value_states.unsqueeze(0) 79 | return attn_output 80 | 81 | 82 | def replace_internlm2_attention_class(): 83 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = ( 84 | InternLM2FlashAttention2ForPackedTraining 85 | ) 86 | print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!') 87 | -------------------------------------------------------------------------------- /internvl/patch/internvit_liger_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | 8 | def apply_liger_kernel_to_internvit() -> None: 9 | from liger_kernel.transformers.layer_norm import LigerLayerNorm 10 | from liger_kernel.transformers.rms_norm import LigerRMSNorm 11 | 12 | from internvl.model.internvl_chat import modeling_intern_vit 13 | 14 | modeling_intern_vit.NORM2FN['rms_norm'] = LigerRMSNorm 15 | modeling_intern_vit.NORM2FN['layer_norm'] = LigerLayerNorm 16 | print('Liger kernel applied to InternViT') 17 | -------------------------------------------------------------------------------- /internvl/patch/llama2_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is copied from: https://github.com/lm-sys/FastChat 3 | """ 4 | 5 | import warnings 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | from flash_attn import __version__ as flash_attn_version 10 | from flash_attn.bert_padding import pad_input, unpad_input 11 | from flash_attn.flash_attn_interface import (flash_attn_func, 12 | flash_attn_varlen_kvpacked_func) 13 | from transformers.models.llama.modeling_llama import (LlamaAttention, 14 | LlamaModel, rotate_half) 15 | 16 | 17 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids): 18 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1] 19 | gather_indices = gather_indices.repeat( 20 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3] 21 | ) 22 | bsz = gather_indices.shape[0] 23 | cos, sin = ( 24 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices) 25 | for x in cos_sin 26 | ) 27 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k)) 28 | return q, k 29 | 30 | 31 | def forward( 32 | self, 33 | hidden_states: torch.Tensor, 34 | attention_mask: Optional[torch.Tensor] = None, 35 | position_ids: Optional[torch.Tensor] = None, 36 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 37 | output_attentions: bool = False, 38 | use_cache: bool = False, 39 | padding_mask: Optional[torch.Tensor] = None, 40 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 41 | if output_attentions: 42 | warnings.warn( 43 | 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.' 44 | ) 45 | 46 | bsz, q_len, _ = hidden_states.size() 47 | kv_heads = getattr(self, 'num_key_value_heads', self.num_heads) 48 | 49 | q, k, v = ( 50 | op(hidden_states).view(bsz, q_len, nh, self.head_dim) 51 | for op, nh in ( 52 | (self.q_proj, self.num_heads), 53 | (self.k_proj, kv_heads), 54 | (self.v_proj, kv_heads), 55 | ) 56 | ) 57 | # shape: (b, s, num_heads, head_dim) 58 | 59 | kv_seq_len = k.shape[1] 60 | past_kv_len = 0 61 | if past_key_value is not None: 62 | past_kv_len = past_key_value[0].shape[2] 63 | kv_seq_len += past_kv_len 64 | 65 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len) 66 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids) 67 | 68 | if past_key_value is not None: 69 | assert ( 70 | flash_attn_version >= '2.1.0' 71 | ), 'past_key_value support requires flash-attn >= 2.1.0' 72 | # reuse k, v 73 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1) 74 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1) 75 | 76 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None 77 | 78 | if attention_mask is None: 79 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view( 80 | bsz, q_len, -1 81 | ) 82 | else: 83 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:]) 84 | # We can skip concat and call unpad twice but seems better to call unpad only once. 85 | kv, _, cu_k_lens, max_k = unpad_input( 86 | torch.stack((k, v), dim=2), attention_mask 87 | ) 88 | output_unpad = flash_attn_varlen_kvpacked_func( 89 | q, 90 | kv, 91 | cu_q_lens, 92 | cu_k_lens, 93 | max_s, 94 | max_k, 95 | 0.0, 96 | softmax_scale=None, 97 | causal=True, 98 | ) 99 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 100 | output = pad_input(output_unpad, indices, bsz, q_len) 101 | 102 | return self.o_proj(output), None, past_key_value 103 | 104 | 105 | # Disable the transformation of the attention mask in LlamaModel as flash attention 106 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward. 107 | def _prepare_decoder_attention_mask( 108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 109 | ): 110 | # [bsz, seq_len] 111 | if past_key_values_length > 0 and attention_mask is not None: 112 | attention_mask = torch.cat( 113 | ( 114 | torch.full( 115 | (input_shape[0], past_key_values_length), 116 | True, 117 | dtype=attention_mask.dtype, 118 | device=attention_mask.device, 119 | ), 120 | attention_mask, 121 | ), 122 | dim=-1, 123 | ) 124 | 125 | if attention_mask is not None and torch.all(attention_mask): 126 | return None # This uses the faster call when training with full samples 127 | 128 | return attention_mask 129 | 130 | 131 | def replace_llama2_attn_with_flash_attn(): 132 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 133 | if cuda_major < 8: 134 | warnings.warn( 135 | 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.' 136 | 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593' 137 | ) 138 | 139 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 140 | LlamaAttention.forward = forward 141 | 142 | 143 | def test(): 144 | from fastchat.train.llama_flash_attn_monkey_patch import \ 145 | forward as fastchat_forward 146 | from transformers.models.llama.configuration_llama import LlamaConfig 147 | 148 | config = LlamaConfig( 149 | hidden_size=1024, 150 | intermediate_size=128, 151 | num_hidden_layers=1, 152 | num_attention_heads=8, 153 | max_position_embeddings=16, 154 | ) 155 | device = torch.device('cuda') 156 | model = LlamaModel(config) 157 | attn = LlamaAttention(config).to(device).half() 158 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings 159 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view( 160 | -1, seqlen 161 | ) 162 | 163 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 164 | for i in range(4): 165 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 166 | if i: 167 | mask[0, -i:] = False 168 | mask[1, :i] = False 169 | 170 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0) 171 | ref, _, _ = attn.forward( 172 | hidden, attention_mask=lmask, position_ids=position_ids 173 | ) 174 | 175 | fast, _, _ = fastchat_forward( 176 | attn, hidden, attention_mask=mask, position_ids=position_ids 177 | ) 178 | 179 | lmask = _prepare_decoder_attention_mask( 180 | model, mask, hidden.shape[:2], hidden, 0 181 | ) 182 | test, _, _ = forward( 183 | attn, hidden, attention_mask=lmask, position_ids=position_ids 184 | ) 185 | 186 | print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}') 187 | print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}') 188 | print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}') 189 | print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}') 190 | print(f'allclose(fast, test) = {torch.allclose(fast, test)}') 191 | 192 | with torch.no_grad(): 193 | # Also check that past_kv is handled properly 194 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device) 195 | part_len = seqlen // 4 196 | assert part_len * 4 == seqlen 197 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device) 198 | mask[0, -2:] = False 199 | lmask = _prepare_decoder_attention_mask( 200 | model, mask, hidden.shape[:2], hidden, 0 201 | ) 202 | oneshot, _, _ = forward( 203 | attn, hidden, attention_mask=lmask, position_ids=position_ids 204 | ) 205 | parts = [] 206 | past_kv, past_kv_len = None, 0 207 | for i in range(4): 208 | start = part_len * i 209 | end = start + part_len 210 | hidden_part = hidden[:, start:end, ...] 211 | lmask = _prepare_decoder_attention_mask( 212 | model, 213 | mask[:, start:end], 214 | hidden_part.shape[:2], 215 | hidden_part, 216 | past_kv_len, 217 | ) 218 | part, _, past_kv = forward( 219 | attn, 220 | hidden_part.clone(), 221 | attention_mask=lmask, 222 | position_ids=position_ids[:, start:end], 223 | past_key_value=past_kv, 224 | use_cache=True, 225 | ) 226 | parts.append(part) 227 | past_kv_len = past_kv[0].shape[2] 228 | 229 | print( 230 | f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}' 231 | ) 232 | print( 233 | f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}' 234 | ) 235 | 236 | 237 | if __name__ == '__main__': 238 | test() 239 | -------------------------------------------------------------------------------- /internvl/patch/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import math 8 | from typing import Optional, Tuple 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import transformers 13 | from torch import nn 14 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 15 | 16 | 17 | def forward( 18 | self, 19 | hidden_states: torch.Tensor, 20 | attention_mask: Optional[torch.Tensor] = None, 21 | position_ids: Optional[torch.Tensor] = None, 22 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 23 | output_attentions: bool = False, 24 | use_cache: bool = False, 25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 26 | """Input shape: Batch x Time x Channel 27 | 28 | attention_mask: [bsz, q_len] 29 | """ 30 | from einops import rearrange 31 | 32 | try: # v1 33 | from flash_attn.flash_attn_interface import \ 34 | flash_attn_unpadded_qkvpacked_func 35 | except: # v2 36 | from flash_attn.flash_attn_interface import ( 37 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func, 38 | ) 39 | from flash_attn.bert_padding import pad_input, unpad_input 40 | 41 | bsz, q_len, _ = hidden_states.size() 42 | 43 | query_states = ( 44 | self.q_proj(hidden_states) 45 | .view(bsz, q_len, self.num_heads, self.head_dim) 46 | .transpose(1, 2) 47 | ) 48 | key_states = ( 49 | self.k_proj(hidden_states) 50 | .view(bsz, q_len, self.num_heads, self.head_dim) 51 | .transpose(1, 2) 52 | ) 53 | value_states = ( 54 | self.v_proj(hidden_states) 55 | .view(bsz, q_len, self.num_heads, self.head_dim) 56 | .transpose(1, 2) 57 | ) 58 | # [bsz, q_len, nh, hd] 59 | # [bsz, nh, q_len, hd] 60 | 61 | kv_seq_len = key_states.shape[-2] 62 | assert past_key_value is None, 'past_key_value is not supported' 63 | 64 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 65 | query_states, key_states = apply_rotary_pos_emb( 66 | query_states, key_states, cos, sin, position_ids 67 | ) 68 | # [bsz, nh, t, hd] 69 | assert not output_attentions, 'output_attentions is not supported' 70 | assert not use_cache, 'use_cache is not supported' 71 | 72 | # Flash attention codes from 73 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 74 | 75 | # transform the data into the format required by flash attention 76 | qkv = torch.stack( 77 | [query_states, key_states, value_states], dim=2 78 | ) # [bsz, nh, 3, q_len, hd] 79 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 80 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 81 | # the attention_mask should be the same as the key_padding_mask 82 | key_padding_mask = attention_mask 83 | 84 | if key_padding_mask is None: 85 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 86 | max_s = q_len 87 | cu_q_lens = torch.arange( 88 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 89 | ) 90 | output = flash_attn_unpadded_qkvpacked_func( 91 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 92 | ) 93 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 94 | else: 95 | nheads = qkv.shape[-2] 96 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 97 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 98 | x_unpad = rearrange( 99 | x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads 100 | ) 101 | output_unpad = flash_attn_unpadded_qkvpacked_func( 102 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 103 | ) 104 | output = rearrange( 105 | pad_input( 106 | rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len 107 | ), 108 | 'b s (h d) -> b s h d', 109 | h=nheads, 110 | ) 111 | return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None 112 | 113 | 114 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 115 | # requires the attention mask to be the same as the key_padding_mask 116 | def _prepare_decoder_attention_mask( 117 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 118 | ): 119 | # [bsz, seq_len] 120 | return attention_mask 121 | 122 | 123 | def forward_2( 124 | self, 125 | hidden_states: torch.Tensor, 126 | attention_mask: Optional[torch.Tensor] = None, 127 | position_ids: Optional[torch.LongTensor] = None, 128 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 129 | output_attentions: bool = False, 130 | use_cache: bool = False, 131 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 132 | bsz, q_len, _ = hidden_states.size() 133 | 134 | query_states = ( 135 | self.q_proj(hidden_states) 136 | .view(bsz, q_len, self.num_heads, self.head_dim) 137 | .transpose(1, 2) 138 | ) 139 | key_states = ( 140 | self.k_proj(hidden_states) 141 | .view(bsz, q_len, self.num_heads, self.head_dim) 142 | .transpose(1, 2) 143 | ) 144 | value_states = ( 145 | self.v_proj(hidden_states) 146 | .view(bsz, q_len, self.num_heads, self.head_dim) 147 | .transpose(1, 2) 148 | ) 149 | 150 | kv_seq_len = key_states.shape[-2] 151 | if past_key_value is not None: 152 | kv_seq_len += past_key_value[0].shape[-2] 153 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 154 | query_states, key_states = apply_rotary_pos_emb( 155 | query_states, key_states, cos, sin, position_ids 156 | ) 157 | 158 | assert not output_attentions, 'output_attentions is not supported' 159 | assert not use_cache, 'use_cache is not supported' 160 | assert past_key_value is None, 'past_key_value is not supported' 161 | 162 | if past_key_value is not None: 163 | # reuse k, v, self_attention 164 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 165 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 166 | 167 | past_key_value = (key_states, value_states) if use_cache else None 168 | if self.training: 169 | attn_output = F.scaled_dot_product_attention( 170 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True 171 | ) 172 | attn_weights = None 173 | else: 174 | attn_weights = torch.matmul( 175 | query_states, key_states.transpose(2, 3) 176 | ) / math.sqrt(self.head_dim) 177 | 178 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 179 | raise ValueError( 180 | f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is' 181 | f' {attn_weights.size()}' 182 | ) 183 | 184 | if attention_mask is not None: 185 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 186 | raise ValueError( 187 | f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}' 188 | ) 189 | attn_weights = attn_weights + attention_mask 190 | attn_weights = torch.max( 191 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 192 | ) 193 | 194 | # upcast attention to fp32 195 | attn_weights = nn.functional.softmax( 196 | attn_weights, dim=-1, dtype=torch.float32 197 | ).to(query_states.dtype) 198 | attn_output = torch.matmul(attn_weights, value_states) 199 | 200 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 201 | raise ValueError( 202 | f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is' 203 | f' {attn_output.size()}' 204 | ) 205 | 206 | attn_output = attn_output.transpose(1, 2) 207 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 208 | 209 | attn_output = self.o_proj(attn_output) 210 | 211 | if not output_attentions: 212 | attn_weights = None 213 | 214 | return attn_output, attn_weights, past_key_value 215 | 216 | 217 | def replace_llama_attn_with_flash_attn(): 218 | if hasattr(F, 'scaled_dot_product_attention'): 219 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2 220 | else: 221 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 222 | _prepare_decoder_attention_mask 223 | ) 224 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 225 | -------------------------------------------------------------------------------- /internvl/patch/llama_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 9 | from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES, 10 | LlamaFlashAttention2) 11 | 12 | 13 | # Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2 14 | class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2): 15 | 16 | def _flash_attention_forward( 17 | self, 18 | query_states, 19 | key_states, 20 | value_states, 21 | attention_mask, 22 | query_length, 23 | dropout=0.0, 24 | softmax_scale=None, 25 | use_sliding_windows=False, 26 | ): 27 | """ 28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 29 | first unpad the input, then computes the attention scores and pad the final attention scores. 30 | 31 | Args: 32 | query_states (`torch.Tensor`): 33 | Input query states to be passed to Flash Attention API 34 | key_states (`torch.Tensor`): 35 | Input key states to be passed to Flash Attention API 36 | value_states (`torch.Tensor`): 37 | Input value states to be passed to Flash Attention API 38 | attention_mask (`torch.Tensor`): 39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 40 | position of padding tokens and 1 for the position of non-padding tokens. 41 | dropout (`int`, *optional*): 42 | Attention dropout 43 | softmax_scale (`float`, *optional*): 44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 45 | use_sliding_windows (`bool`, *optional*): 46 | Whether to activate sliding window attention. 47 | """ 48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 49 | query_states = query_states.squeeze(0) 50 | key_states = key_states.squeeze(0) 51 | value_states = value_states.squeeze(0) 52 | cu_seqlens = attention_mask.squeeze(0) 53 | 54 | with torch.no_grad(): 55 | max_seqlen = max( 56 | [ 57 | cu_seqlens[idx + 1] - cu_seqlens[idx] 58 | for idx in range(cu_seqlens.size(0) - 1) 59 | ] 60 | ).item() 61 | 62 | if not self._flash_attn_uses_top_left_mask: 63 | causal = self.is_causal 64 | else: 65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 66 | causal = self.is_causal and query_length != 1 67 | 68 | # Decide whether to use SWA or not by layer index. 69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 70 | use_sliding_windows = False 71 | 72 | if not use_sliding_windows: 73 | attn_output = flash_attn_varlen_func( 74 | q=query_states, 75 | k=key_states, 76 | v=value_states, 77 | cu_seqlens_q=cu_seqlens, 78 | cu_seqlens_k=cu_seqlens, 79 | max_seqlen_q=max_seqlen, 80 | max_seqlen_k=max_seqlen, 81 | dropout_p=dropout, 82 | softmax_scale=softmax_scale, 83 | causal=causal, 84 | ) 85 | else: 86 | attn_output = flash_attn_varlen_func( 87 | q=query_states, 88 | k=key_states, 89 | v=value_states, 90 | cu_seqlens_q=cu_seqlens, 91 | cu_seqlens_k=cu_seqlens, 92 | max_seqlen_q=max_seqlen, 93 | max_seqlen_k=max_seqlen, 94 | dropout_p=dropout, 95 | softmax_scale=softmax_scale, 96 | causal=causal, 97 | window_size=(self.config.sliding_window, self.config.sliding_window), 98 | ) 99 | 100 | query_states = query_states.unsqueeze(0) 101 | key_states = key_states.unsqueeze(0) 102 | value_states = value_states.unsqueeze(0) 103 | return attn_output 104 | 105 | 106 | def replace_llama_attention_class(): 107 | LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining 108 | print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!') 109 | -------------------------------------------------------------------------------- /internvl/patch/llama_rmsnorm_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import transformers 8 | 9 | 10 | def replace_llama_rmsnorm_with_fused_rmsnorm(): 11 | try: 12 | from functools import partial 13 | 14 | from apex.normalization import FusedRMSNorm 15 | 16 | LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa 17 | transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm 18 | print( 19 | 'Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm' 20 | ) 21 | except ImportError: 22 | # using the normal LlamaRMSNorm 23 | pass 24 | except Exception: 25 | print('discovered apex but it failed to load, falling back to LlamaRMSNorm') 26 | pass 27 | -------------------------------------------------------------------------------- /internvl/patch/pad_data_collator.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import numpy as np 8 | import torch 9 | 10 | IGNORE_INDEX = -100 11 | 12 | 13 | def pad_data_collator(features, pad_id=0): 14 | 15 | first = features[0] 16 | batch = {} 17 | 18 | batch_lens = [feat['input_ids'].shape for feat in features] 19 | max_item_length = max(batch_lens)[0] 20 | for idx in range(len(features)): 21 | feat = features[idx] 22 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 23 | temp_input_ids[: feat['input_ids'].shape[0]] = feat['input_ids'] 24 | feat['input_ids'] = temp_input_ids 25 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 26 | temp_labels[: feat['labels'].shape[0]] = feat['labels'] 27 | feat['labels'] = temp_labels 28 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 29 | 30 | # Special handling for labels. 31 | # Ensure that tensor is created with the correct type 32 | # (it should be automatically the case, but let's make sure of it.) 33 | if 'label' in first and first['label'] is not None: 34 | label = ( 35 | first['label'].item() 36 | if isinstance(first['label'], torch.Tensor) 37 | else first['label'] 38 | ) 39 | dtype = torch.long if isinstance(label, int) else torch.float 40 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 41 | elif 'label_ids' in first and first['label_ids'] is not None: 42 | if isinstance(first['label_ids'], torch.Tensor): 43 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 44 | else: 45 | dtype = ( 46 | torch.long if isinstance(first['label_ids'][0], int) else torch.float 47 | ) 48 | batch['labels'] = torch.tensor( 49 | [f['label_ids'] for f in features], dtype=dtype 50 | ) 51 | 52 | # Handling of all other possible keys. 53 | # Again, we will use the first element to figure out which key/values are not None for this model. 54 | for k, v in first.items(): 55 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str): 56 | if isinstance(v, torch.Tensor): 57 | batch[k] = torch.stack([f[k] for f in features]) 58 | elif isinstance(v, np.ndarray): 59 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 60 | else: 61 | batch[k] = torch.tensor([f[k] for f in features]) 62 | return batch 63 | 64 | 65 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0): 66 | 67 | first = features[0] 68 | batch = {} 69 | 70 | batch_lens = [feat['input_ids'].shape for feat in features] 71 | max_item_length = max_item_length or max(batch_lens)[0] 72 | for idx in range(len(features)): 73 | feat = features[idx] 74 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 75 | temp_input_ids[: feat['input_ids'].shape[0]] = feat['input_ids'] 76 | feat['input_ids'] = temp_input_ids 77 | temp_labels = torch.FloatTensor([IGNORE_INDEX] * max_item_length) 78 | temp_labels[: feat['labels'].shape[0]] = feat['labels'] 79 | feat['labels'] = temp_labels 80 | feat['attention_mask'] = feat['input_ids'].ne(pad_id) 81 | 82 | if 'position_ids' in feat: 83 | temp_position_ids = torch.LongTensor([pad_id] * max_item_length) 84 | temp_position_ids[: feat['position_ids'].shape[0]] = feat['position_ids'] 85 | feat['position_ids'] = temp_position_ids 86 | 87 | if 'loss_weight' in feat: 88 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length) 89 | temp_loss_weight[: feat['loss_weight'].shape[0]] = feat['loss_weight'] 90 | feat['loss_weight'] = temp_loss_weight 91 | 92 | # Special handling for labels. 93 | # Ensure that tensor is created with the correct type 94 | # (it should be automatically the case, but let's make sure of it.) 95 | if 'label' in first and first['label'] is not None: 96 | label = ( 97 | first['label'].item() 98 | if isinstance(first['label'], torch.Tensor) 99 | else first['label'] 100 | ) 101 | dtype = torch.float 102 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype) 103 | elif 'label_ids' in first and first['label_ids'] is not None: 104 | if isinstance(first['label_ids'], torch.Tensor): 105 | batch['labels'] = torch.stack([f['label_ids'] for f in features]) 106 | else: 107 | dtype = torch.float 108 | batch['labels'] = torch.tensor( 109 | [f['label_ids'] for f in features], dtype=dtype 110 | ) 111 | 112 | # Handling of all other possible keys. 113 | # Again, we will use the first element to figure out which key/values are not None for this model. 114 | for k, v in first.items(): 115 | if ( 116 | k not in ('label', 'label_ids', 'pixel_values', 'image_flags') 117 | and v is not None 118 | and not isinstance(v, str) 119 | ): 120 | if isinstance(v, torch.Tensor): 121 | batch[k] = torch.stack([f[k] for f in features]) 122 | elif isinstance(v, np.ndarray): 123 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 124 | else: 125 | batch[k] = torch.tensor([f[k] for f in features]) 126 | if k in ('pixel_values', 'image_flags'): 127 | if isinstance(v, torch.Tensor): 128 | batch[k] = torch.concat([f[k] for f in features]) 129 | elif isinstance(v, np.ndarray): 130 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 131 | else: 132 | batch[k] = torch.concat([f[k] for f in features]) 133 | return batch 134 | 135 | 136 | def dpo_concat_pad_data_collator(features, pad_id=0): 137 | 138 | first = features[0] 139 | batch = {} 140 | 141 | for prefix in ['chosen_', 'rejected_']: 142 | batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features] 143 | max_item_length = max(batch_lens) 144 | for idx in range(len(features)): 145 | feat = features[idx] 146 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length) 147 | temp_input_ids[: feat[f'{prefix}input_ids'].shape[0]] = feat[ 148 | f'{prefix}input_ids' 149 | ] 150 | feat[f'{prefix}input_ids'] = temp_input_ids 151 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length) 152 | temp_labels[: feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels'] 153 | feat[f'{prefix}labels'] = temp_labels 154 | feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id) 155 | 156 | # Handling of all other possible keys. 157 | # Again, we will use the first element to figure out which key/values are not None for this model. 158 | for k, v in first.items(): 159 | if ( 160 | k not in ('pixel_values', 'image_flags') 161 | and v is not None 162 | and not isinstance(v, str) 163 | ): 164 | if isinstance(v, torch.Tensor): 165 | batch[k] = torch.stack([f[k] for f in features]) 166 | elif isinstance(v, np.ndarray): 167 | batch[k] = torch.tensor(np.stack([f[k] for f in features])) 168 | else: 169 | batch[k] = torch.tensor([f[k] for f in features]) 170 | if k in ('pixel_values', 'image_flags'): 171 | if isinstance(v, torch.Tensor): 172 | batch[k] = torch.concat([f[k] for f in features]) 173 | elif isinstance(v, np.ndarray): 174 | batch[k] = torch.concat(np.stack([f[k] for f in features])) 175 | else: 176 | batch[k] = torch.concat([f[k] for f in features]) 177 | return batch 178 | -------------------------------------------------------------------------------- /internvl/patch/phi3_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 9 | 10 | from internvl.model.phi3.modeling_phi3 import (PHI3_ATTENTION_CLASSES, 11 | Phi3FlashAttention2) 12 | 13 | 14 | class Phi3FlashAttention2ForPackedTraining(Phi3FlashAttention2): 15 | 16 | def _flash_attention_forward( 17 | self, 18 | query_states, 19 | key_states, 20 | value_states, 21 | attention_mask, 22 | query_length, 23 | dropout=0.0, 24 | softmax_scale=None, 25 | use_sliding_windows=False, 26 | ): 27 | """ 28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 29 | first unpad the input, then computes the attention scores and pad the final attention scores. 30 | 31 | Args: 32 | query_states (`torch.Tensor`): 33 | Input query states to be passed to Flash Attention API 34 | key_states (`torch.Tensor`): 35 | Input key states to be passed to Flash Attention API 36 | value_states (`torch.Tensor`): 37 | Input value states to be passed to Flash Attention API 38 | attention_mask (`torch.Tensor`): 39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 40 | position of padding tokens and 1 for the position of non-padding tokens. 41 | dropout (`float`): 42 | Attention dropout 43 | softmax_scale (`float`, *optional*): 44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 45 | use_sliding_windows (`bool`, *optional*): 46 | Whether to activate sliding window attention. 47 | """ 48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 49 | query_states = query_states.squeeze(0) 50 | key_states = key_states.squeeze(0) 51 | value_states = value_states.squeeze(0) 52 | cu_seqlens = attention_mask.squeeze(0) 53 | 54 | with torch.no_grad(): 55 | max_seqlen = max( 56 | [ 57 | cu_seqlens[idx + 1] - cu_seqlens[idx] 58 | for idx in range(cu_seqlens.size(0) - 1) 59 | ] 60 | ).item() 61 | 62 | if not self._flash_attn_uses_top_left_mask: 63 | causal = self.is_causal 64 | else: 65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 66 | causal = self.is_causal and query_length != 1 67 | 68 | # Decide whether to use SWA or not by layer index. 69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 70 | use_sliding_windows = False 71 | 72 | if not use_sliding_windows: 73 | attn_output = flash_attn_varlen_func( 74 | q=query_states, 75 | k=key_states, 76 | v=value_states, 77 | cu_seqlens_q=cu_seqlens, 78 | cu_seqlens_k=cu_seqlens, 79 | max_seqlen_q=max_seqlen, 80 | max_seqlen_k=max_seqlen, 81 | dropout_p=dropout, 82 | softmax_scale=softmax_scale, 83 | causal=causal, 84 | ) 85 | else: 86 | attn_output = flash_attn_varlen_func( 87 | q=query_states, 88 | k=key_states, 89 | v=value_states, 90 | cu_seqlens_q=cu_seqlens, 91 | cu_seqlens_k=cu_seqlens, 92 | max_seqlen_q=max_seqlen, 93 | max_seqlen_k=max_seqlen, 94 | dropout_p=dropout, 95 | softmax_scale=softmax_scale, 96 | causal=causal, 97 | window_size=(self.config.sliding_window, self.config.sliding_window), 98 | ) 99 | 100 | query_states = query_states.unsqueeze(0) 101 | key_states = key_states.unsqueeze(0) 102 | value_states = value_states.unsqueeze(0) 103 | return attn_output 104 | 105 | 106 | def replace_phi3_attention_class(): 107 | PHI3_ATTENTION_CLASSES['flash_attention_2'] = Phi3FlashAttention2ForPackedTraining 108 | print('Replace PHI3_ATTENTION_CLASSES to support packed training!!') 109 | -------------------------------------------------------------------------------- /internvl/patch/qwen2_packed_training_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func 9 | from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES, 10 | Qwen2FlashAttention2) 11 | 12 | 13 | # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 14 | class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2): 15 | 16 | def _flash_attention_forward( 17 | self, 18 | query_states, 19 | key_states, 20 | value_states, 21 | attention_mask, 22 | query_length, 23 | dropout=0.0, 24 | softmax_scale=None, 25 | use_sliding_windows=False, 26 | ): 27 | """ 28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 29 | first unpad the input, then computes the attention scores and pad the final attention scores. 30 | 31 | Args: 32 | query_states (`torch.Tensor`): 33 | Input query states to be passed to Flash Attention API 34 | key_states (`torch.Tensor`): 35 | Input key states to be passed to Flash Attention API 36 | value_states (`torch.Tensor`): 37 | Input value states to be passed to Flash Attention API 38 | attention_mask (`torch.Tensor`): 39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 40 | position of padding tokens and 1 for the position of non-padding tokens. 41 | dropout (`int`, *optional*): 42 | Attention dropout 43 | softmax_scale (`float`, *optional*): 44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 45 | use_sliding_windows (`bool`, *optional*): 46 | Whether to activate sliding window attention. 47 | """ 48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1 49 | query_states = query_states.squeeze(0) 50 | key_states = key_states.squeeze(0) 51 | value_states = value_states.squeeze(0) 52 | cu_seqlens = attention_mask.squeeze(0) 53 | 54 | with torch.no_grad(): 55 | max_seqlen = max( 56 | [ 57 | cu_seqlens[idx + 1] - cu_seqlens[idx] 58 | for idx in range(cu_seqlens.size(0) - 1) 59 | ] 60 | ).item() 61 | 62 | if not self._flash_attn_uses_top_left_mask: 63 | causal = self.is_causal 64 | else: 65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 66 | causal = self.is_causal and query_length != 1 67 | 68 | # Decide whether to use SWA or not by layer index. 69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers: 70 | use_sliding_windows = False 71 | 72 | if not use_sliding_windows: 73 | attn_output = flash_attn_varlen_func( 74 | q=query_states, 75 | k=key_states, 76 | v=value_states, 77 | cu_seqlens_q=cu_seqlens, 78 | cu_seqlens_k=cu_seqlens, 79 | max_seqlen_q=max_seqlen, 80 | max_seqlen_k=max_seqlen, 81 | dropout_p=dropout, 82 | softmax_scale=softmax_scale, 83 | causal=causal, 84 | ) 85 | else: 86 | attn_output = flash_attn_varlen_func( 87 | q=query_states, 88 | k=key_states, 89 | v=value_states, 90 | cu_seqlens_q=cu_seqlens, 91 | cu_seqlens_k=cu_seqlens, 92 | max_seqlen_q=max_seqlen, 93 | max_seqlen_k=max_seqlen, 94 | dropout_p=dropout, 95 | softmax_scale=softmax_scale, 96 | causal=causal, 97 | window_size=(self.config.sliding_window, self.config.sliding_window), 98 | ) 99 | 100 | query_states = query_states.unsqueeze(0) 101 | key_states = key_states.unsqueeze(0) 102 | value_states = value_states.unsqueeze(0) 103 | return attn_output 104 | 105 | 106 | def replace_qwen2_attention_class(): 107 | QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining 108 | print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!') 109 | -------------------------------------------------------------------------------- /internvl/patch/train_dataloader_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import datasets 8 | import torch 9 | import transformers 10 | from torch.utils.data import DataLoader 11 | from transformers.trainer import is_datasets_available, seed_worker 12 | 13 | 14 | def get_train_dataloader(self) -> DataLoader: 15 | """ 16 | Returns the training [`~torch.utils.data.DataLoader`]. 17 | 18 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 19 | training if necessary) otherwise. 20 | 21 | Subclass and override this method if you want to inject some custom behavior. 22 | """ 23 | if self.train_dataset is None: 24 | raise ValueError('Trainer: training requires a train_dataset.') 25 | 26 | train_dataset = self.train_dataset 27 | data_collator = self.data_collator 28 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 29 | train_dataset = self._remove_unused_columns( 30 | train_dataset, description='training' 31 | ) 32 | else: 33 | data_collator = self._get_collator_with_removed_columns( 34 | data_collator, description='training' 35 | ) 36 | 37 | dataloader_params = { 38 | 'batch_size': self._train_batch_size, 39 | 'collate_fn': data_collator, 40 | 'num_workers': self.args.dataloader_num_workers, 41 | 'pin_memory': self.args.dataloader_pin_memory, 42 | 'persistent_workers': self.args.dataloader_persistent_workers, 43 | } 44 | 45 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 46 | dataloader_params['sampler'] = self._get_train_sampler() 47 | dataloader_params['drop_last'] = self.args.dataloader_drop_last 48 | dataloader_params['worker_init_fn'] = seed_worker 49 | 50 | if self.args.use_packed_ds: 51 | return DataLoader(train_dataset, **dataloader_params) 52 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 53 | 54 | 55 | def replace_train_dataloader(): 56 | transformers.Trainer.get_train_dataloader = get_train_dataloader 57 | # print('Replace train dataloader!!') 58 | -------------------------------------------------------------------------------- /internvl/patch/train_sampler_patch.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import transformers 11 | from torch.utils.data import Dataset, Sampler 12 | from transformers.tokenization_utils_base import BatchEncoding 13 | from transformers.trainer import (LengthGroupedSampler, RandomSampler, 14 | has_length) 15 | from transformers.trainer_pt_utils import logger 16 | 17 | 18 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38 19 | def split_to_even_chunks(indices, lengths, num_chunks): 20 | """ 21 | Split a list of indices into `chunks` chunks of roughly equal lengths. 22 | """ 23 | 24 | if len(indices) % num_chunks != 0: 25 | return [indices[i::num_chunks] for i in range(num_chunks)] 26 | 27 | num_indices_per_chunk = len(indices) // num_chunks 28 | 29 | chunks = [[] for _ in range(num_chunks)] 30 | chunks_lengths = [0 for _ in range(num_chunks)] 31 | for index in indices: 32 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 33 | chunks[shortest_chunk].append(index) 34 | chunks_lengths[shortest_chunk] += lengths[index] 35 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 36 | chunks_lengths[shortest_chunk] = float('inf') 37 | 38 | return chunks 39 | 40 | 41 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88 42 | def get_length_grouped_indices( 43 | lengths, batch_size, world_size, generator=None, merge=True 44 | ): 45 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 46 | indices = torch.randperm(len(lengths), generator=generator) 47 | megabatch_size = world_size * batch_size 48 | megabatches = [ 49 | indices[i : i + megabatch_size].tolist() 50 | for i in range(0, len(lengths), megabatch_size) 51 | ] 52 | megabatches = [ 53 | sorted(megabatch, key=lambda i: lengths[i], reverse=True) 54 | for megabatch in megabatches 55 | ] 56 | megabatches = [ 57 | split_to_even_chunks(megabatch, lengths, world_size) 58 | for megabatch in megabatches 59 | ] 60 | 61 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 62 | 63 | 64 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99 65 | class LengthGroupedSampler(Sampler): 66 | r""" 67 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 68 | keeping a bit of randomness. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | batch_size: int, 74 | world_size: int, 75 | dataset: Optional[Dataset] = None, 76 | lengths: Optional[List[int]] = None, 77 | model_input_name: Optional[str] = None, 78 | generator=None, 79 | ): 80 | if dataset is None and lengths is None: 81 | raise ValueError('One of dataset and lengths must be provided.') 82 | 83 | self.batch_size = batch_size 84 | if lengths is None: 85 | model_input_name = ( 86 | model_input_name if model_input_name is not None else 'input_ids' 87 | ) 88 | if ( 89 | not ( 90 | isinstance(dataset[0], dict) 91 | or isinstance(dataset[0], BatchEncoding) 92 | ) 93 | or model_input_name not in dataset[0] 94 | ): 95 | raise ValueError( 96 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an ' 97 | f"'{model_input_name}' key." 98 | ) 99 | lengths = [len(feature[model_input_name]) for feature in dataset] 100 | elif isinstance(lengths, torch.Tensor): 101 | logger.info( 102 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...' 103 | ) 104 | lengths = lengths.tolist() 105 | self.world_size = world_size 106 | self.lengths = lengths 107 | self.generator = generator 108 | 109 | def __len__(self): 110 | return len(self.lengths) 111 | 112 | def __iter__(self): 113 | indices = get_length_grouped_indices( 114 | self.lengths, self.batch_size, self.world_size, generator=self.generator 115 | ) 116 | return iter(indices) 117 | 118 | 119 | # patch trainer 120 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 121 | if self.train_dataset is None or not has_length(self.train_dataset): 122 | return None 123 | # Build the sampler. 124 | if self.args.group_by_length: 125 | lengths = [] 126 | for dataset in self.train_dataset.datasets: 127 | lengths = lengths + dataset.length 128 | model_input_name = ( 129 | self.tokenizer.model_input_names[0] if self.tokenizer is not None else None 130 | ) 131 | return LengthGroupedSampler( 132 | self.args.train_batch_size, 133 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 134 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, 135 | dataset=self.train_dataset, 136 | lengths=lengths, 137 | model_input_name=model_input_name, 138 | ) 139 | else: 140 | return RandomSampler(self.train_dataset) 141 | 142 | 143 | def replace_train_sampler(): 144 | transformers.Trainer._get_train_sampler = _get_train_sampler 145 | # print('Replace train sampler!!') 146 | -------------------------------------------------------------------------------- /internvl/train/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | -------------------------------------------------------------------------------- /internvl/train/constants.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # InternVL 3 | # Copyright (c) 2024 OpenGVLab 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | IMG_CONTEXT_TOKEN = '' 8 | IMG_START_TOKEN = '' 9 | IMG_END_TOKEN = '' 10 | QUAD_START_TOKEN = '' 11 | QUAD_END_TOKEN = '' 12 | REF_START_TOKEN = '' 13 | REF_END_TOKEN = '' 14 | BOX_START_TOKEN = '' 15 | BOX_END_TOKEN = '' 16 | PRM_TOKEN = '' 17 | REWARD_TOKENS = ['Yes', 'No'] 18 | IMAGENET_MEAN = (0.485, 0.456, 0.406) 19 | IMAGENET_STD = (0.229, 0.224, 0.225) 20 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073) 21 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711) 22 | SIGLIP_MEAN = (0.5, 0.5, 0.5) 23 | SIGLIP_STD = (0.5, 0.5, 0.5) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements/internvl_chat.txt 2 | -r requirements/streamlit_demo.txt 3 | -r requirements/classification.txt 4 | -r requirements/segmentation.txt 5 | -------------------------------------------------------------------------------- /requirements/classification.txt: -------------------------------------------------------------------------------- 1 | gdown 2 | termcolor 3 | yacs 4 | -------------------------------------------------------------------------------- /requirements/clip_benchmark.txt: -------------------------------------------------------------------------------- 1 | open_clip_torch>=0.2.1 2 | opencv-python 3 | peft>=0.6.2 4 | protobuf 5 | pycocoevalcap 6 | pyyaml 7 | scikit-learn>=1.0,<2 8 | scikit-learn 9 | scipy 10 | task_adaptation 11 | tensorflow==2.11.0 12 | termcolor 13 | tqdm>=2 14 | transformers>=4.32.0 15 | webdataset>=0.2.31 16 | yacs 17 | -------------------------------------------------------------------------------- /requirements/internvl_chat.txt: -------------------------------------------------------------------------------- 1 | accelerate<1 2 | bitsandbytes==0.42.0 3 | decord 4 | deepspeed>=0.13.5 5 | einops==0.6.1 6 | einops-exts==0.0.4 7 | huggingface_hub 8 | imageio 9 | numpy==1.26.4 10 | opencv-python 11 | orjson 12 | peft==0.10.0 13 | pycocoevalcap 14 | pyyaml 15 | scikit-learn>=1.2.2 16 | scipy 17 | sentencepiece==0.1.99 18 | shortuuid 19 | tensorboardX 20 | termcolor 21 | timm==0.9.12 22 | tokenizers==0.15.1 23 | torch>=2 24 | torchvision>=0.15 25 | tqdm 26 | transformers==4.37.2 27 | yacs 28 | -------------------------------------------------------------------------------- /requirements/segmentation.txt: -------------------------------------------------------------------------------- 1 | future 2 | importlib_metadata 3 | mmcv-full==1.6.2 4 | mmsegmentation==0.30.0 5 | openmim 6 | ordered-set 7 | platformdirs 8 | tensorboard 9 | tomli 10 | yapf==0.40.1 11 | -------------------------------------------------------------------------------- /requirements/streamlit_demo.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | gradio==3.35.2 3 | gradio_client==0.2.9 4 | httpx==0.24.0 5 | markdown2[all] 6 | pydantic 7 | requests 8 | streamlit 9 | streamlit-image-select 10 | uvicorn 11 | -------------------------------------------------------------------------------- /shell/internvl2.5/2nd_finetune/internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | GPUS=${GPUS:-8} 4 | BATCH_SIZE=${BATCH_SIZE:-512} 5 | PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-2} 6 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS)) 7 | 8 | 9 | export PYTHONPATH="${PYTHONPATH}:$(pwd)" 10 | export MASTER_PORT=34229 11 | export TF_CPP_MIN_LOG_LEVEL=3 12 | export LAUNCHER=pytorch 13 | 14 | OUTPUT_DIR='work_dirs/internvl_chat_v2_5/prm-8b' 15 | 16 | if [ ! -d "$OUTPUT_DIR" ]; then 17 | mkdir -p "$OUTPUT_DIR" 18 | fi 19 | 20 | torchrun \ 21 | --nnodes=${NNODES} \ 22 | --node_rank=${NODE_RANK} \ 23 | --master_addr=${MASTER_ADDR} \ 24 | --nproc_per_node=${NPROC_PER_NODE} \ 25 | --master_port=${MASTER_PORT} \ 26 | internvl/train/internvl_chat_finetune.py \ 27 | --model_name_or_path "/path/to/model" \ 28 | --conv_style "internvl2_5" \ 29 | --use_fast_tokenizer False \ 30 | --output_dir ${OUTPUT_DIR} \ 31 | --meta_path "./shell/data/meta.json" \ 32 | --overwrite_output_dir True \ 33 | --force_image_size 448 \ 34 | --max_dynamic_patch 6 \ 35 | --down_sample_ratio 0.5 \ 36 | --drop_path_rate 0.4 \ 37 | --freeze_llm False \ 38 | --freeze_mlp False \ 39 | --freeze_backbone True \ 40 | --vision_select_layer -1 \ 41 | --dataloader_num_workers 4 \ 42 | --bf16 True \ 43 | --num_train_epochs 3 \ 44 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \ 45 | --gradient_accumulation_steps ${GRADIENT_ACC} \ 46 | --save_strategy "epoch" \ 47 | --save_total_limit 5 \ 48 | --learning_rate 2e-6 \ 49 | --weight_decay 0.05 \ 50 | --warmup_ratio 0.03 \ 51 | --lr_scheduler_type "cosine" \ 52 | --logging_steps 1 \ 53 | --max_seq_length 8192 \ 54 | --do_train True \ 55 | --grad_checkpoint True \ 56 | --group_by_length True \ 57 | --dynamic_image_size True \ 58 | --use_thumbnail True \ 59 | --ps_version 'v2' \ 60 | --deepspeed "zero_stage1_config.json" \ 61 | --report_to "tensorboard" \ 62 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt" 63 | -------------------------------------------------------------------------------- /zero_stage1_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 1, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e9, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e9, 9 | "contiguous_gradients": true 10 | }, 11 | "fp16": { 12 | "enabled": "auto", 13 | "auto_cast": true, 14 | "loss_scale": 0, 15 | "initial_scale_power": 32, 16 | "loss_scale_window": 1000, 17 | "hysteresis": 2, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "optimizer": { 24 | "type": "AdamW", 25 | "params": { 26 | "lr": "auto", 27 | "betas": [ 28 | 0.9, 29 | 0.999 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": "auto" 33 | } 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": true 41 | } 42 | -------------------------------------------------------------------------------- /zero_stage2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 2, 4 | "allgather_partitions": true, 5 | "allgather_bucket_size": 1e8, 6 | "overlap_comm": true, 7 | "reduce_scatter": true, 8 | "reduce_bucket_size": 1e8, 9 | "contiguous_gradients": true 10 | }, 11 | "fp16": { 12 | "enabled": "auto", 13 | "auto_cast": true, 14 | "loss_scale": 0, 15 | "initial_scale_power": 32, 16 | "loss_scale_window": 1000, 17 | "hysteresis": 2, 18 | "min_loss_scale": 1 19 | }, 20 | "bf16": { 21 | "enabled": "auto" 22 | }, 23 | "optimizer": { 24 | "type": "AdamW", 25 | "params": { 26 | "lr": "auto", 27 | "betas": [ 28 | 0.9, 29 | 0.999 30 | ], 31 | "eps": 1e-8, 32 | "weight_decay": "auto" 33 | } 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 2000, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } 42 | -------------------------------------------------------------------------------- /zero_stage3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "stage3_prefetch_bucket_size": 1e9, 9 | "stage3_param_persistence_threshold": 1e7, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | -------------------------------------------------------------------------------- /zero_stage3_config_100b.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "stage3_prefetch_bucket_size": 1e9, 9 | "stage3_param_persistence_threshold": 1e4, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | -------------------------------------------------------------------------------- /zero_stage3_config_100b_1e8.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e8, 7 | "reduce_bucket_size": 1e8, 8 | "stage3_prefetch_bucket_size": 1e8, 9 | "stage3_param_persistence_threshold": 1e4, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | -------------------------------------------------------------------------------- /zero_stage3_config_34b.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "stage3_prefetch_bucket_size": 1e9, 9 | "stage3_param_persistence_threshold": 1e5, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | -------------------------------------------------------------------------------- /zero_stage3_config_70b.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 3, 4 | "overlap_comm": true, 5 | "contiguous_gradients": true, 6 | "sub_group_size": 1e9, 7 | "reduce_bucket_size": 1e9, 8 | "stage3_prefetch_bucket_size": 1e9, 9 | "stage3_param_persistence_threshold": 1e5, 10 | "stage3_max_live_parameters": 1e9, 11 | "stage3_max_reuse_distance": 1e9, 12 | "stage3_gather_16bit_weights_on_model_save": true 13 | }, 14 | "fp16": { 15 | "enabled": "auto", 16 | "auto_cast": true, 17 | "loss_scale": 0, 18 | "initial_scale_power": 32, 19 | "loss_scale_window": 1000, 20 | "hysteresis": 2, 21 | "min_loss_scale": 1 22 | }, 23 | "bf16": { 24 | "enabled": "auto" 25 | }, 26 | "optimizer": { 27 | "type": "AdamW", 28 | "params": { 29 | "lr": "auto", 30 | "betas": [ 31 | 0.9, 32 | 0.999 33 | ], 34 | "eps": 1e-8, 35 | "weight_decay": "auto" 36 | } 37 | }, 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "steps_per_print": 2000, 41 | "train_batch_size": "auto", 42 | "train_micro_batch_size_per_gpu": "auto", 43 | "wall_clock_breakdown": true 44 | } 45 | --------------------------------------------------------------------------------