├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── llm_structure_tool.iml ├── modules.xml └── vcs.xml ├── README.md ├── README_en.md ├── app.py ├── config └── common_config.py ├── environment.yml ├── finetune ├── llm_utils.py └── pulse_utils.py ├── img ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg └── 5.jpg ├── llmtuner ├── __init__.py ├── api │ ├── __init__.py │ ├── app.py │ └── protocol.py ├── chat │ ├── __init__.py │ └── stream_chat.py ├── dsets │ ├── __init__.py │ ├── loader.py │ ├── preprocess.py │ └── utils.py ├── extras │ ├── __init__.py │ ├── callbacks.py │ ├── constants.py │ ├── logging.py │ ├── misc.py │ ├── patches │ │ ├── __init__.py │ │ └── llama_patch.py │ ├── ploting.py │ ├── save_and_load.py │ └── template.py ├── hparams │ ├── __init__.py │ ├── data_args.py │ ├── finetuning_args.py │ ├── general_args.py │ ├── generating_args.py │ └── model_args.py ├── tuner │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── loader.py │ │ ├── parser.py │ │ └── utils.py │ ├── dpo │ │ ├── __init__.py │ │ ├── collator.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── ppo │ │ ├── __init__.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── workflow.py │ ├── pt │ │ ├── __init__.py │ │ └── workflow.py │ ├── rm │ │ ├── __init__.py │ │ ├── collator.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py │ ├── sft │ │ ├── __init__.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py │ └── tune.py └── webui │ ├── __init__.py │ ├── chatter.py │ ├── common.py │ ├── components │ ├── __init__.py │ ├── chatbot.py │ ├── data.py │ ├── eval.py │ ├── export.py │ ├── infer.py │ ├── top.py │ └── train.py │ ├── css.py │ ├── engine.py │ ├── interface.py │ ├── locales.py │ ├── manager.py │ ├── runner.py │ └── utils.py ├── requirements.txt ├── train_bash.py └── utils └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | # 忽略Python缓存文件 3 | __pycache__/ 4 | *.pyc 5 | *.pyo 6 | *.pyd 7 | 8 | .idea/ 9 | 10 | docker/* 11 | 12 | # 忽略一些用户产生的文件 13 | data/* 14 | finetune/*/ 15 | !finetune/*.py -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/llm_structure_tool.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [[中文版](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README.md)] [[English](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README_en.md)] 2 | 3 | # 大模型结构化工具 4 | 5 | 该工具是一个可基于常见开源模型进行微调的结构化工具,旨在帮助用户处理和分析文本数据,目前提供了训练,预测,评估一体化的功能。训练预测部分采用了[[llmtuner](https://github.com/hiyouga/LLaMA-Factory)],作为一个核心包引入。 6 | 7 | 它提供了以下常见结构化类型,适用于各种场景下的结构化使用,如病例结构化场景。 8 | 9 | - **单选** 10 | 11 | ![单选案例](img/2.jpg) 12 | 13 | - **多选** 14 | 15 | ![多选案例](img/3.jpg) 16 | 17 | - **提取** 18 | 19 | ![提取案例](img/1.jpg) 20 | 21 | ### 安装 22 | 23 | 首先,克隆本项目到本地计算机: 24 | 25 | ``` 26 | git clone https://github.com/JuneYaooo/llm_structure_tool.git 27 | ``` 28 | 29 | #### 建议 conda 安装 30 | ##### 方法一 31 | ``` 32 | cd llm_structure_tool 33 | conda env create -f environment.yml 34 | ``` 35 | ##### 方法二 36 | ``` 37 | conda create -n llm_structure python=3.9 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | 激活conda环境: 42 | 43 | ``` 44 | conda activate llm_structure 45 | ``` 46 | 47 | 然后运行前端demo: 48 | 49 | ``` 50 | python app.py 51 | ``` 52 | 53 | ### 模型配置 54 | 在config/config.py中,填入自己想要使用的模型地址 55 | 56 | ## 使用方法 57 | 结构化工具将在终端上提供一个简单的交互界面。您可以根据提示输入相关信息,选择要执行的功能。 58 | 59 | ### 单句测试 60 | 61 | 输入一段话,设定规则,进行单选、多选或提取 62 | 63 | **示例:** 64 | 65 | 字段类型:提取 66 | 67 | 字段名:肾上腺肿物大小 68 | 69 | 原文:CT检查示左肾上腺区见大小约5.5 cm×5.7 cm不均匀低密度肿块,边界清楚,增强扫描实性成分中度强化,内见无强化低密度,静脉期明显强化。CT诊断:考虑左肾上腺区肿瘤。B超检查示左肾上腺区见4.6 cm×4.2 cm的低回声区,边界清,有包膜,提示左肾上腺实质性占位声像。 70 | 71 | 72 | 输入不相关的字段,如胃部肿物大小,结果为“未提及” 73 | ![提取案例-对比1](img/4.jpg) 74 | 75 | 输入相关的字段,如肾上腺肿物大小,结果为“约5.5 cm×5.7 cm” 76 | ![提取案例-对比2](img/5.jpg) 77 | 78 | ### 训练 79 | 待填充 80 | 81 | ### 预测 82 | 待填充 83 | 84 | ### 评估 85 | 待填充 86 | 87 | ## 致谢 88 | 89 | - [PULSE](https://github.com/openmedlab/PULSE): 本项目使用了PULSE模型(上海人工智能实验室的医疗开源大模型) 90 | - [llmtuner](https://github.com/hiyouga/LLaMA-Factory): 本项目训练预测代码基于llmtuner 91 | 92 | ## 贡献 93 | 94 | 如果您对该项目感兴趣,欢迎贡献您的代码和改进建议。您可以通过以下方式参与: 95 | 96 | 1. 提交问题和建议到本项目的 Issue 页面。 97 | 2. Fork 本项目并提交您的改进建议,我们将会审查并合并合适的改动。 98 | -------------------------------------------------------------------------------- /README_en.md: -------------------------------------------------------------------------------- 1 | [[中文版](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README.md)] [[English](https://github.com/JuneYaooo/llm_structure_tool/blob/main/README_en.md)] 2 | 3 | # Medical Record Structuring Tool (Under Continuous Update) 4 | 5 | This tool is a versatile structuring tool that allows fine-tuning common open-source models for various text data processing and analysis tasks. It currently provides integrated functionalities for training, prediction, and evaluation, with the training and prediction components utilizing [[llmtuner](https://github.com/hiyouga/LLaMA-Factory)] as a core package. 6 | 7 | It offers the following common structuring types applicable to various scenarios, such as medical case structuring: 8 | 9 | - Single selection 10 | 11 | ![Single selection example](img/2.jpg) 12 | 13 | - Multiple selection 14 | 15 | ![Multiple selection example](img/3.jpg) 16 | 17 | - Extraction 18 | 19 | ![Extraction example](img/1.jpg) 20 | 21 | ### Installation 22 | 23 | First, clone this project to your local computer: 24 | 25 | ``` 26 | git clone https://github.com/JuneYaooo/llm_structure_tool.git 27 | ``` 28 | 29 | #### Conda Installation (Recommended) 30 | ##### Method 1 31 | ```shell 32 | cd llm_structure_tool 33 | conda env create -f environment.yml 34 | ``` 35 | ##### Method 2 36 | ```shell 37 | conda create -n llm_structure python=3.9 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | Activate the newly created environment: 42 | 43 | ``` 44 | conda activate llm_structure 45 | ``` 46 | 47 | Then run the frontend demo: 48 | 49 | ``` 50 | python app.py 51 | ``` 52 | 53 | ## Usage 54 | 55 | The structuring tool provides a simple interactive interface in the terminal. You can enter relevant information and select the desired functionality as prompted. 56 | 57 | ### Single Sentence Testing 58 | 59 | Enter a paragraph, set the rules, and perform single selection, multiple selection, or extraction. 60 | 61 | **Example:** 62 | 63 | Field Type: 提取 64 | 65 | Field Name: 肾上腺肿物大小 66 | 67 | Original Text: CT检查示左肾上腺区见大小约5.5 cm×5.7 cm不均匀低密度肿块,边界清楚,增强扫描实性成分中度强化,内见无强化低密度,静脉期明显强化。CT诊断:考虑左肾上腺区肿瘤。B超检查示左肾上腺区见4.6 cm×4.2 cm的低回声区,边界清,有包膜,提示左肾上腺实质性占位声像。 68 | 69 | Entering an unrelated field, such as "Gastric Tumor Size," will result in "Not mentioned." 70 | ![Extraction example - Comparison 1](img/4.jpg) 71 | 72 | Entering a related field, such as "Adrenal Tumor Size," will result in "Approximately 5.5 cm × 5.7 cm." 73 | ![Extraction example - Comparison 2](img/5.jpg) 74 | 75 | ### Training 76 | To be filled 77 | 78 | ### Prediction 79 | To be filled 80 | 81 | ### Evaluation 82 | To be filled 83 | 84 | ## Acknowledgments 85 | 86 | - [PULSE](https://github.com/openmedlab/PULSE): This project uses the PULSE model (a medical open-source large language model from the Shanghai Artificial Intelligence Laboratory). 87 | - [llmtuner](https://github.com/hiyouga/LLaMA-Factory): The training and prediction code for this project is based on llmtuner. 88 | 89 | ## Contribution 90 | 91 | If you are interested in this project, you are welcome to contribute your code and improvement suggestions. You can participate in the following ways: 92 | 93 | 1. Submit issues and suggestions to the Issue page of this project. 94 | 2. Fork this project and submit your improvement suggestions. We will review and merge appropriate changes. -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | from finetune.llm_utils import infer_model, query_model, train_model 4 | import shutil 5 | import time 6 | import datetime 7 | from config.common_config import * 8 | from utils.utils import stop_train_process,evaluate_model 9 | 10 | llm_model_dict_list = list(llm_model_dict.keys()) 11 | 12 | def get_file_modify_time(filename): 13 | try: 14 | return datetime.datetime.fromtimestamp(os.stat(filename).st_mtime).strftime("%Y-%m-%d %H:%M:%S") 15 | except Exception as e: 16 | print('Failed to get modification time for {}'.format(filename)) 17 | print(e) 18 | return 'not available' 19 | 20 | def get_model_update_time(model_name, lora_name): 21 | model_file_name = llm_model_dict[model_name]['name'] 22 | print('get_model_update_time model_file_name',model_file_name) 23 | print('get_model_update_time lora_name',lora_name) 24 | model_lora_dir = os.path.join(f"finetune", model_file_name,'checkpoints',lora_name,'adapter_model.bin') 25 | print('model_lora_dir',model_lora_dir) 26 | update_time = get_file_modify_time(model_lora_dir) 27 | return update_time 28 | 29 | def on_train(model_name, lora_name, config_file, training_data_file): 30 | config_path = 'data/'+os.path.basename(config_file.name) 31 | training_data_path = 'data/'+os.path.basename(training_data_file.name) 32 | msg = train_model(model_name, lora_name, config_path, training_data_path) 33 | return msg 34 | 35 | def format_duration(seconds): 36 | hours = int(seconds // 3600) 37 | minutes = int((seconds % 3600) // 60) 38 | seconds = round(seconds % 60,2) 39 | if hours > 0: 40 | return f"{hours}时{minutes}分{seconds}秒" 41 | elif minutes > 0: 42 | return f"{minutes}分{seconds}秒" 43 | else: 44 | return f"{seconds}秒" 45 | 46 | 47 | def on_test(model_name, select_lora, config_file, test_data_file): 48 | start_time = time.time() 49 | config_path = 'data/'+os.path.basename(config_file.name) 50 | test_data_path = 'data/'+os.path.basename(test_data_file.name) 51 | 52 | result_path,info = infer_model(model_name, select_lora, config_path, test_data_path) 53 | end_time = time.time() 54 | cost_time = end_time-start_time 55 | 56 | info = '用时:'+format_duration(cost_time)+f" ({round(cost_time,2)}秒)" if info=='success' else info 57 | return result_path,info 58 | 59 | def on_evaluate(model_name, select_lora, test_result_file, test_label_file): 60 | test_result_path = 'data/'+os.path.basename(test_result_file.name) 61 | test_label_path = 'data/'+os.path.basename( test_label_file.name) 62 | result_path = evaluate_model(test_result_path, test_label_path) 63 | return result_path 64 | 65 | def on_query(model_name,project_name, field_type, field_name, value_range,special_requirement, query): 66 | res = query_model(model_name,project_name, field_type, field_name, value_range,special_requirement, query) 67 | return res 68 | 69 | def on_stop(model_name,select_lora): 70 | res = stop_train_process() 71 | return res 72 | 73 | def upload_file(file): 74 | print('file',file) 75 | if not os.path.exists("data"): 76 | os.mkdir("data") 77 | filename = os.path.basename(file.name) 78 | shutil.move(file.name, "data/" + filename) 79 | # file_list首位插入新上传的文件 80 | filedir = "data/" + filename 81 | return filedir 82 | 83 | def change_lora_name_input(model_name,lora_name_en): 84 | if lora_name_en == "新建": 85 | return gr.update(visible=True), gr.update(visible=True), 'not avilable' 86 | else: 87 | file_status = f"已加载{lora_name_en}" 88 | model_update_time = get_model_update_time(model_name, lora_name_en) 89 | return gr.update(visible=False), gr.update(visible=False), model_update_time 90 | 91 | 92 | def add_lora(lora_name_en,lora_list): 93 | if lora_name_en in lora_list: 94 | print('名称冲突,不新建') 95 | return gr.update(visible=True,value=lora_name_en), gr.update(visible=False), gr.update(visible=False), lora_list 96 | else: 97 | return gr.update(visible=True, choices=[lora_name_en] + lora_list, value=lora_name_en), gr.update(visible=False), gr.update(visible=False),[lora_name_en] + lora_list 98 | 99 | 100 | def find_folders(directory): 101 | folders = [] 102 | for item in os.listdir(directory): 103 | item_path = os.path.join(directory, item) 104 | if os.path.isdir(item_path): 105 | folders.append(item) 106 | return folders 107 | 108 | 109 | def get_lora_init_list(model_name): 110 | model_file_name = llm_model_dict[model_name]['name'] 111 | model_dir = os.path.join(f"finetune", model_file_name,'checkpoints') 112 | if not os.path.exists(model_dir): 113 | os.makedirs(model_dir) 114 | lora_list = find_folders(model_dir) 115 | return lora_list 116 | 117 | 118 | def get_lora_list(model_name): 119 | model_file_name = llm_model_dict[model_name]['name'] 120 | model_dir = os.path.join(f"finetune", model_file_name,'checkpoints') 121 | if not os.path.exists(model_dir): 122 | os.makedirs(model_dir) 123 | lora_list = find_folders(model_dir) 124 | return gr.update(visible=True, choices=lora_list+ ['新建'], value=lora_list[0] if len(lora_list) > 0 else '新建'), lora_list + ['新建'] 125 | 126 | lora_init_list = get_lora_init_list(llm_model_dict_list[0]) 127 | 128 | webui_title = """ 129 | # 🎉病历结构化🎉 130 | 131 | 可以选择案例测试和[使用excel配置文件进行训练-预测-评估](https://zg0f0ipp6j.feishu.cn/wiki/XC16wwvGgiVSNbkzSPUczqFMn6e) 132 | """ 133 | 134 | def create_tab(): 135 | # 初始化 136 | with gr.Blocks() as demo: 137 | set_lora_list = gr.State(lora_init_list+ ['新建']) 138 | gr.Markdown(webui_title) 139 | with gr.Row(): 140 | with gr.Column(): 141 | model_name = gr.Radio(llm_model_dict_list, 142 | label="选择模型", 143 | value= llm_model_dict_list[0] if len(llm_model_dict_list)>0 else '暂无可选模型', 144 | interactive=True) 145 | with gr.Column(): 146 | select_lora = gr.Dropdown(set_lora_list.value, 147 | label= "选择或者新建一个Lora", 148 | value= set_lora_list.value[0] if len(set_lora_list.value) > 0 else '新建', 149 | interactive=True, 150 | visible=True) 151 | lora_name_en = gr.Textbox(label="请输入Lora英文名称,中间不能有空格,小写字母,单词间可用下划线分开", 152 | lines=1, 153 | interactive=True, 154 | visible=False) 155 | lora_add = gr.Button(value="确认添加Lora", visible=False) 156 | with gr.Row(): 157 | lastest_model = gr.Textbox(type="text", label='模型更新时间(请切换模型或项目刷新显示)') 158 | with gr.Tab("案例测试"): 159 | with gr.Column(): 160 | gr.Markdown(f"初次加载模型可能比较慢,后续会变快") 161 | field_type = gr.Radio(['单选','多选','提取'], 162 | label="字段类型", 163 | value='提取', 164 | interactive=True) 165 | field_name = gr.Textbox(label="字段名", 166 | lines=1, 167 | interactive=True) 168 | value_range = gr.Textbox(label="请输入值域,以','分隔开(对于提取不必输入值域)", 169 | lines=1, 170 | interactive=True) 171 | special_requirement= gr.Textbox(label="特殊说明,假如有的话请填上", 172 | lines=1, 173 | interactive=True) 174 | query = gr.Textbox(label="请输入原文", 175 | lines=1, 176 | interactive=True) 177 | query_button = gr.Button(label="获得结果") 178 | query_res = gr.Textbox(type="text", label='') 179 | 180 | with gr.Tab("训练-预测-评估", visible=False): 181 | gr.Markdown(f""" 182 | Step1:选择一个Lora 183 | Step2:根据任务选择训练 预测或评估,上传对应的参数文件或者数据标准文件,请等待文件上传成功后再开始执行!""") 184 | with gr.Row(): 185 | with gr.Column(): 186 | gr.Markdown("## 训练") 187 | train_config_file = gr.File(label="上传配置文件", file_types=['.xlsx']) 188 | train_data_file = gr.File(label="上传标注数据文件", file_types=['.xlsx']) 189 | train_button = gr.Button("开始训练", label="训练") 190 | kill_train_button = gr.Button("停止所有训练进程", label="训练") 191 | train_res = gr.Textbox(type="text", label='') 192 | 193 | 194 | with gr.Column(): 195 | gr.Markdown("## 预测") 196 | test_config_file = gr.File(label="上传配置文件", file_types=['.xlsx']) 197 | test_data_file = gr.File(label="上传测试数据文件", file_types=['.xlsx']) 198 | test_button = gr.Button(label="评估") 199 | test_res = gr.Textbox(type="text", label='') 200 | download_test = gr.File(label="下载结果文件") 201 | 202 | with gr.Column(): 203 | gr.Markdown("## 评估") 204 | test_result_file = gr.File(label="上传测试结果文件", file_types=['.xlsx']) 205 | test_label_file = gr.File(label="上传标准结果文件", file_types=['.xlsx']) 206 | evaluate_button = gr.Button(label="评估") 207 | download_evaluate = gr.File(label="下载评估结果") 208 | 209 | select_lora.change(fn=change_lora_name_input, 210 | inputs=[model_name,select_lora], 211 | outputs=[lora_name_en, lora_add,lastest_model]) 212 | lora_add.click(fn=add_lora, 213 | inputs=[lora_name_en,set_lora_list], 214 | outputs=[select_lora, lora_name_en, lora_add,set_lora_list]) 215 | model_name.change(fn=get_lora_list, inputs=[model_name], outputs=[select_lora, set_lora_list]) 216 | train_config_file.upload(upload_file, 217 | inputs=train_config_file) 218 | train_data_file.upload(upload_file, 219 | inputs=train_data_file) 220 | test_config_file.upload(upload_file, 221 | inputs=test_config_file) 222 | test_data_file.upload(upload_file, 223 | inputs=test_data_file) 224 | test_result_file.upload(upload_file, 225 | inputs=test_result_file) 226 | test_label_file.upload(upload_file, 227 | inputs=test_label_file) 228 | train_button.click(on_train, inputs=[model_name, select_lora, train_config_file, train_data_file],outputs=[train_res]) 229 | kill_train_button.click(on_stop, inputs=[model_name, select_lora],outputs=[train_res]) 230 | test_button.click(on_test,show_progress=True, inputs=[model_name, select_lora, test_config_file, test_data_file], outputs=[download_test,test_res]) 231 | evaluate_button.click(on_evaluate,show_progress=True, inputs=[model_name, select_lora,test_result_file, test_label_file], outputs=[download_evaluate]) 232 | query_button.click(on_query,show_progress=True, inputs=[model_name, select_lora, field_type, field_name, value_range, special_requirement, query], outputs=[query_res]) 233 | return demo 234 | 235 | tab = create_tab() 236 | 237 | if __name__ == "__main__": 238 | tab.queue(concurrency_count=5).launch(server_name='0.0.0.0',server_port=33366,share=True, inbrowser=True) # 239 | -------------------------------------------------------------------------------- /config/common_config.py: -------------------------------------------------------------------------------- 1 | 2 | ## 这里配置本地可用的开源模型 3 | llm_model_dict = { 4 | "PULSE": {"name": "pulse", 5 | "model_path": "/path/to/your/model", 6 | "template":"default", 7 | "lora_target":"query_key_value", 8 | "per_device_train_batch_size":2 9 | }, 10 | "InternLM": {"name": "internlm", 11 | "model_path": "/path/to/your/model", 12 | "template":"intern", 13 | "lora_target":"q_proj,v_proj", 14 | "per_device_train_batch_size":2 15 | }, 16 | "ChatGLM2": {"name": "chatglm2", 17 | "model_path": "/path/to/your/model", 18 | "template":"chatglm2", 19 | "lora_target":"query_key_value", 20 | "per_device_train_batch_size":4 21 | }, 22 | "ChatGLM3": {"name": "chatglm3", 23 | "model_path": "/path/to/your/model", 24 | "template":"chatglm3", 25 | "lora_target":"query_key_value", 26 | "per_device_train_batch_size":4 27 | }, 28 | } 29 | 30 | # 找到 profile.d/conda.sh 文件的绝对路径,填进来 31 | conda_env_file = '/path-to-your-conda/etc/profile.d/conda.sh' 32 | 33 | # 生成参数 34 | max_length=1500 35 | do_sample=False 36 | temperature=0 -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: llm_structure 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2023.08.22=h06a4308_0 8 | - ld_impl_linux-64=2.38=h1181459_1 9 | - libffi=3.4.4=h6a678d5_0 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=3.0.12=h7f8727e_0 15 | - pip=23.3=py39h06a4308_0 16 | - python=3.9.18=h955ad1f_0 17 | - readline=8.2=h5eee18b_0 18 | - setuptools=68.0.0=py39h06a4308_0 19 | - sqlite=3.41.2=h5eee18b_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - tzdata=2023c=h04d1e81_0 22 | - wheel=0.41.2=py39h06a4308_0 23 | - xz=5.4.2=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - accelerate==0.24.1 27 | - aiofiles==23.2.1 28 | - aiohttp==3.8.6 29 | - aiosignal==1.3.1 30 | - altair==5.1.2 31 | - anyio==4.0.0 32 | - async-timeout==4.0.3 33 | - attrs==23.1.0 34 | - certifi==2023.7.22 35 | - charset-normalizer==3.3.2 36 | - click==8.1.7 37 | - cmake==3.27.7 38 | - contourpy==1.2.0 39 | - cycler==0.12.1 40 | - datasets==2.14.6 41 | - dill==0.3.7 42 | - docstring-parser==0.15 43 | - et-xmlfile==1.1.0 44 | - fastapi==0.95.1 45 | - ffmpy==0.3.1 46 | - filelock==3.13.1 47 | - fire==0.5.0 48 | - fonttools==4.44.0 49 | - frozenlist==1.4.0 50 | - fsspec==2023.10.0 51 | - gradio==3.50.2 52 | - gradio-client==0.6.1 53 | - h11==0.14.0 54 | - httpcore==1.0.1 55 | - httpx==0.25.1 56 | - huggingface-hub==0.18.0 57 | - idna==3.4 58 | - importlib-resources==6.1.1 59 | - jieba==0.42.1 60 | - jinja2==3.1.2 61 | - joblib==1.3.2 62 | - jsonschema==4.19.2 63 | - jsonschema-specifications==2023.7.1 64 | - kiwisolver==1.4.5 65 | - lit==17.0.4 66 | - markdown-it-py==3.0.0 67 | - markupsafe==2.1.3 68 | - matplotlib==3.8.1 69 | - mdurl==0.1.2 70 | - mpmath==1.3.0 71 | - multidict==6.0.4 72 | - multiprocess==0.70.15 73 | - networkx==3.2.1 74 | - nltk==3.8.1 75 | - numpy==1.26.1 76 | - nvidia-cublas-cu11==11.10.3.66 77 | - nvidia-cuda-cupti-cu11==11.7.101 78 | - nvidia-cuda-nvrtc-cu11==11.7.99 79 | - nvidia-cuda-runtime-cu11==11.7.99 80 | - nvidia-cudnn-cu11==8.5.0.96 81 | - nvidia-cufft-cu11==10.9.0.58 82 | - nvidia-curand-cu11==10.2.10.91 83 | - nvidia-cusolver-cu11==11.4.0.1 84 | - nvidia-cusparse-cu11==11.7.4.91 85 | - nvidia-nccl-cu11==2.14.3 86 | - nvidia-nvtx-cu11==11.7.91 87 | - openpyxl==3.0.10 88 | - orjson==3.9.10 89 | - pandas==1.5.3 90 | - peft==0.6.0 91 | - pillow==10.1.0 92 | - protobuf==4.25.0 93 | - pyarrow==14.0.0 94 | - pydantic==1.10.11 95 | - pydub==0.25.1 96 | - pynvml==11.5.0 97 | - pyparsing==3.1.1 98 | - python-multipart==0.0.6 99 | - pytz==2023.3.post1 100 | - pyyaml==6.0.1 101 | - referencing==0.30.2 102 | - regex==2023.10.3 103 | - requests==2.31.0 104 | - rich==13.6.0 105 | - rouge-chinese==1.0.3 106 | - rpds-py==0.12.0 107 | - safetensors==0.4.0 108 | - scikit-learn==1.2.2 109 | - scipy==1.10.1 110 | - semantic-version==2.10.0 111 | - sentencepiece==0.1.99 112 | - shtab==1.6.4 113 | - sniffio==1.3.0 114 | - sse-starlette==1.6.5 115 | - starlette==0.26.1 116 | - sympy==1.12 117 | - termcolor==2.3.0 118 | - threadpoolctl==3.2.0 119 | - tiktoken==0.5.1 120 | - tokenizers==0.13.3 121 | - toolz==0.12.0 122 | - torch==2.0.0 123 | - torchvision==0.15.1 124 | - tqdm==4.66.1 125 | - transformers==4.33.2 126 | - triton==2.0.0 127 | - trl==0.7.2 128 | - tyro==0.5.12 129 | - urllib3==2.0.7 130 | - uvicorn==0.24.0.post1 131 | - websockets==11.0.3 132 | - xlsxwriter==3.1.9 133 | - xxhash==3.4.1 134 | - yarl==1.9.2 -------------------------------------------------------------------------------- /img/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/1.jpg -------------------------------------------------------------------------------- /img/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/2.jpg -------------------------------------------------------------------------------- /img/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/3.jpg -------------------------------------------------------------------------------- /img/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/4.jpg -------------------------------------------------------------------------------- /img/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/img/5.jpg -------------------------------------------------------------------------------- /llmtuner/__init__.py: -------------------------------------------------------------------------------- 1 | # Level: api, webui > chat > tuner > dsets > extras, hparams 2 | 3 | from llmtuner.api import create_app 4 | from llmtuner.chat import ChatModel 5 | from llmtuner.tuner import export_model, run_exp 6 | from llmtuner.webui import create_ui, create_web_demo 7 | 8 | 9 | __version__ = "0.2.0" 10 | -------------------------------------------------------------------------------- /llmtuner/api/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.api.app import create_app 2 | -------------------------------------------------------------------------------- /llmtuner/api/app.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uvicorn 3 | from fastapi import FastAPI, HTTPException, status 4 | from fastapi.middleware.cors import CORSMiddleware 5 | from contextlib import asynccontextmanager 6 | from sse_starlette import EventSourceResponse 7 | from typing import List, Tuple 8 | from pydantic import BaseModel 9 | 10 | from llmtuner.extras.misc import torch_gc 11 | from llmtuner.chat import ChatModel 12 | from llmtuner.api.protocol import ( 13 | Role, 14 | Finish, 15 | ModelCard, 16 | ModelList, 17 | ChatMessage, 18 | DeltaMessage, 19 | ChatCompletionRequest, 20 | ChatCompletionResponse, 21 | ChatCompletionStreamResponse, 22 | ChatCompletionResponseChoice, 23 | ChatCompletionResponseStreamChoice, 24 | ChatCompletionResponseUsage 25 | ) 26 | 27 | 28 | @asynccontextmanager 29 | async def lifespan(app: FastAPI): # collects GPU memory 30 | yield 31 | torch_gc() 32 | 33 | 34 | def to_json(data: BaseModel) -> str: 35 | try: 36 | return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False) 37 | except: 38 | return data.json(exclude_unset=True, ensure_ascii=False) 39 | 40 | 41 | def create_app(chat_model: ChatModel) -> FastAPI: 42 | app = FastAPI(lifespan=lifespan) 43 | 44 | app.add_middleware( 45 | CORSMiddleware, 46 | allow_origins=["*"], 47 | allow_credentials=True, 48 | allow_methods=["*"], 49 | allow_headers=["*"], 50 | ) 51 | 52 | @app.get("/v1/models", response_model=ModelList) 53 | async def list_models(): 54 | model_card = ModelCard(id="gpt-3.5-turbo") 55 | return ModelList(data=[model_card]) 56 | 57 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) 58 | async def create_chat_completion(request: ChatCompletionRequest): 59 | if len(request.messages) < 1 or request.messages[-1].role != Role.USER: 60 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") 61 | 62 | query = request.messages[-1].content 63 | prev_messages = request.messages[:-1] 64 | if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: 65 | system = prev_messages.pop(0).content 66 | else: 67 | system = None 68 | 69 | history = [] 70 | if len(prev_messages) % 2 == 0: 71 | for i in range(0, len(prev_messages), 2): 72 | if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: 73 | history.append([prev_messages[i].content, prev_messages[i+1].content]) 74 | else: 75 | raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") 76 | 77 | if request.stream: 78 | generate = predict(query, history, system, request) 79 | return EventSourceResponse(generate, media_type="text/event-stream") 80 | 81 | response, (prompt_length, response_length) = chat_model.chat( 82 | query, history, system, 83 | do_sample=request.do_sample, 84 | temperature=request.temperature, 85 | top_p=request.top_p, 86 | max_new_tokens=request.max_tokens, 87 | num_return_sequences=request.n 88 | ) 89 | 90 | usage = ChatCompletionResponseUsage( 91 | prompt_tokens=prompt_length, 92 | completion_tokens=response_length, 93 | total_tokens=prompt_length+response_length 94 | ) 95 | 96 | choices = [ChatCompletionResponseChoice( 97 | index=i, 98 | message=ChatMessage(role=Role.ASSISTANT, content=choice), 99 | finish_reason=Finish.STOP 100 | ) for i, choice in enumerate(response)] 101 | 102 | return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) 103 | 104 | async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): 105 | choice_data = ChatCompletionResponseStreamChoice( 106 | index=0, 107 | delta=DeltaMessage(role=Role.ASSISTANT), 108 | finish_reason=None 109 | ) 110 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 111 | yield to_json(chunk) 112 | 113 | for new_text in chat_model.stream_chat( 114 | query, history, system, 115 | do_sample=request.do_sample, 116 | temperature=request.temperature, 117 | top_p=request.top_p, 118 | max_new_tokens=request.max_tokens 119 | ): 120 | if len(new_text) == 0: 121 | continue 122 | 123 | choice_data = ChatCompletionResponseStreamChoice( 124 | index=0, 125 | delta=DeltaMessage(content=new_text), 126 | finish_reason=None 127 | ) 128 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 129 | yield to_json(chunk) 130 | 131 | choice_data = ChatCompletionResponseStreamChoice( 132 | index=0, 133 | delta=DeltaMessage(), 134 | finish_reason=Finish.STOP 135 | ) 136 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 137 | yield to_json(chunk) 138 | yield "[DONE]" 139 | 140 | return app 141 | 142 | 143 | if __name__ == "__main__": 144 | chat_model = ChatModel() 145 | app = create_app(chat_model) 146 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) 147 | -------------------------------------------------------------------------------- /llmtuner/api/protocol.py: -------------------------------------------------------------------------------- 1 | import time 2 | from enum import Enum 3 | from pydantic import BaseModel, Field 4 | from typing import List, Optional 5 | 6 | 7 | class Role(str, Enum): 8 | USER = "user" 9 | ASSISTANT = "assistant" 10 | SYSTEM = "system" 11 | 12 | 13 | class Finish(str, Enum): 14 | STOP = "stop" 15 | LENGTH = "length" 16 | 17 | 18 | class ModelCard(BaseModel): 19 | id: str 20 | object: Optional[str] = "model" 21 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 22 | owned_by: Optional[str] = "owner" 23 | 24 | 25 | class ModelList(BaseModel): 26 | object: Optional[str] = "list" 27 | data: Optional[List[ModelCard]] = [] 28 | 29 | 30 | class ChatMessage(BaseModel): 31 | role: Role 32 | content: str 33 | 34 | 35 | class DeltaMessage(BaseModel): 36 | role: Optional[Role] = None 37 | content: Optional[str] = None 38 | 39 | 40 | class ChatCompletionRequest(BaseModel): 41 | model: str 42 | messages: List[ChatMessage] 43 | do_sample: Optional[bool] = True 44 | temperature: Optional[float] = None 45 | top_p: Optional[float] = None 46 | n: Optional[int] = 1 47 | max_tokens: Optional[int] = None 48 | stream: Optional[bool] = False 49 | 50 | 51 | class ChatCompletionResponseChoice(BaseModel): 52 | index: int 53 | message: ChatMessage 54 | finish_reason: Finish 55 | 56 | 57 | class ChatCompletionResponseStreamChoice(BaseModel): 58 | index: int 59 | delta: DeltaMessage 60 | finish_reason: Optional[Finish] = None 61 | 62 | 63 | class ChatCompletionResponseUsage(BaseModel): 64 | prompt_tokens: int 65 | completion_tokens: int 66 | total_tokens: int 67 | 68 | 69 | class ChatCompletionResponse(BaseModel): 70 | id: Optional[str] = "chatcmpl-default" 71 | object: Optional[str] = "chat.completion" 72 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 73 | model: str 74 | choices: List[ChatCompletionResponseChoice] 75 | usage: ChatCompletionResponseUsage 76 | 77 | 78 | class ChatCompletionStreamResponse(BaseModel): 79 | id: Optional[str] = "chatcmpl-default" 80 | object: Optional[str] = "chat.completion.chunk" 81 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 82 | model: str 83 | choices: List[ChatCompletionResponseStreamChoice] 84 | -------------------------------------------------------------------------------- /llmtuner/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.chat.stream_chat import ChatModel 2 | -------------------------------------------------------------------------------- /llmtuner/chat/stream_chat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Generator, List, Optional, Tuple 3 | from threading import Thread 4 | from transformers import GenerationConfig, TextIteratorStreamer 5 | 6 | from llmtuner.extras.misc import dispatch_model, get_logits_processor 7 | from llmtuner.extras.template import get_template_and_fix_tokenizer 8 | from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer 9 | import re 10 | 11 | class ChatModel: 12 | 13 | def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: 14 | model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) 15 | self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 16 | self.tokenizer.padding_side = "left" 17 | self.model = dispatch_model(self.model) 18 | self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) 19 | self.system_prompt = data_args.system_prompt 20 | 21 | def process_args( 22 | self, 23 | query: str, 24 | history: Optional[List[Tuple[str, str]]] = None, 25 | system: Optional[str] = None, 26 | **input_kwargs 27 | ) -> Tuple[Dict[str, Any], int]: 28 | system = system or self.system_prompt 29 | prompt, _ = self.template.encode_oneturn( 30 | tokenizer=self.tokenizer, query=query, resp="", history=history, system=system 31 | ) 32 | prompt_length = len(prompt) 33 | input_ids = torch.tensor([prompt], device=self.model.device) 34 | 35 | do_sample = input_kwargs.pop("do_sample", None) 36 | temperature = input_kwargs.pop("temperature", None) 37 | top_p = input_kwargs.pop("top_p", None) 38 | top_k = input_kwargs.pop("top_k", None) 39 | num_return_sequences = input_kwargs.pop("num_return_sequences", None) 40 | repetition_penalty = input_kwargs.pop("repetition_penalty", None) 41 | max_length = input_kwargs.pop("max_length", None) 42 | max_new_tokens = input_kwargs.pop("max_new_tokens", None) 43 | 44 | generating_args = self.generating_args.to_dict() 45 | generating_args.update(dict( 46 | do_sample=do_sample if do_sample is not None else generating_args["do_sample"], 47 | temperature=temperature or generating_args["temperature"], 48 | top_p=top_p or generating_args["top_p"], 49 | top_k=top_k or generating_args["top_k"], 50 | num_return_sequences=num_return_sequences or 1, 51 | repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], 52 | eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, 53 | pad_token_id=self.tokenizer.pad_token_id 54 | )) 55 | 56 | if isinstance(num_return_sequences, int) and num_return_sequences > 1: 57 | generating_args["do_sample"] = True 58 | 59 | if max_length: 60 | generating_args.pop("max_new_tokens", None) 61 | generating_args["max_length"] = max_length 62 | 63 | if max_new_tokens: 64 | generating_args.pop("max_length", None) 65 | generating_args["max_new_tokens"] = max_new_tokens 66 | 67 | gen_kwargs = dict( 68 | inputs=input_ids, 69 | generation_config=GenerationConfig(**generating_args), 70 | logits_processor=get_logits_processor() 71 | ) 72 | 73 | return gen_kwargs, prompt_length 74 | 75 | @torch.inference_mode() 76 | def chat( 77 | self, 78 | query: str, 79 | history: Optional[List[Tuple[str, str]]] = None, 80 | system: Optional[str] = None, 81 | **input_kwargs 82 | ) -> Tuple[List[str], Tuple[int, int]]: 83 | gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) 84 | generate_output = self.model.generate(**gen_kwargs) 85 | outputs = generate_output.tolist()[0][prompt_length:] 86 | response = self.tokenizer.decode(outputs, skip_special_tokens=True) 87 | response_length = len(outputs) 88 | response=re.sub(r'Helper: ?', '', response) 89 | return response, (prompt_length, response_length) 90 | 91 | @torch.inference_mode() 92 | def batch_chat( 93 | self, 94 | query: List[str], 95 | history: Optional[List[Tuple[str, str]]] = None, 96 | system: Optional[str] = None, 97 | **input_kwargs 98 | ) -> Tuple[List[str], Tuple[int, int]]: 99 | gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) 100 | generate_output = self.model.generate(**gen_kwargs) 101 | response_ids = generate_output[:, prompt_length:] 102 | response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) 103 | response_length = 0 104 | for i in range(len(response_ids)): 105 | eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() 106 | response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) 107 | return response, (prompt_length, response_length) 108 | 109 | @torch.inference_mode() 110 | def stream_chat( 111 | self, 112 | query: str, 113 | history: Optional[List[Tuple[str, str]]] = None, 114 | system: Optional[str] = None, 115 | **input_kwargs 116 | ) -> Generator[str, None, None]: 117 | gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) 118 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 119 | gen_kwargs["streamer"] = streamer 120 | 121 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs) 122 | thread.start() 123 | 124 | yield from streamer 125 | 126 | -------------------------------------------------------------------------------- /llmtuner/dsets/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.dsets.loader import get_dataset 2 | from llmtuner.dsets.preprocess import preprocess_dataset 3 | from llmtuner.dsets.utils import split_dataset 4 | -------------------------------------------------------------------------------- /llmtuner/dsets/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING, List, Union 3 | 4 | from datasets import concatenate_datasets, interleave_datasets, load_dataset 5 | 6 | from llmtuner.dsets.utils import checksum, EXT2TYPE 7 | from llmtuner.extras.logging import get_logger 8 | 9 | if TYPE_CHECKING: 10 | from datasets import Dataset, IterableDataset 11 | from llmtuner.hparams import ModelArguments, DataArguments 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def get_dataset( 18 | model_args: "ModelArguments", 19 | data_args: "DataArguments" 20 | ) -> Union["Dataset", "IterableDataset"]: 21 | max_samples = data_args.max_samples 22 | all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets 23 | 24 | for dataset_attr in data_args.dataset_list: 25 | logger.info("Loading dataset {}...".format(dataset_attr)) 26 | 27 | if dataset_attr.load_from == "hf_hub": 28 | data_path = dataset_attr.dataset_name 29 | data_files = None 30 | elif dataset_attr.load_from == "script": 31 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) 32 | data_files = None 33 | elif dataset_attr.load_from == "file": 34 | data_path = None 35 | data_files: List[str] = [] 36 | 37 | if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory 38 | for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): 39 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) 40 | if data_path is None: 41 | data_path = EXT2TYPE.get(file_name.split(".")[-1], None) 42 | else: 43 | assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." 44 | elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file 45 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) 46 | data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) 47 | else: 48 | raise ValueError("File not found.") 49 | 50 | assert data_path, "File extension must be txt, csv, json or jsonl." 51 | checksum(data_files, dataset_attr.dataset_sha1) 52 | else: 53 | raise NotImplementedError 54 | 55 | dataset = load_dataset( 56 | data_path, 57 | data_files=data_files, 58 | split=data_args.split, 59 | cache_dir=model_args.cache_dir, 60 | streaming=data_args.streaming, 61 | use_auth_token=True if model_args.use_auth_token else None 62 | ) 63 | 64 | if max_samples is not None: 65 | max_samples_temp = min(len(dataset), max_samples) 66 | dataset = dataset.select(range(max_samples_temp)) 67 | 68 | # TODO: adapt to the sharegpt format 69 | 70 | for column_name in ["prompt", "query", "response", "history"]: # align datasets 71 | if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: 72 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) 73 | 74 | if dataset_attr.system_prompt: # add system prompt 75 | system_prompt = dataset_attr.system_prompt 76 | if data_args.streaming: 77 | dataset = dataset.map(lambda _: {"system": system_prompt}) 78 | else: 79 | dataset = dataset.add_column("system", [system_prompt] * len(dataset)) 80 | 81 | all_datasets.append(dataset) 82 | 83 | if len(data_args.dataset_list) == 1: 84 | return all_datasets[0] 85 | elif data_args.mix_strategy == "concat": 86 | if data_args.streaming: 87 | logger.warning("The samples between different datasets will not be mixed in streaming mode.") 88 | return concatenate_datasets(all_datasets) 89 | elif data_args.mix_strategy.startswith("interleave"): 90 | if not data_args.streaming: 91 | logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") 92 | return interleave_datasets( 93 | datasets=all_datasets, 94 | probabilities=data_args.interleave_probs, 95 | seed=data_args.seed, 96 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" 97 | ) 98 | else: 99 | raise ValueError("Unknown mixing strategy.") 100 | -------------------------------------------------------------------------------- /llmtuner/dsets/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import TYPE_CHECKING, Dict, List, Optional, Union 3 | 4 | from llmtuner.extras.logging import get_logger 5 | 6 | if TYPE_CHECKING: 7 | from datasets import Dataset, IterableDataset 8 | from transformers import TrainingArguments 9 | from llmtuner.hparams import DataArguments 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | EXT2TYPE = { 16 | "csv": "csv", 17 | "json": "json", 18 | "jsonl": "json", 19 | "txt": "text" 20 | } 21 | 22 | 23 | def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: 24 | if file_sha1 is None: 25 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") 26 | return 27 | 28 | if len(data_files) != 1: 29 | logger.warning("Checksum failed: too many files.") 30 | return 31 | 32 | with open(data_files[0], "rb") as f: 33 | sha1 = hashlib.sha1(f.read()).hexdigest() 34 | if sha1 != file_sha1: 35 | logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) 36 | 37 | 38 | def split_dataset( 39 | dataset: Union["Dataset", "IterableDataset"], 40 | data_args: "DataArguments", 41 | training_args: "TrainingArguments" 42 | ) -> Dict[str, "Dataset"]: 43 | if training_args.do_train: 44 | if data_args.val_size > 1e-6: # Split the dataset 45 | if data_args.streaming: 46 | val_set = dataset.take(int(data_args.val_size)) 47 | train_set = dataset.skip(int(data_args.val_size)) 48 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 49 | return {"train_dataset": train_set, "eval_dataset": val_set} 50 | else: 51 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size 52 | dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) 53 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} 54 | else: 55 | if data_args.streaming: 56 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 57 | return {"train_dataset": dataset} 58 | else: # do_eval or do_predict 59 | return {"eval_dataset": dataset} 60 | -------------------------------------------------------------------------------- /llmtuner/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/llmtuner/extras/__init__.py -------------------------------------------------------------------------------- /llmtuner/extras/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from typing import TYPE_CHECKING 5 | from datetime import timedelta 6 | 7 | from transformers import TrainerCallback 8 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR 9 | 10 | from llmtuner.extras.constants import LOG_FILE_NAME 11 | from llmtuner.extras.logging import get_logger 12 | 13 | if TYPE_CHECKING: 14 | from transformers import TrainingArguments, TrainerState, TrainerControl 15 | 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class SavePeftModelCallback(TrainerCallback): 21 | 22 | def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 23 | r""" 24 | Event called after a checkpoint save. 25 | """ 26 | if args.should_save: 27 | output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) 28 | model = kwargs.pop("model") 29 | if getattr(model, "is_peft_model", False): 30 | getattr(model, "pretrained_model").save_pretrained(output_dir) 31 | 32 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 33 | r""" 34 | Event called at the end of training. 35 | """ 36 | if args.should_save: 37 | model = kwargs.pop("model") 38 | if getattr(model, "is_peft_model", False): 39 | getattr(model, "pretrained_model").save_pretrained(args.output_dir) 40 | 41 | 42 | class LogCallback(TrainerCallback): 43 | 44 | def __init__(self, runner=None): 45 | self.runner = runner 46 | self.in_training = False 47 | self.start_time = time.time() 48 | self.cur_steps = 0 49 | self.max_steps = 0 50 | self.elapsed_time = "" 51 | self.remaining_time = "" 52 | 53 | def timing(self): 54 | cur_time = time.time() 55 | elapsed_time = cur_time - self.start_time 56 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 57 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step 58 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) 59 | self.remaining_time = str(timedelta(seconds=int(remaining_time))) 60 | 61 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 62 | r""" 63 | Event called at the beginning of training. 64 | """ 65 | if state.is_local_process_zero: 66 | self.in_training = True 67 | self.start_time = time.time() 68 | self.max_steps = state.max_steps 69 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: 70 | logger.warning("Previous log file in this folder will be deleted.") 71 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) 72 | 73 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 74 | r""" 75 | Event called at the end of training. 76 | """ 77 | if state.is_local_process_zero: 78 | self.in_training = False 79 | self.cur_steps = 0 80 | self.max_steps = 0 81 | 82 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 83 | r""" 84 | Event called at the end of an substep during gradient accumulation. 85 | """ 86 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted: 87 | control.should_epoch_stop = True 88 | control.should_training_stop = True 89 | 90 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 91 | r""" 92 | Event called at the end of a training step. 93 | """ 94 | if state.is_local_process_zero: 95 | self.cur_steps = state.global_step 96 | self.timing() 97 | if self.runner is not None and self.runner.aborted: 98 | control.should_epoch_stop = True 99 | control.should_training_stop = True 100 | 101 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 102 | r""" 103 | Event called after an evaluation phase. 104 | """ 105 | if state.is_local_process_zero and not self.in_training: 106 | self.cur_steps = 0 107 | self.max_steps = 0 108 | 109 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): 110 | r""" 111 | Event called after a successful prediction. 112 | """ 113 | if state.is_local_process_zero and not self.in_training: 114 | self.cur_steps = 0 115 | self.max_steps = 0 116 | 117 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: 118 | r""" 119 | Event called after logging the last logs. 120 | """ 121 | if not state.is_local_process_zero: 122 | return 123 | 124 | logs = dict( 125 | current_steps=self.cur_steps, 126 | total_steps=self.max_steps, 127 | loss=state.log_history[-1].get("loss", None), 128 | eval_loss=state.log_history[-1].get("eval_loss", None), 129 | predict_loss=state.log_history[-1].get("predict_loss", None), 130 | reward=state.log_history[-1].get("reward", None), 131 | learning_rate=state.log_history[-1].get("learning_rate", None), 132 | epoch=state.log_history[-1].get("epoch", None), 133 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, 134 | elapsed_time=self.elapsed_time, 135 | remaining_time=self.remaining_time 136 | ) 137 | if self.runner is not None: 138 | logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( 139 | logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 140 | )) 141 | 142 | os.makedirs(args.output_dir, exist_ok=True) 143 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: 144 | f.write(json.dumps(logs) + "\n") 145 | 146 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 147 | r""" 148 | Event called after a prediction step. 149 | """ 150 | eval_dataloader = kwargs.pop("eval_dataloader", None) 151 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: 152 | if self.max_steps == 0: 153 | self.max_steps = len(eval_dataloader) 154 | self.cur_steps += 1 155 | self.timing() 156 | -------------------------------------------------------------------------------- /llmtuner/extras/constants.py: -------------------------------------------------------------------------------- 1 | IGNORE_INDEX = -100 2 | 3 | LOG_FILE_NAME = "trainer_log.jsonl" 4 | 5 | LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"] 6 | 7 | METHODS = ["full", "freeze", "lora"] 8 | 9 | TRAINING_STAGES = { 10 | "Supervised Fine-Tuning": "sft", 11 | "Reward Modeling": "rm", 12 | "PPO": "ppo", 13 | "DPO": "dpo", 14 | "Pre-Training": "pt" 15 | } 16 | 17 | SUPPORTED_MODELS = { 18 | "LLaMA-7B": "huggyllama/llama-7b", 19 | "LLaMA-13B": "huggyllama/llama-13b", 20 | "LLaMA-30B": "huggyllama/llama-30b", 21 | "LLaMA-65B": "huggyllama/llama-65b", 22 | "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", 23 | "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", 24 | "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", 25 | "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", 26 | "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", 27 | "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", 28 | "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", 29 | "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", 30 | "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", 31 | "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", 32 | "BLOOM-560M": "bigscience/bloom-560m", 33 | "BLOOM-3B": "bigscience/bloom-3b", 34 | "BLOOM-7B1": "bigscience/bloom-7b1", 35 | "BLOOMZ-560M": "bigscience/bloomz-560m", 36 | "BLOOMZ-3B": "bigscience/bloomz-3b", 37 | "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", 38 | "Falcon-7B": "tiiuae/falcon-7b", 39 | "Falcon-40B": "tiiuae/falcon-40b", 40 | "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", 41 | "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", 42 | "Baichuan-7B": "baichuan-inc/Baichuan-7B", 43 | "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", 44 | "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", 45 | "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", 46 | "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", 47 | "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", 48 | "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", 49 | "InternLM-7B": "internlm/internlm-7b", 50 | "InternLM-20B": "internlm/internlm-20b", 51 | "InternLM-7B-Chat": "internlm/internlm-chat-7b", 52 | "InternLM-20B-Chat": "internlm/internlm-chat-20b", 53 | "Qwen-7B": "Qwen/Qwen-7B", 54 | "Qwen-14B": "Qwen/Qwen-14B", 55 | "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", 56 | "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", 57 | "XVERSE-13B": "xverse/XVERSE-13B", 58 | "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat", 59 | "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b", 60 | "ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base", 61 | "ChatGLM3-6B-Chat": "THUDM/chatglm3-6b", 62 | "Phi1.5-1.3B": "microsoft/phi-1_5" 63 | } 64 | 65 | DEFAULT_MODULE = { 66 | "LLaMA": "q_proj,v_proj", 67 | "LLaMA2": "q_proj,v_proj", 68 | "ChineseLLaMA2": "q_proj,v_proj", 69 | "BLOOM": "query_key_value", 70 | "BLOOMZ": "query_key_value", 71 | "Falcon": "query_key_value", 72 | "Baichuan": "W_pack", 73 | "Baichuan2": "W_pack", 74 | "InternLM": "q_proj,v_proj", 75 | "Qwen": "c_attn", 76 | "XVERSE": "q_proj,v_proj", 77 | "ChatGLM2": "query_key_value", 78 | "ChatGLM3": "query_key_value", 79 | "Phi1.5": "Wqkv" 80 | } 81 | 82 | DEFAULT_TEMPLATE = { 83 | "LLaMA2": "llama2", 84 | "ChineseLLaMA2": "llama2_zh", 85 | "Baichuan": "baichuan", 86 | "Baichuan2": "baichuan2", 87 | "InternLM": "intern", 88 | "Qwen": "chatml", 89 | "XVERSE": "xverse", 90 | "ChatGLM2": "chatglm2", 91 | "ChatGLM3": "chatglm3" 92 | } 93 | -------------------------------------------------------------------------------- /llmtuner/extras/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | 5 | class LoggerHandler(logging.Handler): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.log = "" 10 | 11 | def reset(self): 12 | self.log = "" 13 | 14 | def emit(self, record): 15 | if record.name == "httpx": 16 | return 17 | log_entry = self.format(record) 18 | self.log += log_entry 19 | self.log += "\n\n" 20 | 21 | 22 | def reset_logging(): 23 | r""" 24 | Removes basic config of root logger 25 | """ 26 | root = logging.getLogger() 27 | list(map(root.removeHandler, root.handlers)) 28 | list(map(root.removeFilter, root.filters)) 29 | 30 | 31 | def get_logger(name: str) -> logging.Logger: 32 | formatter = logging.Formatter( 33 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S" 35 | ) 36 | handler = logging.StreamHandler(sys.stdout) 37 | handler.setFormatter(formatter) 38 | 39 | logger = logging.getLogger(name) 40 | logger.setLevel(logging.INFO) 41 | logger.addHandler(handler) 42 | 43 | return logger 44 | -------------------------------------------------------------------------------- /llmtuner/extras/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | from typing import TYPE_CHECKING, Tuple 4 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList 5 | 6 | try: 7 | from transformers.utils import ( 8 | is_torch_bf16_cpu_available, 9 | is_torch_bf16_gpu_available, 10 | is_torch_cuda_available, 11 | is_torch_npu_available 12 | ) 13 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() 14 | _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available 15 | except ImportError: 16 | _is_fp16_available = torch.cuda.is_available() 17 | _is_bf16_available = torch.cuda.is_bf16_supported() 18 | 19 | if TYPE_CHECKING: 20 | from transformers.modeling_utils import PreTrainedModel 21 | 22 | 23 | class AverageMeter: 24 | r""" 25 | Computes and stores the average and current value. 26 | """ 27 | def __init__(self): 28 | self.reset() 29 | 30 | def reset(self): 31 | self.val = 0 32 | self.avg = 0 33 | self.sum = 0 34 | self.count = 0 35 | 36 | def update(self, val, n=1): 37 | self.val = val 38 | self.sum += val * n 39 | self.count += n 40 | self.avg = self.sum / self.count 41 | 42 | 43 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: 44 | r""" 45 | Returns the number of trainable parameters and number of all parameters in the model. 46 | """ 47 | trainable_params, all_param = 0, 0 48 | for param in model.parameters(): 49 | num_params = param.numel() 50 | # if using DS Zero 3 and the weights are initialized empty 51 | if num_params == 0 and hasattr(param, "ds_numel"): 52 | num_params = param.ds_numel 53 | 54 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 55 | if param.__class__.__name__ == "Params4bit": 56 | num_params = num_params * 2 57 | 58 | all_param += num_params 59 | if param.requires_grad: 60 | trainable_params += num_params 61 | 62 | return trainable_params, all_param 63 | 64 | 65 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: 66 | r""" 67 | Infers the optimal dtype according to the model_dtype and device compatibility. 68 | """ 69 | if _is_bf16_available and model_dtype == torch.bfloat16: 70 | return torch.bfloat16 71 | elif _is_fp16_available: 72 | return torch.float16 73 | else: 74 | return torch.float32 75 | 76 | 77 | def get_logits_processor() -> LogitsProcessorList: 78 | r""" 79 | Gets logits processor that removes NaN and Inf logits. 80 | """ 81 | logits_processor = LogitsProcessorList() 82 | logits_processor.append(InfNanRemoveLogitsProcessor()) 83 | return logits_processor 84 | 85 | 86 | def torch_gc() -> None: 87 | r""" 88 | Collects GPU memory. 89 | """ 90 | gc.collect() 91 | if torch.cuda.is_available(): 92 | torch.cuda.empty_cache() 93 | torch.cuda.ipc_collect() 94 | 95 | 96 | def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": 97 | r""" 98 | Dispatches a pre-trained model to GPUs with balanced memory. 99 | Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 100 | """ 101 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing 102 | return model 103 | 104 | if torch.cuda.device_count() > 1: 105 | from accelerate import dispatch_model 106 | from accelerate.utils import infer_auto_device_map, get_balanced_memory 107 | 108 | if model._no_split_modules is None: 109 | raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") 110 | 111 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} 112 | max_memory = get_balanced_memory(model, **kwargs) 113 | # Make sure tied weights are tied before creating the device map. 114 | model.tie_weights() 115 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) 116 | return dispatch_model(model, device_map) 117 | else: 118 | return model.cuda() 119 | -------------------------------------------------------------------------------- /llmtuner/extras/patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JuneYaooo/llm_structure_tool/e142cf9dd6a85cffecc291d7ebe12166bea73d72/llmtuner/extras/patches/__init__.py -------------------------------------------------------------------------------- /llmtuner/extras/patches/llama_patch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from typing import Optional, Tuple 5 | from transformers.utils import logging 6 | from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv 7 | 8 | try: 9 | from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore 10 | from flash_attn.bert_padding import pad_input, unpad_input # type: ignore 11 | except ImportError: 12 | print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.") 13 | 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | 18 | # Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 19 | class LlamaShiftShortAttention(LlamaAttention): 20 | 21 | def forward( 22 | self, 23 | hidden_states: torch.Tensor, 24 | attention_mask: Optional[torch.Tensor] = None, 25 | position_ids: Optional[torch.LongTensor] = None, 26 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 27 | output_attentions: bool = False, 28 | use_cache: bool = False, 29 | **kwargs 30 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 31 | bsz, q_len, _ = hidden_states.size() 32 | 33 | query_states = self.q_proj(hidden_states) 34 | key_states = self.k_proj(hidden_states) 35 | value_states = self.v_proj(hidden_states) 36 | 37 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 38 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 39 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 40 | 41 | kv_seq_len = key_states.shape[-2] 42 | if past_key_value is not None: 43 | kv_seq_len += past_key_value[0].shape[-2] 44 | 45 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 46 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 47 | 48 | if past_key_value is not None: # reuse k, v, self_attention 49 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 50 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 51 | 52 | past_key_value = (key_states, value_states) if use_cache else None 53 | 54 | if getattr(self, "num_key_value_groups"): 55 | key_states = repeat_kv(key_states, self.num_key_value_groups) 56 | value_states = repeat_kv(value_states, self.num_key_value_groups) 57 | 58 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift 59 | groupsz = int(q_len * getattr(self.config, "group_size_ratio")) 60 | assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) 61 | num_groups = q_len // groupsz 62 | def shift(state: torch.Tensor) -> torch.Tensor: 63 | state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) 64 | state = torch.cat(( 65 | state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) 66 | ), dim=2) 67 | return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) 68 | 69 | query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) 70 | if attention_mask is not None: 71 | attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) 72 | 73 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 74 | 75 | if attention_mask is not None: 76 | attn_weights = attn_weights + attention_mask 77 | 78 | # upcast attention to fp32 79 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 80 | attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :) 81 | attn_output = attn_output.transpose(1, 2).contiguous() 82 | 83 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift back 84 | attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) 85 | attn_output = torch.cat(( 86 | attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) 87 | )) 88 | 89 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 90 | attn_output = self.o_proj(attn_output) 91 | 92 | if not output_attentions: 93 | attn_weights = None 94 | 95 | return attn_output, attn_weights, past_key_value 96 | 97 | 98 | class LlamaFlashAttention2(LlamaAttention): 99 | 100 | def forward( 101 | self, 102 | hidden_states: torch.Tensor, 103 | attention_mask: Optional[torch.Tensor] = None, 104 | position_ids: Optional[torch.LongTensor] = None, 105 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 106 | output_attentions: bool = False, 107 | use_cache: bool = False, 108 | **kwargs 109 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 110 | # LlamaFlashAttention2 attention does not support output_attentions 111 | output_attentions = False 112 | 113 | bsz, q_len, _ = hidden_states.size() 114 | 115 | query_states = self.q_proj(hidden_states) 116 | key_states = self.k_proj(hidden_states) 117 | value_states = self.v_proj(hidden_states) 118 | 119 | # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) 120 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 121 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 122 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 123 | 124 | kv_seq_len = key_states.shape[-2] 125 | if past_key_value is not None: 126 | kv_seq_len += past_key_value[0].shape[-2] 127 | 128 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 129 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 130 | 131 | if past_key_value is not None: # reuse k, v, self_attention 132 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 133 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 134 | 135 | past_key_value = (key_states, value_states) if use_cache else None 136 | 137 | # cast to half precision 138 | input_dtype = query_states.dtype 139 | if input_dtype == torch.float32: 140 | logger.warning_once("The input hidden states seems to be silently casted in float32.") 141 | query_states = query_states.to(self.config.torch_dtype) 142 | key_states = key_states.to(self.config.torch_dtype) 143 | value_states = value_states.to(self.config.torch_dtype) 144 | 145 | if getattr(self, "num_key_value_groups", None): 146 | key_states = repeat_kv(key_states, self.num_key_value_groups) 147 | value_states = repeat_kv(value_states, self.num_key_value_groups) 148 | 149 | query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) 150 | key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) 151 | value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim) 152 | 153 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift 154 | groupsz = int(q_len * getattr(self.config, "group_size_ratio")) 155 | assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) 156 | num_groups = q_len // groupsz 157 | def shift(state: torch.Tensor) -> torch.Tensor: 158 | state = torch.cat(( 159 | state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1) 160 | ), dim=2) 161 | return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) 162 | 163 | query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) 164 | if attention_mask is not None: 165 | attention_mask = attention_mask.reshape(bsz * num_groups, groupsz) 166 | 167 | if attention_mask is not None: 168 | logger.warning_once("Padded sequences are less efficient in FlashAttention.") 169 | # -q_len: assumes left padding when q_len != kv_len 170 | unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:]) 171 | unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask) 172 | unpadded_v, _, _, _ = unpad_input(value_states, attention_mask) 173 | attn_output_unpad = flash_attn_varlen_func( 174 | unpadded_q, 175 | unpadded_k, 176 | unpadded_v, 177 | cu_seqlens_q=cu_seqlens_q, 178 | cu_seqlens_k=cu_seqlens_k, 179 | max_seqlen_q=max_seqlen_q, 180 | max_seqlen_k=max_seqlen_k, 181 | dropout_p=0.0, 182 | softmax_scale=None, 183 | causal=True, 184 | ) 185 | attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len) 186 | else: 187 | attn_output = flash_attn_func( 188 | query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True 189 | ) 190 | 191 | if getattr(self.config, "group_size_ratio", None) and self.training: # shift back 192 | attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) 193 | attn_output = torch.cat(( 194 | attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1) 195 | )) 196 | 197 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 198 | attn_output = self.o_proj(attn_output) 199 | 200 | if not output_attentions: 201 | attn_weights = None 202 | 203 | return attn_output, attn_weights, past_key_value 204 | 205 | 206 | # Disable the transformation of the attention mask in LlamaModel as flash attention 207 | # takes a boolean padding_mask. Fills in the past kv length for use in forward. 208 | def _prepare_decoder_attention_mask( 209 | self, 210 | attention_mask: torch.Tensor, 211 | input_shape: torch.Tensor, 212 | inputs_embeds: torch.Tensor, 213 | past_key_values_length: int 214 | ) -> torch.Tensor: 215 | if attention_mask is not None and torch.all(attention_mask): 216 | return None # This uses the faster call when training with full samples 217 | 218 | return attention_mask 219 | -------------------------------------------------------------------------------- /llmtuner/extras/ploting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import matplotlib.pyplot as plt 5 | from typing import List, Optional 6 | from transformers.trainer import TRAINER_STATE_NAME 7 | 8 | from llmtuner.extras.logging import get_logger 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def smooth(scalars: List[float]) -> List[float]: 15 | r""" 16 | EMA implementation according to TensorBoard. 17 | """ 18 | last = scalars[0] 19 | smoothed = list() 20 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 21 | for next_val in scalars: 22 | smoothed_val = last * weight + (1 - weight) * next_val 23 | smoothed.append(smoothed_val) 24 | last = smoothed_val 25 | return smoothed 26 | 27 | 28 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: 29 | 30 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: 31 | data = json.load(f) 32 | 33 | for key in keys: 34 | steps, metrics = [], [] 35 | for i in range(len(data["log_history"])): 36 | if key in data["log_history"][i]: 37 | steps.append(data["log_history"][i]["step"]) 38 | metrics.append(data["log_history"][i][key]) 39 | 40 | if len(metrics) == 0: 41 | logger.warning(f"No metric {key} to plot.") 42 | continue 43 | 44 | plt.figure() 45 | plt.plot(steps, metrics, alpha=0.4, label="original") 46 | plt.plot(steps, smooth(metrics), label="smoothed") 47 | plt.title("training {} of {}".format(key, save_dictionary)) 48 | plt.xlabel("step") 49 | plt.ylabel(key) 50 | plt.legend() 51 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) 52 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) 53 | -------------------------------------------------------------------------------- /llmtuner/extras/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers.trainer import WEIGHTS_NAME 4 | 5 | from llmtuner.extras.logging import get_logger 6 | 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: 12 | vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) 13 | if not os.path.exists(vhead_file): 14 | logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) 15 | return False 16 | vhead_params = torch.load(vhead_file, map_location="cpu") 17 | model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) 18 | model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) 19 | model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) 20 | model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) 21 | return True 22 | -------------------------------------------------------------------------------- /llmtuner/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_args import DataArguments 2 | from .finetuning_args import FinetuningArguments 3 | from .generating_args import GeneratingArguments 4 | from .model_args import ModelArguments 5 | -------------------------------------------------------------------------------- /llmtuner/hparams/data_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Literal, Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | @dataclass 8 | class DatasetAttr: 9 | 10 | load_from: str 11 | dataset_name: Optional[str] = None 12 | dataset_sha1: Optional[str] = None 13 | system_prompt: Optional[str] = None 14 | ranking: Optional[bool] = False 15 | formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" 16 | 17 | prompt: Optional[str] = "instruction" 18 | query: Optional[str] = "input" 19 | response: Optional[str] = "output" 20 | history: Optional[str] = None 21 | 22 | def __repr__(self) -> str: 23 | return self.dataset_name 24 | 25 | 26 | @dataclass 27 | class DataArguments: 28 | r""" 29 | Arguments pertaining to what data we are going to input our model for training and evaluation. 30 | """ 31 | template: Optional[str] = field( 32 | default=None, 33 | metadata={"help": "Which template to use for constructing prompts in training and inference."} 34 | ) 35 | dataset: Optional[str] = field( 36 | default=None, 37 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} 38 | ) 39 | dataset_dir: Optional[str] = field( 40 | default="data", 41 | metadata={"help": "The name of the folder containing datasets."} 42 | ) 43 | split: Optional[str] = field( 44 | default="train", 45 | metadata={"help": "Which dataset split to use for training and evaluation."} 46 | ) 47 | cutoff_len: Optional[int] = field( 48 | default=1024, 49 | metadata={"help": "The maximum length of the model inputs after tokenization."} 50 | ) 51 | train_on_prompt: Optional[bool] = field( 52 | default=False, 53 | metadata={"help": "Whether to disable the mask on the prompt or not."} 54 | ) 55 | streaming: Optional[bool] = field( 56 | default=False, 57 | metadata={"help": "Enable dataset streaming."} 58 | ) 59 | buffer_size: Optional[int] = field( 60 | default=16384, 61 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} 62 | ) 63 | mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( 64 | default="concat", 65 | metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} 66 | ) 67 | interleave_probs: Optional[str] = field( 68 | default=None, 69 | metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} 70 | ) 71 | overwrite_cache: Optional[bool] = field( 72 | default=False, 73 | metadata={"help": "Overwrite the cached training and evaluation sets."} 74 | ) 75 | preprocessing_num_workers: Optional[int] = field( 76 | default=None, 77 | metadata={"help": "The number of processes to use for the preprocessing."} 78 | ) 79 | max_samples: Optional[int] = field( 80 | default=None, 81 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 82 | ) 83 | eval_num_beams: Optional[int] = field( 84 | default=None, 85 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 86 | ) 87 | ignore_pad_token_for_loss: Optional[bool] = field( 88 | default=True, 89 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 90 | ) 91 | system_prompt: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."} 94 | ) 95 | val_size: Optional[float] = field( 96 | default=0, 97 | metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} 98 | ) 99 | sft_packing: Optional[bool] = field( 100 | default=False, 101 | metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} 102 | ) 103 | cache_path: Optional[str] = field( 104 | default=None, 105 | metadata={"help": "Path to save or load the preprocessed datasets."} 106 | ) 107 | 108 | def __post_init__(self): 109 | if self.streaming and self.val_size > 1e-6 and self.val_size < 1: 110 | raise ValueError("Streaming mode should have an integer val size.") 111 | 112 | if self.streaming and self.max_samples is not None: 113 | raise ValueError("`max_samples` is incompatible with `streaming`.") 114 | 115 | if self.streaming and self.cache_path: 116 | raise ValueError("`cache_path` is incompatible with `streaming`.") 117 | 118 | def init_for_training(self, seed: int): # support mixing multiple datasets 119 | self.seed = seed 120 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] 121 | try: 122 | with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: 123 | dataset_info = json.load(f) 124 | except Exception: 125 | if self.dataset is not None: 126 | raise ValueError("Cannot find dataset_info.json in `dataset_dir`.") 127 | dataset_info = None 128 | 129 | prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] 130 | prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) 131 | assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1." 132 | 133 | if self.interleave_probs is not None: 134 | self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] 135 | 136 | self.dataset_list: List[DatasetAttr] = [] 137 | for i, name in enumerate(dataset_names): 138 | if name not in dataset_info: 139 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) 140 | 141 | if "hf_hub_url" in dataset_info[name]: 142 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) 143 | elif "script_url" in dataset_info[name]: 144 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) 145 | else: 146 | dataset_attr = DatasetAttr( 147 | "file", 148 | dataset_name=dataset_info[name]["file_name"], 149 | dataset_sha1=dataset_info[name].get("file_sha1", None) 150 | ) 151 | 152 | if "columns" in dataset_info[name]: 153 | dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) 154 | dataset_attr.query = dataset_info[name]["columns"].get("query", None) 155 | dataset_attr.response = dataset_info[name]["columns"].get("response", None) 156 | dataset_attr.history = dataset_info[name]["columns"].get("history", None) 157 | 158 | dataset_attr.ranking = dataset_info[name].get("ranking", False) 159 | dataset_attr.system_prompt = prompt_list[i] 160 | self.dataset_list.append(dataset_attr) 161 | -------------------------------------------------------------------------------- /llmtuner/hparams/finetuning_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Literal, Optional 3 | from dataclasses import asdict, dataclass, field 4 | 5 | 6 | @dataclass 7 | class FinetuningArguments: 8 | r""" 9 | Arguments pertaining to which techniques we are going to fine-tuning with. 10 | """ 11 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( 12 | default="sft", 13 | metadata={"help": "Which stage will be performed in training."} 14 | ) 15 | finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( 16 | default="lora", 17 | metadata={"help": "Which fine-tuning method to use."} 18 | ) 19 | num_layer_trainable: Optional[int] = field( 20 | default=3, 21 | metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} 22 | ) 23 | name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( 24 | default="mlp", 25 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ 26 | LLaMA choices: [\"mlp\", \"self_attn\"], \ 27 | BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ 28 | Qwen choices: [\"mlp\", \"attn\"], \ 29 | Phi-1.5 choices: [\"mlp\", \"mixer\"], \ 30 | LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} 31 | ) 32 | lora_rank: Optional[int] = field( 33 | default=8, 34 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 35 | ) 36 | lora_alpha: Optional[float] = field( 37 | default=32.0, 38 | metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} 39 | ) 40 | lora_dropout: Optional[float] = field( 41 | default=0.1, 42 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 43 | ) 44 | lora_target: Optional[str] = field( 45 | default=None, 46 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ 47 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 48 | BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ 49 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 50 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ 51 | Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ 52 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} 53 | ) 54 | additional_target: Optional[str] = field( 55 | default=None, 56 | metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} 57 | ) 58 | resume_lora_training: Optional[bool] = field( 59 | default=True, 60 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 61 | ) 62 | ppo_score_norm: Optional[bool] = field( 63 | default=False, 64 | metadata={"help": "Use score normalization in PPO training."} 65 | ) 66 | ppo_logger: Optional[str] = field( 67 | default=None, 68 | metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} 69 | ) 70 | ppo_target: Optional[float] = field( 71 | default=6.0, 72 | metadata={"help": "Target KL value for adaptive KL control in PPO training."} 73 | ) 74 | dpo_beta: Optional[float] = field( 75 | default=0.1, 76 | metadata={"help": "The beta parameter for the DPO loss."} 77 | ) 78 | upcast_layernorm: Optional[bool] = field( 79 | default=False, 80 | metadata={"help": "Whether to upcast the layernorm weights in fp32."} 81 | ) 82 | neft_alpha: Optional[float] = field( 83 | default=0, 84 | metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} 85 | ) 86 | 87 | def __post_init__(self): 88 | if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA 89 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] 90 | 91 | if isinstance(self.additional_target, str): 92 | self.additional_target = [target.strip() for target in self.additional_target.split(",")] 93 | 94 | assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." 95 | 96 | def save_to_json(self, json_path: str): 97 | r"""Saves the content of this instance in JSON format inside `json_path`.""" 98 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" 99 | with open(json_path, "w", encoding="utf-8") as f: 100 | f.write(json_string) 101 | 102 | @classmethod 103 | def load_from_json(cls, json_path: str): 104 | r"""Creates an instance from the content of `json_path`.""" 105 | with open(json_path, "r", encoding="utf-8") as f: 106 | text = f.read() 107 | return cls(**json.loads(text)) 108 | -------------------------------------------------------------------------------- /llmtuner/hparams/general_args.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneralArguments: 7 | r""" 8 | Arguments pertaining to which stage we are going to perform. 9 | """ 10 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( 11 | default="sft", 12 | metadata={"help": "Which stage will be performed in training."} 13 | ) 14 | -------------------------------------------------------------------------------- /llmtuner/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneratingArguments: 7 | r""" 8 | Arguments pertaining to specify the decoding parameters. 9 | """ 10 | do_sample: Optional[bool] = field( 11 | default=True, 12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} 13 | ) 14 | temperature: Optional[float] = field( 15 | default=0.95, 16 | metadata={"help": "The value used to modulate the next token probabilities."} 17 | ) 18 | top_p: Optional[float] = field( 19 | default=0.7, 20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} 21 | ) 22 | top_k: Optional[int] = field( 23 | default=50, 24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} 25 | ) 26 | num_beams: Optional[int] = field( 27 | default=1, 28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."} 29 | ) 30 | max_length: Optional[int] = field( 31 | default=None, 32 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} 33 | ) 34 | max_new_tokens: Optional[int] = field( 35 | default=512, 36 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} 37 | ) 38 | repetition_penalty: Optional[float] = field( 39 | default=1.0, 40 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} 41 | ) 42 | length_penalty: Optional[float] = field( 43 | default=1.0, 44 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} 45 | ) 46 | 47 | def to_dict(self) -> Dict[str, Any]: 48 | args = asdict(self) 49 | if args.get("max_new_tokens", None): 50 | args.pop("max_length", None) 51 | return args 52 | -------------------------------------------------------------------------------- /llmtuner/hparams/model_args.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | r""" 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 9 | """ 10 | model_name_or_path: str = field( 11 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 12 | ) 13 | cache_dir: Optional[str] = field( 14 | default=None, 15 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 16 | ) 17 | use_fast_tokenizer: Optional[bool] = field( 18 | default=True, 19 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 20 | ) 21 | split_special_tokens: Optional[bool] = field( 22 | default=False, 23 | metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} 24 | ) 25 | use_auth_token: Optional[bool] = field( 26 | default=False, 27 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 28 | ) 29 | model_revision: Optional[str] = field( 30 | default="main", 31 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 32 | ) 33 | quantization_bit: Optional[int] = field( 34 | default=None, 35 | metadata={"help": "The number of bits to quantize the model."} 36 | ) 37 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 38 | default="nf4", 39 | metadata={"help": "Quantization data type to use in int4 training."} 40 | ) 41 | double_quantization: Optional[bool] = field( 42 | default=True, 43 | metadata={"help": "Whether to use double quantization in int4 training or not."} 44 | ) 45 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field( 46 | default=None, 47 | metadata={"help": "Adopt scaled rotary positional embeddings."} 48 | ) 49 | checkpoint_dir: Optional[str] = field( 50 | default=None, 51 | metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} 52 | ) 53 | flash_attn: Optional[bool] = field( 54 | default=False, 55 | metadata={"help": "Enable FlashAttention-2 for faster training."} 56 | ) 57 | shift_attn: Optional[bool] = field( 58 | default=False, 59 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} 60 | ) 61 | reward_model: Optional[str] = field( 62 | default=None, 63 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 64 | ) 65 | plot_loss: Optional[bool] = field( 66 | default=False, 67 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 68 | ) 69 | hf_auth_token: Optional[str] = field( 70 | default=None, 71 | metadata={"help": "Auth token to log in with Hugging Face Hub."} 72 | ) 73 | export_dir: Optional[str] = field( 74 | default=None, 75 | metadata={"help": "Path to the directory to save the exported model."} 76 | ) 77 | 78 | def __post_init__(self): 79 | self.compute_dtype = None 80 | self.model_max_length = None 81 | 82 | if self.split_special_tokens and self.use_fast_tokenizer: 83 | raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") 84 | 85 | if self.checkpoint_dir is not None: # support merging multiple lora weights 86 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 87 | 88 | if self.quantization_bit is not None: 89 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." 90 | 91 | if self.use_auth_token == True and self.hf_auth_token is not None: 92 | from huggingface_hub.hf_api import HfFolder # lazy load 93 | HfFolder.save_token(self.hf_auth_token) 94 | -------------------------------------------------------------------------------- /llmtuner/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.tune import export_model, run_exp 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/core/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.core.parser import get_train_args, get_infer_args 2 | from llmtuner.tuner.core.loader import load_model_and_tokenizer 3 | -------------------------------------------------------------------------------- /llmtuner/tuner/core/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING 3 | 4 | from peft import ( 5 | PeftModel, 6 | TaskType, 7 | LoraConfig, 8 | get_peft_model 9 | ) 10 | 11 | from llmtuner.extras.logging import get_logger 12 | from llmtuner.tuner.core.utils import find_all_linear_modules 13 | 14 | if TYPE_CHECKING: 15 | from transformers.modeling_utils import PreTrainedModel 16 | from llmtuner.hparams import ModelArguments, FinetuningArguments 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def init_adapter( 23 | model: "PreTrainedModel", 24 | model_args: "ModelArguments", 25 | finetuning_args: "FinetuningArguments", 26 | is_trainable: bool, 27 | is_mergeable: bool 28 | ) -> "PreTrainedModel": 29 | r""" 30 | Initializes the adapters. 31 | 32 | Support full-parameter, freeze and LoRA training. 33 | 34 | Note that the trainable parameters must be cast to float32. 35 | """ 36 | 37 | if finetuning_args.finetuning_type == "none" and is_trainable: 38 | raise ValueError("You cannot use finetuning_type=none while training.") 39 | 40 | if finetuning_args.finetuning_type == "full" and is_trainable: 41 | logger.info("Fine-tuning method: Full") 42 | model = model.float() 43 | 44 | if finetuning_args.finetuning_type == "freeze": 45 | logger.info("Fine-tuning method: Freeze") 46 | num_layers = getattr(model.config, "num_layers") 47 | if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 48 | trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] 49 | else: # fine-tuning the first n layers if num_layer_trainable < 0 50 | trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] 51 | 52 | trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] 53 | for name, param in model.named_parameters(): 54 | if not any(trainable_layer in name for trainable_layer in trainable_layers): 55 | param.requires_grad_(False) 56 | else: 57 | param.data = param.data.to(torch.float32) 58 | 59 | if finetuning_args.finetuning_type == "lora": 60 | logger.info("Fine-tuning method: LoRA") 61 | latest_checkpoint = None 62 | 63 | if model_args.checkpoint_dir is not None: 64 | if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning 65 | checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] 66 | else: 67 | checkpoints_to_merge = model_args.checkpoint_dir 68 | 69 | for checkpoint in checkpoints_to_merge: 70 | model = PeftModel.from_pretrained(model, checkpoint) 71 | model = model.merge_and_unload() 72 | 73 | if len(checkpoints_to_merge) > 0: 74 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) 75 | 76 | if latest_checkpoint is not None: # resume lora training or quantized inference 77 | model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) 78 | 79 | if is_trainable and latest_checkpoint is None: # create new lora weights while training 80 | if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": 81 | target_modules = find_all_linear_modules(model, model_args.quantization_bit) 82 | else: 83 | target_modules = finetuning_args.lora_target 84 | 85 | lora_config = LoraConfig( 86 | task_type=TaskType.CAUSAL_LM, 87 | inference_mode=False, 88 | r=finetuning_args.lora_rank, 89 | lora_alpha=finetuning_args.lora_alpha, 90 | lora_dropout=finetuning_args.lora_dropout, 91 | target_modules=target_modules, 92 | modules_to_save=finetuning_args.additional_target 93 | ) 94 | model = get_peft_model(model, lora_config) 95 | if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 96 | model.base_model.peft_config = model.peft_config 97 | 98 | if model_args.checkpoint_dir is not None: 99 | logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) 100 | 101 | return model 102 | -------------------------------------------------------------------------------- /llmtuner/tuner/core/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import datasets 5 | import transformers 6 | from typing import Any, Dict, Optional, Tuple 7 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments 8 | from transformers.trainer_utils import get_last_checkpoint 9 | 10 | from llmtuner.extras.logging import get_logger 11 | from llmtuner.hparams import ( 12 | ModelArguments, 13 | DataArguments, 14 | FinetuningArguments, 15 | GeneratingArguments 16 | ) 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: 23 | if args is not None: 24 | return parser.parse_dict(args) 25 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 26 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 27 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 28 | return parser.parse_json_file(os.path.abspath(sys.argv[1])) 29 | else: 30 | return parser.parse_args_into_dataclasses() 31 | 32 | 33 | def parse_train_args( 34 | args: Optional[Dict[str, Any]] = None 35 | ) -> Tuple[ 36 | ModelArguments, 37 | DataArguments, 38 | Seq2SeqTrainingArguments, 39 | FinetuningArguments, 40 | GeneratingArguments 41 | ]: 42 | parser = HfArgumentParser(( 43 | ModelArguments, 44 | DataArguments, 45 | Seq2SeqTrainingArguments, 46 | FinetuningArguments, 47 | GeneratingArguments 48 | )) 49 | return _parse_args(parser, args) 50 | 51 | 52 | def parse_infer_args( 53 | args: Optional[Dict[str, Any]] = None 54 | ) -> Tuple[ 55 | ModelArguments, 56 | DataArguments, 57 | FinetuningArguments, 58 | GeneratingArguments 59 | ]: 60 | parser = HfArgumentParser(( 61 | ModelArguments, 62 | DataArguments, 63 | FinetuningArguments, 64 | GeneratingArguments 65 | )) 66 | return _parse_args(parser, args) 67 | 68 | 69 | def get_train_args( 70 | args: Optional[Dict[str, Any]] = None 71 | ) -> Tuple[ 72 | ModelArguments, 73 | DataArguments, 74 | Seq2SeqTrainingArguments, 75 | FinetuningArguments, 76 | GeneratingArguments 77 | ]: 78 | model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) 79 | 80 | # Setup logging 81 | if training_args.should_log: 82 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 83 | transformers.utils.logging.set_verbosity_info() 84 | 85 | log_level = training_args.get_process_log_level() 86 | datasets.utils.logging.set_verbosity(log_level) 87 | transformers.utils.logging.set_verbosity(log_level) 88 | transformers.utils.logging.enable_default_handler() 89 | transformers.utils.logging.enable_explicit_format() 90 | 91 | # Check arguments 92 | data_args.init_for_training(training_args.seed) 93 | 94 | if finetuning_args.stage != "pt" and data_args.template is None: 95 | raise ValueError("Please specify which `template` to use.") 96 | 97 | if finetuning_args.stage != "sft" and training_args.predict_with_generate: 98 | raise ValueError("`predict_with_generate` cannot be set as True except SFT.") 99 | 100 | if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: 101 | raise ValueError("Please enable `predict_with_generate` to save model predictions.") 102 | 103 | if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora": 104 | raise ValueError("RM and PPO stages can only be performed with the LoRA method.") 105 | 106 | if finetuning_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None: 107 | raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.") 108 | 109 | if finetuning_args.stage in ["ppo", "dpo"] and not training_args.do_train: 110 | raise ValueError("PPO and DPO stages can only be performed at training.") 111 | 112 | if finetuning_args.stage in ["rm", "dpo"]: 113 | for dataset_attr in data_args.dataset_list: 114 | if not dataset_attr.ranking: 115 | raise ValueError("Please use ranked datasets for reward modeling or DPO training.") 116 | 117 | if finetuning_args.stage == "ppo" and model_args.reward_model is None: 118 | raise ValueError("Reward model is necessary for PPO training.") 119 | 120 | if finetuning_args.stage == "ppo" and model_args.shift_attn: 121 | raise ValueError("PPO training is incompatible with S^2-Attn.") 122 | 123 | if training_args.max_steps == -1 and data_args.streaming: 124 | raise ValueError("Please specify `max_steps` in streaming mode.") 125 | 126 | if training_args.do_train and training_args.predict_with_generate: 127 | raise ValueError("`predict_with_generate` cannot be set as True while training.") 128 | 129 | if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: 130 | raise ValueError("Please specify `lora_target` in LoRA training.") 131 | 132 | if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": 133 | raise ValueError("Quantization is only compatible with the LoRA method.") 134 | 135 | if model_args.checkpoint_dir is not None: 136 | if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: 137 | raise ValueError("Only LoRA tuning accepts multiple checkpoints.") 138 | 139 | if model_args.quantization_bit is not None: 140 | if len(model_args.checkpoint_dir) != 1: 141 | raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") 142 | 143 | if not finetuning_args.resume_lora_training: 144 | raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.") 145 | 146 | if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm): 147 | logger.warning("We recommend enable `upcast_layernorm` in quantized training.") 148 | 149 | if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): 150 | logger.warning("We recommend enable mixed precision training.") 151 | 152 | if (not training_args.do_train) and model_args.quantization_bit is not None: 153 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") 154 | 155 | # postprocess training_args 156 | if ( 157 | training_args.local_rank != -1 158 | and training_args.ddp_find_unused_parameters is None 159 | and finetuning_args.finetuning_type == "lora" 160 | ): 161 | logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") 162 | training_args_dict = training_args.to_dict() 163 | training_args_dict.update(dict(ddp_find_unused_parameters=False)) 164 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 165 | 166 | if ( 167 | training_args.resume_from_checkpoint is None 168 | and training_args.do_train 169 | and os.path.isdir(training_args.output_dir) 170 | and not training_args.overwrite_output_dir 171 | ): 172 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 173 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 174 | raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") 175 | 176 | if last_checkpoint is not None: 177 | training_args_dict = training_args.to_dict() 178 | training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) 179 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 180 | logger.info( 181 | "Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid." 182 | ) 183 | 184 | # postprocess model_args 185 | model_args.compute_dtype = ( 186 | torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) 187 | ) 188 | model_args.model_max_length = data_args.cutoff_len 189 | 190 | # Log on each process the small summary: 191 | logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( 192 | training_args.local_rank, training_args.device, training_args.n_gpu, 193 | bool(training_args.local_rank != -1), str(model_args.compute_dtype) 194 | )) 195 | logger.info(f"Training/evaluation parameters {training_args}") 196 | 197 | # Set seed before initializing model. 198 | transformers.set_seed(training_args.seed) 199 | 200 | return model_args, data_args, training_args, finetuning_args, generating_args 201 | 202 | 203 | def get_infer_args( 204 | args: Optional[Dict[str, Any]] = None 205 | ) -> Tuple[ 206 | ModelArguments, 207 | DataArguments, 208 | FinetuningArguments, 209 | GeneratingArguments 210 | ]: 211 | model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) 212 | 213 | if data_args.template is None: 214 | raise ValueError("Please specify which `template` to use.") 215 | 216 | if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": 217 | raise ValueError("Quantization is only compatible with the LoRA method.") 218 | 219 | if model_args.checkpoint_dir is not None: 220 | if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) != 1: 221 | raise ValueError("Only LoRA tuning accepts multiple checkpoints.") 222 | 223 | if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1: 224 | raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.") 225 | 226 | return model_args, data_args, finetuning_args, generating_args 227 | -------------------------------------------------------------------------------- /llmtuner/tuner/core/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from types import MethodType 3 | from typing import TYPE_CHECKING, List, Optional 4 | 5 | from llmtuner.extras.constants import LAYERNORM_NAMES 6 | from llmtuner.extras.logging import get_logger 7 | 8 | if TYPE_CHECKING: 9 | from transformers.modeling_utils import PreTrainedModel 10 | from llmtuner.hparams import FinetuningArguments 11 | 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | def find_all_linear_modules( 17 | model: "PreTrainedModel", 18 | quantization_bit: Optional[int] = None, 19 | output_layer_name: Optional[str] = "lm_head" 20 | ) -> List[str]: 21 | if quantization_bit is not None: 22 | import bitsandbytes as bnb 23 | linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt 24 | else: 25 | linear_cls = torch.nn.Linear 26 | 27 | module_names = set() 28 | for name, module in model.named_modules(): 29 | if output_layer_name not in name and isinstance(module, linear_cls): 30 | module_names.add(name.split(".")[-1]) 31 | 32 | if output_layer_name in module_names: 33 | module_names.pop(output_layer_name) 34 | 35 | return list(module_names) 36 | 37 | 38 | def prepare_model_for_training( 39 | model: "PreTrainedModel", 40 | finetuning_args: "FinetuningArguments", 41 | output_layer_name: Optional[str] = "lm_head", 42 | use_gradient_checkpointing: Optional[bool] = True, 43 | layernorm_names: Optional[List[str]] = LAYERNORM_NAMES 44 | ) -> "PreTrainedModel": 45 | r""" 46 | Includes: 47 | (1) cast the layernorm in fp32 48 | (2) make output embedding layer require grads 49 | (3) upcast the lm_head to fp32 50 | Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33 51 | """ 52 | if finetuning_args.upcast_layernorm: 53 | for name, param in model.named_parameters(): 54 | if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names): 55 | param.data = param.data.to(torch.float32) 56 | logger.info("Upcasting weights in layernorm in float32.") 57 | 58 | if finetuning_args.neft_alpha > 1e-6: 59 | input_embed = model.get_input_embeddings() 60 | if isinstance(input_embed, torch.nn.Embedding): 61 | def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: 62 | embeddings = input_embed.__class__.forward(self, x) 63 | if self.training: 64 | dims = self.num_embeddings * self.embedding_dim 65 | mag_norm = finetuning_args.neft_alpha / (dims ** 0.5) 66 | embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm) 67 | return embeddings 68 | 69 | input_embed.forward = MethodType(noisy_forward, input_embed) 70 | logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha)) 71 | else: 72 | logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.") 73 | 74 | if use_gradient_checkpointing: 75 | if hasattr(model, "enable_input_require_grads"): 76 | model.enable_input_require_grads() 77 | else: 78 | def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): 79 | output.requires_grad_(True) 80 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 81 | 82 | model.gradient_checkpointing_enable() 83 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 84 | logger.info("Gradient checkpointing enabled.") 85 | 86 | if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name): 87 | output_layer = getattr(model, output_layer_name) 88 | if isinstance(output_layer, torch.nn.Linear): 89 | def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: 90 | return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32) 91 | 92 | output_layer.forward = MethodType(forward_in_fp32, output_layer) 93 | 94 | return model 95 | -------------------------------------------------------------------------------- /llmtuner/tuner/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.dpo.workflow import run_dpo 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/dpo/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Sequence, Tuple 4 | from transformers import DataCollatorForSeq2Seq 5 | 6 | 7 | @dataclass 8 | class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): 9 | r""" 10 | Data collator for pairwise data. 11 | """ 12 | 13 | def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: 14 | padded_labels = [] 15 | for feature, (prompt_len, answer_len) in zip(batch, positions): 16 | if self.tokenizer.padding_side == "left": 17 | start, end = feature.size(0) - answer_len, feature.size(0) 18 | else: 19 | start, end = prompt_len, prompt_len + answer_len 20 | padded_tensor = self.label_pad_token_id * torch.ones_like(feature) 21 | padded_tensor[start:end] = feature[start:end] 22 | padded_labels.append(padded_tensor) 23 | return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory 24 | 25 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 26 | r""" 27 | Pads batched data to the longest sequence in the batch. 28 | 29 | We generate 2 * n examples where the first n examples represent chosen examples and 30 | the last n examples represent rejected examples. 31 | """ 32 | concatenated_features = [] 33 | label_positions = [] 34 | for key in ("chosen_ids", "rejected_ids"): 35 | for feature in features: 36 | prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) 37 | concatenated_features.append({ 38 | "input_ids": feature["prompt_ids"] + feature[key], 39 | "attention_mask": [1] * (prompt_len + answer_len) 40 | }) 41 | label_positions.append((prompt_len, answer_len)) 42 | 43 | batch = self.tokenizer.pad( 44 | concatenated_features, 45 | padding=self.padding, 46 | max_length=self.max_length, 47 | pad_to_multiple_of=self.pad_to_multiple_of, 48 | return_tensors=self.return_tensors, 49 | ) 50 | batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) 51 | return batch 52 | -------------------------------------------------------------------------------- /llmtuner/tuner/dpo/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union 4 | from transformers import BatchEncoding, Trainer 5 | from trl import DPOTrainer 6 | from trl.trainer.utils import disable_dropout_in_model 7 | 8 | from llmtuner.extras.constants import IGNORE_INDEX 9 | 10 | if TYPE_CHECKING: 11 | from transformers import PreTrainedModel 12 | 13 | 14 | class CustomDPOTrainer(DPOTrainer): 15 | 16 | def __init__( 17 | self, 18 | beta: float, 19 | model: Union["PreTrainedModel", torch.nn.Module], 20 | ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, 21 | disable_dropout: Optional[bool] = True, 22 | loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid", 23 | **kwargs 24 | ): 25 | if disable_dropout: 26 | disable_dropout_in_model(model) 27 | if ref_model is not None: 28 | disable_dropout_in_model(ref_model) 29 | 30 | self.is_encoder_decoder = model.config.is_encoder_decoder 31 | self.ref_model = ref_model 32 | self.use_dpo_data_collator = True # hack to avoid warning 33 | self.label_pad_token_id = IGNORE_INDEX 34 | self.padding_value = 0 35 | self.beta = beta 36 | self.loss_type = loss_type 37 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 38 | 39 | Trainer.__init__(self, model=model, **kwargs) 40 | if not hasattr(self, "accelerator"): 41 | raise AttributeError("Please update `transformers`.") 42 | 43 | if ref_model is not None: 44 | if self.is_deepspeed_enabled: 45 | self.ref_model = self._prepare_deepspeed(self.ref_model) 46 | else: 47 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 48 | 49 | def concatenated_forward( 50 | self, 51 | model: Optional[torch.nn.Module] = None, 52 | batch: Optional[Dict[str, torch.Tensor]] = None 53 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 54 | batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error 55 | 56 | all_logits = model( 57 | input_ids=batch_copied["input_ids"], 58 | attention_mask=batch_copied["attention_mask"], 59 | return_dict=True 60 | ).logits.to(torch.float32) 61 | 62 | all_logps = self._get_batch_logps( 63 | all_logits, 64 | batch["labels"], 65 | average_log_prob=False 66 | ) 67 | batch_size = batch["input_ids"].size(0) // 2 68 | chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) 69 | chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) 70 | return chosen_logps, rejected_logps, chosen_logits, rejected_logits 71 | -------------------------------------------------------------------------------- /llmtuner/tuner/dpo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py 2 | 3 | from copy import deepcopy 4 | from peft import PeftModel 5 | from typing import TYPE_CHECKING, Optional, List 6 | from transformers import Seq2SeqTrainingArguments 7 | 8 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | from llmtuner.extras.ploting import plot_loss 11 | from llmtuner.tuner.core import load_model_and_tokenizer 12 | from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding 13 | from llmtuner.tuner.dpo.trainer import CustomDPOTrainer 14 | 15 | if TYPE_CHECKING: 16 | from transformers import TrainerCallback 17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 18 | 19 | 20 | def run_dpo( 21 | model_args: "ModelArguments", 22 | data_args: "DataArguments", 23 | training_args: "Seq2SeqTrainingArguments", 24 | finetuning_args: "FinetuningArguments", 25 | callbacks: Optional[List["TrainerCallback"]] = None 26 | ): 27 | dataset = get_dataset(model_args, data_args) 28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") 29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") 30 | data_collator = DPODataCollatorWithPadding( 31 | tokenizer=tokenizer, 32 | pad_to_multiple_of=4, 33 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 34 | ) 35 | 36 | training_args_dict = training_args.to_dict() 37 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset 38 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 39 | 40 | # Initialize our Trainer 41 | trainer = CustomDPOTrainer( 42 | beta=finetuning_args.dpo_beta, 43 | model=model, 44 | ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, 45 | args=training_args, 46 | tokenizer=tokenizer, 47 | data_collator=data_collator, 48 | callbacks=callbacks, 49 | **split_dataset(dataset, data_args, training_args) 50 | ) 51 | 52 | # Training 53 | if training_args.do_train: 54 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 55 | trainer.log_metrics("train", train_result.metrics) 56 | trainer.save_metrics("train", train_result.metrics) 57 | trainer.save_state() 58 | trainer.save_model() 59 | if trainer.is_world_process_zero() and model_args.plot_loss: 60 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 61 | -------------------------------------------------------------------------------- /llmtuner/tuner/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.ppo.workflow import run_ppo 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/ppo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, Dict, Literal, Optional 3 | 4 | if TYPE_CHECKING: 5 | from transformers import PreTrainedModel 6 | from trl import AutoModelForCausalLMWithValueHead 7 | 8 | 9 | def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: 10 | if target == "reward": # save default head temporarily 11 | valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict() 12 | setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) 13 | setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) 14 | 15 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active 16 | model.v_head.load_state_dict({ 17 | "summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(), 18 | "summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone() 19 | }) 20 | 21 | 22 | def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]: 23 | layer_norm_params = {} 24 | for name, param in model.named_parameters(): 25 | if param.data.dtype == torch.float32: 26 | layer_norm_params[name] = param.data.detach().clone() 27 | param.data = param.data.to(model.config.torch_dtype) 28 | 29 | return layer_norm_params 30 | 31 | 32 | def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None: 33 | for name, param in model.named_parameters(): 34 | if name in layernorm_params: 35 | param.data = layernorm_params[name] 36 | -------------------------------------------------------------------------------- /llmtuner/tuner/ppo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py 2 | 3 | import math 4 | from trl import PPOConfig 5 | from torch.optim import AdamW 6 | from typing import TYPE_CHECKING, Optional, List 7 | from transformers import DataCollatorWithPadding 8 | from transformers.optimization import get_scheduler 9 | 10 | from llmtuner.dsets import get_dataset, preprocess_dataset 11 | from llmtuner.extras.callbacks import SavePeftModelCallback 12 | from llmtuner.extras.ploting import plot_loss 13 | from llmtuner.tuner.core import load_model_and_tokenizer 14 | from llmtuner.tuner.ppo.trainer import CustomPPOTrainer 15 | 16 | if TYPE_CHECKING: 17 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 18 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 19 | 20 | 21 | def run_ppo( 22 | model_args: "ModelArguments", 23 | data_args: "DataArguments", 24 | training_args: "Seq2SeqTrainingArguments", 25 | finetuning_args: "FinetuningArguments", 26 | generating_args: "GeneratingArguments", 27 | callbacks: Optional[List["TrainerCallback"]] = None 28 | ): 29 | dataset = get_dataset(model_args, data_args) 30 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") 31 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") 32 | 33 | tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training 34 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 35 | 36 | ppo_config = PPOConfig( 37 | model_name=model_args.model_name_or_path, 38 | learning_rate=training_args.learning_rate, 39 | mini_batch_size=training_args.per_device_train_batch_size, 40 | batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, 41 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 42 | ppo_epochs=1, 43 | max_grad_norm=training_args.max_grad_norm, 44 | seed=training_args.seed, 45 | optimize_cuda_cache=True, 46 | target=finetuning_args.ppo_target, 47 | log_with=finetuning_args.ppo_logger, 48 | use_score_scaling=finetuning_args.ppo_score_norm, 49 | use_score_norm=finetuning_args.ppo_score_norm, 50 | accelerator_kwargs={"step_scheduler_with_optimizer": False} 51 | ) 52 | 53 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) 54 | if training_args.max_steps > 0: 55 | num_training_steps = training_args.max_steps 56 | else: 57 | total_train_batch_size = ( 58 | training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size 59 | ) 60 | num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) 61 | 62 | lr_scheduler = get_scheduler( 63 | training_args.lr_scheduler_type, 64 | optimizer=optimizer, 65 | num_warmup_steps=training_args.get_warmup_steps(num_training_steps), 66 | num_training_steps=num_training_steps 67 | ) 68 | 69 | # Initialize our Trainer 70 | ppo_trainer = CustomPPOTrainer( 71 | model_args=model_args, 72 | training_args=training_args, 73 | finetuning_args=finetuning_args, 74 | generating_args=generating_args, 75 | callbacks=callbacks + [SavePeftModelCallback()], 76 | config=ppo_config, 77 | model=model, 78 | ref_model=None, 79 | tokenizer=tokenizer, 80 | dataset=dataset, 81 | data_collator=data_collator, 82 | optimizer=optimizer, 83 | lr_scheduler=lr_scheduler 84 | ) 85 | 86 | # Training 87 | if training_args.do_train: 88 | ppo_trainer.ppo_train() 89 | ppo_trainer.save_model() 90 | ppo_trainer.save_state() # must be called after save_model to have a folder 91 | if ppo_trainer.is_world_process_zero() and model_args.plot_loss: 92 | plot_loss(training_args.output_dir, keys=["loss", "reward"]) 93 | -------------------------------------------------------------------------------- /llmtuner/tuner/pt/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.pt.workflow import run_pt 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/pt/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py 2 | 3 | import math 4 | from typing import TYPE_CHECKING, Optional, List 5 | from transformers import DataCollatorForLanguageModeling, Trainer 6 | 7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 8 | from llmtuner.extras.ploting import plot_loss 9 | from llmtuner.tuner.core import load_model_and_tokenizer 10 | 11 | if TYPE_CHECKING: 12 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 13 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 14 | 15 | 16 | def run_pt( 17 | model_args: "ModelArguments", 18 | data_args: "DataArguments", 19 | training_args: "Seq2SeqTrainingArguments", 20 | finetuning_args: "FinetuningArguments", 21 | callbacks: Optional[List["TrainerCallback"]] = None 22 | ): 23 | dataset = get_dataset(model_args, data_args) 24 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") 25 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") 26 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 27 | 28 | # Initialize our Trainer 29 | trainer = Trainer( 30 | model=model, 31 | args=training_args, 32 | tokenizer=tokenizer, 33 | data_collator=data_collator, 34 | callbacks=callbacks, 35 | **split_dataset(dataset, data_args, training_args) 36 | ) 37 | 38 | # Training 39 | if training_args.do_train: 40 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 41 | trainer.log_metrics("train", train_result.metrics) 42 | trainer.save_metrics("train", train_result.metrics) 43 | trainer.save_state() 44 | trainer.save_model() 45 | if trainer.is_world_process_zero() and model_args.plot_loss: 46 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 47 | 48 | # Evaluation 49 | if training_args.do_eval: 50 | metrics = trainer.evaluate(metric_key_prefix="eval") 51 | try: 52 | perplexity = math.exp(metrics["eval_loss"]) 53 | except OverflowError: 54 | perplexity = float("inf") 55 | 56 | metrics["perplexity"] = perplexity 57 | trainer.log_metrics("eval", metrics) 58 | trainer.save_metrics("eval", metrics) 59 | -------------------------------------------------------------------------------- /llmtuner/tuner/rm/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.rm.workflow import run_rm 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/rm/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Sequence 4 | from transformers import DataCollatorWithPadding 5 | 6 | 7 | @dataclass 8 | class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): 9 | r""" 10 | Data collator for pairwise data. 11 | """ 12 | 13 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 14 | r""" 15 | Pads batched data to the longest sequence in the batch. 16 | 17 | We generate 2 * n examples where the first n examples represent chosen examples and 18 | the last n examples represent rejected examples. 19 | """ 20 | features = [ 21 | { 22 | "input_ids": feature["prompt_ids"] + feature[key], 23 | "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) 24 | } 25 | for key in ("chosen_ids", "rejected_ids") for feature in features 26 | ] 27 | return super().__call__(features) 28 | -------------------------------------------------------------------------------- /llmtuner/tuner/rm/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, Sequence, Tuple, Union 3 | 4 | 5 | def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 6 | preds, _ = eval_preds 7 | return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} 8 | -------------------------------------------------------------------------------- /llmtuner/tuner/rm/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 5 | from transformers import Trainer 6 | 7 | from llmtuner.extras.logging import get_logger 8 | 9 | if TYPE_CHECKING: 10 | from transformers.trainer import PredictionOutput 11 | from transformers.modeling_utils import PreTrainedModel 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class PairwiseTrainer(Trainer): 18 | r""" 19 | Inherits PeftTrainer to compute pairwise loss. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.can_return_loss = True # override property to return eval_loss 25 | 26 | def compute_loss( 27 | self, 28 | model: "PreTrainedModel", 29 | inputs: Dict[str, torch.Tensor], 30 | return_outputs: Optional[bool] = False 31 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: 32 | r""" 33 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. 34 | 35 | Subclass and override to inject custom behavior. 36 | 37 | Note that the first element will be removed from the output tuple. 38 | See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 39 | """ 40 | # Compute rewards 41 | _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) 42 | if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 43 | values = torch.transpose(values, 0, 1) 44 | 45 | # Split the inputs and rewards into two parts, chosen and rejected 46 | batch_size = inputs["input_ids"].size(0) // 2 47 | chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] 48 | chosen_attn_mask, rejected_attn_mask = ( 49 | inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:] 50 | ) 51 | chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] 52 | chosen_scores, rejected_scores = [], [] 53 | 54 | # Compute pairwise loss. Only backprop on the different tokens before padding 55 | # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py 56 | loss = 0 57 | for i in range(batch_size): 58 | chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1 59 | rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1 60 | check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() 61 | 62 | if len(check_divergence) == 0: 63 | end_index = chosen_length 64 | div_index = end_index - 1 65 | else: 66 | end_index = max(chosen_length, rejected_length) 67 | div_index = check_divergence[0] 68 | 69 | assert div_index > 0 70 | chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] 71 | rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] 72 | if return_outputs: # use the score on the EOS token for inference 73 | chosen_scores.append(chosen_rewards[i, chosen_length-1]) 74 | rejected_scores.append(rejected_rewards[i, rejected_length-1]) 75 | loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() 76 | 77 | loss = loss / batch_size 78 | if return_outputs: 79 | chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) 80 | return loss, [loss, chosen_scores, rejected_scores] 81 | 82 | return loss 83 | 84 | def save_predictions( 85 | self, 86 | predict_results: "PredictionOutput" 87 | ) -> None: 88 | r""" 89 | Saves model predictions to `output_dir`. 90 | 91 | A custom behavior that not contained in Seq2SeqTrainer. 92 | """ 93 | if not self.is_world_process_zero(): 94 | return 95 | 96 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 97 | logger.info(f"Saving prediction results to {output_prediction_file}") 98 | 99 | chosen_scores, rejected_scores = predict_results.predictions 100 | 101 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 102 | res: List[str] = [] 103 | for c_score, r_score in zip(chosen_scores, rejected_scores): 104 | res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) 105 | writer.write("\n".join(res)) 106 | -------------------------------------------------------------------------------- /llmtuner/tuner/rm/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: 2 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py 3 | 4 | from typing import TYPE_CHECKING, Optional, List 5 | from transformers import Seq2SeqTrainingArguments 6 | 7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 8 | from llmtuner.extras.callbacks import SavePeftModelCallback 9 | from llmtuner.extras.ploting import plot_loss 10 | from llmtuner.tuner.core import load_model_and_tokenizer 11 | from llmtuner.tuner.rm.metric import compute_accuracy 12 | from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding 13 | from llmtuner.tuner.rm.trainer import PairwiseTrainer 14 | 15 | if TYPE_CHECKING: 16 | from transformers import TrainerCallback 17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 18 | 19 | 20 | def run_rm( 21 | model_args: "ModelArguments", 22 | data_args: "DataArguments", 23 | training_args: "Seq2SeqTrainingArguments", 24 | finetuning_args: "FinetuningArguments", 25 | callbacks: Optional[List["TrainerCallback"]] = None 26 | ): 27 | dataset = get_dataset(model_args, data_args) 28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") 29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") 30 | data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4) 31 | 32 | training_args_dict = training_args.to_dict() 33 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset 34 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 35 | 36 | # Initialize our Trainer 37 | trainer = PairwiseTrainer( 38 | model=model, 39 | args=training_args, 40 | tokenizer=tokenizer, 41 | data_collator=data_collator, 42 | callbacks=callbacks + [SavePeftModelCallback()], 43 | compute_metrics=compute_accuracy, 44 | **split_dataset(dataset, data_args, training_args) 45 | ) 46 | 47 | # Training 48 | if training_args.do_train: 49 | train_result = trainer.train() 50 | trainer.log_metrics("train", train_result.metrics) 51 | trainer.save_metrics("train", train_result.metrics) 52 | trainer.save_state() 53 | trainer.save_model() 54 | if trainer.is_world_process_zero() and model_args.plot_loss: 55 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 56 | 57 | # Evaluation 58 | if training_args.do_eval: 59 | metrics = trainer.evaluate(metric_key_prefix="eval") 60 | trainer.log_metrics("eval", metrics) 61 | trainer.save_metrics("eval", metrics) 62 | 63 | # Predict 64 | if training_args.do_predict: 65 | predict_results = trainer.predict(dataset, metric_key_prefix="predict") 66 | trainer.log_metrics("predict", predict_results.metrics) 67 | trainer.save_metrics("predict", predict_results.metrics) 68 | trainer.save_predictions(predict_results) 69 | -------------------------------------------------------------------------------- /llmtuner/tuner/sft/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.sft.workflow import run_sft 2 | -------------------------------------------------------------------------------- /llmtuner/tuner/sft/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union 4 | 5 | import jieba 6 | from rouge_chinese import Rouge 7 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 8 | 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | 11 | if TYPE_CHECKING: 12 | from transformers.tokenization_utils import PreTrainedTokenizer 13 | 14 | 15 | @dataclass 16 | class ComputeMetrics: 17 | r""" 18 | Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. 19 | """ 20 | 21 | tokenizer: "PreTrainedTokenizer" 22 | 23 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 24 | r""" 25 | Uses the model predictions to compute metrics. 26 | """ 27 | preds, labels = eval_preds 28 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 29 | 30 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) 31 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 32 | 33 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 34 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 35 | 36 | for pred, label in zip(decoded_preds, decoded_labels): 37 | hypothesis = list(jieba.cut(pred)) 38 | reference = list(jieba.cut(label)) 39 | 40 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: 41 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 42 | else: 43 | rouge = Rouge() 44 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 45 | result = scores[0] 46 | 47 | for k, v in result.items(): 48 | score_dict[k].append(round(v["f"] * 100, 4)) 49 | 50 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 51 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 52 | 53 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 54 | -------------------------------------------------------------------------------- /llmtuner/tuner/sft/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 7 | from transformers import Seq2SeqTrainer 8 | 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | from llmtuner.extras.logging import get_logger 11 | 12 | if TYPE_CHECKING: 13 | from transformers.trainer import PredictionOutput 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | class CustomSeq2SeqTrainer(Seq2SeqTrainer): 20 | r""" 21 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. 22 | """ 23 | 24 | def prediction_step( 25 | self, 26 | model: nn.Module, 27 | inputs: Dict[str, Union[torch.Tensor, Any]], 28 | prediction_loss_only: bool, 29 | ignore_keys: Optional[List[str]] = None, 30 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 31 | r""" 32 | Removes the prompt part in the generated tokens. 33 | 34 | Subclass and override to inject custom behavior. 35 | """ 36 | labels = inputs["labels"].clone() if "labels" in inputs else None # backup labels 37 | if self.args.predict_with_generate: 38 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." 39 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) 40 | if prompt_len > label_len: 41 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) 42 | if label_len > prompt_len: 43 | inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs 44 | 45 | loss, generated_tokens, _ = super().prediction_step( 46 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 47 | ) 48 | if generated_tokens is not None and self.args.predict_with_generate: 49 | generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id 50 | generated_tokens = generated_tokens.contiguous() 51 | 52 | return loss, generated_tokens, labels 53 | 54 | def _pad_tensors_to_target_len( 55 | self, 56 | src_tensor: torch.Tensor, 57 | tgt_tensor: torch.Tensor 58 | ) -> torch.Tensor: 59 | r""" 60 | Pads the tensor to the same length as the target tensor. 61 | """ 62 | assert self.tokenizer.pad_token_id is not None, "Pad token is required." 63 | padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) 64 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding 65 | return padded_tensor.contiguous() # in contiguous memory 66 | 67 | def save_predictions( 68 | self, 69 | predict_results: "PredictionOutput" 70 | ) -> None: 71 | r""" 72 | Saves model predictions to `output_dir`. 73 | 74 | A custom behavior that not contained in Seq2SeqTrainer. 75 | """ 76 | if not self.is_world_process_zero(): 77 | return 78 | 79 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 80 | logger.info(f"Saving prediction results to {output_prediction_file}") 81 | 82 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) 83 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) 84 | 85 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) 86 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) 87 | 88 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 89 | res: List[str] = [] 90 | for pred, label in zip(decoded_preds, decoded_labels): 91 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 92 | writer.write("\n".join(res)) 93 | -------------------------------------------------------------------------------- /llmtuner/tuner/sft/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py 2 | 3 | from typing import TYPE_CHECKING, Optional, List 4 | from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments 5 | 6 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 7 | from llmtuner.extras.constants import IGNORE_INDEX 8 | from llmtuner.extras.misc import get_logits_processor 9 | from llmtuner.extras.ploting import plot_loss 10 | from llmtuner.tuner.core import load_model_and_tokenizer 11 | from llmtuner.tuner.sft.metric import ComputeMetrics 12 | from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer 13 | 14 | if TYPE_CHECKING: 15 | from transformers import TrainerCallback 16 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 17 | 18 | 19 | def run_sft( 20 | model_args: "ModelArguments", 21 | data_args: "DataArguments", 22 | training_args: "Seq2SeqTrainingArguments", 23 | finetuning_args: "FinetuningArguments", 24 | generating_args: "GeneratingArguments", 25 | callbacks: Optional[List["TrainerCallback"]] = None 26 | ): 27 | dataset = get_dataset(model_args, data_args) 28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") 29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") 30 | 31 | if training_args.predict_with_generate: 32 | tokenizer.padding_side = "left" # use left-padding in generation 33 | 34 | data_collator = DataCollatorForSeq2Seq( 35 | tokenizer=tokenizer, 36 | pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention 37 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 38 | ) 39 | 40 | # Override the decoding parameters of Seq2SeqTrainer 41 | training_args_dict = training_args.to_dict() 42 | training_args_dict.update(dict( 43 | generation_max_length=training_args.generation_max_length or data_args.cutoff_len, 44 | generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams 45 | )) 46 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 47 | 48 | # Initialize our Trainer 49 | trainer = CustomSeq2SeqTrainer( 50 | model=model, 51 | args=training_args, 52 | tokenizer=tokenizer, 53 | data_collator=data_collator, 54 | callbacks=callbacks, 55 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, 56 | **split_dataset(dataset, data_args, training_args) 57 | ) 58 | 59 | # Keyword arguments for `model.generate` 60 | gen_kwargs = generating_args.to_dict() 61 | gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids 62 | gen_kwargs["pad_token_id"] = tokenizer.pad_token_id 63 | gen_kwargs["logits_processor"] = get_logits_processor() 64 | 65 | # Training 66 | if training_args.do_train: 67 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 68 | trainer.log_metrics("train", train_result.metrics) 69 | trainer.save_metrics("train", train_result.metrics) 70 | trainer.save_state() 71 | trainer.save_model() 72 | if trainer.is_world_process_zero() and model_args.plot_loss: 73 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 74 | 75 | # Evaluation 76 | if training_args.do_eval: 77 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 78 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled 79 | metrics.pop("eval_loss", None) 80 | trainer.log_metrics("eval", metrics) 81 | trainer.save_metrics("eval", metrics) 82 | 83 | # Predict 84 | if training_args.do_predict: 85 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 86 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled 87 | predict_results.metrics.pop("predict_loss", None) 88 | trainer.log_metrics("predict", predict_results.metrics) 89 | trainer.save_metrics("predict", predict_results.metrics) 90 | trainer.save_predictions(predict_results) 91 | -------------------------------------------------------------------------------- /llmtuner/tuner/tune.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List, Optional 2 | 3 | from llmtuner.extras.callbacks import LogCallback 4 | from llmtuner.extras.logging import get_logger 5 | from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer 6 | from llmtuner.tuner.pt import run_pt 7 | from llmtuner.tuner.sft import run_sft 8 | from llmtuner.tuner.rm import run_rm 9 | from llmtuner.tuner.ppo import run_ppo 10 | from llmtuner.tuner.dpo import run_dpo 11 | 12 | if TYPE_CHECKING: 13 | from transformers import TrainerCallback 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): 20 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) 21 | callbacks = [LogCallback()] if callbacks is None else callbacks 22 | 23 | if finetuning_args.stage == "pt": 24 | run_pt(model_args, data_args, training_args, finetuning_args, callbacks) 25 | elif finetuning_args.stage == "sft": 26 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 27 | elif finetuning_args.stage == "rm": 28 | run_rm(model_args, data_args, training_args, finetuning_args, callbacks) 29 | elif finetuning_args.stage == "ppo": 30 | run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 31 | elif finetuning_args.stage == "dpo": 32 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) 33 | else: 34 | raise ValueError("Unknown task.") 35 | 36 | 37 | def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): 38 | model_args, _, finetuning_args, _ = get_infer_args(args) 39 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 40 | model.config.use_cache = True 41 | tokenizer.padding_side = "left" # restore padding side 42 | tokenizer.init_kwargs["padding_side"] = "left" 43 | model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size) 44 | try: 45 | tokenizer.save_pretrained(model_args.export_dir) 46 | except: 47 | logger.warning("Cannot save tokenizer, please copy the files manually.") 48 | 49 | 50 | if __name__ == "__main__": 51 | run_exp() 52 | -------------------------------------------------------------------------------- /llmtuner/webui/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.webui.interface import create_ui, create_web_demo 2 | -------------------------------------------------------------------------------- /llmtuner/webui/chatter.py: -------------------------------------------------------------------------------- 1 | from gradio.components import Component # cannot use TYPE_CHECKING here 2 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple 3 | 4 | from llmtuner.chat.stream_chat import ChatModel 5 | from llmtuner.extras.misc import torch_gc 6 | from llmtuner.hparams import GeneratingArguments 7 | from llmtuner.webui.common import get_save_dir 8 | from llmtuner.webui.locales import ALERTS 9 | 10 | if TYPE_CHECKING: 11 | from llmtuner.webui.manager import Manager 12 | 13 | 14 | class WebChatModel(ChatModel): 15 | 16 | def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None: 17 | self.manager = manager 18 | self.model = None 19 | self.tokenizer = None 20 | self.generating_args = GeneratingArguments() 21 | if not lazy_init: 22 | super().__init__() 23 | 24 | @property 25 | def loaded(self) -> bool: 26 | return self.model is not None 27 | 28 | def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: 29 | get = lambda name: data[self.manager.get_elem(name)] 30 | lang = get("top.lang") 31 | 32 | if self.loaded: 33 | yield ALERTS["err_exists"][lang] 34 | return 35 | 36 | if not get("top.model_name"): 37 | yield ALERTS["err_no_model"][lang] 38 | return 39 | 40 | if not get("top.model_path"): 41 | yield ALERTS["err_no_path"][lang] 42 | return 43 | 44 | if get("top.checkpoints"): 45 | checkpoint_dir = ",".join([ 46 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") 47 | ]) 48 | else: 49 | checkpoint_dir = None 50 | 51 | yield ALERTS["info_loading"][lang] 52 | args = dict( 53 | model_name_or_path=get("top.model_path"), 54 | checkpoint_dir=checkpoint_dir, 55 | finetuning_type=get("top.finetuning_type"), 56 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, 57 | template=get("top.template"), 58 | system_prompt=get("top.system_prompt"), 59 | flash_attn=get("top.flash_attn"), 60 | shift_attn=get("top.shift_attn"), 61 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None 62 | ) 63 | super().__init__(args) 64 | 65 | yield ALERTS["info_loaded"][lang] 66 | 67 | def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]: 68 | get = lambda name: data[self.manager.get_elem(name)] 69 | lang = get("top.lang") 70 | 71 | yield ALERTS["info_unloading"][lang] 72 | self.model = None 73 | self.tokenizer = None 74 | torch_gc() 75 | yield ALERTS["info_unloaded"][lang] 76 | 77 | def predict( 78 | self, 79 | chatbot: List[Tuple[str, str]], 80 | query: str, 81 | history: List[Tuple[str, str]], 82 | system: str, 83 | max_new_tokens: int, 84 | top_p: float, 85 | temperature: float 86 | ) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]: 87 | chatbot.append([query, ""]) 88 | response = "" 89 | for new_text in self.stream_chat( 90 | query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature 91 | ): 92 | response += new_text 93 | new_history = history + [(query, response)] 94 | chatbot[-1] = [query, self.postprocess(response)] 95 | yield chatbot, new_history 96 | 97 | def postprocess(self, response: str) -> str: 98 | blocks = response.split("```") 99 | for i, block in enumerate(blocks): 100 | if i % 2 == 0: 101 | blocks[i] = block.replace("<", "<").replace(">", ">") 102 | return "```".join(blocks) 103 | -------------------------------------------------------------------------------- /llmtuner/webui/common.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gradio as gr 4 | from typing import Any, Dict, Optional 5 | from transformers.utils import ( 6 | WEIGHTS_NAME, 7 | WEIGHTS_INDEX_NAME, 8 | SAFE_WEIGHTS_NAME, 9 | SAFE_WEIGHTS_INDEX_NAME, 10 | ADAPTER_WEIGHTS_NAME, 11 | ADAPTER_SAFE_WEIGHTS_NAME 12 | ) 13 | 14 | from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES 15 | 16 | 17 | DEFAULT_CACHE_DIR = "cache" 18 | DEFAULT_DATA_DIR = "data" 19 | DEFAULT_SAVE_DIR = "saves" 20 | USER_CONFIG = "user.config" 21 | DATA_CONFIG = "dataset_info.json" 22 | CKPT_NAMES = [ 23 | WEIGHTS_NAME, 24 | WEIGHTS_INDEX_NAME, 25 | SAFE_WEIGHTS_NAME, 26 | SAFE_WEIGHTS_INDEX_NAME, 27 | ADAPTER_WEIGHTS_NAME, 28 | ADAPTER_SAFE_WEIGHTS_NAME 29 | ] 30 | 31 | 32 | def get_save_dir(*args) -> os.PathLike: 33 | return os.path.join(DEFAULT_SAVE_DIR, *args) 34 | 35 | 36 | def get_config_path() -> os.PathLike: 37 | return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) 38 | 39 | 40 | def load_config() -> Dict[str, Any]: 41 | try: 42 | with open(get_config_path(), "r", encoding="utf-8") as f: 43 | return json.load(f) 44 | except: 45 | return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} 46 | 47 | 48 | def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None: 49 | os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) 50 | user_config = load_config() 51 | user_config["lang"] = lang or user_config["lang"] 52 | if model_name: 53 | user_config["last_model"] = model_name 54 | user_config["path_dict"][model_name] = model_path 55 | with open(get_config_path(), "w", encoding="utf-8") as f: 56 | json.dump(user_config, f, indent=2, ensure_ascii=False) 57 | 58 | 59 | def get_model_path(model_name: str) -> str: 60 | user_config = load_config() 61 | return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "") 62 | 63 | 64 | def get_module(model_name: str) -> str: 65 | return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj") 66 | 67 | 68 | def get_template(model_name: str) -> str: 69 | if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: 70 | return DEFAULT_TEMPLATE[model_name.split("-")[0]] 71 | return "default" 72 | 73 | 74 | def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: 75 | checkpoints = [] 76 | save_dir = get_save_dir(model_name, finetuning_type) 77 | if save_dir and os.path.isdir(save_dir): 78 | for checkpoint in os.listdir(save_dir): 79 | if ( 80 | os.path.isdir(os.path.join(save_dir, checkpoint)) 81 | and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES]) 82 | ): 83 | checkpoints.append(checkpoint) 84 | return gr.update(value=[], choices=checkpoints) 85 | 86 | 87 | def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: 88 | try: 89 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 90 | return json.load(f) 91 | except: 92 | print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir)) 93 | return {} 94 | 95 | 96 | def list_dataset( 97 | dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0] 98 | ) -> Dict[str, Any]: 99 | dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) 100 | ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] 101 | datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] 102 | return gr.update(value=[], choices=datasets) 103 | -------------------------------------------------------------------------------- /llmtuner/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.webui.components.top import create_top 2 | from llmtuner.webui.components.train import create_train_tab 3 | from llmtuner.webui.components.eval import create_eval_tab 4 | from llmtuner.webui.components.infer import create_infer_tab 5 | from llmtuner.webui.components.export import create_export_tab 6 | from llmtuner.webui.components.chatbot import create_chat_box 7 | -------------------------------------------------------------------------------- /llmtuner/webui/components/chatbot.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict, Optional, Tuple 3 | 4 | if TYPE_CHECKING: 5 | from gradio.blocks import Block 6 | from gradio.components import Component 7 | from llmtuner.webui.engine import Engine 8 | 9 | 10 | def create_chat_box( 11 | engine: "Engine", 12 | visible: Optional[bool] = False 13 | ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: 14 | elem_dict = dict() 15 | 16 | with gr.Box(visible=visible) as chat_box: 17 | chatbot = gr.Chatbot() 18 | 19 | with gr.Row(): 20 | with gr.Column(scale=4): 21 | system = gr.Textbox(show_label=False) 22 | query = gr.Textbox(show_label=False, lines=8) 23 | submit_btn = gr.Button(variant="primary") 24 | 25 | with gr.Column(scale=1): 26 | clear_btn = gr.Button() 27 | gen_kwargs = engine.chatter.generating_args 28 | max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1) 29 | top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01) 30 | temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01) 31 | 32 | elem_dict.update(dict( 33 | system=system, query=query, submit_btn=submit_btn, clear_btn=clear_btn, 34 | max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature 35 | )) 36 | 37 | history = gr.State([]) 38 | 39 | submit_btn.click( 40 | engine.chatter.predict, 41 | [chatbot, query, history, system, max_new_tokens, top_p, temperature], 42 | [chatbot, history], 43 | show_progress=True 44 | ).then( 45 | lambda: gr.update(value=""), outputs=[query] 46 | ) 47 | 48 | clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) 49 | 50 | return chat_box, chatbot, history, elem_dict 51 | -------------------------------------------------------------------------------- /llmtuner/webui/components/data.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Tuple 3 | 4 | if TYPE_CHECKING: 5 | from gradio.blocks import Block 6 | from gradio.components import Component 7 | 8 | 9 | def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: 10 | with gr.Column(visible=False, elem_classes="modal-box") as preview_box: 11 | preview_count = gr.Number(interactive=False) 12 | preview_samples = gr.JSON(interactive=False) 13 | close_btn = gr.Button() 14 | 15 | close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) 16 | 17 | return preview_box, preview_count, preview_samples, close_btn 18 | -------------------------------------------------------------------------------- /llmtuner/webui/components/eval.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict 3 | 4 | from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR 5 | from llmtuner.webui.components.data import create_preview_box 6 | from llmtuner.webui.utils import can_preview, get_preview 7 | 8 | if TYPE_CHECKING: 9 | from gradio.components import Component 10 | from llmtuner.webui.engine import Engine 11 | 12 | 13 | def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]: 14 | input_elems = engine.manager.get_base_elems() 15 | elem_dict = dict() 16 | 17 | with gr.Row(): 18 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 19 | dataset = gr.Dropdown(multiselect=True, scale=4) 20 | data_preview_btn = gr.Button(interactive=False, scale=1) 21 | 22 | dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False) 23 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False) 24 | 25 | input_elems.update({dataset_dir, dataset}) 26 | elem_dict.update(dict( 27 | dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn 28 | )) 29 | 30 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 31 | 32 | data_preview_btn.click( 33 | get_preview, 34 | [dataset_dir, dataset], 35 | [preview_count, preview_samples, preview_box], 36 | queue=False 37 | ) 38 | 39 | elem_dict.update(dict( 40 | preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn 41 | )) 42 | 43 | with gr.Row(): 44 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) 45 | max_samples = gr.Textbox(value="100000") 46 | batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) 47 | predict = gr.Checkbox(value=True) 48 | 49 | input_elems.update({cutoff_len, max_samples, batch_size, predict}) 50 | elem_dict.update(dict( 51 | cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict 52 | )) 53 | 54 | with gr.Row(): 55 | max_new_tokens = gr.Slider(10, 2048, value=128, step=1) 56 | top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) 57 | temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) 58 | 59 | input_elems.update({max_new_tokens, top_p, temperature}) 60 | elem_dict.update(dict( 61 | max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature 62 | )) 63 | 64 | with gr.Row(): 65 | cmd_preview_btn = gr.Button() 66 | start_btn = gr.Button() 67 | stop_btn = gr.Button() 68 | 69 | with gr.Row(): 70 | resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) 71 | process_bar = gr.Slider(visible=False, interactive=False) 72 | 73 | with gr.Box(): 74 | output_box = gr.Markdown() 75 | 76 | output_elems = [output_box, process_bar] 77 | elem_dict.update(dict( 78 | cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, 79 | resume_btn=resume_btn, process_bar=process_bar, output_box=output_box 80 | )) 81 | 82 | cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems) 83 | start_btn.click(engine.runner.run_eval, input_elems, output_elems) 84 | stop_btn.click(engine.runner.set_abort, queue=False) 85 | resume_btn.change(engine.runner.monitor, outputs=output_elems) 86 | 87 | return elem_dict 88 | -------------------------------------------------------------------------------- /llmtuner/webui/components/export.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict 3 | 4 | from llmtuner.webui.utils import save_model 5 | 6 | if TYPE_CHECKING: 7 | from gradio.components import Component 8 | from llmtuner.webui.engine import Engine 9 | 10 | 11 | def create_export_tab(engine: "Engine") -> Dict[str, "Component"]: 12 | elem_dict = dict() 13 | 14 | with gr.Row(): 15 | export_dir = gr.Textbox() 16 | max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) 17 | 18 | export_btn = gr.Button() 19 | info_box = gr.Textbox(show_label=False, interactive=False) 20 | 21 | export_btn.click( 22 | save_model, 23 | [ 24 | engine.manager.get_elem("top.lang"), 25 | engine.manager.get_elem("top.model_name"), 26 | engine.manager.get_elem("top.model_path"), 27 | engine.manager.get_elem("top.checkpoints"), 28 | engine.manager.get_elem("top.finetuning_type"), 29 | engine.manager.get_elem("top.template"), 30 | max_shard_size, 31 | export_dir 32 | ], 33 | [info_box] 34 | ) 35 | 36 | elem_dict.update(dict( 37 | export_dir=export_dir, 38 | max_shard_size=max_shard_size, 39 | export_btn=export_btn, 40 | info_box=info_box 41 | )) 42 | 43 | return elem_dict 44 | -------------------------------------------------------------------------------- /llmtuner/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict 3 | 4 | from llmtuner.webui.components.chatbot import create_chat_box 5 | 6 | if TYPE_CHECKING: 7 | from gradio.components import Component 8 | from llmtuner.webui.engine import Engine 9 | 10 | 11 | def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]: 12 | input_elems = engine.manager.get_base_elems() 13 | elem_dict = dict() 14 | 15 | with gr.Row(): 16 | load_btn = gr.Button() 17 | unload_btn = gr.Button() 18 | 19 | info_box = gr.Textbox(show_label=False, interactive=False) 20 | elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box)) 21 | 22 | chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False) 23 | elem_dict.update(dict(chat_box=chat_box, **chat_elems)) 24 | 25 | load_btn.click( 26 | engine.chatter.load_model, input_elems, [info_box] 27 | ).then( 28 | lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] 29 | ) 30 | 31 | unload_btn.click( 32 | engine.chatter.unload_model, input_elems, [info_box] 33 | ).then( 34 | lambda: ([], []), outputs=[chatbot, history] 35 | ).then( 36 | lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box] 37 | ) 38 | 39 | return elem_dict 40 | -------------------------------------------------------------------------------- /llmtuner/webui/components/top.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict 3 | 4 | from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS 5 | from llmtuner.extras.template import templates 6 | from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config 7 | from llmtuner.webui.utils import can_quantize 8 | 9 | if TYPE_CHECKING: 10 | from gradio.components import Component 11 | 12 | 13 | def create_top() -> Dict[str, "Component"]: 14 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] 15 | 16 | with gr.Row(): 17 | lang = gr.Dropdown(choices=["en", "zh"], scale=1) 18 | model_name = gr.Dropdown(choices=available_models, scale=3) 19 | model_path = gr.Textbox(scale=3) 20 | 21 | with gr.Row(): 22 | finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) 23 | checkpoints = gr.Dropdown(multiselect=True, scale=5) 24 | refresh_btn = gr.Button(scale=1) 25 | 26 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 27 | with gr.Row(): 28 | quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1) 29 | template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) 30 | system_prompt = gr.Textbox(scale=2) 31 | 32 | with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab: 33 | with gr.Row(): 34 | with gr.Column(): 35 | flash_attn = gr.Checkbox(value=False) 36 | shift_attn = gr.Checkbox(value=False) 37 | rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none") 38 | 39 | model_name.change( 40 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False 41 | ).then( 42 | get_model_path, [model_name], [model_path], queue=False 43 | ).then( 44 | get_template, [model_name], [template], queue=False 45 | ) # do not save config since the below line will save 46 | 47 | model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False) 48 | 49 | finetuning_type.change( 50 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False 51 | ).then( 52 | can_quantize, [finetuning_type], [quantization_bit], queue=False 53 | ) 54 | 55 | refresh_btn.click( 56 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False 57 | ) 58 | 59 | return dict( 60 | lang=lang, 61 | model_name=model_name, 62 | model_path=model_path, 63 | finetuning_type=finetuning_type, 64 | checkpoints=checkpoints, 65 | refresh_btn=refresh_btn, 66 | advanced_tab=advanced_tab, 67 | quantization_bit=quantization_bit, 68 | template=template, 69 | system_prompt=system_prompt, 70 | llama_tab=llama_tab, 71 | flash_attn=flash_attn, 72 | shift_attn=shift_attn, 73 | rope_scaling=rope_scaling 74 | ) 75 | -------------------------------------------------------------------------------- /llmtuner/webui/components/train.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Dict 3 | from transformers.trainer_utils import SchedulerType 4 | 5 | from llmtuner.extras.constants import TRAINING_STAGES 6 | from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR 7 | from llmtuner.webui.components.data import create_preview_box 8 | from llmtuner.webui.utils import can_preview, get_preview, gen_plot 9 | 10 | if TYPE_CHECKING: 11 | from gradio.components import Component 12 | from llmtuner.webui.engine import Engine 13 | 14 | 15 | def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: 16 | input_elems = engine.manager.get_base_elems() 17 | elem_dict = dict() 18 | 19 | with gr.Row(): 20 | training_stage = gr.Dropdown( 21 | choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 22 | ) 23 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 24 | dataset = gr.Dropdown(multiselect=True, scale=4) 25 | data_preview_btn = gr.Button(interactive=False, scale=1) 26 | 27 | training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) 28 | dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False) 29 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False) 30 | 31 | input_elems.update({training_stage, dataset_dir, dataset}) 32 | elem_dict.update(dict( 33 | training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, data_preview_btn=data_preview_btn 34 | )) 35 | 36 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 37 | 38 | data_preview_btn.click( 39 | get_preview, 40 | [dataset_dir, dataset], 41 | [preview_count, preview_samples, preview_box], 42 | queue=False 43 | ) 44 | 45 | elem_dict.update(dict( 46 | preview_count=preview_count, preview_samples=preview_samples, close_btn=close_btn 47 | )) 48 | 49 | with gr.Row(): 50 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) 51 | learning_rate = gr.Textbox(value="5e-5") 52 | num_train_epochs = gr.Textbox(value="3.0") 53 | max_samples = gr.Textbox(value="100000") 54 | compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") 55 | 56 | input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type}) 57 | elem_dict.update(dict( 58 | cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs, 59 | max_samples=max_samples, compute_type=compute_type 60 | )) 61 | 62 | with gr.Row(): 63 | batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) 64 | gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) 65 | lr_scheduler_type = gr.Dropdown( 66 | choices=[scheduler.value for scheduler in SchedulerType], value="cosine" 67 | ) 68 | max_grad_norm = gr.Textbox(value="1.0") 69 | val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) 70 | 71 | input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size}) 72 | elem_dict.update(dict( 73 | batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, 74 | lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size 75 | )) 76 | 77 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 78 | with gr.Row(): 79 | logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) 80 | save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) 81 | warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) 82 | neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1) 83 | 84 | with gr.Column(): 85 | train_on_prompt = gr.Checkbox(value=False) 86 | upcast_layernorm = gr.Checkbox(value=False) 87 | 88 | input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm}) 89 | elem_dict.update(dict( 90 | advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps, 91 | neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm 92 | )) 93 | 94 | with gr.Accordion(label="LoRA config", open=False) as lora_tab: 95 | with gr.Row(): 96 | lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) 97 | lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) 98 | lora_target = gr.Textbox(scale=1) 99 | additional_target = gr.Textbox(scale=1) 100 | resume_lora_training = gr.Checkbox(value=True, scale=1) 101 | 102 | input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training}) 103 | elem_dict.update(dict( 104 | lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target, 105 | additional_target=additional_target, resume_lora_training=resume_lora_training, 106 | )) 107 | 108 | with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: 109 | with gr.Row(): 110 | dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) 111 | reward_model = gr.Dropdown(scale=3) 112 | refresh_btn = gr.Button(scale=1) 113 | 114 | refresh_btn.click( 115 | list_checkpoint, 116 | [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type")], 117 | [reward_model], 118 | queue=False 119 | ) 120 | 121 | input_elems.update({dpo_beta, reward_model}) 122 | elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn)) 123 | 124 | with gr.Row(): 125 | cmd_preview_btn = gr.Button() 126 | start_btn = gr.Button() 127 | stop_btn = gr.Button() 128 | 129 | with gr.Row(): 130 | with gr.Column(scale=3): 131 | with gr.Row(): 132 | output_dir = gr.Textbox() 133 | 134 | with gr.Row(): 135 | resume_btn = gr.Checkbox(visible=False, interactive=False, value=False) 136 | process_bar = gr.Slider(visible=False, interactive=False) 137 | 138 | with gr.Box(): 139 | output_box = gr.Markdown() 140 | 141 | with gr.Column(scale=1): 142 | loss_viewer = gr.Plot() 143 | 144 | input_elems.add(output_dir) 145 | output_elems = [output_box, process_bar] 146 | elem_dict.update(dict( 147 | cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir, 148 | resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer 149 | )) 150 | 151 | cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems) 152 | start_btn.click(engine.runner.run_train, input_elems, output_elems) 153 | stop_btn.click(engine.runner.set_abort, queue=False) 154 | resume_btn.change(engine.runner.monitor, outputs=output_elems) 155 | 156 | output_box.change( 157 | gen_plot, 158 | [engine.manager.get_elem("top.model_name"), engine.manager.get_elem("top.finetuning_type"), output_dir], 159 | loss_viewer, 160 | queue=False 161 | ) 162 | 163 | return elem_dict 164 | -------------------------------------------------------------------------------- /llmtuner/webui/css.py: -------------------------------------------------------------------------------- 1 | CSS = r""" 2 | .modal-box { 3 | position: fixed !important; 4 | top: 50%; 5 | left: 50%; 6 | transform: translate(-50%, -50%); /* center horizontally */ 7 | max-width: 1000px; 8 | max-height: 750px; 9 | background-color: var(--input-background-fill); 10 | border: 2px solid black !important; 11 | z-index: 1000; 12 | padding: 10px; 13 | } 14 | 15 | .dark .modal-box { 16 | border: 2px solid white !important; 17 | } 18 | """ 19 | -------------------------------------------------------------------------------- /llmtuner/webui/engine.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from gradio.components import Component # cannot use TYPE_CHECKING here 3 | from typing import Any, Dict, Generator, Optional 4 | 5 | from llmtuner.webui.chatter import WebChatModel 6 | from llmtuner.webui.common import get_model_path, list_dataset, load_config 7 | from llmtuner.webui.locales import LOCALES 8 | from llmtuner.webui.manager import Manager 9 | from llmtuner.webui.runner import Runner 10 | from llmtuner.webui.utils import get_time 11 | 12 | 13 | class Engine: 14 | 15 | def __init__(self, pure_chat: Optional[bool] = False) -> None: 16 | self.pure_chat = pure_chat 17 | self.manager: "Manager" = Manager() 18 | self.runner: "Runner" = Runner(self.manager) 19 | self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat)) 20 | 21 | def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]): 22 | return {self.manager.get_elem(k): gr.update(**v) for k, v in resume_dict.items()} 23 | 24 | def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]: 25 | user_config = load_config() 26 | lang = user_config.get("lang", None) or "en" 27 | 28 | init_dict = { 29 | "top.lang": {"value": lang}, 30 | "infer.chat_box": {"visible": self.chatter.loaded} 31 | } 32 | 33 | if not self.pure_chat: 34 | init_dict["train.dataset"] = {"choices": list_dataset()["choices"]} 35 | init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]} 36 | 37 | if user_config.get("last_model", None): 38 | init_dict["top.model_name"] = {"value": user_config["last_model"]} 39 | init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])} 40 | 41 | yield self._form_dict(init_dict) 42 | 43 | if not self.pure_chat: 44 | if self.runner.alive: 45 | yield {elem: gr.update(value=value) for elem, value in self.runner.data.items()} 46 | if self.runner.do_train: 47 | yield self._form_dict({"train.resume_btn": {"value": True}}) 48 | else: 49 | yield self._form_dict({"eval.resume_btn": {"value": True}}) 50 | else: 51 | yield self._form_dict({"train.output_dir": {"value": get_time()}}) 52 | 53 | def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]: 54 | return { 55 | component: gr.update(**LOCALES[name][lang]) 56 | for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES 57 | } 58 | -------------------------------------------------------------------------------- /llmtuner/webui/interface.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from transformers.utils.versions import require_version 3 | 4 | from llmtuner.webui.components import ( 5 | create_top, 6 | create_train_tab, 7 | create_eval_tab, 8 | create_infer_tab, 9 | create_export_tab, 10 | create_chat_box 11 | ) 12 | from llmtuner.webui.common import save_config 13 | from llmtuner.webui.css import CSS 14 | from llmtuner.webui.engine import Engine 15 | 16 | 17 | require_version("gradio==3.50.2", "To fix: pip install gradio==3.50.2") 18 | 19 | 20 | def create_ui() -> gr.Blocks: 21 | engine = Engine(pure_chat=False) 22 | 23 | with gr.Blocks(title="LLaMA Board", css=CSS) as demo: 24 | engine.manager.all_elems["top"] = create_top() 25 | lang: "gr.Dropdown" = engine.manager.get_elem("top.lang") 26 | 27 | with gr.Tab("Train"): 28 | engine.manager.all_elems["train"] = create_train_tab(engine) 29 | 30 | with gr.Tab("Evaluate"): 31 | engine.manager.all_elems["eval"] = create_eval_tab(engine) 32 | 33 | with gr.Tab("Chat"): 34 | engine.manager.all_elems["infer"] = create_infer_tab(engine) 35 | 36 | with gr.Tab("Export"): 37 | engine.manager.all_elems["export"] = create_export_tab(engine) 38 | 39 | demo.load(engine.resume, outputs=engine.manager.list_elems()) 40 | lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) 41 | lang.input(save_config, inputs=[lang], queue=False) 42 | 43 | return demo 44 | 45 | 46 | def create_web_demo() -> gr.Blocks: 47 | engine = Engine(pure_chat=True) 48 | 49 | with gr.Blocks(title="Web Demo", css=CSS) as demo: 50 | lang = gr.Dropdown(choices=["en", "zh"]) 51 | engine.manager.all_elems["top"] = dict(lang=lang) 52 | 53 | chat_box, _, _, chat_elems = create_chat_box(engine, visible=True) 54 | engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems) 55 | 56 | demo.load(engine.resume, outputs=engine.manager.list_elems()) 57 | lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False) 58 | lang.input(save_config, inputs=[lang], queue=False) 59 | 60 | return demo 61 | 62 | 63 | if __name__ == "__main__": 64 | demo = create_ui() 65 | demo.queue() 66 | demo.launch(server_name="0.0.0.0", server_port=7866, share=False, inbrowser=True) 67 | -------------------------------------------------------------------------------- /llmtuner/webui/manager.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict, List 2 | 3 | if TYPE_CHECKING: 4 | from gradio.components import Component 5 | 6 | 7 | class Manager: 8 | 9 | def __init__(self) -> None: 10 | self.all_elems: Dict[str, Dict[str, "Component"]] = {} 11 | 12 | def get_elem(self, name: str) -> "Component": 13 | r""" 14 | Example: top.lang, train.dataset 15 | """ 16 | tab_name, elem_name = name.split(".") 17 | return self.all_elems[tab_name][elem_name] 18 | 19 | def get_base_elems(self): 20 | return { 21 | self.all_elems["top"]["lang"], 22 | self.all_elems["top"]["model_name"], 23 | self.all_elems["top"]["model_path"], 24 | self.all_elems["top"]["checkpoints"], 25 | self.all_elems["top"]["finetuning_type"], 26 | self.all_elems["top"]["quantization_bit"], 27 | self.all_elems["top"]["template"], 28 | self.all_elems["top"]["system_prompt"], 29 | self.all_elems["top"]["flash_attn"], 30 | self.all_elems["top"]["shift_attn"], 31 | self.all_elems["top"]["rope_scaling"] 32 | } 33 | 34 | def list_elems(self) -> List["Component"]: 35 | return [elem for elems in self.all_elems.values() for elem in elems.values()] 36 | -------------------------------------------------------------------------------- /llmtuner/webui/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import logging 4 | import gradio as gr 5 | from threading import Thread 6 | from gradio.components import Component # cannot use TYPE_CHECKING here 7 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple 8 | 9 | import transformers 10 | from transformers.trainer import TRAINING_ARGS_NAME 11 | 12 | from llmtuner.extras.callbacks import LogCallback 13 | from llmtuner.extras.constants import TRAINING_STAGES 14 | from llmtuner.extras.logging import LoggerHandler 15 | from llmtuner.extras.misc import torch_gc 16 | from llmtuner.tuner import run_exp 17 | from llmtuner.webui.common import get_module, get_save_dir, load_config 18 | from llmtuner.webui.locales import ALERTS 19 | from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar 20 | 21 | if TYPE_CHECKING: 22 | from llmtuner.webui.manager import Manager 23 | 24 | 25 | class Runner: 26 | 27 | def __init__(self, manager: "Manager") -> None: 28 | self.manager = manager 29 | self.thread: "Thread" = None 30 | self.data: Dict["Component", Any] = None 31 | self.do_train = True 32 | self.monitor_inputs: Dict[str, str] = None 33 | self.aborted = False 34 | self.running = False 35 | self.logger_handler = LoggerHandler() 36 | self.logger_handler.setLevel(logging.INFO) 37 | logging.root.addHandler(self.logger_handler) 38 | transformers.logging.add_handler(self.logger_handler) 39 | 40 | @property 41 | def alive(self) -> bool: 42 | return self.thread is not None 43 | 44 | def set_abort(self) -> None: 45 | self.aborted = True 46 | self.running = False 47 | 48 | def _initialize(self, lang: str, model_name: str, model_path: str, dataset: List[str]) -> str: 49 | if self.running: 50 | return ALERTS["err_conflict"][lang] 51 | 52 | if not model_name: 53 | return ALERTS["err_no_model"][lang] 54 | 55 | if not model_path: 56 | return ALERTS["err_no_path"][lang] 57 | 58 | if len(dataset) == 0: 59 | return ALERTS["err_no_dataset"][lang] 60 | 61 | self.aborted = False 62 | self.logger_handler.reset() 63 | self.trainer_callback = LogCallback(self) 64 | return "" 65 | 66 | def _finalize(self, lang: str, finish_info: str) -> str: 67 | self.thread = None 68 | self.running = False 69 | torch_gc() 70 | if self.aborted: 71 | return ALERTS["info_aborted"][lang] 72 | else: 73 | return finish_info 74 | 75 | def _parse_train_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: 76 | get = lambda name: data[self.manager.get_elem(name)] 77 | user_config = load_config() 78 | 79 | if get("top.checkpoints"): 80 | checkpoint_dir = ",".join([ 81 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") 82 | ]) 83 | else: 84 | checkpoint_dir = None 85 | 86 | output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir")) 87 | 88 | args = dict( 89 | stage=TRAINING_STAGES[get("train.training_stage")], 90 | model_name_or_path=get("top.model_path"), 91 | do_train=True, 92 | cache_dir=user_config.get("cache_dir", None), 93 | checkpoint_dir=checkpoint_dir, 94 | finetuning_type=get("top.finetuning_type"), 95 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, 96 | template=get("top.template"), 97 | system_prompt=get("top.system_prompt"), 98 | flash_attn=get("top.flash_attn"), 99 | shift_attn=get("top.shift_attn"), 100 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, 101 | dataset_dir=get("train.dataset_dir"), 102 | dataset=",".join(get("train.dataset")), 103 | cutoff_len=get("train.cutoff_len"), 104 | learning_rate=float(get("train.learning_rate")), 105 | num_train_epochs=float(get("train.num_train_epochs")), 106 | max_samples=int(get("train.max_samples")), 107 | per_device_train_batch_size=get("train.batch_size"), 108 | gradient_accumulation_steps=get("train.gradient_accumulation_steps"), 109 | lr_scheduler_type=get("train.lr_scheduler_type"), 110 | max_grad_norm=float(get("train.max_grad_norm")), 111 | logging_steps=get("train.logging_steps"), 112 | save_steps=get("train.save_steps"), 113 | warmup_steps=get("train.warmup_steps"), 114 | neft_alpha=get("train.neft_alpha"), 115 | train_on_prompt=get("train.train_on_prompt"), 116 | upcast_layernorm=get("train.upcast_layernorm"), 117 | lora_rank=get("train.lora_rank"), 118 | lora_dropout=get("train.lora_dropout"), 119 | lora_target=get("train.lora_target") or get_module(get("top.model_name")), 120 | additional_target=get("train.additional_target") if get("train.additional_target") else None, 121 | resume_lora_training=get("train.resume_lora_training"), 122 | output_dir=output_dir 123 | ) 124 | args[get("train.compute_type")] = True 125 | args["disable_tqdm"] = True 126 | 127 | if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]: 128 | args["resume_lora_training"] = (args["quantization_bit"] is not None) 129 | 130 | if args["quantization_bit"] is not None: 131 | args["upcast_layernorm"] = True 132 | 133 | if args["stage"] == "ppo": 134 | args["reward_model"] = get("train.reward_model") 135 | 136 | if args["stage"] == "dpo": 137 | args["dpo_beta"] = get("train.dpo_beta") 138 | 139 | if get("train.val_size") > 1e-6 and args["stage"] != "ppo": 140 | args["val_size"] = get("train.val_size") 141 | args["evaluation_strategy"] = "steps" 142 | args["eval_steps"] = get("train.save_steps") 143 | args["load_best_model_at_end"] = True 144 | 145 | return get("top.lang"), get("top.model_name"), get("top.model_path"), get("train.dataset"), output_dir, args 146 | 147 | def _parse_eval_args(self, data: Dict[Component, Any]) -> Tuple[str, str, str, List[str], str, Dict[str, Any]]: 148 | get = lambda name: data[self.manager.get_elem(name)] 149 | user_config = load_config() 150 | 151 | if get("top.checkpoints"): 152 | checkpoint_dir = ",".join([ 153 | get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints") 154 | ]) 155 | output_dir = get_save_dir( 156 | get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints")) 157 | ) 158 | else: 159 | checkpoint_dir = None 160 | output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base") 161 | 162 | args = dict( 163 | stage="sft", 164 | model_name_or_path=get("top.model_path"), 165 | do_eval=True, 166 | predict_with_generate=True, 167 | cache_dir=user_config.get("cache_dir", None), 168 | checkpoint_dir=checkpoint_dir, 169 | finetuning_type=get("top.finetuning_type"), 170 | quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None, 171 | template=get("top.template"), 172 | system_prompt=get("top.system_prompt"), 173 | flash_attn=get("top.flash_attn"), 174 | shift_attn=get("top.shift_attn"), 175 | rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None, 176 | dataset_dir=get("eval.dataset_dir"), 177 | dataset=",".join(get("eval.dataset")), 178 | cutoff_len=get("eval.cutoff_len"), 179 | max_samples=int(get("eval.max_samples")), 180 | per_device_eval_batch_size=get("eval.batch_size"), 181 | max_new_tokens=get("eval.max_new_tokens"), 182 | top_p=get("eval.top_p"), 183 | temperature=get("eval.temperature"), 184 | output_dir=output_dir 185 | ) 186 | 187 | if get("eval.predict"): 188 | args.pop("do_eval", None) 189 | args["do_predict"] = True 190 | 191 | return get("top.lang"), get("top.model_name"), get("top.model_path"), get("eval.dataset"), output_dir, args 192 | 193 | def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 194 | parse_func = self._parse_train_args if do_train else self._parse_eval_args 195 | lang, model_name, model_path, dataset, _, args = parse_func(data) 196 | error = self._initialize(lang, model_name, model_path, dataset) 197 | if error: 198 | yield error, gr.update(visible=False) 199 | else: 200 | yield gen_cmd(args), gr.update(visible=False) 201 | 202 | def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 203 | parse_func = self._parse_train_args if do_train else self._parse_eval_args 204 | lang, model_name, model_path, dataset, output_dir, args = parse_func(data) 205 | self.data, self.do_train, self.monitor_inputs = data, do_train, dict(lang=lang, output_dir=output_dir) 206 | error = self._initialize(lang, model_name, model_path, dataset) 207 | if error: 208 | yield error, gr.update(visible=False) 209 | else: 210 | self.running = True 211 | run_kwargs = dict(args=args, callbacks=[self.trainer_callback]) 212 | self.thread = Thread(target=run_exp, kwargs=run_kwargs) 213 | self.thread.start() 214 | yield from self.monitor() 215 | 216 | def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 217 | yield from self._preview(data, do_train=True) 218 | 219 | def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 220 | yield from self._preview(data, do_train=False) 221 | 222 | def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 223 | yield from self._launch(data, do_train=True) 224 | 225 | def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 226 | yield from self._launch(data, do_train=False) 227 | 228 | def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]: 229 | lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"] 230 | while self.thread.is_alive(): 231 | time.sleep(2) 232 | if self.aborted: 233 | yield ALERTS["info_aborting"][lang], gr.update(visible=False) 234 | else: 235 | yield self.logger_handler.log, update_process_bar(self.trainer_callback) 236 | 237 | if self.do_train: 238 | if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)): 239 | finish_info = ALERTS["info_finished"][lang] 240 | else: 241 | finish_info = ALERTS["err_failed"][lang] 242 | else: 243 | if os.path.exists(os.path.join(output_dir, "all_results.json")): 244 | finish_info = get_eval_results(os.path.join(output_dir, "all_results.json")) 245 | else: 246 | finish_info = ALERTS["err_failed"][lang] 247 | 248 | yield self._finalize(lang, finish_info), gr.update(visible=False) 249 | -------------------------------------------------------------------------------- /llmtuner/webui/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gradio as gr 4 | import matplotlib.figure 5 | import matplotlib.pyplot as plt 6 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple 7 | from datetime import datetime 8 | 9 | from llmtuner.extras.ploting import smooth 10 | from llmtuner.tuner import export_model 11 | from llmtuner.webui.common import get_save_dir, DATA_CONFIG 12 | from llmtuner.webui.locales import ALERTS 13 | 14 | if TYPE_CHECKING: 15 | from llmtuner.extras.callbacks import LogCallback 16 | 17 | 18 | def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: 19 | if not callback.max_steps: 20 | return gr.update(visible=False) 21 | 22 | percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0 23 | label = "Running {:d}/{:d}: {} < {}".format( 24 | callback.cur_steps, 25 | callback.max_steps, 26 | callback.elapsed_time, 27 | callback.remaining_time 28 | ) 29 | return gr.update(label=label, value=percentage, visible=True) 30 | 31 | 32 | def get_time() -> str: 33 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 34 | 35 | 36 | def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: 37 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 38 | dataset_info = json.load(f) 39 | 40 | if ( 41 | len(dataset) > 0 42 | and "file_name" in dataset_info[dataset[0]] 43 | and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) 44 | ): 45 | return gr.update(interactive=True) 46 | else: 47 | return gr.update(interactive=False) 48 | 49 | 50 | def get_preview( 51 | dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2 52 | ) -> Tuple[int, list, Dict[str, Any]]: 53 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 54 | dataset_info = json.load(f) 55 | 56 | data_file: str = dataset_info[dataset[0]]["file_name"] 57 | with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: 58 | if data_file.endswith(".json"): 59 | data = json.load(f) 60 | elif data_file.endswith(".jsonl"): 61 | data = [json.loads(line) for line in f] 62 | else: 63 | data = [line for line in f] 64 | return len(data), data[start:end], gr.update(visible=True) 65 | 66 | 67 | def can_quantize(finetuning_type: str) -> Dict[str, Any]: 68 | if finetuning_type != "lora": 69 | return gr.update(value="None", interactive=False) 70 | else: 71 | return gr.update(interactive=True) 72 | 73 | 74 | def gen_cmd(args: Dict[str, Any]) -> str: 75 | args.pop("disable_tqdm", None) 76 | args["plot_loss"] = args.get("do_train", None) 77 | cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "] 78 | for k, v in args.items(): 79 | if v is not None and v != "": 80 | cmd_lines.append(" --{} {} ".format(k, str(v))) 81 | cmd_text = "\\\n".join(cmd_lines) 82 | cmd_text = "```bash\n{}\n```".format(cmd_text) 83 | return cmd_text 84 | 85 | 86 | def get_eval_results(path: os.PathLike) -> str: 87 | with open(path, "r", encoding="utf-8") as f: 88 | result = json.dumps(json.load(f), indent=4) 89 | return "```json\n{}\n```\n".format(result) 90 | 91 | 92 | def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: 93 | log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") 94 | if not os.path.isfile(log_file): 95 | return None 96 | 97 | plt.close("all") 98 | fig = plt.figure() 99 | ax = fig.add_subplot(111) 100 | steps, losses = [], [] 101 | with open(log_file, "r", encoding="utf-8") as f: 102 | for line in f: 103 | log_info = json.loads(line) 104 | if log_info.get("loss", None): 105 | steps.append(log_info["current_steps"]) 106 | losses.append(log_info["loss"]) 107 | 108 | if len(losses) == 0: 109 | return None 110 | 111 | ax.plot(steps, losses, alpha=0.4, label="original") 112 | ax.plot(steps, smooth(losses), label="smoothed") 113 | ax.legend() 114 | ax.set_xlabel("step") 115 | ax.set_ylabel("loss") 116 | return fig 117 | 118 | 119 | def save_model( 120 | lang: str, 121 | model_name: str, 122 | model_path: str, 123 | checkpoints: List[str], 124 | finetuning_type: str, 125 | template: str, 126 | max_shard_size: int, 127 | export_dir: str 128 | ) -> Generator[str, None, None]: 129 | if not model_name: 130 | yield ALERTS["err_no_model"][lang] 131 | return 132 | 133 | if not model_path: 134 | yield ALERTS["err_no_path"][lang] 135 | return 136 | 137 | if not checkpoints: 138 | yield ALERTS["err_no_checkpoint"][lang] 139 | return 140 | 141 | if not export_dir: 142 | yield ALERTS["err_no_export_dir"][lang] 143 | return 144 | 145 | args = dict( 146 | model_name_or_path=model_path, 147 | checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]), 148 | finetuning_type=finetuning_type, 149 | template=template, 150 | export_dir=export_dir 151 | ) 152 | 153 | yield ALERTS["info_exporting"][lang] 154 | export_model(args, max_shard_size="{}GB".format(max_shard_size)) 155 | yield ALERTS["info_exported"][lang] 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | transformers==4.33.2 4 | datasets==2.14.6 5 | accelerate==0.24.1 6 | peft==0.6.0 7 | trl>=0.7.2 8 | scipy==1.10.1 9 | sentencepiece 10 | protobuf 11 | tiktoken 12 | fire 13 | jieba==0.42.1 14 | rouge-chinese 15 | nltk 16 | gradio==3.50.2 17 | uvicorn 18 | pydantic==1.10.11 19 | fastapi==0.95.1 20 | sse-starlette 21 | matplotlib 22 | scikit-learn==1.2.2 23 | openpyxl==3.0.10 24 | pandas==1.5.3 25 | pynvml 26 | xlsxwriter -------------------------------------------------------------------------------- /train_bash.py: -------------------------------------------------------------------------------- 1 | from llmtuner import run_exp 2 | 3 | 4 | def main(): 5 | run_exp() 6 | 7 | 8 | def _mp_fn(index): 9 | # For xla_spawn (TPUs) 10 | main() 11 | 12 | 13 | if __name__ == "__main__": 14 | main() 15 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import json 3 | import datetime,time 4 | import shutil 5 | import os 6 | import re 7 | import subprocess 8 | from sklearn.metrics import classification_report 9 | from pynvml import (nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, 10 | nvmlDeviceGetName, nvmlDeviceGetMemoryInfo, nvmlShutdown) 11 | 12 | def read_excel_file(file_path): 13 | df = pd.read_excel(file_path) 14 | return df 15 | 16 | def save_to_excel(df, file_path): 17 | df.to_excel(file_path, index=False) 18 | 19 | def get_available_gpu(threshold=20000): 20 | # Initialize NVML 21 | nvmlInit() 22 | 23 | # Get the number of GPU devices 24 | device_count = nvmlDeviceGetCount() 25 | 26 | # Find GPU devices with available memory greater than the threshold 27 | available_gpus = [] 28 | for i in range(device_count): 29 | handle = nvmlDeviceGetHandleByIndex(i) 30 | info = nvmlDeviceGetMemoryInfo(handle) 31 | free_memory_mb = info.free / 1024 / 1024 32 | 33 | if free_memory_mb > threshold: 34 | available_gpus.append(i) 35 | 36 | # Shutdown NVML 37 | nvmlShutdown() 38 | 39 | return available_gpus 40 | 41 | 42 | def is_numeric(value): 43 | try: 44 | float(value) # 尝试将值转换为浮点数 45 | return True # 如果转换成功,则表示值可以转换为数字 46 | except (ValueError, TypeError): 47 | return False # 如果转换失败或者值的类型不是字符串或数字,则表示值不是数字 48 | 49 | 50 | def accuracy_cal(list1, list2): 51 | count = 0 52 | for i in range(len(list1)): 53 | l = list1[i] 54 | r = list2[i] 55 | l = float(l) if is_numeric(l) else str(l) 56 | # print('l',is_numeric(l),l) 57 | r = float(r) if is_numeric(r) else str(r) 58 | # print('r',is_numeric(r),r) 59 | if l == r: 60 | count += 1 61 | accuracy = count / len(list1) 62 | return accuracy 63 | 64 | def evaluate_model(pred_path, gold_path): 65 | output_path = 'data/evaluate_scores.xlsx' 66 | 67 | # 读取金标签数据excel 68 | gold_data = pd.read_excel(gold_path, sheet_name=None) 69 | 70 | # 读取预测结果excel 71 | pred_data = pd.read_excel(pred_path, sheet_name=None) 72 | 73 | # 创建输出结果集excel 74 | # writer = pd.ExcelWriter(output_path) 75 | writer = pd.ExcelWriter(output_path, engine='xlsxwriter') 76 | 77 | # 定义结果集DataFrame 78 | result_df = pd.DataFrame(columns=['sheet', 'field', 'precision', 'recall', 'f1-score', 'accuracy', 'support']) 79 | 80 | # 遍历金标签数据excel的所有sheet 81 | for sheet_name in gold_data.keys(): 82 | # 获取金标签数据和预测结果的对应sheet 83 | gold_sheet = gold_data[sheet_name] 84 | pred_sheet = pred_data.get(sheet_name, None) 85 | 86 | # 如果预测结果中没有该sheet,输出提示信息 87 | if pred_sheet is None: 88 | print(f"Warning: Sheet '{sheet_name}' not found in prediction file.") 89 | continue 90 | 91 | # 获取金标签数据和预测结果的所有字段 92 | gold_columns = sorted(gold_sheet.columns.tolist()) 93 | pred_columns = sorted(pred_sheet.columns.tolist()) 94 | 95 | # 确保金标签数据和预测结果中的字段一致 96 | if set(gold_columns) != set(pred_columns): 97 | print(f"Warning: Columns mismatch in sheet '{sheet_name}'.") 98 | continue 99 | 100 | # 提取金标签数据和预测结果中的标签数据 101 | gold_labels = gold_sheet.fillna(value='未提及') 102 | pred_labels = pred_sheet.fillna(value='未提及') 103 | # 判断预测结果中是否存在空值 104 | if pred_labels.isnull().any().any(): 105 | print(f"Warning: Missing values detected in sheet '{sheet_name}'.") 106 | 107 | # 将预测结果中的数据类型转换为与金标签数据excel中相应的字段相同的数据类型 108 | for col in gold_columns: 109 | if gold_labels[col].dtype != pred_labels[col].dtype: 110 | pred_labels[col] = pred_labels[col].fillna('').astype(str) 111 | gold_labels[col] = gold_labels[col].fillna('').astype(str) 112 | max_lengths = {col: pred_labels[col].apply(lambda x: str(x)).str.len().max() for col in pred_labels.columns} 113 | pred_labels = pred_labels.fillna('').astype(str) 114 | gold_labels = gold_labels.fillna('').astype(str) 115 | # 计算准确率、F1分数、精确率和召回率 116 | for col in gold_columns: 117 | report = classification_report(gold_labels[col], pred_labels[col], output_dict=True, zero_division=0) 118 | accuracy = accuracy_cal(list(gold_labels[col]), list(pred_labels[col])) #report['accuracy'] 119 | new_row = pd.DataFrame({'sheet': sheet_name, 'field': col, 120 | 'precision': report['weighted avg']['precision'], 121 | 'recall': report['weighted avg']['recall'], 122 | 'f1-score': report['weighted avg']['f1-score'], 123 | 'accuracy': accuracy, 124 | 'support': report['weighted avg']['support']}, index=[0]) 125 | result_df = pd.concat([result_df, new_row], axis=0, ignore_index=True) 126 | # 将结果写入输出结果集excel中的新sheet 127 | result_df.to_excel(writer, sheet_name='result', index=False) 128 | workbook = writer.book 129 | worksheet = writer.sheets['result'] 130 | percent_format = workbook.add_format({'num_format': '0%'}) 131 | worksheet.set_column('C:F', None, percent_format) 132 | worksheet.conditional_format('C2:F500', {'type': 'data_bar', 133 | 'bar_color': '#FFA500'}) 134 | # 保存输出结果集excel 135 | writer.close() 136 | print(f"Comparison complete. Results saved to '{output_path}'.") 137 | return output_path 138 | 139 | 140 | def stop_train_process(): 141 | process = subprocess.Popen('ps -ef | grep train_bash.py', shell=True, stdout=subprocess.PIPE) 142 | output, _ = process.communicate() 143 | process.kill() 144 | n = 0 145 | # 解析输出以获取进程ID 146 | print('output',output) 147 | try: 148 | lines = output.decode().split('\n') 149 | for line in lines: 150 | if 'train_bash.py' in line: 151 | parts = line.split() 152 | pid = parts[1] 153 | # 杀死进程 154 | subprocess.call(['kill', '-9', pid]) 155 | n+=1 156 | except Exception as e: 157 | print('error!!',e) 158 | 159 | return f'停止了{n//2}个进程' 160 | --------------------------------------------------------------------------------