├── .gitignore
├── README.md
├── asset
├── chartmoe-logo.jpg
├── teaser.png
└── train_pipeline.png
├── chartmoe
├── __init__.py
├── eval_ChartQA.py
├── eval_MME.ipynb
├── generation_utils.py
├── train
│ ├── README.md
│ ├── chartmoe_construction.py
│ ├── chartmoe_trainer.py
│ ├── data
│ │ ├── code_align.txt
│ │ ├── json_align.txt
│ │ ├── sft.txt
│ │ └── table_align.txt
│ ├── data_mix.py
│ ├── ds_config_zero2.json
│ ├── mlp_moe.py
│ ├── moe_construction.py
│ ├── scripts
│ │ ├── chartmoe_construction.sh
│ │ ├── chartmoe_data_download.py
│ │ ├── chartmoe_download.py
│ │ ├── code_align.sh
│ │ ├── internlm_xc2_download.py
│ │ ├── json_align.sh
│ │ ├── moe_construction.sh
│ │ ├── multi_align.sh
│ │ ├── sft.sh
│ │ └── table_align.sh
│ └── train.py
└── utils
│ └── custom_path.py
├── examples
├── bar2.png
├── bar2_highlight.png
├── line.png
├── line3.png
├── line3_edit.png
├── pie1-to-bar.png
└── pie1.png
├── gradio_demo.py
├── gradio_demo_pics
├── gradio_demo1.jpg
├── gradio_demo2.jpg
├── gradio_demo3.jpg
├── robot.png
└── user.png
├── quickstart.py
├── requirements.txt
├── setup.py
└── vis_tokens.ipynb
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
105 | __pypackages__/
106 |
107 | # Celery stuff
108 | celerybeat-schedule
109 | celerybeat.pid
110 |
111 | # SageMath parsed files
112 | *.sage.py
113 |
114 | # Environments
115 | .env
116 | .venv
117 | env/
118 | venv/
119 | ENV/
120 | env.bak/
121 | venv.bak/
122 |
123 | # Spyder project settings
124 | .spyderproject
125 | .spyproject
126 |
127 | # Rope project settings
128 | .ropeproject
129 |
130 | # mkdocs documentation
131 | /site
132 |
133 | # mypy
134 | .mypy_cache/
135 | .dmypy.json
136 | dmypy.json
137 |
138 | # Pyre type checker
139 | .pyre/
140 |
141 | # pytype static type analyzer
142 | .pytype/
143 |
144 | # Cython debug symbols
145 | cython_debug/
146 |
147 | # project-specific
148 | output/
149 | debug*/
150 | results/
151 |
152 | # cache root
153 | cache/
154 |
155 | # DS_Store
156 | **/.DS_Store
157 |
158 | # pycharm .idea
159 | .idea/
160 |
161 | # wandb
162 | wandb/
163 |
164 | # training dirs
165 | chartmoe/train/debug/
166 | chartmoe/train/wandb/
167 | chartmoe/train/output/
168 | chartmoe/train/logs/
169 | chartmoe/train/ckpt/*
170 | chartmoe/train/data/*.json
171 | chartmoe/train/data/ChartMoE-Align
172 | chartmoe/train/data/SFT
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
ChartMoE
4 |
Mixture of Diversely Aligned Expert Connector for Chart Understanding
5 |
6 | [Zhengzhuo Xu](https://github.com/XuZhengzhuo)
1,2\*, [Bowen Qu](https://github.com/Coobiw)
1,3\*, Yiyan Qi
1\*, Sinan Du
2, Chengjin Xu
1, Chun Yuan
2, Jian Guo
1,4
7 |
8 |
1 International Digital Economy Academy (IDEA),
9 |
2 Tsinghua University,
10 |
3 Peking University,
11 |
12 |
4 Hong Kong University of Science and Technology, Guangzhou
13 |
14 |
ICLR 2025 Oral
15 |
16 | (\* equal contribution)
17 |
18 | [](https://arxiv.org/abs/2409.03277)
19 | [](https://chartmoe.github.io/)
20 | [](https://huggingface.co/IDEA-FinAI/chartmoe)
21 | [](https://huggingface.co/datasets/Coobiw/ChartMoE-Data)
22 | [](https://zhuanlan.zhihu.com/p/31634026232)
23 | [](https://mp.weixin.qq.com/s/9anQbcCahVLnXhNj7aU48Q)
24 | [](https://github.com/IDEA-FinAI/ChartMoE/issues)
25 | [](https://github.com/IDEA-FinAI/ChartMoE/issues)
26 |
27 | *If you have any question, feel free to contact [📧](mailto:brian.bw.qu@gmail.com).*
28 |
29 |
30 |
31 | 
32 |
33 | **ChartMoE** is a multimodal large language model with Mixture-of-Expert connector for advanced chart 1)understanding, 2)replot, 3)editing, 4)highlighting and 5)transformation.
34 |
35 | ## News
36 | - 2025.3.6: A reproduction of diversely aligned moe-connector is released at [🤗HF Link](https://huggingface.co/Coobiw/ChartMoE-Aligned-Connector/tree/main/moe_aligned).
37 | - 2025.2.16: ChartMoE-Data has been released at [🤗](https://huggingface.co/datasets/Coobiw/ChartMoE-Data). Please download it according to [our instruction](#download-and-organize-the-chartmoe-data).
38 | - 2025.2.15: Training codes and recipes are released! Please refer to [📖](chartmoe/train/)!
39 | - 2025.2.11: 🎉🎉🎉 ChartMoE is selected as **ICLR2025 Oral(1.8%)**!
40 | - 2025.1.23: 🎉🎉🎉 ChartMoE is accepted by **ICLR2025**!
41 | - 2024.9.10: We release ChartMoE!
42 |
43 | ## Training of ChartMoE
44 | Please refer to [📖training readme](chartmoe/train/)!
45 |
46 | ## Download and Organize the ChartMoE-Data
47 | [🤗ChartMoE Data](https://huggingface.co/datasets/Coobiw/ChartMoE-Data) has been released! You can download it by running:
48 |
49 | ```bash
50 | cd chartmoe/train
51 | python scripts/chartmoe_data_download.py
52 | ```
53 | Datasets will appear at `chartmoe/train/data`.
54 |
55 | Then, please unzip these two files.
56 | ```bash
57 | unzip ChartMoE-Align.zip
58 | unzip SFT.zip
59 | ```
60 |
61 | Additionally, I want to announce that the `ChartY_replot` in `ChartMoE-Align` contains data with higher quality and bilingual texts! It may be a good choice to sample more from `ChartY_replot`.
62 |
63 | ## Installation
64 | **Step 1.** Create a conda environment and activate it.
65 |
66 | ```bash
67 | conda create -n chartmoe_env python=3.9
68 | conda activate chartmoe_env
69 | ```
70 |
71 | **Step 2.** Install PyTorch (We use PyTorch 2.1.0 / CUDA 12.1)
72 |
73 | ```bash
74 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
75 | ```
76 |
77 | **Step 3.** Install require packages
78 |
79 | ```bash
80 | pip install -r requirements.txt
81 | ```
82 |
83 | **Step 4.** Install editable ChartMoE packages
84 |
85 | ```bash
86 | pip install -e .
87 | ```
88 |
89 | **Step 5.** (Optional) Install Flash-Attn (cuda > 11.7)
90 |
91 | ```bash
92 | pip install flash-attn==2.7.0.post2
93 | ```
94 |
95 | *Flash-Attn can bring ~30% accleration on training and ~20% on evaluation in our experiments.*
96 |
97 | p.s.: If you cannot install `flash-attn`, please set `attn_implementation` to `eager` in ChartMoE's [`config.json`](https://huggingface.co/IDEA-FinAI/chartmoe/blob/main/config.json#L10).
98 |
99 | ## Quick Start
100 | ### Huggingface Download Script of ChartMoE
101 |
102 | *Note: I've supported `flash-attn` for ChartMoE on Feb. 15. If you download chartmoe before this date, you can re-download it for acceleration.*
103 |
104 | Run:
105 | ```bash
106 | cd chartmoe/train
107 | python scripts/chartmoe_download.py
108 | ```
109 | Then, ChartMoE will appear at `chartmoe/train/ckpt/chartmoe`.
110 |
111 | ### Customize the weight path of ChartMoE
112 |
113 | Set your own [ChartMoE_HF_PATH](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L2). I suggest to use the absolute path of `chartmoe/train/ckpt/chartmoe`.
114 |
115 | ### Code Demo
116 |
117 | ```python
118 | from chartmoe import ChartMoE_Robot
119 | import torch
120 |
121 | robot = ChartMoE_Robot()
122 | image_path = "examples/bar2.png"
123 | question = "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart."
124 |
125 | history = ""
126 | with torch.cuda.amp.autocast():
127 | response, history = robot.chat(image_path=image_path, question=question, history=history)
128 |
129 | print(response)
130 | ```
131 |
132 | ## Evaluation
133 |
134 | ### ChartQA
135 | **Customize the path of ChartQA:**
136 |
137 | Set your own [ChartQA_ROOT](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L5)(including `test_human.json` and `test_augmented.json`) and [ChartQA_TEST_IMG_ROOT](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L6)(including the test images).
138 |
139 | **w/ PoT:**
140 |
141 | ```bash
142 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --save_path ./results/chartqa_results_pot --pot
143 | ```
144 |
145 | **w/o PoT:**
146 |
147 | ```bash
148 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --save_path ./results/chartqa_results
149 | ```
150 |
151 | ### MME
152 | Run `chartmoe/eval_MME.ipynb` for MME scores.
153 |
154 | ## WebUI Demo
155 |
156 | ```bash
157 | CUDA_VISIBLE_DEVICES=0 python gradio_demo.py
158 | ```
159 |
160 | 
161 |
162 | ## FAQs
163 | Q1: [CLIP: Input image size (490x490) doesn't match model (336x336)](https://github.com/IDEA-FinAI/ChartMoE/issues/6)
164 |
165 | A1: Please degrade your `transformers` according to `requiresments.txt`.
166 |
167 | ## Acknowledgement
168 | Thanks to [InternLM-XComposer2](https://github.com/InternLM/InternLM-XComposer/tree/main/InternLM-XComposer-2.0) and [CuMo](https://github.com/SHI-Labs/CuMo) for their releases of model weights and source codes! And thanks to [MMC](https://github.com/FuxiaoLiu/MMC) and [ChartGemma](https://github.com/vis-nlp/ChartGemma) for their releases of the high-quality instruction-tuning data!
169 |
170 | ## Citation
171 | If you find our idea or code inspiring, please cite our paper:
172 | ```bibtex
173 | @article{ChartMoE,
174 | title={ChartMoE: Mixture of Diversely Aligned Expert Connector for Chart Understanding},
175 | author={Zhengzhuo Xu and Bowen Qu and Yiyan Qi and Sinan Du and Chengjin Xu and Chun Yuan and Jian Guo},
176 | journal={ArXiv},
177 | year={2024},
178 | volume={abs/2409.03277},
179 | }
180 | ```
181 | This code is partially based on [ChartBench](https://chartbench.github.io/), if you use our code, please also cite:
182 | ```bibtex
183 | @article{ChartBench,
184 | title={ChartBench: A Benchmark for Complex Visual Reasoning in Charts},
185 | author={Zhengzhuo Xu and Sinan Du and Yiyan Qi and Chengjin Xu and Chun Yuan and Jian Guo},
186 | journal={ArXiv},
187 | year={2023},
188 | volume={abs/2312.15915},
189 | }
190 | ```
191 |
--------------------------------------------------------------------------------
/asset/chartmoe-logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/chartmoe-logo.jpg
--------------------------------------------------------------------------------
/asset/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/teaser.png
--------------------------------------------------------------------------------
/asset/train_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/train_pipeline.png
--------------------------------------------------------------------------------
/chartmoe/__init__.py:
--------------------------------------------------------------------------------
1 | from chartmoe.generation_utils import ChartMoE_Robot
--------------------------------------------------------------------------------
/chartmoe/eval_ChartQA.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: ChartQA Evaluation of ChartMoE, with/without PoT(Program of Thoughts)
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | from chartmoe import ChartMoE_Robot
7 | from chartmoe.utils.custom_path import ChartQA_ROOT, ChartQA_TEST_IMG_ROOT
8 |
9 | import os, sys, json, re, io
10 | import argparse
11 | import torch
12 | from tqdm import tqdm
13 | from typing import Optional
14 | from prettytable import PrettyTable
15 |
16 | def relaxed_acc(prediction: str, target: str,
17 | max_relative_change: float = 0.05) -> bool:
18 |
19 | def _to_float(text: str) -> Optional[float]:
20 | try:
21 | match = re.search(r'[\d.]+', text.replace(',', ''))
22 | if match: return float(match.group())
23 | return None
24 | except ValueError:
25 | return None
26 |
27 | prediction_float = _to_float(prediction)
28 | target_float = _to_float(target)
29 |
30 | if prediction_float is not None and target_float is not None:
31 | if target_float == 0:
32 | relative_change = abs(prediction_float - target_float)
33 | else:
34 | relative_change = abs(prediction_float - target_float) / abs(target_float)
35 | return relative_change <= max_relative_change
36 | else:
37 | lp = prediction.lower()
38 | tp = target.lower()
39 |
40 | if ("yes" in lp and "yes" in tp) or ("no" in lp and "no" in tp): return True
41 | if lp in tp: return True
42 | return lp == tp
43 |
44 | def evaluate_relaxed_accuracy(entries, margin=0.05):
45 | scores = []
46 | for elem in entries:
47 | if isinstance(elem['annotation'], str):
48 | elem['annotation'] = [elem['annotation']]
49 | score = max([
50 | relaxed_acc(elem['answer'].strip(), ann, margin)
51 | for ann in elem['annotation']
52 | ])
53 | scores.append(score)
54 |
55 | return sum(scores) / len(scores)
56 |
57 | def execute_python_code(code):
58 | old_stdout = sys.stdout
59 | new_stdout = io.StringIO()
60 | sys.stdout = new_stdout
61 |
62 | status = True
63 | try:
64 | exec(code)
65 | except Exception as e:
66 | status = False
67 | finally:
68 | sys.stdout = old_stdout
69 |
70 | if status:
71 | output = new_stdout.getvalue()
72 | else:
73 | output = None
74 | return output, status
75 |
76 | def extract_python_content(text):
77 | pattern = r"```python(.*?)```"
78 | matches = re.findall(pattern, text, re.DOTALL)
79 | return matches
80 |
81 | class ChartQATester:
82 |
83 | def __init__(self, ckpt_path=None, pot=False, pot_idx=0):
84 | # ChartQA root
85 | self.root = ChartQA_ROOT
86 | self.vis_root = ChartQA_TEST_IMG_ROOT
87 |
88 | self.robot = ChartMoE_Robot(ckpt_path=ckpt_path)
89 | self.prompt = '[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
90 | pot_prompts = [
91 | '[UNUSED_TOKEN_146]user\nPlease give the program of thought.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n',
92 | '[UNUSED_TOKEN_146]user\nPlease give the program of thought in python code. Use print function to output the answer in the end.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n',
93 | ]
94 | self.pot_prompt = pot_prompts[pot_idx]
95 |
96 | self.pot = pot
97 | if self.pot:
98 | self.robot.reset_prompt(prompt=self.pot_prompt)
99 | else:
100 | self.robot.reset_prompt(prompt=self.prompt)
101 |
102 | def reset_prompt(self, p):
103 | self.system_prompt = p
104 |
105 | def infer_all_answers(self, output_path):
106 |
107 | os.makedirs(output_path, exist_ok=True)
108 | print(f"Result will be saved at: {output_path}")
109 |
110 | part_acc = []
111 | for part_name in ['human', 'augmented']:
112 | part_json = os.path.join(output_path, f"{part_name}.json")
113 | if os.path.exists(part_json):
114 | print(f"Load result from: {part_json}")
115 | part = json.load(open(part_json, 'r'))
116 | else:
117 | part = []
118 | samples = json.load(open(self.root+f'test/test_{part_name}.json'))
119 | for q in tqdm(samples):
120 | im_path = os.path.join(self.vis_root, q['imgname'])
121 | question = q['query']
122 |
123 | with torch.cuda.amp.autocast():
124 | response, _ = self.robot.chat(
125 | image_path=im_path,
126 | question=question,
127 | max_new_tokens=500,
128 | num_beams=1,
129 | )
130 | if self.pot:
131 | extraced_result = extract_python_content(response)
132 | if extraced_result:
133 | code = extraced_result[0]
134 | else:
135 | code = response
136 | response, status = execute_python_code(code)
137 |
138 | if not status:
139 | response = "error running..."
140 | response = response.replace("True","Yes").replace("False","No")
141 | response = response.strip()
142 | part.append({
143 | 'image': im_path,
144 | 'query': question,
145 | 'answer': response,
146 | 'annotation': q['label'],
147 | 'code': code
148 | })
149 | else:
150 | part.append({
151 | 'image': im_path,
152 | 'query': question,
153 | 'answer': response,
154 | 'annotation': q['label']
155 | })
156 | with open(part_json, 'w') as f:
157 | json.dump(part, f, indent=4)
158 | part_acc.append(part)
159 |
160 | table = PrettyTable()
161 | table.field_names = ["@AP", "0.05", "0.1", "0.2"]
162 | human_row = ["Human"]
163 | augmented_row = ["Augmented"]
164 | averaged_row = ["Averaged"]
165 | for ap in [0.05, 0.1, 0.2]:
166 | part_acc_ap = [evaluate_relaxed_accuracy(p, ap) for p in part_acc]
167 | human_acc = part_acc_ap[0]
168 | augmented_acc = part_acc_ap[1]
169 | averaged_acc = (human_acc + augmented_acc) / 2
170 | human_row.append(human_acc)
171 | augmented_row.append(augmented_acc)
172 | averaged_row.append(averaged_acc)
173 |
174 | table.add_row(human_row)
175 | table.add_row(augmented_row)
176 | table.add_row(averaged_row)
177 |
178 | table_path = os.path.join(output_path, 'table.txt')
179 | with open(table_path, 'w') as f:
180 | f.write(str(table))
181 |
182 | print(table)
183 |
184 | if __name__ == "__main__":
185 | parser = argparse.ArgumentParser()
186 | parser.add_argument("--save_path", type=str, required=True)
187 | parser.add_argument("--ckpt_path", type=str, default=None)
188 | parser.add_argument('--pot', action='store_true')
189 | parser.add_argument('--pot_idx', type=int, default=0, choices=[0, 1])
190 | args = parser.parse_args()
191 |
192 | tester = ChartQATester(ckpt_path=args.ckpt_path, pot=args.pot, pot_idx=args.pot_idx)
193 | tester.infer_all_answers(output_path=args.save_path)
--------------------------------------------------------------------------------
/chartmoe/eval_MME.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13 | " from .autonotebook import tqdm as notebook_tqdm\n",
14 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
15 | " warnings.warn(\n",
16 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
17 | "- tokenization_internlm_xcomposer2.py\n",
18 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
19 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
20 | "- configuration_chartmoe.py\n",
21 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
22 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
23 | "- build_mlp.py\n",
24 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
25 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
26 | "- modeling_internlm2.py\n",
27 | "- build_mlp.py\n",
28 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
29 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
30 | "- build_moe_connector.py\n",
31 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
32 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n",
33 | "- modeling_chartmoe.py\n",
34 | "- modeling_internlm2.py\n",
35 | "- build_moe_connector.py\n",
36 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n",
37 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
38 | " warnings.warn(\n",
39 | "Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 4.23it/s]\n"
40 | ]
41 | },
42 | {
43 | "name": "stdout",
44 | "output_type": "stream",
45 | "text": [
46 | "Set max length to 4096\n"
47 | ]
48 | },
49 | {
50 | "name": "stderr",
51 | "output_type": "stream",
52 | "text": [
53 | "Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00, 4.16s/it]\n"
54 | ]
55 | }
56 | ],
57 | "source": [
58 | "import os\n",
59 | "os.environ['CUDA_VISIBLE_DEVICES']='2'\n",
60 | "import sys \n",
61 | "import json\n",
62 | "import torch\n",
63 | "import numpy as np\n",
64 | "from PIL import Image \n",
65 | "from tqdm import tqdm\n",
66 | "import datetime\n",
67 | "from collections import defaultdict\n",
68 | "\n",
69 | "\n",
70 | "from datasets import load_dataset\n",
71 | "from chartmoe import ChartMoE_Robot\n",
72 | "\n",
73 | "mme_data = load_dataset(\"lmms-lab/MME\")['test']\n",
74 | "\n",
75 | "robot = ChartMoE_Robot()"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": 2,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "eval_type_dict = {\n",
85 | " \"Perception\": [\n",
86 | " \"existence\",\n",
87 | " \"count\",\n",
88 | " \"position\",\n",
89 | " \"color\",\n",
90 | " \"posters\",\n",
91 | " \"celebrity\",\n",
92 | " \"scene\",\n",
93 | " \"landmark\",\n",
94 | " \"artwork\",\n",
95 | " \"OCR\",\n",
96 | " ],\n",
97 | " \"Cognition\": [\n",
98 | " \"commonsense_reasoning\",\n",
99 | " \"numerical_calculation\",\n",
100 | " \"text_translation\",\n",
101 | " \"code_reasoning\",\n",
102 | " ],\n",
103 | "}"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": 3,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "def parse_pred_ans(pred_ans):\n",
113 | " \"\"\"Brought from Otter Eval\"\"\"\n",
114 | " pred_ans = pred_ans.lower().strip().replace(\".\", \"\")\n",
115 | " pred_label = None\n",
116 | " if pred_ans in [\"yes\", \"no\"]:\n",
117 | " pred_label = pred_ans\n",
118 | " elif len(pred_ans) == 1:\n",
119 | " if pred_ans == \"y\":\n",
120 | " pred_label = \"yes\"\n",
121 | " elif pred_ans == \"n\":\n",
122 | " pred_label = \"no\"\n",
123 | " else:\n",
124 | " pred_label = \"other\"\n",
125 | " else:\n",
126 | " prefix_pred_ans = pred_ans[:4]\n",
127 | " if \"yes\" in prefix_pred_ans:\n",
128 | " pred_label = \"yes\"\n",
129 | " elif \"no\" in prefix_pred_ans:\n",
130 | " pred_label = \"no\"\n",
131 | " else:\n",
132 | " pred_label = \"other\"\n",
133 | " return pred_label"
134 | ]
135 | },
136 | {
137 | "cell_type": "code",
138 | "execution_count": 4,
139 | "metadata": {},
140 | "outputs": [
141 | {
142 | "name": "stderr",
143 | "output_type": "stream",
144 | "text": [
145 | " 0%| | 0/2374 [00:00, ?it/s]/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/transformers/generation/utils.py:1417: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation )\n",
146 | " warnings.warn(\n",
147 | "100%|██████████| 2374/2374 [58:05<00:00, 1.47s/it]\n"
148 | ]
149 | }
150 | ],
151 | "source": [
152 | "results = []\n",
153 | "for d in tqdm(mme_data):\n",
154 | " image = d['image'].convert(\"RGB\")\n",
155 | " question = d['question']\n",
156 | " category = d['category']\n",
157 | " gt_ans = d[\"answer\"].lower().strip().replace(\".\", \"\")\n",
158 | "\n",
159 | " with torch.cuda.amp.autocast():\n",
160 | " pred, _ = robot.chat(\n",
161 | " image=image,\n",
162 | " question=question,\n",
163 | " temperature=1.0,\n",
164 | " max_new_tokens=5,\n",
165 | " num_beams=5,\n",
166 | " do_sample=False,\n",
167 | " repetition_penalty=1.0\n",
168 | " )\n",
169 | "\n",
170 | " pred_ans = parse_pred_ans(pred)\n",
171 | " assert gt_ans in [\"yes\", \"no\"]\n",
172 | " # assert pred_ans in [\"yes\", \"no\", \"other\"]\n",
173 | "\n",
174 | " score = 1.0 if pred_ans == gt_ans else 0.0\n",
175 | " key_name = \"mme_percetion_score\" if category in eval_type_dict[\"Perception\"] else \"mme_cognition_score\"\n",
176 | "\n",
177 | " results.append({key_name: {\"question_id\": d[\"question_id\"], \"category\": category, \"score\": score}})\n",
178 | "\n",
179 | "with open(\"mme_results.jsonl\", 'w') as f:\n",
180 | " for res in results:\n",
181 | " f.write(f\"{json.dumps(res)}\\n\")"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": 5,
187 | "metadata": {},
188 | "outputs": [
189 | {
190 | "name": "stdout",
191 | "output_type": "stream",
192 | "text": [
193 | "2214.1313525410164\n"
194 | ]
195 | }
196 | ],
197 | "source": [
198 | "category2score = defaultdict(dict)\n",
199 | "results = [list(res.values())[0] for res in results]\n",
200 | "for result in results:\n",
201 | " question_id = result[\"question_id\"]\n",
202 | " score = result[\"score\"]\n",
203 | " category = result[\"category\"]\n",
204 | " if question_id not in category2score[category]:\n",
205 | " category2score[category][question_id] = []\n",
206 | " category2score[category][question_id].append(score)\n",
207 | "category2avg_score = {}\n",
208 | "for category, question2scores in category2score.items():\n",
209 | " total_score = 0\n",
210 | " for question_id, scores in question2scores.items():\n",
211 | " assert len(scores) == 2, \"MME only supports pairwise evaluation\"\n",
212 | " acc = sum(scores) / len(scores) * 100.0\n",
213 | " acc_plus = (sum(scores) == 2) * 100.0\n",
214 | " score = acc_plus + acc\n",
215 | " total_score += score\n",
216 | " avg_score = total_score / len(question2scores)\n",
217 | " category2avg_score[category] = avg_score\n",
218 | "total_score = sum(category2avg_score.values())\n",
219 | "print(total_score)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 6,
225 | "metadata": {},
226 | "outputs": [
227 | {
228 | "data": {
229 | "text/plain": [
230 | "{'code_reasoning': 117.5,\n",
231 | " 'artwork': 186.25,\n",
232 | " 'celebrity': 163.8235294117647,\n",
233 | " 'numerical_calculation': 147.5,\n",
234 | " 'text_translation': 155.0,\n",
235 | " 'count': 170.0,\n",
236 | " 'color': 165.0,\n",
237 | " 'commonsense_reasoning': 140.71428571428572,\n",
238 | " 'position': 158.33333333333334,\n",
239 | " 'OCR': 125.0,\n",
240 | " 'landmark': 172.0,\n",
241 | " 'scene': 157.5,\n",
242 | " 'existence': 180.0,\n",
243 | " 'posters': 175.51020408163265}"
244 | ]
245 | },
246 | "execution_count": 6,
247 | "metadata": {},
248 | "output_type": "execute_result"
249 | }
250 | ],
251 | "source": [
252 | "category2avg_score"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 7,
258 | "metadata": {},
259 | "outputs": [
260 | {
261 | "data": {
262 | "text/plain": [
263 | "defaultdict(int,\n",
264 | " {'Perception': 1653.4170668267307, 'Cognition': 560.7142857142858})"
265 | ]
266 | },
267 | "execution_count": 7,
268 | "metadata": {},
269 | "output_type": "execute_result"
270 | }
271 | ],
272 | "source": [
273 | "scores = defaultdict(int)\n",
274 | "for eval_type in eval_type_dict:\n",
275 | " for category_type in eval_type_dict[eval_type]:\n",
276 | " scores[eval_type] += category2avg_score[category_type]\n",
277 | "scores"
278 | ]
279 | },
280 | {
281 | "cell_type": "code",
282 | "execution_count": 9,
283 | "metadata": {},
284 | "outputs": [
285 | {
286 | "data": {
287 | "text/plain": [
288 | "2214.1000000000004"
289 | ]
290 | },
291 | "execution_count": 9,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "1653.4 + 560.7"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {},
304 | "outputs": [],
305 | "source": []
306 | }
307 | ],
308 | "metadata": {
309 | "kernelspec": {
310 | "display_name": "Python 3.9.19 ('intern_clean')",
311 | "language": "python",
312 | "name": "python3"
313 | },
314 | "language_info": {
315 | "codemirror_mode": {
316 | "name": "ipython",
317 | "version": 3
318 | },
319 | "file_extension": ".py",
320 | "mimetype": "text/x-python",
321 | "name": "python",
322 | "nbconvert_exporter": "python",
323 | "pygments_lexer": "ipython3",
324 | "version": "3.9.19"
325 | },
326 | "orig_nbformat": 4,
327 | "vscode": {
328 | "interpreter": {
329 | "hash": "a726b55af14ee9f10619d25e42820b32d50f8ab305998596bdf5d4abd3695153"
330 | }
331 | }
332 | },
333 | "nbformat": 4,
334 | "nbformat_minor": 2
335 | }
336 |
--------------------------------------------------------------------------------
/chartmoe/generation_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Generation Scipt of ChartMoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | import os
7 |
8 | import torch
9 | from transformers import AutoModel, AutoTokenizer
10 | from PIL import Image
11 | import torchvision
12 |
13 | from chartmoe.utils.custom_path import ChartMoE_HF_PATH
14 |
15 | def __padding__(image):
16 | width, height = image.size
17 | tar = max(width, height)
18 | top_padding = int((tar - height)/2)
19 | bottom_padding = tar - height - top_padding
20 | left_padding = int((tar - width)/2)
21 | right_padding = tar - width - left_padding
22 | image = torchvision.transforms.functional.pad(image, [left_padding, top_padding, right_padding, bottom_padding])
23 | return image
24 |
25 | class ChartMoE_Robot:
26 | def __init__(self, ckpt_path = None, img_padding = False):
27 | model_path = ckpt_path if ckpt_path else ChartMoE_HF_PATH
28 | tokenizer = AutoTokenizer.from_pretrained(
29 | model_path,
30 | trust_remote_code=True
31 | )
32 | print(f"\033[34mLoad model from {model_path}\033[0m")
33 | self.model = AutoModel.from_pretrained(
34 | model_path,
35 | trust_remote_code=True,
36 | ).half().cuda().eval()
37 | self.tokenizer = tokenizer
38 | self.model.tokenizer = tokenizer
39 |
40 | self.prompt = '[UNUSED_TOKEN_146]user\n{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n'
41 |
42 | self.img_padding = img_padding
43 |
44 | def reset_prompt(self, prompt):
45 | self.prompt = prompt
46 |
47 | def chat(
48 | self,
49 | image_path=None,
50 | image=None,
51 | question="",
52 | history="",
53 | temperature=1,
54 | max_new_tokens=1000,
55 | num_beams=1,
56 | do_sample=False,
57 | repetition_penalty=1.0,
58 | ):
59 | need_bos = True
60 | pt1 = 0
61 | embeds = []
62 | im_mask = []
63 | question = self.prompt.format(question)
64 | history += question
65 |
66 | if image_path and image:
67 | assert False, "Just give the `image_path` or give the `PIL.Image` to `image`!"
68 | if image_path is None and image is None:
69 | assert False, "`image_path` and `image` are both None! Please give the `image_path` or give the `PIL.Image` to `image`!"
70 |
71 | if image_path:
72 | images = [image_path]
73 | else:
74 | images = [image]
75 | images_loc = [0]
76 |
77 | for i, pts in enumerate(images_loc + [len(history)]):
78 | subtext = history[pt1:pts]
79 | if need_bos or len(subtext) > 0:
80 | text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos)
81 | embeds.append(text_embeds)
82 | im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda())
83 | need_bos = False
84 | if i < len(images):
85 | try:
86 | image = Image.open(images[i]).convert('RGB')
87 | except:
88 | image = images[i].convert('RGB')
89 | if self.img_padding:
90 | image = __padding__(image)
91 | image = self.model.vis_processor(image).unsqueeze(0).cuda()
92 | image_embeds = self.model.encode_img(image)
93 | embeds.append(image_embeds)
94 | im_mask.append(torch.ones(image_embeds.shape[:2]).cuda())
95 | pt1 = pts
96 | embeds = torch.cat(embeds, dim=1)
97 | im_mask = torch.cat(im_mask, dim=1)
98 | im_mask = im_mask.bool()
99 |
100 | eos_token_id = [
101 | self.tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0],
102 | self.tokenizer.eos_token_id,
103 | ]
104 | outputs = self.model.generate(
105 | inputs_embeds=embeds,
106 | im_mask=im_mask,
107 | temperature=temperature,
108 | max_new_tokens=max_new_tokens,
109 | num_beams=num_beams,
110 | do_sample=do_sample,
111 | repetition_penalty=repetition_penalty,
112 | eos_token_id=eos_token_id,
113 | )
114 |
115 | output_token = outputs[0]
116 | if output_token[0] == 0 or output_token[0] == 1:
117 | output_token = output_token[1:]
118 | output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False)
119 | history += output_text
120 | output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip()
121 | return output_text, history
--------------------------------------------------------------------------------
/chartmoe/train/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Training Recipes of ChartMoE
3 |
4 |
5 |
6 |
Datasets are released at 🤗https://huggingface.co/datasets/Coobiw/ChartMoE-Data!
7 |
8 |
9 | In this part, I'll introduct the training recipes for reproducing ChartMoE. Except for the training recipes, I also provided a checkpoint that can be reproduced according to following instructions. You can find it at [🤗](https://huggingface.co/Coobiw/ChartMoE_Reproduced). **This version has better performance on ChartQA(both with & without PoT).**
10 |
11 |
12 | ## Download and Organize the ChartMoE-Data
13 | [🤗ChartMoE Data](https://huggingface.co/datasets/Coobiw/ChartMoE-Data) has been released! You can download it by running:
14 |
15 | ```bash
16 | cd chartmoe/train
17 | python scripts/chartmoe_data_download.py
18 | ```
19 | Datasets will appear at `chartmoe/train/data`.
20 |
21 | Then, please unzip these two files.
22 | ```bash
23 | unzip ChartMoE-Align.zip
24 | unzip SFT.zip
25 | ```
26 |
27 | Additionally, I want to announce that the `ChartY_replot` in `ChartMoE-Align` contains data with higher quality and bilingual texts! It may be a good choice to sample more from `ChartY_replot`.
28 |
29 | ### Data Format
30 | ```python
31 | [
32 | {
33 | "id": "0",
34 | "image": ['path/to/image_0.jpg']
35 | "conversations": [
36 | {
37 | "from": "user",
38 | "value": " Please describe these two images in detail."
39 | },
40 | {
41 | "from": "assistant",
42 | "value": "......"
43 | }
44 | ]
45 | },
46 | {
47 | "id": "1",
48 | "image": ['path/to/image_1.jpg']
49 | "conversations": [
50 | {
51 | "from": "user",
52 | "value": " what is the color of the dog"
53 | },
54 | {
55 | "from": "assistant",
56 | "value": "it is ...."
57 | }
58 | ]
59 | }
60 | ]
61 | ```
62 |
63 | ## Download InternLM_XComposer2_Enhanced
64 |
65 | **Note: I've supported `flash-attn` and `batchified training` for InternLM-XComposer2 on [Coobiw/InternLM-XComposer2_Enhanced](https://huggingface.co/Coobiw/InternLM-XComposer2_Enhanced). This will indeed acclerate training.**
66 |
67 | Run:
68 |
69 | ```bash
70 | cd chartmoe/train
71 | python scripts/internlm_xc2_download.py
72 | ```
73 |
74 | Then, ChartMoE will appear at `chartmoe/train/ckpt/InternLM-XComposer2_Enhanced`.
75 |
76 | ## Diversely-Aligned MoE-MLP Training
77 |
78 | ### Download the intermediate checkpoint
79 | I've uploaded the weight of the moe-connector which is diversely aligned (each of the experts is trainable, but the router is randomly initialized). [🤗HF Link](https://huggingface.co/Coobiw/ChartMoE-Aligned-Connector/tree/main). Please put the `mlp_moe.pth` to `chartmoe/train/output/moe_aligned/mlp_moe.pth`! Then you can directly run sft script~
80 |
81 | ### Training Pipeline of ChartMoE
82 | If you want to train your own moe-connecor, you can feel free to follow these instructions!
83 |
84 | 
85 |
86 | Run:
87 |
88 | ```bash
89 | cd chartmoe/train
90 | bash scripts/multi_align.sh
91 | ```
92 |
93 | Then, the table/json/code MLP connector will appear at `chartmoe/train/output/{}_proj`.format(table/json/code)!
94 |
95 | After diversely alignment, we can construct the MoE-MLP connector by running:
96 |
97 | ```bash
98 | cd chartmoe/train
99 | bash scripts/moe_construction.sh
100 | ```
101 |
102 | The MoE-MLP connnector will appear at `chartmoe/train/output/moe_aligned/mlp_moe.pth`.
103 |
104 | ## SFT
105 |
106 | *Note: In this Repo, we don't add "High-Quality Knowledge Learning" mid-training.*
107 |
108 | Please notice [the path of MoE-MLP connector](./scripts/sft.sh#L24).
109 |
110 | Run:
111 |
112 | ```bash
113 | cd chartmoe/train
114 | mkdir -p logs/sft
115 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/sft.sh 2>&1 | tee logs/sft/tee_logs.txt
116 | ```
117 |
118 | ## Merge MLP-MoE Connector and LoRA Weights for ChartMoE Construction
119 | Run:
120 |
121 | ```bash
122 | cd chartmoe/train
123 | bash scripts/chartmoe_construction.sh
124 | ```
125 |
126 | ## Evaluation on ChartQA
127 | w/o PoT:
128 |
129 | ```bash
130 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --ckpt_path chartmoe/train/output/sft/chartmoe_reproduced --save_path chartmoe/train/output/sft/chartmoe_reproduced/ChartQA_wo-PoT
131 | ```
132 |
133 | Result:
134 | ```
135 | +-----------+--------+--------+--------+
136 | | @AP | 0.05 | 0.1 | 0.2 |
137 | +-----------+--------+--------+--------+
138 | | Human | 0.704 | 0.7376 | 0.772 |
139 | | Augmented | 0.9056 | 0.9192 | 0.9352 |
140 | | Averaged | 0.8048 | 0.8284 | 0.8536 |
141 | +-----------+--------+--------+--------+
142 | ```
143 |
144 | PoT:
145 | ```bash
146 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --ckpt_path chartmoe/train/output/sft/chartmoe_reproduced --save_path chartmoe/train/output/sft/chartmoe_reproduced/ChartQA_PoT --pot --pot_idx 1
147 | ```
148 |
149 | Result:
150 | ```
151 | +-----------+--------+--------+-------+
152 | | @AP | 0.05 | 0.1 | 0.2 |
153 | +-----------+--------+--------+-------+
154 | | Human | 0.7952 | 0.8128 | 0.828 |
155 | | Augmented | 0.904 | 0.9176 | 0.932 |
156 | | Averaged | 0.8496 | 0.8652 | 0.88 |
157 | +-----------+--------+--------+-------+
158 | ```
159 |
--------------------------------------------------------------------------------
/chartmoe/train/chartmoe_construction.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Construct ChartMoE after Post-Training, including: 1. Merge LoRA to LLM 2. Adapt to ChartMoE HF Implementation
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | from dataclasses import dataclass, field
7 | from typing import Optional
8 |
9 | import torch
10 | from peft import PeftConfig, PeftModel
11 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
12 |
13 | from mlp_moe import MLPMoE
14 | import os
15 | import shutil
16 | from glob import glob
17 |
18 |
19 | @dataclass
20 | class ScriptArguments:
21 | """The input names representing the Adapter and Base model fine-tuned with
22 | PEFT, and the output name representing the merged model."""
23 |
24 | moe_aligned_pth_path: Optional[str] = field(
25 | default=None, metadata={'help': 'the path of aligned moe .pth file'}
26 | )
27 | chartmoe_hf_dir: Optional[str] = field(
28 | default=None, metadata={'help': 'the path of downloaded chartmoe hf dir'}
29 | )
30 | adapter_model_name: Optional[str] = field(
31 | default=None, metadata={'help': 'the adapter name'}
32 | )
33 | output_path: Optional[str] = field(
34 | default=None, metadata={'help': 'the merged model saved path'}
35 | )
36 |
37 |
38 | parser = HfArgumentParser(ScriptArguments)
39 | script_args = parser.parse_args_into_dataclasses()[0]
40 | assert script_args.moe_aligned_pth_path is not None, 'please provide the path of aligned moe .pth file'
41 | assert script_args.adapter_model_name is not None, 'please provide the name of the Adapter you would like to merge'
42 | assert script_args.output_path is not None, 'please provide the the merged model saved path'
43 | assert script_args.chartmoe_hf_dir is not None, 'please provide the path of downloaded chartmoe hf dir'
44 |
45 | adapter_model_name = glob(os.path.join(script_args.adapter_model_name, "checkpoint-*"))[0]
46 |
47 | # get base model path from adapter_config.json
48 | peft_config = PeftConfig.from_pretrained(adapter_model_name)
49 | base_model_path = peft_config.base_model_name_or_path
50 | print(f"\033[31mLoad base model from {base_model_path}\033[0m")
51 | model = AutoModelForCausalLM.from_pretrained(
52 | base_model_path,
53 | return_dict=True,
54 | torch_dtype=torch.bfloat16,
55 | trust_remote_code=True,
56 | device_map="cuda",
57 | attn_implementation="eager",
58 | )
59 |
60 | mlp_moe_state_dict = torch.load(script_args.moe_aligned_pth_path, map_location="cpu")
61 |
62 | num_experts = mlp_moe_state_dict['gate.weight'].size(0)
63 | num_selected = mlp_moe_state_dict.pop('num_selected')
64 |
65 | mlp_moe = MLPMoE(num_experts, num_selected, 1024, 4096).to(model.device)
66 | mlp_moe.load_state_dict(mlp_moe_state_dict)
67 |
68 | print("\033[32mload aligned moe...\033[0m")
69 | model.vision_proj = mlp_moe
70 |
71 | tokenizer = AutoTokenizer.from_pretrained(
72 | base_model_path, trust_remote_code=True
73 | )
74 |
75 | # Load the PEFT model
76 | model = PeftModel.from_pretrained(model, adapter_model_name)
77 | model.eval()
78 |
79 | print("\033[33mmerge the lora weights and `modules_to_save` weights\033[0m")
80 | model = model.merge_and_unload()
81 |
82 | model.save_pretrained(f'{script_args.output_path}')
83 | tokenizer.save_pretrained(f'{script_args.output_path}')
84 |
85 | print("\033[34msave and adapt to `ChartMoE` format...\033[0m")
86 | # adjust the contained files
87 | chartmoe_files = [
88 | 'special_tokens_map.json',
89 | 'configuration_chartmoe.py',
90 | 'modeling_internlm2.py',
91 | 'README.md',
92 | 'config.json',
93 | 'generation_config.json',
94 | '.gitattributes',
95 | 'teaser.png',
96 | 'zero_to_fp32.py',
97 | 'pytorch_model.bin.index.json',
98 | 'tokenization_internlm_xcomposer2.py',
99 | 'build_mlp.py',
100 | 'tokenizer_config.json',
101 | 'build_moe_connector.py',
102 | 'tokenizer.model',
103 | 'modeling_chartmoe.py',
104 | ]
105 | keep_files = [
106 | 'pytorch_model-00001-of-00002.bin',
107 | 'pytorch_model-00002-of-00002.bin'
108 | ]
109 |
110 | for fn in os.listdir(script_args.output_path):
111 | if fn not in keep_files:
112 | os.remove(os.path.join(script_args.output_path, fn))
113 |
114 | for fn in chartmoe_files:
115 | if fn != 'config.json':
116 | shutil.copy(os.path.join(script_args.chartmoe_hf_dir, fn), os.path.join(script_args.output_path, fn))
117 | else:
118 | import json
119 | config = json.load(open(os.path.join(script_args.chartmoe_hf_dir, fn), encoding='utf-8'))
120 | config["num_experts"] = num_experts
121 | config["num_selected"] = num_selected
122 | with open(os.path.join(script_args.output_path, fn), 'w', encoding='utf-8') as fo:
123 | json.dump(config, fo, indent=4, ensure_ascii=False)
124 |
125 |
--------------------------------------------------------------------------------
/chartmoe/train/chartmoe_trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Trainer of ChartMoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | REFERENCE: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py
6 | """
7 | import os
8 | import torch
9 | import torch.nn as nn
10 |
11 | from transformers import Trainer
12 | from transformers.trainer import (
13 | is_sagemaker_mp_enabled,
14 | get_parameter_names,
15 | has_length,
16 | ALL_LAYERNORM_LAYERS,
17 | logger,
18 | )
19 | from typing import List, Optional
20 |
21 | def maybe_zero_3(param, ignore_status=False, name=None):
22 | from deepspeed import zero
23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
24 | if hasattr(param, "ds_id"):
25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
26 | if not ignore_status:
27 | print(name, 'no ignore status')
28 | with zero.GatheredParameters([param]):
29 | param = param.data.detach().cpu().clone()
30 | else:
31 | param = param.detach().cpu().clone()
32 | return param
33 |
34 |
35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
38 | return to_return
39 |
40 | class ChartMoETrainer(Trainer):
41 |
42 | def _save_checkpoint(self, model, trial, metrics=None):
43 | if getattr(self.args, 'tune_mm_mlp', False):
44 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
45 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
46 |
47 | run_dir = self._get_output_dir(trial=trial)
48 | output_dir = os.path.join(run_dir, checkpoint_folder)
49 |
50 | # Only save Adapter
51 | keys_to_match = ['vision_proj']
52 |
53 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
54 |
55 | if self.args.local_rank == 0 or self.args.local_rank == -1:
56 | self.model.config.save_pretrained(output_dir)
57 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_mlp.bin'))
58 | else:
59 | super(ChartMoETrainer, self)._save_checkpoint(model, trial, metrics)
60 |
61 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
62 | if getattr(self.args, 'tune_mm_mlp', False):
63 | pass
64 | else:
65 | super(ChartMoETrainer, self)._save(output_dir, state_dict)
--------------------------------------------------------------------------------
/chartmoe/train/data/code_align.txt:
--------------------------------------------------------------------------------
1 | ./data/chart2code.json 300
--------------------------------------------------------------------------------
/chartmoe/train/data/json_align.txt:
--------------------------------------------------------------------------------
1 | ./data/chart2json.json 200
--------------------------------------------------------------------------------
/chartmoe/train/data/sft.txt:
--------------------------------------------------------------------------------
1 | ./data/chartqa.json 70
2 | ./data/chartgemma.json -1
--------------------------------------------------------------------------------
/chartmoe/train/data/table_align.txt:
--------------------------------------------------------------------------------
1 | ./data/chart2table.json 500
--------------------------------------------------------------------------------
/chartmoe/train/data_mix.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Mixture of Common/Sampling Dataset
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | REFERENCE: https://github.com/InternLM/InternLM-XComposer
6 | """
7 | import random
8 | import bisect
9 |
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 | from torch.utils.data import Dataset
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 |
17 |
18 | def conv2text(sources):
19 | END_HUMAN = '[UNUSED_TOKEN_145]\n'
20 | END_BOT = '[UNUSED_TOKEN_145]\n'
21 | conversation = ''
22 |
23 | for idx, sentence in enumerate(sources):
24 | BEGIN_SIGNAL = ''
25 |
26 | from_str = sentence['from']
27 | if from_str.lower() == 'human' or from_str.lower() == 'user':
28 | from_str = '[UNUSED_TOKEN_146]user\n'
29 | temp = (
30 | BEGIN_SIGNAL + from_str + sentence['value'].strip() +
31 | END_HUMAN)
32 | else:
33 | from_str = '[UNUSED_TOKEN_146]assistant\n'
34 | temp = (
35 | BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT)
36 | conversation += temp
37 |
38 | return conversation + ''
39 |
40 |
41 | class ImageProcessor:
42 |
43 | def __init__(self, image_size=490):
44 | mean = (0.48145466, 0.4578275, 0.40821073)
45 | std = (0.26862954, 0.26130258, 0.27577711)
46 | self.normalize = transforms.Normalize(mean, std)
47 |
48 | self.transform = transforms.Compose([
49 | transforms.Resize((image_size, image_size),
50 | interpolation=InterpolationMode.BICUBIC),
51 | transforms.ToTensor(),
52 | self.normalize,
53 | ])
54 |
55 | def __call__(self, item):
56 | item = Image.open(item).convert('RGB')
57 | return self.transform(item)
58 |
59 |
60 | class Mix_dataset(Dataset):
61 |
62 | def __init__(self,
63 | json_datas,
64 | img_size=490,
65 | local_rank=0,
66 | hd_num=-1):
67 | """vis_root (string): Root directory of images (e.g. coco/images/)
68 | ann_root (string): directory to store the annotation file."""
69 | super().__init__()
70 | print(f'init mix data at rank {local_rank}')
71 | self.local_rank = local_rank
72 |
73 | self.datasets = []
74 | self.start_idx_per = [0]
75 | self.data_num = 0
76 | for _, d in json_datas.items():
77 | if 'image' in d[0].keys():
78 | has_img = True
79 | else:
80 | has_img = False
81 | sub_data_set = Common_dataset(
82 | d,
83 | has_img=has_img,
84 | img_size=img_size,
85 | hd_num=hd_num)
86 | self.datasets.append(sub_data_set)
87 | self.start_idx_per.append(self.start_idx_per[-1] + len(sub_data_set))
88 | self.data_num += len(sub_data_set)
89 |
90 | self.start_idx_per.pop(-1)
91 |
92 | if len(self.datasets) == 0:
93 | raise ValueError(
94 | 'Both _multi and _text are empty. Cannot sample any data.')
95 |
96 | def __len__(self):
97 | return self.data_num
98 |
99 | def __getitem__(self, index):
100 | index = index % self.data_num # avoid some indices which are outside the interval
101 | dataset_idx = bisect.bisect_right(self.start_idx_per, index) - 1
102 | sample_idx = index - self.start_idx_per[dataset_idx]
103 | sample = self.datasets[dataset_idx].get_item(sample_idx)
104 | return dict(samples=sample)
105 |
106 |
107 | class Common_dataset(Dataset):
108 |
109 | def __init__(self,
110 | raw_data,
111 | has_img=True,
112 | img_size=490,
113 | hd_num=-1):
114 | self.raw_data = raw_data
115 | print(f'load {len(self.raw_data)} data')
116 | assert hd_num == -1, "please set `hd_num` to `-1`"
117 |
118 | self.vis_processor = ImageProcessor(image_size=img_size)
119 | self.text_processor = conv2text
120 | self.has_img = has_img
121 |
122 | def __len__(self):
123 | return len(self.raw_data)
124 |
125 | def __get_item__(self, i):
126 | conv_text = conv2text(self.raw_data[i]['conversations'])
127 | sample = dict(text_input=conv_text, )
128 | if self.has_img:
129 | image_file = self.raw_data[i]['image']
130 | image = [self.vis_processor(i) for i in image_file]
131 | sample['image'] = torch.stack(image)
132 | else:
133 | sample['image'] = None
134 |
135 | return sample
136 |
137 | def get_item(self, idx):
138 | text_input = []
139 | images = []
140 |
141 | sample = self.__get_item__(idx)
142 | text_input.append(sample['text_input'])
143 | images.append(sample['image'])
144 | sample = {
145 | 'text_input': text_input,
146 | 'data_type': 'multi' if self.has_img else 'text',
147 | }
148 | if self.has_img:
149 | sample['image'] = torch.cat(images)
150 | return sample
151 |
152 |
153 |
154 |
155 | class Mix_sampling_dataset(Dataset):
156 |
157 | def __init__(self,
158 | json_datas,
159 | seq_packing_size=1,
160 | img_size=490,
161 | local_rank=0,
162 | hd_num=-1):
163 | """vis_root (string): Root directory of images (e.g. coco/images/)
164 | ann_root (string): directory to store the annotation file."""
165 | super().__init__()
166 | print(f'init mix sampling_data at rank {local_rank}')
167 | self.datasets_text, self.datasets_multi = [], []
168 | self.data_num_text, self.data_num_multi = [], []
169 |
170 | self.seq_packing_size = seq_packing_size
171 | self.set_seed = False
172 | self.local_rank = local_rank
173 | for _, d in json_datas.items():
174 | if 'image' in d[0].keys():
175 | has_img = True
176 | else:
177 | has_img = False
178 | sub_data_set = Sample_dataset(
179 | d,
180 | seq_packing_size,
181 | has_img=has_img,
182 | img_size=img_size,
183 | hd_num=hd_num)
184 | if has_img:
185 | self.datasets_multi.append(sub_data_set)
186 | self.data_num_multi.append(len(sub_data_set))
187 | else:
188 | self.datasets_text.append(sub_data_set)
189 | self.data_num_text.append(len(sub_data_set))
190 |
191 | self.data_ratio_multi = [
192 | float(ratio) / sum(self.data_num_multi)
193 | for ratio in self.data_num_multi
194 | ]
195 | self.data_ratio_text = [
196 | float(ratio) / sum(self.data_num_text)
197 | for ratio in self.data_num_text
198 | ]
199 | self.data_num = np.sum(self.data_num_multi) + np.sum(
200 | self.data_num_text)
201 | self.use_multi = 0
202 |
203 | def __len__(self):
204 | return int(np.sum(self.data_num) / self.seq_packing_size)
205 |
206 | def __getitem__(self, index):
207 | if not self.set_seed:
208 | random.seed(index)
209 | self.set_seed = True
210 | print(f'Set seed {index} for rank {self.local_rank}')
211 |
212 | if len(self.datasets_multi) == 0 and len(self.datasets_text) == 0:
213 | raise ValueError(
214 | 'Both _multi and _text are empty. Cannot sample any data.')
215 |
216 | if len(self.datasets_multi) > 0 and (self.use_multi < self.seq_packing_size
217 | or len(self.datasets_text) == 0):
218 | data_idx = random.choices(
219 | range(len(self.data_ratio_multi)),
220 | weights=self.data_ratio_multi,
221 | k=1)[0]
222 | sample = self.datasets_multi[data_idx].get_item()
223 | elif len(self.datasets_text) > 0:
224 | data_idx = random.choices(
225 | range(len(self.data_ratio_text)),
226 | weights=self.data_ratio_text,
227 | k=1)[0]
228 | sample = self.datasets_text[data_idx].get_item()
229 | else:
230 | raise ValueError('Unable to select a dataset for sampling.')
231 |
232 | self.use_multi += 1
233 | if self.use_multi > self.seq_packing_size * 2:
234 | self.use_multi = 0
235 | return dict(samples=sample)
236 |
237 |
238 | class Sample_dataset(Dataset):
239 |
240 | def __init__(self,
241 | raw_data,
242 | seq_packing_size,
243 | has_img=True,
244 | img_size=490,
245 | hd_num=-1):
246 | self.raw_data = raw_data
247 | print(f'load {len(self.raw_data)} data')
248 | self.seq_packing_size = seq_packing_size
249 | assert hd_num == -1, "please set `hd_num` to `-1`"
250 |
251 | self.vis_processor = ImageProcessor(image_size=img_size)
252 | self.text_processor = conv2text
253 | self.has_img = has_img
254 |
255 | def __len__(self):
256 | return len(self.raw_data)
257 |
258 | def __get_item__(self, i):
259 | conv_text = conv2text(self.raw_data[i]['conversations'])
260 | sample = dict(text_input=conv_text, )
261 | if self.has_img:
262 | image_file = self.raw_data[i]['image']
263 | image = [self.vis_processor(i) for i in image_file]
264 | sample['image'] = torch.stack(image)
265 | else:
266 | sample['image'] = None
267 |
268 | return sample
269 |
270 | def get_item(self, ):
271 | text_input = []
272 | images = []
273 | for i in range(self.seq_packing_size):
274 | # Randomly select an index from raw_data to get a random sample
275 | idx = random.randrange(len(self.raw_data))
276 | sample = self.__get_item__(idx)
277 | text_input.append(sample['text_input'])
278 | images.append(sample['image'])
279 | sample = {
280 | 'text_input': text_input,
281 | 'data_type': 'multi' if self.has_img else 'text',
282 | }
283 | if self.has_img:
284 | sample['image'] = torch.cat(images)
285 | return sample
286 |
--------------------------------------------------------------------------------
/chartmoe/train/ds_config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 2,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "allgather_partitions": true,
21 | "allgather_bucket_size": 2e8,
22 | "overlap_comm": true,
23 | "reduce_scatter": true,
24 | "reduce_bucket_size": 2e8,
25 | "contiguous_gradients": true
26 | },
27 |
28 | "gradient_accumulation_steps": "auto",
29 | "gradient_clipping": "auto",
30 | "steps_per_print": 100,
31 | "train_batch_size": "auto",
32 | "train_micro_batch_size_per_gpu": "auto",
33 | "wall_clock_breakdown": false
34 | }
35 |
--------------------------------------------------------------------------------
/chartmoe/train/mlp_moe.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Implementation of MLP-MoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | REFERENCE: https://github.com/SHI-Labs/CuMo
6 | """
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from einops import rearrange, repeat, reduce, pack, unpack
11 |
12 | class MLPMoE(nn.Module):
13 | def __init__(self, num_experts, num_selected, mm_channels, channels):
14 | super().__init__()
15 | self.num_experts = num_experts
16 | self.num_selected = num_selected
17 | self.mm_channels = mm_channels
18 | self.channels = channels
19 |
20 | self.gate = nn.Linear(mm_channels, num_experts, bias=False)
21 | # nn.init.zeros_(self.gate.weight)
22 |
23 | self.num_selected = num_selected
24 | self.num_experts = num_experts
25 |
26 | self.experts = nn.ModuleList(
27 | [
28 | nn.Sequential(
29 | nn.Linear(mm_channels, channels, bias=True),
30 | nn.GELU(),
31 | nn.Linear(channels, channels, bias=True)
32 | )
33 | for _ in range(num_experts)
34 | ]
35 | )
36 |
37 | def forward(self, x_img):
38 | gate_logits = self.gate(x_img)
39 | gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype)
40 |
41 | weights, selected_experts = torch.topk(gate_softmax, self.num_selected)
42 | weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype)
43 |
44 | results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype)
45 | for b in range(x_img.shape[0]):
46 | for i, expert in enumerate(self.experts):
47 | token_idx, nth_expert = torch.where(selected_experts[b] == i)
48 | results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx])
49 |
50 | return results
51 |
52 |
53 | class MLPMoE_bzloss(nn.Module):
54 | def __init__(self, num_experts, num_selected, mm_channels, channels):
55 | super().__init__()
56 | self.num_experts = num_experts
57 | self.num_selected = num_selected
58 | self.mm_channels = mm_channels
59 | self.channels = channels
60 |
61 | self.gate = nn.Linear(mm_channels, num_experts, bias=False)
62 | # nn.init.zeros_(self.gate.weight)
63 |
64 | self.num_selected = num_selected
65 | self.num_experts = num_experts
66 | self.experts = nn.ModuleList([nn.Sequential(nn.Linear(mm_channels, channels, bias=True), nn.GELU(), nn.Linear(channels, channels, bias=True)) for _ in range(num_experts)])
67 |
68 | def forward(self, x_img):
69 | gate_logits = self.gate(x_img)
70 |
71 | router_z_loss = torch.logsumexp(gate_logits, dim = -1)
72 | router_z_loss = torch.square(router_z_loss)
73 | router_z_loss = router_z_loss.mean()
74 |
75 | gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype)
76 |
77 | density_1_proxy = reduce(gate_softmax, '... n e -> ... e', 'mean')
78 |
79 | weights, selected_experts = torch.topk(gate_softmax, self.num_selected)
80 |
81 | one_hot_gate_indices = F.one_hot(rearrange(selected_experts, '... k -> k ...'), self.num_experts).float()[0]
82 | density_1 = reduce(one_hot_gate_indices, '... n e -> ... e', 'mean')
83 | balance_loss = (density_1_proxy * density_1).mean() * float(self.num_experts ** 2)
84 |
85 | weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype)
86 |
87 | results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype)
88 |
89 | for b in range(x_img.shape[0]):
90 | for i, expert in enumerate(self.experts):
91 | token_idx, nth_expert = torch.where(selected_experts[b] == i)
92 | results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx])
93 |
94 | return results, 0.1 * balance_loss, 0.01 * router_z_loss
--------------------------------------------------------------------------------
/chartmoe/train/moe_construction.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Construct MLP-MoE Connector with three aligned MLP and the general MLP Connector
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | from mlp_moe import MLPMoE
7 | import argparse
8 | from glob import glob
9 | import os
10 | from copy import deepcopy
11 |
12 | import torch
13 | import torch.nn as nn
14 | import transformers
15 |
16 | def main(args):
17 | root = args.root_dir
18 | table_ckpt = glob(f"{root}/table_proj/checkpoint-*/mm_mlp.bin")[0]
19 | json_ckpt = glob(f"{root}/json_proj/checkpoint-*/mm_mlp.bin")[0]
20 | code_ckpt = glob(f"{root}/code_proj/checkpoint-*/mm_mlp.bin")[0]
21 |
22 | base_model = transformers.AutoModel.from_pretrained(
23 | args.base_model,
24 | trust_remote_code=True,
25 | device_map="cuda",
26 | attn_implementation="eager",
27 | )
28 | base_proj = deepcopy(base_model.vision_proj)
29 | del base_model
30 |
31 | table_proj = torch.load(table_ckpt)
32 | json_proj = torch.load(json_ckpt)
33 | code_proj = torch.load(code_ckpt)
34 |
35 | mlp_moe = MLPMoE(args.mlp_smoe_experts, args.mlp_smoe_topk, 1024, 4096)
36 | for idx, expert in enumerate(mlp_moe.experts):
37 | print(idx % args.mlp_smoe_experts)
38 | if idx % args.mlp_smoe_experts == 0:
39 | for target_layer, source_layer in zip(expert, base_proj):
40 | if isinstance(target_layer, nn.Linear) and isinstance(source_layer, nn.Linear):
41 | target_layer.weight = deepcopy(source_layer.weight)
42 | target_layer.bias = deepcopy(source_layer.bias)
43 | print(f"{idx} expert: load base_proj")
44 | if idx % args.mlp_smoe_experts == 1:
45 | for ii in [0,2]:
46 | expert[ii].weight.data = table_proj[f'vision_proj.{ii}.weight']
47 | expert[ii].bias.data = table_proj[f'vision_proj.{ii}.bias']
48 | print(f"{idx} expert: load table_proj")
49 | if idx % args.mlp_smoe_experts == 2:
50 | for ii in [0,2]:
51 | expert[ii].weight.data = json_proj[f'vision_proj.{ii}.weight']
52 | expert[ii].bias.data = json_proj[f'vision_proj.{ii}.bias']
53 | print(f"{idx} expert: load json_proj")
54 | if idx % args.mlp_smoe_experts == 3:
55 | for ii in [0,2]:
56 | expert[ii].weight.data = code_proj[f'vision_proj.{ii}.weight']
57 | expert[ii].bias.data = code_proj[f'vision_proj.{ii}.bias']
58 | print(f"{idx} expert: load code_proj")
59 |
60 | os.makedirs(f"{root}/{args.save_name}", exist_ok=True)
61 | mlp_moe_state_dict = mlp_moe.state_dict()
62 | mlp_moe_state_dict['num_selected'] = args.mlp_smoe_topk
63 | torch.save(mlp_moe_state_dict, f"{root}/{args.save_name}/mlp_moe.pth")
64 |
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument("--base_model", type=str)
70 | parser.add_argument("--mlp_smoe_experts", type=int, default=4)
71 | parser.add_argument("--mlp_smoe_topk", type=int, default=2)
72 | parser.add_argument("--root_dir", type=str, required=True)
73 | parser.add_argument("--save_name", type=str, required=True)
74 | args = parser.parse_args()
75 |
76 | main(args)
77 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/chartmoe_construction.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python chartmoe_construction.py \
2 | --moe_aligned_pth_path output/moe_aligned/mlp_moe.pth \
3 | --chartmoe_hf_dir ckpt/chartmoe \
4 | --adapter_model_name output/sft \
5 | --output_path output/sft/chartmoe_reproduced
--------------------------------------------------------------------------------
/chartmoe/train/scripts/chartmoe_data_download.py:
--------------------------------------------------------------------------------
1 | import huggingface_hub
2 | import time
3 |
4 | def download():
5 | try:
6 | huggingface_hub.snapshot_download("Coobiw/ChartMoE-Data", local_dir="./data/", repo_type="dataset", resume_download=True)
7 | return True
8 | except:
9 | print("Caught an exception! Retrying...")
10 | return False
11 |
12 | while True:
13 | result = download()
14 | if result:
15 | print("success")
16 | break # Exit the loop if the function ran successfully
17 | time.sleep(1) # Wait for 1 second before retrying
18 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/chartmoe_download.py:
--------------------------------------------------------------------------------
1 | import huggingface_hub
2 | import time
3 |
4 | def download():
5 | try:
6 | huggingface_hub.snapshot_download("IDEA-FinAI/chartmoe", local_dir="./ckpt/chartmoe",resume_download=True,max_workers=4)
7 | return True
8 | except:
9 | print("Caught an exception! Retrying...")
10 | return False
11 |
12 | while True:
13 | result = download()
14 | if result:
15 | print("success")
16 | break # Exit the loop if the function ran successfully
17 | time.sleep(1) # Wait for 1 second before retrying
18 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/code_align.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced"
6 | export DATA="data/code_align.txt"
7 |
8 | GPUS_PER_NODE=4
9 | NNODES=1
10 | NODE_RANK=0
11 | MASTER_ADDR=localhost
12 | MASTER_PORT=12700
13 |
14 | DISTRIBUTED_ARGS="
15 | --nproc_per_node $GPUS_PER_NODE \
16 | --nnodes $NNODES \
17 | --node_rank $NODE_RANK \
18 | --master_addr $MASTER_ADDR \
19 | --master_port $MASTER_PORT
20 | "
21 |
22 | torchrun $DISTRIBUTED_ARGS train.py \
23 | --model_name_or_path $MODEL \
24 | --data_path $DATA \
25 | --dataloader_num_workers 4 \
26 | --img_size 490 \
27 | --hd_num -1 \
28 | --given_num True \
29 | --bf16 True \
30 | --fix_vit True \
31 | --fix_sampler False \
32 | --fix_llm True \
33 | --use_lora False \
34 | --num_train_epochs 1 \
35 | --per_device_train_batch_size 4 \
36 | --per_device_eval_batch_size 1 \
37 | --gradient_accumulation_steps 16 \
38 | --evaluation_strategy "no" \
39 | --save_strategy "epoch" \
40 | --save_total_limit 5 \
41 | --learning_rate 5e-5 \
42 | --weight_decay 0.1 \
43 | --adam_beta2 0.95 \
44 | --warmup_ratio 0.0 \
45 | --lr_scheduler_type "cosine" \
46 | --logging_steps 1 \
47 | --max_length 4096 \
48 | --gradient_checkpointing True \
49 | --deepspeed ds_config_zero2.json \
50 | --output_dir output/code_proj \
51 | --report_to none
52 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/internlm_xc2_download.py:
--------------------------------------------------------------------------------
1 | import huggingface_hub
2 | import time
3 |
4 | def download():
5 | try:
6 | huggingface_hub.snapshot_download("Coobiw/InternLM-XComposer2_Enhanced", local_dir="./ckpt/InternLM-XComposer2_Enhanced",resume_download=True,max_workers=4)
7 | return True
8 | except:
9 | print("Caught an exception! Retrying...")
10 | return False
11 |
12 | while True:
13 | result = download()
14 | if result:
15 | print("success")
16 | break # Exit the loop if the function ran successfully
17 | time.sleep(1) # Wait for 1 second before retrying
18 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/json_align.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced"
6 | export DATA="data/json_align.txt"
7 |
8 | GPUS_PER_NODE=4
9 | NNODES=1
10 | NODE_RANK=0
11 | MASTER_ADDR=localhost
12 | MASTER_PORT=12700
13 |
14 | DISTRIBUTED_ARGS="
15 | --nproc_per_node $GPUS_PER_NODE \
16 | --nnodes $NNODES \
17 | --node_rank $NODE_RANK \
18 | --master_addr $MASTER_ADDR \
19 | --master_port $MASTER_PORT
20 | "
21 |
22 | torchrun $DISTRIBUTED_ARGS train.py \
23 | --model_name_or_path $MODEL \
24 | --data_path $DATA \
25 | --dataloader_num_workers 4 \
26 | --img_size 490 \
27 | --hd_num -1 \
28 | --given_num True \
29 | --bf16 True \
30 | --fix_vit True \
31 | --fix_sampler False \
32 | --fix_llm True \
33 | --use_lora False \
34 | --num_train_epochs 1 \
35 | --per_device_train_batch_size 4 \
36 | --per_device_eval_batch_size 1 \
37 | --gradient_accumulation_steps 16 \
38 | --evaluation_strategy "no" \
39 | --save_strategy "epoch" \
40 | --save_total_limit 5 \
41 | --learning_rate 5e-5 \
42 | --weight_decay 0.1 \
43 | --adam_beta2 0.95 \
44 | --warmup_ratio 0.0 \
45 | --lr_scheduler_type "cosine" \
46 | --logging_steps 1 \
47 | --max_length 4096 \
48 | --gradient_checkpointing True \
49 | --deepspeed ds_config_zero2.json \
50 | --output_dir output/json_proj \
51 | --report_to none
52 |
--------------------------------------------------------------------------------
/chartmoe/train/scripts/moe_construction.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python moe_construction.py \
2 | --base_model ckpt/InternLM-XComposer2_Enhanced \
3 | --root_dir ./output \
4 | --save_name moe_aligned
--------------------------------------------------------------------------------
/chartmoe/train/scripts/multi_align.sh:
--------------------------------------------------------------------------------
1 | export CUDA_LAUNCH_BLOCKING=1
2 |
3 | mkdir -p logs/json_proj
4 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/json_align.sh 2>&1 | tee logs/json_proj/tee_logs.txt
5 | sleep 1m
6 |
7 | mkdir -p logs/code_proj
8 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/code_align.sh 2>&1 | tee logs/code_proj/tee_logs.txt
9 | sleep 1m
10 |
11 | mkdir -p logs/table_proj
12 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/table_align.sh 2>&1 | tee logs/table_proj/tee_logs.txt
--------------------------------------------------------------------------------
/chartmoe/train/scripts/sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced"
6 | export DATA="data/sft.txt"
7 |
8 | GPUS_PER_NODE=4
9 | NNODES=1
10 | NODE_RANK=0
11 | MASTER_ADDR=localhost
12 | MASTER_PORT=6001
13 |
14 | DISTRIBUTED_ARGS="
15 | --nproc_per_node $GPUS_PER_NODE \
16 | --nnodes $NNODES \
17 | --node_rank $NODE_RANK \
18 | --master_addr $MASTER_ADDR \
19 | --master_port $MASTER_PORT
20 | "
21 |
22 | torchrun $DISTRIBUTED_ARGS train.py \
23 | --model_name_or_path $MODEL \
24 | --moe_aligned_pth_path output/moe_aligned/mlp_moe.pth \
25 | --data_path $DATA \
26 | --dataloader_num_workers 4 \
27 | --img_size 490 \
28 | --hd_num -1 \
29 | --given_num True \
30 | --bf16 True \
31 | --fix_vit True \
32 | --fix_sampler False \
33 | --use_lora True \
34 | --num_train_epochs 1 \
35 | --per_device_train_batch_size 2 \
36 | --per_device_eval_batch_size 1 \
37 | --gradient_accumulation_steps 8 \
38 | --evaluation_strategy "no" \
39 | --save_strategy "epoch" \
40 | --save_total_limit 5 \
41 | --learning_rate 5e-5 \
42 | --weight_decay 0.1 \
43 | --adam_beta2 0.95 \
44 | --warmup_ratio 0.0 \
45 | --lr_scheduler_type "cosine" \
46 | --logging_steps 1 \
47 | --max_length 4096 \
48 | --gradient_checkpointing True \
49 | --deepspeed ds_config_zero2.json \
50 | --output_dir output/sft \
51 | --report_to none
--------------------------------------------------------------------------------
/chartmoe/train/scripts/table_align.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export CUDA_DEVICE_MAX_CONNECTIONS=1
3 | DIR=`pwd`
4 |
5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced"
6 | export DATA="data/table_align.txt"
7 |
8 | GPUS_PER_NODE=4
9 | NNODES=1
10 | NODE_RANK=0
11 | MASTER_ADDR=localhost
12 | MASTER_PORT=12700
13 |
14 | DISTRIBUTED_ARGS="
15 | --nproc_per_node $GPUS_PER_NODE \
16 | --nnodes $NNODES \
17 | --node_rank $NODE_RANK \
18 | --master_addr $MASTER_ADDR \
19 | --master_port $MASTER_PORT
20 | "
21 |
22 | torchrun $DISTRIBUTED_ARGS train.py \
23 | --model_name_or_path $MODEL \
24 | --data_path $DATA \
25 | --dataloader_num_workers 4 \
26 | --img_size 490 \
27 | --hd_num -1 \
28 | --given_num True \
29 | --bf16 True \
30 | --fix_vit True \
31 | --fix_sampler False \
32 | --fix_llm True \
33 | --use_lora False \
34 | --num_train_epochs 1 \
35 | --per_device_train_batch_size 4 \
36 | --per_device_eval_batch_size 1 \
37 | --gradient_accumulation_steps 16 \
38 | --evaluation_strategy "no" \
39 | --save_strategy "epoch" \
40 | --save_total_limit 5 \
41 | --learning_rate 5e-5 \
42 | --weight_decay 0.1 \
43 | --adam_beta2 0.95 \
44 | --warmup_ratio 0.0 \
45 | --lr_scheduler_type "cosine" \
46 | --logging_steps 1 \
47 | --max_length 4096 \
48 | --gradient_checkpointing True \
49 | --deepspeed ds_config_zero2.json \
50 | --output_dir output/table_proj \
51 | --report_to none
52 |
--------------------------------------------------------------------------------
/chartmoe/train/train.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Training Script of ChartMoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | REFERENCE: https://github.com/InternLM/InternLM-XComposer
6 | """
7 | import json
8 | import random
9 | from dataclasses import dataclass, field
10 | from typing import Dict, List, Optional, Sequence
11 |
12 | import torch
13 | import torch.nn as nn
14 | import transformers
15 | from accelerate.utils import DistributedType
16 |
17 | from data_mix import Mix_dataset
18 | from mlp_moe import MLPMoE
19 |
20 | from deepspeed import zero
21 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
22 | from peft import LoraConfig, get_peft_model
23 | from transformers import deepspeed
24 | from chartmoe_trainer import ChartMoETrainer
25 | from transformers.trainer_pt_utils import LabelSmoother
26 |
27 | from copy import deepcopy
28 | import os
29 |
30 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100
31 |
32 | @dataclass
33 | class ModelArguments:
34 | model_name_or_path: Optional[str] = field(default='')
35 | moe_aligned_pth_path: Optional[str] = field(default='')
36 |
37 |
38 | @dataclass
39 | class DataArguments:
40 | data_path: str = field(
41 | default='data.txt', metadata={'help': 'Path to the training data.'})
42 | given_num: bool = False
43 | img_size: int = 490
44 | hd_num: int = -1
45 |
46 |
47 | @dataclass
48 | class TrainingArguments(transformers.TrainingArguments):
49 | cache_dir: Optional[str] = field(default=None)
50 | optim: str = field(default='adamw_torch')
51 | max_length: int = field(
52 | default=4096,
53 | metadata={
54 | 'help':
55 | 'Maximum sequence length. Sequences will be right padded (and possibly truncated).'
56 | },
57 | )
58 | use_lora: bool = False
59 | fix_vit: bool = True
60 | fix_sampler: bool = False
61 | fix_llm: bool = True
62 | label_names: List[str] = field(default_factory=lambda: ['samples'])
63 |
64 |
65 | @dataclass
66 | class LoraArguments:
67 | lora_r: int = 64
68 | lora_alpha: int = 64
69 | lora_dropout: float = 0.05
70 | lora_target_modules: List[str] = field(default_factory=lambda: [
71 | 'attention.wqkv',
72 | 'attention.wo',
73 | 'feed_forward.w1',
74 | 'feed_forward.w2',
75 | 'feed_forward.w3',
76 | ])
77 | lora_weight_path: str = ''
78 | lora_bias: str = 'none'
79 |
80 | def maybe_zero_3(param):
81 | if hasattr(param, 'ds_id'):
82 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE
83 | with zero.GatheredParameters([param]):
84 | param = param.data.detach().cpu().clone()
85 | else:
86 | param = param.detach().cpu().clone()
87 | return param
88 |
89 |
90 | # Borrowed from peft.utils.get_peft_model_state_dict
91 | def get_peft_state_maybe_zero_3(named_params, bias):
92 | if bias == 'none':
93 | to_return = {k: t for k, t in named_params if 'lora_' in k}
94 | elif bias == 'all':
95 | to_return = {
96 | k: t
97 | for k, t in named_params if 'lora_' in k or 'bias' in k
98 | }
99 | elif bias == 'lora_only':
100 | to_return = {}
101 | maybe_lora_bias = {}
102 | lora_bias_names = set()
103 | for k, t in named_params:
104 | if 'lora_' in k:
105 | to_return[k] = t
106 | bias_name = k.split('lora_')[0] + 'bias'
107 | lora_bias_names.add(bias_name)
108 | elif 'bias' in k:
109 | maybe_lora_bias[k] = t
110 | for k, t in maybe_lora_bias:
111 | if bias_name in lora_bias_names:
112 | to_return[bias_name] = t
113 | else:
114 | raise NotImplementedError
115 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()}
116 | return to_return
117 |
118 |
119 | local_rank = None
120 |
121 |
122 | def rank0_print(*args):
123 | if local_rank == 0:
124 | print(*args)
125 |
126 |
127 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
128 | output_dir: str,
129 | bias='none'):
130 | """Collects the state dict and dump to disk."""
131 | # check if zero3 mode enabled
132 | if deepspeed.is_deepspeed_zero3_enabled():
133 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict(
134 | )
135 | else:
136 | if trainer.args.use_lora:
137 | state_dict = get_peft_state_maybe_zero_3(
138 | trainer.model.named_parameters(), bias)
139 | else:
140 | state_dict = trainer.model.state_dict()
141 | if trainer.args.should_save and trainer.args.local_rank == 0:
142 | trainer._save(output_dir, state_dict=state_dict)
143 |
144 |
145 | @dataclass
146 | class DataCollatorForSupervisedDataset:
147 | """Collate examples for supervised fine-tuning."""
148 |
149 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
150 | instances = [instance['samples'] for instance in instances]
151 | text_input, data_type = tuple(
152 | [instance[key] for instance in instances]
153 | for key in ('text_input', 'data_type'))
154 | if 'image' not in instances[0]:
155 | text_input = [instance['text_input'][0] for instance in instances]
156 | batch = dict(
157 | text_input=text_input,
158 | data_type=data_type,
159 | )
160 | if 'image' in instances[0]:
161 | images = [instance['image'] for instance in instances]
162 | batch['image'] = images
163 |
164 | return dict(samples=batch)
165 |
166 |
167 | def make_supervised_data_module(
168 | tokenizer: transformers.PreTrainedTokenizer,
169 | data_args,
170 | ) -> Dict:
171 | """Make dataset and collator for supervised fine-tuning."""
172 |
173 | rank0_print('Loading data...')
174 | if data_args.data_path.endswith('json'):
175 | train_json = json.load(open(data_args.data_path))
176 | elif data_args.data_path.endswith('txt'):
177 | train_json = {}
178 | with open(data_args.data_path) as f:
179 | lines = f.readlines()
180 |
181 | for line in lines:
182 | line = line.strip()
183 | line = line.split(' ')
184 | with open(line[0]) as f:
185 | temp = json.load(f)
186 | if data_args.given_num:
187 | assert len(line) == 2
188 | if int(line[1]) == -1:
189 | print(f"Load total data from {line[0]} ...")
190 | else:
191 | num = int(float(line[1]) * 1000)
192 | if len(temp) > num:
193 | temp = random.sample(temp, num)
194 | else:
195 | ex_temp = []
196 | for i in range(num - len(temp)):
197 | ex_temp.append(random.choice(temp))
198 | temp.extend(ex_temp)
199 | else:
200 | if len(line) == 2:
201 | ratio = float(line[1])
202 | new_len = int(len(temp) * ratio)
203 | if ratio < 1:
204 | temp = random.sample(temp, new_len)
205 | elif ratio > 1:
206 | ex_temp = []
207 | for i in range(new_len - len(temp)):
208 | ex_temp.append(random.choice(temp))
209 | temp.extend(ex_temp)
210 | rank0_print(f'Load {len(temp)} samples from {line}')
211 | train_json[line[0]] = temp
212 | train_dataset = Mix_dataset(
213 | train_json,
214 | img_size=data_args.img_size,
215 | hd_num=data_args.hd_num,
216 | local_rank=local_rank)
217 | rank0_print(str(len(train_dataset)) + ' samples is loaded')
218 | eval_dataset = None
219 |
220 | data_collator = DataCollatorForSupervisedDataset()
221 | return dict(
222 | train_dataset=train_dataset,
223 | eval_dataset=eval_dataset,
224 | data_collator=data_collator,
225 | )
226 |
227 |
228 | def train():
229 | global local_rank
230 |
231 | parser = transformers.HfArgumentParser(
232 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments))
233 | (
234 | model_args,
235 | data_args,
236 | training_args,
237 | lora_args,
238 | ) = parser.parse_args_into_dataclasses()
239 |
240 | if getattr(training_args, 'deepspeed', None):
241 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED
242 |
243 | local_rank = training_args.local_rank
244 |
245 | device_map = None
246 |
247 | # Set RoPE scaling factor
248 | config = transformers.AutoConfig.from_pretrained(
249 | model_args.model_name_or_path,
250 | cache_dir=training_args.cache_dir,
251 | trust_remote_code=True,
252 | )
253 | config.use_cache = False
254 | config.max_length = training_args.max_length
255 |
256 | if config.attn_implementation == "flash_attention_2":
257 | rank0_print("Use Flash-Attn!!!")
258 | else:
259 | rank0_print("Use Eager Attn!!!")
260 |
261 | # Load model and tokenizer
262 | rank0_print(f'Load model from: {model_args.model_name_or_path}')
263 | model = transformers.AutoModelForCausalLM.from_pretrained(
264 | model_args.model_name_or_path,
265 | config=config,
266 | cache_dir=training_args.cache_dir,
267 | device_map=device_map,
268 | trust_remote_code=True,
269 | )
270 |
271 | if data_args.img_size != 336:
272 | model.vit.resize_pos()
273 |
274 | tokenizer = transformers.AutoTokenizer.from_pretrained(
275 | model_args.model_name_or_path,
276 | cache_dir=training_args.cache_dir,
277 | padding_side='right',
278 | use_fast=False,
279 | trust_remote_code=True,
280 | )
281 | model.tokenizer = tokenizer
282 |
283 | if model_args.moe_aligned_pth_path:
284 | rank0_print("Load Multi-Aligned MoE-MLP...")
285 | mlp_moe_state_dict = torch.load(model_args.moe_aligned_pth_path, map_location="cpu")
286 | num_experts = mlp_moe_state_dict['gate.weight'].size(0)
287 | num_selected = mlp_moe_state_dict.pop('num_selected')
288 | mlp_moe = MLPMoE(num_experts, num_selected, 1024, 4096).to(model.device)
289 | mlp_moe.load_state_dict(mlp_moe_state_dict)
290 | model.vision_proj = deepcopy(mlp_moe)
291 | del mlp_moe
292 |
293 | assert not training_args.fix_sampler, \
294 | "If load aligned moe, set `fix_sampler` to `False`. Because the router must be trained!!!"
295 |
296 |
297 | if training_args.fix_vit:
298 | model.vit.requires_grad_(False)
299 | else:
300 | assert False, "please fix vit~"
301 | # model.vit.requires_grad_(True)
302 | # model.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity(
303 | # )
304 |
305 | if training_args.fix_sampler:
306 | model.vision_proj.requires_grad_(False)
307 | else:
308 | model.vision_proj.requires_grad_(True)
309 |
310 | if training_args.use_lora:
311 | for name, param in model.model.named_parameters():
312 | param.requires_grad = False
313 |
314 | lora_config = LoraConfig(
315 | r=lora_args.lora_r,
316 | lora_alpha=lora_args.lora_alpha,
317 | target_modules=lora_args.lora_target_modules,
318 | lora_dropout=lora_args.lora_dropout,
319 | bias=lora_args.lora_bias,
320 | task_type='CAUSAL_LM',
321 | modules_to_save=['vision_proj'] if not training_args.fix_sampler else None
322 | )
323 |
324 | model = get_peft_model(model, lora_config)
325 | model.print_trainable_parameters()
326 |
327 | if training_args.gradient_checkpointing:
328 | model.enable_input_require_grads()
329 | else:
330 | if training_args.fix_llm:
331 | for name, param in model.model.named_parameters():
332 | param.requires_grad = False
333 | for name, param in model.output.named_parameters():
334 | param.requires_grad = False
335 |
336 | total = 0
337 | training = 0
338 | training_name = []
339 | for name, param in model.named_parameters():
340 | param_num = param.numel()
341 | if param.requires_grad:
342 | training += param_num
343 | training_name.append(name)
344 | total += param_num
345 |
346 | rank0_print(f"Total Params:\t{total / 1e9:.2f}B")
347 | rank0_print(f"Training Params:\t{training / 1e6:.2f}M")
348 | rank0_print(training_name)
349 |
350 | # # Name of Trainable Params
351 | # trainable_params_names = []
352 | # for name,param in model.named_parameters():
353 | # if param.requires_grad:
354 | # trainable_params_names.append(name)
355 | # print(trainable_params_names)
356 |
357 | # Load data
358 | data_module = make_supervised_data_module(
359 | tokenizer=tokenizer, data_args=data_args)
360 | print(f"transformers logging bar enabled status: {transformers.processing_utils.logging.is_progress_bar_enabled()}")
361 | transformers.processing_utils.logging.enable_progress_bar()
362 |
363 | # whether `tune_mm_mlp` or not
364 | if training_args.fix_vit and training_args.fix_llm and (not training_args.fix_sampler) and (not training_args.use_lora):
365 | rank0_print("`tune_mm_mlp` is True!!!")
366 | training_args.tune_mm_mlp = True
367 | # Start trainner
368 | trainer = ChartMoETrainer(
369 | model=model, tokenizer=tokenizer, args=training_args, **data_module
370 | )
371 |
372 | trainer.train()
373 | trainer.save_state()
374 |
375 | safe_save_model_for_hf_trainer(
376 | trainer=trainer,
377 | output_dir=training_args.output_dir,
378 | bias=lora_args.lora_bias)
379 |
380 |
381 | if __name__ == '__main__':
382 | train()
--------------------------------------------------------------------------------
/chartmoe/utils/custom_path.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: Configuration of customized paths
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | # Model
7 | ChartMoE_HF_PATH = 'IDEA-FinAI/chartmoe'
8 |
9 | # Data
10 | ChartQA_ROOT = '/path/to/ChartQA/'
11 | ChartQA_TEST_IMG_ROOT = '/path/to/ChartQA/test/png/'
--------------------------------------------------------------------------------
/examples/bar2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/bar2.png
--------------------------------------------------------------------------------
/examples/bar2_highlight.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/bar2_highlight.png
--------------------------------------------------------------------------------
/examples/line.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line.png
--------------------------------------------------------------------------------
/examples/line3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line3.png
--------------------------------------------------------------------------------
/examples/line3_edit.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line3_edit.png
--------------------------------------------------------------------------------
/examples/pie1-to-bar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/pie1-to-bar.png
--------------------------------------------------------------------------------
/examples/pie1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/pie1.png
--------------------------------------------------------------------------------
/gradio_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: WebUI of ChartMoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | REFERENCE: https://github.com/Coobiw/MPP-LLaVA
6 | """
7 | import argparse
8 | import os
9 | import random
10 |
11 | import numpy as np
12 | import torch
13 | import torch.backends.cudnn as cudnn
14 | import gradio as gr
15 | from PIL import Image
16 |
17 | from functools import partial
18 | from copy import deepcopy
19 |
20 | from chartmoe import ChartMoE_Robot
21 |
22 | def disable_torch_init():
23 | """
24 | Disable the redundant torch default initialization to accelerate model creation.
25 | """
26 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
27 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
28 |
29 | # ========================================
30 | # Model Initialization
31 | # ========================================
32 |
33 | print('Initializing Chat')
34 |
35 |
36 | disable_torch_init()
37 | chat_robot = ChartMoE_Robot()
38 |
39 | print('Initialization Finished')
40 |
41 | # ========================================
42 | # Gradio Setting
43 | # ========================================
44 |
45 |
46 | def gradio_reset(history, img_list):
47 | if history is not None:
48 | history = ""
49 | if img_list is not None:
50 | img_list = []
51 | return None, \
52 | gr.update(value=None, interactive=True), \
53 | gr.update(placeholder='Please upload your image first', interactive=False),\
54 | gr.update(value="Upload & Start Chat", interactive=True), \
55 | history, \
56 | img_list
57 |
58 | def upload_img(gr_img, text_input, history, img_list):
59 | def load_img(image,img_list):
60 | if isinstance(image, str): # is a image path
61 | image = Image.open(image).convert('RGB')
62 | elif isinstance(image, Image.Image):
63 | image = image.convert('RGB')
64 |
65 | img_list.append(image)
66 | msg = "Received."
67 | return msg
68 | if gr_img is None:
69 | return None, None, gr.update(interactive=True), history
70 |
71 | load_img(gr_img, img_list)
72 | return gr.update(interactive=False), \
73 | gr.update(interactive=True, placeholder='Type and press Enter'), \
74 | gr.update(value="Start Chatting", interactive=False), \
75 | history, \
76 | img_list
77 |
78 | def gradio_ask(user_message, chatbot):
79 | if len(user_message) == 0:
80 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, history
81 | chatbot = chatbot + [[user_message, None]]
82 | return user_message, chatbot
83 |
84 |
85 | def gradio_answer(chatbot, text_input, history, img_list, do_sample,num_beams, temperature, max_new_tokens):
86 | generation_config = \
87 | {
88 | "do_sample": do_sample=='True',
89 | "num_beams": num_beams,
90 | 'temperature': temperature,
91 | 'max_new_tokens': max_new_tokens,
92 | }
93 |
94 | image = img_list[0]
95 | with torch.cuda.amp.autocast():
96 | response, history = chat_robot.chat(image=image,question=text_input,history=history,**generation_config)
97 | chatbot[-1][1] = response
98 | text_input = ''
99 | return chatbot, history, img_list, text_input
100 |
101 | title = """Demo of ChartMoE
"""
102 | description = """This is the demo of ChartMoE. Upload your image and start chatting! To use example questions, click example image, hit upload, and press enter in the chatbox.
"""
103 |
104 | from transformers.trainer_utils import set_seed
105 | set_seed(42)
106 | #TODO show examples below
107 |
108 | with gr.Blocks() as demo:
109 | gr.Markdown(title)
110 | gr.Markdown(description)
111 |
112 | with gr.Row():
113 | with gr.Column(scale=0.25):
114 | image = gr.Image(type="pil")
115 | with gr.Row():
116 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
117 | clear = gr.Button("Restart 🔄")
118 |
119 | examples_placeholder = gr.Column()
120 |
121 | with gr.Column(scale=0.75):
122 | history = gr.State(value="")
123 | img_list = gr.State(value=[])
124 | chatbot = gr.Chatbot(
125 | label='ChartMoE',
126 | height=700,
127 | avatar_images=['gradio_demo_pics/user.png','gradio_demo_pics/robot.png']
128 | )
129 |
130 | with gr.Row():
131 | text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False, scale=8)
132 | submit_button = gr.Button(value="Submit", variant="primary",scale=2)
133 |
134 | with gr.Row():
135 | do_sample = gr.components.Radio(['True', 'False'],
136 | label='do_sample',
137 | value='False')
138 |
139 | num_beams = gr.Slider(
140 | minimum=1,
141 | maximum=5,
142 | value=1,
143 | step=1,
144 | interactive=True,
145 | label="num beams",
146 | )
147 |
148 | temperature = gr.Slider(
149 | minimum=0.1,
150 | maximum=2.0,
151 | value=1.0,
152 | step=0.1,
153 | interactive=True,
154 | label="Temperature",
155 | )
156 |
157 | max_new_tokens = gr.Slider(
158 | minimum=128,
159 | maximum=4096,
160 | value=1024,
161 | step=128,
162 | interactive=True,
163 | label="max new tokens",
164 | )
165 |
166 | with examples_placeholder:
167 | gr.Examples(examples=[
168 | ["examples/bar2.png", "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart."],
169 | ["examples/line3.png", "Redraw the chart with python matplotlib, giving the code to highlight data point with lowest growth rate (draw a horizontal dotted line parallel to the x-axi, through the lowest point and add \'lowest\' label in the legend anchor). Please keep the same colors and legend as the input chart."],
170 | ["examples/pie1.png", "Redraw the chart with python matplotlib, convert it into a bar chart, giving the code to reflect the fact that the price of \'Gold\' has been reduced to 27% and the \'Silver\' has been increased to 28%. Please keep the colors and legend according to the input chart."]
171 | ], inputs=[image, text_input])
172 |
173 | upload_button.click(upload_img, [image, text_input, history,img_list], [image, text_input, upload_button, history, img_list])
174 |
175 | # print(list(map(type,[text_input, chatbot])))
176 | # print(list(map(type,[chatbot, history, img_list, do_sample, num_beams, temperature, max_new_tokens])))
177 | text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
178 | gradio_answer, [chatbot, text_input, history, img_list, do_sample, num_beams, temperature, max_new_tokens], [chatbot, history, img_list, text_input]
179 | )
180 | submit_button.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
181 | gradio_answer, [chatbot, text_input, history, img_list, do_sample, num_beams, temperature, max_new_tokens], [chatbot, history, img_list, text_input]
182 | )
183 | clear.click(gradio_reset, [history, img_list], [chatbot, image, text_input, upload_button, history, img_list], queue=False)
184 |
185 | demo.launch(share=True,inbrowser=True)
--------------------------------------------------------------------------------
/gradio_demo_pics/gradio_demo1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo1.jpg
--------------------------------------------------------------------------------
/gradio_demo_pics/gradio_demo2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo2.jpg
--------------------------------------------------------------------------------
/gradio_demo_pics/gradio_demo3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo3.jpg
--------------------------------------------------------------------------------
/gradio_demo_pics/robot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/robot.png
--------------------------------------------------------------------------------
/gradio_demo_pics/user.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/user.png
--------------------------------------------------------------------------------
/quickstart.py:
--------------------------------------------------------------------------------
1 | """
2 | FEATURE: QuickTour of ChartMoE
3 | AUTHOR: Brian Qu
4 | URL: https://arxiv.org/abs/2409.03277
5 | """
6 | from chartmoe import ChartMoE_Robot
7 | import torch
8 |
9 | robot = ChartMoE_Robot()
10 | image_path = "examples/bar2.png"
11 | question = "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart."
12 |
13 | history = ""
14 | with torch.cuda.amp.autocast():
15 | response, history = robot.chat(image_path=image_path, question=question, history=history)
16 |
17 | print(response)
18 |
19 | '''Response:
20 | ```python
21 | import matplotlib.pyplot as plt
22 |
23 | data = [3.3, 3.5, 3.6, 3.8, 3.7, 3.6, 3.8]
24 | years = ['2016', '2017', '2018', '2019', '2020', '2021', '2022']
25 | labels = ['Student A Average GPA']
26 | colors = ['blue']
27 |
28 | plt.bar(years, data, color=colors)
29 | plt.title('Student Performance')
30 | plt.xlabel('Year')
31 | plt.ylabel('Student A Average GPA')
32 | plt.legend(labels)
33 |
34 | # Highlight the year with the highest score
35 | highest_score_index = data.index(max(data))
36 | plt.bar(years[highest_score_index], data[highest_score_index], color='red')
37 |
38 | plt.show()
39 | ```
40 |
41 | '''
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.33.2
2 | timm==0.4.12
3 | sentencepiece==0.1.99
4 | gradio==4.13.0
5 | markdown2==2.4.10
6 | xlsxwriter==3.1.2
7 | einops==0.8.0
8 | deepspeed==0.14.2
9 | peft==0.10.0
10 | prettytable==3.10.2
11 | tqdm
12 | datasets==2.21.0
13 | python-Levenshtein
14 | opencv-python>=4.10.0
15 | wandb==0.19.6
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_namespace_packages
2 | import platform
3 |
4 | DEPENDENCY_LINKS = []
5 | if platform.system() == "Windows":
6 | DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html")
7 |
8 |
9 | def fetch_requirements(filename):
10 | with open(filename) as f:
11 | return [ln.strip() for ln in f.read().split("\n")]
12 |
13 |
14 | setup(
15 | name="chartmoe",
16 | version="0.1.0",
17 | author="Coobiw",
18 | description="ChartMoE: Mixture of Expert Connector for Better Chart Understanding",
19 | keywords="Multimodal Large Language Model (MLLM), Chart Understanding, Mixture of Expert (MoE)",
20 | license="3-Clause BSD",
21 | packages=find_namespace_packages(include="chartmoe.*"),
22 | install_requires=fetch_requirements("requirements.txt"),
23 | python_requires=">=3.9",
24 | include_package_data=True,
25 | dependency_links=DEPENDENCY_LINKS,
26 | zip_safe=False,
27 | )
--------------------------------------------------------------------------------