├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── dataset └── preprocessor.py ├── model ├── bash.py ├── data │ └── README.md └── src │ ├── data │ ├── loader.py │ ├── preprocess.py │ ├── template.py │ └── utils.py │ ├── extras │ ├── callbacks.py │ ├── constants.py │ ├── logging.py │ ├── misc.py │ ├── packages.py │ └── ploting.py │ ├── hparams │ ├── data_args.py │ ├── evaluation_args.py │ ├── finetuning_args.py │ ├── generating_args.py │ └── model_args.py │ ├── model │ ├── adapter.py │ ├── loader.py │ ├── parser.py │ ├── patcher.py │ └── utils.py │ └── train │ ├── dpo │ ├── collator.py │ ├── trainer.py │ └── workflow.py │ ├── sft │ ├── metric.py │ ├── trainer.py │ └── workflow.py │ ├── tuner.py │ └── utils.py ├── requirements.txt └── slides ├── Framework.pdf ├── Framework.png ├── Poster-DELLM.pdf └── Slides-DELLM.pdf /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .DS_Store 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Rcrossmeister 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Knowledge-to-SQL 2 | **[2024/10] Check our video presentation in [Underline](https://underline.io/events/466/posters/18354/poster/102023-knowledge-to-sql-enhancing-sql-generation-with-data-expert-llm)!** 3 | 4 | **[2024/08] The video presentation of our paper will be available soon.** 5 | 6 | **[2024/08] The presentation of our paper are scheduled at Virtual Poster Session 2, check the poster and slides [here](./slides).** 7 | 8 | **[2024/05] Our paper is accepted as a findings paper in ACL2024!** 9 | 10 | We propose a novel framework **Knowledge-to-SQL** that leverages **Data Expert Large Language Model (DELLM)** to enhance SQL generation, the paper is available [here](https://aclanthology.org/2024.findings-acl.653.pdf). 11 | 12 | Framework 13 | 14 | ## Setup 15 | 16 | ### Environment 17 | 18 | **The GPU resources we use in our study is 4*A800-SXM4-80G with the corresponding CUDA version 12.1,** we strongly recommend using the torch version above 2.0. 19 | 20 | ```shell 21 | # Clone the repository 22 | git https://github.com/Rcrossmeister/Knowledge-to-SQL.git 23 | cd ./Knowledge-to-SQL 24 | 25 | # Create the conda environment 26 | conda create -n dellm python=3.11.3 27 | conda activate dellm 28 | 29 | # Install the required packages 30 | pip install -r requirements.txt 31 | ``` 32 | 33 | ### Dataset 34 | 35 | We mainly focus on **[BIRD](https://bird-bench.github.io/)** dataset in our study, we also support **[Spider](https://yale-lily.github.io/spider)** dataset for robustness study. 36 | 37 | ## Training 38 | 39 | The training implementaion was inspired by **[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory)**, you can check their technical report [here](https://arxiv.org/abs/2403.13372). 40 | 41 | ### Quick Start 42 | 43 | We provide a script to quick start upon BIRD dataset 44 | 45 | ## Citation 46 | 47 | Please cite our paper if you include Knowledge-to-SQL in your work: 48 | 49 | ``` 50 | @inproceedings{hong2024knowledge, 51 | title = "Knowledge-to-{SQL}: Enhancing {SQL} Generation with Data Expert {LLM}", 52 | author = "Hong, Zijin and 53 | Yuan, Zheng and 54 | Chen, Hao and 55 | Zhang, Qinggang and 56 | Huang, Feiran and 57 | Huang, Xiao", 58 | booktitle = "Findings of the Association for Computational Linguistics ACL 2024", 59 | year = "2024" 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /dataset/preprocessor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import argparse 3 | import fnmatch 4 | import json 5 | import os 6 | import pdb 7 | import pickle 8 | import re 9 | import sqlite3 10 | from typing import Dict, List, Tuple 11 | 12 | import backoff 13 | import openai 14 | import pandas as pd 15 | import sqlparse 16 | from tqdm import tqdm 17 | 18 | import spacy 19 | import json 20 | import sqlite3 21 | from nltk.translate.bleu_score import sentence_bleu 22 | from nltk.translate.bleu_score import SmoothingFunction 23 | from tqdm import tqdm 24 | import numpy as np 25 | 26 | '''openai configure''' 27 | 28 | openai.debug = True 29 | 30 | def new_directory(path): 31 | if not os.path.exists(path): 32 | os.makedirs(path) 33 | 34 | def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]: 35 | """ 36 | Read an sqlite file, and return the CREATE commands for each of the tables in the database. 37 | """ 38 | asdf = 'database' if bench_root == 'spider' else 'databases' 39 | with sqlite3.connect(f'file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro', uri=True) as conn: 40 | # conn.text_factory = bytes 41 | cursor = conn.cursor() 42 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 43 | tables = cursor.fetchall() 44 | schemas = {} 45 | for table in tables: 46 | cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0])) 47 | schemas[table[0]] = cursor.fetchone()[0] 48 | 49 | return schemas 50 | 51 | def nice_look_table(column_names: list, values: list): 52 | rows = [] 53 | # Determine the maximum width of each column 54 | widths = [max(len(str(value[i])) for value in values + [column_names]) for i in range(len(column_names))] 55 | 56 | # Print the column names 57 | header = ''.join(f'{column.rjust(width)} ' for column, width in zip(column_names, widths)) 58 | # print(header) 59 | # Print the values 60 | for value in values: 61 | row = ''.join(f'{str(v).rjust(width)} ' for v, width in zip(value, widths)) 62 | rows.append(row) 63 | rows = "\n".join(rows) 64 | final_output = header + '\n' + rows 65 | return final_output 66 | 67 | 68 | def generate_schema_prompt(db_path, num_rows=None): 69 | # extract create ddls 70 | ''' 71 | :param root_place: 72 | :param db_name: 73 | :return: 74 | ''' 75 | full_schema_prompt_list = [] 76 | conn = sqlite3.connect(db_path) 77 | # Create a cursor object 78 | cursor = conn.cursor() 79 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") 80 | tables = cursor.fetchall() 81 | schemas = {} 82 | for table in tables: 83 | if table == 'sqlite_sequence': 84 | continue 85 | cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0])) 86 | create_prompt = cursor.fetchone()[0] 87 | schemas[table[0]] = create_prompt 88 | if num_rows: 89 | cur_table = table[0] 90 | if cur_table in ['order', 'by', 'group']: 91 | cur_table = "`{}`".format(cur_table) 92 | 93 | cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows)) 94 | column_names = [description[0] for description in cursor.description] 95 | values = cursor.fetchall() 96 | rows_prompt = nice_look_table(column_names=column_names, values=values) 97 | verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(num_rows, 98 | cur_table, 99 | num_rows, 100 | rows_prompt) 101 | schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt) 102 | 103 | for k, v in schemas.items(): 104 | full_schema_prompt_list.append(v) 105 | 106 | schema_prompt = "\n\n".join(full_schema_prompt_list) 107 | 108 | return schema_prompt 109 | 110 | 111 | def generate_comment_prompt(question, knowledge=None): 112 | question_prompt = "{}".format(question) 113 | knowledge_prompt = "{}".format(knowledge) 114 | 115 | return question_prompt, knowledge_prompt 116 | 117 | def generate_combined_prompts_one(db_path, question, knowledge=None): 118 | schema_prompt = generate_schema_prompt(db_path, num_rows=None) # This is the entry to collect values 119 | question_prompt, knowledge_prompt = generate_comment_prompt(question, knowledge) 120 | 121 | return question_prompt, knowledge_prompt, schema_prompt 122 | 123 | def semantic_similarity(column, question): 124 | nlp = spacy.load("en_core_web_lg") 125 | column_doc = nlp(column) 126 | question_doc = nlp(question) 127 | similarity = question_doc.similarity(column_doc) 128 | return similarity 129 | 130 | def nice_look_table(column_names: list, values: list): 131 | rows = [] 132 | # Determine the maximum width of each column 133 | widths = [max(len(str(value[i])) for value in values + [column_names]) for i in range(len(column_names))] 134 | 135 | # Print the column names 136 | header = ''.join(f'{column.rjust(width)} ' for column, width in zip(column_names, widths)) 137 | # print(header) 138 | # Print the values 139 | for value in values: 140 | row = ''.join(f'{str(v).rjust(width)} ' for v, width in zip(value, widths)) 141 | rows.append(row) 142 | rows = "\n".join(rows) 143 | final_output = header + '\n' + rows 144 | return final_output 145 | 146 | def get_tablename_columnList(db_path): 147 | full_schema_prompt_list = [] 148 | conn = sqlite3.connect(db_path) 149 | cursor = conn.cursor() 150 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") 151 | tables = cursor.fetchall() 152 | schemas = {} 153 | for table in tables: 154 | if table == 'sqlite_sequence': 155 | continue 156 | cursor.execute("SELECT * FROM `{}`".format(table[0])) 157 | col_name_list = [tuple[0] for tuple in cursor.description] 158 | schemas[table[0]] = col_name_list 159 | for k, v in schemas.items(): 160 | full_schema_prompt_list.append(v) 161 | return schemas 162 | 163 | def get_sub_table(db_path, selected_tables_columns, num_rows=3): 164 | if len(selected_tables_columns) == 0: 165 | return "" 166 | conn = sqlite3.connect(db_path) 167 | cursor = conn.cursor() 168 | subtable_prompt_list = [] 169 | for table_name, column_list in selected_tables_columns.items(): 170 | execute_query = "SELECT {} FROM `{}` LIMIT {}".format(", ".join(column_list), table_name, num_rows) 171 | cursor.execute(execute_query) 172 | column_names = [description[0] for description in cursor.description] 173 | values = cursor.fetchall() 174 | rows_prompt = nice_look_table(column_names=column_names, values=values) 175 | verbose_prompt = "/*\n{} rows of {} key-columns in Table {}:\n{}\n*/". \ 176 | format(num_rows, len(column_list), table_name, rows_prompt) 177 | subtable_prompt_list.append(verbose_prompt) 178 | subtable_prompt = "\n".join(subtable_prompt_list) 179 | return subtable_prompt 180 | 181 | def get_subtable_prompt(db_path, question): 182 | tableName_columnList_dict = get_tablename_columnList(db_path) 183 | smooth_fn = SmoothingFunction().method7 184 | bleu_score = [] 185 | table_column_pair_list = [] 186 | for table_name, columnList in tableName_columnList_dict.items(): 187 | for column in columnList: 188 | table_column_pair = table_name + " <_> " + column 189 | bleu = sentence_bleu([column], question, smoothing_function=smooth_fn) 190 | bleu_score.append(bleu) 191 | table_column_pair_list.append(table_column_pair) 192 | table_column_pair_list = [table_column_pair_list[i] for i in range(len(bleu_score)) if bleu_score[i] >= 0.08] 193 | bleu_score = [s for s in bleu_score if s >= 0.08] 194 | if len(bleu_score) == 0: 195 | return "" 196 | sorted_id = sorted(range(len(bleu_score)), key=lambda k: bleu_score[k], reverse=True) 197 | sorted_bleu_score = [bleu_score[i] for i in sorted_id] 198 | sorted_table_column_pair = [table_column_pair_list[i] for i in sorted_id] 199 | top_K_table_column_pair = sorted_table_column_pair[:3] 200 | 201 | selected_tables_columns = {} 202 | for table_column_pair in top_K_table_column_pair: 203 | table_name, column_name = table_column_pair.split(" <_> ") 204 | column_name = "`{}`".format(column_name) 205 | if table_name in selected_tables_columns: 206 | selected_tables_columns[table_name].append(column_name) 207 | else: 208 | selected_tables_columns[table_name] = [column_name] 209 | subtable_prompt = get_sub_table(db_path, selected_tables_columns, num_rows=3) 210 | return subtable_prompt 211 | 212 | def construct_ekg_data(db_path_list, question_list, knowledge_list=None): 213 | ''' 214 | :param db_path: str 215 | :param question_list: [] 216 | :return: dict of responses collected from openai 217 | ''' 218 | 219 | output_list = [] 220 | for i, question in tqdm(enumerate(question_list)): 221 | # print('--------------------- processing {}th question ---------------------'.format(i)) 222 | # print('the question is: {}'.format(question)) 223 | 224 | question, knowledge, schema = generate_combined_prompts_one(db_path=db_path_list[i], question=question, 225 | knowledge=knowledge_list[i]) 226 | knowledge.replace(';', '') 227 | output = { 228 | 'instruction': 'You are a helpful assistant. ' 229 | 'Please generate a evidence base on the given database schema and question ' 230 | 'The evidence should use the database information(schema) to explain the question ' 231 | 'The evidence aim to help language model to generate a more accurate SQL to answer the question. ', 232 | 'input': '--schema: ' + schema + ' ' 233 | '--question: ' + question, 234 | 'output': knowledge 235 | } 236 | output_list.append(output) 237 | 238 | return output_list 239 | 240 | def question_package(data_json, knowledge=False): 241 | question_list = [] 242 | for data in data_json: 243 | question_list.append(data['question']) 244 | 245 | return question_list 246 | 247 | def knowledge_package(data_json, knowledge=False): 248 | knowledge_list = [] 249 | for data in data_json: 250 | knowledge_list.append(data['evidence']) 251 | 252 | return knowledge_list 253 | 254 | def decouple_question_schema(datasets, db_root_path): 255 | question_list = [] 256 | db_path_list = [] 257 | knowledge_list = [] 258 | for i, data in enumerate(datasets): 259 | question_list.append(data['question']) 260 | cur_db_path = db_root_path + data['db_id'] + '/' + data['db_id'] + '.sqlite' 261 | db_path_list.append(cur_db_path) 262 | knowledge_list.append(data['evidence']) 263 | 264 | return question_list, db_path_list, knowledge_list 265 | 266 | if __name__ == '__main__': 267 | args_parser = argparse.ArgumentParser() 268 | args_parser.add_argument('--data_path', type=str, default='') 269 | args_parser.add_argument('--db_root_path', type=str, default='') 270 | args_parser.add_argument('--output_path', type=str, default='') 271 | args = args_parser.parse_args() 272 | 273 | eval_data = json.load(open(args.eval_path, 'r')) 274 | 275 | question_list, db_path_list, knowledge_list = decouple_question_schema(datasets=eval_data, 276 | db_root_path=args.db_root_path) 277 | assert len(question_list) == len(db_path_list) == len(knowledge_list) 278 | 279 | json_withSubTable = [] 280 | for i in tqdm(range(len(args.data_path))): 281 | instance = args.data_path[i] 282 | db_id = instance['db_id'] 283 | question = instance['question'] 284 | db_path = args.db_root_path + db_id + "/" + db_id + ".sqlite" 285 | subtable_prompt = get_subtable_prompt(db_path, question) 286 | instance["subtable_prompt"] = subtable_prompt 287 | if "question_id" not in instance: 288 | instance["question_id"] = i 289 | json_withSubTable.append(instance) 290 | 291 | ekg_data = construct_ekg_data(db_path_list=db_path_list, question_list=question_list, knowledge_list=knowledge_list) 292 | 293 | with open(args.output_path, 'w', encoding='utf-8') as file: 294 | json.dump(ekg_data, file, indent=4) 295 | file.close() 296 | 297 | print('successfully construct results') -------------------------------------------------------------------------------- /model/bash.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, Any, Dict, List, Optional 3 | from transformers import PreTrainedModel 4 | 5 | from src.extras.callbacks import LogCallback 6 | from src.extras.logging import get_logger 7 | from src.model.loader import load_model_and_tokenizer 8 | from src.model.parser import get_train_args, get_infer_args 9 | from src.train.sft.workflow import run_sft 10 | from src.train.dpo.workflow import run_dpo 11 | 12 | if TYPE_CHECKING: 13 | from transformers import TrainerCallback 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): 19 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) 20 | callbacks = [LogCallback()] if callbacks is None else callbacks 21 | 22 | if finetuning_args.stage == "sft": 23 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 24 | elif finetuning_args.stage == "dpo": 25 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) 26 | else: 27 | raise ValueError("Unknown task.") 28 | 29 | 30 | def export_model(args: Optional[Dict[str, Any]] = None): 31 | model_args, _, finetuning_args, _ = get_infer_args(args) 32 | 33 | if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: 34 | raise ValueError("Please merge adapters before quantizing the model.") 35 | 36 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 37 | 38 | if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: 39 | raise ValueError("Cannot merge adapters to a quantized model.") 40 | 41 | if not isinstance(model, PreTrainedModel): 42 | raise ValueError("The model is not a `PreTrainedModel`, export aborted.") 43 | 44 | model.config.use_cache = True 45 | if getattr(model.config, "torch_dtype", None) == "bfloat16": 46 | model = model.to(torch.bfloat16).to("cpu") 47 | else: 48 | model = model.to(torch.float16).to("cpu") 49 | setattr(model.config, "torch_dtype", "float16") 50 | 51 | model.save_pretrained( 52 | save_directory=model_args.export_dir, 53 | max_shard_size="{}GB".format(model_args.export_size), 54 | safe_serialization=(not model_args.export_legacy_format) 55 | ) 56 | 57 | try: 58 | tokenizer.padding_side = "left" # restore padding side 59 | tokenizer.init_kwargs["padding_side"] = "left" 60 | tokenizer.save_pretrained(model_args.export_dir) 61 | except: 62 | logger.warning("Cannot save tokenizer, please copy the files manually.") 63 | 64 | 65 | def main(): 66 | run_exp() 67 | 68 | if __name__ == "__main__": 69 | main() -------------------------------------------------------------------------------- /model/data/README.md: -------------------------------------------------------------------------------- 1 | # Data Description 2 | 3 | To support your private data, please refer to this README for more details and revise the `./dataset_info.json` file. -------------------------------------------------------------------------------- /model/src/data/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING, Any, Dict, List, Union 3 | 4 | from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk 5 | 6 | from .utils import checksum 7 | from ..extras.constants import FILEEXT2TYPE 8 | from ..extras.logging import get_logger 9 | 10 | if TYPE_CHECKING: 11 | from datasets import Dataset, IterableDataset 12 | from ..hparams.model_args import ModelArguments 13 | from ..hparams.data_args import DataArguments 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def get_dataset( 20 | model_args: "ModelArguments", 21 | data_args: "DataArguments" 22 | ) -> Union["Dataset", "IterableDataset"]: 23 | max_samples = data_args.max_samples 24 | all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets 25 | 26 | if data_args.cache_path is not None: 27 | if os.path.exists(data_args.cache_path): 28 | logger.warning("Loading dataset from disk will ignore other data arguments.") 29 | dataset = load_from_disk(data_args.cache_path) 30 | if data_args.streaming: 31 | dataset = dataset.to_iterable_dataset() 32 | return dataset 33 | elif data_args.streaming: 34 | raise ValueError("Turn off dataset streaming to save cache files.") 35 | 36 | for dataset_attr in data_args.dataset_list: 37 | logger.info("Loading dataset {}...".format(dataset_attr)) 38 | 39 | data_path, data_name, data_dir, data_files = None, None, None, None 40 | if dataset_attr.load_from in ["hf_hub", "ms_hub"]: 41 | data_path = dataset_attr.dataset_name 42 | data_name = dataset_attr.subset 43 | data_dir = dataset_attr.folder 44 | elif dataset_attr.load_from == "script": 45 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) 46 | data_name = dataset_attr.subset 47 | elif dataset_attr.load_from == "file": 48 | data_files = [] 49 | local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) 50 | if os.path.isdir(local_path): # is directory 51 | for file_name in os.listdir(local_path): 52 | data_files.append(os.path.join(local_path, file_name)) 53 | if data_path is None: 54 | data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) 55 | else: 56 | assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical." 57 | elif os.path.isfile(local_path): # is file 58 | data_files.append(local_path) 59 | data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) 60 | else: 61 | raise ValueError("File not found.") 62 | 63 | assert data_path, "File extension must be txt, csv, json or jsonl." 64 | checksum(data_files, dataset_attr.dataset_sha1) 65 | else: 66 | raise NotImplementedError 67 | 68 | if dataset_attr.load_from == "ms_hub": 69 | try: 70 | from modelscope import MsDataset 71 | from modelscope.utils.config_ds import MS_DATASETS_CACHE 72 | 73 | cache_dir = model_args.cache_dir or MS_DATASETS_CACHE 74 | dataset = MsDataset.load( 75 | dataset_name=data_path, 76 | subset_name=data_name, 77 | data_dir=data_dir, 78 | data_files=data_files, 79 | split=data_args.split, 80 | cache_dir=cache_dir, 81 | token=model_args.ms_hub_token, 82 | use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")) 83 | ).to_hf_dataset() 84 | except ImportError: 85 | raise ImportError("Please install modelscope via `pip install modelscope -U`") 86 | else: 87 | dataset = load_dataset( 88 | path=data_path, 89 | name=data_name, 90 | data_dir=data_dir, 91 | data_files=data_files, 92 | split=data_args.split, 93 | cache_dir=model_args.cache_dir, 94 | token=model_args.hf_hub_token, 95 | streaming=(data_args.streaming and (dataset_attr.load_from != "file")) 96 | ) 97 | 98 | if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True 99 | dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter 100 | 101 | if max_samples is not None: # truncate dataset 102 | dataset = dataset.select(range(min(len(dataset), max_samples))) 103 | 104 | def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]: 105 | # convert dataset from sharegpt format to alpaca format 106 | outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []} 107 | for i, msg_list in enumerate(examples[dataset_attr.messages]): 108 | msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2 109 | if len(msg_list) == 0: 110 | continue 111 | 112 | msg_pairs = [] 113 | user_role, assistant_role = None, None 114 | for idx in range(0, len(msg_list), 2): 115 | if user_role is None and assistant_role is None: 116 | user_role = msg_list[idx][dataset_attr.role] 117 | assistant_role = msg_list[idx + 1][dataset_attr.role] 118 | else: 119 | if ( 120 | msg_list[idx][dataset_attr.role] != user_role 121 | or msg_list[idx+1][dataset_attr.role] != assistant_role 122 | ): 123 | raise ValueError("Only accepts conversation in u/a/u/a/u/a order.") 124 | msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content])) 125 | 126 | if len(msg_pairs) != 0: 127 | outputs["prompt"].append(msg_pairs[-1][0]) 128 | outputs["query"].append("") 129 | outputs["response"].append(msg_pairs[-1][1]) 130 | outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None) 131 | outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") 132 | 133 | return outputs 134 | 135 | if dataset_attr.formatting == "sharegpt": # convert format 136 | column_names = list(next(iter(dataset)).keys()) 137 | kwargs = {} 138 | if not data_args.streaming: 139 | kwargs = dict( 140 | num_proc=data_args.preprocessing_num_workers, 141 | load_from_cache_file=(not data_args.overwrite_cache), 142 | desc="Converting format of dataset" 143 | ) 144 | 145 | dataset = dataset.map( 146 | convert_format, 147 | batched=True, 148 | remove_columns=column_names, 149 | **kwargs 150 | ) 151 | else: 152 | for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset 153 | if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: 154 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) 155 | 156 | all_datasets.append(dataset) 157 | 158 | if len(data_args.dataset_list) == 1: 159 | return all_datasets[0] 160 | elif data_args.mix_strategy == "concat": 161 | if data_args.streaming: 162 | logger.warning("The samples between different datasets will not be mixed in streaming mode.") 163 | return concatenate_datasets(all_datasets) 164 | elif data_args.mix_strategy.startswith("interleave"): 165 | if not data_args.streaming: 166 | logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") 167 | return interleave_datasets( 168 | datasets=all_datasets, 169 | probabilities=data_args.interleave_probs, 170 | seed=data_args.seed, 171 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" 172 | ) 173 | else: 174 | raise ValueError("Unknown mixing strategy.") 175 | -------------------------------------------------------------------------------- /model/src/data/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tiktoken 3 | from itertools import chain 4 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union 5 | 6 | from .template import get_template_and_fix_tokenizer 7 | from ..extras.constants import IGNORE_INDEX 8 | from ..extras.logging import get_logger 9 | 10 | if TYPE_CHECKING: 11 | from datasets import Dataset, IterableDataset 12 | from transformers import Seq2SeqTrainingArguments 13 | from transformers.tokenization_utils import PreTrainedTokenizer 14 | from ..hparams.data_args import DataArguments 15 | 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: 21 | for i in range(len(examples["prompt"])): 22 | query, response = examples["prompt"][i], examples["response"][i] 23 | query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query 24 | history = examples["history"][i] if "history" in examples else None 25 | system = examples["system"][i] if "system" in examples else None 26 | yield query, response, history, system 27 | 28 | 29 | def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]: 30 | max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len))) 31 | max_target_len = max(max_target_len, data_args.reserved_label_len) 32 | max_source_len = data_args.cutoff_len - max_target_len 33 | return max_source_len, max_target_len 34 | 35 | 36 | def preprocess_dataset( 37 | dataset: Union["Dataset", "IterableDataset"], 38 | tokenizer: "PreTrainedTokenizer", 39 | data_args: "DataArguments", 40 | training_args: "Seq2SeqTrainingArguments", 41 | stage: Literal["pt", "sft", "rm", "ppo"] 42 | ) -> Union["Dataset", "IterableDataset"]: 43 | template = get_template_and_fix_tokenizer(data_args.template, tokenizer) 44 | 45 | if data_args.cache_path is not None and os.path.exists(data_args.cache_path): 46 | return dataset # already preprocessed 47 | 48 | if data_args.train_on_prompt and template.efficient_eos: 49 | raise ValueError("Current template does not support `train_on_prompt`.") 50 | 51 | def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 52 | # build grouped texts with format `X1 X2 X3 ...` 53 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) 54 | kwargs = dict(allowed_special="all") 55 | else: 56 | kwargs = dict(add_special_tokens=True) 57 | 58 | if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer 59 | add_eos_token_flag = getattr(tokenizer, "add_eos_token") 60 | setattr(tokenizer, "add_eos_token", True) 61 | 62 | tokenized_examples = tokenizer(examples["prompt"], **kwargs) 63 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} 64 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) 65 | block_size = data_args.cutoff_len 66 | # we drop the small remainder, and if the total_length < block_size, we exclude this batch 67 | total_length = (total_length // block_size) * block_size 68 | # split by chunks of cutoff_len 69 | result = { 70 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)] 71 | for k, t in concatenated_examples.items() 72 | } 73 | # make sure the saved tokenizer is the same as the original one 74 | if hasattr(tokenizer, "add_eos_token"): 75 | setattr(tokenizer, "add_eos_token", add_eos_token_flag) 76 | return result 77 | 78 | def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 79 | # build inputs with format ` X Y ` and labels with format ` ... Y ` 80 | # for multiturn examples, we only mask the prompt part in each prompt-response pair. 81 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} 82 | 83 | for query, response, history, system in construct_example(examples): 84 | if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): 85 | continue 86 | 87 | input_ids, labels = [], [] 88 | for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( 89 | tokenizer, query, response, history, system 90 | )): 91 | source_len, target_len = len(source_ids), len(target_ids) 92 | max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) 93 | if source_len > max_source_len: 94 | source_ids = source_ids[:max_source_len] 95 | if target_len > max_target_len: 96 | target_ids = target_ids[:max_target_len] 97 | 98 | if data_args.train_on_prompt: 99 | source_mask = source_ids 100 | elif turn_idx != 0 and template.efficient_eos: 101 | source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) 102 | else: 103 | source_mask = [IGNORE_INDEX] * len(source_ids) 104 | 105 | input_ids += source_ids + target_ids 106 | labels += source_mask + target_ids 107 | 108 | if template.efficient_eos: 109 | input_ids += [tokenizer.eos_token_id] 110 | labels += [tokenizer.eos_token_id] 111 | 112 | if len(input_ids) > data_args.cutoff_len: 113 | input_ids = input_ids[:data_args.cutoff_len] 114 | labels = labels[:data_args.cutoff_len] 115 | 116 | model_inputs["input_ids"].append(input_ids) 117 | model_inputs["attention_mask"].append([1] * len(input_ids)) 118 | model_inputs["labels"].append(labels) 119 | 120 | return model_inputs 121 | 122 | def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 123 | # build inputs with format ` X1 Y1 X2 Y2 ` 124 | # and labels with format ` ... Y1 ... Y2 ` 125 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} 126 | input_ids, labels = [], [] 127 | for query, response, history, system in construct_example(examples): 128 | if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): 129 | continue 130 | 131 | for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( 132 | tokenizer, query, response, history, system 133 | )): 134 | if data_args.train_on_prompt: 135 | source_mask = source_ids 136 | elif turn_idx != 0 and template.efficient_eos: 137 | source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) 138 | else: 139 | source_mask = [IGNORE_INDEX] * len(source_ids) 140 | input_ids += source_ids + target_ids 141 | labels += source_mask + target_ids 142 | 143 | if template.efficient_eos: 144 | input_ids += [tokenizer.eos_token_id] 145 | labels += [tokenizer.eos_token_id] 146 | 147 | total_length = len(input_ids) 148 | block_size = data_args.cutoff_len 149 | # we drop the small remainder, and if the total_length < block_size, we exclude this batch 150 | total_length = (total_length // block_size) * block_size 151 | # split by chunks of cutoff_len 152 | for i in range(0, total_length, block_size): 153 | model_inputs["input_ids"].append(input_ids[i: i + block_size]) 154 | model_inputs["attention_mask"].append([1] * block_size) 155 | model_inputs["labels"].append(labels[i: i + block_size]) 156 | 157 | return model_inputs 158 | 159 | def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 160 | # build inputs with format ` X` and labels with format `Y ` 161 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} 162 | 163 | for query, response, history, system in construct_example(examples): 164 | if not (isinstance(query, str) and query != ""): 165 | continue 166 | 167 | input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system) 168 | 169 | if template.efficient_eos: 170 | labels += [tokenizer.eos_token_id] 171 | 172 | if len(input_ids) > data_args.cutoff_len: 173 | input_ids = input_ids[:data_args.cutoff_len] 174 | if len(labels) > data_args.cutoff_len: 175 | labels = labels[:data_args.cutoff_len] 176 | 177 | model_inputs["input_ids"].append(input_ids) 178 | model_inputs["attention_mask"].append([1] * len(input_ids)) 179 | model_inputs["labels"].append(labels) 180 | 181 | return model_inputs 182 | 183 | def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]: 184 | # build input pairs with format ` X`, `Y1 ` and `Y2 ` 185 | model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []} 186 | for query, response, history, system in construct_example(examples): 187 | if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1): 188 | continue 189 | 190 | prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system) 191 | _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system) 192 | 193 | if template.efficient_eos: 194 | chosen_ids += [tokenizer.eos_token_id] 195 | rejected_ids += [tokenizer.eos_token_id] 196 | 197 | source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids)) 198 | max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args) 199 | if source_len > max_source_len: 200 | prompt_ids = prompt_ids[:max_source_len] 201 | if target_len > max_target_len: 202 | chosen_ids = chosen_ids[:max_target_len] 203 | rejected_ids = rejected_ids[:max_target_len] 204 | 205 | model_inputs["prompt_ids"].append(prompt_ids) 206 | model_inputs["chosen_ids"].append(chosen_ids) 207 | model_inputs["rejected_ids"].append(rejected_ids) 208 | 209 | return model_inputs 210 | 211 | def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None: 212 | print("input_ids:\n{}".format(example["input_ids"])) 213 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 214 | print("label_ids:\n{}".format(example["labels"])) 215 | print("labels:\n{}".format( 216 | tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) 217 | )) 218 | 219 | def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None: 220 | print("prompt_ids:\n{}".format(example["prompt_ids"])) 221 | print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False))) 222 | print("chosen_ids:\n{}".format(example["chosen_ids"])) 223 | print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False))) 224 | print("rejected_ids:\n{}".format(example["rejected_ids"])) 225 | print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False))) 226 | 227 | def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None: 228 | print("input_ids:\n{}".format(example["input_ids"])) 229 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 230 | 231 | if stage == "pt": 232 | preprocess_func = preprocess_pretrain_dataset 233 | print_function = print_unsupervised_dataset_example 234 | elif stage == "sft" and not training_args.predict_with_generate: 235 | preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset 236 | print_function = print_supervised_dataset_example 237 | elif stage == "rm": 238 | preprocess_func = preprocess_pairwise_dataset 239 | print_function = print_pairwise_dataset_example 240 | else: 241 | preprocess_func = preprocess_unsupervised_dataset 242 | print_function = print_unsupervised_dataset_example 243 | 244 | with training_args.main_process_first(desc="dataset map pre-processing"): 245 | column_names = list(next(iter(dataset)).keys()) 246 | kwargs = {} 247 | if not data_args.streaming: 248 | kwargs = dict( 249 | num_proc=data_args.preprocessing_num_workers, 250 | load_from_cache_file=(not data_args.overwrite_cache), 251 | desc="Running tokenizer on dataset" 252 | ) 253 | 254 | dataset = dataset.map( 255 | preprocess_func, 256 | batched=True, 257 | remove_columns=column_names, 258 | **kwargs 259 | ) 260 | 261 | if data_args.cache_path is not None and not os.path.exists(data_args.cache_path): 262 | if training_args.should_save: 263 | dataset.save_to_disk(data_args.cache_path) 264 | logger.info("Dataset cache saved at {}.".format(data_args.cache_path)) 265 | 266 | if training_args.should_log: 267 | try: 268 | print_function(next(iter(dataset))) 269 | except StopIteration: 270 | raise RuntimeError("Empty dataset!") 271 | 272 | return dataset 273 | -------------------------------------------------------------------------------- /model/src/data/template.py: -------------------------------------------------------------------------------- 1 | import tiktoken 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 4 | 5 | from ..extras.logging import get_logger 6 | 7 | if TYPE_CHECKING: 8 | from transformers import PreTrainedTokenizer 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | @dataclass 15 | class Template: 16 | 17 | prefix: List[Union[str, Dict[str, str]]] 18 | prompt: List[Union[str, Dict[str, str]]] 19 | system: str 20 | sep: List[Union[str, Dict[str, str]]] 21 | stop_words: List[str] 22 | use_history: bool 23 | efficient_eos: bool 24 | replace_eos: bool 25 | 26 | def encode_oneturn( 27 | self, 28 | tokenizer: "PreTrainedTokenizer", 29 | query: str, 30 | resp: str, 31 | history: Optional[List[Tuple[str, str]]] = None, 32 | system: Optional[str] = None 33 | ) -> Tuple[List[int], List[int]]: 34 | r""" 35 | Returns a single pair of token ids representing prompt and response respectively. 36 | """ 37 | system, history = self._format(query, resp, history, system) 38 | encoded_pairs = self._encode(tokenizer, system, history) 39 | prompt_ids = [] 40 | for query_ids, resp_ids in encoded_pairs[:-1]: 41 | prompt_ids = prompt_ids + query_ids + resp_ids 42 | prompt_ids = prompt_ids + encoded_pairs[-1][0] 43 | answer_ids = encoded_pairs[-1][1] 44 | return prompt_ids, answer_ids 45 | 46 | def encode_multiturn( 47 | self, 48 | tokenizer: "PreTrainedTokenizer", 49 | query: str, 50 | resp: str, 51 | history: Optional[List[Tuple[str, str]]] = None, 52 | system: Optional[str] = None 53 | ) -> List[Tuple[List[int], List[int]]]: 54 | r""" 55 | Returns multiple pairs of token ids representing prompts and responses respectively. 56 | """ 57 | system, history = self._format(query, resp, history, system) 58 | encoded_pairs = self._encode(tokenizer, system, history) 59 | return encoded_pairs 60 | 61 | def _format( 62 | self, 63 | query: str, 64 | resp: str, 65 | history: Optional[List[Tuple[str, str]]] = None, 66 | system: Optional[str] = None 67 | ) -> Tuple[str, List[Tuple[str, str]]]: 68 | r""" 69 | Aligns inputs to the standard format. 70 | """ 71 | system = system or self.system # use system if provided 72 | history = history if (history and self.use_history) else [] 73 | history = history + [(query, resp)] 74 | return system, history 75 | 76 | def _get_special_ids( 77 | self, 78 | tokenizer: "PreTrainedTokenizer" 79 | ) -> Tuple[List[int], List[int]]: 80 | if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): 81 | bos_ids = [tokenizer.bos_token_id] 82 | else: # baichuan, gpt2, qwen, yi models have no bos token 83 | bos_ids = [] 84 | 85 | if tokenizer.eos_token_id is None: 86 | raise ValueError("EOS token is required.") 87 | 88 | if self.efficient_eos: 89 | eos_ids = [] 90 | else: 91 | eos_ids = [tokenizer.eos_token_id] 92 | 93 | return bos_ids, eos_ids 94 | 95 | def _encode( 96 | self, 97 | tokenizer: "PreTrainedTokenizer", 98 | system: str, 99 | history: List[Tuple[str, str]] 100 | ) -> List[Tuple[List[int], List[int]]]: 101 | r""" 102 | Encodes formatted inputs to pairs of token ids. 103 | Turn 0: bos + prefix + sep + query resp + eos 104 | Turn t: sep + bos + query resp + eos 105 | """ 106 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 107 | sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) 108 | encoded_pairs = [] 109 | for turn_idx, (query, resp) in enumerate(history): 110 | if turn_idx == 0: 111 | prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) 112 | if len(prefix_ids) != 0: # has prefix 113 | prefix_ids = bos_ids + prefix_ids + sep_ids 114 | else: 115 | prefix_ids = bos_ids 116 | else: 117 | prefix_ids = sep_ids + bos_ids 118 | 119 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1)) 120 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 121 | encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) 122 | return encoded_pairs 123 | 124 | def _convert_inputs_to_ids( 125 | self, 126 | tokenizer: "PreTrainedTokenizer", 127 | context: List[Union[str, Dict[str, str]]], 128 | system: Optional[str] = None, 129 | query: Optional[str] = None, 130 | idx: Optional[str] = None 131 | ) -> List[int]: 132 | r""" 133 | Converts context to token ids. 134 | """ 135 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen) 136 | kwargs = dict(allowed_special="all") 137 | else: 138 | kwargs = dict(add_special_tokens=False) 139 | 140 | token_ids = [] 141 | for elem in context: 142 | if isinstance(elem, str): 143 | elem = elem.replace("{{system}}", system, 1) if system is not None else elem 144 | elem = elem.replace("{{query}}", query, 1) if query is not None else elem 145 | elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem 146 | if len(elem) != 0: 147 | token_ids = token_ids + tokenizer.encode(elem, **kwargs) 148 | elif isinstance(elem, dict): 149 | token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] 150 | else: 151 | raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem))) 152 | 153 | return token_ids 154 | 155 | 156 | @dataclass 157 | class Llama2Template(Template): 158 | 159 | def _encode( 160 | self, 161 | tokenizer: "PreTrainedTokenizer", 162 | system: str, 163 | history: List[Tuple[str, str]] 164 | ) -> List[Tuple[List[int], List[int]]]: 165 | r""" 166 | Encodes formatted inputs to pairs of token ids. 167 | Turn 0: bos + prefix + query resp + eos 168 | Turn t: bos + query resp + eos 169 | """ 170 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 171 | encoded_pairs = [] 172 | for turn_idx, (query, resp) in enumerate(history): 173 | if turn_idx == 0: # llama2 template has no sep_ids 174 | query = self.prefix[0].replace("{{system}}", system) + query 175 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) 176 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 177 | encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) 178 | return encoded_pairs 179 | 180 | 181 | templates: Dict[str, Template] = {} 182 | 183 | 184 | def register_template( 185 | name: str, 186 | prefix: List[Union[str, Dict[str, str]]], 187 | prompt: List[Union[str, Dict[str, str]]], 188 | system: str, 189 | sep: List[Union[str, Dict[str, str]]], 190 | stop_words: Optional[List[str]] = [], 191 | use_history: Optional[bool] = True, 192 | efficient_eos: Optional[bool] = False, 193 | replace_eos: Optional[bool] = False 194 | ) -> None: 195 | template_class = Llama2Template if name.startswith("llama2") else Template 196 | templates[name] = template_class( 197 | prefix=prefix, 198 | prompt=prompt, 199 | system=system, 200 | sep=sep, 201 | stop_words=stop_words, 202 | use_history=use_history, 203 | efficient_eos=efficient_eos, 204 | replace_eos=replace_eos 205 | ) 206 | 207 | 208 | def get_template_and_fix_tokenizer( 209 | name: str, 210 | tokenizer: "PreTrainedTokenizer" 211 | ) -> Template: 212 | if tokenizer.eos_token_id is None: 213 | tokenizer.eos_token = "<|endoftext|>" 214 | logger.info("Add eos token: {}".format(tokenizer.eos_token)) 215 | 216 | if tokenizer.pad_token_id is None: 217 | tokenizer.pad_token = tokenizer.eos_token 218 | logger.info("Add pad token: {}".format(tokenizer.pad_token)) 219 | 220 | if name is None: # for pre-training 221 | return None 222 | 223 | template = templates.get(name, None) 224 | assert template is not None, "Template {} does not exist.".format(name) 225 | 226 | stop_words = template.stop_words 227 | if template.replace_eos: 228 | if not stop_words: 229 | raise ValueError("Stop words are required to replace the EOS token.") 230 | 231 | tokenizer.eos_token = stop_words[0] 232 | stop_words = stop_words[1:] 233 | logger.info("Replace eos token: {}".format(tokenizer.eos_token)) 234 | 235 | if stop_words: 236 | tokenizer.add_special_tokens( 237 | dict(additional_special_tokens=stop_words), 238 | replace_additional_special_tokens=False 239 | ) 240 | logger.info("Add {} to stop words.".format(",".join(stop_words))) 241 | 242 | return template 243 | 244 | 245 | register_template( 246 | name="alpaca", 247 | prefix=[ 248 | "{{system}}" 249 | ], 250 | prompt=[ 251 | "### Instruction:\n{{query}}\n\n### Response:\n" 252 | ], 253 | system=( 254 | "Below is an instruction that describes a task. " 255 | "Write a response that appropriately completes the request." 256 | ), 257 | sep=[ 258 | "\n\n" 259 | ] 260 | ) 261 | 262 | 263 | register_template( 264 | name="aquila", 265 | prefix=[ 266 | "{{system}}" 267 | ], 268 | prompt=[ 269 | "Human: {{query}}###Assistant:" 270 | ], 271 | system=( 272 | "A chat between a curious human and an artificial intelligence assistant. " 273 | "The assistant gives helpful, detailed, and polite answers to the human's questions." 274 | ), 275 | sep=[ 276 | "###" 277 | ], 278 | stop_words=[ 279 | "" 280 | ], 281 | efficient_eos=True 282 | ) 283 | 284 | 285 | register_template( 286 | name="baichuan", 287 | prefix=[ 288 | "{{system}}" 289 | ], 290 | prompt=[ 291 | {"token": ""}, # user token 292 | "{{query}}", 293 | {"token": ""} # assistant token 294 | ], 295 | system="", 296 | sep=[], 297 | efficient_eos=True 298 | ) 299 | 300 | 301 | register_template( 302 | name="baichuan2", 303 | prefix=[ 304 | "{{system}}" 305 | ], 306 | prompt=[ 307 | {"token": ""}, # user token 308 | "{{query}}", 309 | {"token": ""} # assistant token 310 | ], 311 | system="", 312 | sep=[], 313 | efficient_eos=True 314 | ) 315 | 316 | 317 | register_template( 318 | name="belle", 319 | prefix=[ 320 | "{{system}}" 321 | ], 322 | prompt=[ 323 | "Human: {{query}}\n\nBelle: " 324 | ], 325 | system="", 326 | sep=[ 327 | "\n\n" 328 | ] 329 | ) 330 | 331 | 332 | register_template( 333 | name="bluelm", 334 | prefix=[ 335 | "{{system}}" 336 | ], 337 | prompt=[ 338 | {"token": "[|Human|]:"}, 339 | "{{query}}", 340 | {"token": "[|AI|]:"} 341 | ], 342 | system="", 343 | sep=[] 344 | ) 345 | 346 | 347 | register_template( 348 | name="chatglm2", 349 | prefix=[ 350 | {"token": "[gMASK]"}, 351 | {"token": "sop"}, 352 | "{{system}}" 353 | ], 354 | prompt=[ 355 | "[Round {{idx}}]\n\n问:{{query}}\n\n答:" 356 | ], 357 | system="", 358 | sep=[ 359 | "\n\n" 360 | ], 361 | efficient_eos=True 362 | ) 363 | 364 | 365 | register_template( 366 | name="chatglm3", 367 | prefix=[ 368 | {"token": "[gMASK]"}, 369 | {"token": "sop"}, 370 | {"token": "<|system|>"}, 371 | "\n", 372 | "{{system}}" 373 | ], 374 | prompt=[ 375 | {"token": "<|user|>"}, 376 | "\n", 377 | "{{query}}", 378 | {"token": "<|assistant|>"}, 379 | "\n" # add an extra newline to avoid error in ChatGLM's process_response method 380 | ], 381 | system=( 382 | "You are ChatGLM3, a large language model trained by Zhipu.AI. " 383 | "Follow the user's instructions carefully. Respond using markdown." 384 | ), 385 | sep=[], 386 | stop_words=[ 387 | "<|user|>", 388 | "<|observation|>" 389 | ], 390 | efficient_eos=True 391 | ) 392 | 393 | 394 | register_template( 395 | name="chatglm3_raw", # the raw template for tool tuning 396 | prefix=[ 397 | {"token": "[gMASK]"}, 398 | {"token": "sop"}, 399 | {"token": "<|system|>"}, 400 | "\n", 401 | "{{system}}" 402 | ], 403 | prompt=[ 404 | {"token": "<|user|>"}, 405 | "\n", 406 | "{{query}}", 407 | {"token": "<|assistant|>"} 408 | ], 409 | system=( 410 | "You are ChatGLM3, a large language model trained by Zhipu.AI. " 411 | "Follow the user's instructions carefully. Respond using markdown." 412 | ), 413 | sep=[], 414 | stop_words=[ 415 | "<|user|>", 416 | "<|observation|>" 417 | ], 418 | efficient_eos=True 419 | ) 420 | 421 | 422 | register_template( 423 | name="codegeex2", 424 | prefix=[ 425 | {"token": "[gMASK]"}, 426 | {"token": "sop"}, 427 | "{{system}}" 428 | ], 429 | prompt=[ 430 | "{{query}}" 431 | ], 432 | system="", 433 | sep=[] 434 | ) 435 | 436 | 437 | register_template( 438 | name="deepseek", 439 | prefix=[ 440 | "{{system}}" 441 | ], 442 | prompt=[ 443 | "User: {{query}}\n\nAssistant:" 444 | ], 445 | system="", 446 | sep=[] 447 | ) 448 | 449 | 450 | register_template( 451 | name="deepseekcoder", 452 | prefix=[ 453 | "{{system}}" 454 | ], 455 | prompt=[ 456 | "### Instruction:\n{{query}}\n### Response:\n" 457 | ], 458 | system=( 459 | "You are an AI programming assistant, utilizing the Deepseek Coder model, " 460 | "developed by Deepseek Company, and you only answer questions related to computer science. " 461 | "For politically sensitive questions, security and privacy issues, " 462 | "and other non-computer science questions, you will refuse to answer\n" 463 | ), 464 | sep=[ 465 | "\n", 466 | {"token": "<|EOT|>"}, 467 | "\n" 468 | ], 469 | stop_words=[ 470 | "<|EOT|>" 471 | ], 472 | efficient_eos=True 473 | ) 474 | 475 | 476 | register_template( 477 | name="default", 478 | prefix=[ 479 | "{{system}}" 480 | ], 481 | prompt=[ 482 | "Human: {{query}}\nAssistant:" 483 | ], 484 | system=( 485 | "A chat between a curious user and an artificial intelligence assistant. " 486 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 487 | ), 488 | sep=[ 489 | "\n" 490 | ] 491 | ) 492 | 493 | 494 | register_template( 495 | name="falcon", 496 | prefix=[ 497 | "{{system}}" 498 | ], 499 | prompt=[ 500 | "User: {{query}}\nFalcon:" 501 | ], 502 | system="", 503 | sep=[ 504 | "\n" 505 | ], 506 | efficient_eos=True 507 | ) 508 | 509 | 510 | register_template( 511 | name="intern", 512 | prefix=[ 513 | "{{system}}" 514 | ], 515 | prompt=[ 516 | "<|User|>:{{query}}", 517 | {"token": ""}, 518 | "\n<|Bot|>:" 519 | ], 520 | system="", 521 | sep=[ 522 | {"token": ""}, 523 | "\n" 524 | ], 525 | stop_words=[ 526 | "" 527 | ], 528 | efficient_eos=True 529 | ) 530 | 531 | 532 | register_template( 533 | name="llama2", 534 | prefix=[ 535 | "<>\n{{system}}\n<>\n\n" 536 | ], 537 | prompt=[ 538 | "[INST] {{query}} [/INST]" 539 | ], 540 | system=( 541 | "You are a helpful, respectful and honest assistant. " 542 | "Always answer as helpfully as possible, while being safe. " 543 | "Your answers should not include any harmful, unethical, " 544 | "racist, sexist, toxic, dangerous, or illegal content. " 545 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 546 | "If a question does not make any sense, or is not factually coherent, " 547 | "explain why instead of answering something not correct. " 548 | "If you don't know the answer to a question, please don't share false information." 549 | ), 550 | sep=[] 551 | ) 552 | 553 | 554 | register_template( 555 | name="llama2_zh", 556 | prefix=[ 557 | "<>\n{{system}}\n<>\n\n" 558 | ], 559 | prompt=[ 560 | "[INST] {{query}} [/INST]" 561 | ], 562 | system="You are a helpful assistant. 你是一个乐于助人的助手。", 563 | sep=[] 564 | ) 565 | 566 | 567 | register_template( 568 | name="mistral", 569 | prefix=[ 570 | "{{system}}" 571 | ], 572 | prompt=[ 573 | "[INST] {{query}} [/INST]" 574 | ], 575 | system="", 576 | sep=[] 577 | ) 578 | 579 | 580 | register_template( 581 | name="openchat", 582 | prefix=[ 583 | "{{system}}" 584 | ], 585 | prompt=[ 586 | "GPT4 Correct User: {{query}}", 587 | {"token": "<|end_of_turn|>"}, 588 | "GPT4 Correct Assistant:" 589 | ], 590 | system="", 591 | sep=[ 592 | {"token": "<|end_of_turn|>"} 593 | ], 594 | stop_words=[ 595 | "<|end_of_turn|>" 596 | ], 597 | efficient_eos=True 598 | ) 599 | 600 | 601 | register_template( 602 | name="qwen", 603 | prefix=[ 604 | "<|im_start|>system\n{{system}}<|im_end|>" 605 | ], 606 | prompt=[ 607 | "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n" 608 | ], 609 | system="You are a helpful assistant.", 610 | sep=[ 611 | "\n" 612 | ], 613 | stop_words=[ 614 | "<|im_end|>" 615 | ], 616 | replace_eos=True 617 | ) 618 | 619 | 620 | register_template( 621 | name="starchat", 622 | prefix=[ 623 | {"token": "<|system|>"}, 624 | "\n{{system}}", 625 | ], 626 | prompt=[ 627 | {"token": "<|user|>"}, 628 | "\n{{query}}", 629 | {"token": "<|end|>"}, 630 | "\n", 631 | {"token": "<|assistant|>"} 632 | ], 633 | system="", 634 | sep=[ 635 | {"token": "<|end|>"}, 636 | "\n" 637 | ], 638 | stop_words=[ 639 | "<|end|>" 640 | ], 641 | efficient_eos=True 642 | ) 643 | 644 | 645 | register_template( 646 | name="vanilla", 647 | prefix=[], 648 | prompt=[ 649 | "{{query}}" 650 | ], 651 | system="", 652 | sep=[], 653 | use_history=False 654 | ) 655 | 656 | 657 | register_template( 658 | name="vicuna", 659 | prefix=[ 660 | "{{system}}" 661 | ], 662 | prompt=[ 663 | "USER: {{query}} ASSISTANT:" 664 | ], 665 | system=( 666 | "A chat between a curious user and an artificial intelligence assistant. " 667 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 668 | ), 669 | sep=[] 670 | ) 671 | 672 | 673 | register_template( 674 | name="xuanyuan", 675 | prefix=[ 676 | "{{system}}" 677 | ], 678 | prompt=[ 679 | "Human: {{query}} Assistant:" 680 | ], 681 | system=( 682 | "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," 683 | "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" 684 | "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" 685 | ), 686 | sep=[] 687 | ) 688 | 689 | 690 | register_template( 691 | name="xverse", 692 | prefix=[ 693 | "{{system}}" 694 | ], 695 | prompt=[ 696 | "Human: {{query}}\n\nAssistant: " 697 | ], 698 | system="", 699 | sep=[] 700 | ) 701 | 702 | 703 | register_template( 704 | name="yayi", 705 | prefix=[ 706 | {"token": "<|System|>"}, 707 | ":\n{{system}}" 708 | ], 709 | prompt=[ 710 | {"token": "<|Human|>"}, 711 | ":\n{{query}}\n\n", 712 | {"token": "<|YaYi|>"}, 713 | ":" 714 | ], 715 | system=( 716 | "You are a helpful, respectful and honest assistant named YaYi " 717 | "developed by Beijing Wenge Technology Co.,Ltd. " 718 | "Always answer as helpfully as possible, while being safe. " 719 | "Your answers should not include any harmful, unethical, " 720 | "racist, sexist, toxic, dangerous, or illegal content. " 721 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 722 | "If a question does not make any sense, or is not factually coherent, " 723 | "explain why instead of answering something not correct. " 724 | "If you don't know the answer to a question, please don't share false information." 725 | ), 726 | sep=[ 727 | "\n\n" 728 | ], 729 | stop_words=[ 730 | "<|End|>" 731 | ] 732 | ) 733 | 734 | 735 | register_template( 736 | name="yi", 737 | prefix=[ 738 | "{{system}}" 739 | ], 740 | prompt=[ 741 | "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n" 742 | ], 743 | system="", 744 | sep=[ 745 | "\n" 746 | ], 747 | stop_words=[ 748 | "<|im_end|>" 749 | ], 750 | replace_eos=True 751 | ) 752 | 753 | 754 | register_template( 755 | name="yuan", 756 | prefix=[ 757 | "{{system}}" 758 | ], 759 | prompt=[ 760 | "{{query}}", 761 | {"token": ""} 762 | ], 763 | system="", 764 | sep=[ 765 | "\n" 766 | ], 767 | stop_words=[ 768 | "" 769 | ], 770 | replace_eos=True 771 | ) 772 | 773 | 774 | register_template( 775 | name="zephyr", 776 | prefix=[ 777 | {"token": "<|system|>"}, 778 | "\n{{system}}", 779 | {"token": ""} 780 | ], 781 | prompt=[ 782 | {"token": "<|user|>"}, 783 | "\n{{query}}", 784 | {"token": ""}, 785 | {"token": "<|assistant|>"} 786 | ], 787 | system="You are a friendly chatbot who always responds in the style of a pirate", 788 | sep=[] 789 | ) 790 | 791 | 792 | register_template( 793 | name="ziya", 794 | prefix=[ 795 | "{{system}}" 796 | ], 797 | prompt=[ 798 | {"token": ""}, 799 | ":{{query}}\n", 800 | {"token": ""}, 801 | ":" 802 | ], 803 | system="", 804 | sep=[ 805 | "\n" 806 | ] 807 | ) 808 | -------------------------------------------------------------------------------- /model/src/data/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import TYPE_CHECKING, Dict, List, Optional, Union 3 | 4 | from ..extras.logging import get_logger 5 | 6 | if TYPE_CHECKING: 7 | from datasets import Dataset, IterableDataset 8 | from transformers import TrainingArguments 9 | from ..hparams.data_args import DataArguments 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: 16 | if file_sha1 is None: 17 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") 18 | return 19 | 20 | if len(data_files) != 1: 21 | logger.warning("Checksum failed: too many files.") 22 | return 23 | 24 | with open(data_files[0], "rb") as f: 25 | sha1 = hashlib.sha1(f.read()).hexdigest() 26 | if sha1 != file_sha1: 27 | logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) 28 | 29 | 30 | def split_dataset( 31 | dataset: Union["Dataset", "IterableDataset"], 32 | data_args: "DataArguments", 33 | training_args: "TrainingArguments" 34 | ) -> Dict[str, "Dataset"]: 35 | if training_args.do_train: 36 | if data_args.val_size > 1e-6: # Split the dataset 37 | if data_args.streaming: 38 | val_set = dataset.take(int(data_args.val_size)) 39 | train_set = dataset.skip(int(data_args.val_size)) 40 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 41 | return {"train_dataset": train_set, "eval_dataset": val_set} 42 | else: 43 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size 44 | dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) 45 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} 46 | else: 47 | if data_args.streaming: 48 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 49 | return {"train_dataset": dataset} 50 | else: # do_eval or do_predict 51 | return {"eval_dataset": dataset} 52 | -------------------------------------------------------------------------------- /model/src/extras/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from typing import TYPE_CHECKING 5 | from datetime import timedelta 6 | from transformers import TrainerCallback 7 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR 8 | 9 | from .constants import LOG_FILE_NAME 10 | from .logging import get_logger 11 | from .misc import fix_valuehead_checkpoint 12 | 13 | 14 | if TYPE_CHECKING: 15 | from transformers import TrainingArguments, TrainerState, TrainerControl 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | class FixValueHeadModelCallback(TrainerCallback): 21 | 22 | def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 23 | r""" 24 | Event called after a checkpoint save. 25 | """ 26 | if args.should_save: 27 | fix_valuehead_checkpoint( 28 | model=kwargs.pop("model"), 29 | output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), 30 | safe_serialization=args.save_safetensors 31 | ) 32 | 33 | class LogCallback(TrainerCallback): 34 | 35 | def __init__(self, runner=None): 36 | self.runner = runner 37 | self.in_training = False 38 | self.start_time = time.time() 39 | self.cur_steps = 0 40 | self.max_steps = 0 41 | self.elapsed_time = "" 42 | self.remaining_time = "" 43 | 44 | def timing(self): 45 | cur_time = time.time() 46 | elapsed_time = cur_time - self.start_time 47 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 48 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step 49 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) 50 | self.remaining_time = str(timedelta(seconds=int(remaining_time))) 51 | 52 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 53 | r""" 54 | Event called at the beginning of training. 55 | """ 56 | if state.is_local_process_zero: 57 | self.in_training = True 58 | self.start_time = time.time() 59 | self.max_steps = state.max_steps 60 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: 61 | logger.warning("Previous log file in this folder will be deleted.") 62 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) 63 | 64 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 65 | r""" 66 | Event called at the end of training. 67 | """ 68 | if state.is_local_process_zero: 69 | self.in_training = False 70 | self.cur_steps = 0 71 | self.max_steps = 0 72 | 73 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 74 | r""" 75 | Event called at the end of an substep during gradient accumulation. 76 | """ 77 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted: 78 | control.should_epoch_stop = True 79 | control.should_training_stop = True 80 | 81 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 82 | r""" 83 | Event called at the end of a training step. 84 | """ 85 | if state.is_local_process_zero: 86 | self.cur_steps = state.global_step 87 | self.timing() 88 | if self.runner is not None and self.runner.aborted: 89 | control.should_epoch_stop = True 90 | control.should_training_stop = True 91 | 92 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 93 | r""" 94 | Event called after an evaluation phase. 95 | """ 96 | if state.is_local_process_zero and not self.in_training: 97 | self.cur_steps = 0 98 | self.max_steps = 0 99 | 100 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): 101 | r""" 102 | Event called after a successful prediction. 103 | """ 104 | if state.is_local_process_zero and not self.in_training: 105 | self.cur_steps = 0 106 | self.max_steps = 0 107 | 108 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: 109 | r""" 110 | Event called after logging the last logs. 111 | """ 112 | if not state.is_local_process_zero: 113 | return 114 | 115 | logs = dict( 116 | current_steps=self.cur_steps, 117 | total_steps=self.max_steps, 118 | loss=state.log_history[-1].get("loss", None), 119 | eval_loss=state.log_history[-1].get("eval_loss", None), 120 | predict_loss=state.log_history[-1].get("predict_loss", None), 121 | reward=state.log_history[-1].get("reward", None), 122 | learning_rate=state.log_history[-1].get("learning_rate", None), 123 | epoch=state.log_history[-1].get("epoch", None), 124 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, 125 | elapsed_time=self.elapsed_time, 126 | remaining_time=self.remaining_time 127 | ) 128 | if self.runner is not None: 129 | logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( 130 | logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 131 | )) 132 | 133 | os.makedirs(args.output_dir, exist_ok=True) 134 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: 135 | f.write(json.dumps(logs) + "\n") 136 | 137 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 138 | r""" 139 | Event called after a prediction step. 140 | """ 141 | eval_dataloader = kwargs.pop("eval_dataloader", None) 142 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: 143 | if self.max_steps == 0: 144 | self.max_steps = len(eval_dataloader) 145 | self.cur_steps += 1 146 | self.timing() 147 | -------------------------------------------------------------------------------- /model/src/extras/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from collections import defaultdict, OrderedDict 3 | from typing import Dict, Optional 4 | 5 | 6 | CHOICES = ["A", "B", "C", "D"] 7 | 8 | DEFAULT_MODULE = defaultdict(str) 9 | 10 | DEFAULT_TEMPLATE = defaultdict(str) 11 | 12 | FILEEXT2TYPE = { 13 | "arrow": "arrow", 14 | "csv": "csv", 15 | "json": "json", 16 | "jsonl": "json", 17 | "parquet": "parquet", 18 | "txt": "text" 19 | } 20 | 21 | IGNORE_INDEX = -100 22 | 23 | LAYERNORM_NAMES = {"norm", "ln"} 24 | 25 | LOG_FILE_NAME = "trainer_log.jsonl" 26 | 27 | METHODS = ["full", "freeze", "lora"] 28 | 29 | PEFT_METHODS = ["lora"] 30 | 31 | SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] 32 | 33 | SUPPORTED_MODELS = OrderedDict() 34 | 35 | TRAINING_STAGES = { 36 | "Supervised Fine-Tuning": "sft", 37 | "Reward Modeling": "rm", 38 | "PPO": "ppo", 39 | "DPO": "dpo", 40 | "Pre-Training": "pt" 41 | } 42 | 43 | V_HEAD_WEIGHTS_NAME = "value_head.bin" 44 | 45 | V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" 46 | 47 | class DownloadSource(str, Enum): 48 | DEFAULT = "hf" 49 | MODELSCOPE = "ms" 50 | 51 | 52 | def register_model_group( 53 | models: Dict[str, Dict[DownloadSource, str]], 54 | module: Optional[str] = None, 55 | template: Optional[str] = None 56 | ) -> None: 57 | prefix = None 58 | for name, path in models.items(): 59 | if prefix is None: 60 | prefix = name.split("-")[0] 61 | else: 62 | assert prefix == name.split("-")[0], "prefix should be identical." 63 | SUPPORTED_MODELS[name] = path 64 | if module is not None: 65 | DEFAULT_MODULE[prefix] = module 66 | if template is not None: 67 | DEFAULT_TEMPLATE[prefix] = template 68 | 69 | 70 | register_model_group( 71 | models={ 72 | "Baichuan-7B-Base": { 73 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", 74 | DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B" 75 | }, 76 | "Baichuan-13B-Base": { 77 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", 78 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base" 79 | }, 80 | "Baichuan-13B-Chat": { 81 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", 82 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat" 83 | } 84 | }, 85 | module="W_pack", 86 | template="baichuan" 87 | ) 88 | 89 | 90 | register_model_group( 91 | models={ 92 | "Baichuan2-7B-Base": { 93 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", 94 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base" 95 | }, 96 | "Baichuan2-13B-Base": { 97 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", 98 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base" 99 | }, 100 | "Baichuan2-7B-Chat": { 101 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", 102 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat" 103 | }, 104 | "Baichuan2-13B-Chat": { 105 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", 106 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat" 107 | } 108 | }, 109 | module="W_pack", 110 | template="baichuan2" 111 | ) 112 | 113 | 114 | register_model_group( 115 | models={ 116 | "BLOOM-560M": { 117 | DownloadSource.DEFAULT: "bigscience/bloom-560m", 118 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m" 119 | }, 120 | "BLOOM-3B": { 121 | DownloadSource.DEFAULT: "bigscience/bloom-3b", 122 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b" 123 | }, 124 | "BLOOM-7B1": { 125 | DownloadSource.DEFAULT: "bigscience/bloom-7b1", 126 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1" 127 | } 128 | }, 129 | module="query_key_value" 130 | ) 131 | 132 | 133 | register_model_group( 134 | models={ 135 | "BLOOMZ-560M": { 136 | DownloadSource.DEFAULT: "bigscience/bloomz-560m", 137 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m" 138 | }, 139 | "BLOOMZ-3B": { 140 | DownloadSource.DEFAULT: "bigscience/bloomz-3b", 141 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b" 142 | }, 143 | "BLOOMZ-7B1-mt": { 144 | DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", 145 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt" 146 | } 147 | }, 148 | module="query_key_value" 149 | ) 150 | 151 | 152 | register_model_group( 153 | models={ 154 | "BlueLM-7B-Base": { 155 | DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", 156 | DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base" 157 | }, 158 | "BlueLM-7B-Chat": { 159 | DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", 160 | DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat" 161 | } 162 | }, 163 | template="bluelm" 164 | ) 165 | 166 | 167 | register_model_group( 168 | models={ 169 | "ChatGLM2-6B-Chat": { 170 | DownloadSource.DEFAULT: "THUDM/chatglm2-6b", 171 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b" 172 | } 173 | }, 174 | module="query_key_value", 175 | template="chatglm2" 176 | ) 177 | 178 | 179 | register_model_group( 180 | models={ 181 | "ChatGLM3-6B-Base": { 182 | DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", 183 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base" 184 | }, 185 | "ChatGLM3-6B-Chat": { 186 | DownloadSource.DEFAULT: "THUDM/chatglm3-6b", 187 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b" 188 | } 189 | }, 190 | module="query_key_value", 191 | template="chatglm3" 192 | ) 193 | 194 | 195 | register_model_group( 196 | models={ 197 | "ChineseLLaMA2-1.3B": { 198 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", 199 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b" 200 | }, 201 | "ChineseLLaMA2-7B": { 202 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", 203 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b" 204 | }, 205 | "ChineseLLaMA2-13B": { 206 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", 207 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b" 208 | }, 209 | "ChineseLLaMA2-1.3B-Chat": { 210 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", 211 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b" 212 | }, 213 | "ChineseLLaMA2-7B-Chat": { 214 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", 215 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b" 216 | }, 217 | "ChineseLLaMA2-13B-Chat": { 218 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", 219 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b" 220 | } 221 | }, 222 | template="llama2_zh" 223 | ) 224 | 225 | 226 | register_model_group( 227 | models={ 228 | "DeepseekLLM-7B-Base": { 229 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", 230 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base" 231 | }, 232 | "DeepseekLLM-67B-Base": { 233 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", 234 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base" 235 | }, 236 | "DeepseekLLM-7B-Chat": { 237 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", 238 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat" 239 | }, 240 | "DeepseekLLM-67B-Chat": { 241 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", 242 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat" 243 | } 244 | }, 245 | template="deepseek" 246 | ) 247 | 248 | 249 | register_model_group( 250 | models={ 251 | "DeepseekCoder-6.7B-Base": { 252 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", 253 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base" 254 | }, 255 | "DeepseekCoder-33B-Base": { 256 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", 257 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base" 258 | }, 259 | "DeepseekCoder-6.7B-Chat": { 260 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", 261 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct" 262 | }, 263 | "DeepseekCoder-33B-Chat": { 264 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", 265 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct" 266 | } 267 | }, 268 | template="deepseekcoder" 269 | ) 270 | 271 | 272 | register_model_group( 273 | models={ 274 | "Falcon-7B": { 275 | DownloadSource.DEFAULT: "tiiuae/falcon-7b", 276 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b" 277 | }, 278 | "Falcon-40B": { 279 | DownloadSource.DEFAULT: "tiiuae/falcon-40b", 280 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b" 281 | }, 282 | "Falcon-180B": { 283 | DownloadSource.DEFAULT: "tiiuae/falcon-180b", 284 | DownloadSource.MODELSCOPE: "modelscope/falcon-180B" 285 | }, 286 | "Falcon-7B-Chat": { 287 | DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", 288 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct" 289 | }, 290 | "Falcon-40B-Chat": { 291 | DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", 292 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct" 293 | }, 294 | "Falcon-180B-Chat": { 295 | DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", 296 | DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat" 297 | } 298 | }, 299 | module="query_key_value", 300 | template="falcon" 301 | ) 302 | 303 | 304 | register_model_group( 305 | models={ 306 | "InternLM-7B": { 307 | DownloadSource.DEFAULT: "internlm/internlm-7b", 308 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b" 309 | }, 310 | "InternLM-20B": { 311 | DownloadSource.DEFAULT: "internlm/internlm-20b", 312 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b" 313 | }, 314 | "InternLM-7B-Chat": { 315 | DownloadSource.DEFAULT: "internlm/internlm-chat-7b", 316 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b" 317 | }, 318 | "InternLM-20B-Chat": { 319 | DownloadSource.DEFAULT: "internlm/internlm-chat-20b", 320 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b" 321 | } 322 | }, 323 | template="intern" 324 | ) 325 | 326 | 327 | register_model_group( 328 | models={ 329 | "LingoWhale-8B": { 330 | DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", 331 | DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B" 332 | } 333 | }, 334 | module="qkv_proj" 335 | ) 336 | 337 | 338 | register_model_group( 339 | models={ 340 | "LLaMA-7B": { 341 | DownloadSource.DEFAULT: "huggyllama/llama-7b", 342 | DownloadSource.MODELSCOPE: "skyline2006/llama-7b" 343 | }, 344 | "LLaMA-13B": { 345 | DownloadSource.DEFAULT: "huggyllama/llama-13b", 346 | DownloadSource.MODELSCOPE: "skyline2006/llama-13b" 347 | }, 348 | "LLaMA-30B": { 349 | DownloadSource.DEFAULT: "huggyllama/llama-30b", 350 | DownloadSource.MODELSCOPE: "skyline2006/llama-30b" 351 | }, 352 | "LLaMA-65B": { 353 | DownloadSource.DEFAULT: "huggyllama/llama-65b", 354 | DownloadSource.MODELSCOPE: "skyline2006/llama-65b" 355 | } 356 | } 357 | ) 358 | 359 | 360 | register_model_group( 361 | models={ 362 | "LLaMA2-7B": { 363 | DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", 364 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms" 365 | }, 366 | "LLaMA2-13B": { 367 | DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", 368 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms" 369 | }, 370 | "LLaMA2-70B": { 371 | DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", 372 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms" 373 | }, 374 | "LLaMA2-7B-Chat": { 375 | DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", 376 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms" 377 | }, 378 | "LLaMA2-13B-Chat": { 379 | DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", 380 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms" 381 | }, 382 | "LLaMA2-70B-Chat": { 383 | DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", 384 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms" 385 | } 386 | }, 387 | template="llama2" 388 | ) 389 | 390 | 391 | register_model_group( 392 | models={ 393 | "Mistral-7B": { 394 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", 395 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1" 396 | }, 397 | "Mistral-7B-Chat": { 398 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", 399 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1" 400 | }, 401 | "Mistral-7B-v0.2-Chat": { 402 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", 403 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2" 404 | } 405 | }, 406 | template="mistral" 407 | ) 408 | 409 | 410 | register_model_group( 411 | models={ 412 | "Mixtral-8x7B": { 413 | DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", 414 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1" 415 | }, 416 | "Mixtral-8x7B-Chat": { 417 | DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", 418 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1" 419 | } 420 | }, 421 | template="mistral" 422 | ) 423 | 424 | 425 | register_model_group( 426 | models={ 427 | "OpenChat3.5-7B-Chat": { 428 | DownloadSource.DEFAULT: "openchat/openchat_3.5", 429 | DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5" 430 | } 431 | }, 432 | template="openchat" 433 | ) 434 | 435 | 436 | register_model_group( 437 | models={ 438 | "Phi-1.5-1.3B": { 439 | DownloadSource.DEFAULT: "microsoft/phi-1_5", 440 | DownloadSource.MODELSCOPE: "allspace/PHI_1-5" 441 | }, 442 | "Phi-2-2.7B": { 443 | DownloadSource.DEFAULT: "microsoft/phi-2", 444 | DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2" 445 | } 446 | }, 447 | module="Wqkv" 448 | ) 449 | 450 | 451 | register_model_group( 452 | models={ 453 | "Qwen-1.8B": { 454 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", 455 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B" 456 | }, 457 | "Qwen-7B": { 458 | DownloadSource.DEFAULT: "Qwen/Qwen-7B", 459 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B" 460 | }, 461 | "Qwen-14B": { 462 | DownloadSource.DEFAULT: "Qwen/Qwen-14B", 463 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B" 464 | }, 465 | "Qwen-72B": { 466 | DownloadSource.DEFAULT: "Qwen/Qwen-72B", 467 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B" 468 | }, 469 | "Qwen-1.8B-Chat": { 470 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", 471 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat" 472 | }, 473 | "Qwen-7B-Chat": { 474 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", 475 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat" 476 | }, 477 | "Qwen-14B-Chat": { 478 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", 479 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat" 480 | }, 481 | "Qwen-72B-Chat": { 482 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", 483 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat" 484 | }, 485 | "Qwen-1.8B-int8-Chat": { 486 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", 487 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8" 488 | }, 489 | "Qwen-1.8B-int4-Chat": { 490 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", 491 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4" 492 | }, 493 | "Qwen-7B-int8-Chat": { 494 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", 495 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8" 496 | }, 497 | "Qwen-7B-int4-Chat": { 498 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", 499 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4" 500 | }, 501 | "Qwen-14B-int8-Chat": { 502 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", 503 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8" 504 | }, 505 | "Qwen-14B-int4-Chat": { 506 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", 507 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4" 508 | }, 509 | "Qwen-72B-int8-Chat": { 510 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", 511 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8" 512 | }, 513 | "Qwen-72B-int4-Chat": { 514 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", 515 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4" 516 | } 517 | }, 518 | module="c_attn", 519 | template="qwen" 520 | ) 521 | 522 | 523 | register_model_group( 524 | models={ 525 | "Skywork-13B-Base": { 526 | DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", 527 | DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base" 528 | } 529 | } 530 | ) 531 | 532 | 533 | register_model_group( 534 | models={ 535 | "Vicuna1.5-7B-Chat": { 536 | DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", 537 | DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5" 538 | }, 539 | "Vicuna1.5-13B-Chat": { 540 | DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", 541 | DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5" 542 | } 543 | }, 544 | template="vicuna" 545 | ) 546 | 547 | 548 | register_model_group( 549 | models={ 550 | "XuanYuan-70B": { 551 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B" 552 | }, 553 | "XuanYuan-70B-Chat": { 554 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat" 555 | }, 556 | "XuanYuan-70B-int8-Chat": { 557 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit" 558 | }, 559 | "XuanYuan-70B-int4-Chat": { 560 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit" 561 | } 562 | }, 563 | template="xuanyuan" 564 | ) 565 | 566 | 567 | register_model_group( 568 | models={ 569 | "XVERSE-7B": { 570 | DownloadSource.DEFAULT: "xverse/XVERSE-7B", 571 | DownloadSource.MODELSCOPE: "xverse/XVERSE-7B" 572 | }, 573 | "XVERSE-13B": { 574 | DownloadSource.DEFAULT: "xverse/XVERSE-13B", 575 | DownloadSource.MODELSCOPE: "xverse/XVERSE-13B" 576 | }, 577 | "XVERSE-65B": { 578 | DownloadSource.DEFAULT: "xverse/XVERSE-65B", 579 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B" 580 | }, 581 | "XVERSE-65B-2": { 582 | DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", 583 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2" 584 | }, 585 | "XVERSE-7B-Chat": { 586 | DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", 587 | DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat" 588 | }, 589 | "XVERSE-13B-Chat": { 590 | DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", 591 | DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat" 592 | }, 593 | "XVERSE-65B-Chat": { 594 | DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", 595 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat" 596 | } 597 | }, 598 | template="xverse" 599 | ) 600 | 601 | 602 | register_model_group( 603 | models={ 604 | "Yayi-7B": { 605 | DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", 606 | DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2" 607 | }, 608 | "Yayi-13B": { 609 | DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", 610 | DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2" 611 | } 612 | }, 613 | template="yayi" 614 | ) 615 | 616 | 617 | register_model_group( 618 | models={ 619 | "Yi-6B": { 620 | DownloadSource.DEFAULT: "01-ai/Yi-6B", 621 | DownloadSource.MODELSCOPE: "01ai/Yi-6B" 622 | }, 623 | "Yi-34B": { 624 | DownloadSource.DEFAULT: "01-ai/Yi-34B", 625 | DownloadSource.MODELSCOPE: "01ai/Yi-34B" 626 | }, 627 | "Yi-6B-Chat": { 628 | DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", 629 | DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat" 630 | }, 631 | "Yi-34B-Chat": { 632 | DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", 633 | DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat" 634 | }, 635 | "Yi-6B-int8-Chat": { 636 | DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", 637 | DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits" 638 | }, 639 | "Yi-34B-int8-Chat": { 640 | DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", 641 | DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits" 642 | } 643 | }, 644 | template="yi" 645 | ) 646 | 647 | 648 | register_model_group( 649 | models={ 650 | "Yuan2-2B-Chat": { 651 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf", 652 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf" 653 | }, 654 | "Yuan2-51B-Chat": { 655 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf", 656 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf" 657 | }, 658 | "Yuan2-102B-Chat": { 659 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf", 660 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf" 661 | } 662 | }, 663 | template="yuan" 664 | ) 665 | 666 | 667 | register_model_group( 668 | models={ 669 | "Zephyr-7B-Alpha-Chat": { 670 | DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", 671 | DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha" 672 | }, 673 | "Zephyr-7B-Beta-Chat": { 674 | DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", 675 | DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta" 676 | } 677 | }, 678 | template="zephyr" 679 | ) 680 | -------------------------------------------------------------------------------- /model/src/extras/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | 5 | class LoggerHandler(logging.Handler): 6 | r""" 7 | Logger handler used in Web UI. 8 | """ 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.log = "" 13 | 14 | def reset(self): 15 | self.log = "" 16 | 17 | def emit(self, record): 18 | if record.name == "httpx": 19 | return 20 | log_entry = self.format(record) 21 | self.log += log_entry 22 | self.log += "\n\n" 23 | 24 | 25 | def get_logger(name: str) -> logging.Logger: 26 | r""" 27 | Gets a standard logger with a stream hander to stdout. 28 | """ 29 | formatter = logging.Formatter( 30 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 31 | datefmt="%m/%d/%Y %H:%M:%S" 32 | ) 33 | handler = logging.StreamHandler(sys.stdout) 34 | handler.setFormatter(formatter) 35 | 36 | logger = logging.getLogger(name) 37 | logger.setLevel(logging.INFO) 38 | logger.addHandler(handler) 39 | 40 | return logger 41 | 42 | 43 | def reset_logging() -> None: 44 | r""" 45 | Removes basic config of root logger. (unused in script) 46 | """ 47 | root = logging.getLogger() 48 | list(map(root.removeHandler, root.handlers)) 49 | list(map(root.removeFilter, root.filters)) 50 | -------------------------------------------------------------------------------- /model/src/extras/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import torch 4 | from typing import TYPE_CHECKING, Dict, Tuple 5 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel 6 | from transformers.utils import ( 7 | WEIGHTS_NAME, 8 | SAFE_WEIGHTS_NAME, 9 | is_torch_bf16_gpu_available, 10 | is_torch_cuda_available, 11 | is_torch_npu_available, 12 | is_torch_xpu_available 13 | ) 14 | from peft import PeftModel 15 | 16 | from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME 17 | from .logging import get_logger 18 | 19 | 20 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() 21 | try: 22 | _is_bf16_available = is_torch_bf16_gpu_available() 23 | except: 24 | _is_bf16_available = False 25 | 26 | 27 | if TYPE_CHECKING: 28 | from trl import AutoModelForCausalLMWithValueHead 29 | from ..hparams.model_args import ModelArguments 30 | 31 | 32 | logger = get_logger(__name__) 33 | 34 | 35 | class AverageMeter: 36 | r""" 37 | Computes and stores the average and current value. 38 | """ 39 | def __init__(self): 40 | self.reset() 41 | 42 | def reset(self): 43 | self.val = 0 44 | self.avg = 0 45 | self.sum = 0 46 | self.count = 0 47 | 48 | def update(self, val, n=1): 49 | self.val = val 50 | self.sum += val * n 51 | self.count += n 52 | self.avg = self.sum / self.count 53 | 54 | 55 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: 56 | r""" 57 | Returns the number of trainable parameters and number of all parameters in the model. 58 | """ 59 | trainable_params, all_param = 0, 0 60 | for param in model.parameters(): 61 | num_params = param.numel() 62 | # if using DS Zero 3 and the weights are initialized empty 63 | if num_params == 0 and hasattr(param, "ds_numel"): 64 | num_params = param.ds_numel 65 | 66 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 67 | if param.__class__.__name__ == "Params4bit": 68 | num_params = num_params * 2 69 | 70 | all_param += num_params 71 | if param.requires_grad: 72 | trainable_params += num_params 73 | 74 | return trainable_params, all_param 75 | 76 | 77 | def fix_valuehead_checkpoint( 78 | model: "AutoModelForCausalLMWithValueHead", 79 | output_dir: str, 80 | safe_serialization: bool 81 | ) -> None: 82 | r""" 83 | The model is already unwrapped. 84 | 85 | There are three cases: 86 | 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} 87 | 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} 88 | 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} 89 | 90 | We assume `stage3_gather_16bit_weights_on_model_save=true`. 91 | """ 92 | if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): 93 | return 94 | 95 | if safe_serialization: 96 | from safetensors import safe_open 97 | from safetensors.torch import save_file 98 | path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) 99 | with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: 100 | state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} 101 | else: 102 | path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) 103 | state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") 104 | 105 | decoder_state_dict = {} 106 | v_head_state_dict = {} 107 | for name, param in state_dict.items(): 108 | if name.startswith("v_head."): 109 | v_head_state_dict[name] = param 110 | else: 111 | decoder_state_dict[name.replace("pretrained_model.", "")] = param 112 | 113 | os.remove(path_to_checkpoint) 114 | model.pretrained_model.save_pretrained( 115 | output_dir, 116 | state_dict=decoder_state_dict or None, 117 | safe_serialization=safe_serialization 118 | ) 119 | 120 | if safe_serialization: 121 | save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) 122 | else: 123 | torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) 124 | 125 | logger.info("Value head model saved at: {}".format(output_dir)) 126 | 127 | 128 | def get_current_device() -> torch.device: 129 | r""" 130 | Gets the current available device. 131 | """ 132 | if is_torch_xpu_available(): 133 | device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) 134 | elif is_torch_npu_available(): 135 | device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) 136 | elif is_torch_cuda_available(): 137 | device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) 138 | else: 139 | device = "cpu" 140 | 141 | return torch.device(device) 142 | 143 | 144 | def get_device_count() -> int: 145 | return torch.cuda.device_count() 146 | 147 | 148 | def get_logits_processor() -> "LogitsProcessorList": 149 | r""" 150 | Gets logits processor that removes NaN and Inf logits. 151 | """ 152 | logits_processor = LogitsProcessorList() 153 | logits_processor.append(InfNanRemoveLogitsProcessor()) 154 | return logits_processor 155 | 156 | 157 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: 158 | r""" 159 | Infers the optimal dtype according to the model_dtype and device compatibility. 160 | """ 161 | if _is_bf16_available and model_dtype == torch.bfloat16: 162 | return torch.bfloat16 163 | elif _is_fp16_available: 164 | return torch.float16 165 | else: 166 | return torch.float32 167 | 168 | 169 | def torch_gc() -> None: 170 | r""" 171 | Collects GPU memory. 172 | """ 173 | gc.collect() 174 | if torch.cuda.is_available(): 175 | torch.cuda.empty_cache() 176 | torch.cuda.ipc_collect() 177 | 178 | 179 | def try_download_model_from_ms(model_args: "ModelArguments") -> None: 180 | if not use_modelscope() or os.path.exists(model_args.model_name_or_path): 181 | return 182 | 183 | try: 184 | from modelscope import snapshot_download 185 | revision = "master" if model_args.model_revision == "main" else model_args.model_revision 186 | model_args.model_name_or_path = snapshot_download( 187 | model_args.model_name_or_path, 188 | revision=revision, 189 | cache_dir=model_args.cache_dir 190 | ) 191 | except ImportError: 192 | raise ImportError("Please install modelscope via `pip install modelscope -U`") 193 | 194 | 195 | def use_modelscope() -> bool: 196 | return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) 197 | -------------------------------------------------------------------------------- /model/src/extras/packages.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import importlib.util 3 | 4 | 5 | def is_package_available(name: str) -> bool: 6 | return importlib.util.find_spec(name) is not None 7 | 8 | 9 | def get_package_version(name: str) -> str: 10 | try: 11 | return importlib.metadata.version(name) 12 | except: 13 | return "0.0.0" 14 | 15 | 16 | def is_fastapi_availble(): 17 | return is_package_available("fastapi") 18 | 19 | 20 | def is_flash_attn2_available(): 21 | return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") 22 | 23 | 24 | def is_jieba_available(): 25 | return is_package_available("jieba") 26 | 27 | 28 | def is_matplotlib_available(): 29 | return is_package_available("matplotlib") 30 | 31 | 32 | def is_nltk_available(): 33 | return is_package_available("nltk") 34 | 35 | 36 | def is_requests_available(): 37 | return is_package_available("requests") 38 | 39 | 40 | def is_rouge_available(): 41 | return is_package_available("rouge_chinese") 42 | 43 | 44 | def is_starlette_available(): 45 | return is_package_available("sse_starlette") 46 | 47 | 48 | def is_uvicorn_available(): 49 | return is_package_available("uvicorn") 50 | -------------------------------------------------------------------------------- /model/src/extras/ploting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | from typing import List, Optional 5 | from transformers.trainer import TRAINER_STATE_NAME 6 | 7 | from .logging import get_logger 8 | from .packages import is_matplotlib_available 9 | 10 | if is_matplotlib_available(): 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def smooth(scalars: List[float]) -> List[float]: 18 | r""" 19 | EMA implementation according to TensorBoard. 20 | """ 21 | last = scalars[0] 22 | smoothed = list() 23 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 24 | for next_val in scalars: 25 | smoothed_val = last * weight + (1 - weight) * next_val 26 | smoothed.append(smoothed_val) 27 | last = smoothed_val 28 | return smoothed 29 | 30 | 31 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: 32 | 33 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: 34 | data = json.load(f) 35 | 36 | for key in keys: 37 | steps, metrics = [], [] 38 | for i in range(len(data["log_history"])): 39 | if key in data["log_history"][i]: 40 | steps.append(data["log_history"][i]["step"]) 41 | metrics.append(data["log_history"][i][key]) 42 | 43 | if len(metrics) == 0: 44 | logger.warning(f"No metric {key} to plot.") 45 | continue 46 | 47 | plt.figure() 48 | plt.plot(steps, metrics, alpha=0.4, label="original") 49 | plt.plot(steps, smooth(metrics), label="smoothed") 50 | plt.title("training {} of {}".format(key, save_dictionary)) 51 | plt.xlabel("step") 52 | plt.ylabel(key) 53 | plt.legend() 54 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) 55 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) 56 | -------------------------------------------------------------------------------- /model/src/hparams/data_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Literal, Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | DATA_CONFIG = "dataset_info.json" 8 | 9 | 10 | def use_modelscope() -> bool: 11 | return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0"))) 12 | 13 | 14 | @dataclass 15 | class DatasetAttr: 16 | 17 | load_from: Literal["hf_hub", "ms_hub", "script", "file"] 18 | dataset_name: Optional[str] = None 19 | dataset_sha1: Optional[str] = None 20 | subset: Optional[str] = None 21 | folder: Optional[str] = None 22 | ranking: Optional[bool] = False 23 | formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca" 24 | 25 | prompt: Optional[str] = "instruction" 26 | query: Optional[str] = "input" 27 | response: Optional[str] = "output" 28 | history: Optional[str] = None 29 | messages: Optional[str] = "conversations" 30 | role: Optional[str] = "from" 31 | content: Optional[str] = "value" 32 | system: Optional[str] = None 33 | 34 | def __repr__(self) -> str: 35 | return self.dataset_name 36 | 37 | 38 | @dataclass 39 | class DataArguments: 40 | r""" 41 | Arguments pertaining to what data we are going to input our model for training and evaluation. 42 | """ 43 | template: Optional[str] = field( 44 | default=None, 45 | metadata={"help": "Which template to use for constructing prompts in training and inference."} 46 | ) 47 | dataset: Optional[str] = field( 48 | default=None, 49 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} 50 | ) 51 | dataset_dir: Optional[str] = field( 52 | default="data", 53 | metadata={"help": "Path to the folder containing the datasets."} 54 | ) 55 | split: Optional[str] = field( 56 | default="train", 57 | metadata={"help": "Which dataset split to use for training and evaluation."} 58 | ) 59 | cutoff_len: Optional[int] = field( 60 | default=1024, 61 | metadata={"help": "The maximum length of the model inputs after tokenization."} 62 | ) 63 | reserved_label_len: Optional[int] = field( 64 | default=1, 65 | metadata={"help": "The maximum length reserved for label after tokenization."} 66 | ) 67 | train_on_prompt: Optional[bool] = field( 68 | default=False, 69 | metadata={"help": "Whether to disable the mask on the prompt or not."} 70 | ) 71 | streaming: Optional[bool] = field( 72 | default=False, 73 | metadata={"help": "Enable dataset streaming."} 74 | ) 75 | buffer_size: Optional[int] = field( 76 | default=16384, 77 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."} 78 | ) 79 | mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( 80 | default="concat", 81 | metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."} 82 | ) 83 | interleave_probs: Optional[str] = field( 84 | default=None, 85 | metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} 86 | ) 87 | overwrite_cache: Optional[bool] = field( 88 | default=False, 89 | metadata={"help": "Overwrite the cached training and evaluation sets."} 90 | ) 91 | preprocessing_num_workers: Optional[int] = field( 92 | default=None, 93 | metadata={"help": "The number of processes to use for the preprocessing."} 94 | ) 95 | max_samples: Optional[int] = field( 96 | default=None, 97 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 98 | ) 99 | eval_num_beams: Optional[int] = field( 100 | default=None, 101 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 102 | ) 103 | ignore_pad_token_for_loss: Optional[bool] = field( 104 | default=True, 105 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 106 | ) 107 | val_size: Optional[float] = field( 108 | default=0, 109 | metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} 110 | ) 111 | sft_packing: Optional[bool] = field( 112 | default=False, 113 | metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} 114 | ) 115 | cache_path: Optional[str] = field( 116 | default=None, 117 | metadata={"help": "Path to save or load the preprocessed datasets."} 118 | ) 119 | 120 | def __post_init__(self): 121 | if self.reserved_label_len >= self.cutoff_len: 122 | raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.") 123 | 124 | if self.streaming and self.val_size > 1e-6 and self.val_size < 1: 125 | raise ValueError("Streaming mode should have an integer val size.") 126 | 127 | if self.streaming and self.max_samples is not None: 128 | raise ValueError("`max_samples` is incompatible with `streaming`.") 129 | 130 | def init_for_training(self, seed: int): # support mixing multiple datasets 131 | self.seed = seed 132 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else [] 133 | try: 134 | with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f: 135 | dataset_info = json.load(f) 136 | except Exception as err: 137 | if self.dataset is not None: 138 | raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err))) 139 | dataset_info = None 140 | 141 | if self.interleave_probs is not None: 142 | self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] 143 | 144 | self.dataset_list: List[DatasetAttr] = [] 145 | for name in dataset_names: 146 | if name not in dataset_info: 147 | raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG)) 148 | 149 | has_hf_url = "hf_hub_url" in dataset_info[name] 150 | has_ms_url = "ms_hub_url" in dataset_info[name] 151 | 152 | if has_hf_url or has_ms_url: 153 | if (use_modelscope() and has_ms_url) or (not has_hf_url): 154 | dataset_attr = DatasetAttr( 155 | "ms_hub", 156 | dataset_name=dataset_info[name]["ms_hub_url"] 157 | ) 158 | else: 159 | dataset_attr = DatasetAttr( 160 | "hf_hub", 161 | dataset_name=dataset_info[name]["hf_hub_url"] 162 | ) 163 | elif "script_url" in dataset_info[name]: 164 | dataset_attr = DatasetAttr( 165 | "script", 166 | dataset_name=dataset_info[name]["script_url"] 167 | ) 168 | else: 169 | dataset_attr = DatasetAttr( 170 | "file", 171 | dataset_name=dataset_info[name]["file_name"], 172 | dataset_sha1=dataset_info[name].get("file_sha1", None) 173 | ) 174 | 175 | if "columns" in dataset_info[name]: 176 | dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) 177 | dataset_attr.query = dataset_info[name]["columns"].get("query", None) 178 | dataset_attr.response = dataset_info[name]["columns"].get("response", None) 179 | dataset_attr.history = dataset_info[name]["columns"].get("history", None) 180 | dataset_attr.messages = dataset_info[name]["columns"].get("messages", None) 181 | dataset_attr.role = dataset_info[name]["columns"].get("role", None) 182 | dataset_attr.content = dataset_info[name]["columns"].get("content", None) 183 | dataset_attr.system = dataset_info[name]["columns"].get("system", None) 184 | 185 | dataset_attr.subset = dataset_info[name].get("subset", None) 186 | dataset_attr.folder = dataset_info[name].get("folder", None) 187 | dataset_attr.ranking = dataset_info[name].get("ranking", False) 188 | dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca") 189 | self.dataset_list.append(dataset_attr) 190 | -------------------------------------------------------------------------------- /model/src/hparams/evaluation_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Literal, Optional 3 | from dataclasses import dataclass, field 4 | 5 | from datasets import DownloadMode 6 | 7 | 8 | @dataclass 9 | class EvaluationArguments: 10 | r""" 11 | Arguments pertaining to specify the evaluation parameters. 12 | """ 13 | task: str = field( 14 | metadata={"help": "Name of the evaluation task."} 15 | ) 16 | task_dir: Optional[str] = field( 17 | default="evaluation", 18 | metadata={"help": "Path to the folder containing the evaluation datasets."} 19 | ) 20 | batch_size: Optional[int] = field( 21 | default=4, 22 | metadata={"help": "The batch size per GPU for evaluation."} 23 | ) 24 | seed: Optional[int] = field( 25 | default=42, 26 | metadata={"help": "Random seed to be used with data loaders."} 27 | ) 28 | lang: Optional[Literal["en", "zh"]] = field( 29 | default="en", 30 | metadata={"help": "Language used at evaluation."} 31 | ) 32 | n_shot: Optional[int] = field( 33 | default=5, 34 | metadata={"help": "Number of examplars for few-shot learning."} 35 | ) 36 | save_dir: Optional[str] = field( 37 | default=None, 38 | metadata={"help": "Path to save the evaluation results."} 39 | ) 40 | download_mode: Optional[DownloadMode] = field( 41 | default=DownloadMode.REUSE_DATASET_IF_EXISTS, 42 | metadata={"help": "Download mode used for the evaluation datasets."} 43 | ) 44 | 45 | def __post_init__(self): 46 | task_available = [] 47 | for folder in os.listdir(self.task_dir): 48 | if os.path.isdir(os.path.join(self.task_dir, folder)): 49 | task_available.append(folder) 50 | 51 | if self.task not in task_available: 52 | raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir)) 53 | 54 | if self.save_dir is not None and os.path.exists(self.save_dir): 55 | raise ValueError("`save_dir` already exists, use another one.") 56 | -------------------------------------------------------------------------------- /model/src/hparams/finetuning_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Literal, Optional 3 | from dataclasses import asdict, dataclass, field 4 | 5 | 6 | @dataclass 7 | class FreezeArguments: 8 | r""" 9 | Arguments pertaining to the freeze (partial-parameter) training. 10 | """ 11 | name_module_trainable: Optional[str] = field( 12 | default="mlp", 13 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ 14 | Use commas to separate multiple modules. \ 15 | LLaMA choices: [\"mlp\", \"self_attn\"], \ 16 | BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \ 17 | Qwen choices: [\"mlp\", \"attn\"], \ 18 | Phi choices: [\"mlp\", \"mixer\"], \ 19 | Others choices: the same as LLaMA."} 20 | ) 21 | num_layer_trainable: Optional[int] = field( 22 | default=3, 23 | metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."} 24 | ) 25 | 26 | 27 | @dataclass 28 | class LoraArguments: 29 | r""" 30 | Arguments pertaining to the LoRA training. 31 | """ 32 | additional_target: Optional[str] = field( 33 | default=None, 34 | metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} 35 | ) 36 | lora_alpha: Optional[int] = field( 37 | default=None, 38 | metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."} 39 | ) 40 | lora_dropout: Optional[float] = field( 41 | default=0.0, 42 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 43 | ) 44 | lora_rank: Optional[int] = field( 45 | default=8, 46 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 47 | ) 48 | lora_target: Optional[str] = field( 49 | default=None, 50 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ 51 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 52 | BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \ 53 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 54 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ 55 | Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ 56 | Others choices: the same as LLaMA."} 57 | ) 58 | create_new_adapter: Optional[bool] = field( 59 | default=False, 60 | metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."} 61 | ) 62 | 63 | 64 | @dataclass 65 | class RLHFArguments: 66 | r""" 67 | Arguments pertaining to the PPO and DPO training. 68 | """ 69 | dpo_beta: Optional[float] = field( 70 | default=0.1, 71 | metadata={"help": "The beta parameter for the DPO loss."} 72 | ) 73 | dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field( 74 | default="sigmoid", 75 | metadata={"help": "The type of DPO loss to use."} 76 | ) 77 | dpo_ftx: Optional[float] = field( 78 | default=0, 79 | metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."} 80 | ) 81 | ppo_buffer_size: Optional[int] = field( 82 | default=1, 83 | metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."} 84 | ) 85 | ppo_epochs: Optional[int] = field( 86 | default=4, 87 | metadata={"help": "The number of epochs to perform in a PPO optimization step."} 88 | ) 89 | ppo_logger: Optional[str] = field( 90 | default=None, 91 | metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."} 92 | ) 93 | ppo_score_norm: Optional[bool] = field( 94 | default=False, 95 | metadata={"help": "Use score normalization in PPO training."} 96 | ) 97 | ppo_target: Optional[float] = field( 98 | default=6.0, 99 | metadata={"help": "Target KL value for adaptive KL control in PPO training."} 100 | ) 101 | ppo_whiten_rewards: Optional[bool] = field( 102 | default=False, 103 | metadata={"help": "Whiten the rewards before compute advantages in PPO training."} 104 | ) 105 | ref_model: Optional[str] = field( 106 | default=None, 107 | metadata={"help": "Path to the reference model used for the PPO or DPO training."} 108 | ) 109 | ref_model_adapters: Optional[str] = field( 110 | default=None, 111 | metadata={"help": "Path to the adapters of the reference model."} 112 | ) 113 | ref_model_quantization_bit: Optional[int] = field( 114 | default=None, 115 | metadata={"help": "The number of bits to quantize the reference model."} 116 | ) 117 | reward_model: Optional[str] = field( 118 | default=None, 119 | metadata={"help": "Path to the reward model used for the PPO training."} 120 | ) 121 | reward_model_adapters: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "Path to the adapters of the reward model."} 124 | ) 125 | reward_model_quantization_bit: Optional[int] = field( 126 | default=None, 127 | metadata={"help": "The number of bits to quantize the reward model."} 128 | ) 129 | reward_model_type: Optional[Literal["lora", "full", "api"]] = field( 130 | default="lora", 131 | metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."} 132 | ) 133 | 134 | 135 | @dataclass 136 | class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments): 137 | r""" 138 | Arguments pertaining to which techniques we are going to fine-tuning with. 139 | """ 140 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( 141 | default="sft", 142 | metadata={"help": "Which stage will be performed in training."} 143 | ) 144 | finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field( 145 | default="lora", 146 | metadata={"help": "Which fine-tuning method to use."} 147 | ) 148 | plot_loss: Optional[bool] = field( 149 | default=False, 150 | metadata={"help": "Whether or not to save the training loss curves."} 151 | ) 152 | 153 | def __post_init__(self): 154 | def split_arg(arg): 155 | if isinstance(arg, str): 156 | return [item.strip() for item in arg.split(",")] 157 | return arg 158 | 159 | self.name_module_trainable = split_arg(self.name_module_trainable) 160 | self.lora_alpha = self.lora_alpha or self.lora_rank * 2 161 | self.lora_target = split_arg(self.lora_target) 162 | self.additional_target = split_arg(self.additional_target) 163 | self.ref_model_adapters = split_arg(self.ref_model_adapters) 164 | self.reward_model_adapters = split_arg(self.reward_model_adapters) 165 | 166 | assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." 167 | assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." 168 | assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." 169 | 170 | if self.stage == "ppo" and self.reward_model is None: 171 | raise ValueError("Reward model is necessary for PPO training.") 172 | 173 | if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": 174 | raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.") 175 | 176 | def save_to_json(self, json_path: str): 177 | r"""Saves the content of this instance in JSON format inside `json_path`.""" 178 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" 179 | with open(json_path, "w", encoding="utf-8") as f: 180 | f.write(json_string) 181 | 182 | @classmethod 183 | def load_from_json(cls, json_path: str): 184 | r"""Creates an instance from the content of `json_path`.""" 185 | with open(json_path, "r", encoding="utf-8") as f: 186 | text = f.read() 187 | 188 | return cls(**json.loads(text)) 189 | -------------------------------------------------------------------------------- /model/src/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneratingArguments: 7 | r""" 8 | Arguments pertaining to specify the decoding parameters. 9 | """ 10 | do_sample: Optional[bool] = field( 11 | default=True, 12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} 13 | ) 14 | temperature: Optional[float] = field( 15 | default=0.95, 16 | metadata={"help": "The value used to modulate the next token probabilities."} 17 | ) 18 | top_p: Optional[float] = field( 19 | default=0.7, 20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} 21 | ) 22 | top_k: Optional[int] = field( 23 | default=50, 24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} 25 | ) 26 | num_beams: Optional[int] = field( 27 | default=1, 28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."} 29 | ) 30 | num_return_sequences: Optional[int] = field( 31 | default=1, 32 | metadata={"help": "Number of generated sentences"} 33 | ) 34 | max_length: Optional[int] = field( 35 | default=512, 36 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} 37 | ) 38 | max_new_tokens: Optional[int] = field( 39 | default=512, 40 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} 41 | ) 42 | repetition_penalty: Optional[float] = field( 43 | default=1.0, 44 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} 45 | ) 46 | length_penalty: Optional[float] = field( 47 | default=1.0, 48 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} 49 | ) 50 | 51 | def to_dict(self) -> Dict[str, Any]: 52 | args = asdict(self) 53 | if args.get("max_new_tokens", -1) > 0: 54 | args.pop("max_length", None) 55 | else: 56 | args.pop("max_new_tokens", None) 57 | return args 58 | -------------------------------------------------------------------------------- /model/src/hparams/model_args.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Literal, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | r""" 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 9 | """ 10 | model_name_or_path: str = field( 11 | metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."} 12 | ) 13 | adapter_name_or_path: Optional[str] = field( 14 | default=None, 15 | metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."} 16 | ) 17 | cache_dir: Optional[str] = field( 18 | default=None, 19 | metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."} 20 | ) 21 | use_fast_tokenizer: Optional[bool] = field( 22 | default=False, 23 | metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."} 24 | ) 25 | resize_vocab: Optional[bool] = field( 26 | default=False, 27 | metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."} 28 | ) 29 | split_special_tokens: Optional[bool] = field( 30 | default=False, 31 | metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} 32 | ) 33 | model_revision: Optional[str] = field( 34 | default="main", 35 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 36 | ) 37 | quantization_bit: Optional[int] = field( 38 | default=None, 39 | metadata={"help": "The number of bits to quantize the model."} 40 | ) 41 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 42 | default="nf4", 43 | metadata={"help": "Quantization data type to use in int4 training."} 44 | ) 45 | double_quantization: Optional[bool] = field( 46 | default=True, 47 | metadata={"help": "Whether or not to use double quantization in int4 training."} 48 | ) 49 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field( 50 | default=None, 51 | metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."} 52 | ) 53 | flash_attn: Optional[bool] = field( 54 | default=False, 55 | metadata={"help": "Enable FlashAttention-2 for faster training."} 56 | ) 57 | shift_attn: Optional[bool] = field( 58 | default=False, 59 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} 60 | ) 61 | use_unsloth: Optional[bool] = field( 62 | default=False, 63 | metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."} 64 | ) 65 | disable_gradient_checkpointing: Optional[bool] = field( 66 | default=False, 67 | metadata={"help": "Whether or not to disable gradient checkpointing."} 68 | ) 69 | upcast_layernorm: Optional[bool] = field( 70 | default=False, 71 | metadata={"help": "Whether or not to upcast the layernorm weights in fp32."} 72 | ) 73 | hf_hub_token: Optional[str] = field( 74 | default=None, 75 | metadata={"help": "Auth token to log in with Hugging Face Hub."} 76 | ) 77 | ms_hub_token: Optional[str] = field( 78 | default=None, 79 | metadata={"help": "Auth token to log in with ModelScope Hub."} 80 | ) 81 | export_dir: Optional[str] = field( 82 | default=None, 83 | metadata={"help": "Path to the directory to save the exported model."} 84 | ) 85 | export_size: Optional[int] = field( 86 | default=1, 87 | metadata={"help": "The file shard size (in GB) of the exported model."} 88 | ) 89 | export_quantization_bit: Optional[int] = field( 90 | default=None, 91 | metadata={"help": "The number of bits to quantize the exported model."} 92 | ) 93 | export_quantization_dataset: Optional[str] = field( 94 | default=None, 95 | metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."} 96 | ) 97 | export_quantization_nsamples: Optional[int] = field( 98 | default=128, 99 | metadata={"help": "The number of samples used for quantization."} 100 | ) 101 | export_quantization_maxlen: Optional[int] = field( 102 | default=1024, 103 | metadata={"help": "The maximum length of the model inputs used for quantization."} 104 | ) 105 | export_legacy_format: Optional[bool] = field( 106 | default=False, 107 | metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."} 108 | ) 109 | 110 | def __post_init__(self): 111 | self.compute_dtype = None 112 | self.model_max_length = None 113 | 114 | if self.split_special_tokens and self.use_fast_tokenizer: 115 | raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") 116 | 117 | if self.adapter_name_or_path is not None: # support merging multiple lora weights 118 | self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] 119 | 120 | assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." 121 | assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization." 122 | 123 | if self.export_quantization_bit is not None and self.export_quantization_dataset is None: 124 | raise ValueError("Quantization dataset is necessary for exporting.") 125 | 126 | def to_dict(self) -> Dict[str, Any]: 127 | return asdict(self) 128 | -------------------------------------------------------------------------------- /model/src/model/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING 3 | from transformers.integrations import is_deepspeed_zero3_enabled 4 | from peft import PeftModel, TaskType, LoraConfig, get_peft_model 5 | 6 | from ..extras.logging import get_logger 7 | from .utils import find_all_linear_modules 8 | 9 | if TYPE_CHECKING: 10 | from transformers.modeling_utils import PreTrainedModel 11 | from ..hparams.model_args import ModelArguments 12 | from ..hparams.finetuning_args import FinetuningArguments 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def init_adapter( 18 | model: "PreTrainedModel", 19 | model_args: "ModelArguments", 20 | finetuning_args: "FinetuningArguments", 21 | is_trainable: bool 22 | ) -> "PreTrainedModel": 23 | r""" 24 | Initializes the adapters. 25 | 26 | Support full-parameter, freeze and LoRA training. 27 | 28 | Note that the trainable parameters must be cast to float32. 29 | """ 30 | 31 | if (not is_trainable) and model_args.adapter_name_or_path is None: 32 | logger.info("Adapter is not found at evaluation, load the base model.") 33 | return model 34 | 35 | if finetuning_args.finetuning_type == "full" and is_trainable: 36 | logger.info("Fine-tuning method: Full") 37 | model = model.float() 38 | 39 | if finetuning_args.finetuning_type == "freeze" and is_trainable: 40 | logger.info("Fine-tuning method: Freeze") 41 | num_layers = ( 42 | getattr(model.config, "num_hidden_layers", None) 43 | or getattr(model.config, "num_layers", None) 44 | or getattr(model.config, "n_layer", None) 45 | ) 46 | if not num_layers: 47 | raise ValueError("Current model does not support freeze tuning.") 48 | 49 | if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 50 | trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] 51 | else: # fine-tuning the first n layers if num_layer_trainable < 0 52 | trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] 53 | 54 | trainable_layers = [] 55 | for module_name in finetuning_args.name_module_trainable: 56 | for idx in trainable_layer_ids: 57 | trainable_layers.append("{:d}.{}".format(idx, module_name)) 58 | 59 | for name, param in model.named_parameters(): 60 | if not any(trainable_layer in name for trainable_layer in trainable_layers): 61 | param.requires_grad_(False) 62 | else: 63 | param.data = param.data.to(torch.float32) 64 | 65 | if finetuning_args.finetuning_type == "lora": 66 | logger.info("Fine-tuning method: LoRA") 67 | adapter_to_resume = None 68 | 69 | if model_args.adapter_name_or_path is not None: 70 | is_mergeable = True 71 | if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable 72 | assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." 73 | is_mergeable = False 74 | 75 | if is_deepspeed_zero3_enabled(): 76 | assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." 77 | is_mergeable = False 78 | 79 | if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): 80 | adapter_to_merge = model_args.adapter_name_or_path[:-1] 81 | adapter_to_resume = model_args.adapter_name_or_path[-1] 82 | else: 83 | adapter_to_merge = model_args.adapter_name_or_path 84 | 85 | for adapter in adapter_to_merge: 86 | model = PeftModel.from_pretrained(model, adapter) 87 | model = model.merge_and_unload() 88 | 89 | if len(adapter_to_merge) > 0: 90 | logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) 91 | 92 | if adapter_to_resume is not None: # resume lora training 93 | model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable) 94 | 95 | if is_trainable and adapter_to_resume is None: # create new lora weights while training 96 | if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": 97 | target_modules = find_all_linear_modules(model) 98 | else: 99 | target_modules = finetuning_args.lora_target 100 | 101 | peft_kwargs = { 102 | "r": finetuning_args.lora_rank, 103 | "target_modules": target_modules, 104 | "lora_alpha": finetuning_args.lora_alpha, 105 | "lora_dropout": finetuning_args.lora_dropout 106 | } 107 | 108 | if model_args.use_unsloth: 109 | from unsloth import FastLlamaModel, FastMistralModel # type: ignore 110 | unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length} 111 | if getattr(model.config, "model_type", None) == "llama": 112 | model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) 113 | elif getattr(model.config, "model_type", None) == "mistral": 114 | model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) 115 | else: 116 | raise NotImplementedError 117 | 118 | else: 119 | lora_config = LoraConfig( 120 | task_type=TaskType.CAUSAL_LM, 121 | inference_mode=False, 122 | modules_to_save=finetuning_args.additional_target, 123 | **peft_kwargs 124 | ) 125 | model = get_peft_model(model, lora_config) 126 | 127 | for param in filter(lambda p: p.requires_grad, model.parameters()): 128 | param.data = param.data.to(torch.float32) 129 | 130 | if model_args.adapter_name_or_path is not None: 131 | logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) 132 | 133 | return model 134 | -------------------------------------------------------------------------------- /model/src/model/loader.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple 2 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer 3 | from transformers.integrations import is_deepspeed_zero3_enabled 4 | from transformers.utils.versions import require_version 5 | from trl import AutoModelForCausalLMWithValueHead 6 | 7 | from ..extras.logging import get_logger 8 | from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms 9 | from .adapter import init_adapter 10 | from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model 11 | from .utils import load_valuehead_params, register_autoclass 12 | 13 | if TYPE_CHECKING: 14 | from transformers import PreTrainedModel, PreTrainedTokenizer 15 | from ..hparams.model_args import ModelArguments 16 | from ..hparams.finetuning_args import FinetuningArguments 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2") 23 | require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3") 24 | require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") 25 | require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0") 26 | require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6") 27 | 28 | 29 | def load_model_and_tokenizer( 30 | model_args: "ModelArguments", 31 | finetuning_args: "FinetuningArguments", 32 | is_trainable: Optional[bool] = False, 33 | add_valuehead: Optional[bool] = False 34 | ) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]: 35 | r""" 36 | Loads pretrained model and tokenizer. 37 | 38 | Support both training and inference. 39 | """ 40 | 41 | try_download_model_from_ms(model_args) 42 | 43 | config_kwargs = { 44 | "trust_remote_code": True, 45 | "cache_dir": model_args.cache_dir, 46 | "revision": model_args.model_revision, 47 | "token": model_args.hf_hub_token 48 | } 49 | 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | model_args.model_name_or_path, 52 | use_fast=model_args.use_fast_tokenizer, 53 | split_special_tokens=model_args.split_special_tokens, 54 | padding_side="right", 55 | **config_kwargs 56 | ) 57 | patch_tokenizer(tokenizer) 58 | 59 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) 60 | patch_config(config, tokenizer, model_args, config_kwargs, is_trainable) 61 | 62 | model = None 63 | if is_trainable and model_args.use_unsloth: 64 | require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth") 65 | from unsloth import FastLlamaModel, FastMistralModel # type: ignore 66 | unsloth_kwargs = { 67 | "model_name": model_args.model_name_or_path, 68 | "max_seq_length": model_args.model_max_length, 69 | "dtype": model_args.compute_dtype, 70 | "load_in_4bit": model_args.quantization_bit == 4, 71 | "token": model_args.hf_hub_token, 72 | "device_map": get_current_device(), 73 | "rope_scaling": getattr(config, "rope_scaling", None) 74 | } 75 | if getattr(config, "model_type", None) == "llama": 76 | model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs) 77 | elif getattr(config, "model_type", None) == "mistral": 78 | model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs) 79 | else: 80 | logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) 81 | model_args.use_unsloth = False 82 | 83 | if model_args.adapter_name_or_path: 84 | model_args.adapter_name_or_path = None 85 | logger.warning("Unsloth does not support loading adapters.") 86 | 87 | if model is None: 88 | model = AutoModelForCausalLM.from_pretrained( 89 | model_args.model_name_or_path, 90 | config=config, 91 | torch_dtype=model_args.compute_dtype, 92 | low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), 93 | **config_kwargs 94 | ) 95 | 96 | patch_model(model, tokenizer, model_args, is_trainable) 97 | register_autoclass(config, model, tokenizer) 98 | 99 | model = init_adapter(model, model_args, finetuning_args, is_trainable) 100 | 101 | if add_valuehead: 102 | model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model) 103 | patch_valuehead_model(model) 104 | 105 | if model_args.adapter_name_or_path is not None: 106 | vhead_path = model_args.adapter_name_or_path[-1] 107 | else: 108 | vhead_path = model_args.model_name_or_path 109 | 110 | vhead_params = load_valuehead_params(vhead_path, model_args) 111 | if vhead_params is not None: 112 | model.load_state_dict(vhead_params, strict=False) 113 | logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) 114 | 115 | if not is_trainable: 116 | model.requires_grad_(False) 117 | model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model 118 | model.eval() 119 | else: 120 | model.train() 121 | 122 | trainable_params, all_param = count_parameters(model) 123 | logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( 124 | trainable_params, all_param, 100 * trainable_params / all_param 125 | )) 126 | 127 | if not is_trainable: 128 | logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.") 129 | 130 | return model, tokenizer 131 | -------------------------------------------------------------------------------- /model/src/model/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import logging 5 | import datasets 6 | import transformers 7 | from typing import Any, Dict, Optional, Tuple 8 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments 9 | from transformers.trainer_utils import get_last_checkpoint 10 | 11 | from ..extras.logging import get_logger 12 | from ..hparams.model_args import ModelArguments 13 | from ..hparams.data_args import DataArguments 14 | from ..hparams.evaluation_args import EvaluationArguments 15 | from ..hparams.finetuning_args import FinetuningArguments 16 | from ..hparams.generating_args import GeneratingArguments 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | _TRAIN_ARGS = [ 22 | ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments 23 | ] 24 | _TRAIN_CLS = Tuple[ 25 | ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments 26 | ] 27 | _INFER_ARGS = [ 28 | ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 29 | ] 30 | _INFER_CLS = Tuple[ 31 | ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 32 | ] 33 | _EVAL_ARGS = [ 34 | ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments 35 | ] 36 | _EVAL_CLS = Tuple[ 37 | ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments 38 | ] 39 | 40 | 41 | def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: 42 | if args is not None: 43 | return parser.parse_dict(args) 44 | 45 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 46 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 47 | 48 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 49 | return parser.parse_json_file(os.path.abspath(sys.argv[1])) 50 | 51 | (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) 52 | 53 | if unknown_args: 54 | print(parser.format_help()) 55 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) 56 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) 57 | 58 | return (*parsed_args,) 59 | 60 | 61 | def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: 62 | datasets.utils.logging.set_verbosity(log_level) 63 | transformers.utils.logging.set_verbosity(log_level) 64 | transformers.utils.logging.enable_default_handler() 65 | transformers.utils.logging.enable_explicit_format() 66 | 67 | 68 | def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: 69 | if model_args.quantization_bit is not None: 70 | if finetuning_args.finetuning_type != "lora": 71 | raise ValueError("Quantization is only compatible with the LoRA method.") 72 | 73 | if finetuning_args.create_new_adapter: 74 | raise ValueError("Cannot create new adapter upon a quantized model.") 75 | 76 | if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: 77 | if finetuning_args.finetuning_type != "lora": 78 | raise ValueError("Multiple adapters are only available for LoRA tuning.") 79 | 80 | if model_args.quantization_bit is not None: 81 | raise ValueError("Quantized model only accepts a single adapter. Merge them first.") 82 | 83 | 84 | def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: 85 | parser = HfArgumentParser(_TRAIN_ARGS) 86 | return _parse_args(parser, args) 87 | 88 | 89 | def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: 90 | parser = HfArgumentParser(_INFER_ARGS) 91 | return _parse_args(parser, args) 92 | 93 | 94 | def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: 95 | parser = HfArgumentParser(_EVAL_ARGS) 96 | return _parse_args(parser, args) 97 | 98 | 99 | def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: 100 | model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) 101 | 102 | # Setup logging 103 | if training_args.should_log: 104 | _set_transformers_logging() 105 | 106 | # Check arguments 107 | data_args.init_for_training(training_args.seed) 108 | 109 | if finetuning_args.stage != "pt" and data_args.template is None: 110 | raise ValueError("Please specify which `template` to use.") 111 | 112 | if finetuning_args.stage != "sft" and training_args.predict_with_generate: 113 | raise ValueError("`predict_with_generate` cannot be set as True except SFT.") 114 | 115 | if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: 116 | raise ValueError("Please enable `predict_with_generate` to save model predictions.") 117 | 118 | if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: 119 | raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") 120 | 121 | if finetuning_args.stage == "ppo" and not training_args.do_train: 122 | raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") 123 | 124 | if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])): 125 | raise ValueError("Please use ranked datasets for reward modeling or DPO training.") 126 | 127 | if finetuning_args.stage == "ppo" and model_args.shift_attn: 128 | raise ValueError("PPO training is incompatible with S^2-Attn.") 129 | 130 | if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: 131 | raise ValueError("Unsloth does not support lora reward model.") 132 | 133 | if training_args.max_steps == -1 and data_args.streaming: 134 | raise ValueError("Please specify `max_steps` in streaming mode.") 135 | 136 | if training_args.do_train and training_args.predict_with_generate: 137 | raise ValueError("`predict_with_generate` cannot be set as True while training.") 138 | 139 | if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None: 140 | raise ValueError("Please specify `lora_target` in LoRA training.") 141 | 142 | _verify_model_args(model_args, finetuning_args) 143 | 144 | if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): 145 | logger.warning("We recommend enable `upcast_layernorm` in quantized training.") 146 | 147 | if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): 148 | logger.warning("We recommend enable mixed precision training.") 149 | 150 | if (not training_args.do_train) and model_args.quantization_bit is not None: 151 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") 152 | 153 | if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: 154 | logger.warning("Specify `ref_model` for computing rewards at evaluation.") 155 | 156 | # postprocess training_args 157 | if ( 158 | training_args.local_rank != -1 159 | and training_args.ddp_find_unused_parameters is None 160 | and finetuning_args.finetuning_type == "lora" 161 | ): 162 | logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") 163 | training_args_dict = training_args.to_dict() 164 | training_args_dict.update(dict(ddp_find_unused_parameters=False)) 165 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 166 | 167 | if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: 168 | can_resume_from_checkpoint = False 169 | training_args.resume_from_checkpoint = None 170 | else: 171 | can_resume_from_checkpoint = True 172 | 173 | if ( 174 | training_args.resume_from_checkpoint is None 175 | and training_args.do_train 176 | and os.path.isdir(training_args.output_dir) 177 | and not training_args.overwrite_output_dir 178 | and can_resume_from_checkpoint 179 | ): 180 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 181 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 182 | raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") 183 | 184 | if last_checkpoint is not None: 185 | training_args_dict = training_args.to_dict() 186 | training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint)) 187 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 188 | logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format( 189 | training_args.resume_from_checkpoint 190 | )) 191 | 192 | if ( 193 | finetuning_args.stage in ["rm", "ppo"] 194 | and finetuning_args.finetuning_type == "lora" 195 | and training_args.resume_from_checkpoint is not None 196 | ): 197 | logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( 198 | training_args.resume_from_checkpoint 199 | )) 200 | 201 | # postprocess model_args 202 | model_args.compute_dtype = ( 203 | torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None) 204 | ) 205 | model_args.model_max_length = data_args.cutoff_len 206 | 207 | # Log on each process the small summary: 208 | logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format( 209 | training_args.local_rank, training_args.device, training_args.n_gpu, 210 | bool(training_args.local_rank != -1), str(model_args.compute_dtype) 211 | )) 212 | logger.info(f"Training/evaluation parameters {training_args}") 213 | 214 | # Set seed before initializing model. 215 | transformers.set_seed(training_args.seed) 216 | 217 | return model_args, data_args, training_args, finetuning_args, generating_args 218 | 219 | 220 | def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: 221 | model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) 222 | _set_transformers_logging() 223 | 224 | if data_args.template is None: 225 | raise ValueError("Please specify which `template` to use.") 226 | 227 | _verify_model_args(model_args, finetuning_args) 228 | 229 | return model_args, data_args, finetuning_args, generating_args 230 | 231 | 232 | def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: 233 | model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args) 234 | _set_transformers_logging() 235 | 236 | if data_args.template is None: 237 | raise ValueError("Please specify which `template` to use.") 238 | 239 | _verify_model_args(model_args, finetuning_args) 240 | 241 | transformers.set_seed(eval_args.seed) 242 | 243 | return model_args, data_args, eval_args, finetuning_args 244 | -------------------------------------------------------------------------------- /model/src/model/patcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | import random 5 | from types import MethodType 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple 7 | from datasets import load_dataset 8 | 9 | from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase 10 | from transformers.integrations import is_deepspeed_zero3_enabled 11 | from transformers.utils.versions import require_version 12 | 13 | from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES 14 | from ..extras.logging import get_logger 15 | from ..extras.misc import get_current_device, infer_optim_dtype 16 | from ..extras.packages import is_flash_attn2_available 17 | 18 | if TYPE_CHECKING: 19 | from transformers import PretrainedConfig, PreTrainedTokenizer 20 | from trl import AutoModelForCausalLMWithValueHead 21 | from ..hparams.model_args import ModelArguments 22 | 23 | 24 | logger = get_logger(__name__) 25 | SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama 26 | 27 | 28 | def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int): 29 | embedding_dim = embed_weight.size(1) 30 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) 31 | noise_weight = torch.empty_like(avg_weight[-num_new_tokens:]) 32 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) 33 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight 34 | 35 | 36 | def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: 37 | r""" 38 | Resize token embeddings. 39 | """ 40 | current_embedding_size = model.get_input_embeddings().weight.size(0) 41 | if len(tokenizer) > current_embedding_size: 42 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear): 43 | logger.warning("Current model does not support resizing token embeddings.") 44 | return 45 | 46 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) 47 | new_embedding_size = model.get_input_embeddings().weight.size(0) 48 | num_new_tokens = new_embedding_size - current_embedding_size 49 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) 50 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) 51 | 52 | logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) 53 | 54 | 55 | def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]: 56 | r""" 57 | Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133 58 | TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600 59 | """ 60 | if os.path.isfile(model_args.export_quantization_dataset): 61 | data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) 62 | data_files = model_args.export_quantization_dataset 63 | else: 64 | data_path = model_args.export_quantization_dataset 65 | data_files = None 66 | 67 | dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir) 68 | maxlen = model_args.export_quantization_maxlen 69 | 70 | samples = [] 71 | for _ in range(model_args.export_quantization_nsamples): 72 | while True: 73 | sample_idx = random.randint(0, len(dataset) - 1) 74 | sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") 75 | if sample["input_ids"].size(1) >= maxlen: 76 | break # TODO: fix large maxlen 77 | 78 | word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) 79 | input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] 80 | samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True)) 81 | 82 | return samples 83 | 84 | 85 | def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: 86 | if not hasattr(config, "rope_scaling"): 87 | logger.warning("Current model does not support RoPE scaling.") 88 | return 89 | 90 | if is_trainable: 91 | if model_args.rope_scaling == "dynamic": 92 | logger.warning( 93 | "Dynamic NTK scaling may not work well with fine-tuning. " 94 | "See: https://github.com/huggingface/transformers/pull/24653" 95 | ) 96 | 97 | current_max_length = getattr(config, "max_position_embeddings", None) 98 | if current_max_length and model_args.model_max_length > current_max_length: 99 | scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) 100 | else: 101 | logger.warning("Input length is smaller than max length. Consider increase input length.") 102 | scaling_factor = 1.0 103 | else: 104 | scaling_factor = 2.0 105 | 106 | setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) 107 | logger.info("Using {} scaling strategy and setting scaling factor to {}".format( 108 | model_args.rope_scaling, scaling_factor 109 | )) 110 | 111 | 112 | def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None: 113 | if not is_flash_attn2_available(): 114 | logger.warning("FlashAttention2 is not installed.") 115 | return 116 | 117 | config_kwargs["use_flash_attention_2"] = True 118 | logger.info("Using FlashAttention-2 for faster training and inference.") 119 | 120 | 121 | def _configure_longlora(config: "PretrainedConfig") -> None: 122 | if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: 123 | setattr(config, "group_size_ratio", 0.25) 124 | logger.info("Using shift short attention with group_size_ratio=1/4.") 125 | else: 126 | logger.warning("Current model does not support shift short attention.") 127 | 128 | 129 | def _configure_quantization( 130 | config: "PretrainedConfig", 131 | tokenizer: "PreTrainedTokenizer", 132 | model_args: "ModelArguments", 133 | config_kwargs: Dict[str, Any] 134 | ) -> None: 135 | r""" 136 | Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training) 137 | """ 138 | if getattr(config, "quantization_config", None): # gptq 139 | if is_deepspeed_zero3_enabled(): 140 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") 141 | 142 | config_kwargs["device_map"] = {"": get_current_device()} 143 | quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) 144 | if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4: 145 | quantization_config["use_exllama"] = False # disable exllama 146 | logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1))) 147 | 148 | elif model_args.export_quantization_bit is not None: # auto-gptq 149 | require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0") 150 | require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") 151 | from accelerate.utils import get_max_memory 152 | 153 | if getattr(config, "model_type", None) == "chatglm": 154 | raise ValueError("ChatGLM model is not supported.") 155 | 156 | config_kwargs["quantization_config"] = GPTQConfig( 157 | bits=model_args.export_quantization_bit, 158 | tokenizer=tokenizer, 159 | dataset=_get_quantization_dataset(tokenizer, model_args) 160 | ) 161 | config_kwargs["device_map"] = "auto" 162 | config_kwargs["max_memory"] = get_max_memory() 163 | logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit)) 164 | 165 | elif model_args.quantization_bit is not None: # bnb 166 | if is_deepspeed_zero3_enabled(): 167 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.") 168 | 169 | if model_args.quantization_bit == 8: 170 | require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") 171 | config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) 172 | 173 | elif model_args.quantization_bit == 4: 174 | require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") 175 | config_kwargs["quantization_config"] = BitsAndBytesConfig( 176 | load_in_4bit=True, 177 | bnb_4bit_compute_dtype=model_args.compute_dtype, 178 | bnb_4bit_use_double_quant=model_args.double_quantization, 179 | bnb_4bit_quant_type=model_args.quantization_type 180 | ) 181 | 182 | config_kwargs["device_map"] = {"": get_current_device()} 183 | logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) 184 | 185 | 186 | def _prepare_model_for_training( 187 | model: "PreTrainedModel", 188 | model_args: "ModelArguments", 189 | output_layer_name: Optional[str] = "lm_head" 190 | ) -> None: 191 | r""" 192 | Includes: 193 | (1) cast the layernorm in fp32 194 | (2) make output embedding layer require grads 195 | (3) add the upcasting of the lm_head in fp32 196 | Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72 197 | """ 198 | if model_args.upcast_layernorm: 199 | for name, param in model.named_parameters(): 200 | if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): 201 | param.data = param.data.to(torch.float32) 202 | logger.info("Upcasting layernorm weights in float32.") 203 | 204 | if not model_args.disable_gradient_checkpointing: 205 | if not getattr(model, "supports_gradient_checkpointing", False): 206 | logger.warning("Current model does not support gradient checkpointing.") 207 | else: 208 | model.enable_input_require_grads() 209 | model.gradient_checkpointing_enable() 210 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 211 | logger.info("Gradient checkpointing enabled.") 212 | 213 | if hasattr(model, output_layer_name): 214 | def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor): 215 | return output.to(torch.float32) 216 | 217 | output_layer = getattr(model, output_layer_name) 218 | if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: 219 | output_layer.register_forward_hook(fp32_forward_post_hook) 220 | 221 | 222 | def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: 223 | if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): 224 | tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) 225 | 226 | 227 | def patch_config( 228 | config: "PretrainedConfig", 229 | tokenizer: "PreTrainedTokenizer", 230 | model_args: "ModelArguments", 231 | config_kwargs: Dict[str, Any], 232 | is_trainable: bool 233 | ) -> None: 234 | if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 235 | model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) 236 | 237 | if getattr(config, "model_type", None) == "qwen": 238 | for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: 239 | setattr(config, dtype_name, model_args.compute_dtype == dtype) 240 | 241 | if model_args.rope_scaling is not None: 242 | _configure_rope(config, model_args, is_trainable) 243 | 244 | if model_args.flash_attn: 245 | _configure_flashattn(config_kwargs) 246 | 247 | if is_trainable and model_args.shift_attn: 248 | _configure_longlora(config) 249 | 250 | _configure_quantization(config, tokenizer, model_args, config_kwargs) 251 | 252 | 253 | def patch_model( 254 | model: "PreTrainedModel", 255 | tokenizer: "PreTrainedTokenizer", 256 | model_args: "ModelArguments", 257 | is_trainable: bool 258 | ) -> None: 259 | if "GenerationMixin" not in str(model.generate.__func__): 260 | model.generate = MethodType(PreTrainedModel.generate, model) 261 | 262 | if getattr(model.config, "model_type", None) == "chatglm": 263 | setattr(model, "lm_head", model.transformer.output_layer) 264 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) 265 | 266 | if model_args.resize_vocab: 267 | if is_deepspeed_zero3_enabled(): 268 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with vocab resizing.") 269 | 270 | _resize_embedding_layer(model, tokenizer) 271 | 272 | if is_trainable: 273 | _prepare_model_for_training(model, model_args) 274 | 275 | 276 | def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: 277 | def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: 278 | if isinstance(self.pretrained_model, PreTrainedModel): 279 | self.pretrained_model.tie_weights() 280 | 281 | def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: 282 | if isinstance(self.pretrained_model, PreTrainedModel): 283 | return self.pretrained_model.get_input_embeddings() 284 | 285 | ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] 286 | setattr(model, "_keys_to_ignore_on_save", ignore_modules) 287 | setattr(model, "tie_weights", MethodType(tie_weights, model)) 288 | setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) 289 | -------------------------------------------------------------------------------- /model/src/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | from typing import TYPE_CHECKING, Any, Dict, List 4 | from transformers import PreTrainedModel 5 | from transformers.utils import cached_file 6 | 7 | from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME 8 | from ..extras.logging import get_logger 9 | from ..extras.misc import get_current_device 10 | 11 | if TYPE_CHECKING: 12 | from transformers import PretrainedConfig, PreTrainedTokenizer 13 | from ..hparams.model_args import ModelArguments 14 | from ..hparams.data_args import DataArguments 15 | from ..hparams.finetuning_args import FinetuningArguments 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": 22 | r""" 23 | Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available. 24 | Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570 25 | """ 26 | if getattr(model, "quantization_method", None): # already set on current device 27 | return model 28 | 29 | if ( 30 | torch.cuda.device_count() > 1 31 | and isinstance(model, PreTrainedModel) 32 | and model._no_split_modules is not None 33 | and model.config.model_type != "chatglm" 34 | ): 35 | from accelerate import dispatch_model 36 | from accelerate.utils import infer_auto_device_map, get_balanced_memory 37 | 38 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")} 39 | max_memory = get_balanced_memory(model, **kwargs) 40 | # Make sure tied weights are tied before creating the device map. 41 | model.tie_weights() 42 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) 43 | device_map_kwargs = {"device_map": device_map} 44 | if "skip_keys" in inspect.signature(dispatch_model).parameters: 45 | device_map_kwargs["skip_keys"] = model._skip_keys_device_placement 46 | return dispatch_model(model, **device_map_kwargs) 47 | else: 48 | return model.to(device=get_current_device()) 49 | 50 | 51 | def find_all_linear_modules(model: "PreTrainedModel") -> List[str]: 52 | r""" 53 | Finds all available modules to apply lora. 54 | """ 55 | quantization_method = getattr(model, "quantization_method", None) 56 | if quantization_method is None: 57 | linear_cls = torch.nn.Linear 58 | elif quantization_method == "bitsandbytes": 59 | import bitsandbytes as bnb 60 | linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt 61 | else: 62 | raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method)) 63 | 64 | output_layer_names = ["lm_head"] 65 | if model.config.model_type == "chatglm": 66 | output_layer_names.append("output_layer") 67 | 68 | module_names = set() 69 | for name, module in model.named_modules(): 70 | if ( 71 | isinstance(module, linear_cls) 72 | and not any([output_layer in name for output_layer in output_layer_names]) 73 | ): 74 | module_names.add(name.split(".")[-1]) 75 | 76 | logger.info("Found linear modules: {}".format(",".join(module_names))) 77 | return list(module_names) 78 | 79 | 80 | def get_modelcard_args( 81 | model_args: "ModelArguments", 82 | data_args: "DataArguments", 83 | finetuning_args: "FinetuningArguments" 84 | ) -> Dict[str, Any]: 85 | return { 86 | "tasks": "text-generation", 87 | "license": "other", 88 | "finetuned_from": model_args.model_name_or_path, 89 | "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")], 90 | "tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else []) 91 | } 92 | 93 | 94 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: 95 | r""" 96 | Loads value head parameters from Hugging Face Hub or local disk. 97 | 98 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. 99 | """ 100 | kwargs = { 101 | "path_or_repo_id": path_or_repo_id, 102 | "cache_dir": model_args.cache_dir, 103 | "token": model_args.hf_hub_token 104 | } 105 | 106 | try: 107 | from safetensors import safe_open 108 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) 109 | with safe_open(vhead_file, framework="pt", device="cpu") as f: 110 | return {key: f.get_tensor(key) for key in f.keys()} 111 | except Exception as err: 112 | logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err))) 113 | 114 | try: 115 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) 116 | return torch.load(vhead_file, map_location="cpu") 117 | except Exception as err: 118 | logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err))) 119 | 120 | logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id)) 121 | logger.info("Ignore these messages if you are not resuming the training of a value head model.") 122 | return None 123 | 124 | 125 | def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): 126 | if "AutoConfig" in getattr(config, "auto_map", {}): 127 | config.__class__.register_for_auto_class() 128 | if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): 129 | model.__class__.register_for_auto_class() 130 | if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): 131 | tokenizer.__class__.register_for_auto_class() 132 | -------------------------------------------------------------------------------- /model/src/train/dpo/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Sequence, Tuple 4 | from transformers import DataCollatorForSeq2Seq 5 | 6 | 7 | @dataclass 8 | class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): 9 | r""" 10 | Data collator for pairwise data. 11 | """ 12 | 13 | def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: 14 | padded_labels = [] 15 | for feature, (prompt_len, answer_len) in zip(batch, positions): 16 | if self.tokenizer.padding_side == "left": 17 | start, end = feature.size(0) - answer_len, feature.size(0) 18 | else: 19 | start, end = prompt_len, prompt_len + answer_len 20 | padded_tensor = self.label_pad_token_id * torch.ones_like(feature) 21 | padded_tensor[start:end] = feature[start:end] 22 | padded_labels.append(padded_tensor) 23 | return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory 24 | 25 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 26 | r""" 27 | Pads batched data to the longest sequence in the batch. 28 | 29 | We generate 2 * n examples where the first n examples represent chosen examples and 30 | the last n examples represent rejected examples. 31 | """ 32 | concatenated_features = [] 33 | label_positions = [] 34 | for key in ("chosen_ids", "rejected_ids"): 35 | for feature in features: 36 | prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) 37 | concatenated_features.append({ 38 | "input_ids": feature["prompt_ids"] + feature[key], 39 | "attention_mask": [1] * (prompt_len + answer_len) 40 | }) 41 | label_positions.append((prompt_len, answer_len)) 42 | 43 | batch = self.tokenizer.pad( 44 | concatenated_features, 45 | padding=self.padding, 46 | max_length=self.max_length, 47 | pad_to_multiple_of=self.pad_to_multiple_of, 48 | return_tensors=self.return_tensors, 49 | ) 50 | batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) 51 | return batch 52 | -------------------------------------------------------------------------------- /model/src/train/dpo/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union 4 | from transformers import BatchEncoding, Trainer 5 | from trl import DPOTrainer 6 | from trl.trainer.utils import disable_dropout_in_model 7 | 8 | from ...extras.constants import IGNORE_INDEX 9 | 10 | if TYPE_CHECKING: 11 | from transformers import PreTrainedModel 12 | 13 | 14 | class CustomDPOTrainer(DPOTrainer): 15 | 16 | def __init__( 17 | self, 18 | beta: float, 19 | loss_type: Literal["sigmoid", "hinge", "ipo", "kto"], 20 | ftx_gamma: float, 21 | model: Union["PreTrainedModel", torch.nn.Module], 22 | ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, 23 | disable_dropout: Optional[bool] = True, 24 | **kwargs 25 | ): 26 | if disable_dropout: 27 | disable_dropout_in_model(model) 28 | if ref_model is not None: 29 | disable_dropout_in_model(ref_model) 30 | 31 | self.use_dpo_data_collator = True # hack to avoid warning 32 | self.generate_during_eval = False # disable at evaluation 33 | self.label_pad_token_id = IGNORE_INDEX 34 | self.padding_value = 0 35 | self.is_encoder_decoder = model.config.is_encoder_decoder 36 | self.precompute_ref_log_probs = False 37 | self._precomputed_train_ref_log_probs = False 38 | self._precomputed_eval_ref_log_probs = False 39 | self._peft_has_been_casted_to_bf16 = False 40 | 41 | self.ref_model = ref_model 42 | self.beta = beta 43 | self.label_smoothing = 0 44 | self.loss_type = loss_type 45 | self.ftx_gamma = ftx_gamma 46 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 47 | 48 | Trainer.__init__(self, model=model, **kwargs) 49 | if not hasattr(self, "accelerator"): 50 | raise AttributeError("Please update `transformers`.") 51 | 52 | if ref_model is not None: 53 | if self.is_deepspeed_enabled: 54 | if not ( 55 | getattr(ref_model, "is_loaded_in_8bit", False) 56 | or getattr(ref_model, "is_loaded_in_4bit", False) 57 | ): # quantized models are already set on the correct device 58 | self.ref_model = self._prepare_deepspeed(self.ref_model) 59 | else: 60 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 61 | 62 | def sft_loss( 63 | self, 64 | chosen_logits: torch.FloatTensor, 65 | chosen_labels: torch.LongTensor 66 | ) -> torch.Tensor: 67 | r""" 68 | Computes supervised cross-entropy loss of given labels under the given logits. 69 | 70 | Returns: 71 | A tensor of shape (batch_size,) containing the cross-entropy loss of each samples. 72 | """ 73 | all_logps = self.get_batch_logps( 74 | chosen_logits, 75 | chosen_labels, 76 | average_log_prob=True 77 | ) 78 | return -all_logps 79 | 80 | def concatenated_forward( 81 | self, 82 | model: "PreTrainedModel", 83 | batch: Dict[str, torch.Tensor] 84 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 85 | batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error 86 | 87 | all_logits = model( 88 | input_ids=batch_copied["input_ids"], 89 | attention_mask=batch_copied["attention_mask"], 90 | return_dict=True 91 | ).logits.to(torch.float32) 92 | 93 | all_logps = self.get_batch_logps( 94 | all_logits, 95 | batch["labels"], 96 | average_log_prob=False 97 | ) 98 | batch_size = batch["input_ids"].size(0) // 2 99 | chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) 100 | chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) 101 | return chosen_logps, rejected_logps, chosen_logits, rejected_logits 102 | 103 | def get_batch_loss_metrics( 104 | self, 105 | model: "PreTrainedModel", 106 | batch: Dict[str, torch.Tensor], 107 | train_eval: Optional[Literal["train", "eval"]] = "train" 108 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 109 | r""" 110 | Computes the DPO loss and other metrics for the given batch of inputs for train or test. 111 | """ 112 | metrics = {} 113 | ( 114 | policy_chosen_logps, 115 | policy_rejected_logps, 116 | policy_chosen_logits, 117 | policy_rejected_logits, 118 | ) = self.concatenated_forward(model, batch) 119 | with torch.no_grad(): 120 | if self.ref_model is None: 121 | with self.accelerator.unwrap_model(self.model).disable_adapter(): 122 | ( 123 | reference_chosen_logps, 124 | reference_rejected_logps, 125 | _, 126 | _, 127 | ) = self.concatenated_forward(self.model, batch) 128 | else: 129 | ( 130 | reference_chosen_logps, 131 | reference_rejected_logps, 132 | _, 133 | _, 134 | ) = self.concatenated_forward(self.ref_model, batch) 135 | 136 | losses, chosen_rewards, rejected_rewards = self.dpo_loss( 137 | policy_chosen_logps, 138 | policy_rejected_logps, 139 | reference_chosen_logps, 140 | reference_rejected_logps, 141 | ) 142 | if self.ftx_gamma > 1e-6: 143 | batch_size = batch["input_ids"].size(0) // 2 144 | chosen_labels, _ = batch["labels"].split(batch_size, dim=0) 145 | losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels) 146 | 147 | reward_accuracies = (chosen_rewards > rejected_rewards).float() 148 | 149 | prefix = "eval_" if train_eval == "eval" else "" 150 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean() 151 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean() 152 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean() 153 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean() 154 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean() 155 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean() 156 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean() 157 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean() 158 | 159 | return losses.mean(), metrics 160 | -------------------------------------------------------------------------------- /model/src/train/dpo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py 2 | 3 | from typing import TYPE_CHECKING, Optional, List 4 | from transformers import Seq2SeqTrainingArguments 5 | 6 | from ...data.loader import get_dataset 7 | from ...data.utils import split_dataset 8 | from ...extras.constants import IGNORE_INDEX 9 | from ...extras.ploting import plot_loss 10 | from ...hparams.model_args import ModelArguments 11 | from ...model.loader import load_model_and_tokenizer 12 | from ...train.dpo.collator import DPODataCollatorWithPadding 13 | from ...train.dpo.trainer import CustomDPOTrainer 14 | from ...train.utils import create_modelcard_and_push, create_ref_model 15 | 16 | if TYPE_CHECKING: 17 | from transformers import TrainerCallback 18 | from ...hparams import DataArguments, FinetuningArguments 19 | 20 | 21 | def run_dpo( 22 | model_args: "ModelArguments", 23 | data_args: "DataArguments", 24 | training_args: "Seq2SeqTrainingArguments", 25 | finetuning_args: "FinetuningArguments", 26 | callbacks: Optional[List["TrainerCallback"]] = None 27 | ): 28 | 29 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) 30 | dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm") 31 | data_collator = DPODataCollatorWithPadding( 32 | tokenizer=tokenizer, 33 | pad_to_multiple_of=8, 34 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 35 | ) 36 | 37 | # Create reference model 38 | if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself 39 | ref_model = model 40 | else: 41 | ref_model = create_ref_model(model_args, finetuning_args) 42 | 43 | # Update arguments 44 | training_args_dict = training_args.to_dict() 45 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset 46 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 47 | 48 | # Initialize our Trainer 49 | trainer = CustomDPOTrainer( 50 | beta=finetuning_args.dpo_beta, 51 | loss_type=finetuning_args.dpo_loss, 52 | ftx_gamma=finetuning_args.dpo_ftx, 53 | model=model, 54 | ref_model=ref_model, 55 | args=training_args, 56 | tokenizer=tokenizer, 57 | data_collator=data_collator, 58 | callbacks=callbacks, 59 | **split_dataset(dataset, data_args, training_args) 60 | ) 61 | 62 | # Training 63 | if training_args.do_train: 64 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 65 | trainer.save_model() 66 | trainer.log_metrics("train", train_result.metrics) 67 | trainer.save_metrics("train", train_result.metrics) 68 | trainer.save_state() 69 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 70 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 71 | 72 | # Evaluation 73 | if training_args.do_eval: 74 | metrics = trainer.evaluate(metric_key_prefix="eval") 75 | if id(model) == id(ref_model): # unable to compute rewards without a reference model 76 | remove_keys = [key for key in metrics.keys() if "rewards" in key] 77 | for key in remove_keys: 78 | metrics.pop(key) 79 | trainer.log_metrics("eval", metrics) 80 | trainer.save_metrics("eval", metrics) 81 | 82 | # Create model card 83 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) 84 | -------------------------------------------------------------------------------- /model/src/train/sft/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union 4 | 5 | from ...extras.constants import IGNORE_INDEX 6 | from ...extras.packages import ( 7 | is_jieba_available, is_nltk_available, is_rouge_available 8 | ) 9 | 10 | if TYPE_CHECKING: 11 | from transformers.tokenization_utils import PreTrainedTokenizer 12 | 13 | if is_jieba_available(): 14 | import jieba 15 | 16 | if is_nltk_available(): 17 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 18 | 19 | if is_rouge_available(): 20 | from rouge_chinese import Rouge 21 | 22 | 23 | @dataclass 24 | class ComputeMetrics: 25 | r""" 26 | Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. 27 | """ 28 | 29 | tokenizer: "PreTrainedTokenizer" 30 | 31 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 32 | r""" 33 | Uses the model predictions to compute metrics. 34 | """ 35 | preds, labels = eval_preds 36 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 37 | 38 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) 39 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 40 | 41 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 42 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 43 | 44 | for pred, label in zip(decoded_preds, decoded_labels): 45 | hypothesis = list(jieba.cut(pred)) 46 | reference = list(jieba.cut(label)) 47 | 48 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: 49 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 50 | else: 51 | rouge = Rouge() 52 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 53 | result = scores[0] 54 | 55 | for k, v in result.items(): 56 | score_dict[k].append(round(v["f"] * 100, 4)) 57 | 58 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 59 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 60 | 61 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 62 | -------------------------------------------------------------------------------- /model/src/train/sft/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 7 | 8 | import tqdm 9 | from transformers import Seq2SeqTrainer 10 | 11 | from ...extras.constants import IGNORE_INDEX 12 | from ...extras.logging import get_logger 13 | 14 | if TYPE_CHECKING: 15 | from transformers.trainer import PredictionOutput 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | class CustomSeq2SeqTrainer(Seq2SeqTrainer): 22 | r""" 23 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. 24 | """ 25 | 26 | def prediction_step( 27 | self, 28 | model: nn.Module, 29 | inputs: Dict[str, Union[torch.Tensor, Any]], 30 | prediction_loss_only: bool, 31 | ignore_keys: Optional[List[str]] = None, 32 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 33 | r""" 34 | Removes the prompt part in the generated tokens. 35 | 36 | Subclass and override to inject custom behavior. 37 | """ 38 | labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels 39 | if self.args.predict_with_generate: 40 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." 41 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) 42 | if prompt_len > label_len: 43 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) 44 | if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) 45 | inputs["labels"] = inputs["labels"][:, :prompt_len] 46 | 47 | loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) 48 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 49 | ) 50 | if generated_tokens is not None and self.args.predict_with_generate: 51 | generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id 52 | generated_tokens = generated_tokens.contiguous() 53 | 54 | return loss, generated_tokens, labels 55 | 56 | def _pad_tensors_to_target_len( 57 | self, 58 | src_tensor: torch.Tensor, 59 | tgt_tensor: torch.Tensor 60 | ) -> torch.Tensor: 61 | r""" 62 | Pads the tensor to the same length as the target tensor. 63 | """ 64 | assert self.tokenizer.pad_token_id is not None, "Pad token is required." 65 | padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) 66 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding 67 | return padded_tensor.contiguous() # in contiguous memory 68 | 69 | def save_predictions( 70 | self, 71 | predict_results: "PredictionOutput", 72 | sentence_num: int = 1 73 | ) -> None: 74 | r""" 75 | Saves model predictions to `output_dir`. 76 | 77 | A custom behavior that not contained in Seq2SeqTrainer. 78 | """ 79 | if not self.is_world_process_zero(): 80 | return 81 | 82 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 83 | logger.info(f"Saving prediction results to {output_prediction_file}") 84 | 85 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, 86 | self.tokenizer.pad_token_id) 87 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, 88 | self.tokenizer.pad_token_id) 89 | 90 | for i in range(len(preds)): 91 | pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] 92 | if len(pad_len): 93 | preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]), 94 | axis=-1) # move pad token to last 95 | 96 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, 97 | clean_up_tokenization_spaces=False) 98 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) 99 | 100 | if sentence_num != 1: 101 | decoded_labels = [element for element in decoded_labels for _ in range(sentence_num)] 102 | 103 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 104 | res: List[str] = [] 105 | for label, pred in zip(decoded_labels, decoded_preds): 106 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 107 | writer.write("\n".join(res)) -------------------------------------------------------------------------------- /model/src/train/sft/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py 2 | 3 | from typing import TYPE_CHECKING, Optional, List 4 | from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments 5 | 6 | from ...data.loader import get_dataset 7 | from ...data.preprocess import preprocess_dataset 8 | from ...data.utils import split_dataset 9 | from ...extras.constants import IGNORE_INDEX 10 | from ...extras.misc import get_logits_processor 11 | from ...extras.ploting import plot_loss 12 | from ...model.loader import load_model_and_tokenizer 13 | from ...train.sft.metric import ComputeMetrics 14 | from ...train.sft.trainer import CustomSeq2SeqTrainer 15 | from ...train.utils import create_modelcard_and_push 16 | 17 | if TYPE_CHECKING: 18 | from transformers import TrainerCallback 19 | from ...hparams.model_args import ModelArguments 20 | from ...hparams.data_args import DataArguments 21 | from ...hparams.evaluation_args import EvaluationArguments 22 | from ...hparams.finetuning_args import FinetuningArguments 23 | from ...hparams.generating_args import GeneratingArguments 24 | 25 | 26 | def run_sft( 27 | model_args: "ModelArguments", 28 | data_args: "DataArguments", 29 | training_args: "Seq2SeqTrainingArguments", 30 | finetuning_args: "FinetuningArguments", 31 | generating_args: "GeneratingArguments", 32 | callbacks: Optional[List["TrainerCallback"]] = None 33 | ): 34 | dataset = get_dataset(model_args, data_args) 35 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train) 36 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") 37 | 38 | if training_args.predict_with_generate: 39 | tokenizer.padding_side = "left" # use left-padding in generation 40 | 41 | if getattr(model, "is_quantized", False) and not training_args.do_train: 42 | setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction 43 | 44 | data_collator = DataCollatorForSeq2Seq( 45 | tokenizer=tokenizer, 46 | pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention 47 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 48 | ) 49 | 50 | # Override the decoding parameters of Seq2SeqTrainer 51 | training_args_dict = training_args.to_dict() 52 | training_args_dict.update(dict( 53 | generation_max_length=training_args.generation_max_length or data_args.cutoff_len, 54 | generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams 55 | )) 56 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 57 | 58 | # Initialize our Trainer 59 | trainer = CustomSeq2SeqTrainer( 60 | model=model, 61 | args=training_args, 62 | tokenizer=tokenizer, 63 | data_collator=data_collator, 64 | callbacks=callbacks, 65 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, 66 | **split_dataset(dataset, data_args, training_args) 67 | ) 68 | 69 | # Keyword arguments for `model.generate` 70 | gen_kwargs = generating_args.to_dict() 71 | gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids 72 | gen_kwargs["pad_token_id"] = tokenizer.pad_token_id 73 | gen_kwargs["logits_processor"] = get_logits_processor() 74 | 75 | # Training 76 | if training_args.do_train: 77 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 78 | trainer.save_model() 79 | trainer.log_metrics("train", train_result.metrics) 80 | trainer.save_metrics("train", train_result.metrics) 81 | trainer.save_state() 82 | if trainer.is_world_process_zero() and finetuning_args.plot_loss: 83 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 84 | 85 | # Evaluation 86 | if training_args.do_eval: 87 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 88 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled 89 | metrics.pop("eval_loss", None) 90 | trainer.log_metrics("eval", metrics) 91 | trainer.save_metrics("eval", metrics) 92 | 93 | # Predict 94 | if training_args.do_predict: 95 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 96 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled 97 | predict_results.metrics.pop("predict_loss", None) 98 | trainer.log_metrics("predict", predict_results.metrics) 99 | trainer.save_metrics("predict", predict_results.metrics) 100 | trainer.save_predictions(predict_results, sentence_num=gen_kwargs["num_return_sequences"]) 101 | 102 | # Create model card 103 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) -------------------------------------------------------------------------------- /model/src/train/tuner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, Any, Dict, List, Optional 3 | from transformers import PreTrainedModel 4 | 5 | from llmtuner.extras.callbacks import LogCallback 6 | from llmtuner.extras.logging import get_logger 7 | from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer 8 | from llmtuner.train.pt import run_pt 9 | from llmtuner.train.sft import run_sft 10 | from llmtuner.train.rm import run_rm 11 | from llmtuner.train.ppo import run_ppo 12 | from llmtuner.train.dpo import run_dpo 13 | 14 | if TYPE_CHECKING: 15 | from transformers import TrainerCallback 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): 22 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) 23 | callbacks = [LogCallback()] if callbacks is None else callbacks 24 | 25 | if finetuning_args.stage == "pt": 26 | run_pt(model_args, data_args, training_args, finetuning_args, callbacks) 27 | elif finetuning_args.stage == "sft": 28 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 29 | elif finetuning_args.stage == "rm": 30 | run_rm(model_args, data_args, training_args, finetuning_args, callbacks) 31 | elif finetuning_args.stage == "ppo": 32 | run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 33 | elif finetuning_args.stage == "dpo": 34 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) 35 | else: 36 | raise ValueError("Unknown task.") 37 | 38 | 39 | def export_model(args: Optional[Dict[str, Any]] = None): 40 | model_args, _, finetuning_args, _ = get_infer_args(args) 41 | 42 | if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: 43 | raise ValueError("Please merge adapters before quantizing the model.") 44 | 45 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 46 | 47 | if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: 48 | raise ValueError("Cannot merge adapters to a quantized model.") 49 | 50 | if not isinstance(model, PreTrainedModel): 51 | raise ValueError("The model is not a `PreTrainedModel`, export aborted.") 52 | 53 | model.config.use_cache = True 54 | if getattr(model.config, "torch_dtype", None) == "bfloat16": 55 | model = model.to(torch.bfloat16).to("cpu") 56 | else: 57 | model = model.to(torch.float16).to("cpu") 58 | setattr(model.config, "torch_dtype", "float16") 59 | 60 | model.save_pretrained( 61 | save_directory=model_args.export_dir, 62 | max_shard_size="{}GB".format(model_args.export_size), 63 | safe_serialization=(not model_args.export_legacy_format) 64 | ) 65 | 66 | try: 67 | tokenizer.padding_side = "left" # restore padding side 68 | tokenizer.init_kwargs["padding_side"] = "left" 69 | tokenizer.save_pretrained(model_args.export_dir) 70 | except: 71 | logger.warning("Cannot save tokenizer, please copy the files manually.") 72 | 73 | 74 | if __name__ == "__main__": 75 | run_exp() 76 | -------------------------------------------------------------------------------- /model/src/train/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, Optional, Union 3 | 4 | from ..extras.logging import get_logger 5 | from ..hparams.model_args import ModelArguments 6 | from ..hparams.finetuning_args import FinetuningArguments 7 | from ..model.utils import get_modelcard_args, load_valuehead_params 8 | from ..model.loader import load_model_and_tokenizer 9 | 10 | if TYPE_CHECKING: 11 | from transformers import Seq2SeqTrainingArguments, Trainer 12 | from transformers.modeling_utils import PreTrainedModel 13 | from trl import AutoModelForCausalLMWithValueHead 14 | from ..hparams.data_args import DataArguments 15 | 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | def create_modelcard_and_push( 21 | trainer: "Trainer", 22 | model_args: "ModelArguments", 23 | data_args: "DataArguments", 24 | training_args: "Seq2SeqTrainingArguments", 25 | finetuning_args: "FinetuningArguments" 26 | ) -> None: 27 | if training_args.do_train: 28 | if training_args.push_to_hub: 29 | trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args)) 30 | return 31 | try: 32 | trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args)) 33 | except Exception as err: 34 | logger.warning("Failed to create model card: {}".format(str(err))) 35 | 36 | 37 | def create_ref_model( 38 | model_args: "ModelArguments", 39 | finetuning_args: "FinetuningArguments", 40 | add_valuehead: Optional[bool] = False 41 | ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]: 42 | r""" 43 | Creates reference model for PPO/DPO training. Evaluation mode is not supported. 44 | 45 | The valuehead parameter is randomly initialized since it is useless for PPO training. 46 | """ 47 | if finetuning_args.ref_model is not None: 48 | ref_model_args_dict = model_args.to_dict() 49 | ref_model_args_dict.update(dict( 50 | model_name_or_path=finetuning_args.ref_model, 51 | adapter_name_or_path=finetuning_args.ref_model_adapters, 52 | quantization_bit=finetuning_args.ref_model_quantization_bit 53 | )) 54 | ref_model_args = ModelArguments(**ref_model_args_dict) 55 | ref_finetuning_args = FinetuningArguments(finetuning_type="lora") 56 | ref_model, _ = load_model_and_tokenizer( 57 | ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead 58 | ) 59 | logger.info("Created reference model from {}".format(finetuning_args.ref_model)) 60 | else: 61 | if finetuning_args.finetuning_type == "lora": 62 | ref_model = None 63 | else: 64 | ref_model, _ = load_model_and_tokenizer( 65 | model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead 66 | ) 67 | logger.info("Created reference model from the model itself.") 68 | 69 | return ref_model 70 | 71 | 72 | def create_reward_model( 73 | model: "AutoModelForCausalLMWithValueHead", 74 | model_args: "ModelArguments", 75 | finetuning_args: "FinetuningArguments" 76 | ) -> "AutoModelForCausalLMWithValueHead": 77 | r""" 78 | Creates reward model for PPO training. 79 | """ 80 | if finetuning_args.reward_model_type == "api": 81 | assert finetuning_args.reward_model.startswith("http"), "Please provide full url." 82 | logger.info("Use reward server {}".format(finetuning_args.reward_model)) 83 | return finetuning_args.reward_model 84 | elif finetuning_args.reward_model_type == "lora": 85 | model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") 86 | for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 87 | if "default" in name: 88 | param.data = param.data.to(torch.float32) # trainable params should in fp32 89 | vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) 90 | assert vhead_params is not None, "Reward model is not correctly loaded." 91 | model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) 92 | model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) 93 | model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) 94 | model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) 95 | logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) 96 | return None 97 | else: 98 | reward_model_args_dict = model_args.to_dict() 99 | reward_model_args_dict.update(dict( 100 | model_name_or_path=finetuning_args.reward_model, 101 | adapter_name_or_path=finetuning_args.reward_model_adapters, 102 | quantization_bit=finetuning_args.reward_model_quantization_bit 103 | )) 104 | reward_model_args = ModelArguments(**reward_model_args_dict) 105 | reward_finetuning_args = FinetuningArguments(finetuning_type="lora") 106 | reward_model, _ = load_model_and_tokenizer( 107 | reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True 108 | ) 109 | logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) 110 | logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") 111 | return reward_model 112 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.26.0 2 | aiofiles==23.2.1 3 | aiohttp==3.9.0 4 | aiosignal==1.3.1 5 | altair==5.1.2 6 | annotated-types==0.6.0 7 | anyio==3.7.1 8 | appdirs==1.4.4 9 | attrs==23.1.0 10 | bird==0.1.2 11 | certifi==2023.11.17 12 | charset-normalizer==3.3.2 13 | click==8.1.7 14 | contourpy==1.2.0 15 | cycler==0.12.1 16 | datasets==2.16.1 17 | deepspeed==0.12.3 18 | dill==0.3.7 19 | distro==1.8.0 20 | docker-pycreds==0.4.0 21 | docstring-parser==0.15 22 | einops==0.7.0 23 | fastapi==0.109.0 24 | ffmpy==0.3.1 25 | filelock==3.13.1 26 | fonttools==4.45.0 27 | frozenlist==1.4.0 28 | fsspec==2023.10.0 29 | gitdb==4.0.11 30 | GitPython==3.1.42 31 | gradio==3.50.2 32 | gradio_client==0.6.1 33 | h11==0.14.0 34 | hjson==3.1.0 35 | httpcore==1.0.2 36 | httpx==0.25.1 37 | huggingface-hub==0.20.1 38 | idna==3.4 39 | importlib-resources==6.1.1 40 | jieba==0.42.1 41 | Jinja2==3.1.2 42 | joblib==1.3.2 43 | jsonschema==4.20.0 44 | jsonschema-specifications==2023.11.1 45 | kiwisolver==1.4.5 46 | markdown-it-py==3.0.0 47 | MarkupSafe==2.1.3 48 | matplotlib==3.8.2 49 | mdurl==0.1.2 50 | mpmath==1.3.0 51 | multidict==6.0.4 52 | multiprocess==0.70.15 53 | networkx==3.2.1 54 | ninja==1.11.1.1 55 | nltk==3.8.1 56 | numpy==1.26.2 57 | nvidia-cublas-cu12==12.1.3.1 58 | nvidia-cuda-cupti-cu12==12.1.105 59 | nvidia-cuda-nvrtc-cu12==12.1.105 60 | nvidia-cuda-runtime-cu12==12.1.105 61 | nvidia-cudnn-cu12==8.9.2.26 62 | nvidia-cufft-cu12==11.0.2.54 63 | nvidia-curand-cu12==10.3.2.106 64 | nvidia-cusolver-cu12==11.4.5.107 65 | nvidia-cusparse-cu12==12.1.0.106 66 | nvidia-nccl-cu12==2.18.1 67 | nvidia-nvjitlink-cu12==12.3.101 68 | nvidia-nvtx-cu12==12.1.105 69 | orjson==3.9.10 70 | openai==0.28.1 71 | packaging==23.2 72 | pandas==2.1.3 73 | peft==0.7.1 74 | Pillow==10.1.0 75 | protobuf==4.25.2 76 | psutil==5.9.6 77 | py-cpuinfo==9.0.0 78 | pyarrow==14.0.1 79 | pyarrow-hotfix==0.5 80 | pydantic==2.5.3 81 | pydantic_core==2.14.6 82 | pydub==0.25.1 83 | Pygments==2.17.1 84 | pynvml==11.5.0 85 | pyparsing==3.1.1 86 | python-dateutil==2.8.2 87 | python-multipart==0.0.6 88 | pytz==2023.3.post1 89 | PyYAML==6.0.1 90 | referencing==0.31.0 91 | regex==2023.10.3 92 | requests==2.31.0 93 | rich==13.7.0 94 | rouge-chinese==1.0.3 95 | rpds-py==0.13.1 96 | safetensors==0.4.0 97 | scipy==1.11.4 98 | semantic-version==2.10.0 99 | sentencepiece==0.1.99 100 | sentry-sdk==1.40.5 101 | setproctitle==1.3.3 102 | shortuuid==1.0.11 103 | shtab==1.6.4 104 | six==1.16.0 105 | smmap==5.0.1 106 | sniffio==1.3.0 107 | sse-starlette==1.8.2 108 | starlette==0.35.1 109 | sympy==1.12 110 | tiktoken==0.5.2 111 | tokenizers==0.15.0 112 | toolz==0.12.0 113 | torch==2.1.2 114 | torchaudio==2.1.2 115 | torchvision==0.16.2 116 | tqdm==4.66.1 117 | transformers==4.36.2 118 | transformers-stream-generator==0.0.4 119 | triton==2.1.0 120 | trl==0.7.9 121 | typing_extensions==4.8.0 122 | tyro==0.5.17 123 | tzdata==2023.3 124 | urllib3==2.1.0 125 | uvicorn==0.26.0 126 | wandb==0.16.3 127 | websockets==11.0.3 128 | xformers==0.0.23.post1 129 | xxhash==3.4.1 130 | yarl==1.9.3 131 | -------------------------------------------------------------------------------- /slides/Framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Framework.pdf -------------------------------------------------------------------------------- /slides/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Framework.png -------------------------------------------------------------------------------- /slides/Poster-DELLM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Poster-DELLM.pdf -------------------------------------------------------------------------------- /slides/Slides-DELLM.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Slides-DELLM.pdf --------------------------------------------------------------------------------