├── .gitignore ├── README.md ├── asset ├── chartmoe-logo.jpg ├── teaser.png └── train_pipeline.png ├── chartmoe ├── __init__.py ├── eval_ChartQA.py ├── eval_MME.ipynb ├── generation_utils.py ├── train │ ├── README.md │ ├── chartmoe_construction.py │ ├── chartmoe_trainer.py │ ├── data │ │ ├── code_align.txt │ │ ├── json_align.txt │ │ ├── sft.txt │ │ └── table_align.txt │ ├── data_mix.py │ ├── ds_config_zero2.json │ ├── mlp_moe.py │ ├── moe_construction.py │ ├── scripts │ │ ├── chartmoe_construction.sh │ │ ├── chartmoe_data_download.py │ │ ├── chartmoe_download.py │ │ ├── code_align.sh │ │ ├── internlm_xc2_download.py │ │ ├── json_align.sh │ │ ├── moe_construction.sh │ │ ├── multi_align.sh │ │ ├── sft.sh │ │ └── table_align.sh │ └── train.py └── utils │ └── custom_path.py ├── examples ├── bar2.png ├── bar2_highlight.png ├── line.png ├── line3.png ├── line3_edit.png ├── pie1-to-bar.png └── pie1.png ├── gradio_demo.py ├── gradio_demo_pics ├── gradio_demo1.jpg ├── gradio_demo2.jpg ├── gradio_demo3.jpg ├── robot.png └── user.png ├── quickstart.py ├── requirements.txt ├── setup.py └── vis_tokens.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # project-specific 148 | output/ 149 | debug*/ 150 | results/ 151 | 152 | # cache root 153 | cache/ 154 | 155 | # DS_Store 156 | **/.DS_Store 157 | 158 | # pycharm .idea 159 | .idea/ 160 | 161 | # wandb 162 | wandb/ 163 | 164 | # training dirs 165 | chartmoe/train/debug/ 166 | chartmoe/train/wandb/ 167 | chartmoe/train/output/ 168 | chartmoe/train/logs/ 169 | chartmoe/train/ckpt/* 170 | chartmoe/train/data/*.json 171 | chartmoe/train/data/ChartMoE-Align 172 | chartmoe/train/data/SFT -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

ChartMoE

4 |

Mixture of Diversely Aligned Expert Connector for Chart Understanding

5 | 6 | [Zhengzhuo Xu](https://github.com/XuZhengzhuo)1,2\*, [Bowen Qu](https://github.com/Coobiw)1,3\*, Yiyan Qi1\*, Sinan Du2, Chengjin Xu1, Chun Yuan2, Jian Guo1,4 7 | 8 | 1 International Digital Economy Academy (IDEA), 9 | 2 Tsinghua University, 10 | 3 Peking University, 11 | 12 | 4 Hong Kong University of Science and Technology, Guangzhou 13 | 14 | ICLR 2025 Oral 15 | 16 | (\* equal contribution) 17 | 18 | [![arXiv](https://img.shields.io/badge/ArXiv-Prepint-red)](https://arxiv.org/abs/2409.03277) 19 | [![Project Page](https://img.shields.io/badge/Project-Page-brightgreen)](https://chartmoe.github.io/) 20 | [![Hugging Face Model](https://img.shields.io/badge/Hugging%20Face-Model-blue)](https://huggingface.co/IDEA-FinAI/chartmoe) 21 | [![Hugging Face Dataset](https://img.shields.io/badge/Hugging%20Face-Dataset-8A2BE2)](https://huggingface.co/datasets/Coobiw/ChartMoE-Data) 22 | [![Zhihu](https://img.shields.io/badge/Blog-Zhihu-00BFFF)](https://zhuanlan.zhihu.com/p/31634026232) 23 | [![机器之心](https://img.shields.io/badge/Blog-机器之心-black)](https://mp.weixin.qq.com/s/9anQbcCahVLnXhNj7aU48Q)
24 | [![closed issue](https://img.shields.io/github/issues-closed-raw/IDEA-FinAI/ChartMoE)](https://github.com/IDEA-FinAI/ChartMoE/issues) 25 | [![open issues](https://img.shields.io/github/issues-raw/IDEA-FinAI/ChartMoE)](https://github.com/IDEA-FinAI/ChartMoE/issues) 26 | 27 | *If you have any question, feel free to contact [📧](mailto:brian.bw.qu@gmail.com).* 28 | 29 |
30 | 31 | ![](./asset/teaser.png) 32 | 33 | **ChartMoE** is a multimodal large language model with Mixture-of-Expert connector for advanced chart 1)understanding, 2)replot, 3)editing, 4)highlighting and 5)transformation. 34 | 35 | ## News 36 | - 2025.3.6: A reproduction of diversely aligned moe-connector is released at [🤗HF Link](https://huggingface.co/Coobiw/ChartMoE-Aligned-Connector/tree/main/moe_aligned). 37 | - 2025.2.16: ChartMoE-Data has been released at [🤗](https://huggingface.co/datasets/Coobiw/ChartMoE-Data). Please download it according to [our instruction](#download-and-organize-the-chartmoe-data). 38 | - 2025.2.15: Training codes and recipes are released! Please refer to [📖](chartmoe/train/)! 39 | - 2025.2.11: 🎉🎉🎉 ChartMoE is selected as **ICLR2025 Oral(1.8%)**! 40 | - 2025.1.23: 🎉🎉🎉 ChartMoE is accepted by **ICLR2025**! 41 | - 2024.9.10: We release ChartMoE! 42 | 43 | ## Training of ChartMoE 44 | Please refer to [📖training readme](chartmoe/train/)! 45 | 46 | ## Download and Organize the ChartMoE-Data 47 | [🤗ChartMoE Data](https://huggingface.co/datasets/Coobiw/ChartMoE-Data) has been released! You can download it by running: 48 | 49 | ```bash 50 | cd chartmoe/train 51 | python scripts/chartmoe_data_download.py 52 | ``` 53 | Datasets will appear at `chartmoe/train/data`. 54 | 55 | Then, please unzip these two files. 56 | ```bash 57 | unzip ChartMoE-Align.zip 58 | unzip SFT.zip 59 | ``` 60 | 61 | Additionally, I want to announce that the `ChartY_replot` in `ChartMoE-Align` contains data with higher quality and bilingual texts! It may be a good choice to sample more from `ChartY_replot`. 62 | 63 | ## Installation 64 | **Step 1.** Create a conda environment and activate it. 65 | 66 | ```bash 67 | conda create -n chartmoe_env python=3.9 68 | conda activate chartmoe_env 69 | ``` 70 | 71 | **Step 2.** Install PyTorch (We use PyTorch 2.1.0 / CUDA 12.1) 72 | 73 | ```bash 74 | pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121 75 | ``` 76 | 77 | **Step 3.** Install require packages 78 | 79 | ```bash 80 | pip install -r requirements.txt 81 | ``` 82 | 83 | **Step 4.** Install editable ChartMoE packages 84 | 85 | ```bash 86 | pip install -e . 87 | ``` 88 | 89 | **Step 5.** (Optional) Install Flash-Attn (cuda > 11.7) 90 | 91 | ```bash 92 | pip install flash-attn==2.7.0.post2 93 | ``` 94 | 95 | *Flash-Attn can bring ~30% accleration on training and ~20% on evaluation in our experiments.* 96 | 97 | p.s.: If you cannot install `flash-attn`, please set `attn_implementation` to `eager` in ChartMoE's [`config.json`](https://huggingface.co/IDEA-FinAI/chartmoe/blob/main/config.json#L10). 98 | 99 | ## Quick Start 100 | ### Huggingface Download Script of ChartMoE 101 | 102 | *Note: I've supported `flash-attn` for ChartMoE on Feb. 15. If you download chartmoe before this date, you can re-download it for acceleration.* 103 | 104 | Run: 105 | ```bash 106 | cd chartmoe/train 107 | python scripts/chartmoe_download.py 108 | ``` 109 | Then, ChartMoE will appear at `chartmoe/train/ckpt/chartmoe`. 110 | 111 | ### Customize the weight path of ChartMoE 112 | 113 | Set your own [ChartMoE_HF_PATH](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L2). I suggest to use the absolute path of `chartmoe/train/ckpt/chartmoe`. 114 | 115 | ### Code Demo 116 | 117 | ```python 118 | from chartmoe import ChartMoE_Robot 119 | import torch 120 | 121 | robot = ChartMoE_Robot() 122 | image_path = "examples/bar2.png" 123 | question = "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart." 124 | 125 | history = "" 126 | with torch.cuda.amp.autocast(): 127 | response, history = robot.chat(image_path=image_path, question=question, history=history) 128 | 129 | print(response) 130 | ``` 131 | 132 | ## Evaluation 133 | 134 | ### ChartQA 135 | **Customize the path of ChartQA:** 136 | 137 | Set your own [ChartQA_ROOT](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L5)(including `test_human.json` and `test_augmented.json`) and [ChartQA_TEST_IMG_ROOT](https://github.com/Coobiw/ChartMoE/tree/master/chartmoe/utils/custom_path.py#L6)(including the test images). 138 | 139 | **w/ PoT:** 140 | 141 | ```bash 142 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --save_path ./results/chartqa_results_pot --pot 143 | ``` 144 | 145 | **w/o PoT:** 146 | 147 | ```bash 148 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --save_path ./results/chartqa_results 149 | ``` 150 | 151 | ### MME 152 | Run `chartmoe/eval_MME.ipynb` for MME scores. 153 | 154 | ## WebUI Demo 155 | 156 | ```bash 157 | CUDA_VISIBLE_DEVICES=0 python gradio_demo.py 158 | ``` 159 | 160 | ![](./gradio_demo_pics/gradio_demo1.jpg) 161 | 162 | ## FAQs 163 | Q1: [CLIP: Input image size (490x490) doesn't match model (336x336)](https://github.com/IDEA-FinAI/ChartMoE/issues/6) 164 | 165 | A1: Please degrade your `transformers` according to `requiresments.txt`. 166 | 167 | ## Acknowledgement 168 | Thanks to [InternLM-XComposer2](https://github.com/InternLM/InternLM-XComposer/tree/main/InternLM-XComposer-2.0) and [CuMo](https://github.com/SHI-Labs/CuMo) for their releases of model weights and source codes! And thanks to [MMC](https://github.com/FuxiaoLiu/MMC) and [ChartGemma](https://github.com/vis-nlp/ChartGemma) for their releases of the high-quality instruction-tuning data! 169 | 170 | ## Citation 171 | If you find our idea or code inspiring, please cite our paper: 172 | ```bibtex 173 | @article{ChartMoE, 174 | title={ChartMoE: Mixture of Diversely Aligned Expert Connector for Chart Understanding}, 175 | author={Zhengzhuo Xu and Bowen Qu and Yiyan Qi and Sinan Du and Chengjin Xu and Chun Yuan and Jian Guo}, 176 | journal={ArXiv}, 177 | year={2024}, 178 | volume={abs/2409.03277}, 179 | } 180 | ``` 181 | This code is partially based on [ChartBench](https://chartbench.github.io/), if you use our code, please also cite: 182 | ```bibtex 183 | @article{ChartBench, 184 | title={ChartBench: A Benchmark for Complex Visual Reasoning in Charts}, 185 | author={Zhengzhuo Xu and Sinan Du and Yiyan Qi and Chengjin Xu and Chun Yuan and Jian Guo}, 186 | journal={ArXiv}, 187 | year={2023}, 188 | volume={abs/2312.15915}, 189 | } 190 | ``` 191 | -------------------------------------------------------------------------------- /asset/chartmoe-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/chartmoe-logo.jpg -------------------------------------------------------------------------------- /asset/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/teaser.png -------------------------------------------------------------------------------- /asset/train_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/asset/train_pipeline.png -------------------------------------------------------------------------------- /chartmoe/__init__.py: -------------------------------------------------------------------------------- 1 | from chartmoe.generation_utils import ChartMoE_Robot -------------------------------------------------------------------------------- /chartmoe/eval_ChartQA.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: ChartQA Evaluation of ChartMoE, with/without PoT(Program of Thoughts) 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | """ 6 | from chartmoe import ChartMoE_Robot 7 | from chartmoe.utils.custom_path import ChartQA_ROOT, ChartQA_TEST_IMG_ROOT 8 | 9 | import os, sys, json, re, io 10 | import argparse 11 | import torch 12 | from tqdm import tqdm 13 | from typing import Optional 14 | from prettytable import PrettyTable 15 | 16 | def relaxed_acc(prediction: str, target: str, 17 | max_relative_change: float = 0.05) -> bool: 18 | 19 | def _to_float(text: str) -> Optional[float]: 20 | try: 21 | match = re.search(r'[\d.]+', text.replace(',', '')) 22 | if match: return float(match.group()) 23 | return None 24 | except ValueError: 25 | return None 26 | 27 | prediction_float = _to_float(prediction) 28 | target_float = _to_float(target) 29 | 30 | if prediction_float is not None and target_float is not None: 31 | if target_float == 0: 32 | relative_change = abs(prediction_float - target_float) 33 | else: 34 | relative_change = abs(prediction_float - target_float) / abs(target_float) 35 | return relative_change <= max_relative_change 36 | else: 37 | lp = prediction.lower() 38 | tp = target.lower() 39 | 40 | if ("yes" in lp and "yes" in tp) or ("no" in lp and "no" in tp): return True 41 | if lp in tp: return True 42 | return lp == tp 43 | 44 | def evaluate_relaxed_accuracy(entries, margin=0.05): 45 | scores = [] 46 | for elem in entries: 47 | if isinstance(elem['annotation'], str): 48 | elem['annotation'] = [elem['annotation']] 49 | score = max([ 50 | relaxed_acc(elem['answer'].strip(), ann, margin) 51 | for ann in elem['annotation'] 52 | ]) 53 | scores.append(score) 54 | 55 | return sum(scores) / len(scores) 56 | 57 | def execute_python_code(code): 58 | old_stdout = sys.stdout 59 | new_stdout = io.StringIO() 60 | sys.stdout = new_stdout 61 | 62 | status = True 63 | try: 64 | exec(code) 65 | except Exception as e: 66 | status = False 67 | finally: 68 | sys.stdout = old_stdout 69 | 70 | if status: 71 | output = new_stdout.getvalue() 72 | else: 73 | output = None 74 | return output, status 75 | 76 | def extract_python_content(text): 77 | pattern = r"```python(.*?)```" 78 | matches = re.findall(pattern, text, re.DOTALL) 79 | return matches 80 | 81 | class ChartQATester: 82 | 83 | def __init__(self, ckpt_path=None, pot=False, pot_idx=0): 84 | # ChartQA root 85 | self.root = ChartQA_ROOT 86 | self.vis_root = ChartQA_TEST_IMG_ROOT 87 | 88 | self.robot = ChartMoE_Robot(ckpt_path=ckpt_path) 89 | self.prompt = '[UNUSED_TOKEN_146]user\nAnswer the question using a single word or phrase.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n' 90 | pot_prompts = [ 91 | '[UNUSED_TOKEN_146]user\nPlease give the program of thought.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n', 92 | '[UNUSED_TOKEN_146]user\nPlease give the program of thought in python code. Use print function to output the answer in the end.{}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n', 93 | ] 94 | self.pot_prompt = pot_prompts[pot_idx] 95 | 96 | self.pot = pot 97 | if self.pot: 98 | self.robot.reset_prompt(prompt=self.pot_prompt) 99 | else: 100 | self.robot.reset_prompt(prompt=self.prompt) 101 | 102 | def reset_prompt(self, p): 103 | self.system_prompt = p 104 | 105 | def infer_all_answers(self, output_path): 106 | 107 | os.makedirs(output_path, exist_ok=True) 108 | print(f"Result will be saved at: {output_path}") 109 | 110 | part_acc = [] 111 | for part_name in ['human', 'augmented']: 112 | part_json = os.path.join(output_path, f"{part_name}.json") 113 | if os.path.exists(part_json): 114 | print(f"Load result from: {part_json}") 115 | part = json.load(open(part_json, 'r')) 116 | else: 117 | part = [] 118 | samples = json.load(open(self.root+f'test/test_{part_name}.json')) 119 | for q in tqdm(samples): 120 | im_path = os.path.join(self.vis_root, q['imgname']) 121 | question = q['query'] 122 | 123 | with torch.cuda.amp.autocast(): 124 | response, _ = self.robot.chat( 125 | image_path=im_path, 126 | question=question, 127 | max_new_tokens=500, 128 | num_beams=1, 129 | ) 130 | if self.pot: 131 | extraced_result = extract_python_content(response) 132 | if extraced_result: 133 | code = extraced_result[0] 134 | else: 135 | code = response 136 | response, status = execute_python_code(code) 137 | 138 | if not status: 139 | response = "error running..." 140 | response = response.replace("True","Yes").replace("False","No") 141 | response = response.strip() 142 | part.append({ 143 | 'image': im_path, 144 | 'query': question, 145 | 'answer': response, 146 | 'annotation': q['label'], 147 | 'code': code 148 | }) 149 | else: 150 | part.append({ 151 | 'image': im_path, 152 | 'query': question, 153 | 'answer': response, 154 | 'annotation': q['label'] 155 | }) 156 | with open(part_json, 'w') as f: 157 | json.dump(part, f, indent=4) 158 | part_acc.append(part) 159 | 160 | table = PrettyTable() 161 | table.field_names = ["@AP", "0.05", "0.1", "0.2"] 162 | human_row = ["Human"] 163 | augmented_row = ["Augmented"] 164 | averaged_row = ["Averaged"] 165 | for ap in [0.05, 0.1, 0.2]: 166 | part_acc_ap = [evaluate_relaxed_accuracy(p, ap) for p in part_acc] 167 | human_acc = part_acc_ap[0] 168 | augmented_acc = part_acc_ap[1] 169 | averaged_acc = (human_acc + augmented_acc) / 2 170 | human_row.append(human_acc) 171 | augmented_row.append(augmented_acc) 172 | averaged_row.append(averaged_acc) 173 | 174 | table.add_row(human_row) 175 | table.add_row(augmented_row) 176 | table.add_row(averaged_row) 177 | 178 | table_path = os.path.join(output_path, 'table.txt') 179 | with open(table_path, 'w') as f: 180 | f.write(str(table)) 181 | 182 | print(table) 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser() 186 | parser.add_argument("--save_path", type=str, required=True) 187 | parser.add_argument("--ckpt_path", type=str, default=None) 188 | parser.add_argument('--pot', action='store_true') 189 | parser.add_argument('--pot_idx', type=int, default=0, choices=[0, 1]) 190 | args = parser.parse_args() 191 | 192 | tester = ChartQATester(ckpt_path=args.ckpt_path, pot=args.pot, pot_idx=args.pot_idx) 193 | tester.infer_all_answers(output_path=args.save_path) -------------------------------------------------------------------------------- /chartmoe/eval_MME.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 13 | " from .autonotebook import tqdm as notebook_tqdm\n", 14 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 15 | " warnings.warn(\n", 16 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 17 | "- tokenization_internlm_xcomposer2.py\n", 18 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 19 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 20 | "- configuration_chartmoe.py\n", 21 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 22 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 23 | "- build_mlp.py\n", 24 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 25 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 26 | "- modeling_internlm2.py\n", 27 | "- build_mlp.py\n", 28 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 29 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 30 | "- build_moe_connector.py\n", 31 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 32 | "A new version of the following files was downloaded from https://huggingface.co/IDEA-FinAI/chartmoe:\n", 33 | "- modeling_chartmoe.py\n", 34 | "- modeling_internlm2.py\n", 35 | "- build_moe_connector.py\n", 36 | ". Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.\n", 37 | "/data/FinAi_Mapping_Knowledge/qiyiyan/qbw/anaconda3/envs/intern_clean/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", 38 | " warnings.warn(\n", 39 | "Downloading shards: 100%|██████████| 2/2 [00:00<00:00, 4.23it/s]\n" 40 | ] 41 | }, 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "Set max length to 4096\n" 47 | ] 48 | }, 49 | { 50 | "name": "stderr", 51 | "output_type": "stream", 52 | "text": [ 53 | "Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00, 4.16s/it]\n" 54 | ] 55 | } 56 | ], 57 | "source": [ 58 | "import os\n", 59 | "os.environ['CUDA_VISIBLE_DEVICES']='2'\n", 60 | "import sys \n", 61 | "import json\n", 62 | "import torch\n", 63 | "import numpy as np\n", 64 | "from PIL import Image \n", 65 | "from tqdm import tqdm\n", 66 | "import datetime\n", 67 | "from collections import defaultdict\n", 68 | "\n", 69 | "\n", 70 | "from datasets import load_dataset\n", 71 | "from chartmoe import ChartMoE_Robot\n", 72 | "\n", 73 | "mme_data = load_dataset(\"lmms-lab/MME\")['test']\n", 74 | "\n", 75 | "robot = ChartMoE_Robot()" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 2, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "eval_type_dict = {\n", 85 | " \"Perception\": [\n", 86 | " \"existence\",\n", 87 | " \"count\",\n", 88 | " \"position\",\n", 89 | " \"color\",\n", 90 | " \"posters\",\n", 91 | " \"celebrity\",\n", 92 | " \"scene\",\n", 93 | " \"landmark\",\n", 94 | " \"artwork\",\n", 95 | " \"OCR\",\n", 96 | " ],\n", 97 | " \"Cognition\": [\n", 98 | " \"commonsense_reasoning\",\n", 99 | " \"numerical_calculation\",\n", 100 | " \"text_translation\",\n", 101 | " \"code_reasoning\",\n", 102 | " ],\n", 103 | "}" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 3, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "def parse_pred_ans(pred_ans):\n", 113 | " \"\"\"Brought from Otter Eval\"\"\"\n", 114 | " pred_ans = pred_ans.lower().strip().replace(\".\", \"\")\n", 115 | " pred_label = None\n", 116 | " if pred_ans in [\"yes\", \"no\"]:\n", 117 | " pred_label = pred_ans\n", 118 | " elif len(pred_ans) == 1:\n", 119 | " if pred_ans == \"y\":\n", 120 | " pred_label = \"yes\"\n", 121 | " elif pred_ans == \"n\":\n", 122 | " pred_label = \"no\"\n", 123 | " else:\n", 124 | " pred_label = \"other\"\n", 125 | " else:\n", 126 | " prefix_pred_ans = pred_ans[:4]\n", 127 | " if \"yes\" in prefix_pred_ans:\n", 128 | " pred_label = \"yes\"\n", 129 | " elif \"no\" in prefix_pred_ans:\n", 130 | " pred_label = \"no\"\n", 131 | " else:\n", 132 | " pred_label = \"other\"\n", 133 | " return pred_label" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 4, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stderr", 143 | "output_type": "stream", 144 | "text": [ 145 | " 0%| | 0/2374 [00:00 0: 80 | text_embeds = self.model.encode_text(subtext, add_special_tokens=need_bos) 81 | embeds.append(text_embeds) 82 | im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda()) 83 | need_bos = False 84 | if i < len(images): 85 | try: 86 | image = Image.open(images[i]).convert('RGB') 87 | except: 88 | image = images[i].convert('RGB') 89 | if self.img_padding: 90 | image = __padding__(image) 91 | image = self.model.vis_processor(image).unsqueeze(0).cuda() 92 | image_embeds = self.model.encode_img(image) 93 | embeds.append(image_embeds) 94 | im_mask.append(torch.ones(image_embeds.shape[:2]).cuda()) 95 | pt1 = pts 96 | embeds = torch.cat(embeds, dim=1) 97 | im_mask = torch.cat(im_mask, dim=1) 98 | im_mask = im_mask.bool() 99 | 100 | eos_token_id = [ 101 | self.tokenizer.convert_tokens_to_ids(['[UNUSED_TOKEN_145]'])[0], 102 | self.tokenizer.eos_token_id, 103 | ] 104 | outputs = self.model.generate( 105 | inputs_embeds=embeds, 106 | im_mask=im_mask, 107 | temperature=temperature, 108 | max_new_tokens=max_new_tokens, 109 | num_beams=num_beams, 110 | do_sample=do_sample, 111 | repetition_penalty=repetition_penalty, 112 | eos_token_id=eos_token_id, 113 | ) 114 | 115 | output_token = outputs[0] 116 | if output_token[0] == 0 or output_token[0] == 1: 117 | output_token = output_token[1:] 118 | output_text = self.model.tokenizer.decode(output_token, add_special_tokens=False) 119 | history += output_text 120 | output_text = output_text.split('[UNUSED_TOKEN_145]')[0].strip() 121 | return output_text, history -------------------------------------------------------------------------------- /chartmoe/train/README.md: -------------------------------------------------------------------------------- 1 |
2 |

Training Recipes of ChartMoE

3 |
4 | 5 |
6 |

Datasets are released at 🤗https://huggingface.co/datasets/Coobiw/ChartMoE-Data!

7 |
8 | 9 | In this part, I'll introduct the training recipes for reproducing ChartMoE. Except for the training recipes, I also provided a checkpoint that can be reproduced according to following instructions. You can find it at [🤗](https://huggingface.co/Coobiw/ChartMoE_Reproduced). **This version has better performance on ChartQA(both with & without PoT).** 10 | 11 | 12 | ## Download and Organize the ChartMoE-Data 13 | [🤗ChartMoE Data](https://huggingface.co/datasets/Coobiw/ChartMoE-Data) has been released! You can download it by running: 14 | 15 | ```bash 16 | cd chartmoe/train 17 | python scripts/chartmoe_data_download.py 18 | ``` 19 | Datasets will appear at `chartmoe/train/data`. 20 | 21 | Then, please unzip these two files. 22 | ```bash 23 | unzip ChartMoE-Align.zip 24 | unzip SFT.zip 25 | ``` 26 | 27 | Additionally, I want to announce that the `ChartY_replot` in `ChartMoE-Align` contains data with higher quality and bilingual texts! It may be a good choice to sample more from `ChartY_replot`. 28 | 29 | ### Data Format 30 | ```python 31 | [ 32 | { 33 | "id": "0", 34 | "image": ['path/to/image_0.jpg'] 35 | "conversations": [ 36 | { 37 | "from": "user", 38 | "value": " Please describe these two images in detail." 39 | }, 40 | { 41 | "from": "assistant", 42 | "value": "......" 43 | } 44 | ] 45 | }, 46 | { 47 | "id": "1", 48 | "image": ['path/to/image_1.jpg'] 49 | "conversations": [ 50 | { 51 | "from": "user", 52 | "value": " what is the color of the dog" 53 | }, 54 | { 55 | "from": "assistant", 56 | "value": "it is ...." 57 | } 58 | ] 59 | } 60 | ] 61 | ``` 62 | 63 | ## Download InternLM_XComposer2_Enhanced 64 | 65 | **Note: I've supported `flash-attn` and `batchified training` for InternLM-XComposer2 on [Coobiw/InternLM-XComposer2_Enhanced](https://huggingface.co/Coobiw/InternLM-XComposer2_Enhanced). This will indeed acclerate training.** 66 | 67 | Run: 68 | 69 | ```bash 70 | cd chartmoe/train 71 | python scripts/internlm_xc2_download.py 72 | ``` 73 | 74 | Then, ChartMoE will appear at `chartmoe/train/ckpt/InternLM-XComposer2_Enhanced`. 75 | 76 | ## Diversely-Aligned MoE-MLP Training 77 | 78 | ### Download the intermediate checkpoint 79 | I've uploaded the weight of the moe-connector which is diversely aligned (each of the experts is trainable, but the router is randomly initialized). [🤗HF Link](https://huggingface.co/Coobiw/ChartMoE-Aligned-Connector/tree/main). Please put the `mlp_moe.pth` to `chartmoe/train/output/moe_aligned/mlp_moe.pth`! Then you can directly run sft script~ 80 | 81 | ### Training Pipeline of ChartMoE 82 | If you want to train your own moe-connecor, you can feel free to follow these instructions! 83 | 84 | ![Overview](../../asset/train_pipeline.png) 85 | 86 | Run: 87 | 88 | ```bash 89 | cd chartmoe/train 90 | bash scripts/multi_align.sh 91 | ``` 92 | 93 | Then, the table/json/code MLP connector will appear at `chartmoe/train/output/{}_proj`.format(table/json/code)! 94 | 95 | After diversely alignment, we can construct the MoE-MLP connector by running: 96 | 97 | ```bash 98 | cd chartmoe/train 99 | bash scripts/moe_construction.sh 100 | ``` 101 | 102 | The MoE-MLP connnector will appear at `chartmoe/train/output/moe_aligned/mlp_moe.pth`. 103 | 104 | ## SFT 105 | 106 | *Note: In this Repo, we don't add "High-Quality Knowledge Learning" mid-training.* 107 | 108 | Please notice [the path of MoE-MLP connector](./scripts/sft.sh#L24). 109 | 110 | Run: 111 | 112 | ```bash 113 | cd chartmoe/train 114 | mkdir -p logs/sft 115 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/sft.sh 2>&1 | tee logs/sft/tee_logs.txt 116 | ``` 117 | 118 | ## Merge MLP-MoE Connector and LoRA Weights for ChartMoE Construction 119 | Run: 120 | 121 | ```bash 122 | cd chartmoe/train 123 | bash scripts/chartmoe_construction.sh 124 | ``` 125 | 126 | ## Evaluation on ChartQA 127 | w/o PoT: 128 | 129 | ```bash 130 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --ckpt_path chartmoe/train/output/sft/chartmoe_reproduced --save_path chartmoe/train/output/sft/chartmoe_reproduced/ChartQA_wo-PoT 131 | ``` 132 | 133 | Result: 134 | ``` 135 | +-----------+--------+--------+--------+ 136 | | @AP | 0.05 | 0.1 | 0.2 | 137 | +-----------+--------+--------+--------+ 138 | | Human | 0.704 | 0.7376 | 0.772 | 139 | | Augmented | 0.9056 | 0.9192 | 0.9352 | 140 | | Averaged | 0.8048 | 0.8284 | 0.8536 | 141 | +-----------+--------+--------+--------+ 142 | ``` 143 | 144 | PoT: 145 | ```bash 146 | CUDA_VISIBLE_DEVICES=0 python chartmoe/eval_ChartQA.py --ckpt_path chartmoe/train/output/sft/chartmoe_reproduced --save_path chartmoe/train/output/sft/chartmoe_reproduced/ChartQA_PoT --pot --pot_idx 1 147 | ``` 148 | 149 | Result: 150 | ``` 151 | +-----------+--------+--------+-------+ 152 | | @AP | 0.05 | 0.1 | 0.2 | 153 | +-----------+--------+--------+-------+ 154 | | Human | 0.7952 | 0.8128 | 0.828 | 155 | | Augmented | 0.904 | 0.9176 | 0.932 | 156 | | Averaged | 0.8496 | 0.8652 | 0.88 | 157 | +-----------+--------+--------+-------+ 158 | ``` 159 | -------------------------------------------------------------------------------- /chartmoe/train/chartmoe_construction.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Construct ChartMoE after Post-Training, including: 1. Merge LoRA to LLM 2. Adapt to ChartMoE HF Implementation 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | """ 6 | from dataclasses import dataclass, field 7 | from typing import Optional 8 | 9 | import torch 10 | from peft import PeftConfig, PeftModel 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 12 | 13 | from mlp_moe import MLPMoE 14 | import os 15 | import shutil 16 | from glob import glob 17 | 18 | 19 | @dataclass 20 | class ScriptArguments: 21 | """The input names representing the Adapter and Base model fine-tuned with 22 | PEFT, and the output name representing the merged model.""" 23 | 24 | moe_aligned_pth_path: Optional[str] = field( 25 | default=None, metadata={'help': 'the path of aligned moe .pth file'} 26 | ) 27 | chartmoe_hf_dir: Optional[str] = field( 28 | default=None, metadata={'help': 'the path of downloaded chartmoe hf dir'} 29 | ) 30 | adapter_model_name: Optional[str] = field( 31 | default=None, metadata={'help': 'the adapter name'} 32 | ) 33 | output_path: Optional[str] = field( 34 | default=None, metadata={'help': 'the merged model saved path'} 35 | ) 36 | 37 | 38 | parser = HfArgumentParser(ScriptArguments) 39 | script_args = parser.parse_args_into_dataclasses()[0] 40 | assert script_args.moe_aligned_pth_path is not None, 'please provide the path of aligned moe .pth file' 41 | assert script_args.adapter_model_name is not None, 'please provide the name of the Adapter you would like to merge' 42 | assert script_args.output_path is not None, 'please provide the the merged model saved path' 43 | assert script_args.chartmoe_hf_dir is not None, 'please provide the path of downloaded chartmoe hf dir' 44 | 45 | adapter_model_name = glob(os.path.join(script_args.adapter_model_name, "checkpoint-*"))[0] 46 | 47 | # get base model path from adapter_config.json 48 | peft_config = PeftConfig.from_pretrained(adapter_model_name) 49 | base_model_path = peft_config.base_model_name_or_path 50 | print(f"\033[31mLoad base model from {base_model_path}\033[0m") 51 | model = AutoModelForCausalLM.from_pretrained( 52 | base_model_path, 53 | return_dict=True, 54 | torch_dtype=torch.bfloat16, 55 | trust_remote_code=True, 56 | device_map="cuda", 57 | attn_implementation="eager", 58 | ) 59 | 60 | mlp_moe_state_dict = torch.load(script_args.moe_aligned_pth_path, map_location="cpu") 61 | 62 | num_experts = mlp_moe_state_dict['gate.weight'].size(0) 63 | num_selected = mlp_moe_state_dict.pop('num_selected') 64 | 65 | mlp_moe = MLPMoE(num_experts, num_selected, 1024, 4096).to(model.device) 66 | mlp_moe.load_state_dict(mlp_moe_state_dict) 67 | 68 | print("\033[32mload aligned moe...\033[0m") 69 | model.vision_proj = mlp_moe 70 | 71 | tokenizer = AutoTokenizer.from_pretrained( 72 | base_model_path, trust_remote_code=True 73 | ) 74 | 75 | # Load the PEFT model 76 | model = PeftModel.from_pretrained(model, adapter_model_name) 77 | model.eval() 78 | 79 | print("\033[33mmerge the lora weights and `modules_to_save` weights\033[0m") 80 | model = model.merge_and_unload() 81 | 82 | model.save_pretrained(f'{script_args.output_path}') 83 | tokenizer.save_pretrained(f'{script_args.output_path}') 84 | 85 | print("\033[34msave and adapt to `ChartMoE` format...\033[0m") 86 | # adjust the contained files 87 | chartmoe_files = [ 88 | 'special_tokens_map.json', 89 | 'configuration_chartmoe.py', 90 | 'modeling_internlm2.py', 91 | 'README.md', 92 | 'config.json', 93 | 'generation_config.json', 94 | '.gitattributes', 95 | 'teaser.png', 96 | 'zero_to_fp32.py', 97 | 'pytorch_model.bin.index.json', 98 | 'tokenization_internlm_xcomposer2.py', 99 | 'build_mlp.py', 100 | 'tokenizer_config.json', 101 | 'build_moe_connector.py', 102 | 'tokenizer.model', 103 | 'modeling_chartmoe.py', 104 | ] 105 | keep_files = [ 106 | 'pytorch_model-00001-of-00002.bin', 107 | 'pytorch_model-00002-of-00002.bin' 108 | ] 109 | 110 | for fn in os.listdir(script_args.output_path): 111 | if fn not in keep_files: 112 | os.remove(os.path.join(script_args.output_path, fn)) 113 | 114 | for fn in chartmoe_files: 115 | if fn != 'config.json': 116 | shutil.copy(os.path.join(script_args.chartmoe_hf_dir, fn), os.path.join(script_args.output_path, fn)) 117 | else: 118 | import json 119 | config = json.load(open(os.path.join(script_args.chartmoe_hf_dir, fn), encoding='utf-8')) 120 | config["num_experts"] = num_experts 121 | config["num_selected"] = num_selected 122 | with open(os.path.join(script_args.output_path, fn), 'w', encoding='utf-8') as fo: 123 | json.dump(config, fo, indent=4, ensure_ascii=False) 124 | 125 | -------------------------------------------------------------------------------- /chartmoe/train/chartmoe_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Trainer of ChartMoE 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | REFERENCE: https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py 6 | """ 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | 11 | from transformers import Trainer 12 | from transformers.trainer import ( 13 | is_sagemaker_mp_enabled, 14 | get_parameter_names, 15 | has_length, 16 | ALL_LAYERNORM_LAYERS, 17 | logger, 18 | ) 19 | from typing import List, Optional 20 | 21 | def maybe_zero_3(param, ignore_status=False, name=None): 22 | from deepspeed import zero 23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 24 | if hasattr(param, "ds_id"): 25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 26 | if not ignore_status: 27 | print(name, 'no ignore status') 28 | with zero.GatheredParameters([param]): 29 | param = param.data.detach().cpu().clone() 30 | else: 31 | param = param.detach().cpu().clone() 32 | return param 33 | 34 | 35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 38 | return to_return 39 | 40 | class ChartMoETrainer(Trainer): 41 | 42 | def _save_checkpoint(self, model, trial, metrics=None): 43 | if getattr(self.args, 'tune_mm_mlp', False): 44 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 45 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 46 | 47 | run_dir = self._get_output_dir(trial=trial) 48 | output_dir = os.path.join(run_dir, checkpoint_folder) 49 | 50 | # Only save Adapter 51 | keys_to_match = ['vision_proj'] 52 | 53 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 54 | 55 | if self.args.local_rank == 0 or self.args.local_rank == -1: 56 | self.model.config.save_pretrained(output_dir) 57 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_mlp.bin')) 58 | else: 59 | super(ChartMoETrainer, self)._save_checkpoint(model, trial, metrics) 60 | 61 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 62 | if getattr(self.args, 'tune_mm_mlp', False): 63 | pass 64 | else: 65 | super(ChartMoETrainer, self)._save(output_dir, state_dict) -------------------------------------------------------------------------------- /chartmoe/train/data/code_align.txt: -------------------------------------------------------------------------------- 1 | ./data/chart2code.json 300 -------------------------------------------------------------------------------- /chartmoe/train/data/json_align.txt: -------------------------------------------------------------------------------- 1 | ./data/chart2json.json 200 -------------------------------------------------------------------------------- /chartmoe/train/data/sft.txt: -------------------------------------------------------------------------------- 1 | ./data/chartqa.json 70 2 | ./data/chartgemma.json -1 -------------------------------------------------------------------------------- /chartmoe/train/data/table_align.txt: -------------------------------------------------------------------------------- 1 | ./data/chart2table.json 500 -------------------------------------------------------------------------------- /chartmoe/train/data_mix.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Mixture of Common/Sampling Dataset 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | REFERENCE: https://github.com/InternLM/InternLM-XComposer 6 | """ 7 | import random 8 | import bisect 9 | 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | from torch.utils.data import Dataset 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | 17 | 18 | def conv2text(sources): 19 | END_HUMAN = '[UNUSED_TOKEN_145]\n' 20 | END_BOT = '[UNUSED_TOKEN_145]\n' 21 | conversation = '' 22 | 23 | for idx, sentence in enumerate(sources): 24 | BEGIN_SIGNAL = '' 25 | 26 | from_str = sentence['from'] 27 | if from_str.lower() == 'human' or from_str.lower() == 'user': 28 | from_str = '[UNUSED_TOKEN_146]user\n' 29 | temp = ( 30 | BEGIN_SIGNAL + from_str + sentence['value'].strip() + 31 | END_HUMAN) 32 | else: 33 | from_str = '[UNUSED_TOKEN_146]assistant\n' 34 | temp = ( 35 | BEGIN_SIGNAL + from_str + sentence['value'].strip() + END_BOT) 36 | conversation += temp 37 | 38 | return conversation + '' 39 | 40 | 41 | class ImageProcessor: 42 | 43 | def __init__(self, image_size=490): 44 | mean = (0.48145466, 0.4578275, 0.40821073) 45 | std = (0.26862954, 0.26130258, 0.27577711) 46 | self.normalize = transforms.Normalize(mean, std) 47 | 48 | self.transform = transforms.Compose([ 49 | transforms.Resize((image_size, image_size), 50 | interpolation=InterpolationMode.BICUBIC), 51 | transforms.ToTensor(), 52 | self.normalize, 53 | ]) 54 | 55 | def __call__(self, item): 56 | item = Image.open(item).convert('RGB') 57 | return self.transform(item) 58 | 59 | 60 | class Mix_dataset(Dataset): 61 | 62 | def __init__(self, 63 | json_datas, 64 | img_size=490, 65 | local_rank=0, 66 | hd_num=-1): 67 | """vis_root (string): Root directory of images (e.g. coco/images/) 68 | ann_root (string): directory to store the annotation file.""" 69 | super().__init__() 70 | print(f'init mix data at rank {local_rank}') 71 | self.local_rank = local_rank 72 | 73 | self.datasets = [] 74 | self.start_idx_per = [0] 75 | self.data_num = 0 76 | for _, d in json_datas.items(): 77 | if 'image' in d[0].keys(): 78 | has_img = True 79 | else: 80 | has_img = False 81 | sub_data_set = Common_dataset( 82 | d, 83 | has_img=has_img, 84 | img_size=img_size, 85 | hd_num=hd_num) 86 | self.datasets.append(sub_data_set) 87 | self.start_idx_per.append(self.start_idx_per[-1] + len(sub_data_set)) 88 | self.data_num += len(sub_data_set) 89 | 90 | self.start_idx_per.pop(-1) 91 | 92 | if len(self.datasets) == 0: 93 | raise ValueError( 94 | 'Both _multi and _text are empty. Cannot sample any data.') 95 | 96 | def __len__(self): 97 | return self.data_num 98 | 99 | def __getitem__(self, index): 100 | index = index % self.data_num # avoid some indices which are outside the interval 101 | dataset_idx = bisect.bisect_right(self.start_idx_per, index) - 1 102 | sample_idx = index - self.start_idx_per[dataset_idx] 103 | sample = self.datasets[dataset_idx].get_item(sample_idx) 104 | return dict(samples=sample) 105 | 106 | 107 | class Common_dataset(Dataset): 108 | 109 | def __init__(self, 110 | raw_data, 111 | has_img=True, 112 | img_size=490, 113 | hd_num=-1): 114 | self.raw_data = raw_data 115 | print(f'load {len(self.raw_data)} data') 116 | assert hd_num == -1, "please set `hd_num` to `-1`" 117 | 118 | self.vis_processor = ImageProcessor(image_size=img_size) 119 | self.text_processor = conv2text 120 | self.has_img = has_img 121 | 122 | def __len__(self): 123 | return len(self.raw_data) 124 | 125 | def __get_item__(self, i): 126 | conv_text = conv2text(self.raw_data[i]['conversations']) 127 | sample = dict(text_input=conv_text, ) 128 | if self.has_img: 129 | image_file = self.raw_data[i]['image'] 130 | image = [self.vis_processor(i) for i in image_file] 131 | sample['image'] = torch.stack(image) 132 | else: 133 | sample['image'] = None 134 | 135 | return sample 136 | 137 | def get_item(self, idx): 138 | text_input = [] 139 | images = [] 140 | 141 | sample = self.__get_item__(idx) 142 | text_input.append(sample['text_input']) 143 | images.append(sample['image']) 144 | sample = { 145 | 'text_input': text_input, 146 | 'data_type': 'multi' if self.has_img else 'text', 147 | } 148 | if self.has_img: 149 | sample['image'] = torch.cat(images) 150 | return sample 151 | 152 | 153 | 154 | 155 | class Mix_sampling_dataset(Dataset): 156 | 157 | def __init__(self, 158 | json_datas, 159 | seq_packing_size=1, 160 | img_size=490, 161 | local_rank=0, 162 | hd_num=-1): 163 | """vis_root (string): Root directory of images (e.g. coco/images/) 164 | ann_root (string): directory to store the annotation file.""" 165 | super().__init__() 166 | print(f'init mix sampling_data at rank {local_rank}') 167 | self.datasets_text, self.datasets_multi = [], [] 168 | self.data_num_text, self.data_num_multi = [], [] 169 | 170 | self.seq_packing_size = seq_packing_size 171 | self.set_seed = False 172 | self.local_rank = local_rank 173 | for _, d in json_datas.items(): 174 | if 'image' in d[0].keys(): 175 | has_img = True 176 | else: 177 | has_img = False 178 | sub_data_set = Sample_dataset( 179 | d, 180 | seq_packing_size, 181 | has_img=has_img, 182 | img_size=img_size, 183 | hd_num=hd_num) 184 | if has_img: 185 | self.datasets_multi.append(sub_data_set) 186 | self.data_num_multi.append(len(sub_data_set)) 187 | else: 188 | self.datasets_text.append(sub_data_set) 189 | self.data_num_text.append(len(sub_data_set)) 190 | 191 | self.data_ratio_multi = [ 192 | float(ratio) / sum(self.data_num_multi) 193 | for ratio in self.data_num_multi 194 | ] 195 | self.data_ratio_text = [ 196 | float(ratio) / sum(self.data_num_text) 197 | for ratio in self.data_num_text 198 | ] 199 | self.data_num = np.sum(self.data_num_multi) + np.sum( 200 | self.data_num_text) 201 | self.use_multi = 0 202 | 203 | def __len__(self): 204 | return int(np.sum(self.data_num) / self.seq_packing_size) 205 | 206 | def __getitem__(self, index): 207 | if not self.set_seed: 208 | random.seed(index) 209 | self.set_seed = True 210 | print(f'Set seed {index} for rank {self.local_rank}') 211 | 212 | if len(self.datasets_multi) == 0 and len(self.datasets_text) == 0: 213 | raise ValueError( 214 | 'Both _multi and _text are empty. Cannot sample any data.') 215 | 216 | if len(self.datasets_multi) > 0 and (self.use_multi < self.seq_packing_size 217 | or len(self.datasets_text) == 0): 218 | data_idx = random.choices( 219 | range(len(self.data_ratio_multi)), 220 | weights=self.data_ratio_multi, 221 | k=1)[0] 222 | sample = self.datasets_multi[data_idx].get_item() 223 | elif len(self.datasets_text) > 0: 224 | data_idx = random.choices( 225 | range(len(self.data_ratio_text)), 226 | weights=self.data_ratio_text, 227 | k=1)[0] 228 | sample = self.datasets_text[data_idx].get_item() 229 | else: 230 | raise ValueError('Unable to select a dataset for sampling.') 231 | 232 | self.use_multi += 1 233 | if self.use_multi > self.seq_packing_size * 2: 234 | self.use_multi = 0 235 | return dict(samples=sample) 236 | 237 | 238 | class Sample_dataset(Dataset): 239 | 240 | def __init__(self, 241 | raw_data, 242 | seq_packing_size, 243 | has_img=True, 244 | img_size=490, 245 | hd_num=-1): 246 | self.raw_data = raw_data 247 | print(f'load {len(self.raw_data)} data') 248 | self.seq_packing_size = seq_packing_size 249 | assert hd_num == -1, "please set `hd_num` to `-1`" 250 | 251 | self.vis_processor = ImageProcessor(image_size=img_size) 252 | self.text_processor = conv2text 253 | self.has_img = has_img 254 | 255 | def __len__(self): 256 | return len(self.raw_data) 257 | 258 | def __get_item__(self, i): 259 | conv_text = conv2text(self.raw_data[i]['conversations']) 260 | sample = dict(text_input=conv_text, ) 261 | if self.has_img: 262 | image_file = self.raw_data[i]['image'] 263 | image = [self.vis_processor(i) for i in image_file] 264 | sample['image'] = torch.stack(image) 265 | else: 266 | sample['image'] = None 267 | 268 | return sample 269 | 270 | def get_item(self, ): 271 | text_input = [] 272 | images = [] 273 | for i in range(self.seq_packing_size): 274 | # Randomly select an index from raw_data to get a random sample 275 | idx = random.randrange(len(self.raw_data)) 276 | sample = self.__get_item__(idx) 277 | text_input.append(sample['text_input']) 278 | images.append(sample['image']) 279 | sample = { 280 | 'text_input': text_input, 281 | 'data_type': 'multi' if self.has_img else 'text', 282 | } 283 | if self.has_img: 284 | sample['image'] = torch.cat(images) 285 | return sample 286 | -------------------------------------------------------------------------------- /chartmoe/train/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 2, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "allgather_partitions": true, 21 | "allgather_bucket_size": 2e8, 22 | "overlap_comm": true, 23 | "reduce_scatter": true, 24 | "reduce_bucket_size": 2e8, 25 | "contiguous_gradients": true 26 | }, 27 | 28 | "gradient_accumulation_steps": "auto", 29 | "gradient_clipping": "auto", 30 | "steps_per_print": 100, 31 | "train_batch_size": "auto", 32 | "train_micro_batch_size_per_gpu": "auto", 33 | "wall_clock_breakdown": false 34 | } 35 | -------------------------------------------------------------------------------- /chartmoe/train/mlp_moe.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Implementation of MLP-MoE 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | REFERENCE: https://github.com/SHI-Labs/CuMo 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from einops import rearrange, repeat, reduce, pack, unpack 11 | 12 | class MLPMoE(nn.Module): 13 | def __init__(self, num_experts, num_selected, mm_channels, channels): 14 | super().__init__() 15 | self.num_experts = num_experts 16 | self.num_selected = num_selected 17 | self.mm_channels = mm_channels 18 | self.channels = channels 19 | 20 | self.gate = nn.Linear(mm_channels, num_experts, bias=False) 21 | # nn.init.zeros_(self.gate.weight) 22 | 23 | self.num_selected = num_selected 24 | self.num_experts = num_experts 25 | 26 | self.experts = nn.ModuleList( 27 | [ 28 | nn.Sequential( 29 | nn.Linear(mm_channels, channels, bias=True), 30 | nn.GELU(), 31 | nn.Linear(channels, channels, bias=True) 32 | ) 33 | for _ in range(num_experts) 34 | ] 35 | ) 36 | 37 | def forward(self, x_img): 38 | gate_logits = self.gate(x_img) 39 | gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype) 40 | 41 | weights, selected_experts = torch.topk(gate_softmax, self.num_selected) 42 | weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype) 43 | 44 | results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype) 45 | for b in range(x_img.shape[0]): 46 | for i, expert in enumerate(self.experts): 47 | token_idx, nth_expert = torch.where(selected_experts[b] == i) 48 | results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx]) 49 | 50 | return results 51 | 52 | 53 | class MLPMoE_bzloss(nn.Module): 54 | def __init__(self, num_experts, num_selected, mm_channels, channels): 55 | super().__init__() 56 | self.num_experts = num_experts 57 | self.num_selected = num_selected 58 | self.mm_channels = mm_channels 59 | self.channels = channels 60 | 61 | self.gate = nn.Linear(mm_channels, num_experts, bias=False) 62 | # nn.init.zeros_(self.gate.weight) 63 | 64 | self.num_selected = num_selected 65 | self.num_experts = num_experts 66 | self.experts = nn.ModuleList([nn.Sequential(nn.Linear(mm_channels, channels, bias=True), nn.GELU(), nn.Linear(channels, channels, bias=True)) for _ in range(num_experts)]) 67 | 68 | def forward(self, x_img): 69 | gate_logits = self.gate(x_img) 70 | 71 | router_z_loss = torch.logsumexp(gate_logits, dim = -1) 72 | router_z_loss = torch.square(router_z_loss) 73 | router_z_loss = router_z_loss.mean() 74 | 75 | gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float).to(x_img.dtype) 76 | 77 | density_1_proxy = reduce(gate_softmax, '... n e -> ... e', 'mean') 78 | 79 | weights, selected_experts = torch.topk(gate_softmax, self.num_selected) 80 | 81 | one_hot_gate_indices = F.one_hot(rearrange(selected_experts, '... k -> k ...'), self.num_experts).float()[0] 82 | density_1 = reduce(one_hot_gate_indices, '... n e -> ... e', 'mean') 83 | balance_loss = (density_1_proxy * density_1).mean() * float(self.num_experts ** 2) 84 | 85 | weights = weights / torch.sum(weights, dim=-1, keepdim=True).to(x_img.dtype) 86 | 87 | results = torch.zeros((x_img.shape[0], x_img.shape[1], self.channels)).to(x_img.device, x_img.dtype) 88 | 89 | for b in range(x_img.shape[0]): 90 | for i, expert in enumerate(self.experts): 91 | token_idx, nth_expert = torch.where(selected_experts[b] == i) 92 | results[b][token_idx] += weights[b][token_idx, nth_expert, None] * expert(x_img[b][token_idx]) 93 | 94 | return results, 0.1 * balance_loss, 0.01 * router_z_loss -------------------------------------------------------------------------------- /chartmoe/train/moe_construction.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Construct MLP-MoE Connector with three aligned MLP and the general MLP Connector 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | """ 6 | from mlp_moe import MLPMoE 7 | import argparse 8 | from glob import glob 9 | import os 10 | from copy import deepcopy 11 | 12 | import torch 13 | import torch.nn as nn 14 | import transformers 15 | 16 | def main(args): 17 | root = args.root_dir 18 | table_ckpt = glob(f"{root}/table_proj/checkpoint-*/mm_mlp.bin")[0] 19 | json_ckpt = glob(f"{root}/json_proj/checkpoint-*/mm_mlp.bin")[0] 20 | code_ckpt = glob(f"{root}/code_proj/checkpoint-*/mm_mlp.bin")[0] 21 | 22 | base_model = transformers.AutoModel.from_pretrained( 23 | args.base_model, 24 | trust_remote_code=True, 25 | device_map="cuda", 26 | attn_implementation="eager", 27 | ) 28 | base_proj = deepcopy(base_model.vision_proj) 29 | del base_model 30 | 31 | table_proj = torch.load(table_ckpt) 32 | json_proj = torch.load(json_ckpt) 33 | code_proj = torch.load(code_ckpt) 34 | 35 | mlp_moe = MLPMoE(args.mlp_smoe_experts, args.mlp_smoe_topk, 1024, 4096) 36 | for idx, expert in enumerate(mlp_moe.experts): 37 | print(idx % args.mlp_smoe_experts) 38 | if idx % args.mlp_smoe_experts == 0: 39 | for target_layer, source_layer in zip(expert, base_proj): 40 | if isinstance(target_layer, nn.Linear) and isinstance(source_layer, nn.Linear): 41 | target_layer.weight = deepcopy(source_layer.weight) 42 | target_layer.bias = deepcopy(source_layer.bias) 43 | print(f"{idx} expert: load base_proj") 44 | if idx % args.mlp_smoe_experts == 1: 45 | for ii in [0,2]: 46 | expert[ii].weight.data = table_proj[f'vision_proj.{ii}.weight'] 47 | expert[ii].bias.data = table_proj[f'vision_proj.{ii}.bias'] 48 | print(f"{idx} expert: load table_proj") 49 | if idx % args.mlp_smoe_experts == 2: 50 | for ii in [0,2]: 51 | expert[ii].weight.data = json_proj[f'vision_proj.{ii}.weight'] 52 | expert[ii].bias.data = json_proj[f'vision_proj.{ii}.bias'] 53 | print(f"{idx} expert: load json_proj") 54 | if idx % args.mlp_smoe_experts == 3: 55 | for ii in [0,2]: 56 | expert[ii].weight.data = code_proj[f'vision_proj.{ii}.weight'] 57 | expert[ii].bias.data = code_proj[f'vision_proj.{ii}.bias'] 58 | print(f"{idx} expert: load code_proj") 59 | 60 | os.makedirs(f"{root}/{args.save_name}", exist_ok=True) 61 | mlp_moe_state_dict = mlp_moe.state_dict() 62 | mlp_moe_state_dict['num_selected'] = args.mlp_smoe_topk 63 | torch.save(mlp_moe_state_dict, f"{root}/{args.save_name}/mlp_moe.pth") 64 | 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument("--base_model", type=str) 70 | parser.add_argument("--mlp_smoe_experts", type=int, default=4) 71 | parser.add_argument("--mlp_smoe_topk", type=int, default=2) 72 | parser.add_argument("--root_dir", type=str, required=True) 73 | parser.add_argument("--save_name", type=str, required=True) 74 | args = parser.parse_args() 75 | 76 | main(args) 77 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/chartmoe_construction.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python chartmoe_construction.py \ 2 | --moe_aligned_pth_path output/moe_aligned/mlp_moe.pth \ 3 | --chartmoe_hf_dir ckpt/chartmoe \ 4 | --adapter_model_name output/sft \ 5 | --output_path output/sft/chartmoe_reproduced -------------------------------------------------------------------------------- /chartmoe/train/scripts/chartmoe_data_download.py: -------------------------------------------------------------------------------- 1 | import huggingface_hub 2 | import time 3 | 4 | def download(): 5 | try: 6 | huggingface_hub.snapshot_download("Coobiw/ChartMoE-Data", local_dir="./data/", repo_type="dataset", resume_download=True) 7 | return True 8 | except: 9 | print("Caught an exception! Retrying...") 10 | return False 11 | 12 | while True: 13 | result = download() 14 | if result: 15 | print("success") 16 | break # Exit the loop if the function ran successfully 17 | time.sleep(1) # Wait for 1 second before retrying 18 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/chartmoe_download.py: -------------------------------------------------------------------------------- 1 | import huggingface_hub 2 | import time 3 | 4 | def download(): 5 | try: 6 | huggingface_hub.snapshot_download("IDEA-FinAI/chartmoe", local_dir="./ckpt/chartmoe",resume_download=True,max_workers=4) 7 | return True 8 | except: 9 | print("Caught an exception! Retrying...") 10 | return False 11 | 12 | while True: 13 | result = download() 14 | if result: 15 | print("success") 16 | break # Exit the loop if the function ran successfully 17 | time.sleep(1) # Wait for 1 second before retrying 18 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/code_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced" 6 | export DATA="data/code_align.txt" 7 | 8 | GPUS_PER_NODE=4 9 | NNODES=1 10 | NODE_RANK=0 11 | MASTER_ADDR=localhost 12 | MASTER_PORT=12700 13 | 14 | DISTRIBUTED_ARGS=" 15 | --nproc_per_node $GPUS_PER_NODE \ 16 | --nnodes $NNODES \ 17 | --node_rank $NODE_RANK \ 18 | --master_addr $MASTER_ADDR \ 19 | --master_port $MASTER_PORT 20 | " 21 | 22 | torchrun $DISTRIBUTED_ARGS train.py \ 23 | --model_name_or_path $MODEL \ 24 | --data_path $DATA \ 25 | --dataloader_num_workers 4 \ 26 | --img_size 490 \ 27 | --hd_num -1 \ 28 | --given_num True \ 29 | --bf16 True \ 30 | --fix_vit True \ 31 | --fix_sampler False \ 32 | --fix_llm True \ 33 | --use_lora False \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 4 \ 36 | --per_device_eval_batch_size 1 \ 37 | --gradient_accumulation_steps 16 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "epoch" \ 40 | --save_total_limit 5 \ 41 | --learning_rate 5e-5 \ 42 | --weight_decay 0.1 \ 43 | --adam_beta2 0.95 \ 44 | --warmup_ratio 0.0 \ 45 | --lr_scheduler_type "cosine" \ 46 | --logging_steps 1 \ 47 | --max_length 4096 \ 48 | --gradient_checkpointing True \ 49 | --deepspeed ds_config_zero2.json \ 50 | --output_dir output/code_proj \ 51 | --report_to none 52 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/internlm_xc2_download.py: -------------------------------------------------------------------------------- 1 | import huggingface_hub 2 | import time 3 | 4 | def download(): 5 | try: 6 | huggingface_hub.snapshot_download("Coobiw/InternLM-XComposer2_Enhanced", local_dir="./ckpt/InternLM-XComposer2_Enhanced",resume_download=True,max_workers=4) 7 | return True 8 | except: 9 | print("Caught an exception! Retrying...") 10 | return False 11 | 12 | while True: 13 | result = download() 14 | if result: 15 | print("success") 16 | break # Exit the loop if the function ran successfully 17 | time.sleep(1) # Wait for 1 second before retrying 18 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/json_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced" 6 | export DATA="data/json_align.txt" 7 | 8 | GPUS_PER_NODE=4 9 | NNODES=1 10 | NODE_RANK=0 11 | MASTER_ADDR=localhost 12 | MASTER_PORT=12700 13 | 14 | DISTRIBUTED_ARGS=" 15 | --nproc_per_node $GPUS_PER_NODE \ 16 | --nnodes $NNODES \ 17 | --node_rank $NODE_RANK \ 18 | --master_addr $MASTER_ADDR \ 19 | --master_port $MASTER_PORT 20 | " 21 | 22 | torchrun $DISTRIBUTED_ARGS train.py \ 23 | --model_name_or_path $MODEL \ 24 | --data_path $DATA \ 25 | --dataloader_num_workers 4 \ 26 | --img_size 490 \ 27 | --hd_num -1 \ 28 | --given_num True \ 29 | --bf16 True \ 30 | --fix_vit True \ 31 | --fix_sampler False \ 32 | --fix_llm True \ 33 | --use_lora False \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 4 \ 36 | --per_device_eval_batch_size 1 \ 37 | --gradient_accumulation_steps 16 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "epoch" \ 40 | --save_total_limit 5 \ 41 | --learning_rate 5e-5 \ 42 | --weight_decay 0.1 \ 43 | --adam_beta2 0.95 \ 44 | --warmup_ratio 0.0 \ 45 | --lr_scheduler_type "cosine" \ 46 | --logging_steps 1 \ 47 | --max_length 4096 \ 48 | --gradient_checkpointing True \ 49 | --deepspeed ds_config_zero2.json \ 50 | --output_dir output/json_proj \ 51 | --report_to none 52 | -------------------------------------------------------------------------------- /chartmoe/train/scripts/moe_construction.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python moe_construction.py \ 2 | --base_model ckpt/InternLM-XComposer2_Enhanced \ 3 | --root_dir ./output \ 4 | --save_name moe_aligned -------------------------------------------------------------------------------- /chartmoe/train/scripts/multi_align.sh: -------------------------------------------------------------------------------- 1 | export CUDA_LAUNCH_BLOCKING=1 2 | 3 | mkdir -p logs/json_proj 4 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/json_align.sh 2>&1 | tee logs/json_proj/tee_logs.txt 5 | sleep 1m 6 | 7 | mkdir -p logs/code_proj 8 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/code_align.sh 2>&1 | tee logs/code_proj/tee_logs.txt 9 | sleep 1m 10 | 11 | mkdir -p logs/table_proj 12 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/table_align.sh 2>&1 | tee logs/table_proj/tee_logs.txt -------------------------------------------------------------------------------- /chartmoe/train/scripts/sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced" 6 | export DATA="data/sft.txt" 7 | 8 | GPUS_PER_NODE=4 9 | NNODES=1 10 | NODE_RANK=0 11 | MASTER_ADDR=localhost 12 | MASTER_PORT=6001 13 | 14 | DISTRIBUTED_ARGS=" 15 | --nproc_per_node $GPUS_PER_NODE \ 16 | --nnodes $NNODES \ 17 | --node_rank $NODE_RANK \ 18 | --master_addr $MASTER_ADDR \ 19 | --master_port $MASTER_PORT 20 | " 21 | 22 | torchrun $DISTRIBUTED_ARGS train.py \ 23 | --model_name_or_path $MODEL \ 24 | --moe_aligned_pth_path output/moe_aligned/mlp_moe.pth \ 25 | --data_path $DATA \ 26 | --dataloader_num_workers 4 \ 27 | --img_size 490 \ 28 | --hd_num -1 \ 29 | --given_num True \ 30 | --bf16 True \ 31 | --fix_vit True \ 32 | --fix_sampler False \ 33 | --use_lora True \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 2 \ 36 | --per_device_eval_batch_size 1 \ 37 | --gradient_accumulation_steps 8 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "epoch" \ 40 | --save_total_limit 5 \ 41 | --learning_rate 5e-5 \ 42 | --weight_decay 0.1 \ 43 | --adam_beta2 0.95 \ 44 | --warmup_ratio 0.0 \ 45 | --lr_scheduler_type "cosine" \ 46 | --logging_steps 1 \ 47 | --max_length 4096 \ 48 | --gradient_checkpointing True \ 49 | --deepspeed ds_config_zero2.json \ 50 | --output_dir output/sft \ 51 | --report_to none -------------------------------------------------------------------------------- /chartmoe/train/scripts/table_align.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export CUDA_DEVICE_MAX_CONNECTIONS=1 3 | DIR=`pwd` 4 | 5 | export MODEL="ckpt/InternLM-XComposer2_Enhanced" 6 | export DATA="data/table_align.txt" 7 | 8 | GPUS_PER_NODE=4 9 | NNODES=1 10 | NODE_RANK=0 11 | MASTER_ADDR=localhost 12 | MASTER_PORT=12700 13 | 14 | DISTRIBUTED_ARGS=" 15 | --nproc_per_node $GPUS_PER_NODE \ 16 | --nnodes $NNODES \ 17 | --node_rank $NODE_RANK \ 18 | --master_addr $MASTER_ADDR \ 19 | --master_port $MASTER_PORT 20 | " 21 | 22 | torchrun $DISTRIBUTED_ARGS train.py \ 23 | --model_name_or_path $MODEL \ 24 | --data_path $DATA \ 25 | --dataloader_num_workers 4 \ 26 | --img_size 490 \ 27 | --hd_num -1 \ 28 | --given_num True \ 29 | --bf16 True \ 30 | --fix_vit True \ 31 | --fix_sampler False \ 32 | --fix_llm True \ 33 | --use_lora False \ 34 | --num_train_epochs 1 \ 35 | --per_device_train_batch_size 4 \ 36 | --per_device_eval_batch_size 1 \ 37 | --gradient_accumulation_steps 16 \ 38 | --evaluation_strategy "no" \ 39 | --save_strategy "epoch" \ 40 | --save_total_limit 5 \ 41 | --learning_rate 5e-5 \ 42 | --weight_decay 0.1 \ 43 | --adam_beta2 0.95 \ 44 | --warmup_ratio 0.0 \ 45 | --lr_scheduler_type "cosine" \ 46 | --logging_steps 1 \ 47 | --max_length 4096 \ 48 | --gradient_checkpointing True \ 49 | --deepspeed ds_config_zero2.json \ 50 | --output_dir output/table_proj \ 51 | --report_to none 52 | -------------------------------------------------------------------------------- /chartmoe/train/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Training Script of ChartMoE 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | REFERENCE: https://github.com/InternLM/InternLM-XComposer 6 | """ 7 | import json 8 | import random 9 | from dataclasses import dataclass, field 10 | from typing import Dict, List, Optional, Sequence 11 | 12 | import torch 13 | import torch.nn as nn 14 | import transformers 15 | from accelerate.utils import DistributedType 16 | 17 | from data_mix import Mix_dataset 18 | from mlp_moe import MLPMoE 19 | 20 | from deepspeed import zero 21 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 22 | from peft import LoraConfig, get_peft_model 23 | from transformers import deepspeed 24 | from chartmoe_trainer import ChartMoETrainer 25 | from transformers.trainer_pt_utils import LabelSmoother 26 | 27 | from copy import deepcopy 28 | import os 29 | 30 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index # -100 31 | 32 | @dataclass 33 | class ModelArguments: 34 | model_name_or_path: Optional[str] = field(default='') 35 | moe_aligned_pth_path: Optional[str] = field(default='') 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | data_path: str = field( 41 | default='data.txt', metadata={'help': 'Path to the training data.'}) 42 | given_num: bool = False 43 | img_size: int = 490 44 | hd_num: int = -1 45 | 46 | 47 | @dataclass 48 | class TrainingArguments(transformers.TrainingArguments): 49 | cache_dir: Optional[str] = field(default=None) 50 | optim: str = field(default='adamw_torch') 51 | max_length: int = field( 52 | default=4096, 53 | metadata={ 54 | 'help': 55 | 'Maximum sequence length. Sequences will be right padded (and possibly truncated).' 56 | }, 57 | ) 58 | use_lora: bool = False 59 | fix_vit: bool = True 60 | fix_sampler: bool = False 61 | fix_llm: bool = True 62 | label_names: List[str] = field(default_factory=lambda: ['samples']) 63 | 64 | 65 | @dataclass 66 | class LoraArguments: 67 | lora_r: int = 64 68 | lora_alpha: int = 64 69 | lora_dropout: float = 0.05 70 | lora_target_modules: List[str] = field(default_factory=lambda: [ 71 | 'attention.wqkv', 72 | 'attention.wo', 73 | 'feed_forward.w1', 74 | 'feed_forward.w2', 75 | 'feed_forward.w3', 76 | ]) 77 | lora_weight_path: str = '' 78 | lora_bias: str = 'none' 79 | 80 | def maybe_zero_3(param): 81 | if hasattr(param, 'ds_id'): 82 | assert param.ds_status == ZeroParamStatus.NOT_AVAILABLE 83 | with zero.GatheredParameters([param]): 84 | param = param.data.detach().cpu().clone() 85 | else: 86 | param = param.detach().cpu().clone() 87 | return param 88 | 89 | 90 | # Borrowed from peft.utils.get_peft_model_state_dict 91 | def get_peft_state_maybe_zero_3(named_params, bias): 92 | if bias == 'none': 93 | to_return = {k: t for k, t in named_params if 'lora_' in k} 94 | elif bias == 'all': 95 | to_return = { 96 | k: t 97 | for k, t in named_params if 'lora_' in k or 'bias' in k 98 | } 99 | elif bias == 'lora_only': 100 | to_return = {} 101 | maybe_lora_bias = {} 102 | lora_bias_names = set() 103 | for k, t in named_params: 104 | if 'lora_' in k: 105 | to_return[k] = t 106 | bias_name = k.split('lora_')[0] + 'bias' 107 | lora_bias_names.add(bias_name) 108 | elif 'bias' in k: 109 | maybe_lora_bias[k] = t 110 | for k, t in maybe_lora_bias: 111 | if bias_name in lora_bias_names: 112 | to_return[bias_name] = t 113 | else: 114 | raise NotImplementedError 115 | to_return = {k: maybe_zero_3(v) for k, v in to_return.items()} 116 | return to_return 117 | 118 | 119 | local_rank = None 120 | 121 | 122 | def rank0_print(*args): 123 | if local_rank == 0: 124 | print(*args) 125 | 126 | 127 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 128 | output_dir: str, 129 | bias='none'): 130 | """Collects the state dict and dump to disk.""" 131 | # check if zero3 mode enabled 132 | if deepspeed.is_deepspeed_zero3_enabled(): 133 | state_dict = trainer.model_wrapped._zero3_consolidated_16bit_state_dict( 134 | ) 135 | else: 136 | if trainer.args.use_lora: 137 | state_dict = get_peft_state_maybe_zero_3( 138 | trainer.model.named_parameters(), bias) 139 | else: 140 | state_dict = trainer.model.state_dict() 141 | if trainer.args.should_save and trainer.args.local_rank == 0: 142 | trainer._save(output_dir, state_dict=state_dict) 143 | 144 | 145 | @dataclass 146 | class DataCollatorForSupervisedDataset: 147 | """Collate examples for supervised fine-tuning.""" 148 | 149 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 150 | instances = [instance['samples'] for instance in instances] 151 | text_input, data_type = tuple( 152 | [instance[key] for instance in instances] 153 | for key in ('text_input', 'data_type')) 154 | if 'image' not in instances[0]: 155 | text_input = [instance['text_input'][0] for instance in instances] 156 | batch = dict( 157 | text_input=text_input, 158 | data_type=data_type, 159 | ) 160 | if 'image' in instances[0]: 161 | images = [instance['image'] for instance in instances] 162 | batch['image'] = images 163 | 164 | return dict(samples=batch) 165 | 166 | 167 | def make_supervised_data_module( 168 | tokenizer: transformers.PreTrainedTokenizer, 169 | data_args, 170 | ) -> Dict: 171 | """Make dataset and collator for supervised fine-tuning.""" 172 | 173 | rank0_print('Loading data...') 174 | if data_args.data_path.endswith('json'): 175 | train_json = json.load(open(data_args.data_path)) 176 | elif data_args.data_path.endswith('txt'): 177 | train_json = {} 178 | with open(data_args.data_path) as f: 179 | lines = f.readlines() 180 | 181 | for line in lines: 182 | line = line.strip() 183 | line = line.split(' ') 184 | with open(line[0]) as f: 185 | temp = json.load(f) 186 | if data_args.given_num: 187 | assert len(line) == 2 188 | if int(line[1]) == -1: 189 | print(f"Load total data from {line[0]} ...") 190 | else: 191 | num = int(float(line[1]) * 1000) 192 | if len(temp) > num: 193 | temp = random.sample(temp, num) 194 | else: 195 | ex_temp = [] 196 | for i in range(num - len(temp)): 197 | ex_temp.append(random.choice(temp)) 198 | temp.extend(ex_temp) 199 | else: 200 | if len(line) == 2: 201 | ratio = float(line[1]) 202 | new_len = int(len(temp) * ratio) 203 | if ratio < 1: 204 | temp = random.sample(temp, new_len) 205 | elif ratio > 1: 206 | ex_temp = [] 207 | for i in range(new_len - len(temp)): 208 | ex_temp.append(random.choice(temp)) 209 | temp.extend(ex_temp) 210 | rank0_print(f'Load {len(temp)} samples from {line}') 211 | train_json[line[0]] = temp 212 | train_dataset = Mix_dataset( 213 | train_json, 214 | img_size=data_args.img_size, 215 | hd_num=data_args.hd_num, 216 | local_rank=local_rank) 217 | rank0_print(str(len(train_dataset)) + ' samples is loaded') 218 | eval_dataset = None 219 | 220 | data_collator = DataCollatorForSupervisedDataset() 221 | return dict( 222 | train_dataset=train_dataset, 223 | eval_dataset=eval_dataset, 224 | data_collator=data_collator, 225 | ) 226 | 227 | 228 | def train(): 229 | global local_rank 230 | 231 | parser = transformers.HfArgumentParser( 232 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments)) 233 | ( 234 | model_args, 235 | data_args, 236 | training_args, 237 | lora_args, 238 | ) = parser.parse_args_into_dataclasses() 239 | 240 | if getattr(training_args, 'deepspeed', None): 241 | training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED 242 | 243 | local_rank = training_args.local_rank 244 | 245 | device_map = None 246 | 247 | # Set RoPE scaling factor 248 | config = transformers.AutoConfig.from_pretrained( 249 | model_args.model_name_or_path, 250 | cache_dir=training_args.cache_dir, 251 | trust_remote_code=True, 252 | ) 253 | config.use_cache = False 254 | config.max_length = training_args.max_length 255 | 256 | if config.attn_implementation == "flash_attention_2": 257 | rank0_print("Use Flash-Attn!!!") 258 | else: 259 | rank0_print("Use Eager Attn!!!") 260 | 261 | # Load model and tokenizer 262 | rank0_print(f'Load model from: {model_args.model_name_or_path}') 263 | model = transformers.AutoModelForCausalLM.from_pretrained( 264 | model_args.model_name_or_path, 265 | config=config, 266 | cache_dir=training_args.cache_dir, 267 | device_map=device_map, 268 | trust_remote_code=True, 269 | ) 270 | 271 | if data_args.img_size != 336: 272 | model.vit.resize_pos() 273 | 274 | tokenizer = transformers.AutoTokenizer.from_pretrained( 275 | model_args.model_name_or_path, 276 | cache_dir=training_args.cache_dir, 277 | padding_side='right', 278 | use_fast=False, 279 | trust_remote_code=True, 280 | ) 281 | model.tokenizer = tokenizer 282 | 283 | if model_args.moe_aligned_pth_path: 284 | rank0_print("Load Multi-Aligned MoE-MLP...") 285 | mlp_moe_state_dict = torch.load(model_args.moe_aligned_pth_path, map_location="cpu") 286 | num_experts = mlp_moe_state_dict['gate.weight'].size(0) 287 | num_selected = mlp_moe_state_dict.pop('num_selected') 288 | mlp_moe = MLPMoE(num_experts, num_selected, 1024, 4096).to(model.device) 289 | mlp_moe.load_state_dict(mlp_moe_state_dict) 290 | model.vision_proj = deepcopy(mlp_moe) 291 | del mlp_moe 292 | 293 | assert not training_args.fix_sampler, \ 294 | "If load aligned moe, set `fix_sampler` to `False`. Because the router must be trained!!!" 295 | 296 | 297 | if training_args.fix_vit: 298 | model.vit.requires_grad_(False) 299 | else: 300 | assert False, "please fix vit~" 301 | # model.vit.requires_grad_(True) 302 | # model.vit.vision_tower.vision_model.post_layernorm = torch.nn.Identity( 303 | # ) 304 | 305 | if training_args.fix_sampler: 306 | model.vision_proj.requires_grad_(False) 307 | else: 308 | model.vision_proj.requires_grad_(True) 309 | 310 | if training_args.use_lora: 311 | for name, param in model.model.named_parameters(): 312 | param.requires_grad = False 313 | 314 | lora_config = LoraConfig( 315 | r=lora_args.lora_r, 316 | lora_alpha=lora_args.lora_alpha, 317 | target_modules=lora_args.lora_target_modules, 318 | lora_dropout=lora_args.lora_dropout, 319 | bias=lora_args.lora_bias, 320 | task_type='CAUSAL_LM', 321 | modules_to_save=['vision_proj'] if not training_args.fix_sampler else None 322 | ) 323 | 324 | model = get_peft_model(model, lora_config) 325 | model.print_trainable_parameters() 326 | 327 | if training_args.gradient_checkpointing: 328 | model.enable_input_require_grads() 329 | else: 330 | if training_args.fix_llm: 331 | for name, param in model.model.named_parameters(): 332 | param.requires_grad = False 333 | for name, param in model.output.named_parameters(): 334 | param.requires_grad = False 335 | 336 | total = 0 337 | training = 0 338 | training_name = [] 339 | for name, param in model.named_parameters(): 340 | param_num = param.numel() 341 | if param.requires_grad: 342 | training += param_num 343 | training_name.append(name) 344 | total += param_num 345 | 346 | rank0_print(f"Total Params:\t{total / 1e9:.2f}B") 347 | rank0_print(f"Training Params:\t{training / 1e6:.2f}M") 348 | rank0_print(training_name) 349 | 350 | # # Name of Trainable Params 351 | # trainable_params_names = [] 352 | # for name,param in model.named_parameters(): 353 | # if param.requires_grad: 354 | # trainable_params_names.append(name) 355 | # print(trainable_params_names) 356 | 357 | # Load data 358 | data_module = make_supervised_data_module( 359 | tokenizer=tokenizer, data_args=data_args) 360 | print(f"transformers logging bar enabled status: {transformers.processing_utils.logging.is_progress_bar_enabled()}") 361 | transformers.processing_utils.logging.enable_progress_bar() 362 | 363 | # whether `tune_mm_mlp` or not 364 | if training_args.fix_vit and training_args.fix_llm and (not training_args.fix_sampler) and (not training_args.use_lora): 365 | rank0_print("`tune_mm_mlp` is True!!!") 366 | training_args.tune_mm_mlp = True 367 | # Start trainner 368 | trainer = ChartMoETrainer( 369 | model=model, tokenizer=tokenizer, args=training_args, **data_module 370 | ) 371 | 372 | trainer.train() 373 | trainer.save_state() 374 | 375 | safe_save_model_for_hf_trainer( 376 | trainer=trainer, 377 | output_dir=training_args.output_dir, 378 | bias=lora_args.lora_bias) 379 | 380 | 381 | if __name__ == '__main__': 382 | train() -------------------------------------------------------------------------------- /chartmoe/utils/custom_path.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: Configuration of customized paths 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | """ 6 | # Model 7 | ChartMoE_HF_PATH = 'IDEA-FinAI/chartmoe' 8 | 9 | # Data 10 | ChartQA_ROOT = '/path/to/ChartQA/' 11 | ChartQA_TEST_IMG_ROOT = '/path/to/ChartQA/test/png/' -------------------------------------------------------------------------------- /examples/bar2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/bar2.png -------------------------------------------------------------------------------- /examples/bar2_highlight.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/bar2_highlight.png -------------------------------------------------------------------------------- /examples/line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line.png -------------------------------------------------------------------------------- /examples/line3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line3.png -------------------------------------------------------------------------------- /examples/line3_edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/line3_edit.png -------------------------------------------------------------------------------- /examples/pie1-to-bar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/pie1-to-bar.png -------------------------------------------------------------------------------- /examples/pie1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/examples/pie1.png -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: WebUI of ChartMoE 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | REFERENCE: https://github.com/Coobiw/MPP-LLaVA 6 | """ 7 | import argparse 8 | import os 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | import torch.backends.cudnn as cudnn 14 | import gradio as gr 15 | from PIL import Image 16 | 17 | from functools import partial 18 | from copy import deepcopy 19 | 20 | from chartmoe import ChartMoE_Robot 21 | 22 | def disable_torch_init(): 23 | """ 24 | Disable the redundant torch default initialization to accelerate model creation. 25 | """ 26 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 27 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 28 | 29 | # ======================================== 30 | # Model Initialization 31 | # ======================================== 32 | 33 | print('Initializing Chat') 34 | 35 | 36 | disable_torch_init() 37 | chat_robot = ChartMoE_Robot() 38 | 39 | print('Initialization Finished') 40 | 41 | # ======================================== 42 | # Gradio Setting 43 | # ======================================== 44 | 45 | 46 | def gradio_reset(history, img_list): 47 | if history is not None: 48 | history = "" 49 | if img_list is not None: 50 | img_list = [] 51 | return None, \ 52 | gr.update(value=None, interactive=True), \ 53 | gr.update(placeholder='Please upload your image first', interactive=False),\ 54 | gr.update(value="Upload & Start Chat", interactive=True), \ 55 | history, \ 56 | img_list 57 | 58 | def upload_img(gr_img, text_input, history, img_list): 59 | def load_img(image,img_list): 60 | if isinstance(image, str): # is a image path 61 | image = Image.open(image).convert('RGB') 62 | elif isinstance(image, Image.Image): 63 | image = image.convert('RGB') 64 | 65 | img_list.append(image) 66 | msg = "Received." 67 | return msg 68 | if gr_img is None: 69 | return None, None, gr.update(interactive=True), history 70 | 71 | load_img(gr_img, img_list) 72 | return gr.update(interactive=False), \ 73 | gr.update(interactive=True, placeholder='Type and press Enter'), \ 74 | gr.update(value="Start Chatting", interactive=False), \ 75 | history, \ 76 | img_list 77 | 78 | def gradio_ask(user_message, chatbot): 79 | if len(user_message) == 0: 80 | return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, history 81 | chatbot = chatbot + [[user_message, None]] 82 | return user_message, chatbot 83 | 84 | 85 | def gradio_answer(chatbot, text_input, history, img_list, do_sample,num_beams, temperature, max_new_tokens): 86 | generation_config = \ 87 | { 88 | "do_sample": do_sample=='True', 89 | "num_beams": num_beams, 90 | 'temperature': temperature, 91 | 'max_new_tokens': max_new_tokens, 92 | } 93 | 94 | image = img_list[0] 95 | with torch.cuda.amp.autocast(): 96 | response, history = chat_robot.chat(image=image,question=text_input,history=history,**generation_config) 97 | chatbot[-1][1] = response 98 | text_input = '' 99 | return chatbot, history, img_list, text_input 100 | 101 | title = """

Demo of ChartMoE

""" 102 | description = """

This is the demo of ChartMoE. Upload your image and start chatting! To use example questions, click example image, hit upload, and press enter in the chatbox.

""" 103 | 104 | from transformers.trainer_utils import set_seed 105 | set_seed(42) 106 | #TODO show examples below 107 | 108 | with gr.Blocks() as demo: 109 | gr.Markdown(title) 110 | gr.Markdown(description) 111 | 112 | with gr.Row(): 113 | with gr.Column(scale=0.25): 114 | image = gr.Image(type="pil") 115 | with gr.Row(): 116 | upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary") 117 | clear = gr.Button("Restart 🔄") 118 | 119 | examples_placeholder = gr.Column() 120 | 121 | with gr.Column(scale=0.75): 122 | history = gr.State(value="") 123 | img_list = gr.State(value=[]) 124 | chatbot = gr.Chatbot( 125 | label='ChartMoE', 126 | height=700, 127 | avatar_images=['gradio_demo_pics/user.png','gradio_demo_pics/robot.png'] 128 | ) 129 | 130 | with gr.Row(): 131 | text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False, scale=8) 132 | submit_button = gr.Button(value="Submit", variant="primary",scale=2) 133 | 134 | with gr.Row(): 135 | do_sample = gr.components.Radio(['True', 'False'], 136 | label='do_sample', 137 | value='False') 138 | 139 | num_beams = gr.Slider( 140 | minimum=1, 141 | maximum=5, 142 | value=1, 143 | step=1, 144 | interactive=True, 145 | label="num beams", 146 | ) 147 | 148 | temperature = gr.Slider( 149 | minimum=0.1, 150 | maximum=2.0, 151 | value=1.0, 152 | step=0.1, 153 | interactive=True, 154 | label="Temperature", 155 | ) 156 | 157 | max_new_tokens = gr.Slider( 158 | minimum=128, 159 | maximum=4096, 160 | value=1024, 161 | step=128, 162 | interactive=True, 163 | label="max new tokens", 164 | ) 165 | 166 | with examples_placeholder: 167 | gr.Examples(examples=[ 168 | ["examples/bar2.png", "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart."], 169 | ["examples/line3.png", "Redraw the chart with python matplotlib, giving the code to highlight data point with lowest growth rate (draw a horizontal dotted line parallel to the x-axi, through the lowest point and add \'lowest\' label in the legend anchor). Please keep the same colors and legend as the input chart."], 170 | ["examples/pie1.png", "Redraw the chart with python matplotlib, convert it into a bar chart, giving the code to reflect the fact that the price of \'Gold\' has been reduced to 27% and the \'Silver\' has been increased to 28%. Please keep the colors and legend according to the input chart."] 171 | ], inputs=[image, text_input]) 172 | 173 | upload_button.click(upload_img, [image, text_input, history,img_list], [image, text_input, upload_button, history, img_list]) 174 | 175 | # print(list(map(type,[text_input, chatbot]))) 176 | # print(list(map(type,[chatbot, history, img_list, do_sample, num_beams, temperature, max_new_tokens]))) 177 | text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( 178 | gradio_answer, [chatbot, text_input, history, img_list, do_sample, num_beams, temperature, max_new_tokens], [chatbot, history, img_list, text_input] 179 | ) 180 | submit_button.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then( 181 | gradio_answer, [chatbot, text_input, history, img_list, do_sample, num_beams, temperature, max_new_tokens], [chatbot, history, img_list, text_input] 182 | ) 183 | clear.click(gradio_reset, [history, img_list], [chatbot, image, text_input, upload_button, history, img_list], queue=False) 184 | 185 | demo.launch(share=True,inbrowser=True) -------------------------------------------------------------------------------- /gradio_demo_pics/gradio_demo1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo1.jpg -------------------------------------------------------------------------------- /gradio_demo_pics/gradio_demo2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo2.jpg -------------------------------------------------------------------------------- /gradio_demo_pics/gradio_demo3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/gradio_demo3.jpg -------------------------------------------------------------------------------- /gradio_demo_pics/robot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/robot.png -------------------------------------------------------------------------------- /gradio_demo_pics/user.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-FinAI/ChartMoE/132b0361a97e887f37de38b5bbaedc5290acaef3/gradio_demo_pics/user.png -------------------------------------------------------------------------------- /quickstart.py: -------------------------------------------------------------------------------- 1 | """ 2 | FEATURE: QuickTour of ChartMoE 3 | AUTHOR: Brian Qu 4 | URL: https://arxiv.org/abs/2409.03277 5 | """ 6 | from chartmoe import ChartMoE_Robot 7 | import torch 8 | 9 | robot = ChartMoE_Robot() 10 | image_path = "examples/bar2.png" 11 | question = "Redraw the chart with python matplotlib, giving the code to highlight the column corresponding to the year in which the student got the highest score (painting it red). Please keep the same colors and legend as the input chart." 12 | 13 | history = "" 14 | with torch.cuda.amp.autocast(): 15 | response, history = robot.chat(image_path=image_path, question=question, history=history) 16 | 17 | print(response) 18 | 19 | '''Response: 20 | ```python 21 | import matplotlib.pyplot as plt 22 | 23 | data = [3.3, 3.5, 3.6, 3.8, 3.7, 3.6, 3.8] 24 | years = ['2016', '2017', '2018', '2019', '2020', '2021', '2022'] 25 | labels = ['Student A Average GPA'] 26 | colors = ['blue'] 27 | 28 | plt.bar(years, data, color=colors) 29 | plt.title('Student Performance') 30 | plt.xlabel('Year') 31 | plt.ylabel('Student A Average GPA') 32 | plt.legend(labels) 33 | 34 | # Highlight the year with the highest score 35 | highest_score_index = data.index(max(data)) 36 | plt.bar(years[highest_score_index], data[highest_score_index], color='red') 37 | 38 | plt.show() 39 | ``` 40 | 41 | ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.33.2 2 | timm==0.4.12 3 | sentencepiece==0.1.99 4 | gradio==4.13.0 5 | markdown2==2.4.10 6 | xlsxwriter==3.1.2 7 | einops==0.8.0 8 | deepspeed==0.14.2 9 | peft==0.10.0 10 | prettytable==3.10.2 11 | tqdm 12 | datasets==2.21.0 13 | python-Levenshtein 14 | opencv-python>=4.10.0 15 | wandb==0.19.6 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_namespace_packages 2 | import platform 3 | 4 | DEPENDENCY_LINKS = [] 5 | if platform.system() == "Windows": 6 | DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html") 7 | 8 | 9 | def fetch_requirements(filename): 10 | with open(filename) as f: 11 | return [ln.strip() for ln in f.read().split("\n")] 12 | 13 | 14 | setup( 15 | name="chartmoe", 16 | version="0.1.0", 17 | author="Coobiw", 18 | description="ChartMoE: Mixture of Expert Connector for Better Chart Understanding", 19 | keywords="Multimodal Large Language Model (MLLM), Chart Understanding, Mixture of Expert (MoE)", 20 | license="3-Clause BSD", 21 | packages=find_namespace_packages(include="chartmoe.*"), 22 | install_requires=fetch_requirements("requirements.txt"), 23 | python_requires=">=3.9", 24 | include_package_data=True, 25 | dependency_links=DEPENDENCY_LINKS, 26 | zip_safe=False, 27 | ) --------------------------------------------------------------------------------