├── .flake8
├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── data_pipeline
├── llm_utils.py
├── mllm_as_a_judge.py
├── omegaprm.py
├── prm_data_format.py
├── process_json.py
├── run_data_pipeline.py
├── run_data_pipeline.sh
└── traverse.py
├── docs
├── case_study.png
├── logo.png
├── performance.png
└── wechat_qr.png
├── eval
└── prm
│ ├── evaluate_k12_prm.py
│ ├── evaluate_mathverse_prm.py
│ ├── evaluate_mathvision_prm.py
│ ├── evaluate_mathvista_prm.py
│ ├── evaluate_olympiadbench_prm.py
│ └── extract_calculate.py
├── evaluate.sh
├── internvl
├── conversation.py
├── dist_utils.py
├── model
│ ├── __init__.py
│ ├── internlm2
│ │ ├── configuration_internlm2.py
│ │ ├── modeling_internlm2.py
│ │ ├── tokenization_internlm2.py
│ │ └── tokenization_internlm2_fast.py
│ ├── internvl_chat
│ │ ├── __init__.py
│ │ ├── configuration_intern_vit.py
│ │ ├── configuration_internvl_chat.py
│ │ ├── modeling_intern_vit.py
│ │ └── modeling_internvl_chat.py
│ └── phi3
│ │ ├── configuration_phi3.py
│ │ └── modeling_phi3.py
├── patch
│ ├── __init__.py
│ ├── internlm2_packed_training_patch.py
│ ├── internvit_liger_monkey_patch.py
│ ├── llama2_flash_attn_monkey_patch.py
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_packed_training_patch.py
│ ├── llama_rmsnorm_monkey_patch.py
│ ├── pad_data_collator.py
│ ├── phi3_packed_training_patch.py
│ ├── qwen2_packed_training_patch.py
│ ├── train_dataloader_patch.py
│ └── train_sampler_patch.py
└── train
│ ├── __init__.py
│ ├── constants.py
│ ├── dataset.py
│ ├── dataset_packed.py
│ └── internvl_chat_finetune.py
├── requirements.txt
├── requirements
├── classification.txt
├── clip_benchmark.txt
├── internvl_chat.txt
├── segmentation.txt
└── streamlit_demo.txt
├── shell
└── internvl2.5
│ └── 2nd_finetune
│ └── internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh
├── zero_stage1_config.json
├── zero_stage2_config.json
├── zero_stage3_config.json
├── zero_stage3_config_100b.json
├── zero_stage3_config_100b_1e8.json
├── zero_stage3_config_34b.json
└── zero_stage3_config_70b.json
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E501, F403, C901, W504, W605, E251, E122, E126, E127, E722, W503, E128, E741, E731, E701
3 | select = E1, E3, E502, E7, E9, W1, W5, W6
4 | max-line-length = 180
5 | exclude=*.egg/*,build,dist,detection/configs/*
6 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | .idea/
163 |
164 | .DS_Store
165 | data_process/
166 | internvl_chat/work_dirs/
167 | internvl_chat/unittest/
168 | internvl_chat/data/
169 | Husky2/*
170 | data_process/
171 | *distillation*
172 |
173 | batchscript-*
174 | results/
175 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | line-length = 180
3 | multi_line_output = 0
4 | extra_standard_library = setuptools
5 | known_third_party = PIL,asynctest,cityscapesscripts,cv2,gather_models,matplotlib,mmcv,numpy,onnx,onnxruntime,pycocotools,pytest,pytorch_sphinx_theme,requests,scipy,seaborn,six,terminaltables,torch,ts,yaml
6 | no_lines_before = STDLIB,LOCALFOLDER
7 | default_section = THIRDPARTY
8 |
9 | [yapf]
10 | BASED_ON_STYLE = pep8
11 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
12 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
13 |
14 | [codespell]
15 | skip = *.ipynb
16 | quiet-level = 3
17 | ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood
18 | © 2022 GitHub, Inc.
19 | Terms
20 | Privacy
21 | Security
22 | Status
23 | Docs
24 | Contact GitHub
25 | Pricing
26 | API
27 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/PyCQA/flake8
3 | rev: 5.0.4
4 | hooks:
5 | - id: flake8
6 | - repo: https://github.com/PyCQA/isort
7 | rev: 5.11.5
8 | hooks:
9 | - id: isort
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v4.3.0
12 | hooks:
13 | - id: trailing-whitespace
14 | - id: check-yaml
15 | - id: end-of-file-fixer
16 | - id: requirements-txt-fixer
17 | - id: double-quote-string-fixer
18 | - id: check-merge-conflict
19 | - id: fix-encoding-pragma
20 | args: ["--remove"]
21 | - id: mixed-line-ending
22 | args: ["--fix=lf"]
23 | - repo: https://github.com/executablebooks/mdformat
24 | rev: 0.7.9
25 | hooks:
26 | - id: mdformat
27 | args: ["--number"]
28 | additional_dependencies:
29 | - mdformat-openmmlab
30 | - mdformat_frontmatter
31 | - linkify-it-py
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to MM-Eureka
2 |
3 | After cloning the repository, please install pre-commit hooks with:
4 |
5 | ```
6 | pip install pre-commit
7 | pre-commit install
8 | ```
9 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 ModalMinds Team
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 |
7 | # MM-PRM
8 |
9 |
10 |
11 |
18 |
19 |
20 |
21 |
MM-PRM: Enhancing Multimodal Mathematical Reasoning with Scalable Step-Level Supervision
22 |
23 |
24 |
25 | ## 🎯Overview
26 |
27 | While Multimodal Large Language Models (MLLMs) have achieved impressive progress in vision-language understanding, they still struggle with complex multi-step reasoning, often producing logically inconsistent or partially correct solutions. A key limitation lies in the lack of fine-grained supervision over intermediate reasoning steps. To address this, we propose **MM-PRM**, a process reward model trained within a fully automated, scalable framework. We first build **MM-Policy**, a strong multimodal model trained on diverse mathematical reasoning data. Then, we construct **MM-K12**, a curated dataset of 10,000 multimodal math problems with verifiable answers, which serves as seed data. Leveraging a Monte Carlo Tree Search (MCTS)-based pipeline, we generate over 700k step-level annotations without human labeling. The resulting PRM is used to score candidate reasoning paths in the Best-of-N inference setup and achieves significant improvements across both in-domain (MM-K12 test set) and out-of-domain (OlympiadBench, MathVista, etc.) benchmarks. Further analysis confirms the effectiveness of soft labels, smaller learning rates, and path diversity in optimizing PRM performance. MM-PRM demonstrates that process supervision is a powerful tool for enhancing the logical robustness of multimodal reasoning systems. We release all our codes and data at [MM-PRM](https://github.com/ModalMinds/MM-PRM).
28 |
29 | ## 🗞️ News
30 |
31 | - **\[2025/05/19\]** We released `MM-PRM`.
32 | - 📖 Paper: [MM-PRM-Paper](https://arxiv.org/abs/2505.13427)
33 | - 📊 Data: [MM-K12](https://huggingface.co/datasets/Cierra0506/MM-K12)
34 | - 🤗 Model: [MM-PRM](https://huggingface.co/Cierra0506/MM-PRM)
35 |
36 | ## 📊 MM-K12 Dataset
37 |
38 | We released **MM-K12** dataset at [MM-K12](https://huggingface.co/datasets/Cierra0506/MM-K12).
39 |
40 | ## 🤖 Models
41 |
42 |
43 |

44 |
45 |
46 | *Figure 1 | Qualitative example of MM-PRM accurately identifying error steps in multimodal reasoning process.*
47 |
48 |
49 |

50 |
51 |
52 | *Figure 2 | Performance improvements across various benchmarks when applying the MM-PRM to different models.*
53 |
54 | - 🤗 [MM-PRM](https://huggingface.co/Cierra0506/MM-PRM)
55 |
56 | ## 🏁 Getting Started
57 |
58 | ### 📦 Installation
59 |
60 | ```shell
61 | git clone https://github.com/ModalMinds/MM-PRM.git
62 | cd MM-PRM
63 | pip install -r requirements.txt
64 |
65 | # install flash-attn==2.3.6:
66 |
67 | pip install flash-attn==2.3.6 --no-build-isolation
68 |
69 | # Alternatively you can compile from source:
70 |
71 | git clone https://github.com/Dao-AILab/flash-attention.git
72 | cd flash-attention
73 | git checkout v2.3.6
74 | python setup.py install
75 | ```
76 |
77 | ### 📂 Data Pipeline
78 |
79 | 1. **Seed dataset preparation**
80 |
81 | To begin, prepare a seed dataset consisting of verifiable problems. Each example should be formatted as a JSON object containing the following fields:
82 |
83 | ```json
84 | [
85 | {
86 | "id": "unique identifier for the problem",
87 | "question": "problem statement",
88 | "correct_answer": "ground-truth final answer for evaluation and verification",
89 | "image_path": "/path/to/image.png"
90 | },
91 | ...
92 | ]
93 | ```
94 |
95 | This dataset will be used as input to the data pipeline to generate annotated solution trees with step-wise correctness labels.
96 |
97 | To enable parallel data generation, you need to split the seed dataset into smaller chunks.
98 |
99 | ```shell
100 | cd data_pipeline
101 | python process_json.py
102 | ```
103 |
104 | 2. **API endpoint setup (Optional)**
105 |
106 | The data generation process requires an API endpoint to automatically verify whether the final answer in a rollout is correct. You can deploy a model (e.g., Qwen2.5) locally to act as the answer judge.
107 |
108 | We recommend using [vLLM](https://docs.vllm.ai/) to deploy a local model.
109 |
110 | 3. **Run data pipeline**
111 |
112 | Once you have all set, you can run the data pipeline to generate step-level supervision data.
113 |
114 | Before running, ensure that all necessary parameters are correctly set in the script or passed through the environment.
115 |
116 | ```shell
117 | sh run_data_pipeline.sh
118 | ```
119 |
120 | 4. **Sampling Training Data from annotation trees**
121 |
122 | After generating annotated reasoning trees, you need to sample step-by-step solution paths from these trees to construct the training data for the Process Reward Model (PRM). This can be done using the script:
123 |
124 | ```shell
125 | python traverse.py
126 | ```
127 |
128 | The next step is to convert this data into the format required for PRM training. Use the following script to perform the formatting:
129 |
130 | ```shell
131 | python prm_data_format.py
132 | ```
133 |
134 | ### 🌐 Start PRM Training
135 |
136 | Create a JSON file in `internvl_chat/shell/data/`
137 |
138 | The format for the JSON file should be:
139 |
140 | ```json
141 | {
142 | "your-custom-prm_dataset": {
143 | "root": "/path/to/the/image/root",
144 | "annotation": "/path/to/the/jsonl/annotation",
145 | "data_augment": false,
146 | "repeat_time": 1,
147 | "length": "number of samples in the dataset"
148 | }
149 | }
150 | ```
151 |
152 | Once the dataset configuration is in place, you can start training the PRM model with:
153 |
154 | ```shell
155 | GPUS=8 sh shell/internvl2.5/2nd_finetune/internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh
156 | ```
157 |
158 | ### 📊 Evaluation
159 |
160 | We provide our **evaluation code** in the `eval/` directory.
161 |
162 | ## ⭐ Starchart
163 |
164 | [](https://star-history.com/#ModalMinds/MM-PRM&Date)
165 |
166 | ## 🤝 Contribution
167 |
168 | If you want to contribute, please feel free to make a pull request or create an issue.
169 |
170 | Please refer to `CONTRIBUTING.md` before you dive in!
171 |
172 | ## 📬 Contact
173 |
174 | If you have any questions or would like to engage with our community, feel free to scan the QR code below to join our WeChat group.
175 |
176 |
177 |

178 |
179 |
180 | ## 🎓 Acknowledgements
181 |
182 | We acknowledge the outstanding open-source contributions from [OpenR](https://github.com/openreasoner/openr) and [vLLM](https://github.com/vllm-project/vllm). We also extend our gratitude to [InternVL](https://github.com/OpenGVLab/InternVL) for their open-source techniques and base models, which have enabled us to further our exploration.
183 |
184 | ## 📜 Citation
185 |
186 | ```
187 | @article{du2025mmprm,
188 | title={MM-PRM: Enhancing Multimodal Mathematical Reasoning with Scalable Step-Level Supervision},
189 | author={Lingxiao Du and Fanqing Meng and Zongkai Liu and Zhixiang Zhou and Ping Luo and Qiaosheng Zhang and Wenqi Shao},
190 | year={2025},
191 | journal={arXiv preprint arXiv:2505.13427},
192 | }
193 | @article{meng2025mmeureka,
194 | title={MM-Eureka: Exploring the Frontiers of Multimodal Reasoning with Rule-based Reinforcement Learning},
195 | author={Fanqing Meng and Lingxiao Du and Zongkai Liu and Zhixiang Zhou and Quanfeng Lu and Daocheng Fu and Tiancheng Han and Botian Shi and Wenhai Wang and Junjun He and Kaipeng Zhang and Ping Luo and Yu Qiao and Qiaosheng Zhang and Wenqi Shao},
196 | year={2025},
197 | journal={arXiv preprint arXiv:2503.07365},
198 | }
199 | @article{liu2025cpgd,
200 | title={CPGD: Toward Stable Rule-based Reinforcement Learning for Language Models},
201 | author={Zongkai Liu and Fanqing Meng and Lingxiao Du and Zhixiang Zhou and Chao Yu and Wenqi Shao and Qiaosheng Zhang},
202 | year={2025},
203 | journal={arXiv preprint arXiv:2505.12504},
204 | }
205 | ```
206 |
--------------------------------------------------------------------------------
/data_pipeline/llm_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import re
3 |
4 | from transformers import AutoTokenizer
5 | from vllm import LLM, SamplingParams
6 | from vllm.multimodal.utils import fetch_image
7 |
8 | logger = logging.getLogger('main')
9 |
10 |
11 | class LanguageModel:
12 | def __init__(
13 | self,
14 | model='OpenGVLab/InternVL2_5-8B',
15 | max_new_tokens=4096,
16 | temperature=1.0,
17 | top_k=50,
18 | top_p=0.9,
19 | ):
20 | self.model = model
21 | self.max_new_tokens = max_new_tokens
22 | self.temperature = temperature
23 | self.top_k = top_k
24 | self.top_p = top_p
25 | self.repetition_penalty = 1.05
26 |
27 | logger.info(f'Loading model {self.model}...')
28 | self.model = LLM(
29 | model=model,
30 | trust_remote_code=True,
31 | tensor_parallel_size=1,
32 | limit_mm_per_prompt={'image': 8},
33 | )
34 | self.tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
35 | self.stop_tokens = ['<|im_end|>\n'.strip()]
36 | self.stop_token_ids = [
37 | self.tokenizer.convert_tokens_to_ids(i) for i in self.stop_tokens
38 | ]
39 | self.special_tokens = self.tokenizer.all_special_tokens
40 | self.custom_tokens = ['', '', '', '']
41 | self.special_tokens = [
42 | token for token in self.special_tokens if token not in self.custom_tokens
43 | ]
44 | self.pattern1 = r'|'.join(map(re.escape, self.special_tokens))
45 | self.pattern2 = r'(.*?)|(.*?)'
46 | logger.info('Model loaded successfully.')
47 |
48 | def generate_results(self, prompt, image_path=None, num_copies=16):
49 | if '' not in prompt:
50 | prompt = '\n' + prompt
51 |
52 | messages = [
53 | {
54 | 'role': 'system',
55 | 'content': '你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。',
56 | },
57 | {'role': 'user', 'content': prompt},
58 | ]
59 | prompt = self.tokenizer.apply_chat_template(
60 | messages, tokenize=False, add_generation_prompt=True
61 | )
62 |
63 | inputs = []
64 | if image_path:
65 | image = fetch_image('file://' + image_path, allowed_local_media_path='/')
66 | for _ in range(num_copies):
67 | inputs.append(
68 | {
69 | 'prompt': prompt,
70 | 'multi_modal_data': {'image': image},
71 | }
72 | )
73 | else:
74 | for _ in range(num_copies):
75 | inputs.append({'prompt': prompt})
76 |
77 | sampling_params = SamplingParams(
78 | temperature=self.temperature,
79 | max_tokens=self.max_new_tokens,
80 | top_p=self.top_p,
81 | top_k=self.top_k,
82 | repetition_penalty=self.repetition_penalty,
83 | stop_token_ids=self.stop_token_ids,
84 | skip_special_tokens=False,
85 | )
86 | model_outputs = self.model.generate(inputs, sampling_params=sampling_params)
87 | batch_results = []
88 | for model_output in model_outputs:
89 | response = re.sub(self.pattern1, '', model_output.outputs[0].text)
90 | matches = re.findall(self.pattern2, response, re.DOTALL)
91 | res = (
92 | [match[0] if match[0] else match[1] for match in matches]
93 | if matches
94 | else []
95 | )
96 | res = list(map(str.strip, res))
97 | batch_results.append(res)
98 |
99 | for result in batch_results:
100 | logger.debug(f'Prompt: {prompt}\nGenerated rollout: {result}')
101 |
102 | return batch_results
103 |
--------------------------------------------------------------------------------
/data_pipeline/mllm_as_a_judge.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import logging
4 | import os
5 | from concurrent.futures import ThreadPoolExecutor
6 |
7 | import numpy as np
8 | import openai
9 | from tqdm import tqdm
10 |
11 | logging.basicConfig(
12 | level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
13 | )
14 |
15 | openai._utils._logs.logger.setLevel(logging.WARNING)
16 | openai._utils._logs.httpx_logger.setLevel(logging.WARNING)
17 | openai.base_url = 'http://127.0.0.1:10022/v1/'
18 | openai.api_key = 'FAKE_API_KEY'
19 | MODEL_NAME = 'Qwen2-VL-72B-Instruct'
20 | OUTPUT_DIR = 'mllm_as_a_judge_outputs'
21 |
22 |
23 | def traverse(node):
24 | def dfs(node, question, encoded_image, previous_solution, answer):
25 | partial_solution = (
26 | '\n\n'.join(previous_solution)
27 | if previous_solution
28 | else 'No partial solution'
29 | )
30 | following_steps = '\n\n'.join(node['partial_solution'])
31 | prompt = f"""I will provide a problem, its corresponding answer, a partial solution to the problem and some steps that continue from the partial solution. They will be formatted as follows:
32 |
33 | [Problem]
34 |
35 | ...(problem)...
36 |
37 | [Correct Answer]
38 |
39 | ...(problem's correct answer)...
40 |
41 | [Partial Solution]
42 |
43 | ...(partial solution)...
44 |
45 | [Following Steps]
46 |
47 | ...(some steps that continue from the partial solution)...
48 |
49 | Your task is to evaluate the Following Steps to determine whether they are logically and mathematically valid. If they are valid, respond with "Yes"; otherwise, respond with "No".
50 |
51 | * Respond with "Yes" or "No" only.
52 |
53 | ------------------------------------------------
54 |
55 | The following is the information for you task:
56 |
57 | [Problem]
58 |
59 | {question}
60 |
61 | [Correct Answer]
62 |
63 | {answer}
64 |
65 | [Partial Solution]
66 |
67 | {partial_solution}
68 |
69 | [Following Steps]
70 |
71 | {following_steps}
72 | """
73 |
74 | messages = [
75 | {
76 | 'role': 'user',
77 | 'content': [
78 | {'type': 'text', 'text': prompt},
79 | {
80 | 'type': 'image_url',
81 | 'image_url': {'url': encoded_image},
82 | },
83 | ],
84 | }
85 | ]
86 |
87 | completion = openai.chat.completions.create(
88 | model=MODEL_NAME,
89 | messages=messages,
90 | temperature=0.0,
91 | max_tokens=1,
92 | top_p=0.95,
93 | logprobs=True,
94 | top_logprobs=5,
95 | )
96 |
97 | top_logprobs = completion.choices[0].logprobs.content[0].top_logprobs
98 | top_logprobs = {i.token: i.logprob for i in top_logprobs}
99 |
100 | logprob_yes = max(top_logprobs.get('YES', -100), top_logprobs.get('Yes', -100))
101 | logprob_no = max(top_logprobs.get('NO', -100), top_logprobs.get('No', -100))
102 | node['logprob_yes'] = float(np.exp(logprob_yes))
103 | node['logprob_no'] = float(np.exp(logprob_no))
104 | node['llm_as_a_judge'] = float(
105 | np.exp(logprob_yes) / (np.exp(logprob_yes) + np.exp(logprob_no))
106 | )
107 | for child in node['children']:
108 | dfs(
109 | child,
110 | question,
111 | encoded_image,
112 | previous_solution + node['partial_solution'],
113 | answer,
114 | )
115 |
116 | try:
117 | with open(node['image_path'], 'rb') as f:
118 | encoded_image = (
119 | f'data:image;base64,{base64.b64encode(f.read()).decode("utf-8")}'
120 | )
121 | for child in node['children']:
122 | dfs(child, node['question'], encoded_image, [], node['answer'])
123 |
124 | with open(os.path.join(OUTPUT_DIR, f'{node["id"]}.json'), 'w') as f:
125 | json.dump(node, f, ensure_ascii=False, indent=4)
126 | logging.info(node['id'])
127 | except:
128 | pass
129 |
130 |
131 | if __name__ == '__main__':
132 | os.makedirs(OUTPUT_DIR, exist_ok=True)
133 |
134 | root_dir = '/path/to/outputs'
135 | roots = []
136 | for file in os.listdir(root_dir):
137 | logging.info(file)
138 | with open(os.path.join(root_dir, file), 'r') as f:
139 | roots.append(json.load(f))
140 |
141 | with ThreadPoolExecutor(max_workers=16) as executor:
142 | results = list(tqdm(executor.map(traverse, roots), total=len(roots)))
143 |
--------------------------------------------------------------------------------
/data_pipeline/prm_data_format.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import jsonlines
4 |
5 | input_file = '/path/to/input/file.json'
6 | output_file = '/path/to/output/file.jsonl'
7 |
8 | with open(input_file, 'r') as f:
9 | data = json.load(f)
10 | res = []
11 |
12 | for idx, item in enumerate(data):
13 | question = item['question']
14 | process = item['process']
15 | labels = item['labels']
16 | image_path = item['image_path']
17 |
18 | combined_value = f'Question: {question}\nProcess: {process}'
19 |
20 | conversations = [
21 | {'from': 'human', 'value': combined_value},
22 | {'from': 'gpt', 'value': labels},
23 | ]
24 |
25 | new_item = {'id': idx, 'image': image_path, 'conversations': conversations}
26 |
27 | res.append(new_item)
28 |
29 | with jsonlines.open(output_file, 'w') as writer:
30 | writer.write_all(res)
31 |
--------------------------------------------------------------------------------
/data_pipeline/process_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import math
3 | import os
4 |
5 |
6 | def split_questions_uniformly(input_file, output_dir, num_splits):
7 | with open(input_file, 'r') as f:
8 | input_data = json.load(f)
9 |
10 | num_data = len(input_data)
11 | data_per_file = math.ceil(num_data / num_splits)
12 |
13 | os.makedirs(output_dir, exist_ok=True)
14 |
15 | for i in range(num_splits):
16 | start_idx = i * data_per_file
17 | end_idx = min(start_idx + data_per_file, num_data)
18 | data_subset = input_data[start_idx:end_idx]
19 |
20 | output_filepath = os.path.join(output_dir, f'questions_part_{i + 1}.json')
21 | print(f'Saving {len(data_subset)} questions to {output_filepath}')
22 |
23 | with open(output_filepath, 'w') as f_out:
24 | json.dump(data_subset, f_out, indent=4, ensure_ascii=False)
25 |
26 |
27 | if __name__ == '__main__':
28 | split_questions_uniformly('/path/to/input/file', 'split_dir', 16)
29 |
--------------------------------------------------------------------------------
/data_pipeline/run_data_pipeline.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import hashlib
3 | import json
4 | import logging
5 | import os
6 |
7 | from llm_utils import LanguageModel
8 | from omegaprm import OmegaPRM
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser()
12 |
13 | parser.add_argument('--input_file', type=str)
14 | parser.add_argument('--log_file', type=str, default='log.txt')
15 | parser.add_argument('--output_dir', type=str, default='output')
16 |
17 | parser.add_argument('--model', type=str, default='OpenGVLab/InternVL2_5-8B')
18 | parser.add_argument('--max_new_tokens', type=int, default=2048)
19 | parser.add_argument('--temperature', type=float, default=0.7)
20 | parser.add_argument('--top_k', type=int, default=30)
21 | parser.add_argument('--top_p', type=float, default=0.9)
22 |
23 | parser.add_argument('--c_puct', type=float, default=0.125)
24 | parser.add_argument('--alpha', type=float, default=0.5)
25 | parser.add_argument('--beta', type=float, default=0.9)
26 | parser.add_argument('--length_scale', type=int, default=500)
27 | parser.add_argument('--num_rollouts', type=int, default=16)
28 | parser.add_argument('--max_search_count', type=int, default=20)
29 | parser.add_argument('--rollout_budget', type=int, default=200)
30 |
31 | parser.add_argument('--api_endpoint', type=str)
32 |
33 | args = parser.parse_args()
34 |
35 | logging.basicConfig(
36 | level=logging.DEBUG,
37 | format='%(asctime)s [%(levelname)s] %(message)s',
38 | handlers=[logging.FileHandler(args.log_file), logging.StreamHandler()],
39 | )
40 | logger = logging.getLogger('main')
41 |
42 | logger.info('Start OmegaPRM')
43 | logger.info(f'Using model: {args.model}')
44 | logger.info(f'Input file: {args.input_file}')
45 | logger.info(f'Output directory: {args.output_dir}')
46 |
47 | with open(args.input_file, 'r') as f:
48 | input_data = json.load(f)
49 |
50 | LLM = LanguageModel(
51 | model=args.model,
52 | max_new_tokens=args.max_new_tokens,
53 | temperature=args.temperature,
54 | top_k=args.top_k,
55 | top_p=args.top_p,
56 | )
57 |
58 | for question in input_data:
59 | hash_value = hashlib.md5(json.dumps(question).encode()).hexdigest()
60 | if os.path.exists(os.path.join(args.output_dir, f'{hash_value}.json')):
61 | logger.info(f"Already processed: {question['question']}")
62 | continue
63 | try:
64 | omega_prm = OmegaPRM(
65 | LLM=LLM,
66 | question=question['question'],
67 | image_path=question['image_path'],
68 | correct_answer=question['correct_answer'],
69 | c_puct=args.c_puct,
70 | alpha=args.alpha,
71 | beta=args.beta,
72 | length_scale=args.length_scale,
73 | num_rollouts=args.num_rollouts,
74 | max_search_count=args.max_search_count,
75 | rollout_budget=args.rollout_budget,
76 | api_endpoint=args.api_endpoint,
77 | )
78 | data = omega_prm.run()
79 | data['id'] = question['id']
80 |
81 | filename = os.path.join(args.output_dir, f'{hash_value}.json')
82 | with open(filename, 'w') as f:
83 | json.dump(data, f, indent=4, ensure_ascii=False)
84 | logger.info(f'Processed: {hash_value}.json')
85 | except Exception as e:
86 | logger.error(f"Error processing {question['question']}: {e}")
87 |
--------------------------------------------------------------------------------
/data_pipeline/run_data_pipeline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SPLIT_DIR="split_dir"
4 | LOG_DIR="logs"
5 | OUTPUT_DIR="outputs"
6 | MODEL="/path/to/model"
7 |
8 | MAX_NEW_TOKENS=4096
9 | TEMPERATURE=1.0
10 | TOP_K=50
11 | TOP_P=0.9
12 |
13 | C_PUCT=0.125
14 | ALPHA=0.5
15 | BETA=0.9
16 | LENGTH_SCALE=2000
17 | NUM_ROLLOUTS=16
18 | MAX_SEARCH_COUNT=200
19 | ROLLOUT_BUDGET=1000
20 |
21 | API_ENDPOINT="http://127.0.0.1:8000/v1/"
22 |
23 |
24 | mkdir -p $OUTPUT_DIR
25 | mkdir -p $LOG_DIR
26 |
27 | START=$((NODE_RANK * 8 + 1))
28 |
29 | for i in {0..7}
30 | do
31 | j=$((START + i))
32 | GPU_ID=$i
33 | INPUT_FILE="$SPLIT_DIR/questions_part_${j}.json"
34 | LOG_FILE="$LOG_DIR/part_${j}.log"
35 |
36 | CUDA_VISIBLE_DEVICES="${GPU_ID}" python run_data_pipeline.py \
37 | --input_file $INPUT_FILE \
38 | --log_file $LOG_FILE \
39 | --output_dir $OUTPUT_DIR \
40 | --model $MODEL \
41 | --max_new_tokens $MAX_NEW_TOKENS \
42 | --temperature $TEMPERATURE \
43 | --top_k $TOP_K \
44 | --top_p $TOP_P \
45 | --c_puct $C_PUCT \
46 | --alpha $ALPHA \
47 | --beta $BETA \
48 | --length_scale $LENGTH_SCALE \
49 | --num_rollouts $NUM_ROLLOUTS \
50 | --max_search_count $MAX_SEARCH_COUNT \
51 | --rollout_budget $ROLLOUT_BUDGET \
52 | --api_endpoint $API_ENDPOINT &
53 | done
54 |
55 | wait
56 |
--------------------------------------------------------------------------------
/data_pipeline/traverse.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 |
5 | SEP = ''
6 |
7 | THRESHOLD = 0.00
8 |
9 |
10 | def random_select_solutions(solutions, n):
11 | if len(solutions) <= n:
12 | return solutions
13 | return random.sample(solutions, n)
14 |
15 |
16 | def completion_too_short(string, word_count_thres=10):
17 | return len(string.split(' ')) <= word_count_thres
18 |
19 |
20 | def remove_redundant(solutions):
21 | unique_solutions = {}
22 | for solution in solutions:
23 | key = solution['process']
24 | if key not in unique_solutions:
25 | unique_solutions[key] = solution
26 |
27 | return list(unique_solutions.values())
28 |
29 |
30 | def traverse(root):
31 | question = root['question']
32 | answer = root['answer']
33 | image_path = root['image_path']
34 | positive = []
35 | negative = []
36 | res = []
37 |
38 | def dfs(node, solution_prefix, labels):
39 | partial_solution = list(map(str.strip, node['partial_solution']))
40 | partial_solution = '\n\n'.join(partial_solution)
41 | solution_prefix = solution_prefix + partial_solution + SEP + '\n\n'
42 | labels = labels + [node['mc_value']]
43 | if node['mc_value'] <= THRESHOLD:
44 | if not completion_too_short(solution_prefix):
45 | negative.append(
46 | {
47 | 'question': question,
48 | 'answer': answer,
49 | 'image_path': image_path,
50 | 'process': solution_prefix.strip(),
51 | 'labels': labels,
52 | }
53 | )
54 | return
55 | if node['children']:
56 | for child in node['children']:
57 | dfs(child, solution_prefix, labels)
58 | else:
59 | if not completion_too_short(solution_prefix):
60 | positive.append(
61 | {
62 | 'question': question,
63 | 'answer': answer,
64 | 'image_path': image_path,
65 | 'process': solution_prefix.strip(),
66 | 'labels': labels,
67 | }
68 | )
69 | return
70 |
71 | if root['children']:
72 | for child in root['children']:
73 | dfs(child, '', [])
74 |
75 | negative = remove_redundant(negative)
76 | positive = remove_redundant(positive)
77 |
78 | res.extend(negative)
79 |
80 | if len(positive) < len(negative):
81 | res.extend(positive)
82 | else:
83 | res.extend(random.sample(positive, len(negative)))
84 | return res
85 |
86 |
87 | if __name__ == '__main__':
88 | root_dir = 'outputs'
89 | res = []
90 | for file in os.listdir(root_dir):
91 | print(file)
92 | with open(os.path.join(root_dir, file), 'r') as f:
93 | data = json.load(f)
94 | if data['mc_value'] != 0.0 and data['mc_value'] != 1.0:
95 | res.extend(traverse(data))
96 |
97 | print(len(res))
98 | json.dump(res, open('output_file.json', 'w'), ensure_ascii=False, indent=4)
99 |
--------------------------------------------------------------------------------
/docs/case_study.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/case_study.png
--------------------------------------------------------------------------------
/docs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/logo.png
--------------------------------------------------------------------------------
/docs/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/performance.png
--------------------------------------------------------------------------------
/docs/wechat_qr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ModalMinds/MM-PRM/b914fd66c8efb3d6667610fba1f8fd1e55cde095/docs/wechat_qr.png
--------------------------------------------------------------------------------
/eval/prm/evaluate_k12_prm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 |
8 | import torch
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from internvl.model import load_model_and_tokenizer
13 | from internvl.train.dataset import build_transform, dynamic_preprocess
14 |
15 | ds_collections = {
16 | 'k12_prm': {'root': '/path/to/image/root', 'annotation': '/path/to/rollout/file'}
17 | }
18 |
19 |
20 | def collate_fn(batches):
21 | pixel_values = batches[0]['pixel_values']
22 | prompts = batches[0]['prompts']
23 | steps_lens = batches[0]['steps_lens']
24 | data_items = batches[0]['data_item']
25 | return pixel_values, prompts, steps_lens, data_items
26 |
27 |
28 | class K12PRMDataset(torch.utils.data.Dataset):
29 |
30 | def __init__(
31 | self,
32 | root,
33 | annotation,
34 | input_size=224,
35 | dynamic_image_size=False,
36 | use_thumbnail=False,
37 | max_num=6,
38 | ):
39 | self.root = root
40 | self.data = json.load(open(annotation))
41 | self.input_size = input_size
42 | self.dynamic_image_size = dynamic_image_size
43 | self.use_thumbnail = use_thumbnail
44 | self.max_num = max_num
45 | self.transform = build_transform(is_train=False, input_size=input_size)
46 |
47 | def __len__(self):
48 | return len(self.data)
49 |
50 | def __getitem__(self, idx):
51 | data_item = self.data[idx]
52 | image = Image.open(os.path.join(self.root, data_item['image_path'])).convert(
53 | 'RGB'
54 | )
55 |
56 | if self.dynamic_image_size:
57 | images = dynamic_preprocess(
58 | image,
59 | image_size=self.input_size,
60 | use_thumbnail=self.use_thumbnail,
61 | max_num=self.max_num,
62 | )
63 | else:
64 | images = [image]
65 | pixel_values = [self.transform(image) for image in images]
66 | pixel_values = torch.stack(pixel_values)
67 |
68 | question = data_item['question']
69 |
70 | prompts = []
71 | steps_lens = []
72 | for solution_split in data_item['solutions_splits']:
73 | solution = ''.join(solution_split) + ''
74 | prompt = f'Question: {question}\nProcess: {solution}'
75 | prompts.append(prompt)
76 | steps_lens.append(len(solution_split))
77 |
78 | return {
79 | 'pixel_values': pixel_values,
80 | 'prompts': prompts,
81 | 'steps_lens': steps_lens,
82 | 'data_item': data_item,
83 | }
84 |
85 |
86 | class InferenceSampler(torch.utils.data.sampler.Sampler):
87 |
88 | def __init__(self, size):
89 | self._size = int(size)
90 | assert size > 0
91 | self._rank = torch.distributed.get_rank()
92 | self._world_size = torch.distributed.get_world_size()
93 | self._local_indices = self._get_local_indices(
94 | size, self._world_size, self._rank
95 | )
96 |
97 | @staticmethod
98 | def _get_local_indices(total_size, world_size, rank):
99 | shard_size = total_size // world_size
100 | left = total_size % world_size
101 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
102 |
103 | begin = sum(shard_sizes[:rank])
104 | end = min(sum(shard_sizes[: rank + 1]), total_size)
105 | return range(begin, end)
106 |
107 | def __iter__(self):
108 | yield from self._local_indices
109 |
110 | def __len__(self):
111 | return len(self._local_indices)
112 |
113 |
114 | def evaluate_chat_model():
115 | random.seed(args.seed)
116 |
117 | for ds_name in args.datasets:
118 | dataset = K12PRMDataset(
119 | root=ds_collections[ds_name]['root'],
120 | annotation=ds_collections[ds_name]['annotation'],
121 | input_size=image_size,
122 | dynamic_image_size=args.dynamic,
123 | use_thumbnail=use_thumbnail,
124 | max_num=args.max_num,
125 | )
126 | dataloader = torch.utils.data.DataLoader(
127 | dataset=dataset,
128 | sampler=InferenceSampler(len(dataset)),
129 | batch_size=1,
130 | num_workers=args.num_workers,
131 | pin_memory=True,
132 | drop_last=False,
133 | collate_fn=collate_fn,
134 | )
135 |
136 | outputs = []
137 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm(
138 | enumerate(dataloader)
139 | ):
140 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
141 |
142 | prm_scores_flattened = []
143 | for i in range(0, len(prompts), args.mini_batch_size):
144 | curr_bs = min(args.mini_batch_size, len(prompts) - i)
145 | output = model.batch_prm(
146 | tokenizer=tokenizer,
147 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0),
148 | questions=prompts[i : i + curr_bs],
149 | num_patches_list=[pixel_values.shape[0]] * curr_bs,
150 | verbose=True,
151 | )
152 | prm_scores_flattened.extend(output.tolist())
153 |
154 | data_item['prm_scores'] = []
155 | curr_len = 0
156 | for i in range(len(steps_lens)):
157 | data_item['prm_scores'].append(
158 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]]
159 | )
160 | curr_len += steps_lens[i]
161 |
162 | for i in range(len(data_item['prm_scores'])):
163 | assert len(data_item['prm_scores'][i]) == steps_lens[i]
164 |
165 | print(f'Pred: {data_item["prm_scores"]}')
166 | outputs.append(data_item)
167 |
168 | if idx % 50 == 0:
169 | torch.distributed.barrier()
170 |
171 | torch.distributed.barrier()
172 |
173 | world_size = torch.distributed.get_world_size()
174 | merged_outputs = [None for _ in range(world_size)]
175 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
176 |
177 | merged_outputs = [json.loads(_) for _ in merged_outputs]
178 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
179 |
180 | if torch.distributed.get_rank() == 0:
181 | print(f'Evaluating {ds_name} ...')
182 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
183 | results_file = f'{ds_name}_{time_prefix}.json'
184 | output_path = os.path.join(args.out_dir, results_file)
185 | json.dump(
186 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False
187 | )
188 | print('Results saved to {}'.format(output_path))
189 |
190 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}'
191 | print(cmd)
192 | os.system(cmd)
193 |
194 |
195 | if __name__ == '__main__':
196 | parser = argparse.ArgumentParser()
197 | parser.add_argument('--checkpoint', type=str, default='')
198 | parser.add_argument('--datasets', type=str, default='k12_prm')
199 | parser.add_argument('--mini-batch-size', type=int, default=4)
200 | parser.add_argument('--num-workers', type=int, default=2)
201 | parser.add_argument('--out-dir', type=str, default='results')
202 | parser.add_argument('--seed', type=int, default=0)
203 | parser.add_argument('--dynamic', action='store_true', default=True)
204 | parser.add_argument('--max-num', type=int, default=6)
205 | parser.add_argument('--load-in-8bit', action='store_true')
206 | parser.add_argument('--load-in-4bit', action='store_true')
207 | parser.add_argument('--auto', action='store_true')
208 | args = parser.parse_args()
209 |
210 | if not os.path.exists(args.out_dir):
211 | os.makedirs(args.out_dir, exist_ok=True)
212 |
213 | args.datasets = args.datasets.split(',')
214 | print('datasets:', args.datasets)
215 |
216 | torch.distributed.init_process_group(
217 | backend='nccl',
218 | world_size=int(os.getenv('WORLD_SIZE', '1')),
219 | rank=int(os.getenv('RANK', '0')),
220 | )
221 |
222 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
223 |
224 | model, tokenizer = load_model_and_tokenizer(args)
225 |
226 | image_size = model.config.force_image_size or model.config.vision_config.image_size
227 | use_thumbnail = model.config.use_thumbnail
228 |
229 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
230 | print(f'[test] total_params: {total_params}B')
231 | print(f'[test] image_size: {image_size}')
232 | print(f'[test] template: {model.config.template}')
233 | print(f'[test] dynamic_image_size: {args.dynamic}')
234 | print(f'[test] use_thumbnail: {use_thumbnail}')
235 |
236 | evaluate_chat_model()
237 |
--------------------------------------------------------------------------------
/eval/prm/evaluate_mathverse_prm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 |
8 | import torch
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from internvl.model import load_model_and_tokenizer
13 | from internvl.train.dataset import build_transform, dynamic_preprocess
14 |
15 | ds_collections = {
16 | 'mathverse_prm': {
17 | 'root': '/path/to/image/root',
18 | 'annotation': '/path/to/rollout/file',
19 | }
20 | }
21 |
22 |
23 | def collate_fn(batches):
24 | pixel_values = batches[0]['pixel_values']
25 | prompts = batches[0]['prompts']
26 | steps_lens = batches[0]['steps_lens']
27 | data_items = batches[0]['data_item']
28 | return pixel_values, prompts, steps_lens, data_items
29 |
30 |
31 | class MathVersePRMDataset(torch.utils.data.Dataset):
32 |
33 | def __init__(
34 | self,
35 | root,
36 | annotation,
37 | input_size=224,
38 | dynamic_image_size=False,
39 | use_thumbnail=False,
40 | max_num=6,
41 | ):
42 | self.root = root
43 | self.data = json.load(open(annotation))
44 | self.input_size = input_size
45 | self.dynamic_image_size = dynamic_image_size
46 | self.use_thumbnail = use_thumbnail
47 | self.max_num = max_num
48 | self.transform = build_transform(is_train=False, input_size=input_size)
49 |
50 | def __len__(self):
51 | return len(self.data)
52 |
53 | def __getitem__(self, idx):
54 | data_item = self.data[idx]
55 | image = os.path.join(self.root, data_item['image'])
56 |
57 | image = Image.open(image).convert('RGB')
58 | if self.dynamic_image_size:
59 | images = dynamic_preprocess(
60 | image,
61 | image_size=self.input_size,
62 | use_thumbnail=self.use_thumbnail,
63 | max_num=self.max_num,
64 | )
65 | else:
66 | images = [image]
67 | pixel_values = [self.transform(image) for image in images]
68 | pixel_values = torch.stack(pixel_values)
69 |
70 | question = data_item['query_cot']
71 |
72 | prompts = []
73 | steps_lens = []
74 | for solution_split in data_item['solutions_splits']:
75 | solution = ''.join(solution_split) + ''
76 | prompt = f'Question: {question}\nProcess: {solution}'
77 | prompts.append(prompt)
78 | steps_lens.append(len(solution_split))
79 |
80 | return {
81 | 'pixel_values': pixel_values,
82 | 'prompts': prompts,
83 | 'steps_lens': steps_lens,
84 | 'data_item': data_item,
85 | }
86 |
87 |
88 | class InferenceSampler(torch.utils.data.sampler.Sampler):
89 |
90 | def __init__(self, size):
91 | self._size = int(size)
92 | assert size > 0
93 | self._rank = torch.distributed.get_rank()
94 | self._world_size = torch.distributed.get_world_size()
95 | self._local_indices = self._get_local_indices(
96 | size, self._world_size, self._rank
97 | )
98 |
99 | @staticmethod
100 | def _get_local_indices(total_size, world_size, rank):
101 | shard_size = total_size // world_size
102 | left = total_size % world_size
103 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
104 |
105 | begin = sum(shard_sizes[:rank])
106 | end = min(sum(shard_sizes[: rank + 1]), total_size)
107 | return range(begin, end)
108 |
109 | def __iter__(self):
110 | yield from self._local_indices
111 |
112 | def __len__(self):
113 | return len(self._local_indices)
114 |
115 |
116 | def evaluate_chat_model():
117 | random.seed(args.seed)
118 |
119 | for ds_name in args.datasets:
120 | dataset = MathVersePRMDataset(
121 | root=ds_collections[ds_name]['root'],
122 | annotation=ds_collections[ds_name]['annotation'],
123 | input_size=image_size,
124 | dynamic_image_size=args.dynamic,
125 | use_thumbnail=use_thumbnail,
126 | max_num=args.max_num,
127 | )
128 | dataloader = torch.utils.data.DataLoader(
129 | dataset=dataset,
130 | sampler=InferenceSampler(len(dataset)),
131 | batch_size=1,
132 | num_workers=args.num_workers,
133 | pin_memory=True,
134 | drop_last=False,
135 | collate_fn=collate_fn,
136 | )
137 |
138 | outputs = []
139 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm(
140 | enumerate(dataloader)
141 | ):
142 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
143 |
144 | prm_scores_flattened = []
145 | for i in range(0, len(prompts), args.mini_batch_size):
146 | curr_bs = min(args.mini_batch_size, len(prompts) - i)
147 | output = model.batch_prm(
148 | tokenizer=tokenizer,
149 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0),
150 | questions=prompts[i : i + curr_bs],
151 | num_patches_list=[pixel_values.shape[0]] * curr_bs,
152 | verbose=True,
153 | )
154 | prm_scores_flattened.extend(output.tolist())
155 |
156 | data_item['prm_scores'] = []
157 | curr_len = 0
158 | for i in range(len(steps_lens)):
159 | data_item['prm_scores'].append(
160 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]]
161 | )
162 | curr_len += steps_lens[i]
163 |
164 | for i in range(len(data_item['prm_scores'])):
165 | assert len(data_item['prm_scores'][i]) == steps_lens[i]
166 |
167 | print(f'Pred: {data_item["prm_scores"]}')
168 | outputs.append(data_item)
169 |
170 | if idx % 50 == 0:
171 | torch.distributed.barrier()
172 |
173 | torch.distributed.barrier()
174 |
175 | world_size = torch.distributed.get_world_size()
176 | merged_outputs = [None for _ in range(world_size)]
177 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
178 |
179 | merged_outputs = [json.loads(_) for _ in merged_outputs]
180 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
181 |
182 | if torch.distributed.get_rank() == 0:
183 | print(f'Evaluating {ds_name} ...')
184 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
185 | results_file = f'{ds_name}_{time_prefix}.json'
186 | output_path = os.path.join(args.out_dir, results_file)
187 | json.dump(
188 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False
189 | )
190 | print('Results saved to {}'.format(output_path))
191 |
192 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}'
193 | print(cmd)
194 | os.system(cmd)
195 |
196 |
197 | if __name__ == '__main__':
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument('--checkpoint', type=str, default='')
200 | parser.add_argument('--datasets', type=str, default='')
201 | parser.add_argument('--mini-batch-size', type=int, default=4)
202 | parser.add_argument('--num-workers', type=int, default=2)
203 | parser.add_argument('--out-dir', type=str, default='results')
204 | parser.add_argument('--seed', type=int, default=0)
205 | parser.add_argument('--dynamic', action='store_true', default=True)
206 | parser.add_argument('--max-num', type=int, default=6)
207 | parser.add_argument('--load-in-8bit', action='store_true')
208 | parser.add_argument('--load-in-4bit', action='store_true')
209 | parser.add_argument('--auto', action='store_true')
210 | args = parser.parse_args()
211 |
212 | if not os.path.exists(args.out_dir):
213 | os.makedirs(args.out_dir, exist_ok=True)
214 |
215 | args.datasets = args.datasets.split(',')
216 | print('datasets:', args.datasets)
217 |
218 | torch.distributed.init_process_group(
219 | backend='nccl',
220 | world_size=int(os.getenv('WORLD_SIZE', '1')),
221 | rank=int(os.getenv('RANK', '0')),
222 | )
223 |
224 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
225 |
226 | model, tokenizer = load_model_and_tokenizer(args)
227 |
228 | image_size = model.config.force_image_size or model.config.vision_config.image_size
229 | use_thumbnail = model.config.use_thumbnail
230 |
231 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
232 | print(f'[test] total_params: {total_params}B')
233 | print(f'[test] image_size: {image_size}')
234 | print(f'[test] template: {model.config.template}')
235 | print(f'[test] dynamic_image_size: {args.dynamic}')
236 | print(f'[test] use_thumbnail: {use_thumbnail}')
237 |
238 | evaluate_chat_model()
239 |
--------------------------------------------------------------------------------
/eval/prm/evaluate_mathvision_prm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 |
8 | import torch
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from internvl.model import load_model_and_tokenizer
13 | from internvl.train.dataset import build_transform, dynamic_preprocess
14 |
15 | ds_collections = {
16 | 'mathvision_prm': {
17 | 'root': '/path/to/image/root',
18 | 'annotation': '/path/to/rollout/file',
19 | }
20 | }
21 |
22 |
23 | def collate_fn(batches):
24 | pixel_values = batches[0]['pixel_values']
25 | prompts = batches[0]['prompts']
26 | steps_lens = batches[0]['steps_lens']
27 | data_items = batches[0]['data_item']
28 | return pixel_values, prompts, steps_lens, data_items
29 |
30 |
31 | class MathVisionPRMDataset(torch.utils.data.Dataset):
32 |
33 | def __init__(
34 | self,
35 | root,
36 | annotation,
37 | input_size=224,
38 | dynamic_image_size=False,
39 | use_thumbnail=False,
40 | max_num=6,
41 | ):
42 | self.root = root
43 | self.data = json.load(open(annotation))
44 | self.input_size = input_size
45 | self.dynamic_image_size = dynamic_image_size
46 | self.use_thumbnail = use_thumbnail
47 | self.max_num = max_num
48 | self.transform = build_transform(is_train=False, input_size=input_size)
49 |
50 | def __len__(self):
51 | return len(self.data)
52 |
53 | def __getitem__(self, idx):
54 | data_item = self.data[idx]
55 | image = os.path.join(self.root, os.path.basename(data_item['image']))
56 |
57 | image = Image.open(image).convert('RGB')
58 | if self.dynamic_image_size:
59 | images = dynamic_preprocess(
60 | image,
61 | image_size=self.input_size,
62 | use_thumbnail=self.use_thumbnail,
63 | max_num=self.max_num,
64 | )
65 | else:
66 | images = [image]
67 | pixel_values = [self.transform(image) for image in images]
68 | pixel_values = torch.stack(pixel_values)
69 |
70 | options = ''
71 | if len(data_item['options']) > 0:
72 | assert len(data_item['options']) == 5, data_item
73 | if ''.join(data_item['options']) != 'ABCDE':
74 | options = f"(A) {data_item['options'][0]}\n(B) {data_item['options'][1]}\n(C) {data_item['options'][2]}\n(D) {data_item['options'][3]}\n(E) {data_item['options'][4]}\n"
75 | question = f"{data_item['question']}\n{options}"
76 |
77 | prompts = []
78 | steps_lens = []
79 | for solution_split in data_item['solutions_splits']:
80 | solution = ''.join(solution_split) + ''
81 | prompt = f'Question: {question}\nProcess: {solution}'
82 | prompts.append(prompt)
83 | steps_lens.append(len(solution_split))
84 |
85 | return {
86 | 'pixel_values': pixel_values,
87 | 'prompts': prompts,
88 | 'steps_lens': steps_lens,
89 | 'data_item': data_item,
90 | }
91 |
92 |
93 | class InferenceSampler(torch.utils.data.sampler.Sampler):
94 |
95 | def __init__(self, size):
96 | self._size = int(size)
97 | assert size > 0
98 | self._rank = torch.distributed.get_rank()
99 | self._world_size = torch.distributed.get_world_size()
100 | self._local_indices = self._get_local_indices(
101 | size, self._world_size, self._rank
102 | )
103 |
104 | @staticmethod
105 | def _get_local_indices(total_size, world_size, rank):
106 | shard_size = total_size // world_size
107 | left = total_size % world_size
108 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
109 |
110 | begin = sum(shard_sizes[:rank])
111 | end = min(sum(shard_sizes[: rank + 1]), total_size)
112 | return range(begin, end)
113 |
114 | def __iter__(self):
115 | yield from self._local_indices
116 |
117 | def __len__(self):
118 | return len(self._local_indices)
119 |
120 |
121 | def evaluate_chat_model():
122 | random.seed(args.seed)
123 |
124 | for ds_name in args.datasets:
125 | dataset = MathVisionPRMDataset(
126 | root=ds_collections[ds_name]['root'],
127 | annotation=ds_collections[ds_name]['annotation'],
128 | input_size=image_size,
129 | dynamic_image_size=args.dynamic,
130 | use_thumbnail=use_thumbnail,
131 | max_num=args.max_num,
132 | )
133 | dataloader = torch.utils.data.DataLoader(
134 | dataset=dataset,
135 | sampler=InferenceSampler(len(dataset)),
136 | batch_size=1,
137 | num_workers=args.num_workers,
138 | pin_memory=True,
139 | drop_last=False,
140 | collate_fn=collate_fn,
141 | )
142 |
143 | outputs = []
144 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm(
145 | enumerate(dataloader)
146 | ):
147 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
148 |
149 | prm_scores_flattened = []
150 | for i in range(0, len(prompts), args.mini_batch_size):
151 | curr_bs = min(args.mini_batch_size, len(prompts) - i)
152 | output = model.batch_prm(
153 | tokenizer=tokenizer,
154 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0),
155 | questions=prompts[i : i + curr_bs],
156 | num_patches_list=[pixel_values.shape[0]] * curr_bs,
157 | verbose=True,
158 | )
159 | prm_scores_flattened.extend(output.tolist())
160 |
161 | data_item['prm_scores'] = []
162 | curr_len = 0
163 | for i in range(len(steps_lens)):
164 | data_item['prm_scores'].append(
165 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]]
166 | )
167 | curr_len += steps_lens[i]
168 |
169 | for i in range(len(data_item['prm_scores'])):
170 | assert len(data_item['prm_scores'][i]) == steps_lens[i]
171 |
172 | print(f'Pred: {data_item["prm_scores"]}')
173 | outputs.append(data_item)
174 |
175 | if idx % 50 == 0:
176 | torch.distributed.barrier()
177 |
178 | torch.distributed.barrier()
179 |
180 | world_size = torch.distributed.get_world_size()
181 | merged_outputs = [None for _ in range(world_size)]
182 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
183 |
184 | merged_outputs = [json.loads(_) for _ in merged_outputs]
185 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
186 |
187 | if torch.distributed.get_rank() == 0:
188 | print(f'Evaluating {ds_name} ...')
189 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
190 | results_file = f'{ds_name}_{time_prefix}.json'
191 | output_path = os.path.join(args.out_dir, results_file)
192 | json.dump(
193 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False
194 | )
195 | print('Results saved to {}'.format(output_path))
196 |
197 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}'
198 | print(cmd)
199 | os.system(cmd)
200 |
201 |
202 | if __name__ == '__main__':
203 | parser = argparse.ArgumentParser()
204 | parser.add_argument('--checkpoint', type=str, default='')
205 | parser.add_argument('--datasets', type=str, default='')
206 | parser.add_argument('--mini-batch-size', type=int, default=4)
207 | parser.add_argument('--num-workers', type=int, default=2)
208 | parser.add_argument('--out-dir', type=str, default='results')
209 | parser.add_argument('--seed', type=int, default=0)
210 | parser.add_argument('--dynamic', action='store_true', default=True)
211 | parser.add_argument('--max-num', type=int, default=6)
212 | parser.add_argument('--load-in-8bit', action='store_true')
213 | parser.add_argument('--load-in-4bit', action='store_true')
214 | parser.add_argument('--auto', action='store_true')
215 | args = parser.parse_args()
216 |
217 | if not os.path.exists(args.out_dir):
218 | os.makedirs(args.out_dir, exist_ok=True)
219 |
220 | args.datasets = args.datasets.split(',')
221 | print('datasets:', args.datasets)
222 |
223 | torch.distributed.init_process_group(
224 | backend='nccl',
225 | world_size=int(os.getenv('WORLD_SIZE', '1')),
226 | rank=int(os.getenv('RANK', '0')),
227 | )
228 |
229 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
230 |
231 | model, tokenizer = load_model_and_tokenizer(args)
232 |
233 | image_size = model.config.force_image_size or model.config.vision_config.image_size
234 | use_thumbnail = model.config.use_thumbnail
235 |
236 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
237 | print(f'[test] total_params: {total_params}B')
238 | print(f'[test] image_size: {image_size}')
239 | print(f'[test] template: {model.config.template}')
240 | print(f'[test] dynamic_image_size: {args.dynamic}')
241 | print(f'[test] use_thumbnail: {use_thumbnail}')
242 |
243 | evaluate_chat_model()
244 |
--------------------------------------------------------------------------------
/eval/prm/evaluate_mathvista_prm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 |
8 | import torch
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from internvl.model import load_model_and_tokenizer
13 | from internvl.train.dataset import build_transform, dynamic_preprocess
14 |
15 | ds_collections = {
16 | 'mathvista_prm': {
17 | 'root': '/path/to/image/root',
18 | 'annotation': '/path/to/rollout/file',
19 | }
20 | }
21 |
22 |
23 | def collate_fn(batches):
24 | pixel_values = batches[0]['pixel_values']
25 | prompts = batches[0]['prompts']
26 | steps_lens = batches[0]['steps_lens']
27 | data_items = batches[0]['data_item']
28 | return pixel_values, prompts, steps_lens, data_items
29 |
30 |
31 | class MathVistaPRMDataset(torch.utils.data.Dataset):
32 |
33 | def __init__(
34 | self,
35 | root,
36 | annotation,
37 | input_size=224,
38 | dynamic_image_size=False,
39 | use_thumbnail=False,
40 | max_num=6,
41 | ):
42 | self.root = root
43 | self.data = json.load(open(annotation))
44 | self.input_size = input_size
45 | self.dynamic_image_size = dynamic_image_size
46 | self.use_thumbnail = use_thumbnail
47 | self.max_num = max_num
48 | self.transform = build_transform(is_train=False, input_size=input_size)
49 |
50 | def __len__(self):
51 | return len(self.data)
52 |
53 | def __getitem__(self, idx):
54 | data_item = self.data[idx]
55 | image = os.path.join(self.root, os.path.basename(data_item['image']))
56 |
57 | image = Image.open(image).convert('RGB')
58 | if self.dynamic_image_size:
59 | images = dynamic_preprocess(
60 | image,
61 | image_size=self.input_size,
62 | use_thumbnail=self.use_thumbnail,
63 | max_num=self.max_num,
64 | )
65 | else:
66 | images = [image]
67 | pixel_values = [self.transform(image) for image in images]
68 | pixel_values = torch.stack(pixel_values)
69 |
70 | question = data_item['query']
71 |
72 | prompts = []
73 | steps_lens = []
74 | for solution_split in data_item['solutions_splits']:
75 | solution = ''.join(solution_split) + ''
76 | prompt = f'Question: {question}\nProcess: {solution}'
77 | prompts.append(prompt)
78 | steps_lens.append(len(solution_split))
79 |
80 | return {
81 | 'pixel_values': pixel_values,
82 | 'prompts': prompts,
83 | 'steps_lens': steps_lens,
84 | 'data_item': data_item,
85 | }
86 |
87 |
88 | class InferenceSampler(torch.utils.data.sampler.Sampler):
89 |
90 | def __init__(self, size):
91 | self._size = int(size)
92 | assert size > 0
93 | self._rank = torch.distributed.get_rank()
94 | self._world_size = torch.distributed.get_world_size()
95 | self._local_indices = self._get_local_indices(
96 | size, self._world_size, self._rank
97 | )
98 |
99 | @staticmethod
100 | def _get_local_indices(total_size, world_size, rank):
101 | shard_size = total_size // world_size
102 | left = total_size % world_size
103 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
104 |
105 | begin = sum(shard_sizes[:rank])
106 | end = min(sum(shard_sizes[: rank + 1]), total_size)
107 | return range(begin, end)
108 |
109 | def __iter__(self):
110 | yield from self._local_indices
111 |
112 | def __len__(self):
113 | return len(self._local_indices)
114 |
115 |
116 | def evaluate_chat_model():
117 | random.seed(args.seed)
118 |
119 | for ds_name in args.datasets:
120 | dataset = MathVistaPRMDataset(
121 | root=ds_collections[ds_name]['root'],
122 | annotation=ds_collections[ds_name]['annotation'],
123 | input_size=image_size,
124 | dynamic_image_size=args.dynamic,
125 | use_thumbnail=use_thumbnail,
126 | max_num=args.max_num,
127 | )
128 | dataloader = torch.utils.data.DataLoader(
129 | dataset=dataset,
130 | sampler=InferenceSampler(len(dataset)),
131 | batch_size=1,
132 | num_workers=args.num_workers,
133 | pin_memory=True,
134 | drop_last=False,
135 | collate_fn=collate_fn,
136 | )
137 |
138 | outputs = []
139 | for idx, (pixel_values, prompts, steps_lens, data_item) in tqdm(
140 | enumerate(dataloader)
141 | ):
142 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
143 |
144 | prm_scores_flattened = []
145 | for i in range(0, len(prompts), args.mini_batch_size):
146 | curr_bs = min(args.mini_batch_size, len(prompts) - i)
147 | output = model.batch_prm(
148 | tokenizer=tokenizer,
149 | pixel_values=torch.cat([pixel_values] * curr_bs, dim=0),
150 | questions=prompts[i : i + curr_bs],
151 | num_patches_list=[pixel_values.shape[0]] * curr_bs,
152 | verbose=True,
153 | )
154 | prm_scores_flattened.extend(output.tolist())
155 |
156 | data_item['prm_scores'] = []
157 | curr_len = 0
158 | for i in range(len(steps_lens)):
159 | data_item['prm_scores'].append(
160 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]]
161 | )
162 | curr_len += steps_lens[i]
163 |
164 | for i in range(len(data_item['prm_scores'])):
165 | assert len(data_item['prm_scores'][i]) == steps_lens[i]
166 |
167 | print(f'Pred: {data_item["prm_scores"]}')
168 | outputs.append(data_item)
169 |
170 | if idx % 50 == 0:
171 | torch.distributed.barrier()
172 |
173 | torch.distributed.barrier()
174 |
175 | world_size = torch.distributed.get_world_size()
176 | merged_outputs = [None for _ in range(world_size)]
177 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
178 |
179 | merged_outputs = [json.loads(_) for _ in merged_outputs]
180 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
181 |
182 | if torch.distributed.get_rank() == 0:
183 | print(f'Evaluating {ds_name} ...')
184 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
185 | results_file = f'{ds_name}_{time_prefix}.json'
186 | output_path = os.path.join(args.out_dir, results_file)
187 | json.dump(
188 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False
189 | )
190 | print('Results saved to {}'.format(output_path))
191 |
192 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}'
193 | print(cmd)
194 | os.system(cmd)
195 |
196 |
197 | if __name__ == '__main__':
198 | parser = argparse.ArgumentParser()
199 | parser.add_argument('--checkpoint', type=str, default='')
200 | parser.add_argument('--datasets', type=str, default='')
201 | parser.add_argument('--mini-batch-size', type=int, default=4)
202 | parser.add_argument('--num-workers', type=int, default=2)
203 | parser.add_argument('--out-dir', type=str, default='results')
204 | parser.add_argument('--seed', type=int, default=0)
205 | parser.add_argument('--dynamic', action='store_true', default=True)
206 | parser.add_argument('--max-num', type=int, default=6)
207 | parser.add_argument('--load-in-8bit', action='store_true')
208 | parser.add_argument('--load-in-4bit', action='store_true')
209 | parser.add_argument('--auto', action='store_true')
210 | args = parser.parse_args()
211 |
212 | if not os.path.exists(args.out_dir):
213 | os.makedirs(args.out_dir, exist_ok=True)
214 |
215 | args.datasets = args.datasets.split(',')
216 | print('datasets:', args.datasets)
217 |
218 | torch.distributed.init_process_group(
219 | backend='nccl',
220 | world_size=int(os.getenv('WORLD_SIZE', '1')),
221 | rank=int(os.getenv('RANK', '0')),
222 | )
223 |
224 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
225 |
226 | model, tokenizer = load_model_and_tokenizer(args)
227 |
228 | image_size = model.config.force_image_size or model.config.vision_config.image_size
229 | use_thumbnail = model.config.use_thumbnail
230 |
231 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
232 | print(f'[test] total_params: {total_params}B')
233 | print(f'[test] image_size: {image_size}')
234 | print(f'[test] template: {model.config.template}')
235 | print(f'[test] dynamic_image_size: {args.dynamic}')
236 | print(f'[test] use_thumbnail: {use_thumbnail}')
237 |
238 | evaluate_chat_model()
239 |
--------------------------------------------------------------------------------
/eval/prm/evaluate_olympiadbench_prm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 |
8 | import torch
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | from internvl.model import load_model_and_tokenizer
13 | from internvl.train.dataset import build_transform, dynamic_preprocess
14 |
15 | ds_collections = {
16 | 'olympiadbench_prm': {
17 | 'root': '/path/to/image/root',
18 | 'annotation': '/path/to/rollout/file',
19 | }
20 | }
21 |
22 |
23 | def collate_fn(batches):
24 | pixel_values = batches[0]['pixel_values']
25 | prompts = batches[0]['prompts']
26 | steps_lens = batches[0]['steps_lens']
27 | num_image_patch = batches[0]['num_image_patch']
28 | data_items = batches[0]['data_item']
29 | return pixel_values, prompts, steps_lens, num_image_patch, data_items
30 |
31 |
32 | class OlympiadBenchPRMDataset(torch.utils.data.Dataset):
33 |
34 | def __init__(
35 | self,
36 | root,
37 | annotation,
38 | input_size=224,
39 | dynamic_image_size=False,
40 | use_thumbnail=False,
41 | max_num=6,
42 | ):
43 | self.root = root
44 | self.data = json.load(open(annotation))
45 | self.input_size = input_size
46 | self.dynamic_image_size = dynamic_image_size
47 | self.use_thumbnail = use_thumbnail
48 | self.max_num = max_num
49 | self.transform = build_transform(is_train=False, input_size=input_size)
50 |
51 | def __len__(self):
52 | return len(self.data)
53 |
54 | def __getitem__(self, idx):
55 | data_item = self.data[idx]
56 |
57 | images, num_tiles = [], []
58 | for i in range(1, 6):
59 | key = f'image_{i}'
60 | if data_item[key] is None:
61 | continue
62 |
63 | image = Image.open(os.path.join(self.root, data_item[key])).convert('RGB')
64 |
65 | if self.dynamic_image_size:
66 | image = dynamic_preprocess(
67 | image,
68 | image_size=self.input_size,
69 | use_thumbnail=self.use_thumbnail,
70 | max_num=self.max_num,
71 | )
72 | images += image
73 | num_tiles.append(len(image))
74 | else:
75 | images.append(image)
76 | num_tiles.append(1)
77 | pixel_values = [self.transform(image) for image in images]
78 | pixel_values = torch.stack(pixel_values)
79 |
80 | question = data_item['question']
81 |
82 | prompts = []
83 | steps_lens = []
84 | for solution_split in data_item['solutions_splits']:
85 | solution = ''.join(solution_split) + ''
86 | prompt = f'Question: {question}\nProcess: {solution}'
87 | prompts.append(prompt)
88 | steps_lens.append(len(solution_split))
89 |
90 | return {
91 | 'pixel_values': pixel_values,
92 | 'num_image_patch': num_tiles,
93 | 'prompts': prompts,
94 | 'steps_lens': steps_lens,
95 | 'data_item': data_item,
96 | }
97 |
98 |
99 | class InferenceSampler(torch.utils.data.sampler.Sampler):
100 |
101 | def __init__(self, size):
102 | self._size = int(size)
103 | assert size > 0
104 | self._rank = torch.distributed.get_rank()
105 | self._world_size = torch.distributed.get_world_size()
106 | self._local_indices = self._get_local_indices(
107 | size, self._world_size, self._rank
108 | )
109 |
110 | @staticmethod
111 | def _get_local_indices(total_size, world_size, rank):
112 | shard_size = total_size // world_size
113 | left = total_size % world_size
114 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
115 |
116 | begin = sum(shard_sizes[:rank])
117 | end = min(sum(shard_sizes[: rank + 1]), total_size)
118 | return range(begin, end)
119 |
120 | def __iter__(self):
121 | yield from self._local_indices
122 |
123 | def __len__(self):
124 | return len(self._local_indices)
125 |
126 |
127 | def evaluate_chat_model():
128 | random.seed(args.seed)
129 |
130 | for ds_name in args.datasets:
131 | dataset = OlympiadBenchPRMDataset(
132 | root=ds_collections[ds_name]['root'],
133 | annotation=ds_collections[ds_name]['annotation'],
134 | input_size=image_size,
135 | dynamic_image_size=args.dynamic,
136 | use_thumbnail=use_thumbnail,
137 | max_num=args.max_num,
138 | )
139 | dataloader = torch.utils.data.DataLoader(
140 | dataset=dataset,
141 | sampler=InferenceSampler(len(dataset)),
142 | batch_size=1,
143 | num_workers=args.num_workers,
144 | pin_memory=True,
145 | drop_last=False,
146 | collate_fn=collate_fn,
147 | )
148 |
149 | outputs = []
150 | for idx, (
151 | pixel_values,
152 | prompts,
153 | steps_lens,
154 | num_image_patch,
155 | data_item,
156 | ) in tqdm(enumerate(dataloader)):
157 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
158 |
159 | prm_scores_flattened = []
160 | for i in range(0, len(prompts), args.mini_batch_size):
161 | curr_bs = min(args.mini_batch_size, len(prompts) - i)
162 | output = model.prm(
163 | tokenizer=tokenizer,
164 | pixel_values=pixel_values,
165 | question=prompts[i : i + curr_bs][0],
166 | num_patches_list=num_image_patch,
167 | verbose=True,
168 | )
169 | prm_scores_flattened.extend(output.tolist())
170 |
171 | data_item['prm_scores'] = []
172 | curr_len = 0
173 | for i in range(len(steps_lens)):
174 | data_item['prm_scores'].append(
175 | prm_scores_flattened[curr_len : curr_len + steps_lens[i]]
176 | )
177 | curr_len += steps_lens[i]
178 |
179 | for i in range(len(data_item['prm_scores'])):
180 | assert len(data_item['prm_scores'][i]) == steps_lens[i]
181 |
182 | print(f'Pred: {data_item["prm_scores"]}')
183 | outputs.append(data_item)
184 |
185 | if idx % 50 == 0:
186 | torch.distributed.barrier()
187 |
188 | torch.distributed.barrier()
189 |
190 | world_size = torch.distributed.get_world_size()
191 | merged_outputs = [None for _ in range(world_size)]
192 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
193 |
194 | merged_outputs = [json.loads(_) for _ in merged_outputs]
195 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
196 |
197 | if torch.distributed.get_rank() == 0:
198 | print(f'Evaluating {ds_name} ...')
199 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
200 | results_file = f'{ds_name}_{time_prefix}.json'
201 | output_path = os.path.join(args.out_dir, results_file)
202 | json.dump(
203 | merged_outputs, open(output_path, 'w'), indent=4, ensure_ascii=False
204 | )
205 | print('Results saved to {}'.format(output_path))
206 |
207 | cmd = f'python eval/prm/extract_calculate.py --output_file {results_file} --output_dir {args.out_dir}'
208 | print(cmd)
209 | os.system(cmd)
210 |
211 |
212 | if __name__ == '__main__':
213 | parser = argparse.ArgumentParser()
214 | parser.add_argument('--checkpoint', type=str, default='')
215 | parser.add_argument('--datasets', type=str, default='')
216 | parser.add_argument('--mini-batch-size', type=int, default=1)
217 | parser.add_argument('--num-workers', type=int, default=2)
218 | parser.add_argument('--out-dir', type=str, default='results')
219 | parser.add_argument('--seed', type=int, default=0)
220 | parser.add_argument('--dynamic', action='store_true', default=True)
221 | parser.add_argument('--max-num', type=int, default=6)
222 | parser.add_argument('--load-in-8bit', action='store_true')
223 | parser.add_argument('--load-in-4bit', action='store_true')
224 | parser.add_argument('--auto', action='store_true')
225 | args = parser.parse_args()
226 |
227 | if not os.path.exists(args.out_dir):
228 | os.makedirs(args.out_dir, exist_ok=True)
229 |
230 | args.datasets = args.datasets.split(',')
231 | print('datasets:', args.datasets)
232 |
233 | torch.distributed.init_process_group(
234 | backend='nccl',
235 | world_size=int(os.getenv('WORLD_SIZE', '1')),
236 | rank=int(os.getenv('RANK', '0')),
237 | )
238 |
239 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
240 |
241 | model, tokenizer = load_model_and_tokenizer(args)
242 |
243 | image_size = model.config.force_image_size or model.config.vision_config.image_size
244 | use_thumbnail = model.config.use_thumbnail
245 |
246 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
247 | print(f'[test] total_params: {total_params}B')
248 | print(f'[test] image_size: {image_size}')
249 | print(f'[test] template: {model.config.template}')
250 | print(f'[test] dynamic_image_size: {args.dynamic}')
251 | print(f'[test] use_thumbnail: {use_thumbnail}')
252 |
253 | evaluate_chat_model()
254 |
--------------------------------------------------------------------------------
/eval/prm/extract_calculate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import math
4 | import os
5 | import random
6 | from functools import reduce
7 |
8 |
9 | def calculate_accuracy(data):
10 | results = {}
11 |
12 | cnts = []
13 | for i in range(5000):
14 | cnt = 0
15 | for item in data:
16 | if random.choice(item['labels']) == 1:
17 | cnt += 1
18 | cnts.append(cnt)
19 |
20 | cnt = sum(cnts) / len(cnts)
21 | print(f'random {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
22 | results['random'] = {
23 | 'correct': cnt,
24 | 'total': len(data),
25 | 'accuracy': cnt / len(data),
26 | }
27 |
28 | cnt = 0
29 | for item in data:
30 | labels = random.sample(item['labels'], min(16, len(item['labels'])))
31 | if sum(labels) >= 1:
32 | cnt += 1
33 | print(f'pass@16 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
34 | results['pass@16'] = {
35 | 'correct': cnt,
36 | 'total': len(data),
37 | 'accuracy': cnt / len(data),
38 | }
39 |
40 | cnt = 0
41 | for item in data:
42 | labels = random.sample(item['labels'], min(8, len(item['labels'])))
43 | if sum(labels) >= 1:
44 | cnt += 1
45 | print(f'pass@8 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
46 | results['pass@8'] = {
47 | 'correct': cnt,
48 | 'total': len(data),
49 | 'accuracy': cnt / len(data),
50 | }
51 |
52 | cnt = 0
53 | for item in data:
54 | labels = random.sample(item['labels'], min(4, len(item['labels'])))
55 | if sum(labels) >= 1:
56 | cnt += 1
57 | print(f'pass@4 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
58 | results['pass@4'] = {
59 | 'correct': cnt,
60 | 'total': len(data),
61 | 'accuracy': cnt / len(data),
62 | }
63 |
64 | cnt = 0
65 | for item in data:
66 | labels = random.sample(item['labels'], min(2, len(item['labels'])))
67 | if sum(labels) >= 1:
68 | cnt += 1
69 | print(f'pass@2 {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
70 | results['pass@2'] = {
71 | 'correct': cnt,
72 | 'total': len(data),
73 | 'accuracy': cnt / len(data),
74 | }
75 |
76 | cnt = 0
77 | for item in data:
78 | prm_score = list(map(lambda x: min(x) if x else 0, item['prm_scores']))
79 | if item['labels'][prm_score.index(max(prm_score))] == 1:
80 | cnt += 1
81 | print(f'prm accuracy min {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
82 | results['min'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)}
83 |
84 | cnt = 0
85 | for item in data:
86 | prm_score = list(map(lambda x: x[-1] if x else 0, item['prm_scores']))
87 | if item['labels'][prm_score.index(max(prm_score))] == 1:
88 | cnt += 1
89 | print(f'prm accuracy last {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
90 | results['last'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)}
91 |
92 | cnt = 0
93 | for item in data:
94 | prm_score = list(
95 | map(lambda x: reduce(lambda a, b: a * b, x) if x else 0, item['prm_scores'])
96 | )
97 | if item['labels'][prm_score.index(max(prm_score))] == 1:
98 | cnt += 1
99 | print(f'prm accuracy product {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
100 | results['product'] = {
101 | 'correct': cnt,
102 | 'total': len(data),
103 | 'accuracy': cnt / len(data),
104 | }
105 |
106 | cnt = 0
107 | for item in data:
108 | prm_score = list(map(lambda x: sum(x) / len(x) if x else 0, item['prm_scores']))
109 | if item['labels'][prm_score.index(max(prm_score))] == 1:
110 | cnt += 1
111 | print(f'prm accuracy average {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
112 | results['average'] = {
113 | 'correct': cnt,
114 | 'total': len(data),
115 | 'accuracy': cnt / len(data),
116 | }
117 |
118 | cnt = 0
119 | for item in data:
120 | prm_score = list(
121 | map(
122 | lambda x: (
123 | sum(list(map(lambda a: math.log(a) if a != 0 else -9999, x)))
124 | if x
125 | else 0
126 | ),
127 | item['prm_scores'],
128 | )
129 | )
130 | if item['labels'][prm_score.index(max(prm_score))] == 1:
131 | cnt += 1
132 | print(
133 | f'prm accuracy sum_logprob {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%'
134 | )
135 | results['sum_logprob'] = {
136 | 'correct': cnt,
137 | 'total': len(data),
138 | 'accuracy': cnt / len(data),
139 | }
140 |
141 | cnt = 0
142 | for item in data:
143 | prm_score = list(map(lambda x: max(x) if x else 0, item['prm_scores']))
144 | if item['labels'][prm_score.index(max(prm_score))] == 1:
145 | cnt += 1
146 | print(f'prm accuracy max {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
147 | results['max'] = {'correct': cnt, 'total': len(data), 'accuracy': cnt / len(data)}
148 |
149 | cnt = 0
150 | for item in data:
151 | try:
152 | prm_score = list(
153 | map(
154 | lambda x: (
155 | sum(
156 | list(
157 | map(
158 | lambda a: math.log(a / (1 - a)) if a != 1 else 9999,
159 | x,
160 | )
161 | )
162 | )
163 | if x
164 | else 0
165 | ),
166 | item['prm_scores'],
167 | )
168 | )
169 | except Exception as e:
170 | print(e)
171 | print(item['prm_scores'])
172 | raise e
173 | if item['labels'][prm_score.index(max(prm_score))] == 1:
174 | cnt += 1
175 | print(f'prm accuracy sum_logit {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
176 | results['sum_logit'] = {
177 | 'correct': cnt,
178 | 'total': len(data),
179 | 'accuracy': cnt / len(data),
180 | }
181 |
182 | cnt = 0
183 | for item in data:
184 | prm_score = list(
185 | map(
186 | lambda x: (
187 | sum(list(map(lambda a: (a / (1 - a)) if a != 1 else 9999, x)))
188 | / len(x)
189 | if x
190 | else 0
191 | ),
192 | item['prm_scores'],
193 | )
194 | )
195 | if item['labels'][prm_score.index(max(prm_score))] == 1:
196 | cnt += 1
197 | print(f'prm accuracy mean_odd {cnt} / {len(data)}, rate: {cnt / len(data) * 100}%')
198 | results['mean_odd'] = {
199 | 'correct': cnt,
200 | 'total': len(data),
201 | 'accuracy': cnt / len(data),
202 | }
203 |
204 | return results
205 |
206 |
207 | if __name__ == '__main__':
208 | parser = argparse.ArgumentParser()
209 | parser.add_argument('--output_dir', type=str, default='./results')
210 | parser.add_argument('--output_file', type=str, default='')
211 | args = parser.parse_args()
212 |
213 | result_file = os.path.join(args.output_dir, args.output_file)
214 |
215 | print(f'Reading {result_file}...')
216 | results = calculate_accuracy(json.load(open(result_file)))
217 |
218 | print(f"Saving results to {result_file.replace('.json', f'_score.json')}...")
219 | json.dump(
220 | results,
221 | open(result_file.replace('.json', f'_score.json'), 'w'),
222 | indent=4,
223 | ensure_ascii=False,
224 | )
225 | print(f'Results saved.')
226 |
--------------------------------------------------------------------------------
/evaluate.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | CHECKPOINT=${1}
4 | DATASET=${2}
5 | CHECKPOINT="$(pwd)/${CHECKPOINT}"
6 | export PYTHONPATH="$(pwd):${PYTHONPATH}"
7 | echo "CHECKPOINT: ${CHECKPOINT}"
8 |
9 | MASTER_PORT=${MASTER_PORT:-63669}
10 | PORT=${PORT:-63665}
11 | GPUS=${GPUS:-8}
12 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
13 | NODES=$((GPUS / GPUS_PER_NODE))
14 | export MASTER_PORT=${MASTER_PORT}
15 | export PORT=${PORT}
16 |
17 | # Save original arguments
18 | ARGS=("$@")
19 |
20 | # Parse options
21 | while [[ $# -gt 0 ]]; do
22 | case "$1" in
23 | --auto)
24 | GPUS=1
25 | shift
26 | ;;
27 | *)
28 | shift
29 | ;;
30 | esac
31 | done
32 | echo "GPUS: ${GPUS}"
33 |
34 | if [[ "${DATASET}" == *"k12"* ]]; then
35 | torchrun \
36 | --nnodes=1 \
37 | --node_rank=0 \
38 | --master_addr=127.0.0.1 \
39 | --nproc_per_node=${GPUS} \
40 | --master_port=${MASTER_PORT} \
41 | eval/prm/evaluate_k12_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}"
42 | fi
43 |
44 | if [[ "${DATASET}" == *"mathvista"* ]]; then
45 | torchrun \
46 | --nnodes=1 \
47 | --node_rank=0 \
48 | --master_addr=127.0.0.1 \
49 | --nproc_per_node=${GPUS} \
50 | --master_port=${MASTER_PORT} \
51 | eval/prm/evaluate_mathvista_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}"
52 | fi
53 |
54 | if [[ "${DATASET}" == *"mathverse"* ]]; then
55 | torchrun \
56 | --nnodes=1 \
57 | --node_rank=0 \
58 | --master_addr=127.0.0.1 \
59 | --nproc_per_node=${GPUS} \
60 | --master_port=${MASTER_PORT} \
61 | eval/prm/evaluate_mathverse_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}"
62 | fi
63 |
64 | if [[ "${DATASET}" == *"mathvision"* ]]; then
65 | torchrun \
66 | --nnodes=1 \
67 | --node_rank=0 \
68 | --master_addr=127.0.0.1 \
69 | --nproc_per_node=${GPUS} \
70 | --master_port=${MASTER_PORT} \
71 | eval/prm/evaluate_mathvision_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}"
72 | fi
73 |
74 | if [[ "${DATASET}" == *"olympiadbench"* ]]; then
75 | torchrun \
76 | --nnodes=1 \
77 | --node_rank=0 \
78 | --master_addr=127.0.0.1 \
79 | --nproc_per_node=${GPUS} \
80 | --master_port=${MASTER_PORT} \
81 | eval/prm/evaluate_olympiadbench_prm.py --checkpoint ${CHECKPOINT} --datasets ${DATASET} "${ARGS[@]:2}"
82 | fi
83 |
--------------------------------------------------------------------------------
/internvl/dist_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import socket
3 | import subprocess
4 | from datetime import timedelta
5 |
6 | import deepspeed
7 | import torch
8 | import torch.multiprocessing as mp
9 | from torch import distributed as dist
10 |
11 | timeout = timedelta(minutes=60)
12 |
13 |
14 | def _find_free_port():
15 | # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
16 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
17 | # Binding to port 0 will cause the OS to find an available port for us
18 | sock.bind(('', 0))
19 | port = sock.getsockname()[1]
20 | sock.close()
21 | # NOTE: there is still a chance the port could be taken by other processes.
22 | return port
23 |
24 |
25 | def _is_free_port(port):
26 | ips = socket.gethostbyname_ex(socket.gethostname())[-1]
27 | ips.append('localhost')
28 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
29 | return all(s.connect_ex((ip, port)) != 0 for ip in ips)
30 |
31 |
32 | def init_dist(launcher, backend='nccl', **kwargs):
33 | if mp.get_start_method(allow_none=True) is None:
34 | mp.set_start_method('spawn')
35 | if launcher == 'pytorch':
36 | _init_dist_pytorch(backend, **kwargs)
37 | elif launcher == 'mpi':
38 | _init_dist_mpi(backend, **kwargs)
39 | elif launcher == 'slurm':
40 | _init_dist_slurm(backend, **kwargs)
41 | else:
42 | raise ValueError(f'Invalid launcher type: {launcher}')
43 |
44 |
45 | def _init_dist_pytorch(backend, **kwargs):
46 | # TODO: use local_rank instead of rank % num_gpus
47 | rank = int(os.environ['RANK'])
48 | num_gpus = torch.cuda.device_count()
49 | torch.cuda.set_device(rank % num_gpus)
50 | # dist.init_process_group(backend=backend, **kwargs)
51 | deepspeed.init_distributed(dist_backend=backend)
52 |
53 |
54 | def _init_dist_mpi(backend, **kwargs):
55 | local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
56 | torch.cuda.set_device(local_rank)
57 | if 'MASTER_PORT' not in os.environ:
58 | # 29500 is torch.distributed default port
59 | os.environ['MASTER_PORT'] = '29500'
60 | if 'MASTER_ADDR' not in os.environ:
61 | raise KeyError('The environment variable MASTER_ADDR is not set')
62 | os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
63 | os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
64 | dist.init_process_group(backend=backend, **kwargs)
65 |
66 |
67 | def _init_dist_slurm(backend, port=None):
68 | """Initialize slurm distributed training environment.
69 |
70 | If argument ``port`` is not specified, then the master port will be system
71 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
72 | environment variable, then a default port ``29500`` will be used.
73 |
74 | Args:
75 | backend (str): Backend of torch.distributed.
76 | port (int, optional): Master port. Defaults to None.
77 | """
78 | proc_id = int(os.environ['SLURM_PROCID'])
79 | ntasks = int(os.environ['SLURM_NTASKS'])
80 | node_list = os.environ['SLURM_NODELIST']
81 | num_gpus = torch.cuda.device_count()
82 | torch.cuda.set_device(proc_id % num_gpus)
83 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
84 | # specify master port
85 | if port is not None:
86 | os.environ['MASTER_PORT'] = str(port)
87 | elif 'MASTER_PORT' in os.environ:
88 | pass # use MASTER_PORT in the environment variable
89 | else:
90 | # if torch.distributed default port(29500) is available
91 | # then use it, else find a free port
92 | if _is_free_port(29500):
93 | os.environ['MASTER_PORT'] = '29500'
94 | else:
95 | os.environ['MASTER_PORT'] = str(_find_free_port())
96 | # use MASTER_ADDR in the environment variable if it already exists
97 | if 'MASTER_ADDR' not in os.environ:
98 | os.environ['MASTER_ADDR'] = addr
99 | os.environ['WORLD_SIZE'] = str(ntasks)
100 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
101 | os.environ['RANK'] = str(proc_id)
102 | # dist.init_process_group(backend=backend, timeout=timeout)
103 | deepspeed.init_distributed(dist_backend=backend)
104 |
--------------------------------------------------------------------------------
/internvl/model/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import math
8 |
9 | import torch
10 | from transformers import AutoTokenizer
11 |
12 | from internvl.model.internvl_chat import InternVLChatConfig, InternVLChatModel
13 |
14 |
15 | def split_model(num_layers, vit_alpha=0.5):
16 | device_map = {}
17 | world_size = torch.cuda.device_count()
18 | # Since the first GPU will be used for ViT, treat it as half a GPU.
19 | num_layers_per_gpu = math.ceil(num_layers / (world_size - vit_alpha))
20 | num_layers_per_gpu = [num_layers_per_gpu] * world_size
21 | num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * (1 - vit_alpha))
22 | layer_cnt = 0
23 | for i, num_layer in enumerate(num_layers_per_gpu):
24 | for j in range(num_layer):
25 | device_map[f'language_model.model.layers.{layer_cnt}'] = i
26 | layer_cnt += 1
27 | device_map['vision_model'] = 0
28 | device_map['mlp1'] = 0
29 | device_map['language_model.model.tok_embeddings'] = 0
30 | device_map['language_model.model.embed_tokens'] = 0
31 | device_map['language_model.output'] = 0
32 | device_map['language_model.model.norm'] = 0
33 | device_map['language_model.lm_head'] = 0
34 | device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
35 | device_map['language_model.model.rotary_emb'] = 0
36 |
37 | return device_map
38 |
39 |
40 | def load_model_and_tokenizer(args):
41 | if args.auto:
42 | config = InternVLChatConfig.from_pretrained(args.checkpoint)
43 | num_hidden_layers = config.llm_config.num_hidden_layers
44 | device_map = split_model(num_hidden_layers)
45 | kwargs = {'device_map': device_map} if args.auto else {}
46 | tokenizer = AutoTokenizer.from_pretrained(
47 | args.checkpoint, trust_remote_code=True, use_fast=False
48 | )
49 | model = InternVLChatModel.from_pretrained(
50 | args.checkpoint,
51 | low_cpu_mem_usage=True,
52 | torch_dtype=torch.bfloat16,
53 | load_in_8bit=args.load_in_8bit,
54 | load_in_4bit=args.load_in_4bit,
55 | **kwargs,
56 | ).eval()
57 | if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
58 | model = model.cuda()
59 | return model, tokenizer
60 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/configuration_internlm2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/configuration_llama.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """InternLM2 model configuration"""
17 |
18 | from transformers.configuration_utils import PretrainedConfig
19 | from transformers.utils import logging
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | INTERNLM2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24 |
25 |
26 | # Modified from transformers.model.llama.configuration_llama.LlamaConfig
27 | class InternLM2Config(PretrainedConfig):
28 | r"""
29 | This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
30 | an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
31 | configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
32 |
33 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
34 | documentation from [`PretrainedConfig`] for more information.
35 |
36 |
37 | Args:
38 | vocab_size (`int`, *optional*, defaults to 32000):
39 | Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
40 | `inputs_ids` passed when calling [`InternLM2Model`]
41 | hidden_size (`int`, *optional*, defaults to 4096):
42 | Dimension of the hidden representations.
43 | intermediate_size (`int`, *optional*, defaults to 11008):
44 | Dimension of the MLP representations.
45 | num_hidden_layers (`int`, *optional*, defaults to 32):
46 | Number of hidden layers in the Transformer encoder.
47 | num_attention_heads (`int`, *optional*, defaults to 32):
48 | Number of attention heads for each attention layer in the Transformer encoder.
49 | num_key_value_heads (`int`, *optional*):
50 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
51 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
52 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
53 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
54 | by meanpooling all the original heads within that group. For more details checkout [this
55 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
56 | `num_attention_heads`.
57 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58 | The non-linear activation function (function or string) in the decoder.
59 | max_position_embeddings (`int`, *optional*, defaults to 2048):
60 | The maximum sequence length that this model might ever be used with. Typically set this to something large
61 | just in case (e.g., 512 or 1024 or 2048).
62 | initializer_range (`float`, *optional*, defaults to 0.02):
63 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64 | rms_norm_eps (`float`, *optional*, defaults to 1e-12):
65 | The epsilon used by the rms normalization layers.
66 | use_cache (`bool`, *optional*, defaults to `True`):
67 | Whether or not the model should return the last key/values attentions (not used by all models). Only
68 | relevant if `config.is_decoder=True`.
69 | tie_word_embeddings(`bool`, *optional*, defaults to `False`):
70 | Whether to tie weight embeddings
71 | Example:
72 |
73 | """
74 |
75 | model_type = 'internlm2'
76 | _auto_class = 'AutoConfig'
77 |
78 | def __init__( # pylint: disable=W0102
79 | self,
80 | vocab_size=103168,
81 | hidden_size=4096,
82 | intermediate_size=11008,
83 | num_hidden_layers=32,
84 | num_attention_heads=32,
85 | num_key_value_heads=None,
86 | hidden_act='silu',
87 | max_position_embeddings=2048,
88 | initializer_range=0.02,
89 | rms_norm_eps=1e-6,
90 | use_cache=True,
91 | pad_token_id=0,
92 | bos_token_id=1,
93 | eos_token_id=2,
94 | tie_word_embeddings=False,
95 | bias=True,
96 | rope_theta=10000,
97 | rope_scaling=None,
98 | attn_implementation='eager',
99 | **kwargs,
100 | ):
101 | self.vocab_size = vocab_size
102 | self.max_position_embeddings = max_position_embeddings
103 | self.hidden_size = hidden_size
104 | self.intermediate_size = intermediate_size
105 | self.num_hidden_layers = num_hidden_layers
106 | self.num_attention_heads = num_attention_heads
107 | self.bias = bias
108 |
109 | if num_key_value_heads is None:
110 | num_key_value_heads = num_attention_heads
111 | self.num_key_value_heads = num_key_value_heads
112 |
113 | self.hidden_act = hidden_act
114 | self.initializer_range = initializer_range
115 | self.rms_norm_eps = rms_norm_eps
116 | self.use_cache = use_cache
117 | self.rope_theta = rope_theta
118 | self.rope_scaling = rope_scaling
119 | self._rope_scaling_validation()
120 |
121 | self.attn_implementation = attn_implementation
122 | if self.attn_implementation is None:
123 | self.attn_implementation = 'eager'
124 | super().__init__(
125 | pad_token_id=pad_token_id,
126 | bos_token_id=bos_token_id,
127 | eos_token_id=eos_token_id,
128 | tie_word_embeddings=tie_word_embeddings,
129 | **kwargs,
130 | )
131 |
132 | def _rope_scaling_validation(self):
133 | """
134 | Validate the `rope_scaling` configuration.
135 | """
136 | if self.rope_scaling is None:
137 | return
138 |
139 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
140 | raise ValueError(
141 | '`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
142 | f'got {self.rope_scaling}'
143 | )
144 | rope_scaling_type = self.rope_scaling.get('type', None)
145 | rope_scaling_factor = self.rope_scaling.get('factor', None)
146 | if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
147 | raise ValueError(
148 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
149 | )
150 | if (
151 | rope_scaling_factor is None
152 | or not isinstance(rope_scaling_factor, float)
153 | or rope_scaling_factor < 1.0
154 | ):
155 | raise ValueError(
156 | f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
157 | )
158 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/tokenization_internlm2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Tokenization classes for InternLM."""
18 | import os
19 | from shutil import copyfile
20 | from typing import Any, Dict, List, Optional, Tuple
21 |
22 | import sentencepiece as spm
23 | from transformers.tokenization_utils import PreTrainedTokenizer
24 | from transformers.utils import logging
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
29 |
30 | PRETRAINED_VOCAB_FILES_MAP = {}
31 |
32 |
33 | # Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
34 | class InternLM2Tokenizer(PreTrainedTokenizer):
35 | """
36 | Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
37 |
38 | Args:
39 | vocab_file (`str`):
40 | Path to the vocabulary file.
41 | """
42 |
43 | vocab_files_names = VOCAB_FILES_NAMES
44 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
45 | model_input_names = ['input_ids', 'attention_mask']
46 | _auto_class = 'AutoTokenizer'
47 |
48 | def __init__(
49 | self,
50 | vocab_file,
51 | unk_token='',
52 | bos_token='',
53 | eos_token='',
54 | pad_token='',
55 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
56 | add_bos_token=True,
57 | add_eos_token=False,
58 | decode_with_prefix_space=False,
59 | clean_up_tokenization_spaces=False,
60 | **kwargs,
61 | ):
62 | self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
63 | self.vocab_file = vocab_file
64 | self.add_bos_token = add_bos_token
65 | self.add_eos_token = add_eos_token
66 | self.decode_with_prefix_space = decode_with_prefix_space
67 | self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
68 | self.sp_model.Load(vocab_file)
69 | self._no_prefix_space_tokens = None
70 | super().__init__(
71 | bos_token=bos_token,
72 | eos_token=eos_token,
73 | unk_token=unk_token,
74 | pad_token=pad_token,
75 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
76 | **kwargs,
77 | )
78 |
79 | @property
80 | def no_prefix_space_tokens(self):
81 | if self._no_prefix_space_tokens is None:
82 | vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
83 | self._no_prefix_space_tokens = {
84 | i for i, tok in enumerate(vocab) if not tok.startswith('▁')
85 | }
86 | return self._no_prefix_space_tokens
87 |
88 | @property
89 | def vocab_size(self):
90 | """Returns vocab size"""
91 | return self.sp_model.get_piece_size()
92 |
93 | @property
94 | def bos_token_id(self) -> Optional[int]:
95 | return self.sp_model.bos_id()
96 |
97 | @property
98 | def eos_token_id(self) -> Optional[int]:
99 | return self.sp_model.eos_id()
100 |
101 | def get_vocab(self):
102 | """Returns vocab as a dict"""
103 | vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
104 | vocab.update(self.added_tokens_encoder)
105 | return vocab
106 |
107 | def _tokenize(self, text):
108 | """Returns a tokenized string."""
109 | return self.sp_model.encode(text, out_type=str)
110 |
111 | def _convert_token_to_id(self, token):
112 | """Converts a token (str) in an id using the vocab."""
113 | return self.sp_model.piece_to_id(token)
114 |
115 | def _convert_id_to_token(self, index):
116 | """Converts an index (integer) in a token (str) using the vocab."""
117 | token = self.sp_model.IdToPiece(index)
118 | return token
119 |
120 | def _maybe_add_prefix_space(self, tokens, decoded):
121 | if tokens and tokens[0] not in self.no_prefix_space_tokens:
122 | return ' ' + decoded
123 | else:
124 | return decoded
125 |
126 | def convert_tokens_to_string(self, tokens):
127 | """Converts a sequence of tokens (string) in a single string."""
128 | current_sub_tokens = []
129 | out_string = ''
130 | prev_is_special = False
131 | for token in tokens:
132 | # make sure that special tokens are not decoded using sentencepiece model
133 | if token in self.all_special_tokens:
134 | if not prev_is_special:
135 | out_string += ' '
136 | out_string += self.sp_model.decode(current_sub_tokens) + token
137 | prev_is_special = True
138 | current_sub_tokens = []
139 | else:
140 | current_sub_tokens.append(token)
141 | prev_is_special = False
142 | out_string += self.sp_model.decode(current_sub_tokens)
143 | out_string = self.clean_up_tokenization(out_string)
144 | out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
145 | return out_string[1:]
146 |
147 | def save_vocabulary(
148 | self, save_directory, filename_prefix: Optional[str] = None
149 | ) -> Tuple[str]:
150 | """
151 | Save the vocabulary and special tokens file to a directory.
152 |
153 | Args:
154 | save_directory (`str`):
155 | The directory in which to save the vocabulary.
156 |
157 | Returns:
158 | `Tuple(str)`: Paths to the files saved.
159 | """
160 | if not os.path.isdir(save_directory):
161 | logger.error(f'Vocabulary path ({save_directory}) should be a directory')
162 | return
163 | out_vocab_file = os.path.join(
164 | save_directory,
165 | (filename_prefix + '-' if filename_prefix else '')
166 | + VOCAB_FILES_NAMES['vocab_file'],
167 | )
168 |
169 | if os.path.abspath(self.vocab_file) != os.path.abspath(
170 | out_vocab_file
171 | ) and os.path.isfile(self.vocab_file):
172 | copyfile(self.vocab_file, out_vocab_file)
173 | elif not os.path.isfile(self.vocab_file):
174 | with open(out_vocab_file, 'wb') as fi:
175 | content_spiece_model = self.sp_model.serialized_model_proto()
176 | fi.write(content_spiece_model)
177 |
178 | return (out_vocab_file,)
179 |
180 | def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
181 | if self.add_bos_token:
182 | bos_token_ids = [self.bos_token_id]
183 | else:
184 | bos_token_ids = []
185 |
186 | output = bos_token_ids + token_ids_0
187 |
188 | if token_ids_1 is not None:
189 | output = output + token_ids_1
190 |
191 | if self.add_eos_token:
192 | output = output + [self.eos_token_id]
193 |
194 | return output
195 |
196 | def get_special_tokens_mask(
197 | self,
198 | token_ids_0: List[int],
199 | token_ids_1: Optional[List[int]] = None,
200 | already_has_special_tokens: bool = False,
201 | ) -> List[int]:
202 | """
203 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
204 | special tokens using the tokenizer `prepare_for_model` method.
205 |
206 | Args:
207 | token_ids_0 (`List[int]`):
208 | List of IDs.
209 | token_ids_1 (`List[int]`, *optional*):
210 | Optional second list of IDs for sequence pairs.
211 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
212 | Whether or not the token list is already formatted with special tokens for the model.
213 |
214 | Returns:
215 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
216 | """
217 | if already_has_special_tokens:
218 | return super().get_special_tokens_mask(
219 | token_ids_0=token_ids_0,
220 | token_ids_1=token_ids_1,
221 | already_has_special_tokens=True,
222 | )
223 |
224 | if token_ids_1 is None:
225 | return [1] + ([0] * len(token_ids_0)) + [1]
226 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
227 |
228 | def create_token_type_ids_from_sequences(
229 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
230 | ) -> List[int]:
231 | """
232 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
233 | use of token type ids, therefore a list of zeros is returned.
234 |
235 | Args:
236 | token_ids_0 (`List[int]`):
237 | List of IDs.
238 | token_ids_1 (`List[int]`, *optional*):
239 | Optional second list of IDs for sequence pairs.
240 |
241 | Returns:
242 | `List[int]`: List of zeros.
243 | """
244 | eos = [self.eos_token_id]
245 |
246 | if token_ids_1 is None:
247 | return len(token_ids_0 + eos) * [0]
248 | return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
249 |
--------------------------------------------------------------------------------
/internvl/model/internlm2/tokenization_internlm2_fast.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # This code is based on transformers/src/transformers/models/llama/tokenization_llama_fast.py
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | """Tokenization Fast class for InternLM."""
18 | import os
19 | from shutil import copyfile
20 | from typing import Any, Dict, Optional, Tuple
21 |
22 | from tokenizers import Tokenizer, decoders, normalizers, processors
23 | from tokenizers.models import BPE
24 | from transformers.convert_slow_tokenizer import (SLOW_TO_FAST_CONVERTERS,
25 | SentencePieceExtractor,
26 | SpmConverter)
27 | from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
28 | from transformers.utils import logging
29 |
30 | from .tokenization_internlm2 import InternLM2Tokenizer
31 |
32 | logger = logging.get_logger(__name__)
33 |
34 | VOCAB_FILES_NAMES = {'vocab_file': './tokenizer.model'}
35 |
36 |
37 | # Modified from transformers.convert_slow_tokenizer.LlamaConverter
38 | class InternLM2Converter(SpmConverter):
39 | handle_byte_fallback = True
40 |
41 | def vocab(self, proto):
42 | vocab = [
43 | ('', 0.0),
44 | ('', 0.0),
45 | ('', 0.0),
46 | ]
47 | vocab += [(piece.piece, piece.score) for piece in proto.pieces[3:]]
48 | return vocab
49 |
50 | def unk_id(self, proto):
51 | unk_id = 0
52 | return unk_id
53 |
54 | def decoder(self, replacement, add_prefix_space):
55 | return decoders.Sequence(
56 | [
57 | decoders.Replace('▁', ' '),
58 | decoders.ByteFallback(),
59 | decoders.Fuse(),
60 | decoders.Strip(content=' ', left=1),
61 | ]
62 | )
63 |
64 | def tokenizer(self, proto):
65 | model_type = proto.trainer_spec.model_type
66 | vocab_scores = self.vocab(proto)
67 | # special tokens
68 | added_tokens = self.original_tokenizer.added_tokens_decoder
69 | for i in range(len(vocab_scores)):
70 | piece, score = vocab_scores[i]
71 | if i in added_tokens:
72 | vocab_scores[i] = (added_tokens[i].content, score)
73 | if model_type == 1:
74 | raise RuntimeError('InternLM2 is supposed to be a BPE model!')
75 |
76 | elif model_type == 2:
77 | _, merges = SentencePieceExtractor(
78 | self.original_tokenizer.vocab_file
79 | ).extract(vocab_scores)
80 | bpe_vocab = {word: i for i, (word, _score) in enumerate(vocab_scores)}
81 | tokenizer = Tokenizer(
82 | BPE(
83 | bpe_vocab,
84 | merges,
85 | unk_token=proto.trainer_spec.unk_piece,
86 | fuse_unk=True,
87 | byte_fallback=True,
88 | )
89 | )
90 | tokenizer.add_special_tokens(
91 | [added_token for index, added_token in added_tokens.items()]
92 | )
93 | else:
94 | raise Exception(
95 | "You're trying to run a `Unigram` model but you're file was trained with a different algorithm"
96 | )
97 |
98 | return tokenizer
99 |
100 | def normalizer(self, proto):
101 | normalizers_list = []
102 | if proto.normalizer_spec.add_dummy_prefix:
103 | normalizers_list.append(normalizers.Prepend(prepend='▁'))
104 | normalizers_list.append(normalizers.Replace(pattern=' ', content='▁'))
105 | return normalizers.Sequence(normalizers_list)
106 |
107 | def pre_tokenizer(self, replacement, add_prefix_space):
108 | return None
109 |
110 |
111 | SLOW_TO_FAST_CONVERTERS['InternLM2Tokenizer'] = InternLM2Converter
112 |
113 |
114 | # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
115 | class InternLM2TokenizerFast(PreTrainedTokenizerFast):
116 | vocab_files_names = VOCAB_FILES_NAMES
117 | slow_tokenizer_class = InternLM2Tokenizer
118 | padding_side = 'left'
119 | model_input_names = ['input_ids', 'attention_mask']
120 | _auto_class = 'AutoTokenizer'
121 |
122 | def __init__(
123 | self,
124 | vocab_file,
125 | unk_token='',
126 | bos_token='',
127 | eos_token='',
128 | pad_token='',
129 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
130 | add_bos_token=True,
131 | add_eos_token=False,
132 | decode_with_prefix_space=False,
133 | clean_up_tokenization_spaces=False,
134 | **kwargs,
135 | ):
136 | super().__init__(
137 | vocab_file=vocab_file,
138 | unk_token=unk_token,
139 | bos_token=bos_token,
140 | eos_token=eos_token,
141 | pad_token=pad_token,
142 | sp_model_kwargs=sp_model_kwargs,
143 | add_bos_token=add_bos_token,
144 | add_eos_token=add_eos_token,
145 | decode_with_prefix_space=decode_with_prefix_space,
146 | clean_up_tokenization_spaces=clean_up_tokenization_spaces,
147 | **kwargs,
148 | )
149 | self._add_bos_token = add_bos_token
150 | self._add_eos_token = add_eos_token
151 | self.update_post_processor()
152 | self.vocab_file = vocab_file
153 |
154 | @property
155 | def can_save_slow_tokenizer(self) -> bool:
156 | return os.path.isfile(self.vocab_file) if self.vocab_file else False
157 |
158 | def update_post_processor(self):
159 | """
160 | Updates the underlying post processor with the current `bos_token` and `eos_token`.
161 | """
162 | bos = self.bos_token
163 | bos_token_id = self.bos_token_id
164 | if bos is None and self.add_bos_token:
165 | raise ValueError('add_bos_token = True but bos_token = None')
166 |
167 | eos = self.eos_token
168 | eos_token_id = self.eos_token_id
169 | if eos is None and self.add_eos_token:
170 | raise ValueError('add_eos_token = True but eos_token = None')
171 |
172 | single = f"{(bos+':0 ') if self.add_bos_token else ''}$A:0{(' '+eos+':0') if self.add_eos_token else ''}"
173 | pair = f"{single}{(' '+bos+':1') if self.add_bos_token else ''} $B:1{(' '+eos+':1') if self.add_eos_token else ''}"
174 |
175 | special_tokens = []
176 | if self.add_bos_token:
177 | special_tokens.append((bos, bos_token_id))
178 | if self.add_eos_token:
179 | special_tokens.append((eos, eos_token_id))
180 | self._tokenizer.post_processor = processors.TemplateProcessing(
181 | single=single, pair=pair, special_tokens=special_tokens
182 | )
183 |
184 | @property
185 | def add_eos_token(self):
186 | return self._add_eos_token
187 |
188 | @property
189 | def add_bos_token(self):
190 | return self._add_bos_token
191 |
192 | @add_eos_token.setter
193 | def add_eos_token(self, value):
194 | self._add_eos_token = value
195 | self.update_post_processor()
196 |
197 | @add_bos_token.setter
198 | def add_bos_token(self, value):
199 | self._add_bos_token = value
200 | self.update_post_processor()
201 |
202 | def save_vocabulary(
203 | self, save_directory: str, filename_prefix: Optional[str] = None
204 | ) -> Tuple[str]:
205 | if not self.can_save_slow_tokenizer:
206 | raise ValueError(
207 | 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
208 | 'tokenizer.'
209 | )
210 |
211 | if not os.path.isdir(save_directory):
212 | logger.error(f'Vocabulary path ({save_directory}) should be a directory')
213 | return
214 | out_vocab_file = os.path.join(
215 | save_directory,
216 | (filename_prefix + '-' if filename_prefix else '')
217 | + VOCAB_FILES_NAMES['vocab_file'],
218 | )
219 |
220 | if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
221 | copyfile(self.vocab_file, out_vocab_file)
222 |
223 | return (out_vocab_file,)
224 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .configuration_intern_vit import InternVisionConfig
8 | from .configuration_internvl_chat import InternVLChatConfig
9 | from .modeling_intern_vit import InternVisionModel
10 | from .modeling_internvl_chat import InternVLChatModel
11 |
12 | __all__ = [
13 | 'InternVisionConfig',
14 | 'InternVisionModel',
15 | 'InternVLChatConfig',
16 | 'InternVLChatModel',
17 | ]
18 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/configuration_intern_vit.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import os
8 | from typing import Union
9 |
10 | from transformers.configuration_utils import PretrainedConfig
11 | from transformers.utils import logging
12 |
13 | logger = logging.get_logger(__name__)
14 |
15 |
16 | class InternVisionConfig(PretrainedConfig):
17 | r"""
18 | This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
19 | instantiate a vision encoder according to the specified arguments, defining the model architecture.
20 |
21 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22 | documentation from [`PretrainedConfig`] for more information.
23 |
24 | Args:
25 | num_channels (`int`, *optional*, defaults to 3):
26 | Number of color channels in the input images (e.g., 3 for RGB).
27 | patch_size (`int`, *optional*, defaults to 14):
28 | The size (resolution) of each patch.
29 | image_size (`int`, *optional*, defaults to 224):
30 | The size (resolution) of each image.
31 | qkv_bias (`bool`, *optional*, defaults to `False`):
32 | Whether to add a bias to the queries and values in the self-attention layers.
33 | hidden_size (`int`, *optional*, defaults to 3200):
34 | Dimensionality of the encoder layers and the pooler layer.
35 | num_attention_heads (`int`, *optional*, defaults to 25):
36 | Number of attention heads for each attention layer in the Transformer encoder.
37 | intermediate_size (`int`, *optional*, defaults to 12800):
38 | Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
39 | qk_normalization (`bool`, *optional*, defaults to `True`):
40 | Whether to normalize the queries and keys in the self-attention layers.
41 | num_hidden_layers (`int`, *optional*, defaults to 48):
42 | Number of hidden layers in the Transformer encoder.
43 | use_flash_attn (`bool`, *optional*, defaults to `True`):
44 | Whether to use flash attention mechanism.
45 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
46 | The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
47 | `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
48 | layer_norm_eps (`float`, *optional*, defaults to 1e-6):
49 | The epsilon used by the layer normalization layers.
50 | dropout (`float`, *optional*, defaults to 0.0):
51 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52 | drop_path_rate (`float`, *optional*, defaults to 0.0):
53 | Dropout rate for stochastic depth.
54 | attention_dropout (`float`, *optional*, defaults to 0.0):
55 | The dropout ratio for the attention probabilities.
56 | initializer_range (`float`, *optional*, defaults to 0.02):
57 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
58 | initializer_factor (`float`, *optional*, defaults to 0.1):
59 | A factor for layer scale.
60 | """
61 |
62 | model_type = 'intern_vit_6b'
63 |
64 | def __init__(
65 | self,
66 | num_channels=3,
67 | patch_size=14,
68 | image_size=224,
69 | qkv_bias=False,
70 | hidden_size=3200,
71 | num_attention_heads=25,
72 | intermediate_size=12800,
73 | qk_normalization=True,
74 | num_hidden_layers=48,
75 | use_flash_attn=True,
76 | hidden_act='gelu',
77 | norm_type='rms_norm',
78 | layer_norm_eps=1e-6,
79 | dropout=0.0,
80 | drop_path_rate=0.0,
81 | attention_dropout=0.0,
82 | initializer_range=0.02,
83 | initializer_factor=0.1,
84 | **kwargs,
85 | ):
86 | super().__init__(**kwargs)
87 |
88 | self.hidden_size = hidden_size
89 | self.intermediate_size = intermediate_size
90 | self.dropout = dropout
91 | self.drop_path_rate = drop_path_rate
92 | self.num_hidden_layers = num_hidden_layers
93 | self.num_attention_heads = num_attention_heads
94 | self.num_channels = num_channels
95 | self.patch_size = patch_size
96 | self.image_size = image_size
97 | self.initializer_range = initializer_range
98 | self.initializer_factor = initializer_factor
99 | self.attention_dropout = attention_dropout
100 | self.layer_norm_eps = layer_norm_eps
101 | self.hidden_act = hidden_act
102 | self.norm_type = norm_type
103 | self.qkv_bias = qkv_bias
104 | self.qk_normalization = qk_normalization
105 | self.use_flash_attn = use_flash_attn
106 |
107 | @classmethod
108 | def from_pretrained(
109 | cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
110 | ) -> 'PretrainedConfig':
111 | config_dict, kwargs = cls.get_config_dict(
112 | pretrained_model_name_or_path, **kwargs
113 | )
114 |
115 | if 'vision_config' in config_dict:
116 | config_dict = config_dict['vision_config']
117 |
118 | if (
119 | 'model_type' in config_dict
120 | and hasattr(cls, 'model_type')
121 | and config_dict['model_type'] != cls.model_type
122 | ):
123 | logger.warning(
124 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
125 | f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
126 | )
127 |
128 | return cls.from_dict(config_dict, **kwargs)
129 |
--------------------------------------------------------------------------------
/internvl/model/internvl_chat/configuration_internvl_chat.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import copy
8 |
9 | from transformers import AutoConfig, LlamaConfig, Qwen2Config
10 | from transformers.configuration_utils import PretrainedConfig
11 | from transformers.utils import logging
12 |
13 | from internvl.model.internlm2.configuration_internlm2 import InternLM2Config
14 | from internvl.model.phi3.configuration_phi3 import Phi3Config
15 |
16 | from .configuration_intern_vit import InternVisionConfig
17 |
18 | logger = logging.get_logger(__name__)
19 |
20 |
21 | class InternVLChatConfig(PretrainedConfig):
22 | model_type = 'internvl_chat'
23 | is_composition = True
24 |
25 | def __init__(
26 | self,
27 | vision_config=None,
28 | llm_config=None,
29 | use_backbone_lora=0,
30 | use_llm_lora=0,
31 | pad2square=False,
32 | select_layer=-1,
33 | force_image_size=None,
34 | downsample_ratio=0.5,
35 | template=None,
36 | dynamic_image_size=False,
37 | use_thumbnail=False,
38 | ps_version='v1',
39 | min_dynamic_patch=1,
40 | max_dynamic_patch=6,
41 | **kwargs,
42 | ):
43 | super().__init__(**kwargs)
44 |
45 | if vision_config is None:
46 | vision_config = {'architectures': ['InternVisionModel']}
47 | logger.info(
48 | 'vision_config is None. Initializing the InternVisionConfig with default values.'
49 | )
50 |
51 | if llm_config is None:
52 | # TODO: There might still be a bug in transformers version 4.44 and above.
53 | llm_config = {'architectures': ['']}
54 | logger.info(
55 | 'llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`).'
56 | )
57 |
58 | self.vision_config = InternVisionConfig(**vision_config)
59 | if llm_config['architectures'][0] == 'LlamaForCausalLM':
60 | self.llm_config = LlamaConfig(**llm_config)
61 | elif llm_config['architectures'][0] == 'InternLM2ForCausalLM':
62 | self.llm_config = InternLM2Config(**llm_config)
63 | elif llm_config['architectures'][0] == 'Phi3ForCausalLM':
64 | self.llm_config = Phi3Config(**llm_config)
65 | elif llm_config['architectures'][0] == 'Qwen2ForCausalLM':
66 | self.llm_config = Qwen2Config(**llm_config)
67 | else:
68 | raise ValueError(
69 | 'Unsupported architecture: {}'.format(llm_config['architectures'][0])
70 | )
71 | self.use_backbone_lora = use_backbone_lora
72 | self.use_llm_lora = use_llm_lora
73 | self.pad2square = pad2square
74 | self.select_layer = select_layer
75 | self.force_image_size = force_image_size
76 | self.downsample_ratio = downsample_ratio
77 | self.template = template
78 | self.dynamic_image_size = dynamic_image_size
79 | self.use_thumbnail = use_thumbnail
80 | self.ps_version = ps_version # pixel shuffle version
81 | self.min_dynamic_patch = min_dynamic_patch
82 | self.max_dynamic_patch = max_dynamic_patch
83 |
84 | self.hidden_size = self.llm_config.hidden_size
85 | # By default, we use tie_word_embeddings=False for models of all sizes.
86 | self.tie_word_embeddings = False
87 | self.llm_config.tie_word_embeddings = self.tie_word_embeddings
88 |
89 | logger.info(f'vision_select_layer: {self.select_layer}')
90 | logger.info(f'ps_version: {self.ps_version}')
91 | logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}')
92 | logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}')
93 |
94 | def to_dict(self):
95 | """
96 | Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
97 |
98 | Returns:
99 | `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
100 | """
101 | output = copy.deepcopy(self.__dict__)
102 | output['vision_config'] = self.vision_config.to_dict()
103 | output['llm_config'] = self.llm_config.to_dict()
104 | output['model_type'] = self.__class__.model_type
105 | output['use_backbone_lora'] = self.use_backbone_lora
106 | output['use_llm_lora'] = self.use_llm_lora
107 | output['select_layer'] = self.select_layer
108 | output['force_image_size'] = self.force_image_size
109 | output['downsample_ratio'] = self.downsample_ratio
110 | output['template'] = self.template
111 | output['dynamic_image_size'] = self.dynamic_image_size
112 | output['use_thumbnail'] = self.use_thumbnail
113 | output['ps_version'] = self.ps_version
114 | output['min_dynamic_patch'] = self.min_dynamic_patch
115 | output['max_dynamic_patch'] = self.max_dynamic_patch
116 |
117 | return output
118 |
--------------------------------------------------------------------------------
/internvl/model/phi3/configuration_phi3.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License atd
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Phi-3 model configuration"""
16 |
17 |
18 | from transformers.configuration_utils import PretrainedConfig
19 | from transformers.utils import logging
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24 | 'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
25 | 'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
26 | }
27 |
28 |
29 | class Phi3Config(PretrainedConfig):
30 | r"""
31 | This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
32 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
33 | defaults will yield a similar configuration to that of the
34 | [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
35 |
36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
37 | documentation from [`PretrainedConfig`] for more information.
38 |
39 | Args:
40 | vocab_size (`int`, *optional*, defaults to 32064):
41 | Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
42 | `inputs_ids` passed when calling [`Phi3Model`].
43 | hidden_size (`int`, *optional*, defaults to 3072):
44 | Dimension of the hidden representations.
45 | intermediate_size (`int`, *optional*, defaults to 8192):
46 | Dimension of the MLP representations.
47 | num_hidden_layers (`int`, *optional*, defaults to 32):
48 | Number of hidden layers in the Transformer decoder.
49 | num_attention_heads (`int`, *optional*, defaults to 32):
50 | Number of attention heads for each attention layer in the Transformer decoder.
51 | num_key_value_heads (`int`, *optional*):
52 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
53 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
54 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
55 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
56 | by meanpooling all the original heads within that group. For more details checkout [this
57 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
58 | `num_attention_heads`.
59 | resid_pdrop (`float`, *optional*, defaults to 0.0):
60 | Dropout probability for mlp outputs.
61 | embd_pdrop (`int`, *optional*, defaults to 0.0):
62 | The dropout ratio for the embeddings.
63 | attention_dropout (`float`, *optional*, defaults to 0.0):
64 | The dropout ratio after computing the attention scores.
65 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
66 | The non-linear activation function (function or string) in the decoder.
67 | max_position_embeddings (`int`, *optional*, defaults to 4096):
68 | The maximum sequence length that this model might ever be used with.
69 | original_max_position_embeddings (`int`, *optional*, defaults to 4096):
70 | The maximum sequence length that this model was trained with. This is used to determine the size of the
71 | original RoPE embeddings when using long scaling.
72 | initializer_range (`float`, *optional*, defaults to 0.02):
73 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74 | rms_norm_eps (`float`, *optional*, defaults to 1e-05):
75 | The epsilon value used for the RMSNorm.
76 | use_cache (`bool`, *optional*, defaults to `True`):
77 | Whether or not the model should return the last key/values attentions (not used by all models). Only
78 | relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80 | Whether to tie weight embeddings
81 | rope_theta (`float`, *optional*, defaults to 10000.0):
82 | The base period of the RoPE embeddings.
83 | rope_scaling (`dict`, *optional*):
84 | The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
85 | contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
86 | the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
87 | divided by the number of attention heads divided by 2.
88 | bos_token_id (`int`, *optional*, defaults to 1):
89 | The id of the "beginning-of-sequence" token.
90 | eos_token_id (`int`, *optional*, defaults to 32000):
91 | The id of the "end-of-sequence" token.
92 | pad_token_id (`int`, *optional*, defaults to 32000):
93 | The id of the padding token.
94 | sliding_window (`int`, *optional*):
95 | Sliding window attention window size. If `None`, no sliding window is applied.
96 |
97 | Example:
98 |
99 | ```python
100 | >>> from transformers import Phi3Model, Phi3Config
101 |
102 | >>> # Initializing a Phi-3 style configuration
103 | >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
104 |
105 | >>> # Initializing a model from the configuration
106 | >>> model = Phi3Model(configuration)
107 |
108 | >>> # Accessing the model configuration
109 | >>> configuration = model.config
110 | ```"""
111 |
112 | model_type = 'phi3'
113 | keys_to_ignore_at_inference = ['past_key_values']
114 |
115 | def __init__(
116 | self,
117 | vocab_size=32064,
118 | hidden_size=3072,
119 | intermediate_size=8192,
120 | num_hidden_layers=32,
121 | num_attention_heads=32,
122 | num_key_value_heads=None,
123 | resid_pdrop=0.0,
124 | embd_pdrop=0.0,
125 | attention_dropout=0.0,
126 | hidden_act='silu',
127 | max_position_embeddings=4096,
128 | original_max_position_embeddings=4096,
129 | initializer_range=0.02,
130 | rms_norm_eps=1e-5,
131 | use_cache=True,
132 | tie_word_embeddings=False,
133 | rope_theta=10000.0,
134 | rope_scaling=None,
135 | bos_token_id=1,
136 | eos_token_id=32000,
137 | pad_token_id=32000,
138 | sliding_window=None,
139 | **kwargs,
140 | ):
141 | self.vocab_size = vocab_size
142 | self.hidden_size = hidden_size
143 | self.intermediate_size = intermediate_size
144 | self.num_hidden_layers = num_hidden_layers
145 | self.num_attention_heads = num_attention_heads
146 |
147 | if num_key_value_heads is None:
148 | num_key_value_heads = num_attention_heads
149 |
150 | self.num_key_value_heads = num_key_value_heads
151 | self.resid_pdrop = resid_pdrop
152 | self.embd_pdrop = embd_pdrop
153 | self.attention_dropout = attention_dropout
154 | self.hidden_act = hidden_act
155 | self.max_position_embeddings = max_position_embeddings
156 | self.original_max_position_embeddings = original_max_position_embeddings
157 | self.initializer_range = initializer_range
158 | self.rms_norm_eps = rms_norm_eps
159 | self.use_cache = use_cache
160 | self.rope_theta = rope_theta
161 | self.rope_scaling = rope_scaling
162 | self._rope_scaling_validation()
163 | self.sliding_window = sliding_window
164 |
165 | super().__init__(
166 | bos_token_id=bos_token_id,
167 | eos_token_id=eos_token_id,
168 | pad_token_id=pad_token_id,
169 | tie_word_embeddings=tie_word_embeddings,
170 | **kwargs,
171 | )
172 |
173 | def _rope_scaling_validation(self):
174 | """
175 | Validate the `rope_scaling` configuration.
176 | """
177 | if self.rope_scaling is None:
178 | return
179 |
180 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
181 | raise ValueError(
182 | '`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
183 | f'got {self.rope_scaling}'
184 | )
185 | rope_scaling_type = self.rope_scaling.get('type', None)
186 | rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
187 | rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
188 | if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
189 | raise ValueError(
190 | f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}"
191 | )
192 | if not (
193 | isinstance(rope_scaling_short_factor, list)
194 | and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
195 | ):
196 | raise ValueError(
197 | f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
198 | )
199 | if (
200 | not len(rope_scaling_short_factor)
201 | == self.hidden_size // self.num_attention_heads // 2
202 | ):
203 | raise ValueError(
204 | f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
205 | )
206 | if not (
207 | isinstance(rope_scaling_long_factor, list)
208 | and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
209 | ):
210 | raise ValueError(
211 | f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
212 | )
213 | if (
214 | not len(rope_scaling_long_factor)
215 | == self.hidden_size // self.num_attention_heads // 2
216 | ):
217 | raise ValueError(
218 | f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
219 | )
220 |
--------------------------------------------------------------------------------
/internvl/patch/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from .internlm2_packed_training_patch import replace_internlm2_attention_class
8 | from .internvit_liger_monkey_patch import apply_liger_kernel_to_internvit
9 | from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn
10 | from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
11 | from .llama_packed_training_patch import replace_llama_attention_class
12 | from .llama_rmsnorm_monkey_patch import \
13 | replace_llama_rmsnorm_with_fused_rmsnorm
14 | from .pad_data_collator import (concat_pad_data_collator,
15 | dpo_concat_pad_data_collator,
16 | pad_data_collator)
17 | from .phi3_packed_training_patch import replace_phi3_attention_class
18 | from .qwen2_packed_training_patch import replace_qwen2_attention_class
19 | from .train_dataloader_patch import replace_train_dataloader
20 | from .train_sampler_patch import replace_train_sampler
21 |
22 | __all__ = [
23 | 'replace_llama_attn_with_flash_attn',
24 | 'replace_llama_rmsnorm_with_fused_rmsnorm',
25 | 'replace_llama2_attn_with_flash_attn',
26 | 'replace_train_sampler',
27 | 'replace_train_dataloader',
28 | 'replace_internlm2_attention_class',
29 | 'replace_qwen2_attention_class',
30 | 'replace_phi3_attention_class',
31 | 'replace_llama_attention_class',
32 | 'pad_data_collator',
33 | 'dpo_concat_pad_data_collator',
34 | 'concat_pad_data_collator',
35 | 'apply_liger_kernel_to_internvit',
36 | ]
37 |
--------------------------------------------------------------------------------
/internvl/patch/internlm2_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
9 |
10 | from internvl.model.internlm2.modeling_internlm2 import (
11 | INTERNLM2_ATTENTION_CLASSES, InternLM2FlashAttention2,
12 | apply_rotary_pos_emb)
13 |
14 |
15 | # Modified from internvl.model.internlm2.modeling_internlm2.InternLM2FlashAttention2
16 | class InternLM2FlashAttention2ForPackedTraining(InternLM2FlashAttention2):
17 |
18 | def _flash_attention_forward(
19 | self,
20 | query_states,
21 | key_states,
22 | value_states,
23 | attention_mask,
24 | query_length,
25 | dropout=0.0,
26 | softmax_scale=None,
27 | ):
28 | """
29 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
30 | first unpad the input, then computes the attention scores and pad the final attention scores.
31 |
32 | Args:
33 | query_states (`torch.Tensor`):
34 | Input query states to be passed to Flash Attention API
35 | key_states (`torch.Tensor`):
36 | Input key states to be passed to Flash Attention API
37 | value_states (`torch.Tensor`):
38 | Input value states to be passed to Flash Attention API
39 | attention_mask (`torch.Tensor`):
40 | rename from cu_seqlens to keep compatability - (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
41 | of the sequences in the batch.
42 | dropout (`int`, *optional*):
43 | Attention dropout
44 | softmax_scale (`float`, *optional*):
45 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
46 | """
47 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
48 | query_states = query_states.squeeze(0)
49 | key_states = key_states.squeeze(0)
50 | value_states = value_states.squeeze(0)
51 | cu_seqlens = attention_mask.squeeze(0)
52 |
53 | with torch.no_grad():
54 | max_seqlen = max(
55 | [
56 | cu_seqlens[idx + 1] - cu_seqlens[idx]
57 | for idx in range(cu_seqlens.size(0) - 1)
58 | ]
59 | ).item()
60 |
61 | # Contains at least one padding token in the sequence
62 | causal = self.is_causal and query_length != 1
63 | attn_output = flash_attn_varlen_func(
64 | q=query_states,
65 | k=key_states,
66 | v=value_states,
67 | cu_seqlens_q=cu_seqlens,
68 | cu_seqlens_k=cu_seqlens,
69 | max_seqlen_q=max_seqlen,
70 | max_seqlen_k=max_seqlen,
71 | dropout_p=dropout,
72 | softmax_scale=softmax_scale,
73 | causal=causal,
74 | )
75 |
76 | query_states = query_states.unsqueeze(0)
77 | key_states = key_states.unsqueeze(0)
78 | value_states = value_states.unsqueeze(0)
79 | return attn_output
80 |
81 |
82 | def replace_internlm2_attention_class():
83 | INTERNLM2_ATTENTION_CLASSES['flash_attention_2'] = (
84 | InternLM2FlashAttention2ForPackedTraining
85 | )
86 | print('Replace INTERNLM2_ATTENTION_CLASSES to support packed training!!')
87 |
--------------------------------------------------------------------------------
/internvl/patch/internvit_liger_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 |
8 | def apply_liger_kernel_to_internvit() -> None:
9 | from liger_kernel.transformers.layer_norm import LigerLayerNorm
10 | from liger_kernel.transformers.rms_norm import LigerRMSNorm
11 |
12 | from internvl.model.internvl_chat import modeling_intern_vit
13 |
14 | modeling_intern_vit.NORM2FN['rms_norm'] = LigerRMSNorm
15 | modeling_intern_vit.NORM2FN['layer_norm'] = LigerLayerNorm
16 | print('Liger kernel applied to InternViT')
17 |
--------------------------------------------------------------------------------
/internvl/patch/llama2_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | This file is copied from: https://github.com/lm-sys/FastChat
3 | """
4 |
5 | import warnings
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | from flash_attn import __version__ as flash_attn_version
10 | from flash_attn.bert_padding import pad_input, unpad_input
11 | from flash_attn.flash_attn_interface import (flash_attn_func,
12 | flash_attn_varlen_kvpacked_func)
13 | from transformers.models.llama.modeling_llama import (LlamaAttention,
14 | LlamaModel, rotate_half)
15 |
16 |
17 | def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
18 | gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
19 | gather_indices = gather_indices.repeat(
20 | 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
21 | )
22 | bsz = gather_indices.shape[0]
23 | cos, sin = (
24 | torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
25 | for x in cos_sin
26 | )
27 | q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
28 | return q, k
29 |
30 |
31 | def forward(
32 | self,
33 | hidden_states: torch.Tensor,
34 | attention_mask: Optional[torch.Tensor] = None,
35 | position_ids: Optional[torch.Tensor] = None,
36 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
37 | output_attentions: bool = False,
38 | use_cache: bool = False,
39 | padding_mask: Optional[torch.Tensor] = None,
40 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
41 | if output_attentions:
42 | warnings.warn(
43 | 'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.'
44 | )
45 |
46 | bsz, q_len, _ = hidden_states.size()
47 | kv_heads = getattr(self, 'num_key_value_heads', self.num_heads)
48 |
49 | q, k, v = (
50 | op(hidden_states).view(bsz, q_len, nh, self.head_dim)
51 | for op, nh in (
52 | (self.q_proj, self.num_heads),
53 | (self.k_proj, kv_heads),
54 | (self.v_proj, kv_heads),
55 | )
56 | )
57 | # shape: (b, s, num_heads, head_dim)
58 |
59 | kv_seq_len = k.shape[1]
60 | past_kv_len = 0
61 | if past_key_value is not None:
62 | past_kv_len = past_key_value[0].shape[2]
63 | kv_seq_len += past_kv_len
64 |
65 | cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
66 | q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
67 |
68 | if past_key_value is not None:
69 | assert (
70 | flash_attn_version >= '2.1.0'
71 | ), 'past_key_value support requires flash-attn >= 2.1.0'
72 | # reuse k, v
73 | k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
74 | v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
75 |
76 | past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
77 |
78 | if attention_mask is None:
79 | output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
80 | bsz, q_len, -1
81 | )
82 | else:
83 | q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
84 | # We can skip concat and call unpad twice but seems better to call unpad only once.
85 | kv, _, cu_k_lens, max_k = unpad_input(
86 | torch.stack((k, v), dim=2), attention_mask
87 | )
88 | output_unpad = flash_attn_varlen_kvpacked_func(
89 | q,
90 | kv,
91 | cu_q_lens,
92 | cu_k_lens,
93 | max_s,
94 | max_k,
95 | 0.0,
96 | softmax_scale=None,
97 | causal=True,
98 | )
99 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
100 | output = pad_input(output_unpad, indices, bsz, q_len)
101 |
102 | return self.o_proj(output), None, past_key_value
103 |
104 |
105 | # Disable the transformation of the attention mask in LlamaModel as flash attention
106 | # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
107 | def _prepare_decoder_attention_mask(
108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
109 | ):
110 | # [bsz, seq_len]
111 | if past_key_values_length > 0 and attention_mask is not None:
112 | attention_mask = torch.cat(
113 | (
114 | torch.full(
115 | (input_shape[0], past_key_values_length),
116 | True,
117 | dtype=attention_mask.dtype,
118 | device=attention_mask.device,
119 | ),
120 | attention_mask,
121 | ),
122 | dim=-1,
123 | )
124 |
125 | if attention_mask is not None and torch.all(attention_mask):
126 | return None # This uses the faster call when training with full samples
127 |
128 | return attention_mask
129 |
130 |
131 | def replace_llama2_attn_with_flash_attn():
132 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
133 | if cuda_major < 8:
134 | warnings.warn(
135 | 'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.'
136 | 'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593'
137 | )
138 |
139 | LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
140 | LlamaAttention.forward = forward
141 |
142 |
143 | def test():
144 | from fastchat.train.llama_flash_attn_monkey_patch import \
145 | forward as fastchat_forward
146 | from transformers.models.llama.configuration_llama import LlamaConfig
147 |
148 | config = LlamaConfig(
149 | hidden_size=1024,
150 | intermediate_size=128,
151 | num_hidden_layers=1,
152 | num_attention_heads=8,
153 | max_position_embeddings=16,
154 | )
155 | device = torch.device('cuda')
156 | model = LlamaModel(config)
157 | attn = LlamaAttention(config).to(device).half()
158 | bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
159 | position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
160 | -1, seqlen
161 | )
162 |
163 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
164 | for i in range(4):
165 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
166 | if i:
167 | mask[0, -i:] = False
168 | mask[1, :i] = False
169 |
170 | lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
171 | ref, _, _ = attn.forward(
172 | hidden, attention_mask=lmask, position_ids=position_ids
173 | )
174 |
175 | fast, _, _ = fastchat_forward(
176 | attn, hidden, attention_mask=mask, position_ids=position_ids
177 | )
178 |
179 | lmask = _prepare_decoder_attention_mask(
180 | model, mask, hidden.shape[:2], hidden, 0
181 | )
182 | test, _, _ = forward(
183 | attn, hidden, attention_mask=lmask, position_ids=position_ids
184 | )
185 |
186 | print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}')
187 | print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}')
188 | print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}')
189 | print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}')
190 | print(f'allclose(fast, test) = {torch.allclose(fast, test)}')
191 |
192 | with torch.no_grad():
193 | # Also check that past_kv is handled properly
194 | hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
195 | part_len = seqlen // 4
196 | assert part_len * 4 == seqlen
197 | mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
198 | mask[0, -2:] = False
199 | lmask = _prepare_decoder_attention_mask(
200 | model, mask, hidden.shape[:2], hidden, 0
201 | )
202 | oneshot, _, _ = forward(
203 | attn, hidden, attention_mask=lmask, position_ids=position_ids
204 | )
205 | parts = []
206 | past_kv, past_kv_len = None, 0
207 | for i in range(4):
208 | start = part_len * i
209 | end = start + part_len
210 | hidden_part = hidden[:, start:end, ...]
211 | lmask = _prepare_decoder_attention_mask(
212 | model,
213 | mask[:, start:end],
214 | hidden_part.shape[:2],
215 | hidden_part,
216 | past_kv_len,
217 | )
218 | part, _, past_kv = forward(
219 | attn,
220 | hidden_part.clone(),
221 | attention_mask=lmask,
222 | position_ids=position_ids[:, start:end],
223 | past_key_value=past_kv,
224 | use_cache=True,
225 | )
226 | parts.append(part)
227 | past_kv_len = past_kv[0].shape[2]
228 |
229 | print(
230 | f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}'
231 | )
232 | print(
233 | f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}'
234 | )
235 |
236 |
237 | if __name__ == '__main__':
238 | test()
239 |
--------------------------------------------------------------------------------
/internvl/patch/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import math
8 | from typing import Optional, Tuple
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | import transformers
13 | from torch import nn
14 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
15 |
16 |
17 | def forward(
18 | self,
19 | hidden_states: torch.Tensor,
20 | attention_mask: Optional[torch.Tensor] = None,
21 | position_ids: Optional[torch.Tensor] = None,
22 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
23 | output_attentions: bool = False,
24 | use_cache: bool = False,
25 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26 | """Input shape: Batch x Time x Channel
27 |
28 | attention_mask: [bsz, q_len]
29 | """
30 | from einops import rearrange
31 |
32 | try: # v1
33 | from flash_attn.flash_attn_interface import \
34 | flash_attn_unpadded_qkvpacked_func
35 | except: # v2
36 | from flash_attn.flash_attn_interface import (
37 | flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func,
38 | )
39 | from flash_attn.bert_padding import pad_input, unpad_input
40 |
41 | bsz, q_len, _ = hidden_states.size()
42 |
43 | query_states = (
44 | self.q_proj(hidden_states)
45 | .view(bsz, q_len, self.num_heads, self.head_dim)
46 | .transpose(1, 2)
47 | )
48 | key_states = (
49 | self.k_proj(hidden_states)
50 | .view(bsz, q_len, self.num_heads, self.head_dim)
51 | .transpose(1, 2)
52 | )
53 | value_states = (
54 | self.v_proj(hidden_states)
55 | .view(bsz, q_len, self.num_heads, self.head_dim)
56 | .transpose(1, 2)
57 | )
58 | # [bsz, q_len, nh, hd]
59 | # [bsz, nh, q_len, hd]
60 |
61 | kv_seq_len = key_states.shape[-2]
62 | assert past_key_value is None, 'past_key_value is not supported'
63 |
64 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
65 | query_states, key_states = apply_rotary_pos_emb(
66 | query_states, key_states, cos, sin, position_ids
67 | )
68 | # [bsz, nh, t, hd]
69 | assert not output_attentions, 'output_attentions is not supported'
70 | assert not use_cache, 'use_cache is not supported'
71 |
72 | # Flash attention codes from
73 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
74 |
75 | # transform the data into the format required by flash attention
76 | qkv = torch.stack(
77 | [query_states, key_states, value_states], dim=2
78 | ) # [bsz, nh, 3, q_len, hd]
79 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
80 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
81 | # the attention_mask should be the same as the key_padding_mask
82 | key_padding_mask = attention_mask
83 |
84 | if key_padding_mask is None:
85 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
86 | max_s = q_len
87 | cu_q_lens = torch.arange(
88 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
89 | )
90 | output = flash_attn_unpadded_qkvpacked_func(
91 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
92 | )
93 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
94 | else:
95 | nheads = qkv.shape[-2]
96 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
97 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
98 | x_unpad = rearrange(
99 | x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads
100 | )
101 | output_unpad = flash_attn_unpadded_qkvpacked_func(
102 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
103 | )
104 | output = rearrange(
105 | pad_input(
106 | rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len
107 | ),
108 | 'b s (h d) -> b s h d',
109 | h=nheads,
110 | )
111 | return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None
112 |
113 |
114 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
115 | # requires the attention mask to be the same as the key_padding_mask
116 | def _prepare_decoder_attention_mask(
117 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
118 | ):
119 | # [bsz, seq_len]
120 | return attention_mask
121 |
122 |
123 | def forward_2(
124 | self,
125 | hidden_states: torch.Tensor,
126 | attention_mask: Optional[torch.Tensor] = None,
127 | position_ids: Optional[torch.LongTensor] = None,
128 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
129 | output_attentions: bool = False,
130 | use_cache: bool = False,
131 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
132 | bsz, q_len, _ = hidden_states.size()
133 |
134 | query_states = (
135 | self.q_proj(hidden_states)
136 | .view(bsz, q_len, self.num_heads, self.head_dim)
137 | .transpose(1, 2)
138 | )
139 | key_states = (
140 | self.k_proj(hidden_states)
141 | .view(bsz, q_len, self.num_heads, self.head_dim)
142 | .transpose(1, 2)
143 | )
144 | value_states = (
145 | self.v_proj(hidden_states)
146 | .view(bsz, q_len, self.num_heads, self.head_dim)
147 | .transpose(1, 2)
148 | )
149 |
150 | kv_seq_len = key_states.shape[-2]
151 | if past_key_value is not None:
152 | kv_seq_len += past_key_value[0].shape[-2]
153 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
154 | query_states, key_states = apply_rotary_pos_emb(
155 | query_states, key_states, cos, sin, position_ids
156 | )
157 |
158 | assert not output_attentions, 'output_attentions is not supported'
159 | assert not use_cache, 'use_cache is not supported'
160 | assert past_key_value is None, 'past_key_value is not supported'
161 |
162 | if past_key_value is not None:
163 | # reuse k, v, self_attention
164 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
165 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
166 |
167 | past_key_value = (key_states, value_states) if use_cache else None
168 | if self.training:
169 | attn_output = F.scaled_dot_product_attention(
170 | query_states, key_states, value_states, dropout_p=0.0, is_causal=True
171 | )
172 | attn_weights = None
173 | else:
174 | attn_weights = torch.matmul(
175 | query_states, key_states.transpose(2, 3)
176 | ) / math.sqrt(self.head_dim)
177 |
178 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
179 | raise ValueError(
180 | f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
181 | f' {attn_weights.size()}'
182 | )
183 |
184 | if attention_mask is not None:
185 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
186 | raise ValueError(
187 | f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
188 | )
189 | attn_weights = attn_weights + attention_mask
190 | attn_weights = torch.max(
191 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
192 | )
193 |
194 | # upcast attention to fp32
195 | attn_weights = nn.functional.softmax(
196 | attn_weights, dim=-1, dtype=torch.float32
197 | ).to(query_states.dtype)
198 | attn_output = torch.matmul(attn_weights, value_states)
199 |
200 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
201 | raise ValueError(
202 | f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
203 | f' {attn_output.size()}'
204 | )
205 |
206 | attn_output = attn_output.transpose(1, 2)
207 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
208 |
209 | attn_output = self.o_proj(attn_output)
210 |
211 | if not output_attentions:
212 | attn_weights = None
213 |
214 | return attn_output, attn_weights, past_key_value
215 |
216 |
217 | def replace_llama_attn_with_flash_attn():
218 | if hasattr(F, 'scaled_dot_product_attention'):
219 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
220 | else:
221 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
222 | _prepare_decoder_attention_mask
223 | )
224 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
225 |
--------------------------------------------------------------------------------
/internvl/patch/llama_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
9 | from transformers.models.llama.modeling_llama import (LLAMA_ATTENTION_CLASSES,
10 | LlamaFlashAttention2)
11 |
12 |
13 | # Modified from transformers.models.llama.modeling_llama.LlamaFlashAttention2
14 | class LlamaFlashAttention2ForPackedTraining(LlamaFlashAttention2):
15 |
16 | def _flash_attention_forward(
17 | self,
18 | query_states,
19 | key_states,
20 | value_states,
21 | attention_mask,
22 | query_length,
23 | dropout=0.0,
24 | softmax_scale=None,
25 | use_sliding_windows=False,
26 | ):
27 | """
28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
29 | first unpad the input, then computes the attention scores and pad the final attention scores.
30 |
31 | Args:
32 | query_states (`torch.Tensor`):
33 | Input query states to be passed to Flash Attention API
34 | key_states (`torch.Tensor`):
35 | Input key states to be passed to Flash Attention API
36 | value_states (`torch.Tensor`):
37 | Input value states to be passed to Flash Attention API
38 | attention_mask (`torch.Tensor`):
39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
40 | position of padding tokens and 1 for the position of non-padding tokens.
41 | dropout (`int`, *optional*):
42 | Attention dropout
43 | softmax_scale (`float`, *optional*):
44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
45 | use_sliding_windows (`bool`, *optional*):
46 | Whether to activate sliding window attention.
47 | """
48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
49 | query_states = query_states.squeeze(0)
50 | key_states = key_states.squeeze(0)
51 | value_states = value_states.squeeze(0)
52 | cu_seqlens = attention_mask.squeeze(0)
53 |
54 | with torch.no_grad():
55 | max_seqlen = max(
56 | [
57 | cu_seqlens[idx + 1] - cu_seqlens[idx]
58 | for idx in range(cu_seqlens.size(0) - 1)
59 | ]
60 | ).item()
61 |
62 | if not self._flash_attn_uses_top_left_mask:
63 | causal = self.is_causal
64 | else:
65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
66 | causal = self.is_causal and query_length != 1
67 |
68 | # Decide whether to use SWA or not by layer index.
69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
70 | use_sliding_windows = False
71 |
72 | if not use_sliding_windows:
73 | attn_output = flash_attn_varlen_func(
74 | q=query_states,
75 | k=key_states,
76 | v=value_states,
77 | cu_seqlens_q=cu_seqlens,
78 | cu_seqlens_k=cu_seqlens,
79 | max_seqlen_q=max_seqlen,
80 | max_seqlen_k=max_seqlen,
81 | dropout_p=dropout,
82 | softmax_scale=softmax_scale,
83 | causal=causal,
84 | )
85 | else:
86 | attn_output = flash_attn_varlen_func(
87 | q=query_states,
88 | k=key_states,
89 | v=value_states,
90 | cu_seqlens_q=cu_seqlens,
91 | cu_seqlens_k=cu_seqlens,
92 | max_seqlen_q=max_seqlen,
93 | max_seqlen_k=max_seqlen,
94 | dropout_p=dropout,
95 | softmax_scale=softmax_scale,
96 | causal=causal,
97 | window_size=(self.config.sliding_window, self.config.sliding_window),
98 | )
99 |
100 | query_states = query_states.unsqueeze(0)
101 | key_states = key_states.unsqueeze(0)
102 | value_states = value_states.unsqueeze(0)
103 | return attn_output
104 |
105 |
106 | def replace_llama_attention_class():
107 | LLAMA_ATTENTION_CLASSES['flash_attention_2'] = LlamaFlashAttention2ForPackedTraining
108 | print('Replace LLAMA_ATTENTION_CLASSES to support packed training!!')
109 |
--------------------------------------------------------------------------------
/internvl/patch/llama_rmsnorm_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import transformers
8 |
9 |
10 | def replace_llama_rmsnorm_with_fused_rmsnorm():
11 | try:
12 | from functools import partial
13 |
14 | from apex.normalization import FusedRMSNorm
15 |
16 | LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
17 | transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
18 | print(
19 | 'Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm'
20 | )
21 | except ImportError:
22 | # using the normal LlamaRMSNorm
23 | pass
24 | except Exception:
25 | print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
26 | pass
27 |
--------------------------------------------------------------------------------
/internvl/patch/pad_data_collator.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import numpy as np
8 | import torch
9 |
10 | IGNORE_INDEX = -100
11 |
12 |
13 | def pad_data_collator(features, pad_id=0):
14 |
15 | first = features[0]
16 | batch = {}
17 |
18 | batch_lens = [feat['input_ids'].shape for feat in features]
19 | max_item_length = max(batch_lens)[0]
20 | for idx in range(len(features)):
21 | feat = features[idx]
22 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
23 | temp_input_ids[: feat['input_ids'].shape[0]] = feat['input_ids']
24 | feat['input_ids'] = temp_input_ids
25 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
26 | temp_labels[: feat['labels'].shape[0]] = feat['labels']
27 | feat['labels'] = temp_labels
28 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
29 |
30 | # Special handling for labels.
31 | # Ensure that tensor is created with the correct type
32 | # (it should be automatically the case, but let's make sure of it.)
33 | if 'label' in first and first['label'] is not None:
34 | label = (
35 | first['label'].item()
36 | if isinstance(first['label'], torch.Tensor)
37 | else first['label']
38 | )
39 | dtype = torch.long if isinstance(label, int) else torch.float
40 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
41 | elif 'label_ids' in first and first['label_ids'] is not None:
42 | if isinstance(first['label_ids'], torch.Tensor):
43 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
44 | else:
45 | dtype = (
46 | torch.long if isinstance(first['label_ids'][0], int) else torch.float
47 | )
48 | batch['labels'] = torch.tensor(
49 | [f['label_ids'] for f in features], dtype=dtype
50 | )
51 |
52 | # Handling of all other possible keys.
53 | # Again, we will use the first element to figure out which key/values are not None for this model.
54 | for k, v in first.items():
55 | if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
56 | if isinstance(v, torch.Tensor):
57 | batch[k] = torch.stack([f[k] for f in features])
58 | elif isinstance(v, np.ndarray):
59 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
60 | else:
61 | batch[k] = torch.tensor([f[k] for f in features])
62 | return batch
63 |
64 |
65 | def concat_pad_data_collator(features, max_item_length=None, pad_id=0):
66 |
67 | first = features[0]
68 | batch = {}
69 |
70 | batch_lens = [feat['input_ids'].shape for feat in features]
71 | max_item_length = max_item_length or max(batch_lens)[0]
72 | for idx in range(len(features)):
73 | feat = features[idx]
74 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
75 | temp_input_ids[: feat['input_ids'].shape[0]] = feat['input_ids']
76 | feat['input_ids'] = temp_input_ids
77 | temp_labels = torch.FloatTensor([IGNORE_INDEX] * max_item_length)
78 | temp_labels[: feat['labels'].shape[0]] = feat['labels']
79 | feat['labels'] = temp_labels
80 | feat['attention_mask'] = feat['input_ids'].ne(pad_id)
81 |
82 | if 'position_ids' in feat:
83 | temp_position_ids = torch.LongTensor([pad_id] * max_item_length)
84 | temp_position_ids[: feat['position_ids'].shape[0]] = feat['position_ids']
85 | feat['position_ids'] = temp_position_ids
86 |
87 | if 'loss_weight' in feat:
88 | temp_loss_weight = torch.FloatTensor([pad_id] * max_item_length)
89 | temp_loss_weight[: feat['loss_weight'].shape[0]] = feat['loss_weight']
90 | feat['loss_weight'] = temp_loss_weight
91 |
92 | # Special handling for labels.
93 | # Ensure that tensor is created with the correct type
94 | # (it should be automatically the case, but let's make sure of it.)
95 | if 'label' in first and first['label'] is not None:
96 | label = (
97 | first['label'].item()
98 | if isinstance(first['label'], torch.Tensor)
99 | else first['label']
100 | )
101 | dtype = torch.float
102 | batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
103 | elif 'label_ids' in first and first['label_ids'] is not None:
104 | if isinstance(first['label_ids'], torch.Tensor):
105 | batch['labels'] = torch.stack([f['label_ids'] for f in features])
106 | else:
107 | dtype = torch.float
108 | batch['labels'] = torch.tensor(
109 | [f['label_ids'] for f in features], dtype=dtype
110 | )
111 |
112 | # Handling of all other possible keys.
113 | # Again, we will use the first element to figure out which key/values are not None for this model.
114 | for k, v in first.items():
115 | if (
116 | k not in ('label', 'label_ids', 'pixel_values', 'image_flags')
117 | and v is not None
118 | and not isinstance(v, str)
119 | ):
120 | if isinstance(v, torch.Tensor):
121 | batch[k] = torch.stack([f[k] for f in features])
122 | elif isinstance(v, np.ndarray):
123 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
124 | else:
125 | batch[k] = torch.tensor([f[k] for f in features])
126 | if k in ('pixel_values', 'image_flags'):
127 | if isinstance(v, torch.Tensor):
128 | batch[k] = torch.concat([f[k] for f in features])
129 | elif isinstance(v, np.ndarray):
130 | batch[k] = torch.concat(np.stack([f[k] for f in features]))
131 | else:
132 | batch[k] = torch.concat([f[k] for f in features])
133 | return batch
134 |
135 |
136 | def dpo_concat_pad_data_collator(features, pad_id=0):
137 |
138 | first = features[0]
139 | batch = {}
140 |
141 | for prefix in ['chosen_', 'rejected_']:
142 | batch_lens = [feat[f'{prefix}input_ids'].shape[0] for feat in features]
143 | max_item_length = max(batch_lens)
144 | for idx in range(len(features)):
145 | feat = features[idx]
146 | temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
147 | temp_input_ids[: feat[f'{prefix}input_ids'].shape[0]] = feat[
148 | f'{prefix}input_ids'
149 | ]
150 | feat[f'{prefix}input_ids'] = temp_input_ids
151 | temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
152 | temp_labels[: feat[f'{prefix}labels'].shape[0]] = feat[f'{prefix}labels']
153 | feat[f'{prefix}labels'] = temp_labels
154 | feat[f'{prefix}attention_mask'] = feat[f'{prefix}input_ids'].ne(pad_id)
155 |
156 | # Handling of all other possible keys.
157 | # Again, we will use the first element to figure out which key/values are not None for this model.
158 | for k, v in first.items():
159 | if (
160 | k not in ('pixel_values', 'image_flags')
161 | and v is not None
162 | and not isinstance(v, str)
163 | ):
164 | if isinstance(v, torch.Tensor):
165 | batch[k] = torch.stack([f[k] for f in features])
166 | elif isinstance(v, np.ndarray):
167 | batch[k] = torch.tensor(np.stack([f[k] for f in features]))
168 | else:
169 | batch[k] = torch.tensor([f[k] for f in features])
170 | if k in ('pixel_values', 'image_flags'):
171 | if isinstance(v, torch.Tensor):
172 | batch[k] = torch.concat([f[k] for f in features])
173 | elif isinstance(v, np.ndarray):
174 | batch[k] = torch.concat(np.stack([f[k] for f in features]))
175 | else:
176 | batch[k] = torch.concat([f[k] for f in features])
177 | return batch
178 |
--------------------------------------------------------------------------------
/internvl/patch/phi3_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
9 |
10 | from internvl.model.phi3.modeling_phi3 import (PHI3_ATTENTION_CLASSES,
11 | Phi3FlashAttention2)
12 |
13 |
14 | class Phi3FlashAttention2ForPackedTraining(Phi3FlashAttention2):
15 |
16 | def _flash_attention_forward(
17 | self,
18 | query_states,
19 | key_states,
20 | value_states,
21 | attention_mask,
22 | query_length,
23 | dropout=0.0,
24 | softmax_scale=None,
25 | use_sliding_windows=False,
26 | ):
27 | """
28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
29 | first unpad the input, then computes the attention scores and pad the final attention scores.
30 |
31 | Args:
32 | query_states (`torch.Tensor`):
33 | Input query states to be passed to Flash Attention API
34 | key_states (`torch.Tensor`):
35 | Input key states to be passed to Flash Attention API
36 | value_states (`torch.Tensor`):
37 | Input value states to be passed to Flash Attention API
38 | attention_mask (`torch.Tensor`):
39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
40 | position of padding tokens and 1 for the position of non-padding tokens.
41 | dropout (`float`):
42 | Attention dropout
43 | softmax_scale (`float`, *optional*):
44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
45 | use_sliding_windows (`bool`, *optional*):
46 | Whether to activate sliding window attention.
47 | """
48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
49 | query_states = query_states.squeeze(0)
50 | key_states = key_states.squeeze(0)
51 | value_states = value_states.squeeze(0)
52 | cu_seqlens = attention_mask.squeeze(0)
53 |
54 | with torch.no_grad():
55 | max_seqlen = max(
56 | [
57 | cu_seqlens[idx + 1] - cu_seqlens[idx]
58 | for idx in range(cu_seqlens.size(0) - 1)
59 | ]
60 | ).item()
61 |
62 | if not self._flash_attn_uses_top_left_mask:
63 | causal = self.is_causal
64 | else:
65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
66 | causal = self.is_causal and query_length != 1
67 |
68 | # Decide whether to use SWA or not by layer index.
69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
70 | use_sliding_windows = False
71 |
72 | if not use_sliding_windows:
73 | attn_output = flash_attn_varlen_func(
74 | q=query_states,
75 | k=key_states,
76 | v=value_states,
77 | cu_seqlens_q=cu_seqlens,
78 | cu_seqlens_k=cu_seqlens,
79 | max_seqlen_q=max_seqlen,
80 | max_seqlen_k=max_seqlen,
81 | dropout_p=dropout,
82 | softmax_scale=softmax_scale,
83 | causal=causal,
84 | )
85 | else:
86 | attn_output = flash_attn_varlen_func(
87 | q=query_states,
88 | k=key_states,
89 | v=value_states,
90 | cu_seqlens_q=cu_seqlens,
91 | cu_seqlens_k=cu_seqlens,
92 | max_seqlen_q=max_seqlen,
93 | max_seqlen_k=max_seqlen,
94 | dropout_p=dropout,
95 | softmax_scale=softmax_scale,
96 | causal=causal,
97 | window_size=(self.config.sliding_window, self.config.sliding_window),
98 | )
99 |
100 | query_states = query_states.unsqueeze(0)
101 | key_states = key_states.unsqueeze(0)
102 | value_states = value_states.unsqueeze(0)
103 | return attn_output
104 |
105 |
106 | def replace_phi3_attention_class():
107 | PHI3_ATTENTION_CLASSES['flash_attention_2'] = Phi3FlashAttention2ForPackedTraining
108 | print('Replace PHI3_ATTENTION_CLASSES to support packed training!!')
109 |
--------------------------------------------------------------------------------
/internvl/patch/qwen2_packed_training_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import torch
8 | from flash_attn.flash_attn_interface import flash_attn_varlen_func
9 | from transformers.models.qwen2.modeling_qwen2 import (QWEN2_ATTENTION_CLASSES,
10 | Qwen2FlashAttention2)
11 |
12 |
13 | # Modified from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2
14 | class Qwen2FlashAttention2ForPackedTraining(Qwen2FlashAttention2):
15 |
16 | def _flash_attention_forward(
17 | self,
18 | query_states,
19 | key_states,
20 | value_states,
21 | attention_mask,
22 | query_length,
23 | dropout=0.0,
24 | softmax_scale=None,
25 | use_sliding_windows=False,
26 | ):
27 | """
28 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
29 | first unpad the input, then computes the attention scores and pad the final attention scores.
30 |
31 | Args:
32 | query_states (`torch.Tensor`):
33 | Input query states to be passed to Flash Attention API
34 | key_states (`torch.Tensor`):
35 | Input key states to be passed to Flash Attention API
36 | value_states (`torch.Tensor`):
37 | Input value states to be passed to Flash Attention API
38 | attention_mask (`torch.Tensor`):
39 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
40 | position of padding tokens and 1 for the position of non-padding tokens.
41 | dropout (`int`, *optional*):
42 | Attention dropout
43 | softmax_scale (`float`, *optional*):
44 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
45 | use_sliding_windows (`bool`, *optional*):
46 | Whether to activate sliding window attention.
47 | """
48 | assert query_states.size(0) == key_states.size(0) == value_states.size(0) == 1
49 | query_states = query_states.squeeze(0)
50 | key_states = key_states.squeeze(0)
51 | value_states = value_states.squeeze(0)
52 | cu_seqlens = attention_mask.squeeze(0)
53 |
54 | with torch.no_grad():
55 | max_seqlen = max(
56 | [
57 | cu_seqlens[idx + 1] - cu_seqlens[idx]
58 | for idx in range(cu_seqlens.size(0) - 1)
59 | ]
60 | ).item()
61 |
62 | if not self._flash_attn_uses_top_left_mask:
63 | causal = self.is_causal
64 | else:
65 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
66 | causal = self.is_causal and query_length != 1
67 |
68 | # Decide whether to use SWA or not by layer index.
69 | if use_sliding_windows and self.layer_idx >= self.config.max_window_layers:
70 | use_sliding_windows = False
71 |
72 | if not use_sliding_windows:
73 | attn_output = flash_attn_varlen_func(
74 | q=query_states,
75 | k=key_states,
76 | v=value_states,
77 | cu_seqlens_q=cu_seqlens,
78 | cu_seqlens_k=cu_seqlens,
79 | max_seqlen_q=max_seqlen,
80 | max_seqlen_k=max_seqlen,
81 | dropout_p=dropout,
82 | softmax_scale=softmax_scale,
83 | causal=causal,
84 | )
85 | else:
86 | attn_output = flash_attn_varlen_func(
87 | q=query_states,
88 | k=key_states,
89 | v=value_states,
90 | cu_seqlens_q=cu_seqlens,
91 | cu_seqlens_k=cu_seqlens,
92 | max_seqlen_q=max_seqlen,
93 | max_seqlen_k=max_seqlen,
94 | dropout_p=dropout,
95 | softmax_scale=softmax_scale,
96 | causal=causal,
97 | window_size=(self.config.sliding_window, self.config.sliding_window),
98 | )
99 |
100 | query_states = query_states.unsqueeze(0)
101 | key_states = key_states.unsqueeze(0)
102 | value_states = value_states.unsqueeze(0)
103 | return attn_output
104 |
105 |
106 | def replace_qwen2_attention_class():
107 | QWEN2_ATTENTION_CLASSES['flash_attention_2'] = Qwen2FlashAttention2ForPackedTraining
108 | print('Replace QWEN2_ATTENTION_CLASSES to support packed training!!')
109 |
--------------------------------------------------------------------------------
/internvl/patch/train_dataloader_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | import datasets
8 | import torch
9 | import transformers
10 | from torch.utils.data import DataLoader
11 | from transformers.trainer import is_datasets_available, seed_worker
12 |
13 |
14 | def get_train_dataloader(self) -> DataLoader:
15 | """
16 | Returns the training [`~torch.utils.data.DataLoader`].
17 |
18 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
19 | training if necessary) otherwise.
20 |
21 | Subclass and override this method if you want to inject some custom behavior.
22 | """
23 | if self.train_dataset is None:
24 | raise ValueError('Trainer: training requires a train_dataset.')
25 |
26 | train_dataset = self.train_dataset
27 | data_collator = self.data_collator
28 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
29 | train_dataset = self._remove_unused_columns(
30 | train_dataset, description='training'
31 | )
32 | else:
33 | data_collator = self._get_collator_with_removed_columns(
34 | data_collator, description='training'
35 | )
36 |
37 | dataloader_params = {
38 | 'batch_size': self._train_batch_size,
39 | 'collate_fn': data_collator,
40 | 'num_workers': self.args.dataloader_num_workers,
41 | 'pin_memory': self.args.dataloader_pin_memory,
42 | 'persistent_workers': self.args.dataloader_persistent_workers,
43 | }
44 |
45 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
46 | dataloader_params['sampler'] = self._get_train_sampler()
47 | dataloader_params['drop_last'] = self.args.dataloader_drop_last
48 | dataloader_params['worker_init_fn'] = seed_worker
49 |
50 | if self.args.use_packed_ds:
51 | return DataLoader(train_dataset, **dataloader_params)
52 | return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
53 |
54 |
55 | def replace_train_dataloader():
56 | transformers.Trainer.get_train_dataloader = get_train_dataloader
57 | # print('Replace train dataloader!!')
58 |
--------------------------------------------------------------------------------
/internvl/patch/train_sampler_patch.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import transformers
11 | from torch.utils.data import Dataset, Sampler
12 | from transformers.tokenization_utils_base import BatchEncoding
13 | from transformers.trainer import (LengthGroupedSampler, RandomSampler,
14 | has_length)
15 | from transformers.trainer_pt_utils import logger
16 |
17 |
18 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
19 | def split_to_even_chunks(indices, lengths, num_chunks):
20 | """
21 | Split a list of indices into `chunks` chunks of roughly equal lengths.
22 | """
23 |
24 | if len(indices) % num_chunks != 0:
25 | return [indices[i::num_chunks] for i in range(num_chunks)]
26 |
27 | num_indices_per_chunk = len(indices) // num_chunks
28 |
29 | chunks = [[] for _ in range(num_chunks)]
30 | chunks_lengths = [0 for _ in range(num_chunks)]
31 | for index in indices:
32 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
33 | chunks[shortest_chunk].append(index)
34 | chunks_lengths[shortest_chunk] += lengths[index]
35 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
36 | chunks_lengths[shortest_chunk] = float('inf')
37 |
38 | return chunks
39 |
40 |
41 | # copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
42 | def get_length_grouped_indices(
43 | lengths, batch_size, world_size, generator=None, merge=True
44 | ):
45 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
46 | indices = torch.randperm(len(lengths), generator=generator)
47 | megabatch_size = world_size * batch_size
48 | megabatches = [
49 | indices[i : i + megabatch_size].tolist()
50 | for i in range(0, len(lengths), megabatch_size)
51 | ]
52 | megabatches = [
53 | sorted(megabatch, key=lambda i: lengths[i], reverse=True)
54 | for megabatch in megabatches
55 | ]
56 | megabatches = [
57 | split_to_even_chunks(megabatch, lengths, world_size)
58 | for megabatch in megabatches
59 | ]
60 |
61 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
62 |
63 |
64 | # modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
65 | class LengthGroupedSampler(Sampler):
66 | r"""
67 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
68 | keeping a bit of randomness.
69 | """
70 |
71 | def __init__(
72 | self,
73 | batch_size: int,
74 | world_size: int,
75 | dataset: Optional[Dataset] = None,
76 | lengths: Optional[List[int]] = None,
77 | model_input_name: Optional[str] = None,
78 | generator=None,
79 | ):
80 | if dataset is None and lengths is None:
81 | raise ValueError('One of dataset and lengths must be provided.')
82 |
83 | self.batch_size = batch_size
84 | if lengths is None:
85 | model_input_name = (
86 | model_input_name if model_input_name is not None else 'input_ids'
87 | )
88 | if (
89 | not (
90 | isinstance(dataset[0], dict)
91 | or isinstance(dataset[0], BatchEncoding)
92 | )
93 | or model_input_name not in dataset[0]
94 | ):
95 | raise ValueError(
96 | 'Can only automatically infer lengths for datasets whose items are dictionaries with an '
97 | f"'{model_input_name}' key."
98 | )
99 | lengths = [len(feature[model_input_name]) for feature in dataset]
100 | elif isinstance(lengths, torch.Tensor):
101 | logger.info(
102 | 'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
103 | )
104 | lengths = lengths.tolist()
105 | self.world_size = world_size
106 | self.lengths = lengths
107 | self.generator = generator
108 |
109 | def __len__(self):
110 | return len(self.lengths)
111 |
112 | def __iter__(self):
113 | indices = get_length_grouped_indices(
114 | self.lengths, self.batch_size, self.world_size, generator=self.generator
115 | )
116 | return iter(indices)
117 |
118 |
119 | # patch trainer
120 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
121 | if self.train_dataset is None or not has_length(self.train_dataset):
122 | return None
123 | # Build the sampler.
124 | if self.args.group_by_length:
125 | lengths = []
126 | for dataset in self.train_dataset.datasets:
127 | lengths = lengths + dataset.length
128 | model_input_name = (
129 | self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
130 | )
131 | return LengthGroupedSampler(
132 | self.args.train_batch_size,
133 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
134 | # self.args.train_batch_size * self.args.gradient_accumulation_steps,
135 | dataset=self.train_dataset,
136 | lengths=lengths,
137 | model_input_name=model_input_name,
138 | )
139 | else:
140 | return RandomSampler(self.train_dataset)
141 |
142 |
143 | def replace_train_sampler():
144 | transformers.Trainer._get_train_sampler = _get_train_sampler
145 | # print('Replace train sampler!!')
146 |
--------------------------------------------------------------------------------
/internvl/train/__init__.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
--------------------------------------------------------------------------------
/internvl/train/constants.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InternVL
3 | # Copyright (c) 2024 OpenGVLab
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # --------------------------------------------------------
6 |
7 | IMG_CONTEXT_TOKEN = ''
8 | IMG_START_TOKEN = '
'
9 | IMG_END_TOKEN = ''
10 | QUAD_START_TOKEN = ''
11 | QUAD_END_TOKEN = ''
12 | REF_START_TOKEN = '['
13 | REF_END_TOKEN = ']'
14 | BOX_START_TOKEN = ''
15 | BOX_END_TOKEN = ''
16 | PRM_TOKEN = ''
17 | REWARD_TOKENS = ['Yes', 'No']
18 | IMAGENET_MEAN = (0.485, 0.456, 0.406)
19 | IMAGENET_STD = (0.229, 0.224, 0.225)
20 | CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
21 | CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
22 | SIGLIP_MEAN = (0.5, 0.5, 0.5)
23 | SIGLIP_STD = (0.5, 0.5, 0.5)
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | -r requirements/internvl_chat.txt
2 | -r requirements/streamlit_demo.txt
3 | -r requirements/classification.txt
4 | -r requirements/segmentation.txt
5 |
--------------------------------------------------------------------------------
/requirements/classification.txt:
--------------------------------------------------------------------------------
1 | gdown
2 | termcolor
3 | yacs
4 |
--------------------------------------------------------------------------------
/requirements/clip_benchmark.txt:
--------------------------------------------------------------------------------
1 | open_clip_torch>=0.2.1
2 | opencv-python
3 | peft>=0.6.2
4 | protobuf
5 | pycocoevalcap
6 | pyyaml
7 | scikit-learn>=1.0,<2
8 | scikit-learn
9 | scipy
10 | task_adaptation
11 | tensorflow==2.11.0
12 | termcolor
13 | tqdm>=2
14 | transformers>=4.32.0
15 | webdataset>=0.2.31
16 | yacs
17 |
--------------------------------------------------------------------------------
/requirements/internvl_chat.txt:
--------------------------------------------------------------------------------
1 | accelerate<1
2 | bitsandbytes==0.42.0
3 | decord
4 | deepspeed>=0.13.5
5 | einops==0.6.1
6 | einops-exts==0.0.4
7 | huggingface_hub
8 | imageio
9 | numpy==1.26.4
10 | opencv-python
11 | orjson
12 | peft==0.10.0
13 | pycocoevalcap
14 | pyyaml
15 | scikit-learn>=1.2.2
16 | scipy
17 | sentencepiece==0.1.99
18 | shortuuid
19 | tensorboardX
20 | termcolor
21 | timm==0.9.12
22 | tokenizers==0.15.1
23 | torch>=2
24 | torchvision>=0.15
25 | tqdm
26 | transformers==4.37.2
27 | yacs
28 |
--------------------------------------------------------------------------------
/requirements/segmentation.txt:
--------------------------------------------------------------------------------
1 | future
2 | importlib_metadata
3 | mmcv-full==1.6.2
4 | mmsegmentation==0.30.0
5 | openmim
6 | ordered-set
7 | platformdirs
8 | tensorboard
9 | tomli
10 | yapf==0.40.1
11 |
--------------------------------------------------------------------------------
/requirements/streamlit_demo.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | gradio==3.35.2
3 | gradio_client==0.2.9
4 | httpx==0.24.0
5 | markdown2[all]
6 | pydantic
7 | requests
8 | streamlit
9 | streamlit-image-select
10 | uvicorn
11 |
--------------------------------------------------------------------------------
/shell/internvl2.5/2nd_finetune/internvl2_5_38b_dynamic_res_2nd_finetune_full_prm.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | GPUS=${GPUS:-8}
4 | BATCH_SIZE=${BATCH_SIZE:-512}
5 | PER_DEVICE_BATCH_SIZE=${PER_DEVICE_BATCH_SIZE:-2}
6 | GRADIENT_ACC=$((BATCH_SIZE / PER_DEVICE_BATCH_SIZE / GPUS))
7 |
8 |
9 | export PYTHONPATH="${PYTHONPATH}:$(pwd)"
10 | export MASTER_PORT=34229
11 | export TF_CPP_MIN_LOG_LEVEL=3
12 | export LAUNCHER=pytorch
13 |
14 | OUTPUT_DIR='work_dirs/internvl_chat_v2_5/prm-8b'
15 |
16 | if [ ! -d "$OUTPUT_DIR" ]; then
17 | mkdir -p "$OUTPUT_DIR"
18 | fi
19 |
20 | torchrun \
21 | --nnodes=${NNODES} \
22 | --node_rank=${NODE_RANK} \
23 | --master_addr=${MASTER_ADDR} \
24 | --nproc_per_node=${NPROC_PER_NODE} \
25 | --master_port=${MASTER_PORT} \
26 | internvl/train/internvl_chat_finetune.py \
27 | --model_name_or_path "/path/to/model" \
28 | --conv_style "internvl2_5" \
29 | --use_fast_tokenizer False \
30 | --output_dir ${OUTPUT_DIR} \
31 | --meta_path "./shell/data/meta.json" \
32 | --overwrite_output_dir True \
33 | --force_image_size 448 \
34 | --max_dynamic_patch 6 \
35 | --down_sample_ratio 0.5 \
36 | --drop_path_rate 0.4 \
37 | --freeze_llm False \
38 | --freeze_mlp False \
39 | --freeze_backbone True \
40 | --vision_select_layer -1 \
41 | --dataloader_num_workers 4 \
42 | --bf16 True \
43 | --num_train_epochs 3 \
44 | --per_device_train_batch_size ${PER_DEVICE_BATCH_SIZE} \
45 | --gradient_accumulation_steps ${GRADIENT_ACC} \
46 | --save_strategy "epoch" \
47 | --save_total_limit 5 \
48 | --learning_rate 2e-6 \
49 | --weight_decay 0.05 \
50 | --warmup_ratio 0.03 \
51 | --lr_scheduler_type "cosine" \
52 | --logging_steps 1 \
53 | --max_seq_length 8192 \
54 | --do_train True \
55 | --grad_checkpoint True \
56 | --group_by_length True \
57 | --dynamic_image_size True \
58 | --use_thumbnail True \
59 | --ps_version 'v2' \
60 | --deepspeed "zero_stage1_config.json" \
61 | --report_to "tensorboard" \
62 | 2>&1 | tee -a "${OUTPUT_DIR}/training_log.txt"
63 |
--------------------------------------------------------------------------------
/zero_stage1_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 1,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e9,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e9,
9 | "contiguous_gradients": true
10 | },
11 | "fp16": {
12 | "enabled": "auto",
13 | "auto_cast": true,
14 | "loss_scale": 0,
15 | "initial_scale_power": 32,
16 | "loss_scale_window": 1000,
17 | "hysteresis": 2,
18 | "min_loss_scale": 1
19 | },
20 | "bf16": {
21 | "enabled": "auto"
22 | },
23 | "optimizer": {
24 | "type": "AdamW",
25 | "params": {
26 | "lr": "auto",
27 | "betas": [
28 | 0.9,
29 | 0.999
30 | ],
31 | "eps": 1e-8,
32 | "weight_decay": "auto"
33 | }
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 2000,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": true
41 | }
42 |
--------------------------------------------------------------------------------
/zero_stage2_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 2,
4 | "allgather_partitions": true,
5 | "allgather_bucket_size": 1e8,
6 | "overlap_comm": true,
7 | "reduce_scatter": true,
8 | "reduce_bucket_size": 1e8,
9 | "contiguous_gradients": true
10 | },
11 | "fp16": {
12 | "enabled": "auto",
13 | "auto_cast": true,
14 | "loss_scale": 0,
15 | "initial_scale_power": 32,
16 | "loss_scale_window": 1000,
17 | "hysteresis": 2,
18 | "min_loss_scale": 1
19 | },
20 | "bf16": {
21 | "enabled": "auto"
22 | },
23 | "optimizer": {
24 | "type": "AdamW",
25 | "params": {
26 | "lr": "auto",
27 | "betas": [
28 | 0.9,
29 | 0.999
30 | ],
31 | "eps": 1e-8,
32 | "weight_decay": "auto"
33 | }
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 2000,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
42 |
--------------------------------------------------------------------------------
/zero_stage3_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e9,
7 | "reduce_bucket_size": 1e9,
8 | "stage3_prefetch_bucket_size": 1e9,
9 | "stage3_param_persistence_threshold": 1e7,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------
/zero_stage3_config_100b.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e9,
7 | "reduce_bucket_size": 1e9,
8 | "stage3_prefetch_bucket_size": 1e9,
9 | "stage3_param_persistence_threshold": 1e4,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------
/zero_stage3_config_100b_1e8.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e8,
7 | "reduce_bucket_size": 1e8,
8 | "stage3_prefetch_bucket_size": 1e8,
9 | "stage3_param_persistence_threshold": 1e4,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------
/zero_stage3_config_34b.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e9,
7 | "reduce_bucket_size": 1e9,
8 | "stage3_prefetch_bucket_size": 1e9,
9 | "stage3_param_persistence_threshold": 1e5,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------
/zero_stage3_config_70b.json:
--------------------------------------------------------------------------------
1 | {
2 | "zero_optimization": {
3 | "stage": 3,
4 | "overlap_comm": true,
5 | "contiguous_gradients": true,
6 | "sub_group_size": 1e9,
7 | "reduce_bucket_size": 1e9,
8 | "stage3_prefetch_bucket_size": 1e9,
9 | "stage3_param_persistence_threshold": 1e5,
10 | "stage3_max_live_parameters": 1e9,
11 | "stage3_max_reuse_distance": 1e9,
12 | "stage3_gather_16bit_weights_on_model_save": true
13 | },
14 | "fp16": {
15 | "enabled": "auto",
16 | "auto_cast": true,
17 | "loss_scale": 0,
18 | "initial_scale_power": 32,
19 | "loss_scale_window": 1000,
20 | "hysteresis": 2,
21 | "min_loss_scale": 1
22 | },
23 | "bf16": {
24 | "enabled": "auto"
25 | },
26 | "optimizer": {
27 | "type": "AdamW",
28 | "params": {
29 | "lr": "auto",
30 | "betas": [
31 | 0.9,
32 | 0.999
33 | ],
34 | "eps": 1e-8,
35 | "weight_decay": "auto"
36 | }
37 | },
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "steps_per_print": 2000,
41 | "train_batch_size": "auto",
42 | "train_micro_batch_size_per_gpu": "auto",
43 | "wall_clock_breakdown": true
44 | }
45 |
--------------------------------------------------------------------------------