├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── llm_structure_tool.iml
├── modules.xml
└── vcs.xml
├── README.md
├── README_en.md
├── app.py
├── config
└── common_config.py
├── environment.yml
├── finetune
├── llm_utils.py
└── pulse_utils.py
├── img
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg
└── 5.jpg
├── llmtuner
├── __init__.py
├── api
│ ├── __init__.py
│ ├── app.py
│ └── protocol.py
├── chat
│ ├── __init__.py
│ └── stream_chat.py
├── dsets
│ ├── __init__.py
│ ├── loader.py
│ ├── preprocess.py
│ └── utils.py
├── extras
│ ├── __init__.py
│ ├── callbacks.py
│ ├── constants.py
│ ├── logging.py
│ ├── misc.py
│ ├── patches
│ │ ├── __init__.py
│ │ └── llama_patch.py
│ ├── ploting.py
│ ├── save_and_load.py
│ └── template.py
├── hparams
│ ├── __init__.py
│ ├── data_args.py
│ ├── finetuning_args.py
│ ├── general_args.py
│ ├── generating_args.py
│ └── model_args.py
├── tuner
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── loader.py
│ │ ├── parser.py
│ │ └── utils.py
│ ├── dpo
│ │ ├── __init__.py
│ │ ├── collator.py
│ │ ├── trainer.py
│ │ └── workflow.py
│ ├── ppo
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── workflow.py
│ ├── pt
│ │ ├── __init__.py
│ │ └── workflow.py
│ ├── rm
│ │ ├── __init__.py
│ │ ├── collator.py
│ │ ├── metric.py
│ │ ├── trainer.py
│ │ └── workflow.py
│ ├── sft
│ │ ├── __init__.py
│ │ ├── metric.py
│ │ ├── trainer.py
│ │ └── workflow.py
│ └── tune.py
└── webui
│ ├── __init__.py
│ ├── chatter.py
│ ├── common.py
│ ├── components
│ ├── __init__.py
│ ├── chatbot.py
│ ├── data.py
│ ├── eval.py
│ ├── export.py
│ ├── infer.py
│ ├── top.py
│ └── train.py
│ ├── css.py
│ ├── engine.py
│ ├── interface.py
│ ├── locales.py
│ ├── manager.py
│ ├── runner.py
│ └── utils.py
├── requirements.txt
├── train_bash.py
└── utils
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | # 忽略Python缓存文件
3 | __pycache__/
4 | *.pyc
5 | *.pyo
6 | *.pyd
7 |
8 | .idea/
9 |
10 | docker/*
11 |
12 | # 忽略一些用户产生的文件
13 | data/*
14 | finetune/*/
15 | !finetune/*.py
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/llm_structure_tool.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [[中文版](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README.md)] [[English](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README_en.md)]
2 |
3 | # 大模型结构化工具
4 |
5 | 该工具是一个可基于常见开源模型进行微调的结构化工具,旨在帮助用户处理和分析文本数据,目前提供了训练,预测,评估一体化的功能。训练预测部分采用了[[llmtuner](https://github.com/hiyouga/LLaMA-Factory)],作为一个核心包引入。
6 |
7 | 它提供了以下常见结构化类型,适用于各种场景下的结构化使用,如病例结构化场景。
8 |
9 | - **单选**
10 |
11 | 
12 |
13 | - **多选**
14 |
15 | 
16 |
17 | - **提取**
18 |
19 | 
20 |
21 | ### 安装
22 |
23 | 首先,克隆本项目到本地计算机:
24 |
25 | ```
26 | git clone https://github.com/JuneYaooo/llm_structure_tool.git
27 | ```
28 |
29 | #### 建议 conda 安装
30 | ##### 方法一
31 | ```
32 | cd llm_structure_tool
33 | conda env create -f environment.yml
34 | ```
35 | ##### 方法二
36 | ```
37 | conda create -n llm_structure python=3.9
38 | pip install -r requirements.txt
39 | ```
40 |
41 | 激活conda环境:
42 |
43 | ```
44 | conda activate llm_structure
45 | ```
46 |
47 | 然后运行前端demo:
48 |
49 | ```
50 | python app.py
51 | ```
52 |
53 | ### 模型配置
54 | 在config/config.py中,填入自己想要使用的模型地址
55 |
56 | ## 使用方法
57 | 结构化工具将在终端上提供一个简单的交互界面。您可以根据提示输入相关信息,选择要执行的功能。
58 |
59 | ### 单句测试
60 |
61 | 输入一段话,设定规则,进行单选、多选或提取
62 |
63 | **示例:**
64 |
65 | 字段类型:提取
66 |
67 | 字段名:肾上腺肿物大小
68 |
69 | 原文:CT检查示左肾上腺区见大小约5.5 cm×5.7 cm不均匀低密度肿块,边界清楚,增强扫描实性成分中度强化,内见无强化低密度,静脉期明显强化。CT诊断:考虑左肾上腺区肿瘤。B超检查示左肾上腺区见4.6 cm×4.2 cm的低回声区,边界清,有包膜,提示左肾上腺实质性占位声像。
70 |
71 |
72 | 输入不相关的字段,如胃部肿物大小,结果为“未提及”
73 | 
74 |
75 | 输入相关的字段,如肾上腺肿物大小,结果为“约5.5 cm×5.7 cm”
76 | 
77 |
78 | ### 训练
79 | 待填充
80 |
81 | ### 预测
82 | 待填充
83 |
84 | ### 评估
85 | 待填充
86 |
87 | ## 致谢
88 |
89 | - [PULSE](https://github.com/openmedlab/PULSE): 本项目使用了PULSE模型(上海人工智能实验室的医疗开源大模型)
90 | - [llmtuner](https://github.com/hiyouga/LLaMA-Factory): 本项目训练预测代码基于llmtuner
91 |
92 | ## 贡献
93 |
94 | 如果您对该项目感兴趣,欢迎贡献您的代码和改进建议。您可以通过以下方式参与:
95 |
96 | 1. 提交问题和建议到本项目的 Issue 页面。
97 | 2. Fork 本项目并提交您的改进建议,我们将会审查并合并合适的改动。
98 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | [[中文版](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README.md)] [[English](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README_en.md)]
2 |
3 | # Medical Record Structuring Tool (Under Continuous Update)
4 |
5 | This tool is a versatile structuring tool that allows fine-tuning common open-source models for various text data processing and analysis tasks. It currently provides integrated functionalities for training, prediction, and evaluation, with the training and prediction components utilizing [[llmtuner](https://github.com/hiyouga/LLaMA-Factory)] as a core package.
6 |
7 | It offers the following common structuring types applicable to various scenarios, such as medical case structuring:
8 |
9 | - Single selection
10 |
11 | 
12 |
13 | - Multiple selection
14 |
15 | 
16 |
17 | - Extraction
18 |
19 | 
20 |
21 | ### Installation
22 |
23 | First, clone this project to your local computer:
24 |
25 | ```
26 | git clone https://github.com/JuneYaooo/llm_structure_tool.git
27 | ```
28 |
29 | #### Conda Installation (Recommended)
30 | ##### Method 1
31 | ```shell
32 | cd llm_structure_tool
33 | conda env create -f environment.yml
34 | ```
35 | ##### Method 2
36 | ```shell
37 | conda create -n llm_structure python=3.9
38 | pip install -r requirements.txt
39 | ```
40 |
41 | Activate the newly created environment:
42 |
43 | ```
44 | conda activate llm_structure
45 | ```
46 |
47 | Then run the frontend demo:
48 |
49 | ```
50 | python app.py
51 | ```
52 |
53 | ## Usage
54 |
55 | The structuring tool provides a simple interactive interface in the terminal. You can enter relevant information and select the desired functionality as prompted.
56 |
57 | ### Single Sentence Testing
58 |
59 | Enter a paragraph, set the rules, and perform single selection, multiple selection, or extraction.
60 |
61 | **Example:**
62 |
63 | Field Type: 提取
64 |
65 | Field Name: 肾上腺肿物大小
66 |
67 | Original Text: CT检查示左肾上腺区见大小约5.5 cm×5.7 cm不均匀低密度肿块,边界清楚,增强扫描实性成分中度强化,内见无强化低密度,静脉期明显强化。CT诊断:考虑左肾上腺区肿瘤。B超检查示左肾上腺区见4.6 cm×4.2 cm的低回声区,边界清,有包膜,提示左肾上腺实质性占位声像。
68 |
69 | Entering an unrelated field, such as "Gastric Tumor Size," will result in "Not mentioned."
70 | 
71 |
72 | Entering a related field, such as "Adrenal Tumor Size," will result in "Approximately 5.5 cm × 5.7 cm."
73 | 
74 |
75 | ### Training
76 | To be filled
77 |
78 | ### Prediction
79 | To be filled
80 |
81 | ### Evaluation
82 | To be filled
83 |
84 | ## Acknowledgments
85 |
86 | - [PULSE](https://github.com/openmedlab/PULSE): This project uses the PULSE model (a medical open-source large language model from the Shanghai Artificial Intelligence Laboratory).
87 | - [llmtuner](https://github.com/hiyouga/LLaMA-Factory): The training and prediction code for this project is based on llmtuner.
88 |
89 | ## Contribution
90 |
91 | If you are interested in this project, you are welcome to contribute your code and improvement suggestions. You can participate in the following ways:
92 |
93 | 1. Submit issues and suggestions to the Issue page of this project.
94 | 2. Fork this project and submit your improvement suggestions. We will review and merge appropriate changes.
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | import os
3 | from finetune.llm_utils import infer_model, query_model, train_model
4 | import shutil
5 | import time
6 | import datetime
7 | from config.common_config import *
8 | from utils.utils import stop_train_process,evaluate_model
9 |
10 | llm_model_dict_list = list(llm_model_dict.keys())
11 |
12 | def get_file_modify_time(filename):
13 | try:
14 | return datetime.datetime.fromtimestamp(os.stat(filename).st_mtime).strftime("%Y-%m-%d %H:%M:%S")
15 | except Exception as e:
16 | print('Failed to get modification time for {}'.format(filename))
17 | print(e)
18 | return 'not available'
19 |
20 | def get_model_update_time(model_name, lora_name):
21 | model_file_name = llm_model_dict[model_name]['name']
22 | print('get_model_update_time model_file_name',model_file_name)
23 | print('get_model_update_time lora_name',lora_name)
24 | model_lora_dir = os.path.join(f"finetune", model_file_name,'checkpoints',lora_name,'adapter_model.bin')
25 | print('model_lora_dir',model_lora_dir)
26 | update_time = get_file_modify_time(model_lora_dir)
27 | return update_time
28 |
29 | def on_train(model_name, lora_name, config_file, training_data_file):
30 | config_path = 'data/'+os.path.basename(config_file.name)
31 | training_data_path = 'data/'+os.path.basename(training_data_file.name)
32 | msg = train_model(model_name, lora_name, config_path, training_data_path)
33 | return msg
34 |
35 | def format_duration(seconds):
36 | hours = int(seconds // 3600)
37 | minutes = int((seconds % 3600) // 60)
38 | seconds = round(seconds % 60,2)
39 | if hours > 0:
40 | return f"{hours}时{minutes}分{seconds}秒"
41 | elif minutes > 0:
42 | return f"{minutes}分{seconds}秒"
43 | else:
44 | return f"{seconds}秒"
45 |
46 |
47 | def on_test(model_name, select_lora, config_file, test_data_file):
48 | start_time = time.time()
49 | config_path = 'data/'+os.path.basename(config_file.name)
50 | test_data_path = 'data/'+os.path.basename(test_data_file.name)
51 |
52 | result_path,info = infer_model(model_name, select_lora, config_path, test_data_path)
53 | end_time = time.time()
54 | cost_time = end_time-start_time
55 |
56 | info = '用时:'+format_duration(cost_time)+f" ({round(cost_time,2)}秒)" if info=='success' else info
57 | return result_path,info
58 |
59 | def on_evaluate(model_name, select_lora, test_result_file, test_label_file):
60 | test_result_path = 'data/'+os.path.basename(test_result_file.name)
61 | test_label_path = 'data/'+os.path.basename( test_label_file.name)
62 | result_path = evaluate_model(test_result_path, test_label_path)
63 | return result_path
64 |
65 | def on_query(model_name,project_name, field_type, field_name, value_range,special_requirement, query):
66 | res = query_model(model_name,project_name, field_type, field_name, value_range,special_requirement, query)
67 | return res
68 |
69 | def on_stop(model_name,select_lora):
70 | res = stop_train_process()
71 | return res
72 |
73 | def upload_file(file):
74 | print('file',file)
75 | if not os.path.exists("data"):
76 | os.mkdir("data")
77 | filename = os.path.basename(file.name)
78 | shutil.move(file.name, "data/" + filename)
79 | # file_list首位插入新上传的文件
80 | filedir = "data/" + filename
81 | return filedir
82 |
83 | def change_lora_name_input(model_name,lora_name_en):
84 | if lora_name_en == "新建":
85 | return gr.update(visible=True), gr.update(visible=True), 'not avilable'
86 | else:
87 | file_status = f"已加载{lora_name_en}"
88 | model_update_time = get_model_update_time(model_name, lora_name_en)
89 | return gr.update(visible=False), gr.update(visible=False), model_update_time
90 |
91 |
92 | def add_lora(lora_name_en,lora_list):
93 | if lora_name_en in lora_list:
94 | print('名称冲突,不新建')
95 | return gr.update(visible=True,value=lora_name_en), gr.update(visible=False), gr.update(visible=False), lora_list
96 | else:
97 | return gr.update(visible=True, choices=[lora_name_en] + lora_list, value=lora_name_en), gr.update(visible=False), gr.update(visible=False),[lora_name_en] + lora_list
98 |
99 |
100 | def find_folders(directory):
101 | folders = []
102 | for item in os.listdir(directory):
103 | item_path = os.path.join(directory, item)
104 | if os.path.isdir(item_path):
105 | folders.append(item)
106 | return folders
107 |
108 |
109 | def get_lora_init_list(model_name):
110 | model_file_name = llm_model_dict[model_name]['name']
111 | model_dir = os.path.join(f"finetune", model_file_name,'checkpoints')
112 | if not os.path.exists(model_dir):
113 | os.makedirs(model_dir)
114 | lora_list = find_folders(model_dir)
115 | return lora_list
116 |
117 |
118 | def get_lora_list(model_name):
119 | model_file_name = llm_model_dict[model_name]['name']
120 | model_dir = os.path.join(f"finetune", model_file_name,'checkpoints')
121 | if not os.path.exists(model_dir):
122 | os.makedirs(model_dir)
123 | lora_list = find_folders(model_dir)
124 | return gr.update(visible=True, choices=lora_list+ ['新建'], value=lora_list[0] if len(lora_list) > 0 else '新建'), lora_list + ['新建']
125 |
126 | lora_init_list = get_lora_init_list(llm_model_dict_list[0])
127 |
128 | webui_title = """
129 | # 🎉病历结构化🎉
130 |
131 | 可以选择案例测试和[使用excel配置文件进行训练-预测-评估](https://zg0f0ipp6j.feishu.cn/wiki/XC16wwvGgiVSNbkzSPUczqFMn6e)
132 | """
133 |
134 | def create_tab():
135 | # 初始化
136 | with gr.Blocks() as demo:
137 | set_lora_list = gr.State(lora_init_list+ ['新建'])
138 | gr.Markdown(webui_title)
139 | with gr.Row():
140 | with gr.Column():
141 | model_name = gr.Radio(llm_model_dict_list,
142 | label="选择模型",
143 | value= llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型',
144 | interactive=True)
145 | with gr.Column():
146 | select_lora = gr.Dropdown(set_lora_list.value,
147 | label= "选择或者新建一个Lora",
148 | value= set_lora_list.value[0] if len(set_lora_list.value) > 0 else '新建',
149 | interactive=True,
150 | visible=True)
151 | lora_name_en = gr.Textbox(label="请输入Lora英文名称,中间不能有空格,小写字母,单词间可用下划线分开",
152 | lines=1,
153 | interactive=True,
154 | visible=False)
155 | lora_add = gr.Button(value="确认添加Lora", visible=False)
156 | with gr.Row():
157 | lastest_model = gr.Textbox(type="text", label='模型更新时间(请切换模型或项目刷新显示)')
158 | with gr.Tab("案例测试"):
159 | with gr.Column():
160 | gr.Markdown(f"初次加载模型可能比较慢,后续会变快")
161 | field_type = gr.Radio(['单选','多选','提取'],
162 | label="字段类型",
163 | value='提取',
164 | interactive=True)
165 | field_name = gr.Textbox(label="字段名",
166 | lines=1,
167 | interactive=True)
168 | value_range = gr.Textbox(label="请输入值域,以','分隔开(对于提取不必输入值域)",
169 | lines=1,
170 | interactive=True)
171 | special_requirement= gr.Textbox(label="特殊说明,假如有的话请填上",
172 | lines=1,
173 | interactive=True)
174 | query = gr.Textbox(label="请输入原文",
175 | lines=1,
176 | interactive=True)
177 | query_button = gr.Button(label="获得结果")
178 | query_res = gr.Textbox(type="text", label='')
179 |
180 | with gr.Tab("训练-预测-评估", visible=False):
181 | gr.Markdown(f"""
182 | Step1:选择一个Lora
183 | Step2:根据任务选择训练 预测或评估,上传对应的参数文件或者数据标准文件,请等待文件上传成功后再开始执行!""")
184 | with gr.Row():
185 | with gr.Column():
186 | gr.Markdown("## 训练")
187 | train_config_file = gr.File(label="上传配置文件", file_types=['.xlsx'])
188 | train_data_file = gr.File(label="上传标注数据文件", file_types=['.xlsx'])
189 | train_button = gr.Button("开始训练", label="训练")
190 | kill_train_button = gr.Button("停止所有训练进程", label="训练")
191 | train_res = gr.Textbox(type="text", label='')
192 |
193 |
194 | with gr.Column():
195 | gr.Markdown("## 预测")
196 | test_config_file = gr.File(label="上传配置文件", file_types=['.xlsx'])
197 | test_data_file = gr.File(label="上传测试数据文件", file_types=['.xlsx'])
198 | test_button = gr.Button(label="评估")
199 | test_res = gr.Textbox(type="text", label='')
200 | download_test = gr.File(label="下载结果文件")
201 |
202 | with gr.Column():
203 | gr.Markdown("## 评估")
204 | test_result_file = gr.File(label="上传测试结果文件", file_types=['.xlsx'])
205 | test_label_file = gr.File(label="上传标准结果文件", file_types=['.xlsx'])
206 | evaluate_button = gr.Button(label="评估")
207 | download_evaluate = gr.File(label="下载评估结果")
208 |
209 | select_lora.change(fn=change_lora_name_input,
210 | inputs=[model_name,select_lora],
211 | outputs=[lora_name_en, lora_add,lastest_model])
212 | lora_add.click(fn=add_lora,
213 | inputs=[lora_name_en,set_lora_list],
214 | outputs=[select_lora, lora_name_en, lora_add,set_lora_list])
215 | model_name.change(fn=get_lora_list, inputs=[model_name], outputs=[select_lora, set_lora_list])
216 | train_config_file.upload(upload_file,
217 | inputs=train_config_file)
218 | train_data_file.upload(upload_file,
219 | inputs=train_data_file)
220 | test_config_file.upload(upload_file,
221 | inputs=test_config_file)
222 | test_data_file.upload(upload_file,
223 | inputs=test_data_file)
224 | test_result_file.upload(upload_file,
225 | inputs=test_result_file)
226 | test_label_file.upload(upload_file,
227 | inputs=test_label_file)
228 | train_button.click(on_train, inputs=[model_name, select_lora, train_config_file, train_data_file],outputs=[train_res])
229 | kill_train_button.click(on_stop, inputs=[model_name, select_lora],outputs=[train_res])
230 | test_button.click(on_test,show_progress=True, inputs=[model_name, select_lora, test_config_file, test_data_file], outputs=[download_test,test_res])
231 | evaluate_button.click(on_evaluate,show_progress=True, inputs=[model_name, select_lora,test_result_file, test_label_file], outputs=[download_evaluate])
232 | query_button.click(on_query,show_progress=True, inputs=[model_name, select_lora, field_type, field_name, value_range, special_requirement, query], outputs=[query_res])
233 | return demo
234 |
235 | tab = create_tab()
236 |
237 | if __name__ == "__main__":
238 | tab.queue(concurrency_count=5).launch(server_name='0.0.0.0',server_port=33366,share=True, inbrowser=True) #
239 |
--------------------------------------------------------------------------------
/config/common_config.py:
--------------------------------------------------------------------------------
1 |
2 | ## 这里配置本地可用的开源模型
3 | llm_model_dict = {
4 | "PULSE": {"name": "pulse",
5 | "model_path": "/path/to/your/model",
6 | "template":"default",
7 | "lora_target":"query_key_value",
8 | "per_device_train_batch_size":2
9 | },
10 | "InternLM": {"name": "internlm",
11 | "model_path": "/path/to/your/model",
12 | "template":"intern",
13 | "lora_target":"q_proj,v_proj",
14 | "per_device_train_batch_size":2
15 | },
16 | "ChatGLM2": {"name": "chatglm2",
17 | "model_path": "/path/to/your/model",
18 | "template":"chatglm2",
19 | "lora_target":"query_key_value",
20 | "per_device_train_batch_size":4
21 | },
22 | "ChatGLM3": {"name": "chatglm3",
23 | "model_path": "/path/to/your/model",
24 | "template":"chatglm3",
25 | "lora_target":"query_key_value",
26 | "per_device_train_batch_size":4
27 | },
28 | }
29 |
30 | # 找到 profile.d/conda.sh 文件的绝对路径,填进来
31 | conda_env_file = '/path-to-your-conda/etc/profile.d/conda.sh'
32 |
33 | # 生成参数
34 | max_length=1500
35 | do_sample=False
36 | temperature=0
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: llm_structure
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=5.1=1_gnu
7 | - ca-certificates=2023.08.22=h06a4308_0
8 | - ld_impl_linux-64=2.38=h1181459_1
9 | - libffi=3.4.4=h6a678d5_0
10 | - libgcc-ng=11.2.0=h1234567_1
11 | - libgomp=11.2.0=h1234567_1
12 | - libstdcxx-ng=11.2.0=h1234567_1
13 | - ncurses=6.4=h6a678d5_0
14 | - openssl=3.0.12=h7f8727e_0
15 | - pip=23.3=py39h06a4308_0
16 | - python=3.9.18=h955ad1f_0
17 | - readline=8.2=h5eee18b_0
18 | - setuptools=68.0.0=py39h06a4308_0
19 | - sqlite=3.41.2=h5eee18b_0
20 | - tk=8.6.12=h1ccaba5_0
21 | - tzdata=2023c=h04d1e81_0
22 | - wheel=0.41.2=py39h06a4308_0
23 | - xz=5.4.2=h5eee18b_0
24 | - zlib=1.2.13=h5eee18b_0
25 | - pip:
26 | - accelerate==0.24.1
27 | - aiofiles==23.2.1
28 | - aiohttp==3.8.6
29 | - aiosignal==1.3.1
30 | - altair==5.1.2
31 | - anyio==4.0.0
32 | - async-timeout==4.0.3
33 | - attrs==23.1.0
34 | - certifi==2023.7.22
35 | - charset-normalizer==3.3.2
36 | - click==8.1.7
37 | - cmake==3.27.7
38 | - contourpy==1.2.0
39 | - cycler==0.12.1
40 | - datasets==2.14.6
41 | - dill==0.3.7
42 | - docstring-parser==0.15
43 | - et-xmlfile==1.1.0
44 | - fastapi==0.95.1
45 | - ffmpy==0.3.1
46 | - filelock==3.13.1
47 | - fire==0.5.0
48 | - fonttools==4.44.0
49 | - frozenlist==1.4.0
50 | - fsspec==2023.10.0
51 | - gradio==3.50.2
52 | - gradio-client==0.6.1
53 | - h11==0.14.0
54 | - httpcore==1.0.1
55 | - httpx==0.25.1
56 | - huggingface-hub==0.18.0
57 | - idna==3.4
58 | - importlib-resources==6.1.1
59 | - jieba==0.42.1
60 | - jinja2==3.1.2
61 | - joblib==1.3.2
62 | - jsonschema==4.19.2
63 | - jsonschema-specifications==2023.7.1
64 | - kiwisolver==1.4.5
65 | - lit==17.0.4
66 | - markdown-it-py==3.0.0
67 | - markupsafe==2.1.3
68 | - matplotlib==3.8.1
69 | - mdurl==0.1.2
70 | - mpmath==1.3.0
71 | - multidict==6.0.4
72 | - multiprocess==0.70.15
73 | - networkx==3.2.1
74 | - nltk==3.8.1
75 | - numpy==1.26.1
76 | - nvidia-cublas-cu11==11.10.3.66
77 | - nvidia-cuda-cupti-cu11==11.7.101
78 | - nvidia-cuda-nvrtc-cu11==11.7.99
79 | - nvidia-cuda-runtime-cu11==11.7.99
80 | - nvidia-cudnn-cu11==8.5.0.96
81 | - nvidia-cufft-cu11==10.9.0.58
82 | - nvidia-curand-cu11==10.2.10.91
83 | - nvidia-cusolver-cu11==11.4.0.1
84 | - nvidia-cusparse-cu11==11.7.4.91
85 | - nvidia-nccl-cu11==2.14.3
86 | - nvidia-nvtx-cu11==11.7.91
87 | - openpyxl==3.0.10
88 | - orjson==3.9.10
89 | - pandas==1.5.3
90 | - peft==0.6.0
91 | - pillow==10.1.0
92 | - protobuf==4.25.0
93 | - pyarrow==14.0.0
94 | - pydantic==1.10.11
95 | - pydub==0.25.1
96 | - pynvml==11.5.0
97 | - pyparsing==3.1.1
98 | - python-multipart==0.0.6
99 | - pytz==2023.3.post1
100 | - pyyaml==6.0.1
101 | - referencing==0.30.2
102 | - regex==2023.10.3
103 | - requests==2.31.0
104 | - rich==13.6.0
105 | - rouge-chinese==1.0.3
106 | - rpds-py==0.12.0
107 | - safetensors==0.4.0
108 | - scikit-learn==1.2.2
109 | - scipy==1.10.1
110 | - semantic-version==2.10.0
111 | - sentencepiece==0.1.99
112 | - shtab==1.6.4
113 | - sniffio==1.3.0
114 | - sse-starlette==1.6.5
115 | - starlette==0.26.1
116 | - sympy==1.12
117 | - termcolor==2.3.0
118 | - threadpoolctl==3.2.0
119 | - tiktoken==0.5.1
120 | - tokenizers==0.13.3
121 | - toolz==0.12.0
122 | - torch==2.0.0
123 | - torchvision==0.15.1
124 | - tqdm==4.66.1
125 | - transformers==4.33.2
126 | - triton==2.0.0
127 | - trl==0.7.2
128 | - tyro==0.5.12
129 | - urllib3==2.0.7
130 | - uvicorn==0.24.0.post1
131 | - websockets==11.0.3
132 | - xlsxwriter==3.1.9
133 | - xxhash==3.4.1
134 | - yarl==1.9.2
--------------------------------------------------------------------------------
/img/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/1.jpg
--------------------------------------------------------------------------------
/img/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/2.jpg
--------------------------------------------------------------------------------
/img/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/3.jpg
--------------------------------------------------------------------------------
/img/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/4.jpg
--------------------------------------------------------------------------------
/img/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/5.jpg
--------------------------------------------------------------------------------
/llmtuner/__init__.py:
--------------------------------------------------------------------------------
1 | # Level: api, webui > chat > tuner > dsets > extras, hparams
2 |
3 | from llmtuner.api import create_app
4 | from llmtuner.chat import ChatModel
5 | from llmtuner.tuner import export_model, run_exp
6 | from llmtuner.webui import create_ui, create_web_demo
7 |
8 |
9 | __version__ = "0.2.0"
10 |
--------------------------------------------------------------------------------
/llmtuner/api/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.api.app import create_app
2 |
--------------------------------------------------------------------------------
/llmtuner/api/app.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uvicorn
3 | from fastapi import FastAPI, HTTPException, status
4 | from fastapi.middleware.cors import CORSMiddleware
5 | from contextlib import asynccontextmanager
6 | from sse_starlette import EventSourceResponse
7 | from typing import List, Tuple
8 | from pydantic import BaseModel
9 |
10 | from llmtuner.extras.misc import torch_gc
11 | from llmtuner.chat import ChatModel
12 | from llmtuner.api.protocol import (
13 | Role,
14 | Finish,
15 | ModelCard,
16 | ModelList,
17 | ChatMessage,
18 | DeltaMessage,
19 | ChatCompletionRequest,
20 | ChatCompletionResponse,
21 | ChatCompletionStreamResponse,
22 | ChatCompletionResponseChoice,
23 | ChatCompletionResponseStreamChoice,
24 | ChatCompletionResponseUsage
25 | )
26 |
27 |
28 | @asynccontextmanager
29 | async def lifespan(app: FastAPI): # collects GPU memory
30 | yield
31 | torch_gc()
32 |
33 |
34 | def to_json(data: BaseModel) -> str:
35 | try:
36 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
37 | except:
38 | return data.json(exclude_unset=True, ensure_ascii=False)
39 |
40 |
41 | def create_app(chat_model: ChatModel) -> FastAPI:
42 | app = FastAPI(lifespan=lifespan)
43 |
44 | app.add_middleware(
45 | CORSMiddleware,
46 | allow_origins=["*"],
47 | allow_credentials=True,
48 | allow_methods=["*"],
49 | allow_headers=["*"],
50 | )
51 |
52 | @app.get("/v1/models", response_model=ModelList)
53 | async def list_models():
54 | model_card = ModelCard(id="gpt-3.5-turbo")
55 | return ModelList(data=[model_card])
56 |
57 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
58 | async def create_chat_completion(request: ChatCompletionRequest):
59 | if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
60 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
61 |
62 | query = request.messages[-1].content
63 | prev_messages = request.messages[:-1]
64 | if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
65 | system = prev_messages.pop(0).content
66 | else:
67 | system = None
68 |
69 | history = []
70 | if len(prev_messages) % 2 == 0:
71 | for i in range(0, len(prev_messages), 2):
72 | if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
73 | history.append([prev_messages[i].content, prev_messages[i+1].content])
74 | else:
75 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
76 |
77 | if request.stream:
78 | generate = predict(query, history, system, request)
79 | return EventSourceResponse(generate, media_type="text/event-stream")
80 |
81 | response, (prompt_length, response_length) = chat_model.chat(
82 | query, history, system,
83 | do_sample=request.do_sample,
84 | temperature=request.temperature,
85 | top_p=request.top_p,
86 | max_new_tokens=request.max_tokens,
87 | num_return_sequences=request.n
88 | )
89 |
90 | usage = ChatCompletionResponseUsage(
91 | prompt_tokens=prompt_length,
92 | completion_tokens=response_length,
93 | total_tokens=prompt_length+response_length
94 | )
95 |
96 | choices = [ChatCompletionResponseChoice(
97 | index=i,
98 | message=ChatMessage(role=Role.ASSISTANT, content=choice),
99 | finish_reason=Finish.STOP
100 | ) for i, choice in enumerate(response)]
101 |
102 | return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
103 |
104 | async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
105 | choice_data = ChatCompletionResponseStreamChoice(
106 | index=0,
107 | delta=DeltaMessage(role=Role.ASSISTANT),
108 | finish_reason=None
109 | )
110 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
111 | yield to_json(chunk)
112 |
113 | for new_text in chat_model.stream_chat(
114 | query, history, system,
115 | do_sample=request.do_sample,
116 | temperature=request.temperature,
117 | top_p=request.top_p,
118 | max_new_tokens=request.max_tokens
119 | ):
120 | if len(new_text) == 0:
121 | continue
122 |
123 | choice_data = ChatCompletionResponseStreamChoice(
124 | index=0,
125 | delta=DeltaMessage(content=new_text),
126 | finish_reason=None
127 | )
128 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
129 | yield to_json(chunk)
130 |
131 | choice_data = ChatCompletionResponseStreamChoice(
132 | index=0,
133 | delta=DeltaMessage(),
134 | finish_reason=Finish.STOP
135 | )
136 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
137 | yield to_json(chunk)
138 | yield "[DONE]"
139 |
140 | return app
141 |
142 |
143 | if __name__ == "__main__":
144 | chat_model = ChatModel()
145 | app = create_app(chat_model)
146 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
147 |
--------------------------------------------------------------------------------
/llmtuner/api/protocol.py:
--------------------------------------------------------------------------------
1 | import time
2 | from enum import Enum
3 | from pydantic import BaseModel, Field
4 | from typing import List, Optional
5 |
6 |
7 | class Role(str, Enum):
8 | USER = "user"
9 | ASSISTANT = "assistant"
10 | SYSTEM = "system"
11 |
12 |
13 | class Finish(str, Enum):
14 | STOP = "stop"
15 | LENGTH = "length"
16 |
17 |
18 | class ModelCard(BaseModel):
19 | id: str
20 | object: Optional[str] = "model"
21 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
22 | owned_by: Optional[str] = "owner"
23 |
24 |
25 | class ModelList(BaseModel):
26 | object: Optional[str] = "list"
27 | data: Optional[List[ModelCard]] = []
28 |
29 |
30 | class ChatMessage(BaseModel):
31 | role: Role
32 | content: str
33 |
34 |
35 | class DeltaMessage(BaseModel):
36 | role: Optional[Role] = None
37 | content: Optional[str] = None
38 |
39 |
40 | class ChatCompletionRequest(BaseModel):
41 | model: str
42 | messages: List[ChatMessage]
43 | do_sample: Optional[bool] = True
44 | temperature: Optional[float] = None
45 | top_p: Optional[float] = None
46 | n: Optional[int] = 1
47 | max_tokens: Optional[int] = None
48 | stream: Optional[bool] = False
49 |
50 |
51 | class ChatCompletionResponseChoice(BaseModel):
52 | index: int
53 | message: ChatMessage
54 | finish_reason: Finish
55 |
56 |
57 | class ChatCompletionResponseStreamChoice(BaseModel):
58 | index: int
59 | delta: DeltaMessage
60 | finish_reason: Optional[Finish] = None
61 |
62 |
63 | class ChatCompletionResponseUsage(BaseModel):
64 | prompt_tokens: int
65 | completion_tokens: int
66 | total_tokens: int
67 |
68 |
69 | class ChatCompletionResponse(BaseModel):
70 | id: Optional[str] = "chatcmpl-default"
71 | object: Optional[str] = "chat.completion"
72 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
73 | model: str
74 | choices: List[ChatCompletionResponseChoice]
75 | usage: ChatCompletionResponseUsage
76 |
77 |
78 | class ChatCompletionStreamResponse(BaseModel):
79 | id: Optional[str] = "chatcmpl-default"
80 | object: Optional[str] = "chat.completion.chunk"
81 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
82 | model: str
83 | choices: List[ChatCompletionResponseStreamChoice]
84 |
--------------------------------------------------------------------------------
/llmtuner/chat/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.chat.stream_chat import ChatModel
2 |
--------------------------------------------------------------------------------
/llmtuner/chat/stream_chat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Generator, List, Optional, Tuple
3 | from threading import Thread
4 | from transformers import GenerationConfig, TextIteratorStreamer
5 |
6 | from llmtuner.extras.misc import dispatch_model, get_logits_processor
7 | from llmtuner.extras.template import get_template_and_fix_tokenizer
8 | from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
9 | import re
10 |
11 | class ChatModel:
12 |
13 | def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
14 | model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
15 | self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
16 | self.tokenizer.padding_side = "left"
17 | self.model = dispatch_model(self.model)
18 | self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
19 | self.system_prompt = data_args.system_prompt
20 |
21 | def process_args(
22 | self,
23 | query: str,
24 | history: Optional[List[Tuple[str, str]]] = None,
25 | system: Optional[str] = None,
26 | **input_kwargs
27 | ) -> Tuple[Dict[str, Any], int]:
28 | system = system or self.system_prompt
29 | prompt, _ = self.template.encode_oneturn(
30 | tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
31 | )
32 | prompt_length = len(prompt)
33 | input_ids = torch.tensor([prompt], device=self.model.device)
34 |
35 | do_sample = input_kwargs.pop("do_sample", None)
36 | temperature = input_kwargs.pop("temperature", None)
37 | top_p = input_kwargs.pop("top_p", None)
38 | top_k = input_kwargs.pop("top_k", None)
39 | num_return_sequences = input_kwargs.pop("num_return_sequences", None)
40 | repetition_penalty = input_kwargs.pop("repetition_penalty", None)
41 | max_length = input_kwargs.pop("max_length", None)
42 | max_new_tokens = input_kwargs.pop("max_new_tokens", None)
43 |
44 | generating_args = self.generating_args.to_dict()
45 | generating_args.update(dict(
46 | do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
47 | temperature=temperature or generating_args["temperature"],
48 | top_p=top_p or generating_args["top_p"],
49 | top_k=top_k or generating_args["top_k"],
50 | num_return_sequences=num_return_sequences or 1,
51 | repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
52 | eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
53 | pad_token_id=self.tokenizer.pad_token_id
54 | ))
55 |
56 | if isinstance(num_return_sequences, int) and num_return_sequences > 1:
57 | generating_args["do_sample"] = True
58 |
59 | if max_length:
60 | generating_args.pop("max_new_tokens", None)
61 | generating_args["max_length"] = max_length
62 |
63 | if max_new_tokens:
64 | generating_args.pop("max_length", None)
65 | generating_args["max_new_tokens"] = max_new_tokens
66 |
67 | gen_kwargs = dict(
68 | inputs=input_ids,
69 | generation_config=GenerationConfig(**generating_args),
70 | logits_processor=get_logits_processor()
71 | )
72 |
73 | return gen_kwargs, prompt_length
74 |
75 | @torch.inference_mode()
76 | def chat(
77 | self,
78 | query: str,
79 | history: Optional[List[Tuple[str, str]]] = None,
80 | system: Optional[str] = None,
81 | **input_kwargs
82 | ) -> Tuple[List[str], Tuple[int, int]]:
83 | gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
84 | generate_output = self.model.generate(**gen_kwargs)
85 | outputs = generate_output.tolist()[0][prompt_length:]
86 | response = self.tokenizer.decode(outputs, skip_special_tokens=True)
87 | response_length = len(outputs)
88 | response=re.sub(r'Helper: ?', '', response)
89 | return response, (prompt_length, response_length)
90 |
91 | @torch.inference_mode()
92 | def batch_chat(
93 | self,
94 | query: List[str],
95 | history: Optional[List[Tuple[str, str]]] = None,
96 | system: Optional[str] = None,
97 | **input_kwargs
98 | ) -> Tuple[List[str], Tuple[int, int]]:
99 | gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
100 | generate_output = self.model.generate(**gen_kwargs)
101 | response_ids = generate_output[:, prompt_length:]
102 | response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
103 | response_length = 0
104 | for i in range(len(response_ids)):
105 | eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
106 | response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
107 | return response, (prompt_length, response_length)
108 |
109 | @torch.inference_mode()
110 | def stream_chat(
111 | self,
112 | query: str,
113 | history: Optional[List[Tuple[str, str]]] = None,
114 | system: Optional[str] = None,
115 | **input_kwargs
116 | ) -> Generator[str, None, None]:
117 | gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
118 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
119 | gen_kwargs["streamer"] = streamer
120 |
121 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
122 | thread.start()
123 |
124 | yield from streamer
125 |
126 |
--------------------------------------------------------------------------------
/llmtuner/dsets/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.dsets.loader import get_dataset
2 | from llmtuner.dsets.preprocess import preprocess_dataset
3 | from llmtuner.dsets.utils import split_dataset
4 |
--------------------------------------------------------------------------------
/llmtuner/dsets/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import TYPE_CHECKING, List, Union
3 |
4 | from datasets import concatenate_datasets, interleave_datasets, load_dataset
5 |
6 | from llmtuner.dsets.utils import checksum, EXT2TYPE
7 | from llmtuner.extras.logging import get_logger
8 |
9 | if TYPE_CHECKING:
10 | from datasets import Dataset, IterableDataset
11 | from llmtuner.hparams import ModelArguments, DataArguments
12 |
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def get_dataset(
18 | model_args: "ModelArguments",
19 | data_args: "DataArguments"
20 | ) -> Union["Dataset", "IterableDataset"]:
21 | max_samples = data_args.max_samples
22 | all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
23 |
24 | for dataset_attr in data_args.dataset_list:
25 | logger.info("Loading dataset {}...".format(dataset_attr))
26 |
27 | if dataset_attr.load_from == "hf_hub":
28 | data_path = dataset_attr.dataset_name
29 | data_files = None
30 | elif dataset_attr.load_from == "script":
31 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
32 | data_files = None
33 | elif dataset_attr.load_from == "file":
34 | data_path = None
35 | data_files: List[str] = []
36 |
37 | if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
38 | for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
39 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
40 | if data_path is None:
41 | data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
42 | else:
43 | assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
44 | elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
45 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
46 | data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
47 | else:
48 | raise ValueError("File not found.")
49 |
50 | assert data_path, "File extension must be txt, csv, json or jsonl."
51 | checksum(data_files, dataset_attr.dataset_sha1)
52 | else:
53 | raise NotImplementedError
54 |
55 | dataset = load_dataset(
56 | data_path,
57 | data_files=data_files,
58 | split=data_args.split,
59 | cache_dir=model_args.cache_dir,
60 | streaming=data_args.streaming,
61 | use_auth_token=True if model_args.use_auth_token else None
62 | )
63 |
64 | if max_samples is not None:
65 | max_samples_temp = min(len(dataset), max_samples)
66 | dataset = dataset.select(range(max_samples_temp))
67 |
68 | # TODO: adapt to the sharegpt format
69 |
70 | for column_name in ["prompt", "query", "response", "history"]: # align datasets
71 | if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
72 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
73 |
74 | if dataset_attr.system_prompt: # add system prompt
75 | system_prompt = dataset_attr.system_prompt
76 | if data_args.streaming:
77 | dataset = dataset.map(lambda _: {"system": system_prompt})
78 | else:
79 | dataset = dataset.add_column("system", [system_prompt] * len(dataset))
80 |
81 | all_datasets.append(dataset)
82 |
83 | if len(data_args.dataset_list) == 1:
84 | return all_datasets[0]
85 | elif data_args.mix_strategy == "concat":
86 | if data_args.streaming:
87 | logger.warning("The samples between different datasets will not be mixed in streaming mode.")
88 | return concatenate_datasets(all_datasets)
89 | elif data_args.mix_strategy.startswith("interleave"):
90 | if not data_args.streaming:
91 | logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
92 | return interleave_datasets(
93 | datasets=all_datasets,
94 | probabilities=data_args.interleave_probs,
95 | seed=data_args.seed,
96 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
97 | )
98 | else:
99 | raise ValueError("Unknown mixing strategy.")
100 |
--------------------------------------------------------------------------------
/llmtuner/dsets/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from typing import TYPE_CHECKING, Dict, List, Optional, Union
3 |
4 | from llmtuner.extras.logging import get_logger
5 |
6 | if TYPE_CHECKING:
7 | from datasets import Dataset, IterableDataset
8 | from transformers import TrainingArguments
9 | from llmtuner.hparams import DataArguments
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | EXT2TYPE = {
16 | "csv": "csv",
17 | "json": "json",
18 | "jsonl": "json",
19 | "txt": "text"
20 | }
21 |
22 |
23 | def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
24 | if file_sha1 is None:
25 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
26 | return
27 |
28 | if len(data_files) != 1:
29 | logger.warning("Checksum failed: too many files.")
30 | return
31 |
32 | with open(data_files[0], "rb") as f:
33 | sha1 = hashlib.sha1(f.read()).hexdigest()
34 | if sha1 != file_sha1:
35 | logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
36 |
37 |
38 | def split_dataset(
39 | dataset: Union["Dataset", "IterableDataset"],
40 | data_args: "DataArguments",
41 | training_args: "TrainingArguments"
42 | ) -> Dict[str, "Dataset"]:
43 | if training_args.do_train:
44 | if data_args.val_size > 1e-6: # Split the dataset
45 | if data_args.streaming:
46 | val_set = dataset.take(int(data_args.val_size))
47 | train_set = dataset.skip(int(data_args.val_size))
48 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
49 | return {"train_dataset": train_set, "eval_dataset": val_set}
50 | else:
51 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
52 | dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
53 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
54 | else:
55 | if data_args.streaming:
56 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
57 | return {"train_dataset": dataset}
58 | else: # do_eval or do_predict
59 | return {"eval_dataset": dataset}
60 |
--------------------------------------------------------------------------------
/llmtuner/extras/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/llmtuner/extras/__init__.py
--------------------------------------------------------------------------------
/llmtuner/extras/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | from typing import TYPE_CHECKING
5 | from datetime import timedelta
6 |
7 | from transformers import TrainerCallback
8 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
9 |
10 | from llmtuner.extras.constants import LOG_FILE_NAME
11 | from llmtuner.extras.logging import get_logger
12 |
13 | if TYPE_CHECKING:
14 | from transformers import TrainingArguments, TrainerState, TrainerControl
15 |
16 |
17 | logger = get_logger(__name__)
18 |
19 |
20 | class SavePeftModelCallback(TrainerCallback):
21 |
22 | def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
23 | r"""
24 | Event called after a checkpoint save.
25 | """
26 | if args.should_save:
27 | output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
28 | model = kwargs.pop("model")
29 | if getattr(model, "is_peft_model", False):
30 | getattr(model, "pretrained_model").save_pretrained(output_dir)
31 |
32 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
33 | r"""
34 | Event called at the end of training.
35 | """
36 | if args.should_save:
37 | model = kwargs.pop("model")
38 | if getattr(model, "is_peft_model", False):
39 | getattr(model, "pretrained_model").save_pretrained(args.output_dir)
40 |
41 |
42 | class LogCallback(TrainerCallback):
43 |
44 | def __init__(self, runner=None):
45 | self.runner = runner
46 | self.in_training = False
47 | self.start_time = time.time()
48 | self.cur_steps = 0
49 | self.max_steps = 0
50 | self.elapsed_time = ""
51 | self.remaining_time = ""
52 |
53 | def timing(self):
54 | cur_time = time.time()
55 | elapsed_time = cur_time - self.start_time
56 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
57 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
58 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
59 | self.remaining_time = str(timedelta(seconds=int(remaining_time)))
60 |
61 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
62 | r"""
63 | Event called at the beginning of training.
64 | """
65 | if state.is_local_process_zero:
66 | self.in_training = True
67 | self.start_time = time.time()
68 | self.max_steps = state.max_steps
69 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
70 | logger.warning("Previous log file in this folder will be deleted.")
71 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
72 |
73 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
74 | r"""
75 | Event called at the end of training.
76 | """
77 | if state.is_local_process_zero:
78 | self.in_training = False
79 | self.cur_steps = 0
80 | self.max_steps = 0
81 |
82 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
83 | r"""
84 | Event called at the end of an substep during gradient accumulation.
85 | """
86 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
87 | control.should_epoch_stop = True
88 | control.should_training_stop = True
89 |
90 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
91 | r"""
92 | Event called at the end of a training step.
93 | """
94 | if state.is_local_process_zero:
95 | self.cur_steps = state.global_step
96 | self.timing()
97 | if self.runner is not None and self.runner.aborted:
98 | control.should_epoch_stop = True
99 | control.should_training_stop = True
100 |
101 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
102 | r"""
103 | Event called after an evaluation phase.
104 | """
105 | if state.is_local_process_zero and not self.in_training:
106 | self.cur_steps = 0
107 | self.max_steps = 0
108 |
109 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
110 | r"""
111 | Event called after a successful prediction.
112 | """
113 | if state.is_local_process_zero and not self.in_training:
114 | self.cur_steps = 0
115 | self.max_steps = 0
116 |
117 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
118 | r"""
119 | Event called after logging the last logs.
120 | """
121 | if not state.is_local_process_zero:
122 | return
123 |
124 | logs = dict(
125 | current_steps=self.cur_steps,
126 | total_steps=self.max_steps,
127 | loss=state.log_history[-1].get("loss", None),
128 | eval_loss=state.log_history[-1].get("eval_loss", None),
129 | predict_loss=state.log_history[-1].get("predict_loss", None),
130 | reward=state.log_history[-1].get("reward", None),
131 | learning_rate=state.log_history[-1].get("learning_rate", None),
132 | epoch=state.log_history[-1].get("epoch", None),
133 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
134 | elapsed_time=self.elapsed_time,
135 | remaining_time=self.remaining_time
136 | )
137 | if self.runner is not None:
138 | logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
139 | logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
140 | ))
141 |
142 | os.makedirs(args.output_dir, exist_ok=True)
143 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
144 | f.write(json.dumps(logs) + "\n")
145 |
146 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
147 | r"""
148 | Event called after a prediction step.
149 | """
150 | eval_dataloader = kwargs.pop("eval_dataloader", None)
151 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
152 | if self.max_steps == 0:
153 | self.max_steps = len(eval_dataloader)
154 | self.cur_steps += 1
155 | self.timing()
156 |
--------------------------------------------------------------------------------
/llmtuner/extras/constants.py:
--------------------------------------------------------------------------------
1 | IGNORE_INDEX = -100
2 |
3 | LOG_FILE_NAME = "trainer_log.jsonl"
4 |
5 | LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"]
6 |
7 | METHODS = ["full", "freeze", "lora"]
8 |
9 | TRAINING_STAGES = {
10 | "Supervised Fine-Tuning": "sft",
11 | "Reward Modeling": "rm",
12 | "PPO": "ppo",
13 | "DPO": "dpo",
14 | "Pre-Training": "pt"
15 | }
16 |
17 | SUPPORTED_MODELS = {
18 | "LLaMA-7B": "huggyllama/llama-7b",
19 | "LLaMA-13B": "huggyllama/llama-13b",
20 | "LLaMA-30B": "huggyllama/llama-30b",
21 | "LLaMA-65B": "huggyllama/llama-65b",
22 | "LLaMA2-7B": "meta-llama/Llama-2-7b-hf",
23 | "LLaMA2-13B": "meta-llama/Llama-2-13b-hf",
24 | "LLaMA2-70B": "meta-llama/Llama-2-70b-hf",
25 | "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
26 | "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
27 | "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
28 | "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
29 | "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
30 | "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
31 | "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
32 | "BLOOM-560M": "bigscience/bloom-560m",
33 | "BLOOM-3B": "bigscience/bloom-3b",
34 | "BLOOM-7B1": "bigscience/bloom-7b1",
35 | "BLOOMZ-560M": "bigscience/bloomz-560m",
36 | "BLOOMZ-3B": "bigscience/bloomz-3b",
37 | "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
38 | "Falcon-7B": "tiiuae/falcon-7b",
39 | "Falcon-40B": "tiiuae/falcon-40b",
40 | "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
41 | "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
42 | "Baichuan-7B": "baichuan-inc/Baichuan-7B",
43 | "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
44 | "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
45 | "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
46 | "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
47 | "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
48 | "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
49 | "InternLM-7B": "internlm/internlm-7b",
50 | "InternLM-20B": "internlm/internlm-20b",
51 | "InternLM-7B-Chat": "internlm/internlm-chat-7b",
52 | "InternLM-20B-Chat": "internlm/internlm-chat-20b",
53 | "Qwen-7B": "Qwen/Qwen-7B",
54 | "Qwen-14B": "Qwen/Qwen-14B",
55 | "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
56 | "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
57 | "XVERSE-13B": "xverse/XVERSE-13B",
58 | "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
59 | "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
60 | "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
61 | "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
62 | "Phi1.5-1.3B": "microsoft/phi-1_5"
63 | }
64 |
65 | DEFAULT_MODULE = {
66 | "LLaMA": "q_proj,v_proj",
67 | "LLaMA2": "q_proj,v_proj",
68 | "ChineseLLaMA2": "q_proj,v_proj",
69 | "BLOOM": "query_key_value",
70 | "BLOOMZ": "query_key_value",
71 | "Falcon": "query_key_value",
72 | "Baichuan": "W_pack",
73 | "Baichuan2": "W_pack",
74 | "InternLM": "q_proj,v_proj",
75 | "Qwen": "c_attn",
76 | "XVERSE": "q_proj,v_proj",
77 | "ChatGLM2": "query_key_value",
78 | "ChatGLM3": "query_key_value",
79 | "Phi1.5": "Wqkv"
80 | }
81 |
82 | DEFAULT_TEMPLATE = {
83 | "LLaMA2": "llama2",
84 | "ChineseLLaMA2": "llama2_zh",
85 | "Baichuan": "baichuan",
86 | "Baichuan2": "baichuan2",
87 | "InternLM": "intern",
88 | "Qwen": "chatml",
89 | "XVERSE": "xverse",
90 | "ChatGLM2": "chatglm2",
91 | "ChatGLM3": "chatglm3"
92 | }
93 |
--------------------------------------------------------------------------------
/llmtuner/extras/logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 |
4 |
5 | class LoggerHandler(logging.Handler):
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.log = ""
10 |
11 | def reset(self):
12 | self.log = ""
13 |
14 | def emit(self, record):
15 | if record.name == "httpx":
16 | return
17 | log_entry = self.format(record)
18 | self.log += log_entry
19 | self.log += "\n\n"
20 |
21 |
22 | def reset_logging():
23 | r"""
24 | Removes basic config of root logger
25 | """
26 | root = logging.getLogger()
27 | list(map(root.removeHandler, root.handlers))
28 | list(map(root.removeFilter, root.filters))
29 |
30 |
31 | def get_logger(name: str) -> logging.Logger:
32 | formatter = logging.Formatter(
33 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
34 | datefmt="%m/%d/%Y %H:%M:%S"
35 | )
36 | handler = logging.StreamHandler(sys.stdout)
37 | handler.setFormatter(formatter)
38 |
39 | logger = logging.getLogger(name)
40 | logger.setLevel(logging.INFO)
41 | logger.addHandler(handler)
42 |
43 | return logger
44 |
--------------------------------------------------------------------------------
/llmtuner/extras/misc.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import torch
3 | from typing import TYPE_CHECKING, Tuple
4 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
5 |
6 | try:
7 | from transformers.utils import (
8 | is_torch_bf16_cpu_available,
9 | is_torch_bf16_gpu_available,
10 | is_torch_cuda_available,
11 | is_torch_npu_available
12 | )
13 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
14 | _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
15 | except ImportError:
16 | _is_fp16_available = torch.cuda.is_available()
17 | _is_bf16_available = torch.cuda.is_bf16_supported()
18 |
19 | if TYPE_CHECKING:
20 | from transformers.modeling_utils import PreTrainedModel
21 |
22 |
23 | class AverageMeter:
24 | r"""
25 | Computes and stores the average and current value.
26 | """
27 | def __init__(self):
28 | self.reset()
29 |
30 | def reset(self):
31 | self.val = 0
32 | self.avg = 0
33 | self.sum = 0
34 | self.count = 0
35 |
36 | def update(self, val, n=1):
37 | self.val = val
38 | self.sum += val * n
39 | self.count += n
40 | self.avg = self.sum / self.count
41 |
42 |
43 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
44 | r"""
45 | Returns the number of trainable parameters and number of all parameters in the model.
46 | """
47 | trainable_params, all_param = 0, 0
48 | for param in model.parameters():
49 | num_params = param.numel()
50 | # if using DS Zero 3 and the weights are initialized empty
51 | if num_params == 0 and hasattr(param, "ds_numel"):
52 | num_params = param.ds_numel
53 |
54 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
55 | if param.__class__.__name__ == "Params4bit":
56 | num_params = num_params * 2
57 |
58 | all_param += num_params
59 | if param.requires_grad:
60 | trainable_params += num_params
61 |
62 | return trainable_params, all_param
63 |
64 |
65 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
66 | r"""
67 | Infers the optimal dtype according to the model_dtype and device compatibility.
68 | """
69 | if _is_bf16_available and model_dtype == torch.bfloat16:
70 | return torch.bfloat16
71 | elif _is_fp16_available:
72 | return torch.float16
73 | else:
74 | return torch.float32
75 |
76 |
77 | def get_logits_processor() -> LogitsProcessorList:
78 | r"""
79 | Gets logits processor that removes NaN and Inf logits.
80 | """
81 | logits_processor = LogitsProcessorList()
82 | logits_processor.append(InfNanRemoveLogitsProcessor())
83 | return logits_processor
84 |
85 |
86 | def torch_gc() -> None:
87 | r"""
88 | Collects GPU memory.
89 | """
90 | gc.collect()
91 | if torch.cuda.is_available():
92 | torch.cuda.empty_cache()
93 | torch.cuda.ipc_collect()
94 |
95 |
96 | def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
97 | r"""
98 | Dispatches a pre-trained model to GPUs with balanced memory.
99 | Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
100 | """
101 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
102 | return model
103 |
104 | if torch.cuda.device_count() > 1:
105 | from accelerate import dispatch_model
106 | from accelerate.utils import infer_auto_device_map, get_balanced_memory
107 |
108 | if model._no_split_modules is None:
109 | raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
110 |
111 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
112 | max_memory = get_balanced_memory(model, **kwargs)
113 | # Make sure tied weights are tied before creating the device map.
114 | model.tie_weights()
115 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
116 | return dispatch_model(model, device_map)
117 | else:
118 | return model.cuda()
119 |
--------------------------------------------------------------------------------
/llmtuner/extras/patches/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/llmtuner/extras/patches/__init__.py
--------------------------------------------------------------------------------
/llmtuner/extras/patches/llama_patch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from typing import Optional, Tuple
5 | from transformers.utils import logging
6 | from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
7 |
8 | try:
9 | from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
10 | from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
11 | except ImportError:
12 | print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
13 |
14 |
15 | logger = logging.get_logger(__name__)
16 |
17 |
18 | # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
19 | class LlamaShiftShortAttention(LlamaAttention):
20 |
21 | def forward(
22 | self,
23 | hidden_states: torch.Tensor,
24 | attention_mask: Optional[torch.Tensor] = None,
25 | position_ids: Optional[torch.LongTensor] = None,
26 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
27 | output_attentions: bool = False,
28 | use_cache: bool = False,
29 | **kwargs
30 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
31 | bsz, q_len, _ = hidden_states.size()
32 |
33 | query_states = self.q_proj(hidden_states)
34 | key_states = self.k_proj(hidden_states)
35 | value_states = self.v_proj(hidden_states)
36 |
37 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
38 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
39 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
40 |
41 | kv_seq_len = key_states.shape[-2]
42 | if past_key_value is not None:
43 | kv_seq_len += past_key_value[0].shape[-2]
44 |
45 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
46 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
47 |
48 | if past_key_value is not None: # reuse k, v, self_attention
49 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
50 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
51 |
52 | past_key_value = (key_states, value_states) if use_cache else None
53 |
54 | if getattr(self, "num_key_value_groups"):
55 | key_states = repeat_kv(key_states, self.num_key_value_groups)
56 | value_states = repeat_kv(value_states, self.num_key_value_groups)
57 |
58 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift
59 | groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
60 | assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
61 | num_groups = q_len // groupsz
62 | def shift(state: torch.Tensor) -> torch.Tensor:
63 | state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
64 | state = torch.cat((
65 | state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
66 | ), dim=2)
67 | return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
68 |
69 | query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
70 | if attention_mask is not None:
71 | attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
72 |
73 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
74 |
75 | if attention_mask is not None:
76 | attn_weights = attn_weights + attention_mask
77 |
78 | # upcast attention to fp32
79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
80 | attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
81 | attn_output = attn_output.transpose(1, 2).contiguous()
82 |
83 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
84 | attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
85 | attn_output = torch.cat((
86 | attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
87 | ))
88 |
89 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
90 | attn_output = self.o_proj(attn_output)
91 |
92 | if not output_attentions:
93 | attn_weights = None
94 |
95 | return attn_output, attn_weights, past_key_value
96 |
97 |
98 | class LlamaFlashAttention2(LlamaAttention):
99 |
100 | def forward(
101 | self,
102 | hidden_states: torch.Tensor,
103 | attention_mask: Optional[torch.Tensor] = None,
104 | position_ids: Optional[torch.LongTensor] = None,
105 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
106 | output_attentions: bool = False,
107 | use_cache: bool = False,
108 | **kwargs
109 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110 | # LlamaFlashAttention2 attention does not support output_attentions
111 | output_attentions = False
112 |
113 | bsz, q_len, _ = hidden_states.size()
114 |
115 | query_states = self.q_proj(hidden_states)
116 | key_states = self.k_proj(hidden_states)
117 | value_states = self.v_proj(hidden_states)
118 |
119 | # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
120 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
121 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
122 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
123 |
124 | kv_seq_len = key_states.shape[-2]
125 | if past_key_value is not None:
126 | kv_seq_len += past_key_value[0].shape[-2]
127 |
128 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
129 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
130 |
131 | if past_key_value is not None: # reuse k, v, self_attention
132 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
133 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
134 |
135 | past_key_value = (key_states, value_states) if use_cache else None
136 |
137 | # cast to half precision
138 | input_dtype = query_states.dtype
139 | if input_dtype == torch.float32:
140 | logger.warning_once("The input hidden states seems to be silently casted in float32.")
141 | query_states = query_states.to(self.config.torch_dtype)
142 | key_states = key_states.to(self.config.torch_dtype)
143 | value_states = value_states.to(self.config.torch_dtype)
144 |
145 | if getattr(self, "num_key_value_groups", None):
146 | key_states = repeat_kv(key_states, self.num_key_value_groups)
147 | value_states = repeat_kv(value_states, self.num_key_value_groups)
148 |
149 | query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
150 | key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
151 | value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
152 |
153 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift
154 | groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
155 | assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
156 | num_groups = q_len // groupsz
157 | def shift(state: torch.Tensor) -> torch.Tensor:
158 | state = torch.cat((
159 | state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
160 | ), dim=2)
161 | return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
162 |
163 | query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
164 | if attention_mask is not None:
165 | attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
166 |
167 | if attention_mask is not None:
168 | logger.warning_once("Padded sequences are less efficient in FlashAttention.")
169 | # -q_len: assumes left padding when q_len != kv_len
170 | unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
171 | unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
172 | unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
173 | attn_output_unpad = flash_attn_varlen_func(
174 | unpadded_q,
175 | unpadded_k,
176 | unpadded_v,
177 | cu_seqlens_q=cu_seqlens_q,
178 | cu_seqlens_k=cu_seqlens_k,
179 | max_seqlen_q=max_seqlen_q,
180 | max_seqlen_k=max_seqlen_k,
181 | dropout_p=0.0,
182 | softmax_scale=None,
183 | causal=True,
184 | )
185 | attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
186 | else:
187 | attn_output = flash_attn_func(
188 | query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
189 | )
190 |
191 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
192 | attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
193 | attn_output = torch.cat((
194 | attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
195 | ))
196 |
197 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
198 | attn_output = self.o_proj(attn_output)
199 |
200 | if not output_attentions:
201 | attn_weights = None
202 |
203 | return attn_output, attn_weights, past_key_value
204 |
205 |
206 | # Disable the transformation of the attention mask in LlamaModel as flash attention
207 | # takes a boolean padding_mask. Fills in the past kv length for use in forward.
208 | def _prepare_decoder_attention_mask(
209 | self,
210 | attention_mask: torch.Tensor,
211 | input_shape: torch.Tensor,
212 | inputs_embeds: torch.Tensor,
213 | past_key_values_length: int
214 | ) -> torch.Tensor:
215 | if attention_mask is not None and torch.all(attention_mask):
216 | return None # This uses the faster call when training with full samples
217 |
218 | return attention_mask
219 |
--------------------------------------------------------------------------------
/llmtuner/extras/ploting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import matplotlib.pyplot as plt
5 | from typing import List, Optional
6 | from transformers.trainer import TRAINER_STATE_NAME
7 |
8 | from llmtuner.extras.logging import get_logger
9 |
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def smooth(scalars: List[float]) -> List[float]:
15 | r"""
16 | EMA implementation according to TensorBoard.
17 | """
18 | last = scalars[0]
19 | smoothed = list()
20 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
21 | for next_val in scalars:
22 | smoothed_val = last * weight + (1 - weight) * next_val
23 | smoothed.append(smoothed_val)
24 | last = smoothed_val
25 | return smoothed
26 |
27 |
28 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
29 |
30 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
31 | data = json.load(f)
32 |
33 | for key in keys:
34 | steps, metrics = [], []
35 | for i in range(len(data["log_history"])):
36 | if key in data["log_history"][i]:
37 | steps.append(data["log_history"][i]["step"])
38 | metrics.append(data["log_history"][i][key])
39 |
40 | if len(metrics) == 0:
41 | logger.warning(f"No metric {key} to plot.")
42 | continue
43 |
44 | plt.figure()
45 | plt.plot(steps, metrics, alpha=0.4, label="original")
46 | plt.plot(steps, smooth(metrics), label="smoothed")
47 | plt.title("training {} of {}".format(key, save_dictionary))
48 | plt.xlabel("step")
49 | plt.ylabel(key)
50 | plt.legend()
51 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
52 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
53 |
--------------------------------------------------------------------------------
/llmtuner/extras/save_and_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from transformers.trainer import WEIGHTS_NAME
4 |
5 | from llmtuner.extras.logging import get_logger
6 |
7 |
8 | logger = get_logger(__name__)
9 |
10 |
11 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
12 | vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
13 | if not os.path.exists(vhead_file):
14 | logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
15 | return False
16 | vhead_params = torch.load(vhead_file, map_location="cpu")
17 | model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
18 | model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
19 | model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
20 | model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
21 | return True
22 |
--------------------------------------------------------------------------------
/llmtuner/hparams/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_args import DataArguments
2 | from .finetuning_args import FinetuningArguments
3 | from .generating_args import GeneratingArguments
4 | from .model_args import ModelArguments
5 |
--------------------------------------------------------------------------------
/llmtuner/hparams/data_args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import List, Literal, Optional
4 | from dataclasses import dataclass, field
5 |
6 |
7 | @dataclass
8 | class DatasetAttr:
9 |
10 | load_from: str
11 | dataset_name: Optional[str] = None
12 | dataset_sha1: Optional[str] = None
13 | system_prompt: Optional[str] = None
14 | ranking: Optional[bool] = False
15 | formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
16 |
17 | prompt: Optional[str] = "instruction"
18 | query: Optional[str] = "input"
19 | response: Optional[str] = "output"
20 | history: Optional[str] = None
21 |
22 | def __repr__(self) -> str:
23 | return self.dataset_name
24 |
25 |
26 | @dataclass
27 | class DataArguments:
28 | r"""
29 | Arguments pertaining to what data we are going to input our model for training and evaluation.
30 | """
31 | template: Optional[str] = field(
32 | default=None,
33 | metadata={"help": "Which template to use for constructing prompts in training and inference."}
34 | )
35 | dataset: Optional[str] = field(
36 | default=None,
37 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
38 | )
39 | dataset_dir: Optional[str] = field(
40 | default="data",
41 | metadata={"help": "The name of the folder containing datasets."}
42 | )
43 | split: Optional[str] = field(
44 | default="train",
45 | metadata={"help": "Which dataset split to use for training and evaluation."}
46 | )
47 | cutoff_len: Optional[int] = field(
48 | default=1024,
49 | metadata={"help": "The maximum length of the model inputs after tokenization."}
50 | )
51 | train_on_prompt: Optional[bool] = field(
52 | default=False,
53 | metadata={"help": "Whether to disable the mask on the prompt or not."}
54 | )
55 | streaming: Optional[bool] = field(
56 | default=False,
57 | metadata={"help": "Enable dataset streaming."}
58 | )
59 | buffer_size: Optional[int] = field(
60 | default=16384,
61 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
62 | )
63 | mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
64 | default="concat",
65 | metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
66 | )
67 | interleave_probs: Optional[str] = field(
68 | default=None,
69 | metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
70 | )
71 | overwrite_cache: Optional[bool] = field(
72 | default=False,
73 | metadata={"help": "Overwrite the cached training and evaluation sets."}
74 | )
75 | preprocessing_num_workers: Optional[int] = field(
76 | default=None,
77 | metadata={"help": "The number of processes to use for the preprocessing."}
78 | )
79 | max_samples: Optional[int] = field(
80 | default=None,
81 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
82 | )
83 | eval_num_beams: Optional[int] = field(
84 | default=None,
85 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
86 | )
87 | ignore_pad_token_for_loss: Optional[bool] = field(
88 | default=True,
89 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
90 | )
91 | system_prompt: Optional[str] = field(
92 | default=None,
93 | metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
94 | )
95 | val_size: Optional[float] = field(
96 | default=0,
97 | metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
98 | )
99 | sft_packing: Optional[bool] = field(
100 | default=False,
101 | metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
102 | )
103 | cache_path: Optional[str] = field(
104 | default=None,
105 | metadata={"help": "Path to save or load the preprocessed datasets."}
106 | )
107 |
108 | def __post_init__(self):
109 | if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
110 | raise ValueError("Streaming mode should have an integer val size.")
111 |
112 | if self.streaming and self.max_samples is not None:
113 | raise ValueError("`max_samples` is incompatible with `streaming`.")
114 |
115 | if self.streaming and self.cache_path:
116 | raise ValueError("`cache_path` is incompatible with `streaming`.")
117 |
118 | def init_for_training(self, seed: int): # support mixing multiple datasets
119 | self.seed = seed
120 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
121 | try:
122 | with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
123 | dataset_info = json.load(f)
124 | except Exception:
125 | if self.dataset is not None:
126 | raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
127 | dataset_info = None
128 |
129 | prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
130 | prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
131 | assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
132 |
133 | if self.interleave_probs is not None:
134 | self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
135 |
136 | self.dataset_list: List[DatasetAttr] = []
137 | for i, name in enumerate(dataset_names):
138 | if name not in dataset_info:
139 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
140 |
141 | if "hf_hub_url" in dataset_info[name]:
142 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
143 | elif "script_url" in dataset_info[name]:
144 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
145 | else:
146 | dataset_attr = DatasetAttr(
147 | "file",
148 | dataset_name=dataset_info[name]["file_name"],
149 | dataset_sha1=dataset_info[name].get("file_sha1", None)
150 | )
151 |
152 | if "columns" in dataset_info[name]:
153 | dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
154 | dataset_attr.query = dataset_info[name]["columns"].get("query", None)
155 | dataset_attr.response = dataset_info[name]["columns"].get("response", None)
156 | dataset_attr.history = dataset_info[name]["columns"].get("history", None)
157 |
158 | dataset_attr.ranking = dataset_info[name].get("ranking", False)
159 | dataset_attr.system_prompt = prompt_list[i]
160 | self.dataset_list.append(dataset_attr)
161 |
--------------------------------------------------------------------------------
/llmtuner/hparams/finetuning_args.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Literal, Optional
3 | from dataclasses import asdict, dataclass, field
4 |
5 |
6 | @dataclass
7 | class FinetuningArguments:
8 | r"""
9 | Arguments pertaining to which techniques we are going to fine-tuning with.
10 | """
11 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
12 | default="sft",
13 | metadata={"help": "Which stage will be performed in training."}
14 | )
15 | finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
16 | default="lora",
17 | metadata={"help": "Which fine-tuning method to use."}
18 | )
19 | num_layer_trainable: Optional[int] = field(
20 | default=3,
21 | metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
22 | )
23 | name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
24 | default="mlp",
25 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
26 | LLaMA choices: [\"mlp\", \"self_attn\"], \
27 | BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \
28 | Qwen choices: [\"mlp\", \"attn\"], \
29 | Phi-1.5 choices: [\"mlp\", \"mixer\"], \
30 | LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."}
31 | )
32 | lora_rank: Optional[int] = field(
33 | default=8,
34 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
35 | )
36 | lora_alpha: Optional[float] = field(
37 | default=32.0,
38 | metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
39 | )
40 | lora_dropout: Optional[float] = field(
41 | default=0.1,
42 | metadata={"help": "Dropout rate for the LoRA fine-tuning."}
43 | )
44 | lora_target: Optional[str] = field(
45 | default=None,
46 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
47 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
48 | BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
49 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
50 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
51 | Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
52 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
53 | )
54 | additional_target: Optional[str] = field(
55 | default=None,
56 | metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
57 | )
58 | resume_lora_training: Optional[bool] = field(
59 | default=True,
60 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
61 | )
62 | ppo_score_norm: Optional[bool] = field(
63 | default=False,
64 | metadata={"help": "Use score normalization in PPO training."}
65 | )
66 | ppo_logger: Optional[str] = field(
67 | default=None,
68 | metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
69 | )
70 | ppo_target: Optional[float] = field(
71 | default=6.0,
72 | metadata={"help": "Target KL value for adaptive KL control in PPO training."}
73 | )
74 | dpo_beta: Optional[float] = field(
75 | default=0.1,
76 | metadata={"help": "The beta parameter for the DPO loss."}
77 | )
78 | upcast_layernorm: Optional[bool] = field(
79 | default=False,
80 | metadata={"help": "Whether to upcast the layernorm weights in fp32."}
81 | )
82 | neft_alpha: Optional[float] = field(
83 | default=0,
84 | metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
85 | )
86 |
87 | def __post_init__(self):
88 | if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
89 | self.lora_target = [target.strip() for target in self.lora_target.split(",")]
90 |
91 | if isinstance(self.additional_target, str):
92 | self.additional_target = [target.strip() for target in self.additional_target.split(",")]
93 |
94 | assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
95 |
96 | def save_to_json(self, json_path: str):
97 | r"""Saves the content of this instance in JSON format inside `json_path`."""
98 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
99 | with open(json_path, "w", encoding="utf-8") as f:
100 | f.write(json_string)
101 |
102 | @classmethod
103 | def load_from_json(cls, json_path: str):
104 | r"""Creates an instance from the content of `json_path`."""
105 | with open(json_path, "r", encoding="utf-8") as f:
106 | text = f.read()
107 | return cls(**json.loads(text))
108 |
--------------------------------------------------------------------------------
/llmtuner/hparams/general_args.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 | from dataclasses import dataclass, field
3 |
4 |
5 | @dataclass
6 | class GeneralArguments:
7 | r"""
8 | Arguments pertaining to which stage we are going to perform.
9 | """
10 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
11 | default="sft",
12 | metadata={"help": "Which stage will be performed in training."}
13 | )
14 |
--------------------------------------------------------------------------------
/llmtuner/hparams/generating_args.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 | from dataclasses import asdict, dataclass, field
3 |
4 |
5 | @dataclass
6 | class GeneratingArguments:
7 | r"""
8 | Arguments pertaining to specify the decoding parameters.
9 | """
10 | do_sample: Optional[bool] = field(
11 | default=True,
12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
13 | )
14 | temperature: Optional[float] = field(
15 | default=0.95,
16 | metadata={"help": "The value used to modulate the next token probabilities."}
17 | )
18 | top_p: Optional[float] = field(
19 | default=0.7,
20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
21 | )
22 | top_k: Optional[int] = field(
23 | default=50,
24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
25 | )
26 | num_beams: Optional[int] = field(
27 | default=1,
28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."}
29 | )
30 | max_length: Optional[int] = field(
31 | default=None,
32 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
33 | )
34 | max_new_tokens: Optional[int] = field(
35 | default=512,
36 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
37 | )
38 | repetition_penalty: Optional[float] = field(
39 | default=1.0,
40 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
41 | )
42 | length_penalty: Optional[float] = field(
43 | default=1.0,
44 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
45 | )
46 |
47 | def to_dict(self) -> Dict[str, Any]:
48 | args = asdict(self)
49 | if args.get("max_new_tokens", None):
50 | args.pop("max_length", None)
51 | return args
52 |
--------------------------------------------------------------------------------
/llmtuner/hparams/model_args.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 | from dataclasses import dataclass, field
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | r"""
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
9 | """
10 | model_name_or_path: str = field(
11 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
12 | )
13 | cache_dir: Optional[str] = field(
14 | default=None,
15 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
16 | )
17 | use_fast_tokenizer: Optional[bool] = field(
18 | default=True,
19 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
20 | )
21 | split_special_tokens: Optional[bool] = field(
22 | default=False,
23 | metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
24 | )
25 | use_auth_token: Optional[bool] = field(
26 | default=False,
27 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
28 | )
29 | model_revision: Optional[str] = field(
30 | default="main",
31 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
32 | )
33 | quantization_bit: Optional[int] = field(
34 | default=None,
35 | metadata={"help": "The number of bits to quantize the model."}
36 | )
37 | quantization_type: Optional[Literal["fp4", "nf4"]] = field(
38 | default="nf4",
39 | metadata={"help": "Quantization data type to use in int4 training."}
40 | )
41 | double_quantization: Optional[bool] = field(
42 | default=True,
43 | metadata={"help": "Whether to use double quantization in int4 training or not."}
44 | )
45 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
46 | default=None,
47 | metadata={"help": "Adopt scaled rotary positional embeddings."}
48 | )
49 | checkpoint_dir: Optional[str] = field(
50 | default=None,
51 | metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
52 | )
53 | flash_attn: Optional[bool] = field(
54 | default=False,
55 | metadata={"help": "Enable FlashAttention-2 for faster training."}
56 | )
57 | shift_attn: Optional[bool] = field(
58 | default=False,
59 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
60 | )
61 | reward_model: Optional[str] = field(
62 | default=None,
63 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
64 | )
65 | plot_loss: Optional[bool] = field(
66 | default=False,
67 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
68 | )
69 | hf_auth_token: Optional[str] = field(
70 | default=None,
71 | metadata={"help": "Auth token to log in with Hugging Face Hub."}
72 | )
73 | export_dir: Optional[str] = field(
74 | default=None,
75 | metadata={"help": "Path to the directory to save the exported model."}
76 | )
77 |
78 | def __post_init__(self):
79 | self.compute_dtype = None
80 | self.model_max_length = None
81 |
82 | if self.split_special_tokens and self.use_fast_tokenizer:
83 | raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
84 |
85 | if self.checkpoint_dir is not None: # support merging multiple lora weights
86 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
87 |
88 | if self.quantization_bit is not None:
89 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
90 |
91 | if self.use_auth_token == True and self.hf_auth_token is not None:
92 | from huggingface_hub.hf_api import HfFolder # lazy load
93 | HfFolder.save_token(self.hf_auth_token)
94 |
--------------------------------------------------------------------------------
/llmtuner/tuner/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.tune import export_model, run_exp
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/core/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.core.parser import get_train_args, get_infer_args
2 | from llmtuner.tuner.core.loader import load_model_and_tokenizer
3 |
--------------------------------------------------------------------------------
/llmtuner/tuner/core/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING
3 |
4 | from peft import (
5 | PeftModel,
6 | TaskType,
7 | LoraConfig,
8 | get_peft_model
9 | )
10 |
11 | from llmtuner.extras.logging import get_logger
12 | from llmtuner.tuner.core.utils import find_all_linear_modules
13 |
14 | if TYPE_CHECKING:
15 | from transformers.modeling_utils import PreTrainedModel
16 | from llmtuner.hparams import ModelArguments, FinetuningArguments
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def init_adapter(
23 | model: "PreTrainedModel",
24 | model_args: "ModelArguments",
25 | finetuning_args: "FinetuningArguments",
26 | is_trainable: bool,
27 | is_mergeable: bool
28 | ) -> "PreTrainedModel":
29 | r"""
30 | Initializes the adapters.
31 |
32 | Support full-parameter, freeze and LoRA training.
33 |
34 | Note that the trainable parameters must be cast to float32.
35 | """
36 |
37 | if finetuning_args.finetuning_type == "none" and is_trainable:
38 | raise ValueError("You cannot use finetuning_type=none while training.")
39 |
40 | if finetuning_args.finetuning_type == "full" and is_trainable:
41 | logger.info("Fine-tuning method: Full")
42 | model = model.float()
43 |
44 | if finetuning_args.finetuning_type == "freeze":
45 | logger.info("Fine-tuning method: Freeze")
46 | num_layers = getattr(model.config, "num_layers")
47 | if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
48 | trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
49 | else: # fine-tuning the first n layers if num_layer_trainable < 0
50 | trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
51 |
52 | trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
53 | for name, param in model.named_parameters():
54 | if not any(trainable_layer in name for trainable_layer in trainable_layers):
55 | param.requires_grad_(False)
56 | else:
57 | param.data = param.data.to(torch.float32)
58 |
59 | if finetuning_args.finetuning_type == "lora":
60 | logger.info("Fine-tuning method: LoRA")
61 | latest_checkpoint = None
62 |
63 | if model_args.checkpoint_dir is not None:
64 | if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
65 | checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
66 | else:
67 | checkpoints_to_merge = model_args.checkpoint_dir
68 |
69 | for checkpoint in checkpoints_to_merge:
70 | model = PeftModel.from_pretrained(model, checkpoint)
71 | model = model.merge_and_unload()
72 |
73 | if len(checkpoints_to_merge) > 0:
74 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
75 |
76 | if latest_checkpoint is not None: # resume lora training or quantized inference
77 | model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
78 |
79 | if is_trainable and latest_checkpoint is None: # create new lora weights while training
80 | if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
81 | target_modules = find_all_linear_modules(model, model_args.quantization_bit)
82 | else:
83 | target_modules = finetuning_args.lora_target
84 |
85 | lora_config = LoraConfig(
86 | task_type=TaskType.CAUSAL_LM,
87 | inference_mode=False,
88 | r=finetuning_args.lora_rank,
89 | lora_alpha=finetuning_args.lora_alpha,
90 | lora_dropout=finetuning_args.lora_dropout,
91 | target_modules=target_modules,
92 | modules_to_save=finetuning_args.additional_target
93 | )
94 | model = get_peft_model(model, lora_config)
95 | if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923
96 | model.base_model.peft_config = model.peft_config
97 |
98 | if model_args.checkpoint_dir is not None:
99 | logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
100 |
101 | return model
102 |
--------------------------------------------------------------------------------
/llmtuner/tuner/core/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import datasets
5 | import transformers
6 | from typing import Any, Dict, Optional, Tuple
7 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments
8 | from transformers.trainer_utils import get_last_checkpoint
9 |
10 | from llmtuner.extras.logging import get_logger
11 | from llmtuner.hparams import (
12 | ModelArguments,
13 | DataArguments,
14 | FinetuningArguments,
15 | GeneratingArguments
16 | )
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
23 | if args is not None:
24 | return parser.parse_dict(args)
25 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
26 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
27 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
28 | return parser.parse_json_file(os.path.abspath(sys.argv[1]))
29 | else:
30 | return parser.parse_args_into_dataclasses()
31 |
32 |
33 | def parse_train_args(
34 | args: Optional[Dict[str, Any]] = None
35 | ) -> Tuple[
36 | ModelArguments,
37 | DataArguments,
38 | Seq2SeqTrainingArguments,
39 | FinetuningArguments,
40 | GeneratingArguments
41 | ]:
42 | parser = HfArgumentParser((
43 | ModelArguments,
44 | DataArguments,
45 | Seq2SeqTrainingArguments,
46 | FinetuningArguments,
47 | GeneratingArguments
48 | ))
49 | return _parse_args(parser, args)
50 |
51 |
52 | def parse_infer_args(
53 | args: Optional[Dict[str, Any]] = None
54 | ) -> Tuple[
55 | ModelArguments,
56 | DataArguments,
57 | FinetuningArguments,
58 | GeneratingArguments
59 | ]:
60 | parser = HfArgumentParser((
61 | ModelArguments,
62 | DataArguments,
63 | FinetuningArguments,
64 | GeneratingArguments
65 | ))
66 | return _parse_args(parser, args)
67 |
68 |
69 | def get_train_args(
70 | args: Optional[Dict[str, Any]] = None
71 | ) -> Tuple[
72 | ModelArguments,
73 | DataArguments,
74 | Seq2SeqTrainingArguments,
75 | FinetuningArguments,
76 | GeneratingArguments
77 | ]:
78 | model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
79 |
80 | # Setup logging
81 | if training_args.should_log:
82 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
83 | transformers.utils.logging.set_verbosity_info()
84 |
85 | log_level = training_args.get_process_log_level()
86 | datasets.utils.logging.set_verbosity(log_level)
87 | transformers.utils.logging.set_verbosity(log_level)
88 | transformers.utils.logging.enable_default_handler()
89 | transformers.utils.logging.enable_explicit_format()
90 |
91 | # Check arguments
92 | data_args.init_for_training(training_args.seed)
93 |
94 | if finetuning_args.stage != "pt" and data_args.template is None:
95 | raise ValueError("Please specify which `template` to use.")
96 |
97 | if finetuning_args.stage != "sft" and training_args.predict_with_generate:
98 | raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
99 |
100 | if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
101 | raise ValueError("Please enable `predict_with_generate` to save model predictions.")
102 |
103 | if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
104 | raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
105 |
106 | if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
107 | raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
108 |
109 | if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train:
110 | raise ValueError("PPO and DPO stages can only be performed at training.")
111 |
112 | if finetuning_args.stage in ["rm", "dpo"]:
113 | for dataset_attr in data_args.dataset_list:
114 | if not dataset_attr.ranking:
115 | raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
116 |
117 | if finetuning_args.stage == "ppo" and model_args.reward_model is None:
118 | raise ValueError("Reward model is necessary for PPO training.")
119 |
120 | if finetuning_args.stage == "ppo" and model_args.shift_attn:
121 | raise ValueError("PPO training is incompatible with S^2-Attn.")
122 |
123 | if training_args.max_steps == -1 and data_args.streaming:
124 | raise ValueError("Please specify `max_steps` in streaming mode.")
125 |
126 | if training_args.do_train and training_args.predict_with_generate:
127 | raise ValueError("`predict_with_generate` cannot be set as True while training.")
128 |
129 | if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
130 | raise ValueError("Please specify `lora_target` in LoRA training.")
131 |
132 | if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
133 | raise ValueError("Quantization is only compatible with the LoRA method.")
134 |
135 | if model_args.checkpoint_dir is not None:
136 | if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
137 | raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
138 |
139 | if model_args.quantization_bit is not None:
140 | if len(model_args.checkpoint_dir) != 1:
141 | raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
142 |
143 | if not finetuning_args.resume_lora_training:
144 | raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")
145 |
146 | if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
147 | logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
148 |
149 | if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
150 | logger.warning("We recommend enable mixed precision training.")
151 |
152 | if (not training_args.do_train) and model_args.quantization_bit is not None:
153 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
154 |
155 | # postprocess training_args
156 | if (
157 | training_args.local_rank != -1
158 | and training_args.ddp_find_unused_parameters is None
159 | and finetuning_args.finetuning_type == "lora"
160 | ):
161 | logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
162 | training_args_dict = training_args.to_dict()
163 | training_args_dict.update(dict(ddp_find_unused_parameters=False))
164 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
165 |
166 | if (
167 | training_args.resume_from_checkpoint is None
168 | and training_args.do_train
169 | and os.path.isdir(training_args.output_dir)
170 | and not training_args.overwrite_output_dir
171 | ):
172 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
173 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
174 | raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
175 |
176 | if last_checkpoint is not None:
177 | training_args_dict = training_args.to_dict()
178 | training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
179 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
180 | logger.info(
181 | "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
182 | )
183 |
184 | # postprocess model_args
185 | model_args.compute_dtype = (
186 | torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
187 | )
188 | model_args.model_max_length = data_args.cutoff_len
189 |
190 | # Log on each process the small summary:
191 | logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
192 | training_args.local_rank, training_args.device, training_args.n_gpu,
193 | bool(training_args.local_rank != -1), str(model_args.compute_dtype)
194 | ))
195 | logger.info(f"Training/evaluation parameters {training_args}")
196 |
197 | # Set seed before initializing model.
198 | transformers.set_seed(training_args.seed)
199 |
200 | return model_args, data_args, training_args, finetuning_args, generating_args
201 |
202 |
203 | def get_infer_args(
204 | args: Optional[Dict[str, Any]] = None
205 | ) -> Tuple[
206 | ModelArguments,
207 | DataArguments,
208 | FinetuningArguments,
209 | GeneratingArguments
210 | ]:
211 | model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
212 |
213 | if data_args.template is None:
214 | raise ValueError("Please specify which `template` to use.")
215 |
216 | if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
217 | raise ValueError("Quantization is only compatible with the LoRA method.")
218 |
219 | if model_args.checkpoint_dir is not None:
220 | if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1:
221 | raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
222 |
223 | if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
224 | raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")
225 |
226 | return model_args, data_args, finetuning_args, generating_args
227 |
--------------------------------------------------------------------------------
/llmtuner/tuner/core/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from types import MethodType
3 | from typing import TYPE_CHECKING, List, Optional
4 |
5 | from llmtuner.extras.constants import LAYERNORM_NAMES
6 | from llmtuner.extras.logging import get_logger
7 |
8 | if TYPE_CHECKING:
9 | from transformers.modeling_utils import PreTrainedModel
10 | from llmtuner.hparams import FinetuningArguments
11 |
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | def find_all_linear_modules(
17 | model: "PreTrainedModel",
18 | quantization_bit: Optional[int] = None,
19 | output_layer_name: Optional[str] = "lm_head"
20 | ) -> List[str]:
21 | if quantization_bit is not None:
22 | import bitsandbytes as bnb
23 | linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
24 | else:
25 | linear_cls = torch.nn.Linear
26 |
27 | module_names = set()
28 | for name, module in model.named_modules():
29 | if output_layer_name not in name and isinstance(module, linear_cls):
30 | module_names.add(name.split(".")[-1])
31 |
32 | if output_layer_name in module_names:
33 | module_names.pop(output_layer_name)
34 |
35 | return list(module_names)
36 |
37 |
38 | def prepare_model_for_training(
39 | model: "PreTrainedModel",
40 | finetuning_args: "FinetuningArguments",
41 | output_layer_name: Optional[str] = "lm_head",
42 | use_gradient_checkpointing: Optional[bool] = True,
43 | layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
44 | ) -> "PreTrainedModel":
45 | r"""
46 | Includes:
47 | (1) cast the layernorm in fp32
48 | (2) make output embedding layer require grads
49 | (3) upcast the lm_head to fp32
50 | Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
51 | """
52 | if finetuning_args.upcast_layernorm:
53 | for name, param in model.named_parameters():
54 | if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
55 | param.data = param.data.to(torch.float32)
56 | logger.info("Upcasting weights in layernorm in float32.")
57 |
58 | if finetuning_args.neft_alpha > 1e-6:
59 | input_embed = model.get_input_embeddings()
60 | if isinstance(input_embed, torch.nn.Embedding):
61 | def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor:
62 | embeddings = input_embed.__class__.forward(self, x)
63 | if self.training:
64 | dims = self.num_embeddings * self.embedding_dim
65 | mag_norm = finetuning_args.neft_alpha / (dims ** 0.5)
66 | embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
67 | return embeddings
68 |
69 | input_embed.forward = MethodType(noisy_forward, input_embed)
70 | logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
71 | else:
72 | logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.")
73 |
74 | if use_gradient_checkpointing:
75 | if hasattr(model, "enable_input_require_grads"):
76 | model.enable_input_require_grads()
77 | else:
78 | def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor):
79 | output.requires_grad_(True)
80 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
81 |
82 | model.gradient_checkpointing_enable()
83 | model.config.use_cache = False # turn off when gradient checkpointing is enabled
84 | logger.info("Gradient checkpointing enabled.")
85 |
86 | if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
87 | output_layer = getattr(model, output_layer_name)
88 | if isinstance(output_layer, torch.nn.Linear):
89 | def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor:
90 | return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32)
91 |
92 | output_layer.forward = MethodType(forward_in_fp32, output_layer)
93 |
94 | return model
95 |
--------------------------------------------------------------------------------
/llmtuner/tuner/dpo/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.dpo.workflow import run_dpo
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/dpo/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, List, Sequence, Tuple
4 | from transformers import DataCollatorForSeq2Seq
5 |
6 |
7 | @dataclass
8 | class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
9 | r"""
10 | Data collator for pairwise data.
11 | """
12 |
13 | def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
14 | padded_labels = []
15 | for feature, (prompt_len, answer_len) in zip(batch, positions):
16 | if self.tokenizer.padding_side == "left":
17 | start, end = feature.size(0) - answer_len, feature.size(0)
18 | else:
19 | start, end = prompt_len, prompt_len + answer_len
20 | padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
21 | padded_tensor[start:end] = feature[start:end]
22 | padded_labels.append(padded_tensor)
23 | return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
24 |
25 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
26 | r"""
27 | Pads batched data to the longest sequence in the batch.
28 |
29 | We generate 2 * n examples where the first n examples represent chosen examples and
30 | the last n examples represent rejected examples.
31 | """
32 | concatenated_features = []
33 | label_positions = []
34 | for key in ("chosen_ids", "rejected_ids"):
35 | for feature in features:
36 | prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
37 | concatenated_features.append({
38 | "input_ids": feature["prompt_ids"] + feature[key],
39 | "attention_mask": [1] * (prompt_len + answer_len)
40 | })
41 | label_positions.append((prompt_len, answer_len))
42 |
43 | batch = self.tokenizer.pad(
44 | concatenated_features,
45 | padding=self.padding,
46 | max_length=self.max_length,
47 | pad_to_multiple_of=self.pad_to_multiple_of,
48 | return_tensors=self.return_tensors,
49 | )
50 | batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
51 | return batch
52 |
--------------------------------------------------------------------------------
/llmtuner/tuner/dpo/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import defaultdict
3 | from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
4 | from transformers import BatchEncoding, Trainer
5 | from trl import DPOTrainer
6 | from trl.trainer.utils import disable_dropout_in_model
7 |
8 | from llmtuner.extras.constants import IGNORE_INDEX
9 |
10 | if TYPE_CHECKING:
11 | from transformers import PreTrainedModel
12 |
13 |
14 | class CustomDPOTrainer(DPOTrainer):
15 |
16 | def __init__(
17 | self,
18 | beta: float,
19 | model: Union["PreTrainedModel", torch.nn.Module],
20 | ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
21 | disable_dropout: Optional[bool] = True,
22 | loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
23 | **kwargs
24 | ):
25 | if disable_dropout:
26 | disable_dropout_in_model(model)
27 | if ref_model is not None:
28 | disable_dropout_in_model(ref_model)
29 |
30 | self.is_encoder_decoder = model.config.is_encoder_decoder
31 | self.ref_model = ref_model
32 | self.use_dpo_data_collator = True # hack to avoid warning
33 | self.label_pad_token_id = IGNORE_INDEX
34 | self.padding_value = 0
35 | self.beta = beta
36 | self.loss_type = loss_type
37 | self._stored_metrics = defaultdict(lambda: defaultdict(list))
38 |
39 | Trainer.__init__(self, model=model, **kwargs)
40 | if not hasattr(self, "accelerator"):
41 | raise AttributeError("Please update `transformers`.")
42 |
43 | if ref_model is not None:
44 | if self.is_deepspeed_enabled:
45 | self.ref_model = self._prepare_deepspeed(self.ref_model)
46 | else:
47 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
48 |
49 | def concatenated_forward(
50 | self,
51 | model: Optional[torch.nn.Module] = None,
52 | batch: Optional[Dict[str, torch.Tensor]] = None
53 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
54 | batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
55 |
56 | all_logits = model(
57 | input_ids=batch_copied["input_ids"],
58 | attention_mask=batch_copied["attention_mask"],
59 | return_dict=True
60 | ).logits.to(torch.float32)
61 |
62 | all_logps = self._get_batch_logps(
63 | all_logits,
64 | batch["labels"],
65 | average_log_prob=False
66 | )
67 | batch_size = batch["input_ids"].size(0) // 2
68 | chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
69 | chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
70 | return chosen_logps, rejected_logps, chosen_logits, rejected_logits
71 |
--------------------------------------------------------------------------------
/llmtuner/tuner/dpo/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
2 |
3 | from copy import deepcopy
4 | from peft import PeftModel
5 | from typing import TYPE_CHECKING, Optional, List
6 | from transformers import Seq2SeqTrainingArguments
7 |
8 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
9 | from llmtuner.extras.constants import IGNORE_INDEX
10 | from llmtuner.extras.ploting import plot_loss
11 | from llmtuner.tuner.core import load_model_and_tokenizer
12 | from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
13 | from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
14 |
15 | if TYPE_CHECKING:
16 | from transformers import TrainerCallback
17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
18 |
19 |
20 | def run_dpo(
21 | model_args: "ModelArguments",
22 | data_args: "DataArguments",
23 | training_args: "Seq2SeqTrainingArguments",
24 | finetuning_args: "FinetuningArguments",
25 | callbacks: Optional[List["TrainerCallback"]] = None
26 | ):
27 | dataset = get_dataset(model_args, data_args)
28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
30 | data_collator = DPODataCollatorWithPadding(
31 | tokenizer=tokenizer,
32 | pad_to_multiple_of=4,
33 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
34 | )
35 |
36 | training_args_dict = training_args.to_dict()
37 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
38 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
39 |
40 | # Initialize our Trainer
41 | trainer = CustomDPOTrainer(
42 | beta=finetuning_args.dpo_beta,
43 | model=model,
44 | ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None,
45 | args=training_args,
46 | tokenizer=tokenizer,
47 | data_collator=data_collator,
48 | callbacks=callbacks,
49 | **split_dataset(dataset, data_args, training_args)
50 | )
51 |
52 | # Training
53 | if training_args.do_train:
54 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
55 | trainer.log_metrics("train", train_result.metrics)
56 | trainer.save_metrics("train", train_result.metrics)
57 | trainer.save_state()
58 | trainer.save_model()
59 | if trainer.is_world_process_zero() and model_args.plot_loss:
60 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
61 |
--------------------------------------------------------------------------------
/llmtuner/tuner/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.ppo.workflow import run_ppo
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/ppo/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING, Dict, Literal, Optional
3 |
4 | if TYPE_CHECKING:
5 | from transformers import PreTrainedModel
6 | from trl import AutoModelForCausalLMWithValueHead
7 |
8 |
9 | def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
10 | if target == "reward": # save default head temporarily
11 | valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
12 | setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
13 | setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
14 |
15 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
16 | model.v_head.load_state_dict({
17 | "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
18 | "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
19 | })
20 |
21 |
22 | def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
23 | layer_norm_params = {}
24 | for name, param in model.named_parameters():
25 | if param.data.dtype == torch.float32:
26 | layer_norm_params[name] = param.data.detach().clone()
27 | param.data = param.data.to(model.config.torch_dtype)
28 |
29 | return layer_norm_params
30 |
31 |
32 | def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
33 | for name, param in model.named_parameters():
34 | if name in layernorm_params:
35 | param.data = layernorm_params[name]
36 |
--------------------------------------------------------------------------------
/llmtuner/tuner/ppo/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
2 |
3 | import math
4 | from trl import PPOConfig
5 | from torch.optim import AdamW
6 | from typing import TYPE_CHECKING, Optional, List
7 | from transformers import DataCollatorWithPadding
8 | from transformers.optimization import get_scheduler
9 |
10 | from llmtuner.dsets import get_dataset, preprocess_dataset
11 | from llmtuner.extras.callbacks import SavePeftModelCallback
12 | from llmtuner.extras.ploting import plot_loss
13 | from llmtuner.tuner.core import load_model_and_tokenizer
14 | from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
15 |
16 | if TYPE_CHECKING:
17 | from transformers import Seq2SeqTrainingArguments, TrainerCallback
18 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
19 |
20 |
21 | def run_ppo(
22 | model_args: "ModelArguments",
23 | data_args: "DataArguments",
24 | training_args: "Seq2SeqTrainingArguments",
25 | finetuning_args: "FinetuningArguments",
26 | generating_args: "GeneratingArguments",
27 | callbacks: Optional[List["TrainerCallback"]] = None
28 | ):
29 | dataset = get_dataset(model_args, data_args)
30 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
31 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
32 |
33 | tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
34 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
35 |
36 | ppo_config = PPOConfig(
37 | model_name=model_args.model_name_or_path,
38 | learning_rate=training_args.learning_rate,
39 | mini_batch_size=training_args.per_device_train_batch_size,
40 | batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
41 | gradient_accumulation_steps=training_args.gradient_accumulation_steps,
42 | ppo_epochs=1,
43 | max_grad_norm=training_args.max_grad_norm,
44 | seed=training_args.seed,
45 | optimize_cuda_cache=True,
46 | target=finetuning_args.ppo_target,
47 | log_with=finetuning_args.ppo_logger,
48 | use_score_scaling=finetuning_args.ppo_score_norm,
49 | use_score_norm=finetuning_args.ppo_score_norm,
50 | accelerator_kwargs={"step_scheduler_with_optimizer": False}
51 | )
52 |
53 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
54 | if training_args.max_steps > 0:
55 | num_training_steps = training_args.max_steps
56 | else:
57 | total_train_batch_size = (
58 | training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
59 | )
60 | num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
61 |
62 | lr_scheduler = get_scheduler(
63 | training_args.lr_scheduler_type,
64 | optimizer=optimizer,
65 | num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
66 | num_training_steps=num_training_steps
67 | )
68 |
69 | # Initialize our Trainer
70 | ppo_trainer = CustomPPOTrainer(
71 | model_args=model_args,
72 | training_args=training_args,
73 | finetuning_args=finetuning_args,
74 | generating_args=generating_args,
75 | callbacks=callbacks + [SavePeftModelCallback()],
76 | config=ppo_config,
77 | model=model,
78 | ref_model=None,
79 | tokenizer=tokenizer,
80 | dataset=dataset,
81 | data_collator=data_collator,
82 | optimizer=optimizer,
83 | lr_scheduler=lr_scheduler
84 | )
85 |
86 | # Training
87 | if training_args.do_train:
88 | ppo_trainer.ppo_train()
89 | ppo_trainer.save_model()
90 | ppo_trainer.save_state() # must be called after save_model to have a folder
91 | if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
92 | plot_loss(training_args.output_dir, keys=["loss", "reward"])
93 |
--------------------------------------------------------------------------------
/llmtuner/tuner/pt/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.pt.workflow import run_pt
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/pt/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
2 |
3 | import math
4 | from typing import TYPE_CHECKING, Optional, List
5 | from transformers import DataCollatorForLanguageModeling, Trainer
6 |
7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
8 | from llmtuner.extras.ploting import plot_loss
9 | from llmtuner.tuner.core import load_model_and_tokenizer
10 |
11 | if TYPE_CHECKING:
12 | from transformers import Seq2SeqTrainingArguments, TrainerCallback
13 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
14 |
15 |
16 | def run_pt(
17 | model_args: "ModelArguments",
18 | data_args: "DataArguments",
19 | training_args: "Seq2SeqTrainingArguments",
20 | finetuning_args: "FinetuningArguments",
21 | callbacks: Optional[List["TrainerCallback"]] = None
22 | ):
23 | dataset = get_dataset(model_args, data_args)
24 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
25 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
26 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
27 |
28 | # Initialize our Trainer
29 | trainer = Trainer(
30 | model=model,
31 | args=training_args,
32 | tokenizer=tokenizer,
33 | data_collator=data_collator,
34 | callbacks=callbacks,
35 | **split_dataset(dataset, data_args, training_args)
36 | )
37 |
38 | # Training
39 | if training_args.do_train:
40 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
41 | trainer.log_metrics("train", train_result.metrics)
42 | trainer.save_metrics("train", train_result.metrics)
43 | trainer.save_state()
44 | trainer.save_model()
45 | if trainer.is_world_process_zero() and model_args.plot_loss:
46 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
47 |
48 | # Evaluation
49 | if training_args.do_eval:
50 | metrics = trainer.evaluate(metric_key_prefix="eval")
51 | try:
52 | perplexity = math.exp(metrics["eval_loss"])
53 | except OverflowError:
54 | perplexity = float("inf")
55 |
56 | metrics["perplexity"] = perplexity
57 | trainer.log_metrics("eval", metrics)
58 | trainer.save_metrics("eval", metrics)
59 |
--------------------------------------------------------------------------------
/llmtuner/tuner/rm/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.rm.workflow import run_rm
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/rm/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, Sequence
4 | from transformers import DataCollatorWithPadding
5 |
6 |
7 | @dataclass
8 | class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
9 | r"""
10 | Data collator for pairwise data.
11 | """
12 |
13 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
14 | r"""
15 | Pads batched data to the longest sequence in the batch.
16 |
17 | We generate 2 * n examples where the first n examples represent chosen examples and
18 | the last n examples represent rejected examples.
19 | """
20 | features = [
21 | {
22 | "input_ids": feature["prompt_ids"] + feature[key],
23 | "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
24 | }
25 | for key in ("chosen_ids", "rejected_ids") for feature in features
26 | ]
27 | return super().__call__(features)
28 |
--------------------------------------------------------------------------------
/llmtuner/tuner/rm/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Dict, Sequence, Tuple, Union
3 |
4 |
5 | def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
6 | preds, _ = eval_preds
7 | return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
8 |
--------------------------------------------------------------------------------
/llmtuner/tuner/rm/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
5 | from transformers import Trainer
6 |
7 | from llmtuner.extras.logging import get_logger
8 |
9 | if TYPE_CHECKING:
10 | from transformers.trainer import PredictionOutput
11 | from transformers.modeling_utils import PreTrainedModel
12 |
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | class PairwiseTrainer(Trainer):
18 | r"""
19 | Inherits PeftTrainer to compute pairwise loss.
20 | """
21 |
22 | def __init__(self, *args, **kwargs):
23 | super().__init__(*args, **kwargs)
24 | self.can_return_loss = True # override property to return eval_loss
25 |
26 | def compute_loss(
27 | self,
28 | model: "PreTrainedModel",
29 | inputs: Dict[str, torch.Tensor],
30 | return_outputs: Optional[bool] = False
31 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
32 | r"""
33 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
34 |
35 | Subclass and override to inject custom behavior.
36 |
37 | Note that the first element will be removed from the output tuple.
38 | See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
39 | """
40 | # Compute rewards
41 | _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
42 | if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
43 | values = torch.transpose(values, 0, 1)
44 |
45 | # Split the inputs and rewards into two parts, chosen and rejected
46 | batch_size = inputs["input_ids"].size(0) // 2
47 | chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
48 | chosen_attn_mask, rejected_attn_mask = (
49 | inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
50 | )
51 | chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
52 | chosen_scores, rejected_scores = [], []
53 |
54 | # Compute pairwise loss. Only backprop on the different tokens before padding
55 | # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
56 | loss = 0
57 | for i in range(batch_size):
58 | chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
59 | rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
60 | check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
61 |
62 | if len(check_divergence) == 0:
63 | end_index = chosen_length
64 | div_index = end_index - 1
65 | else:
66 | end_index = max(chosen_length, rejected_length)
67 | div_index = check_divergence[0]
68 |
69 | assert div_index > 0
70 | chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
71 | rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
72 | if return_outputs: # use the score on the EOS token for inference
73 | chosen_scores.append(chosen_rewards[i, chosen_length-1])
74 | rejected_scores.append(rejected_rewards[i, rejected_length-1])
75 | loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
76 |
77 | loss = loss / batch_size
78 | if return_outputs:
79 | chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
80 | return loss, [loss, chosen_scores, rejected_scores]
81 |
82 | return loss
83 |
84 | def save_predictions(
85 | self,
86 | predict_results: "PredictionOutput"
87 | ) -> None:
88 | r"""
89 | Saves model predictions to `output_dir`.
90 |
91 | A custom behavior that not contained in Seq2SeqTrainer.
92 | """
93 | if not self.is_world_process_zero():
94 | return
95 |
96 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
97 | logger.info(f"Saving prediction results to {output_prediction_file}")
98 |
99 | chosen_scores, rejected_scores = predict_results.predictions
100 |
101 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
102 | res: List[str] = []
103 | for c_score, r_score in zip(chosen_scores, rejected_scores):
104 | res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
105 | writer.write("\n".join(res))
106 |
--------------------------------------------------------------------------------
/llmtuner/tuner/rm/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by:
2 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
3 |
4 | from typing import TYPE_CHECKING, Optional, List
5 | from transformers import Seq2SeqTrainingArguments
6 |
7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
8 | from llmtuner.extras.callbacks import SavePeftModelCallback
9 | from llmtuner.extras.ploting import plot_loss
10 | from llmtuner.tuner.core import load_model_and_tokenizer
11 | from llmtuner.tuner.rm.metric import compute_accuracy
12 | from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
13 | from llmtuner.tuner.rm.trainer import PairwiseTrainer
14 |
15 | if TYPE_CHECKING:
16 | from transformers import TrainerCallback
17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
18 |
19 |
20 | def run_rm(
21 | model_args: "ModelArguments",
22 | data_args: "DataArguments",
23 | training_args: "Seq2SeqTrainingArguments",
24 | finetuning_args: "FinetuningArguments",
25 | callbacks: Optional[List["TrainerCallback"]] = None
26 | ):
27 | dataset = get_dataset(model_args, data_args)
28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
30 | data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
31 |
32 | training_args_dict = training_args.to_dict()
33 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
34 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
35 |
36 | # Initialize our Trainer
37 | trainer = PairwiseTrainer(
38 | model=model,
39 | args=training_args,
40 | tokenizer=tokenizer,
41 | data_collator=data_collator,
42 | callbacks=callbacks + [SavePeftModelCallback()],
43 | compute_metrics=compute_accuracy,
44 | **split_dataset(dataset, data_args, training_args)
45 | )
46 |
47 | # Training
48 | if training_args.do_train:
49 | train_result = trainer.train()
50 | trainer.log_metrics("train", train_result.metrics)
51 | trainer.save_metrics("train", train_result.metrics)
52 | trainer.save_state()
53 | trainer.save_model()
54 | if trainer.is_world_process_zero() and model_args.plot_loss:
55 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
56 |
57 | # Evaluation
58 | if training_args.do_eval:
59 | metrics = trainer.evaluate(metric_key_prefix="eval")
60 | trainer.log_metrics("eval", metrics)
61 | trainer.save_metrics("eval", metrics)
62 |
63 | # Predict
64 | if training_args.do_predict:
65 | predict_results = trainer.predict(dataset, metric_key_prefix="predict")
66 | trainer.log_metrics("predict", predict_results.metrics)
67 | trainer.save_metrics("predict", predict_results.metrics)
68 | trainer.save_predictions(predict_results)
69 |
--------------------------------------------------------------------------------
/llmtuner/tuner/sft/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.tuner.sft.workflow import run_sft
2 |
--------------------------------------------------------------------------------
/llmtuner/tuner/sft/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
4 |
5 | import jieba
6 | from rouge_chinese import Rouge
7 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
8 |
9 | from llmtuner.extras.constants import IGNORE_INDEX
10 |
11 | if TYPE_CHECKING:
12 | from transformers.tokenization_utils import PreTrainedTokenizer
13 |
14 |
15 | @dataclass
16 | class ComputeMetrics:
17 | r"""
18 | Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
19 | """
20 |
21 | tokenizer: "PreTrainedTokenizer"
22 |
23 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
24 | r"""
25 | Uses the model predictions to compute metrics.
26 | """
27 | preds, labels = eval_preds
28 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
29 |
30 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
31 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
32 |
33 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
34 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
35 |
36 | for pred, label in zip(decoded_preds, decoded_labels):
37 | hypothesis = list(jieba.cut(pred))
38 | reference = list(jieba.cut(label))
39 |
40 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
41 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
42 | else:
43 | rouge = Rouge()
44 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
45 | result = scores[0]
46 |
47 | for k, v in result.items():
48 | score_dict[k].append(round(v["f"] * 100, 4))
49 |
50 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
51 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
52 |
53 | return {k: float(np.mean(v)) for k, v in score_dict.items()}
54 |
--------------------------------------------------------------------------------
/llmtuner/tuner/sft/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
7 | from transformers import Seq2SeqTrainer
8 |
9 | from llmtuner.extras.constants import IGNORE_INDEX
10 | from llmtuner.extras.logging import get_logger
11 |
12 | if TYPE_CHECKING:
13 | from transformers.trainer import PredictionOutput
14 |
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | class CustomSeq2SeqTrainer(Seq2SeqTrainer):
20 | r"""
21 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
22 | """
23 |
24 | def prediction_step(
25 | self,
26 | model: nn.Module,
27 | inputs: Dict[str, Union[torch.Tensor, Any]],
28 | prediction_loss_only: bool,
29 | ignore_keys: Optional[List[str]] = None,
30 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
31 | r"""
32 | Removes the prompt part in the generated tokens.
33 |
34 | Subclass and override to inject custom behavior.
35 | """
36 | labels = inputs["labels"].clone() if "labels" in inputs else None # backup labels
37 | if self.args.predict_with_generate:
38 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
39 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
40 | if prompt_len > label_len:
41 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
42 | if label_len > prompt_len:
43 | inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
44 |
45 | loss, generated_tokens, _ = super().prediction_step(
46 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
47 | )
48 | if generated_tokens is not None and self.args.predict_with_generate:
49 | generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
50 | generated_tokens = generated_tokens.contiguous()
51 |
52 | return loss, generated_tokens, labels
53 |
54 | def _pad_tensors_to_target_len(
55 | self,
56 | src_tensor: torch.Tensor,
57 | tgt_tensor: torch.Tensor
58 | ) -> torch.Tensor:
59 | r"""
60 | Pads the tensor to the same length as the target tensor.
61 | """
62 | assert self.tokenizer.pad_token_id is not None, "Pad token is required."
63 | padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
64 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
65 | return padded_tensor.contiguous() # in contiguous memory
66 |
67 | def save_predictions(
68 | self,
69 | predict_results: "PredictionOutput"
70 | ) -> None:
71 | r"""
72 | Saves model predictions to `output_dir`.
73 |
74 | A custom behavior that not contained in Seq2SeqTrainer.
75 | """
76 | if not self.is_world_process_zero():
77 | return
78 |
79 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
80 | logger.info(f"Saving prediction results to {output_prediction_file}")
81 |
82 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
83 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
84 |
85 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
86 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True)
87 |
88 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
89 | res: List[str] = []
90 | for pred, label in zip(decoded_preds, decoded_labels):
91 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
92 | writer.write("\n".join(res))
93 |
--------------------------------------------------------------------------------
/llmtuner/tuner/sft/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
2 |
3 | from typing import TYPE_CHECKING, Optional, List
4 | from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
5 |
6 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
7 | from llmtuner.extras.constants import IGNORE_INDEX
8 | from llmtuner.extras.misc import get_logits_processor
9 | from llmtuner.extras.ploting import plot_loss
10 | from llmtuner.tuner.core import load_model_and_tokenizer
11 | from llmtuner.tuner.sft.metric import ComputeMetrics
12 | from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
13 |
14 | if TYPE_CHECKING:
15 | from transformers import TrainerCallback
16 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
17 |
18 |
19 | def run_sft(
20 | model_args: "ModelArguments",
21 | data_args: "DataArguments",
22 | training_args: "Seq2SeqTrainingArguments",
23 | finetuning_args: "FinetuningArguments",
24 | generating_args: "GeneratingArguments",
25 | callbacks: Optional[List["TrainerCallback"]] = None
26 | ):
27 | dataset = get_dataset(model_args, data_args)
28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
30 |
31 | if training_args.predict_with_generate:
32 | tokenizer.padding_side = "left" # use left-padding in generation
33 |
34 | data_collator = DataCollatorForSeq2Seq(
35 | tokenizer=tokenizer,
36 | pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
37 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
38 | )
39 |
40 | # Override the decoding parameters of Seq2SeqTrainer
41 | training_args_dict = training_args.to_dict()
42 | training_args_dict.update(dict(
43 | generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
44 | generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
45 | ))
46 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
47 |
48 | # Initialize our Trainer
49 | trainer = CustomSeq2SeqTrainer(
50 | model=model,
51 | args=training_args,
52 | tokenizer=tokenizer,
53 | data_collator=data_collator,
54 | callbacks=callbacks,
55 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
56 | **split_dataset(dataset, data_args, training_args)
57 | )
58 |
59 | # Keyword arguments for `model.generate`
60 | gen_kwargs = generating_args.to_dict()
61 | gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
62 | gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
63 | gen_kwargs["logits_processor"] = get_logits_processor()
64 |
65 | # Training
66 | if training_args.do_train:
67 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
68 | trainer.log_metrics("train", train_result.metrics)
69 | trainer.save_metrics("train", train_result.metrics)
70 | trainer.save_state()
71 | trainer.save_model()
72 | if trainer.is_world_process_zero() and model_args.plot_loss:
73 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
74 |
75 | # Evaluation
76 | if training_args.do_eval:
77 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
78 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
79 | metrics.pop("eval_loss", None)
80 | trainer.log_metrics("eval", metrics)
81 | trainer.save_metrics("eval", metrics)
82 |
83 | # Predict
84 | if training_args.do_predict:
85 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
86 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
87 | predict_results.metrics.pop("predict_loss", None)
88 | trainer.log_metrics("predict", predict_results.metrics)
89 | trainer.save_metrics("predict", predict_results.metrics)
90 | trainer.save_predictions(predict_results)
91 |
--------------------------------------------------------------------------------
/llmtuner/tuner/tune.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Any, Dict, List, Optional
2 |
3 | from llmtuner.extras.callbacks import LogCallback
4 | from llmtuner.extras.logging import get_logger
5 | from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
6 | from llmtuner.tuner.pt import run_pt
7 | from llmtuner.tuner.sft import run_sft
8 | from llmtuner.tuner.rm import run_rm
9 | from llmtuner.tuner.ppo import run_ppo
10 | from llmtuner.tuner.dpo import run_dpo
11 |
12 | if TYPE_CHECKING:
13 | from transformers import TrainerCallback
14 |
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
20 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
21 | callbacks = [LogCallback()] if callbacks is None else callbacks
22 |
23 | if finetuning_args.stage == "pt":
24 | run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
25 | elif finetuning_args.stage == "sft":
26 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
27 | elif finetuning_args.stage == "rm":
28 | run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
29 | elif finetuning_args.stage == "ppo":
30 | run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
31 | elif finetuning_args.stage == "dpo":
32 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
33 | else:
34 | raise ValueError("Unknown task.")
35 |
36 |
37 | def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
38 | model_args, _, finetuning_args, _ = get_infer_args(args)
39 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
40 | model.config.use_cache = True
41 | tokenizer.padding_side = "left" # restore padding side
42 | tokenizer.init_kwargs["padding_side"] = "left"
43 | model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
44 | try:
45 | tokenizer.save_pretrained(model_args.export_dir)
46 | except:
47 | logger.warning("Cannot save tokenizer, please copy the files manually.")
48 |
49 |
50 | if __name__ == "__main__":
51 | run_exp()
52 |
--------------------------------------------------------------------------------
/llmtuner/webui/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.webui.interface import create_ui, create_web_demo
2 |
--------------------------------------------------------------------------------
/llmtuner/webui/chatter.py:
--------------------------------------------------------------------------------
1 | from gradio.components import Component # cannot use TYPE_CHECKING here
2 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
3 |
4 | from llmtuner.chat.stream_chat import ChatModel
5 | from llmtuner.extras.misc import torch_gc
6 | from llmtuner.hparams import GeneratingArguments
7 | from llmtuner.webui.common import get_save_dir
8 | from llmtuner.webui.locales import ALERTS
9 |
10 | if TYPE_CHECKING:
11 | from llmtuner.webui.manager import Manager
12 |
13 |
14 | class WebChatModel(ChatModel):
15 |
16 | def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
17 | self.manager = manager
18 | self.model = None
19 | self.tokenizer = None
20 | self.generating_args = GeneratingArguments()
21 | if not lazy_init:
22 | super().__init__()
23 |
24 | @property
25 | def loaded(self) -> bool:
26 | return self.model is not None
27 |
28 | def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
29 | get = lambda name: data[self.manager.get_elem(name)]
30 | lang = get("top.lang")
31 |
32 | if self.loaded:
33 | yield ALERTS["err_exists"][lang]
34 | return
35 |
36 | if not get("top.model_name"):
37 | yield ALERTS["err_no_model"][lang]
38 | return
39 |
40 | if not get("top.model_path"):
41 | yield ALERTS["err_no_path"][lang]
42 | return
43 |
44 | if get("top.checkpoints"):
45 | checkpoint_dir = ",".join([
46 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
47 | ])
48 | else:
49 | checkpoint_dir = None
50 |
51 | yield ALERTS["info_loading"][lang]
52 | args = dict(
53 | model_name_or_path=get("top.model_path"),
54 | checkpoint_dir=checkpoint_dir,
55 | finetuning_type=get("top.finetuning_type"),
56 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
57 | template=get("top.template"),
58 | system_prompt=get("top.system_prompt"),
59 | flash_attn=get("top.flash_attn"),
60 | shift_attn=get("top.shift_attn"),
61 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
62 | )
63 | super().__init__(args)
64 |
65 | yield ALERTS["info_loaded"][lang]
66 |
67 | def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
68 | get = lambda name: data[self.manager.get_elem(name)]
69 | lang = get("top.lang")
70 |
71 | yield ALERTS["info_unloading"][lang]
72 | self.model = None
73 | self.tokenizer = None
74 | torch_gc()
75 | yield ALERTS["info_unloaded"][lang]
76 |
77 | def predict(
78 | self,
79 | chatbot: List[Tuple[str, str]],
80 | query: str,
81 | history: List[Tuple[str, str]],
82 | system: str,
83 | max_new_tokens: int,
84 | top_p: float,
85 | temperature: float
86 | ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
87 | chatbot.append([query, ""])
88 | response = ""
89 | for new_text in self.stream_chat(
90 | query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
91 | ):
92 | response += new_text
93 | new_history = history + [(query, response)]
94 | chatbot[-1] = [query, self.postprocess(response)]
95 | yield chatbot, new_history
96 |
97 | def postprocess(self, response: str) -> str:
98 | blocks = response.split("```")
99 | for i, block in enumerate(blocks):
100 | if i % 2 == 0:
101 | blocks[i] = block.replace("<", "<").replace(">", ">")
102 | return "```".join(blocks)
103 |
--------------------------------------------------------------------------------
/llmtuner/webui/common.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import gradio as gr
4 | from typing import Any, Dict, Optional
5 | from transformers.utils import (
6 | WEIGHTS_NAME,
7 | WEIGHTS_INDEX_NAME,
8 | SAFE_WEIGHTS_NAME,
9 | SAFE_WEIGHTS_INDEX_NAME,
10 | ADAPTER_WEIGHTS_NAME,
11 | ADAPTER_SAFE_WEIGHTS_NAME
12 | )
13 |
14 | from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
15 |
16 |
17 | DEFAULT_CACHE_DIR = "cache"
18 | DEFAULT_DATA_DIR = "data"
19 | DEFAULT_SAVE_DIR = "saves"
20 | USER_CONFIG = "user.config"
21 | DATA_CONFIG = "dataset_info.json"
22 | CKPT_NAMES = [
23 | WEIGHTS_NAME,
24 | WEIGHTS_INDEX_NAME,
25 | SAFE_WEIGHTS_NAME,
26 | SAFE_WEIGHTS_INDEX_NAME,
27 | ADAPTER_WEIGHTS_NAME,
28 | ADAPTER_SAFE_WEIGHTS_NAME
29 | ]
30 |
31 |
32 | def get_save_dir(*args) -> os.PathLike:
33 | return os.path.join(DEFAULT_SAVE_DIR, *args)
34 |
35 |
36 | def get_config_path() -> os.PathLike:
37 | return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
38 |
39 |
40 | def load_config() -> Dict[str, Any]:
41 | try:
42 | with open(get_config_path(), "r", encoding="utf-8") as f:
43 | return json.load(f)
44 | except:
45 | return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
46 |
47 |
48 | def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
49 | os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
50 | user_config = load_config()
51 | user_config["lang"] = lang or user_config["lang"]
52 | if model_name:
53 | user_config["last_model"] = model_name
54 | user_config["path_dict"][model_name] = model_path
55 | with open(get_config_path(), "w", encoding="utf-8") as f:
56 | json.dump(user_config, f, indent=2, ensure_ascii=False)
57 |
58 |
59 | def get_model_path(model_name: str) -> str:
60 | user_config = load_config()
61 | return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
62 |
63 |
64 | def get_module(model_name: str) -> str:
65 | return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
66 |
67 |
68 | def get_template(model_name: str) -> str:
69 | if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
70 | return DEFAULT_TEMPLATE[model_name.split("-")[0]]
71 | return "default"
72 |
73 |
74 | def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
75 | checkpoints = []
76 | save_dir = get_save_dir(model_name, finetuning_type)
77 | if save_dir and os.path.isdir(save_dir):
78 | for checkpoint in os.listdir(save_dir):
79 | if (
80 | os.path.isdir(os.path.join(save_dir, checkpoint))
81 | and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
82 | ):
83 | checkpoints.append(checkpoint)
84 | return gr.update(value=[], choices=checkpoints)
85 |
86 |
87 | def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
88 | try:
89 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
90 | return json.load(f)
91 | except:
92 | print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
93 | return {}
94 |
95 |
96 | def list_dataset(
97 | dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
98 | ) -> Dict[str, Any]:
99 | dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
100 | ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
101 | datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
102 | return gr.update(value=[], choices=datasets)
103 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/__init__.py:
--------------------------------------------------------------------------------
1 | from llmtuner.webui.components.top import create_top
2 | from llmtuner.webui.components.train import create_train_tab
3 | from llmtuner.webui.components.eval import create_eval_tab
4 | from llmtuner.webui.components.infer import create_infer_tab
5 | from llmtuner.webui.components.export import create_export_tab
6 | from llmtuner.webui.components.chatbot import create_chat_box
7 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/chatbot.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict, Optional, Tuple
3 |
4 | if TYPE_CHECKING:
5 | from gradio.blocks import Block
6 | from gradio.components import Component
7 | from llmtuner.webui.engine import Engine
8 |
9 |
10 | def create_chat_box(
11 | engine: "Engine",
12 | visible: Optional[bool] = False
13 | ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
14 | elem_dict = dict()
15 |
16 | with gr.Box(visible=visible) as chat_box:
17 | chatbot = gr.Chatbot()
18 |
19 | with gr.Row():
20 | with gr.Column(scale=4):
21 | system = gr.Textbox(show_label=False)
22 | query = gr.Textbox(show_label=False, lines=8)
23 | submit_btn = gr.Button(variant="primary")
24 |
25 | with gr.Column(scale=1):
26 | clear_btn = gr.Button()
27 | gen_kwargs = engine.chatter.generating_args
28 | max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
29 | top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
30 | temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
31 |
32 | elem_dict.update(dict(
33 | system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn,
34 | max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
35 | ))
36 |
37 | history = gr.State([])
38 |
39 | submit_btn.click(
40 | engine.chatter.predict,
41 | [chatbot, query, history, system, max_new_tokens, top_p, temperature],
42 | [chatbot, history],
43 | show_progress=True
44 | ).then(
45 | lambda: gr.update(value=""), outputs=[query]
46 | )
47 |
48 | clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
49 |
50 | return chat_box, chatbot, history, elem_dict
51 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/data.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Tuple
3 |
4 | if TYPE_CHECKING:
5 | from gradio.blocks import Block
6 | from gradio.components import Component
7 |
8 |
9 | def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
10 | with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
11 | preview_count = gr.Number(interactive=False)
12 | preview_samples = gr.JSON(interactive=False)
13 | close_btn = gr.Button()
14 |
15 | close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
16 |
17 | return preview_box, preview_count, preview_samples, close_btn
18 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/eval.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict
3 |
4 | from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
5 | from llmtuner.webui.components.data import create_preview_box
6 | from llmtuner.webui.utils import can_preview, get_preview
7 |
8 | if TYPE_CHECKING:
9 | from gradio.components import Component
10 | from llmtuner.webui.engine import Engine
11 |
12 |
13 | def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
14 | input_elems = engine.manager.get_base_elems()
15 | elem_dict = dict()
16 |
17 | with gr.Row():
18 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
19 | dataset = gr.Dropdown(multiselect=True, scale=4)
20 | data_preview_btn = gr.Button(interactive=False, scale=1)
21 |
22 | dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
23 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
24 |
25 | input_elems.update({dataset_dir, dataset})
26 | elem_dict.update(dict(
27 | dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
28 | ))
29 |
30 | preview_box, preview_count, preview_samples, close_btn = create_preview_box()
31 |
32 | data_preview_btn.click(
33 | get_preview,
34 | [dataset_dir, dataset],
35 | [preview_count, preview_samples, preview_box],
36 | queue=False
37 | )
38 |
39 | elem_dict.update(dict(
40 | preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
41 | ))
42 |
43 | with gr.Row():
44 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
45 | max_samples = gr.Textbox(value="100000")
46 | batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
47 | predict = gr.Checkbox(value=True)
48 |
49 | input_elems.update({cutoff_len, max_samples, batch_size, predict})
50 | elem_dict.update(dict(
51 | cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
52 | ))
53 |
54 | with gr.Row():
55 | max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
56 | top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
57 | temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
58 |
59 | input_elems.update({max_new_tokens, top_p, temperature})
60 | elem_dict.update(dict(
61 | max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
62 | ))
63 |
64 | with gr.Row():
65 | cmd_preview_btn = gr.Button()
66 | start_btn = gr.Button()
67 | stop_btn = gr.Button()
68 |
69 | with gr.Row():
70 | resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
71 | process_bar = gr.Slider(visible=False, interactive=False)
72 |
73 | with gr.Box():
74 | output_box = gr.Markdown()
75 |
76 | output_elems = [output_box, process_bar]
77 | elem_dict.update(dict(
78 | cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
79 | resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
80 | ))
81 |
82 | cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
83 | start_btn.click(engine.runner.run_eval, input_elems, output_elems)
84 | stop_btn.click(engine.runner.set_abort, queue=False)
85 | resume_btn.change(engine.runner.monitor, outputs=output_elems)
86 |
87 | return elem_dict
88 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/export.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict
3 |
4 | from llmtuner.webui.utils import save_model
5 |
6 | if TYPE_CHECKING:
7 | from gradio.components import Component
8 | from llmtuner.webui.engine import Engine
9 |
10 |
11 | def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
12 | elem_dict = dict()
13 |
14 | with gr.Row():
15 | export_dir = gr.Textbox()
16 | max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
17 |
18 | export_btn = gr.Button()
19 | info_box = gr.Textbox(show_label=False, interactive=False)
20 |
21 | export_btn.click(
22 | save_model,
23 | [
24 | engine.manager.get_elem("top.lang"),
25 | engine.manager.get_elem("top.model_name"),
26 | engine.manager.get_elem("top.model_path"),
27 | engine.manager.get_elem("top.checkpoints"),
28 | engine.manager.get_elem("top.finetuning_type"),
29 | engine.manager.get_elem("top.template"),
30 | max_shard_size,
31 | export_dir
32 | ],
33 | [info_box]
34 | )
35 |
36 | elem_dict.update(dict(
37 | export_dir=export_dir,
38 | max_shard_size=max_shard_size,
39 | export_btn=export_btn,
40 | info_box=info_box
41 | ))
42 |
43 | return elem_dict
44 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/infer.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict
3 |
4 | from llmtuner.webui.components.chatbot import create_chat_box
5 |
6 | if TYPE_CHECKING:
7 | from gradio.components import Component
8 | from llmtuner.webui.engine import Engine
9 |
10 |
11 | def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
12 | input_elems = engine.manager.get_base_elems()
13 | elem_dict = dict()
14 |
15 | with gr.Row():
16 | load_btn = gr.Button()
17 | unload_btn = gr.Button()
18 |
19 | info_box = gr.Textbox(show_label=False, interactive=False)
20 | elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
21 |
22 | chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
23 | elem_dict.update(dict(chat_box=chat_box, **chat_elems))
24 |
25 | load_btn.click(
26 | engine.chatter.load_model, input_elems, [info_box]
27 | ).then(
28 | lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
29 | )
30 |
31 | unload_btn.click(
32 | engine.chatter.unload_model, input_elems, [info_box]
33 | ).then(
34 | lambda: ([], []), outputs=[chatbot, history]
35 | ).then(
36 | lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
37 | )
38 |
39 | return elem_dict
40 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/top.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict
3 |
4 | from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
5 | from llmtuner.extras.template import templates
6 | from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
7 | from llmtuner.webui.utils import can_quantize
8 |
9 | if TYPE_CHECKING:
10 | from gradio.components import Component
11 |
12 |
13 | def create_top() -> Dict[str, "Component"]:
14 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
15 |
16 | with gr.Row():
17 | lang = gr.Dropdown(choices=["en", "zh"], scale=1)
18 | model_name = gr.Dropdown(choices=available_models, scale=3)
19 | model_path = gr.Textbox(scale=3)
20 |
21 | with gr.Row():
22 | finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
23 | checkpoints = gr.Dropdown(multiselect=True, scale=5)
24 | refresh_btn = gr.Button(scale=1)
25 |
26 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
27 | with gr.Row():
28 | quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
29 | template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
30 | system_prompt = gr.Textbox(scale=2)
31 |
32 | with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
33 | with gr.Row():
34 | with gr.Column():
35 | flash_attn = gr.Checkbox(value=False)
36 | shift_attn = gr.Checkbox(value=False)
37 | rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
38 |
39 | model_name.change(
40 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
41 | ).then(
42 | get_model_path, [model_name], [model_path], queue=False
43 | ).then(
44 | get_template, [model_name], [template], queue=False
45 | ) # do not save config since the below line will save
46 |
47 | model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
48 |
49 | finetuning_type.change(
50 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
51 | ).then(
52 | can_quantize, [finetuning_type], [quantization_bit], queue=False
53 | )
54 |
55 | refresh_btn.click(
56 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
57 | )
58 |
59 | return dict(
60 | lang=lang,
61 | model_name=model_name,
62 | model_path=model_path,
63 | finetuning_type=finetuning_type,
64 | checkpoints=checkpoints,
65 | refresh_btn=refresh_btn,
66 | advanced_tab=advanced_tab,
67 | quantization_bit=quantization_bit,
68 | template=template,
69 | system_prompt=system_prompt,
70 | llama_tab=llama_tab,
71 | flash_attn=flash_attn,
72 | shift_attn=shift_attn,
73 | rope_scaling=rope_scaling
74 | )
75 |
--------------------------------------------------------------------------------
/llmtuner/webui/components/train.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import TYPE_CHECKING, Dict
3 | from transformers.trainer_utils import SchedulerType
4 |
5 | from llmtuner.extras.constants import TRAINING_STAGES
6 | from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
7 | from llmtuner.webui.components.data import create_preview_box
8 | from llmtuner.webui.utils import can_preview, get_preview, gen_plot
9 |
10 | if TYPE_CHECKING:
11 | from gradio.components import Component
12 | from llmtuner.webui.engine import Engine
13 |
14 |
15 | def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
16 | input_elems = engine.manager.get_base_elems()
17 | elem_dict = dict()
18 |
19 | with gr.Row():
20 | training_stage = gr.Dropdown(
21 | choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
22 | )
23 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
24 | dataset = gr.Dropdown(multiselect=True, scale=4)
25 | data_preview_btn = gr.Button(interactive=False, scale=1)
26 |
27 | training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
28 | dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
29 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False)
30 |
31 | input_elems.update({training_stage, dataset_dir, dataset})
32 | elem_dict.update(dict(
33 | training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn
34 | ))
35 |
36 | preview_box, preview_count, preview_samples, close_btn = create_preview_box()
37 |
38 | data_preview_btn.click(
39 | get_preview,
40 | [dataset_dir, dataset],
41 | [preview_count, preview_samples, preview_box],
42 | queue=False
43 | )
44 |
45 | elem_dict.update(dict(
46 | preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn
47 | ))
48 |
49 | with gr.Row():
50 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
51 | learning_rate = gr.Textbox(value="5e-5")
52 | num_train_epochs = gr.Textbox(value="3.0")
53 | max_samples = gr.Textbox(value="100000")
54 | compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
55 |
56 | input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
57 | elem_dict.update(dict(
58 | cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
59 | max_samples=max_samples, compute_type=compute_type
60 | ))
61 |
62 | with gr.Row():
63 | batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
64 | gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
65 | lr_scheduler_type = gr.Dropdown(
66 | choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
67 | )
68 | max_grad_norm = gr.Textbox(value="1.0")
69 | val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
70 |
71 | input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
72 | elem_dict.update(dict(
73 | batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
74 | lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
75 | ))
76 |
77 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
78 | with gr.Row():
79 | logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
80 | save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
81 | warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
82 | neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
83 |
84 | with gr.Column():
85 | train_on_prompt = gr.Checkbox(value=False)
86 | upcast_layernorm = gr.Checkbox(value=False)
87 |
88 | input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
89 | elem_dict.update(dict(
90 | advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
91 | neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
92 | ))
93 |
94 | with gr.Accordion(label="LoRA config", open=False) as lora_tab:
95 | with gr.Row():
96 | lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
97 | lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
98 | lora_target = gr.Textbox(scale=1)
99 | additional_target = gr.Textbox(scale=1)
100 | resume_lora_training = gr.Checkbox(value=True, scale=1)
101 |
102 | input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
103 | elem_dict.update(dict(
104 | lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
105 | additional_target=additional_target, resume_lora_training=resume_lora_training,
106 | ))
107 |
108 | with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
109 | with gr.Row():
110 | dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
111 | reward_model = gr.Dropdown(scale=3)
112 | refresh_btn = gr.Button(scale=1)
113 |
114 | refresh_btn.click(
115 | list_checkpoint,
116 | [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")],
117 | [reward_model],
118 | queue=False
119 | )
120 |
121 | input_elems.update({dpo_beta, reward_model})
122 | elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
123 |
124 | with gr.Row():
125 | cmd_preview_btn = gr.Button()
126 | start_btn = gr.Button()
127 | stop_btn = gr.Button()
128 |
129 | with gr.Row():
130 | with gr.Column(scale=3):
131 | with gr.Row():
132 | output_dir = gr.Textbox()
133 |
134 | with gr.Row():
135 | resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
136 | process_bar = gr.Slider(visible=False, interactive=False)
137 |
138 | with gr.Box():
139 | output_box = gr.Markdown()
140 |
141 | with gr.Column(scale=1):
142 | loss_viewer = gr.Plot()
143 |
144 | input_elems.add(output_dir)
145 | output_elems = [output_box, process_bar]
146 | elem_dict.update(dict(
147 | cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
148 | resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
149 | ))
150 |
151 | cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
152 | start_btn.click(engine.runner.run_train, input_elems, output_elems)
153 | stop_btn.click(engine.runner.set_abort, queue=False)
154 | resume_btn.change(engine.runner.monitor, outputs=output_elems)
155 |
156 | output_box.change(
157 | gen_plot,
158 | [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir],
159 | loss_viewer,
160 | queue=False
161 | )
162 |
163 | return elem_dict
164 |
--------------------------------------------------------------------------------
/llmtuner/webui/css.py:
--------------------------------------------------------------------------------
1 | CSS = r"""
2 | .modal-box {
3 | position: fixed !important;
4 | top: 50%;
5 | left: 50%;
6 | transform: translate(-50%, -50%); /* center horizontally */
7 | max-width: 1000px;
8 | max-height: 750px;
9 | background-color: var(--input-background-fill);
10 | border: 2px solid black !important;
11 | z-index: 1000;
12 | padding: 10px;
13 | }
14 |
15 | .dark .modal-box {
16 | border: 2px solid white !important;
17 | }
18 | """
19 |
--------------------------------------------------------------------------------
/llmtuner/webui/engine.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from gradio.components import Component # cannot use TYPE_CHECKING here
3 | from typing import Any, Dict, Generator, Optional
4 |
5 | from llmtuner.webui.chatter import WebChatModel
6 | from llmtuner.webui.common import get_model_path, list_dataset, load_config
7 | from llmtuner.webui.locales import LOCALES
8 | from llmtuner.webui.manager import Manager
9 | from llmtuner.webui.runner import Runner
10 | from llmtuner.webui.utils import get_time
11 |
12 |
13 | class Engine:
14 |
15 | def __init__(self, pure_chat: Optional[bool] = False) -> None:
16 | self.pure_chat = pure_chat
17 | self.manager: "Manager" = Manager()
18 | self.runner: "Runner" = Runner(self.manager)
19 | self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
20 |
21 | def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
22 | return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()}
23 |
24 | def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
25 | user_config = load_config()
26 | lang = user_config.get("lang", None) or "en"
27 |
28 | init_dict = {
29 | "top.lang": {"value": lang},
30 | "infer.chat_box": {"visible": self.chatter.loaded}
31 | }
32 |
33 | if not self.pure_chat:
34 | init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
35 | init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
36 |
37 | if user_config.get("last_model", None):
38 | init_dict["top.model_name"] = {"value": user_config["last_model"]}
39 | init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
40 |
41 | yield self._form_dict(init_dict)
42 |
43 | if not self.pure_chat:
44 | if self.runner.alive:
45 | yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()}
46 | if self.runner.do_train:
47 | yield self._form_dict({"train.resume_btn": {"value": True}})
48 | else:
49 | yield self._form_dict({"eval.resume_btn": {"value": True}})
50 | else:
51 | yield self._form_dict({"train.output_dir": {"value": get_time()}})
52 |
53 | def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
54 | return {
55 | component: gr.update(**LOCALES[name][lang])
56 | for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
57 | }
58 |
--------------------------------------------------------------------------------
/llmtuner/webui/interface.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from transformers.utils.versions import require_version
3 |
4 | from llmtuner.webui.components import (
5 | create_top,
6 | create_train_tab,
7 | create_eval_tab,
8 | create_infer_tab,
9 | create_export_tab,
10 | create_chat_box
11 | )
12 | from llmtuner.webui.common import save_config
13 | from llmtuner.webui.css import CSS
14 | from llmtuner.webui.engine import Engine
15 |
16 |
17 | require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2")
18 |
19 |
20 | def create_ui() -> gr.Blocks:
21 | engine = Engine(pure_chat=False)
22 |
23 | with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
24 | engine.manager.all_elems["top"] = create_top()
25 | lang: "gr.Dropdown" = engine.manager.get_elem("top.lang")
26 |
27 | with gr.Tab("Train"):
28 | engine.manager.all_elems["train"] = create_train_tab(engine)
29 |
30 | with gr.Tab("Evaluate"):
31 | engine.manager.all_elems["eval"] = create_eval_tab(engine)
32 |
33 | with gr.Tab("Chat"):
34 | engine.manager.all_elems["infer"] = create_infer_tab(engine)
35 |
36 | with gr.Tab("Export"):
37 | engine.manager.all_elems["export"] = create_export_tab(engine)
38 |
39 | demo.load(engine.resume, outputs=engine.manager.list_elems())
40 | lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
41 | lang.input(save_config, inputs=[lang], queue=False)
42 |
43 | return demo
44 |
45 |
46 | def create_web_demo() -> gr.Blocks:
47 | engine = Engine(pure_chat=True)
48 |
49 | with gr.Blocks(title="Web Demo", css=CSS) as demo:
50 | lang = gr.Dropdown(choices=["en", "zh"])
51 | engine.manager.all_elems["top"] = dict(lang=lang)
52 |
53 | chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
54 | engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
55 |
56 | demo.load(engine.resume, outputs=engine.manager.list_elems())
57 | lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
58 | lang.input(save_config, inputs=[lang], queue=False)
59 |
60 | return demo
61 |
62 |
63 | if __name__ == "__main__":
64 | demo = create_ui()
65 | demo.queue()
66 | demo.launch(server_name="0.0.0.0", server_port=7866, share=False, inbrowser=True)
67 |
--------------------------------------------------------------------------------
/llmtuner/webui/manager.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Dict, List
2 |
3 | if TYPE_CHECKING:
4 | from gradio.components import Component
5 |
6 |
7 | class Manager:
8 |
9 | def __init__(self) -> None:
10 | self.all_elems: Dict[str, Dict[str, "Component"]] = {}
11 |
12 | def get_elem(self, name: str) -> "Component":
13 | r"""
14 | Example: top.lang, train.dataset
15 | """
16 | tab_name, elem_name = name.split(".")
17 | return self.all_elems[tab_name][elem_name]
18 |
19 | def get_base_elems(self):
20 | return {
21 | self.all_elems["top"]["lang"],
22 | self.all_elems["top"]["model_name"],
23 | self.all_elems["top"]["model_path"],
24 | self.all_elems["top"]["checkpoints"],
25 | self.all_elems["top"]["finetuning_type"],
26 | self.all_elems["top"]["quantization_bit"],
27 | self.all_elems["top"]["template"],
28 | self.all_elems["top"]["system_prompt"],
29 | self.all_elems["top"]["flash_attn"],
30 | self.all_elems["top"]["shift_attn"],
31 | self.all_elems["top"]["rope_scaling"]
32 | }
33 |
34 | def list_elems(self) -> List["Component"]:
35 | return [elem for elems in self.all_elems.values() for elem in elems.values()]
36 |
--------------------------------------------------------------------------------
/llmtuner/webui/runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import logging
4 | import gradio as gr
5 | from threading import Thread
6 | from gradio.components import Component # cannot use TYPE_CHECKING here
7 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
8 |
9 | import transformers
10 | from transformers.trainer import TRAINING_ARGS_NAME
11 |
12 | from llmtuner.extras.callbacks import LogCallback
13 | from llmtuner.extras.constants import TRAINING_STAGES
14 | from llmtuner.extras.logging import LoggerHandler
15 | from llmtuner.extras.misc import torch_gc
16 | from llmtuner.tuner import run_exp
17 | from llmtuner.webui.common import get_module, get_save_dir, load_config
18 | from llmtuner.webui.locales import ALERTS
19 | from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
20 |
21 | if TYPE_CHECKING:
22 | from llmtuner.webui.manager import Manager
23 |
24 |
25 | class Runner:
26 |
27 | def __init__(self, manager: "Manager") -> None:
28 | self.manager = manager
29 | self.thread: "Thread" = None
30 | self.data: Dict["Component", Any] = None
31 | self.do_train = True
32 | self.monitor_inputs: Dict[str, str] = None
33 | self.aborted = False
34 | self.running = False
35 | self.logger_handler = LoggerHandler()
36 | self.logger_handler.setLevel(logging.INFO)
37 | logging.root.addHandler(self.logger_handler)
38 | transformers.logging.add_handler(self.logger_handler)
39 |
40 | @property
41 | def alive(self) -> bool:
42 | return self.thread is not None
43 |
44 | def set_abort(self) -> None:
45 | self.aborted = True
46 | self.running = False
47 |
48 | def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str:
49 | if self.running:
50 | return ALERTS["err_conflict"][lang]
51 |
52 | if not model_name:
53 | return ALERTS["err_no_model"][lang]
54 |
55 | if not model_path:
56 | return ALERTS["err_no_path"][lang]
57 |
58 | if len(dataset) == 0:
59 | return ALERTS["err_no_dataset"][lang]
60 |
61 | self.aborted = False
62 | self.logger_handler.reset()
63 | self.trainer_callback = LogCallback(self)
64 | return ""
65 |
66 | def _finalize(self, lang: str, finish_info: str) -> str:
67 | self.thread = None
68 | self.running = False
69 | torch_gc()
70 | if self.aborted:
71 | return ALERTS["info_aborted"][lang]
72 | else:
73 | return finish_info
74 |
75 | def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
76 | get = lambda name: data[self.manager.get_elem(name)]
77 | user_config = load_config()
78 |
79 | if get("top.checkpoints"):
80 | checkpoint_dir = ",".join([
81 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
82 | ])
83 | else:
84 | checkpoint_dir = None
85 |
86 | output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
87 |
88 | args = dict(
89 | stage=TRAINING_STAGES[get("train.training_stage")],
90 | model_name_or_path=get("top.model_path"),
91 | do_train=True,
92 | cache_dir=user_config.get("cache_dir", None),
93 | checkpoint_dir=checkpoint_dir,
94 | finetuning_type=get("top.finetuning_type"),
95 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
96 | template=get("top.template"),
97 | system_prompt=get("top.system_prompt"),
98 | flash_attn=get("top.flash_attn"),
99 | shift_attn=get("top.shift_attn"),
100 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
101 | dataset_dir=get("train.dataset_dir"),
102 | dataset=",".join(get("train.dataset")),
103 | cutoff_len=get("train.cutoff_len"),
104 | learning_rate=float(get("train.learning_rate")),
105 | num_train_epochs=float(get("train.num_train_epochs")),
106 | max_samples=int(get("train.max_samples")),
107 | per_device_train_batch_size=get("train.batch_size"),
108 | gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
109 | lr_scheduler_type=get("train.lr_scheduler_type"),
110 | max_grad_norm=float(get("train.max_grad_norm")),
111 | logging_steps=get("train.logging_steps"),
112 | save_steps=get("train.save_steps"),
113 | warmup_steps=get("train.warmup_steps"),
114 | neft_alpha=get("train.neft_alpha"),
115 | train_on_prompt=get("train.train_on_prompt"),
116 | upcast_layernorm=get("train.upcast_layernorm"),
117 | lora_rank=get("train.lora_rank"),
118 | lora_dropout=get("train.lora_dropout"),
119 | lora_target=get("train.lora_target") or get_module(get("top.model_name")),
120 | additional_target=get("train.additional_target") if get("train.additional_target") else None,
121 | resume_lora_training=get("train.resume_lora_training"),
122 | output_dir=output_dir
123 | )
124 | args[get("train.compute_type")] = True
125 | args["disable_tqdm"] = True
126 |
127 | if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
128 | args["resume_lora_training"] = (args["quantization_bit"] is not None)
129 |
130 | if args["quantization_bit"] is not None:
131 | args["upcast_layernorm"] = True
132 |
133 | if args["stage"] == "ppo":
134 | args["reward_model"] = get("train.reward_model")
135 |
136 | if args["stage"] == "dpo":
137 | args["dpo_beta"] = get("train.dpo_beta")
138 |
139 | if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
140 | args["val_size"] = get("train.val_size")
141 | args["evaluation_strategy"] = "steps"
142 | args["eval_steps"] = get("train.save_steps")
143 | args["load_best_model_at_end"] = True
144 |
145 | return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args
146 |
147 | def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]:
148 | get = lambda name: data[self.manager.get_elem(name)]
149 | user_config = load_config()
150 |
151 | if get("top.checkpoints"):
152 | checkpoint_dir = ",".join([
153 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
154 | ])
155 | output_dir = get_save_dir(
156 | get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
157 | )
158 | else:
159 | checkpoint_dir = None
160 | output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
161 |
162 | args = dict(
163 | stage="sft",
164 | model_name_or_path=get("top.model_path"),
165 | do_eval=True,
166 | predict_with_generate=True,
167 | cache_dir=user_config.get("cache_dir", None),
168 | checkpoint_dir=checkpoint_dir,
169 | finetuning_type=get("top.finetuning_type"),
170 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
171 | template=get("top.template"),
172 | system_prompt=get("top.system_prompt"),
173 | flash_attn=get("top.flash_attn"),
174 | shift_attn=get("top.shift_attn"),
175 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
176 | dataset_dir=get("eval.dataset_dir"),
177 | dataset=",".join(get("eval.dataset")),
178 | cutoff_len=get("eval.cutoff_len"),
179 | max_samples=int(get("eval.max_samples")),
180 | per_device_eval_batch_size=get("eval.batch_size"),
181 | max_new_tokens=get("eval.max_new_tokens"),
182 | top_p=get("eval.top_p"),
183 | temperature=get("eval.temperature"),
184 | output_dir=output_dir
185 | )
186 |
187 | if get("eval.predict"):
188 | args.pop("do_eval", None)
189 | args["do_predict"] = True
190 |
191 | return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args
192 |
193 | def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
194 | parse_func = self._parse_train_args if do_train else self._parse_eval_args
195 | lang, model_name, model_path, dataset, _, args = parse_func(data)
196 | error = self._initialize(lang, model_name, model_path, dataset)
197 | if error:
198 | yield error, gr.update(visible=False)
199 | else:
200 | yield gen_cmd(args), gr.update(visible=False)
201 |
202 | def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
203 | parse_func = self._parse_train_args if do_train else self._parse_eval_args
204 | lang, model_name, model_path, dataset, output_dir, args = parse_func(data)
205 | self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir)
206 | error = self._initialize(lang, model_name, model_path, dataset)
207 | if error:
208 | yield error, gr.update(visible=False)
209 | else:
210 | self.running = True
211 | run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
212 | self.thread = Thread(target=run_exp, kwargs=run_kwargs)
213 | self.thread.start()
214 | yield from self.monitor()
215 |
216 | def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
217 | yield from self._preview(data, do_train=True)
218 |
219 | def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
220 | yield from self._preview(data, do_train=False)
221 |
222 | def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
223 | yield from self._launch(data, do_train=True)
224 |
225 | def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
226 | yield from self._launch(data, do_train=False)
227 |
228 | def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
229 | lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
230 | while self.thread.is_alive():
231 | time.sleep(2)
232 | if self.aborted:
233 | yield ALERTS["info_aborting"][lang], gr.update(visible=False)
234 | else:
235 | yield self.logger_handler.log, update_process_bar(self.trainer_callback)
236 |
237 | if self.do_train:
238 | if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
239 | finish_info = ALERTS["info_finished"][lang]
240 | else:
241 | finish_info = ALERTS["err_failed"][lang]
242 | else:
243 | if os.path.exists(os.path.join(output_dir, "all_results.json")):
244 | finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
245 | else:
246 | finish_info = ALERTS["err_failed"][lang]
247 |
248 | yield self._finalize(lang, finish_info), gr.update(visible=False)
249 |
--------------------------------------------------------------------------------
/llmtuner/webui/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import gradio as gr
4 | import matplotlib.figure
5 | import matplotlib.pyplot as plt
6 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
7 | from datetime import datetime
8 |
9 | from llmtuner.extras.ploting import smooth
10 | from llmtuner.tuner import export_model
11 | from llmtuner.webui.common import get_save_dir, DATA_CONFIG
12 | from llmtuner.webui.locales import ALERTS
13 |
14 | if TYPE_CHECKING:
15 | from llmtuner.extras.callbacks import LogCallback
16 |
17 |
18 | def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
19 | if not callback.max_steps:
20 | return gr.update(visible=False)
21 |
22 | percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
23 | label = "Running {:d}/{:d}: {} < {}".format(
24 | callback.cur_steps,
25 | callback.max_steps,
26 | callback.elapsed_time,
27 | callback.remaining_time
28 | )
29 | return gr.update(label=label, value=percentage, visible=True)
30 |
31 |
32 | def get_time() -> str:
33 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
34 |
35 |
36 | def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
37 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
38 | dataset_info = json.load(f)
39 |
40 | if (
41 | len(dataset) > 0
42 | and "file_name" in dataset_info[dataset[0]]
43 | and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
44 | ):
45 | return gr.update(interactive=True)
46 | else:
47 | return gr.update(interactive=False)
48 |
49 |
50 | def get_preview(
51 | dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2
52 | ) -> Tuple[int, list, Dict[str, Any]]:
53 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
54 | dataset_info = json.load(f)
55 |
56 | data_file: str = dataset_info[dataset[0]]["file_name"]
57 | with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
58 | if data_file.endswith(".json"):
59 | data = json.load(f)
60 | elif data_file.endswith(".jsonl"):
61 | data = [json.loads(line) for line in f]
62 | else:
63 | data = [line for line in f]
64 | return len(data), data[start:end], gr.update(visible=True)
65 |
66 |
67 | def can_quantize(finetuning_type: str) -> Dict[str, Any]:
68 | if finetuning_type != "lora":
69 | return gr.update(value="None", interactive=False)
70 | else:
71 | return gr.update(interactive=True)
72 |
73 |
74 | def gen_cmd(args: Dict[str, Any]) -> str:
75 | args.pop("disable_tqdm", None)
76 | args["plot_loss"] = args.get("do_train", None)
77 | cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
78 | for k, v in args.items():
79 | if v is not None and v != "":
80 | cmd_lines.append(" --{} {} ".format(k, str(v)))
81 | cmd_text = "\\\n".join(cmd_lines)
82 | cmd_text = "```bash\n{}\n```".format(cmd_text)
83 | return cmd_text
84 |
85 |
86 | def get_eval_results(path: os.PathLike) -> str:
87 | with open(path, "r", encoding="utf-8") as f:
88 | result = json.dumps(json.load(f), indent=4)
89 | return "```json\n{}\n```\n".format(result)
90 |
91 |
92 | def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
93 | log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
94 | if not os.path.isfile(log_file):
95 | return None
96 |
97 | plt.close("all")
98 | fig = plt.figure()
99 | ax = fig.add_subplot(111)
100 | steps, losses = [], []
101 | with open(log_file, "r", encoding="utf-8") as f:
102 | for line in f:
103 | log_info = json.loads(line)
104 | if log_info.get("loss", None):
105 | steps.append(log_info["current_steps"])
106 | losses.append(log_info["loss"])
107 |
108 | if len(losses) == 0:
109 | return None
110 |
111 | ax.plot(steps, losses, alpha=0.4, label="original")
112 | ax.plot(steps, smooth(losses), label="smoothed")
113 | ax.legend()
114 | ax.set_xlabel("step")
115 | ax.set_ylabel("loss")
116 | return fig
117 |
118 |
119 | def save_model(
120 | lang: str,
121 | model_name: str,
122 | model_path: str,
123 | checkpoints: List[str],
124 | finetuning_type: str,
125 | template: str,
126 | max_shard_size: int,
127 | export_dir: str
128 | ) -> Generator[str, None, None]:
129 | if not model_name:
130 | yield ALERTS["err_no_model"][lang]
131 | return
132 |
133 | if not model_path:
134 | yield ALERTS["err_no_path"][lang]
135 | return
136 |
137 | if not checkpoints:
138 | yield ALERTS["err_no_checkpoint"][lang]
139 | return
140 |
141 | if not export_dir:
142 | yield ALERTS["err_no_export_dir"][lang]
143 | return
144 |
145 | args = dict(
146 | model_name_or_path=model_path,
147 | checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
148 | finetuning_type=finetuning_type,
149 | template=template,
150 | export_dir=export_dir
151 | )
152 |
153 | yield ALERTS["info_exporting"][lang]
154 | export_model(args, max_shard_size="{}GB".format(max_shard_size))
155 | yield ALERTS["info_exported"][lang]
156 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.0
2 | torchvision==0.15.1
3 | transformers==4.33.2
4 | datasets==2.14.6
5 | accelerate==0.24.1
6 | peft==0.6.0
7 | trl>=0.7.2
8 | scipy==1.10.1
9 | sentencepiece
10 | protobuf
11 | tiktoken
12 | fire
13 | jieba==0.42.1
14 | rouge-chinese
15 | nltk
16 | gradio==3.50.2
17 | uvicorn
18 | pydantic==1.10.11
19 | fastapi==0.95.1
20 | sse-starlette
21 | matplotlib
22 | scikit-learn==1.2.2
23 | openpyxl==3.0.10
24 | pandas==1.5.3
25 | pynvml
26 | xlsxwriter
--------------------------------------------------------------------------------
/train_bash.py:
--------------------------------------------------------------------------------
1 | from llmtuner import run_exp
2 |
3 |
4 | def main():
5 | run_exp()
6 |
7 |
8 | def _mp_fn(index):
9 | # For xla_spawn (TPUs)
10 | main()
11 |
12 |
13 | if __name__ == "__main__":
14 | main()
15 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import json
3 | import datetime,time
4 | import shutil
5 | import os
6 | import re
7 | import subprocess
8 | from sklearn.metrics import classification_report
9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex,
10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown)
11 |
12 | def read_excel_file(file_path):
13 | df = pd.read_excel(file_path)
14 | return df
15 |
16 | def save_to_excel(df, file_path):
17 | df.to_excel(file_path, index=False)
18 |
19 | def get_available_gpu(threshold=20000):
20 | # Initialize NVML
21 | nvmlInit()
22 |
23 | # Get the number of GPU devices
24 | device_count = nvmlDeviceGetCount()
25 |
26 | # Find GPU devices with available memory greater than the threshold
27 | available_gpus = []
28 | for i in range(device_count):
29 | handle = nvmlDeviceGetHandleByIndex(i)
30 | info = nvmlDeviceGetMemoryInfo(handle)
31 | free_memory_mb = info.free / 1024 / 1024
32 |
33 | if free_memory_mb > threshold:
34 | available_gpus.append(i)
35 |
36 | # Shutdown NVML
37 | nvmlShutdown()
38 |
39 | return available_gpus
40 |
41 |
42 | def is_numeric(value):
43 | try:
44 | float(value) # 尝试将值转换为浮点数
45 | return True # 如果转换成功,则表示值可以转换为数字
46 | except (ValueError, TypeError):
47 | return False # 如果转换失败或者值的类型不是字符串或数字,则表示值不是数字
48 |
49 |
50 | def accuracy_cal(list1, list2):
51 | count = 0
52 | for i in range(len(list1)):
53 | l = list1[i]
54 | r = list2[i]
55 | l = float(l) if is_numeric(l) else str(l)
56 | # print('l',is_numeric(l),l)
57 | r = float(r) if is_numeric(r) else str(r)
58 | # print('r',is_numeric(r),r)
59 | if l == r:
60 | count += 1
61 | accuracy = count / len(list1)
62 | return accuracy
63 |
64 | def evaluate_model(pred_path, gold_path):
65 | output_path = 'data/evaluate_scores.xlsx'
66 |
67 | # 读取金标签数据excel
68 | gold_data = pd.read_excel(gold_path, sheet_name=None)
69 |
70 | # 读取预测结果excel
71 | pred_data = pd.read_excel(pred_path, sheet_name=None)
72 |
73 | # 创建输出结果集excel
74 | # writer = pd.ExcelWriter(output_path)
75 | writer = pd.ExcelWriter(output_path, engine='xlsxwriter')
76 |
77 | # 定义结果集DataFrame
78 | result_df = pd.DataFrame(columns=['sheet', 'field', 'precision', 'recall', 'f1-score', 'accuracy', 'support'])
79 |
80 | # 遍历金标签数据excel的所有sheet
81 | for sheet_name in gold_data.keys():
82 | # 获取金标签数据和预测结果的对应sheet
83 | gold_sheet = gold_data[sheet_name]
84 | pred_sheet = pred_data.get(sheet_name, None)
85 |
86 | # 如果预测结果中没有该sheet,输出提示信息
87 | if pred_sheet is None:
88 | print(f"Warning: Sheet '{sheet_name}' not found in prediction file.")
89 | continue
90 |
91 | # 获取金标签数据和预测结果的所有字段
92 | gold_columns = sorted(gold_sheet.columns.tolist())
93 | pred_columns = sorted(pred_sheet.columns.tolist())
94 |
95 | # 确保金标签数据和预测结果中的字段一致
96 | if set(gold_columns) != set(pred_columns):
97 | print(f"Warning: Columns mismatch in sheet '{sheet_name}'.")
98 | continue
99 |
100 | # 提取金标签数据和预测结果中的标签数据
101 | gold_labels = gold_sheet.fillna(value='未提及')
102 | pred_labels = pred_sheet.fillna(value='未提及')
103 | # 判断预测结果中是否存在空值
104 | if pred_labels.isnull().any().any():
105 | print(f"Warning: Missing values detected in sheet '{sheet_name}'.")
106 |
107 | # 将预测结果中的数据类型转换为与金标签数据excel中相应的字段相同的数据类型
108 | for col in gold_columns:
109 | if gold_labels[col].dtype != pred_labels[col].dtype:
110 | pred_labels[col] = pred_labels[col].fillna('').astype(str)
111 | gold_labels[col] = gold_labels[col].fillna('').astype(str)
112 | max_lengths = {col: pred_labels[col].apply(lambda x: str(x)).str.len().max() for col in pred_labels.columns}
113 | pred_labels = pred_labels.fillna('').astype(str)
114 | gold_labels = gold_labels.fillna('').astype(str)
115 | # 计算准确率、F1分数、精确率和召回率
116 | for col in gold_columns:
117 | report = classification_report(gold_labels[col], pred_labels[col], output_dict=True, zero_division=0)
118 | accuracy = accuracy_cal(list(gold_labels[col]), list(pred_labels[col])) #report['accuracy']
119 | new_row = pd.DataFrame({'sheet': sheet_name, 'field': col,
120 | 'precision': report['weighted avg']['precision'],
121 | 'recall': report['weighted avg']['recall'],
122 | 'f1-score': report['weighted avg']['f1-score'],
123 | 'accuracy': accuracy,
124 | 'support': report['weighted avg']['support']}, index=[0])
125 | result_df = pd.concat([result_df, new_row], axis=0, ignore_index=True)
126 | # 将结果写入输出结果集excel中的新sheet
127 | result_df.to_excel(writer, sheet_name='result', index=False)
128 | workbook = writer.book
129 | worksheet = writer.sheets['result']
130 | percent_format = workbook.add_format({'num_format': '0%'})
131 | worksheet.set_column('C:F', None, percent_format)
132 | worksheet.conditional_format('C2:F500', {'type': 'data_bar',
133 | 'bar_color': '#FFA500'})
134 | # 保存输出结果集excel
135 | writer.close()
136 | print(f"Comparison complete. Results saved to '{output_path}'.")
137 | return output_path
138 |
139 |
140 | def stop_train_process():
141 | process = subprocess.Popen('ps -ef | grep train_bash.py', shell=True, stdout=subprocess.PIPE)
142 | output, _ = process.communicate()
143 | process.kill()
144 | n = 0
145 | # 解析输出以获取进程ID
146 | print('output',output)
147 | try:
148 | lines = output.decode().split('\n')
149 | for line in lines:
150 | if 'train_bash.py' in line:
151 | parts = line.split()
152 | pid = parts[1]
153 | # 杀死进程
154 | subprocess.call(['kill', '-9', pid])
155 | n+=1
156 | except Exception as e:
157 | print('error!!',e)
158 |
159 | return f'停止了{n//2}个进程'
160 |
--------------------------------------------------------------------------------