├── .gitattributes
├── .gitignore
├── LICENSE
├── README.md
├── dataset
└── preprocessor.py
├── model
├── bash.py
├── data
│ └── README.md
└── src
│ ├── data
│ ├── loader.py
│ ├── preprocess.py
│ ├── template.py
│ └── utils.py
│ ├── extras
│ ├── callbacks.py
│ ├── constants.py
│ ├── logging.py
│ ├── misc.py
│ ├── packages.py
│ └── ploting.py
│ ├── hparams
│ ├── data_args.py
│ ├── evaluation_args.py
│ ├── finetuning_args.py
│ ├── generating_args.py
│ └── model_args.py
│ ├── model
│ ├── adapter.py
│ ├── loader.py
│ ├── parser.py
│ ├── patcher.py
│ └── utils.py
│ └── train
│ ├── dpo
│ ├── collator.py
│ ├── trainer.py
│ └── workflow.py
│ ├── sft
│ ├── metric.py
│ ├── trainer.py
│ └── workflow.py
│ ├── tuner.py
│ └── utils.py
├── requirements.txt
└── slides
├── Framework.pdf
├── Framework.png
├── Poster-DELLM.pdf
└── Slides-DELLM.pdf
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Rcrossmeister
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Knowledge-to-SQL
2 | **[2024/10] Check our video presentation in [Underline](https://underline.io/events/466/posters/18354/poster/102023-knowledge-to-sql-enhancing-sql-generation-with-data-expert-llm)!**
3 |
4 | **[2024/08] The video presentation of our paper will be available soon.**
5 |
6 | **[2024/08] The presentation of our paper are scheduled at Virtual Poster Session 2, check the poster and slides [here](./slides).**
7 |
8 | **[2024/05] Our paper is accepted as a findings paper in ACL2024!**
9 |
10 | We propose a novel framework **Knowledge-to-SQL** that leverages **Data Expert Large Language Model (DELLM)** to enhance SQL generation, the paper is available [here](https://aclanthology.org/2024.findings-acl.653.pdf).
11 |
12 |
13 |
14 | ## Setup
15 |
16 | ### Environment
17 |
18 | **The GPU resources we use in our study is 4*A800-SXM4-80G with the corresponding CUDA version 12.1,** we strongly recommend using the torch version above 2.0.
19 |
20 | ```shell
21 | # Clone the repository
22 | git https://github.com/Rcrossmeister/Knowledge-to-SQL.git
23 | cd ./Knowledge-to-SQL
24 |
25 | # Create the conda environment
26 | conda create -n dellm python=3.11.3
27 | conda activate dellm
28 |
29 | # Install the required packages
30 | pip install -r requirements.txt
31 | ```
32 |
33 | ### Dataset
34 |
35 | We mainly focus on **[BIRD](https://bird-bench.github.io/)** dataset in our study, we also support **[Spider](https://yale-lily.github.io/spider)** dataset for robustness study.
36 |
37 | ## Training
38 |
39 | The training implementaion was inspired by **[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory)**, you can check their technical report [here](https://arxiv.org/abs/2403.13372).
40 |
41 | ### Quick Start
42 |
43 | We provide a script to quick start upon BIRD dataset
44 |
45 | ## Citation
46 |
47 | Please cite our paper if you include Knowledge-to-SQL in your work:
48 |
49 | ```
50 | @inproceedings{hong2024knowledge,
51 | title = "Knowledge-to-{SQL}: Enhancing {SQL} Generation with Data Expert {LLM}",
52 | author = "Hong, Zijin and
53 | Yuan, Zheng and
54 | Chen, Hao and
55 | Zhang, Qinggang and
56 | Huang, Feiran and
57 | Huang, Xiao",
58 | booktitle = "Findings of the Association for Computational Linguistics ACL 2024",
59 | year = "2024"
60 | }
61 | ```
62 |
--------------------------------------------------------------------------------
/dataset/preprocessor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import argparse
3 | import fnmatch
4 | import json
5 | import os
6 | import pdb
7 | import pickle
8 | import re
9 | import sqlite3
10 | from typing import Dict, List, Tuple
11 |
12 | import backoff
13 | import openai
14 | import pandas as pd
15 | import sqlparse
16 | from tqdm import tqdm
17 |
18 | import spacy
19 | import json
20 | import sqlite3
21 | from nltk.translate.bleu_score import sentence_bleu
22 | from nltk.translate.bleu_score import SmoothingFunction
23 | from tqdm import tqdm
24 | import numpy as np
25 |
26 | '''openai configure'''
27 |
28 | openai.debug = True
29 |
30 | def new_directory(path):
31 | if not os.path.exists(path):
32 | os.makedirs(path)
33 |
34 | def get_db_schemas(bench_root: str, db_name: str) -> Dict[str, str]:
35 | """
36 | Read an sqlite file, and return the CREATE commands for each of the tables in the database.
37 | """
38 | asdf = 'database' if bench_root == 'spider' else 'databases'
39 | with sqlite3.connect(f'file:{bench_root}/{asdf}/{db_name}/{db_name}.sqlite?mode=ro', uri=True) as conn:
40 | # conn.text_factory = bytes
41 | cursor = conn.cursor()
42 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
43 | tables = cursor.fetchall()
44 | schemas = {}
45 | for table in tables:
46 | cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0]))
47 | schemas[table[0]] = cursor.fetchone()[0]
48 |
49 | return schemas
50 |
51 | def nice_look_table(column_names: list, values: list):
52 | rows = []
53 | # Determine the maximum width of each column
54 | widths = [max(len(str(value[i])) for value in values + [column_names]) for i in range(len(column_names))]
55 |
56 | # Print the column names
57 | header = ''.join(f'{column.rjust(width)} ' for column, width in zip(column_names, widths))
58 | # print(header)
59 | # Print the values
60 | for value in values:
61 | row = ''.join(f'{str(v).rjust(width)} ' for v, width in zip(value, widths))
62 | rows.append(row)
63 | rows = "\n".join(rows)
64 | final_output = header + '\n' + rows
65 | return final_output
66 |
67 |
68 | def generate_schema_prompt(db_path, num_rows=None):
69 | # extract create ddls
70 | '''
71 | :param root_place:
72 | :param db_name:
73 | :return:
74 | '''
75 | full_schema_prompt_list = []
76 | conn = sqlite3.connect(db_path)
77 | # Create a cursor object
78 | cursor = conn.cursor()
79 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
80 | tables = cursor.fetchall()
81 | schemas = {}
82 | for table in tables:
83 | if table == 'sqlite_sequence':
84 | continue
85 | cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name='{}';".format(table[0]))
86 | create_prompt = cursor.fetchone()[0]
87 | schemas[table[0]] = create_prompt
88 | if num_rows:
89 | cur_table = table[0]
90 | if cur_table in ['order', 'by', 'group']:
91 | cur_table = "`{}`".format(cur_table)
92 |
93 | cursor.execute("SELECT * FROM {} LIMIT {}".format(cur_table, num_rows))
94 | column_names = [description[0] for description in cursor.description]
95 | values = cursor.fetchall()
96 | rows_prompt = nice_look_table(column_names=column_names, values=values)
97 | verbose_prompt = "/* \n {} example rows: \n SELECT * FROM {} LIMIT {}; \n {} \n */".format(num_rows,
98 | cur_table,
99 | num_rows,
100 | rows_prompt)
101 | schemas[table[0]] = "{} \n {}".format(create_prompt, verbose_prompt)
102 |
103 | for k, v in schemas.items():
104 | full_schema_prompt_list.append(v)
105 |
106 | schema_prompt = "\n\n".join(full_schema_prompt_list)
107 |
108 | return schema_prompt
109 |
110 |
111 | def generate_comment_prompt(question, knowledge=None):
112 | question_prompt = "{}".format(question)
113 | knowledge_prompt = "{}".format(knowledge)
114 |
115 | return question_prompt, knowledge_prompt
116 |
117 | def generate_combined_prompts_one(db_path, question, knowledge=None):
118 | schema_prompt = generate_schema_prompt(db_path, num_rows=None) # This is the entry to collect values
119 | question_prompt, knowledge_prompt = generate_comment_prompt(question, knowledge)
120 |
121 | return question_prompt, knowledge_prompt, schema_prompt
122 |
123 | def semantic_similarity(column, question):
124 | nlp = spacy.load("en_core_web_lg")
125 | column_doc = nlp(column)
126 | question_doc = nlp(question)
127 | similarity = question_doc.similarity(column_doc)
128 | return similarity
129 |
130 | def nice_look_table(column_names: list, values: list):
131 | rows = []
132 | # Determine the maximum width of each column
133 | widths = [max(len(str(value[i])) for value in values + [column_names]) for i in range(len(column_names))]
134 |
135 | # Print the column names
136 | header = ''.join(f'{column.rjust(width)} ' for column, width in zip(column_names, widths))
137 | # print(header)
138 | # Print the values
139 | for value in values:
140 | row = ''.join(f'{str(v).rjust(width)} ' for v, width in zip(value, widths))
141 | rows.append(row)
142 | rows = "\n".join(rows)
143 | final_output = header + '\n' + rows
144 | return final_output
145 |
146 | def get_tablename_columnList(db_path):
147 | full_schema_prompt_list = []
148 | conn = sqlite3.connect(db_path)
149 | cursor = conn.cursor()
150 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
151 | tables = cursor.fetchall()
152 | schemas = {}
153 | for table in tables:
154 | if table == 'sqlite_sequence':
155 | continue
156 | cursor.execute("SELECT * FROM `{}`".format(table[0]))
157 | col_name_list = [tuple[0] for tuple in cursor.description]
158 | schemas[table[0]] = col_name_list
159 | for k, v in schemas.items():
160 | full_schema_prompt_list.append(v)
161 | return schemas
162 |
163 | def get_sub_table(db_path, selected_tables_columns, num_rows=3):
164 | if len(selected_tables_columns) == 0:
165 | return ""
166 | conn = sqlite3.connect(db_path)
167 | cursor = conn.cursor()
168 | subtable_prompt_list = []
169 | for table_name, column_list in selected_tables_columns.items():
170 | execute_query = "SELECT {} FROM `{}` LIMIT {}".format(", ".join(column_list), table_name, num_rows)
171 | cursor.execute(execute_query)
172 | column_names = [description[0] for description in cursor.description]
173 | values = cursor.fetchall()
174 | rows_prompt = nice_look_table(column_names=column_names, values=values)
175 | verbose_prompt = "/*\n{} rows of {} key-columns in Table {}:\n{}\n*/". \
176 | format(num_rows, len(column_list), table_name, rows_prompt)
177 | subtable_prompt_list.append(verbose_prompt)
178 | subtable_prompt = "\n".join(subtable_prompt_list)
179 | return subtable_prompt
180 |
181 | def get_subtable_prompt(db_path, question):
182 | tableName_columnList_dict = get_tablename_columnList(db_path)
183 | smooth_fn = SmoothingFunction().method7
184 | bleu_score = []
185 | table_column_pair_list = []
186 | for table_name, columnList in tableName_columnList_dict.items():
187 | for column in columnList:
188 | table_column_pair = table_name + " <_> " + column
189 | bleu = sentence_bleu([column], question, smoothing_function=smooth_fn)
190 | bleu_score.append(bleu)
191 | table_column_pair_list.append(table_column_pair)
192 | table_column_pair_list = [table_column_pair_list[i] for i in range(len(bleu_score)) if bleu_score[i] >= 0.08]
193 | bleu_score = [s for s in bleu_score if s >= 0.08]
194 | if len(bleu_score) == 0:
195 | return ""
196 | sorted_id = sorted(range(len(bleu_score)), key=lambda k: bleu_score[k], reverse=True)
197 | sorted_bleu_score = [bleu_score[i] for i in sorted_id]
198 | sorted_table_column_pair = [table_column_pair_list[i] for i in sorted_id]
199 | top_K_table_column_pair = sorted_table_column_pair[:3]
200 |
201 | selected_tables_columns = {}
202 | for table_column_pair in top_K_table_column_pair:
203 | table_name, column_name = table_column_pair.split(" <_> ")
204 | column_name = "`{}`".format(column_name)
205 | if table_name in selected_tables_columns:
206 | selected_tables_columns[table_name].append(column_name)
207 | else:
208 | selected_tables_columns[table_name] = [column_name]
209 | subtable_prompt = get_sub_table(db_path, selected_tables_columns, num_rows=3)
210 | return subtable_prompt
211 |
212 | def construct_ekg_data(db_path_list, question_list, knowledge_list=None):
213 | '''
214 | :param db_path: str
215 | :param question_list: []
216 | :return: dict of responses collected from openai
217 | '''
218 |
219 | output_list = []
220 | for i, question in tqdm(enumerate(question_list)):
221 | # print('--------------------- processing {}th question ---------------------'.format(i))
222 | # print('the question is: {}'.format(question))
223 |
224 | question, knowledge, schema = generate_combined_prompts_one(db_path=db_path_list[i], question=question,
225 | knowledge=knowledge_list[i])
226 | knowledge.replace(';', '')
227 | output = {
228 | 'instruction': 'You are a helpful assistant. '
229 | 'Please generate a evidence base on the given database schema and question '
230 | 'The evidence should use the database information(schema) to explain the question '
231 | 'The evidence aim to help language model to generate a more accurate SQL to answer the question. ',
232 | 'input': '--schema: ' + schema + ' '
233 | '--question: ' + question,
234 | 'output': knowledge
235 | }
236 | output_list.append(output)
237 |
238 | return output_list
239 |
240 | def question_package(data_json, knowledge=False):
241 | question_list = []
242 | for data in data_json:
243 | question_list.append(data['question'])
244 |
245 | return question_list
246 |
247 | def knowledge_package(data_json, knowledge=False):
248 | knowledge_list = []
249 | for data in data_json:
250 | knowledge_list.append(data['evidence'])
251 |
252 | return knowledge_list
253 |
254 | def decouple_question_schema(datasets, db_root_path):
255 | question_list = []
256 | db_path_list = []
257 | knowledge_list = []
258 | for i, data in enumerate(datasets):
259 | question_list.append(data['question'])
260 | cur_db_path = db_root_path + data['db_id'] + '/' + data['db_id'] + '.sqlite'
261 | db_path_list.append(cur_db_path)
262 | knowledge_list.append(data['evidence'])
263 |
264 | return question_list, db_path_list, knowledge_list
265 |
266 | if __name__ == '__main__':
267 | args_parser = argparse.ArgumentParser()
268 | args_parser.add_argument('--data_path', type=str, default='')
269 | args_parser.add_argument('--db_root_path', type=str, default='')
270 | args_parser.add_argument('--output_path', type=str, default='')
271 | args = args_parser.parse_args()
272 |
273 | eval_data = json.load(open(args.eval_path, 'r'))
274 |
275 | question_list, db_path_list, knowledge_list = decouple_question_schema(datasets=eval_data,
276 | db_root_path=args.db_root_path)
277 | assert len(question_list) == len(db_path_list) == len(knowledge_list)
278 |
279 | json_withSubTable = []
280 | for i in tqdm(range(len(args.data_path))):
281 | instance = args.data_path[i]
282 | db_id = instance['db_id']
283 | question = instance['question']
284 | db_path = args.db_root_path + db_id + "/" + db_id + ".sqlite"
285 | subtable_prompt = get_subtable_prompt(db_path, question)
286 | instance["subtable_prompt"] = subtable_prompt
287 | if "question_id" not in instance:
288 | instance["question_id"] = i
289 | json_withSubTable.append(instance)
290 |
291 | ekg_data = construct_ekg_data(db_path_list=db_path_list, question_list=question_list, knowledge_list=knowledge_list)
292 |
293 | with open(args.output_path, 'w', encoding='utf-8') as file:
294 | json.dump(ekg_data, file, indent=4)
295 | file.close()
296 |
297 | print('successfully construct results')
--------------------------------------------------------------------------------
/model/bash.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING, Any, Dict, List, Optional
3 | from transformers import PreTrainedModel
4 |
5 | from src.extras.callbacks import LogCallback
6 | from src.extras.logging import get_logger
7 | from src.model.loader import load_model_and_tokenizer
8 | from src.model.parser import get_train_args, get_infer_args
9 | from src.train.sft.workflow import run_sft
10 | from src.train.dpo.workflow import run_dpo
11 |
12 | if TYPE_CHECKING:
13 | from transformers import TrainerCallback
14 |
15 |
16 | logger = get_logger(__name__)
17 |
18 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
19 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
20 | callbacks = [LogCallback()] if callbacks is None else callbacks
21 |
22 | if finetuning_args.stage == "sft":
23 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
24 | elif finetuning_args.stage == "dpo":
25 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
26 | else:
27 | raise ValueError("Unknown task.")
28 |
29 |
30 | def export_model(args: Optional[Dict[str, Any]] = None):
31 | model_args, _, finetuning_args, _ = get_infer_args(args)
32 |
33 | if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
34 | raise ValueError("Please merge adapters before quantizing the model.")
35 |
36 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
37 |
38 | if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
39 | raise ValueError("Cannot merge adapters to a quantized model.")
40 |
41 | if not isinstance(model, PreTrainedModel):
42 | raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
43 |
44 | model.config.use_cache = True
45 | if getattr(model.config, "torch_dtype", None) == "bfloat16":
46 | model = model.to(torch.bfloat16).to("cpu")
47 | else:
48 | model = model.to(torch.float16).to("cpu")
49 | setattr(model.config, "torch_dtype", "float16")
50 |
51 | model.save_pretrained(
52 | save_directory=model_args.export_dir,
53 | max_shard_size="{}GB".format(model_args.export_size),
54 | safe_serialization=(not model_args.export_legacy_format)
55 | )
56 |
57 | try:
58 | tokenizer.padding_side = "left" # restore padding side
59 | tokenizer.init_kwargs["padding_side"] = "left"
60 | tokenizer.save_pretrained(model_args.export_dir)
61 | except:
62 | logger.warning("Cannot save tokenizer, please copy the files manually.")
63 |
64 |
65 | def main():
66 | run_exp()
67 |
68 | if __name__ == "__main__":
69 | main()
--------------------------------------------------------------------------------
/model/data/README.md:
--------------------------------------------------------------------------------
1 | # Data Description
2 |
3 | To support your private data, please refer to this README for more details and revise the `./dataset_info.json` file.
--------------------------------------------------------------------------------
/model/src/data/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import TYPE_CHECKING, Any, Dict, List, Union
3 |
4 | from datasets import concatenate_datasets, interleave_datasets, load_dataset, load_from_disk
5 |
6 | from .utils import checksum
7 | from ..extras.constants import FILEEXT2TYPE
8 | from ..extras.logging import get_logger
9 |
10 | if TYPE_CHECKING:
11 | from datasets import Dataset, IterableDataset
12 | from ..hparams.model_args import ModelArguments
13 | from ..hparams.data_args import DataArguments
14 |
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | def get_dataset(
20 | model_args: "ModelArguments",
21 | data_args: "DataArguments"
22 | ) -> Union["Dataset", "IterableDataset"]:
23 | max_samples = data_args.max_samples
24 | all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
25 |
26 | if data_args.cache_path is not None:
27 | if os.path.exists(data_args.cache_path):
28 | logger.warning("Loading dataset from disk will ignore other data arguments.")
29 | dataset = load_from_disk(data_args.cache_path)
30 | if data_args.streaming:
31 | dataset = dataset.to_iterable_dataset()
32 | return dataset
33 | elif data_args.streaming:
34 | raise ValueError("Turn off dataset streaming to save cache files.")
35 |
36 | for dataset_attr in data_args.dataset_list:
37 | logger.info("Loading dataset {}...".format(dataset_attr))
38 |
39 | data_path, data_name, data_dir, data_files = None, None, None, None
40 | if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
41 | data_path = dataset_attr.dataset_name
42 | data_name = dataset_attr.subset
43 | data_dir = dataset_attr.folder
44 | elif dataset_attr.load_from == "script":
45 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
46 | data_name = dataset_attr.subset
47 | elif dataset_attr.load_from == "file":
48 | data_files = []
49 | local_path: str = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
50 | if os.path.isdir(local_path): # is directory
51 | for file_name in os.listdir(local_path):
52 | data_files.append(os.path.join(local_path, file_name))
53 | if data_path is None:
54 | data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
55 | else:
56 | assert data_path == FILEEXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
57 | elif os.path.isfile(local_path): # is file
58 | data_files.append(local_path)
59 | data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
60 | else:
61 | raise ValueError("File not found.")
62 |
63 | assert data_path, "File extension must be txt, csv, json or jsonl."
64 | checksum(data_files, dataset_attr.dataset_sha1)
65 | else:
66 | raise NotImplementedError
67 |
68 | if dataset_attr.load_from == "ms_hub":
69 | try:
70 | from modelscope import MsDataset
71 | from modelscope.utils.config_ds import MS_DATASETS_CACHE
72 |
73 | cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
74 | dataset = MsDataset.load(
75 | dataset_name=data_path,
76 | subset_name=data_name,
77 | data_dir=data_dir,
78 | data_files=data_files,
79 | split=data_args.split,
80 | cache_dir=cache_dir,
81 | token=model_args.ms_hub_token,
82 | use_streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
83 | ).to_hf_dataset()
84 | except ImportError:
85 | raise ImportError("Please install modelscope via `pip install modelscope -U`")
86 | else:
87 | dataset = load_dataset(
88 | path=data_path,
89 | name=data_name,
90 | data_dir=data_dir,
91 | data_files=data_files,
92 | split=data_args.split,
93 | cache_dir=model_args.cache_dir,
94 | token=model_args.hf_hub_token,
95 | streaming=(data_args.streaming and (dataset_attr.load_from != "file"))
96 | )
97 |
98 | if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
99 | dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
100 |
101 | if max_samples is not None: # truncate dataset
102 | dataset = dataset.select(range(min(len(dataset), max_samples)))
103 |
104 | def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
105 | # convert dataset from sharegpt format to alpaca format
106 | outputs = {"prompt": [], "query": [], "response": [], "history": [], "system": []}
107 | for i, msg_list in enumerate(examples[dataset_attr.messages]):
108 | msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
109 | if len(msg_list) == 0:
110 | continue
111 |
112 | msg_pairs = []
113 | user_role, assistant_role = None, None
114 | for idx in range(0, len(msg_list), 2):
115 | if user_role is None and assistant_role is None:
116 | user_role = msg_list[idx][dataset_attr.role]
117 | assistant_role = msg_list[idx + 1][dataset_attr.role]
118 | else:
119 | if (
120 | msg_list[idx][dataset_attr.role] != user_role
121 | or msg_list[idx+1][dataset_attr.role] != assistant_role
122 | ):
123 | raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
124 | msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
125 |
126 | if len(msg_pairs) != 0:
127 | outputs["prompt"].append(msg_pairs[-1][0])
128 | outputs["query"].append("")
129 | outputs["response"].append(msg_pairs[-1][1])
130 | outputs["history"].append(msg_pairs[:-1] if len(msg_pairs) > 1 else None)
131 | outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
132 |
133 | return outputs
134 |
135 | if dataset_attr.formatting == "sharegpt": # convert format
136 | column_names = list(next(iter(dataset)).keys())
137 | kwargs = {}
138 | if not data_args.streaming:
139 | kwargs = dict(
140 | num_proc=data_args.preprocessing_num_workers,
141 | load_from_cache_file=(not data_args.overwrite_cache),
142 | desc="Converting format of dataset"
143 | )
144 |
145 | dataset = dataset.map(
146 | convert_format,
147 | batched=True,
148 | remove_columns=column_names,
149 | **kwargs
150 | )
151 | else:
152 | for column_name in ["prompt", "query", "response", "history", "system"]: # align dataset
153 | if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
154 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
155 |
156 | all_datasets.append(dataset)
157 |
158 | if len(data_args.dataset_list) == 1:
159 | return all_datasets[0]
160 | elif data_args.mix_strategy == "concat":
161 | if data_args.streaming:
162 | logger.warning("The samples between different datasets will not be mixed in streaming mode.")
163 | return concatenate_datasets(all_datasets)
164 | elif data_args.mix_strategy.startswith("interleave"):
165 | if not data_args.streaming:
166 | logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
167 | return interleave_datasets(
168 | datasets=all_datasets,
169 | probabilities=data_args.interleave_probs,
170 | seed=data_args.seed,
171 | stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
172 | )
173 | else:
174 | raise ValueError("Unknown mixing strategy.")
175 |
--------------------------------------------------------------------------------
/model/src/data/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tiktoken
3 | from itertools import chain
4 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Union
5 |
6 | from .template import get_template_and_fix_tokenizer
7 | from ..extras.constants import IGNORE_INDEX
8 | from ..extras.logging import get_logger
9 |
10 | if TYPE_CHECKING:
11 | from datasets import Dataset, IterableDataset
12 | from transformers import Seq2SeqTrainingArguments
13 | from transformers.tokenization_utils import PreTrainedTokenizer
14 | from ..hparams.data_args import DataArguments
15 |
16 |
17 | logger = get_logger(__name__)
18 |
19 |
20 | def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
21 | for i in range(len(examples["prompt"])):
22 | query, response = examples["prompt"][i], examples["response"][i]
23 | query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
24 | history = examples["history"][i] if "history" in examples else None
25 | system = examples["system"][i] if "system" in examples else None
26 | yield query, response, history, system
27 |
28 |
29 | def infer_max_len(source_len: int, target_len: int, data_args: "DataArguments") -> Tuple[int, int]:
30 | max_target_len = int(data_args.cutoff_len * (target_len / (source_len + target_len)))
31 | max_target_len = max(max_target_len, data_args.reserved_label_len)
32 | max_source_len = data_args.cutoff_len - max_target_len
33 | return max_source_len, max_target_len
34 |
35 |
36 | def preprocess_dataset(
37 | dataset: Union["Dataset", "IterableDataset"],
38 | tokenizer: "PreTrainedTokenizer",
39 | data_args: "DataArguments",
40 | training_args: "Seq2SeqTrainingArguments",
41 | stage: Literal["pt", "sft", "rm", "ppo"]
42 | ) -> Union["Dataset", "IterableDataset"]:
43 | template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
44 |
45 | if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
46 | return dataset # already preprocessed
47 |
48 | if data_args.train_on_prompt and template.efficient_eos:
49 | raise ValueError("Current template does not support `train_on_prompt`.")
50 |
51 | def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
52 | # build grouped texts with format `X1 X2 X3 ...`
53 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
54 | kwargs = dict(allowed_special="all")
55 | else:
56 | kwargs = dict(add_special_tokens=True)
57 |
58 | if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
59 | add_eos_token_flag = getattr(tokenizer, "add_eos_token")
60 | setattr(tokenizer, "add_eos_token", True)
61 |
62 | tokenized_examples = tokenizer(examples["prompt"], **kwargs)
63 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
64 | total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
65 | block_size = data_args.cutoff_len
66 | # we drop the small remainder, and if the total_length < block_size, we exclude this batch
67 | total_length = (total_length // block_size) * block_size
68 | # split by chunks of cutoff_len
69 | result = {
70 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
71 | for k, t in concatenated_examples.items()
72 | }
73 | # make sure the saved tokenizer is the same as the original one
74 | if hasattr(tokenizer, "add_eos_token"):
75 | setattr(tokenizer, "add_eos_token", add_eos_token_flag)
76 | return result
77 |
78 | def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
79 | # build inputs with format ` X Y ` and labels with format ` ... Y `
80 | # for multiturn examples, we only mask the prompt part in each prompt-response pair.
81 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
82 |
83 | for query, response, history, system in construct_example(examples):
84 | if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
85 | continue
86 |
87 | input_ids, labels = [], []
88 | for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
89 | tokenizer, query, response, history, system
90 | )):
91 | source_len, target_len = len(source_ids), len(target_ids)
92 | max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
93 | if source_len > max_source_len:
94 | source_ids = source_ids[:max_source_len]
95 | if target_len > max_target_len:
96 | target_ids = target_ids[:max_target_len]
97 |
98 | if data_args.train_on_prompt:
99 | source_mask = source_ids
100 | elif turn_idx != 0 and template.efficient_eos:
101 | source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
102 | else:
103 | source_mask = [IGNORE_INDEX] * len(source_ids)
104 |
105 | input_ids += source_ids + target_ids
106 | labels += source_mask + target_ids
107 |
108 | if template.efficient_eos:
109 | input_ids += [tokenizer.eos_token_id]
110 | labels += [tokenizer.eos_token_id]
111 |
112 | if len(input_ids) > data_args.cutoff_len:
113 | input_ids = input_ids[:data_args.cutoff_len]
114 | labels = labels[:data_args.cutoff_len]
115 |
116 | model_inputs["input_ids"].append(input_ids)
117 | model_inputs["attention_mask"].append([1] * len(input_ids))
118 | model_inputs["labels"].append(labels)
119 |
120 | return model_inputs
121 |
122 | def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
123 | # build inputs with format ` X1 Y1 X2 Y2 `
124 | # and labels with format ` ... Y1 ... Y2 `
125 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
126 | input_ids, labels = [], []
127 | for query, response, history, system in construct_example(examples):
128 | if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
129 | continue
130 |
131 | for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
132 | tokenizer, query, response, history, system
133 | )):
134 | if data_args.train_on_prompt:
135 | source_mask = source_ids
136 | elif turn_idx != 0 and template.efficient_eos:
137 | source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
138 | else:
139 | source_mask = [IGNORE_INDEX] * len(source_ids)
140 | input_ids += source_ids + target_ids
141 | labels += source_mask + target_ids
142 |
143 | if template.efficient_eos:
144 | input_ids += [tokenizer.eos_token_id]
145 | labels += [tokenizer.eos_token_id]
146 |
147 | total_length = len(input_ids)
148 | block_size = data_args.cutoff_len
149 | # we drop the small remainder, and if the total_length < block_size, we exclude this batch
150 | total_length = (total_length // block_size) * block_size
151 | # split by chunks of cutoff_len
152 | for i in range(0, total_length, block_size):
153 | model_inputs["input_ids"].append(input_ids[i: i + block_size])
154 | model_inputs["attention_mask"].append([1] * block_size)
155 | model_inputs["labels"].append(labels[i: i + block_size])
156 |
157 | return model_inputs
158 |
159 | def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
160 | # build inputs with format ` X` and labels with format `Y `
161 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
162 |
163 | for query, response, history, system in construct_example(examples):
164 | if not (isinstance(query, str) and query != ""):
165 | continue
166 |
167 | input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
168 |
169 | if template.efficient_eos:
170 | labels += [tokenizer.eos_token_id]
171 |
172 | if len(input_ids) > data_args.cutoff_len:
173 | input_ids = input_ids[:data_args.cutoff_len]
174 | if len(labels) > data_args.cutoff_len:
175 | labels = labels[:data_args.cutoff_len]
176 |
177 | model_inputs["input_ids"].append(input_ids)
178 | model_inputs["attention_mask"].append([1] * len(input_ids))
179 | model_inputs["labels"].append(labels)
180 |
181 | return model_inputs
182 |
183 | def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
184 | # build input pairs with format ` X`, `Y1 ` and `Y2 `
185 | model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
186 | for query, response, history, system in construct_example(examples):
187 | if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
188 | continue
189 |
190 | prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
191 | _, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
192 |
193 | if template.efficient_eos:
194 | chosen_ids += [tokenizer.eos_token_id]
195 | rejected_ids += [tokenizer.eos_token_id]
196 |
197 | source_len, target_len = len(prompt_ids), max(len(chosen_ids), len(rejected_ids))
198 | max_source_len, max_target_len = infer_max_len(source_len, target_len, data_args)
199 | if source_len > max_source_len:
200 | prompt_ids = prompt_ids[:max_source_len]
201 | if target_len > max_target_len:
202 | chosen_ids = chosen_ids[:max_target_len]
203 | rejected_ids = rejected_ids[:max_target_len]
204 |
205 | model_inputs["prompt_ids"].append(prompt_ids)
206 | model_inputs["chosen_ids"].append(chosen_ids)
207 | model_inputs["rejected_ids"].append(rejected_ids)
208 |
209 | return model_inputs
210 |
211 | def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
212 | print("input_ids:\n{}".format(example["input_ids"]))
213 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
214 | print("label_ids:\n{}".format(example["labels"]))
215 | print("labels:\n{}".format(
216 | tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
217 | ))
218 |
219 | def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
220 | print("prompt_ids:\n{}".format(example["prompt_ids"]))
221 | print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
222 | print("chosen_ids:\n{}".format(example["chosen_ids"]))
223 | print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
224 | print("rejected_ids:\n{}".format(example["rejected_ids"]))
225 | print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
226 |
227 | def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
228 | print("input_ids:\n{}".format(example["input_ids"]))
229 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
230 |
231 | if stage == "pt":
232 | preprocess_func = preprocess_pretrain_dataset
233 | print_function = print_unsupervised_dataset_example
234 | elif stage == "sft" and not training_args.predict_with_generate:
235 | preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
236 | print_function = print_supervised_dataset_example
237 | elif stage == "rm":
238 | preprocess_func = preprocess_pairwise_dataset
239 | print_function = print_pairwise_dataset_example
240 | else:
241 | preprocess_func = preprocess_unsupervised_dataset
242 | print_function = print_unsupervised_dataset_example
243 |
244 | with training_args.main_process_first(desc="dataset map pre-processing"):
245 | column_names = list(next(iter(dataset)).keys())
246 | kwargs = {}
247 | if not data_args.streaming:
248 | kwargs = dict(
249 | num_proc=data_args.preprocessing_num_workers,
250 | load_from_cache_file=(not data_args.overwrite_cache),
251 | desc="Running tokenizer on dataset"
252 | )
253 |
254 | dataset = dataset.map(
255 | preprocess_func,
256 | batched=True,
257 | remove_columns=column_names,
258 | **kwargs
259 | )
260 |
261 | if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
262 | if training_args.should_save:
263 | dataset.save_to_disk(data_args.cache_path)
264 | logger.info("Dataset cache saved at {}.".format(data_args.cache_path))
265 |
266 | if training_args.should_log:
267 | try:
268 | print_function(next(iter(dataset)))
269 | except StopIteration:
270 | raise RuntimeError("Empty dataset!")
271 |
272 | return dataset
273 |
--------------------------------------------------------------------------------
/model/src/data/template.py:
--------------------------------------------------------------------------------
1 | import tiktoken
2 | from dataclasses import dataclass
3 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
4 |
5 | from ..extras.logging import get_logger
6 |
7 | if TYPE_CHECKING:
8 | from transformers import PreTrainedTokenizer
9 |
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | @dataclass
15 | class Template:
16 |
17 | prefix: List[Union[str, Dict[str, str]]]
18 | prompt: List[Union[str, Dict[str, str]]]
19 | system: str
20 | sep: List[Union[str, Dict[str, str]]]
21 | stop_words: List[str]
22 | use_history: bool
23 | efficient_eos: bool
24 | replace_eos: bool
25 |
26 | def encode_oneturn(
27 | self,
28 | tokenizer: "PreTrainedTokenizer",
29 | query: str,
30 | resp: str,
31 | history: Optional[List[Tuple[str, str]]] = None,
32 | system: Optional[str] = None
33 | ) -> Tuple[List[int], List[int]]:
34 | r"""
35 | Returns a single pair of token ids representing prompt and response respectively.
36 | """
37 | system, history = self._format(query, resp, history, system)
38 | encoded_pairs = self._encode(tokenizer, system, history)
39 | prompt_ids = []
40 | for query_ids, resp_ids in encoded_pairs[:-1]:
41 | prompt_ids = prompt_ids + query_ids + resp_ids
42 | prompt_ids = prompt_ids + encoded_pairs[-1][0]
43 | answer_ids = encoded_pairs[-1][1]
44 | return prompt_ids, answer_ids
45 |
46 | def encode_multiturn(
47 | self,
48 | tokenizer: "PreTrainedTokenizer",
49 | query: str,
50 | resp: str,
51 | history: Optional[List[Tuple[str, str]]] = None,
52 | system: Optional[str] = None
53 | ) -> List[Tuple[List[int], List[int]]]:
54 | r"""
55 | Returns multiple pairs of token ids representing prompts and responses respectively.
56 | """
57 | system, history = self._format(query, resp, history, system)
58 | encoded_pairs = self._encode(tokenizer, system, history)
59 | return encoded_pairs
60 |
61 | def _format(
62 | self,
63 | query: str,
64 | resp: str,
65 | history: Optional[List[Tuple[str, str]]] = None,
66 | system: Optional[str] = None
67 | ) -> Tuple[str, List[Tuple[str, str]]]:
68 | r"""
69 | Aligns inputs to the standard format.
70 | """
71 | system = system or self.system # use system if provided
72 | history = history if (history and self.use_history) else []
73 | history = history + [(query, resp)]
74 | return system, history
75 |
76 | def _get_special_ids(
77 | self,
78 | tokenizer: "PreTrainedTokenizer"
79 | ) -> Tuple[List[int], List[int]]:
80 | if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
81 | bos_ids = [tokenizer.bos_token_id]
82 | else: # baichuan, gpt2, qwen, yi models have no bos token
83 | bos_ids = []
84 |
85 | if tokenizer.eos_token_id is None:
86 | raise ValueError("EOS token is required.")
87 |
88 | if self.efficient_eos:
89 | eos_ids = []
90 | else:
91 | eos_ids = [tokenizer.eos_token_id]
92 |
93 | return bos_ids, eos_ids
94 |
95 | def _encode(
96 | self,
97 | tokenizer: "PreTrainedTokenizer",
98 | system: str,
99 | history: List[Tuple[str, str]]
100 | ) -> List[Tuple[List[int], List[int]]]:
101 | r"""
102 | Encodes formatted inputs to pairs of token ids.
103 | Turn 0: bos + prefix + sep + query resp + eos
104 | Turn t: sep + bos + query resp + eos
105 | """
106 | bos_ids, eos_ids = self._get_special_ids(tokenizer)
107 | sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
108 | encoded_pairs = []
109 | for turn_idx, (query, resp) in enumerate(history):
110 | if turn_idx == 0:
111 | prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
112 | if len(prefix_ids) != 0: # has prefix
113 | prefix_ids = bos_ids + prefix_ids + sep_ids
114 | else:
115 | prefix_ids = bos_ids
116 | else:
117 | prefix_ids = sep_ids + bos_ids
118 |
119 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx+1))
120 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
121 | encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
122 | return encoded_pairs
123 |
124 | def _convert_inputs_to_ids(
125 | self,
126 | tokenizer: "PreTrainedTokenizer",
127 | context: List[Union[str, Dict[str, str]]],
128 | system: Optional[str] = None,
129 | query: Optional[str] = None,
130 | idx: Optional[str] = None
131 | ) -> List[int]:
132 | r"""
133 | Converts context to token ids.
134 | """
135 | if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
136 | kwargs = dict(allowed_special="all")
137 | else:
138 | kwargs = dict(add_special_tokens=False)
139 |
140 | token_ids = []
141 | for elem in context:
142 | if isinstance(elem, str):
143 | elem = elem.replace("{{system}}", system, 1) if system is not None else elem
144 | elem = elem.replace("{{query}}", query, 1) if query is not None else elem
145 | elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
146 | if len(elem) != 0:
147 | token_ids = token_ids + tokenizer.encode(elem, **kwargs)
148 | elif isinstance(elem, dict):
149 | token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
150 | else:
151 | raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
152 |
153 | return token_ids
154 |
155 |
156 | @dataclass
157 | class Llama2Template(Template):
158 |
159 | def _encode(
160 | self,
161 | tokenizer: "PreTrainedTokenizer",
162 | system: str,
163 | history: List[Tuple[str, str]]
164 | ) -> List[Tuple[List[int], List[int]]]:
165 | r"""
166 | Encodes formatted inputs to pairs of token ids.
167 | Turn 0: bos + prefix + query resp + eos
168 | Turn t: bos + query resp + eos
169 | """
170 | bos_ids, eos_ids = self._get_special_ids(tokenizer)
171 | encoded_pairs = []
172 | for turn_idx, (query, resp) in enumerate(history):
173 | if turn_idx == 0: # llama2 template has no sep_ids
174 | query = self.prefix[0].replace("{{system}}", system) + query
175 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
176 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
177 | encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
178 | return encoded_pairs
179 |
180 |
181 | templates: Dict[str, Template] = {}
182 |
183 |
184 | def register_template(
185 | name: str,
186 | prefix: List[Union[str, Dict[str, str]]],
187 | prompt: List[Union[str, Dict[str, str]]],
188 | system: str,
189 | sep: List[Union[str, Dict[str, str]]],
190 | stop_words: Optional[List[str]] = [],
191 | use_history: Optional[bool] = True,
192 | efficient_eos: Optional[bool] = False,
193 | replace_eos: Optional[bool] = False
194 | ) -> None:
195 | template_class = Llama2Template if name.startswith("llama2") else Template
196 | templates[name] = template_class(
197 | prefix=prefix,
198 | prompt=prompt,
199 | system=system,
200 | sep=sep,
201 | stop_words=stop_words,
202 | use_history=use_history,
203 | efficient_eos=efficient_eos,
204 | replace_eos=replace_eos
205 | )
206 |
207 |
208 | def get_template_and_fix_tokenizer(
209 | name: str,
210 | tokenizer: "PreTrainedTokenizer"
211 | ) -> Template:
212 | if tokenizer.eos_token_id is None:
213 | tokenizer.eos_token = "<|endoftext|>"
214 | logger.info("Add eos token: {}".format(tokenizer.eos_token))
215 |
216 | if tokenizer.pad_token_id is None:
217 | tokenizer.pad_token = tokenizer.eos_token
218 | logger.info("Add pad token: {}".format(tokenizer.pad_token))
219 |
220 | if name is None: # for pre-training
221 | return None
222 |
223 | template = templates.get(name, None)
224 | assert template is not None, "Template {} does not exist.".format(name)
225 |
226 | stop_words = template.stop_words
227 | if template.replace_eos:
228 | if not stop_words:
229 | raise ValueError("Stop words are required to replace the EOS token.")
230 |
231 | tokenizer.eos_token = stop_words[0]
232 | stop_words = stop_words[1:]
233 | logger.info("Replace eos token: {}".format(tokenizer.eos_token))
234 |
235 | if stop_words:
236 | tokenizer.add_special_tokens(
237 | dict(additional_special_tokens=stop_words),
238 | replace_additional_special_tokens=False
239 | )
240 | logger.info("Add {} to stop words.".format(",".join(stop_words)))
241 |
242 | return template
243 |
244 |
245 | register_template(
246 | name="alpaca",
247 | prefix=[
248 | "{{system}}"
249 | ],
250 | prompt=[
251 | "### Instruction:\n{{query}}\n\n### Response:\n"
252 | ],
253 | system=(
254 | "Below is an instruction that describes a task. "
255 | "Write a response that appropriately completes the request."
256 | ),
257 | sep=[
258 | "\n\n"
259 | ]
260 | )
261 |
262 |
263 | register_template(
264 | name="aquila",
265 | prefix=[
266 | "{{system}}"
267 | ],
268 | prompt=[
269 | "Human: {{query}}###Assistant:"
270 | ],
271 | system=(
272 | "A chat between a curious human and an artificial intelligence assistant. "
273 | "The assistant gives helpful, detailed, and polite answers to the human's questions."
274 | ),
275 | sep=[
276 | "###"
277 | ],
278 | stop_words=[
279 | ""
280 | ],
281 | efficient_eos=True
282 | )
283 |
284 |
285 | register_template(
286 | name="baichuan",
287 | prefix=[
288 | "{{system}}"
289 | ],
290 | prompt=[
291 | {"token": ""}, # user token
292 | "{{query}}",
293 | {"token": ""} # assistant token
294 | ],
295 | system="",
296 | sep=[],
297 | efficient_eos=True
298 | )
299 |
300 |
301 | register_template(
302 | name="baichuan2",
303 | prefix=[
304 | "{{system}}"
305 | ],
306 | prompt=[
307 | {"token": ""}, # user token
308 | "{{query}}",
309 | {"token": ""} # assistant token
310 | ],
311 | system="",
312 | sep=[],
313 | efficient_eos=True
314 | )
315 |
316 |
317 | register_template(
318 | name="belle",
319 | prefix=[
320 | "{{system}}"
321 | ],
322 | prompt=[
323 | "Human: {{query}}\n\nBelle: "
324 | ],
325 | system="",
326 | sep=[
327 | "\n\n"
328 | ]
329 | )
330 |
331 |
332 | register_template(
333 | name="bluelm",
334 | prefix=[
335 | "{{system}}"
336 | ],
337 | prompt=[
338 | {"token": "[|Human|]:"},
339 | "{{query}}",
340 | {"token": "[|AI|]:"}
341 | ],
342 | system="",
343 | sep=[]
344 | )
345 |
346 |
347 | register_template(
348 | name="chatglm2",
349 | prefix=[
350 | {"token": "[gMASK]"},
351 | {"token": "sop"},
352 | "{{system}}"
353 | ],
354 | prompt=[
355 | "[Round {{idx}}]\n\n问:{{query}}\n\n答:"
356 | ],
357 | system="",
358 | sep=[
359 | "\n\n"
360 | ],
361 | efficient_eos=True
362 | )
363 |
364 |
365 | register_template(
366 | name="chatglm3",
367 | prefix=[
368 | {"token": "[gMASK]"},
369 | {"token": "sop"},
370 | {"token": "<|system|>"},
371 | "\n",
372 | "{{system}}"
373 | ],
374 | prompt=[
375 | {"token": "<|user|>"},
376 | "\n",
377 | "{{query}}",
378 | {"token": "<|assistant|>"},
379 | "\n" # add an extra newline to avoid error in ChatGLM's process_response method
380 | ],
381 | system=(
382 | "You are ChatGLM3, a large language model trained by Zhipu.AI. "
383 | "Follow the user's instructions carefully. Respond using markdown."
384 | ),
385 | sep=[],
386 | stop_words=[
387 | "<|user|>",
388 | "<|observation|>"
389 | ],
390 | efficient_eos=True
391 | )
392 |
393 |
394 | register_template(
395 | name="chatglm3_raw", # the raw template for tool tuning
396 | prefix=[
397 | {"token": "[gMASK]"},
398 | {"token": "sop"},
399 | {"token": "<|system|>"},
400 | "\n",
401 | "{{system}}"
402 | ],
403 | prompt=[
404 | {"token": "<|user|>"},
405 | "\n",
406 | "{{query}}",
407 | {"token": "<|assistant|>"}
408 | ],
409 | system=(
410 | "You are ChatGLM3, a large language model trained by Zhipu.AI. "
411 | "Follow the user's instructions carefully. Respond using markdown."
412 | ),
413 | sep=[],
414 | stop_words=[
415 | "<|user|>",
416 | "<|observation|>"
417 | ],
418 | efficient_eos=True
419 | )
420 |
421 |
422 | register_template(
423 | name="codegeex2",
424 | prefix=[
425 | {"token": "[gMASK]"},
426 | {"token": "sop"},
427 | "{{system}}"
428 | ],
429 | prompt=[
430 | "{{query}}"
431 | ],
432 | system="",
433 | sep=[]
434 | )
435 |
436 |
437 | register_template(
438 | name="deepseek",
439 | prefix=[
440 | "{{system}}"
441 | ],
442 | prompt=[
443 | "User: {{query}}\n\nAssistant:"
444 | ],
445 | system="",
446 | sep=[]
447 | )
448 |
449 |
450 | register_template(
451 | name="deepseekcoder",
452 | prefix=[
453 | "{{system}}"
454 | ],
455 | prompt=[
456 | "### Instruction:\n{{query}}\n### Response:\n"
457 | ],
458 | system=(
459 | "You are an AI programming assistant, utilizing the Deepseek Coder model, "
460 | "developed by Deepseek Company, and you only answer questions related to computer science. "
461 | "For politically sensitive questions, security and privacy issues, "
462 | "and other non-computer science questions, you will refuse to answer\n"
463 | ),
464 | sep=[
465 | "\n",
466 | {"token": "<|EOT|>"},
467 | "\n"
468 | ],
469 | stop_words=[
470 | "<|EOT|>"
471 | ],
472 | efficient_eos=True
473 | )
474 |
475 |
476 | register_template(
477 | name="default",
478 | prefix=[
479 | "{{system}}"
480 | ],
481 | prompt=[
482 | "Human: {{query}}\nAssistant:"
483 | ],
484 | system=(
485 | "A chat between a curious user and an artificial intelligence assistant. "
486 | "The assistant gives helpful, detailed, and polite answers to the user's questions."
487 | ),
488 | sep=[
489 | "\n"
490 | ]
491 | )
492 |
493 |
494 | register_template(
495 | name="falcon",
496 | prefix=[
497 | "{{system}}"
498 | ],
499 | prompt=[
500 | "User: {{query}}\nFalcon:"
501 | ],
502 | system="",
503 | sep=[
504 | "\n"
505 | ],
506 | efficient_eos=True
507 | )
508 |
509 |
510 | register_template(
511 | name="intern",
512 | prefix=[
513 | "{{system}}"
514 | ],
515 | prompt=[
516 | "<|User|>:{{query}}",
517 | {"token": ""},
518 | "\n<|Bot|>:"
519 | ],
520 | system="",
521 | sep=[
522 | {"token": ""},
523 | "\n"
524 | ],
525 | stop_words=[
526 | ""
527 | ],
528 | efficient_eos=True
529 | )
530 |
531 |
532 | register_template(
533 | name="llama2",
534 | prefix=[
535 | "<>\n{{system}}\n<>\n\n"
536 | ],
537 | prompt=[
538 | "[INST] {{query}} [/INST]"
539 | ],
540 | system=(
541 | "You are a helpful, respectful and honest assistant. "
542 | "Always answer as helpfully as possible, while being safe. "
543 | "Your answers should not include any harmful, unethical, "
544 | "racist, sexist, toxic, dangerous, or illegal content. "
545 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
546 | "If a question does not make any sense, or is not factually coherent, "
547 | "explain why instead of answering something not correct. "
548 | "If you don't know the answer to a question, please don't share false information."
549 | ),
550 | sep=[]
551 | )
552 |
553 |
554 | register_template(
555 | name="llama2_zh",
556 | prefix=[
557 | "<>\n{{system}}\n<>\n\n"
558 | ],
559 | prompt=[
560 | "[INST] {{query}} [/INST]"
561 | ],
562 | system="You are a helpful assistant. 你是一个乐于助人的助手。",
563 | sep=[]
564 | )
565 |
566 |
567 | register_template(
568 | name="mistral",
569 | prefix=[
570 | "{{system}}"
571 | ],
572 | prompt=[
573 | "[INST] {{query}} [/INST]"
574 | ],
575 | system="",
576 | sep=[]
577 | )
578 |
579 |
580 | register_template(
581 | name="openchat",
582 | prefix=[
583 | "{{system}}"
584 | ],
585 | prompt=[
586 | "GPT4 Correct User: {{query}}",
587 | {"token": "<|end_of_turn|>"},
588 | "GPT4 Correct Assistant:"
589 | ],
590 | system="",
591 | sep=[
592 | {"token": "<|end_of_turn|>"}
593 | ],
594 | stop_words=[
595 | "<|end_of_turn|>"
596 | ],
597 | efficient_eos=True
598 | )
599 |
600 |
601 | register_template(
602 | name="qwen",
603 | prefix=[
604 | "<|im_start|>system\n{{system}}<|im_end|>"
605 | ],
606 | prompt=[
607 | "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
608 | ],
609 | system="You are a helpful assistant.",
610 | sep=[
611 | "\n"
612 | ],
613 | stop_words=[
614 | "<|im_end|>"
615 | ],
616 | replace_eos=True
617 | )
618 |
619 |
620 | register_template(
621 | name="starchat",
622 | prefix=[
623 | {"token": "<|system|>"},
624 | "\n{{system}}",
625 | ],
626 | prompt=[
627 | {"token": "<|user|>"},
628 | "\n{{query}}",
629 | {"token": "<|end|>"},
630 | "\n",
631 | {"token": "<|assistant|>"}
632 | ],
633 | system="",
634 | sep=[
635 | {"token": "<|end|>"},
636 | "\n"
637 | ],
638 | stop_words=[
639 | "<|end|>"
640 | ],
641 | efficient_eos=True
642 | )
643 |
644 |
645 | register_template(
646 | name="vanilla",
647 | prefix=[],
648 | prompt=[
649 | "{{query}}"
650 | ],
651 | system="",
652 | sep=[],
653 | use_history=False
654 | )
655 |
656 |
657 | register_template(
658 | name="vicuna",
659 | prefix=[
660 | "{{system}}"
661 | ],
662 | prompt=[
663 | "USER: {{query}} ASSISTANT:"
664 | ],
665 | system=(
666 | "A chat between a curious user and an artificial intelligence assistant. "
667 | "The assistant gives helpful, detailed, and polite answers to the user's questions."
668 | ),
669 | sep=[]
670 | )
671 |
672 |
673 | register_template(
674 | name="xuanyuan",
675 | prefix=[
676 | "{{system}}"
677 | ],
678 | prompt=[
679 | "Human: {{query}} Assistant:"
680 | ],
681 | system=(
682 | "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
683 | "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
684 | "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
685 | ),
686 | sep=[]
687 | )
688 |
689 |
690 | register_template(
691 | name="xverse",
692 | prefix=[
693 | "{{system}}"
694 | ],
695 | prompt=[
696 | "Human: {{query}}\n\nAssistant: "
697 | ],
698 | system="",
699 | sep=[]
700 | )
701 |
702 |
703 | register_template(
704 | name="yayi",
705 | prefix=[
706 | {"token": "<|System|>"},
707 | ":\n{{system}}"
708 | ],
709 | prompt=[
710 | {"token": "<|Human|>"},
711 | ":\n{{query}}\n\n",
712 | {"token": "<|YaYi|>"},
713 | ":"
714 | ],
715 | system=(
716 | "You are a helpful, respectful and honest assistant named YaYi "
717 | "developed by Beijing Wenge Technology Co.,Ltd. "
718 | "Always answer as helpfully as possible, while being safe. "
719 | "Your answers should not include any harmful, unethical, "
720 | "racist, sexist, toxic, dangerous, or illegal content. "
721 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
722 | "If a question does not make any sense, or is not factually coherent, "
723 | "explain why instead of answering something not correct. "
724 | "If you don't know the answer to a question, please don't share false information."
725 | ),
726 | sep=[
727 | "\n\n"
728 | ],
729 | stop_words=[
730 | "<|End|>"
731 | ]
732 | )
733 |
734 |
735 | register_template(
736 | name="yi",
737 | prefix=[
738 | "{{system}}"
739 | ],
740 | prompt=[
741 | "<|im_start|>user\n{{query}}<|im_end|>\n<|im_start|>assistant\n"
742 | ],
743 | system="",
744 | sep=[
745 | "\n"
746 | ],
747 | stop_words=[
748 | "<|im_end|>"
749 | ],
750 | replace_eos=True
751 | )
752 |
753 |
754 | register_template(
755 | name="yuan",
756 | prefix=[
757 | "{{system}}"
758 | ],
759 | prompt=[
760 | "{{query}}",
761 | {"token": ""}
762 | ],
763 | system="",
764 | sep=[
765 | "\n"
766 | ],
767 | stop_words=[
768 | ""
769 | ],
770 | replace_eos=True
771 | )
772 |
773 |
774 | register_template(
775 | name="zephyr",
776 | prefix=[
777 | {"token": "<|system|>"},
778 | "\n{{system}}",
779 | {"token": ""}
780 | ],
781 | prompt=[
782 | {"token": "<|user|>"},
783 | "\n{{query}}",
784 | {"token": ""},
785 | {"token": "<|assistant|>"}
786 | ],
787 | system="You are a friendly chatbot who always responds in the style of a pirate",
788 | sep=[]
789 | )
790 |
791 |
792 | register_template(
793 | name="ziya",
794 | prefix=[
795 | "{{system}}"
796 | ],
797 | prompt=[
798 | {"token": ""},
799 | ":{{query}}\n",
800 | {"token": ""},
801 | ":"
802 | ],
803 | system="",
804 | sep=[
805 | "\n"
806 | ]
807 | )
808 |
--------------------------------------------------------------------------------
/model/src/data/utils.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from typing import TYPE_CHECKING, Dict, List, Optional, Union
3 |
4 | from ..extras.logging import get_logger
5 |
6 | if TYPE_CHECKING:
7 | from datasets import Dataset, IterableDataset
8 | from transformers import TrainingArguments
9 | from ..hparams.data_args import DataArguments
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
16 | if file_sha1 is None:
17 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
18 | return
19 |
20 | if len(data_files) != 1:
21 | logger.warning("Checksum failed: too many files.")
22 | return
23 |
24 | with open(data_files[0], "rb") as f:
25 | sha1 = hashlib.sha1(f.read()).hexdigest()
26 | if sha1 != file_sha1:
27 | logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
28 |
29 |
30 | def split_dataset(
31 | dataset: Union["Dataset", "IterableDataset"],
32 | data_args: "DataArguments",
33 | training_args: "TrainingArguments"
34 | ) -> Dict[str, "Dataset"]:
35 | if training_args.do_train:
36 | if data_args.val_size > 1e-6: # Split the dataset
37 | if data_args.streaming:
38 | val_set = dataset.take(int(data_args.val_size))
39 | train_set = dataset.skip(int(data_args.val_size))
40 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
41 | return {"train_dataset": train_set, "eval_dataset": val_set}
42 | else:
43 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
44 | dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
45 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
46 | else:
47 | if data_args.streaming:
48 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
49 | return {"train_dataset": dataset}
50 | else: # do_eval or do_predict
51 | return {"eval_dataset": dataset}
52 |
--------------------------------------------------------------------------------
/model/src/extras/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | from typing import TYPE_CHECKING
5 | from datetime import timedelta
6 | from transformers import TrainerCallback
7 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
8 |
9 | from .constants import LOG_FILE_NAME
10 | from .logging import get_logger
11 | from .misc import fix_valuehead_checkpoint
12 |
13 |
14 | if TYPE_CHECKING:
15 | from transformers import TrainingArguments, TrainerState, TrainerControl
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 | class FixValueHeadModelCallback(TrainerCallback):
21 |
22 | def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
23 | r"""
24 | Event called after a checkpoint save.
25 | """
26 | if args.should_save:
27 | fix_valuehead_checkpoint(
28 | model=kwargs.pop("model"),
29 | output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)),
30 | safe_serialization=args.save_safetensors
31 | )
32 |
33 | class LogCallback(TrainerCallback):
34 |
35 | def __init__(self, runner=None):
36 | self.runner = runner
37 | self.in_training = False
38 | self.start_time = time.time()
39 | self.cur_steps = 0
40 | self.max_steps = 0
41 | self.elapsed_time = ""
42 | self.remaining_time = ""
43 |
44 | def timing(self):
45 | cur_time = time.time()
46 | elapsed_time = cur_time - self.start_time
47 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
48 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
49 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
50 | self.remaining_time = str(timedelta(seconds=int(remaining_time)))
51 |
52 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
53 | r"""
54 | Event called at the beginning of training.
55 | """
56 | if state.is_local_process_zero:
57 | self.in_training = True
58 | self.start_time = time.time()
59 | self.max_steps = state.max_steps
60 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
61 | logger.warning("Previous log file in this folder will be deleted.")
62 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
63 |
64 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
65 | r"""
66 | Event called at the end of training.
67 | """
68 | if state.is_local_process_zero:
69 | self.in_training = False
70 | self.cur_steps = 0
71 | self.max_steps = 0
72 |
73 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
74 | r"""
75 | Event called at the end of an substep during gradient accumulation.
76 | """
77 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
78 | control.should_epoch_stop = True
79 | control.should_training_stop = True
80 |
81 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
82 | r"""
83 | Event called at the end of a training step.
84 | """
85 | if state.is_local_process_zero:
86 | self.cur_steps = state.global_step
87 | self.timing()
88 | if self.runner is not None and self.runner.aborted:
89 | control.should_epoch_stop = True
90 | control.should_training_stop = True
91 |
92 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
93 | r"""
94 | Event called after an evaluation phase.
95 | """
96 | if state.is_local_process_zero and not self.in_training:
97 | self.cur_steps = 0
98 | self.max_steps = 0
99 |
100 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
101 | r"""
102 | Event called after a successful prediction.
103 | """
104 | if state.is_local_process_zero and not self.in_training:
105 | self.cur_steps = 0
106 | self.max_steps = 0
107 |
108 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
109 | r"""
110 | Event called after logging the last logs.
111 | """
112 | if not state.is_local_process_zero:
113 | return
114 |
115 | logs = dict(
116 | current_steps=self.cur_steps,
117 | total_steps=self.max_steps,
118 | loss=state.log_history[-1].get("loss", None),
119 | eval_loss=state.log_history[-1].get("eval_loss", None),
120 | predict_loss=state.log_history[-1].get("predict_loss", None),
121 | reward=state.log_history[-1].get("reward", None),
122 | learning_rate=state.log_history[-1].get("learning_rate", None),
123 | epoch=state.log_history[-1].get("epoch", None),
124 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
125 | elapsed_time=self.elapsed_time,
126 | remaining_time=self.remaining_time
127 | )
128 | if self.runner is not None:
129 | logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
130 | logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
131 | ))
132 |
133 | os.makedirs(args.output_dir, exist_ok=True)
134 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
135 | f.write(json.dumps(logs) + "\n")
136 |
137 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
138 | r"""
139 | Event called after a prediction step.
140 | """
141 | eval_dataloader = kwargs.pop("eval_dataloader", None)
142 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
143 | if self.max_steps == 0:
144 | self.max_steps = len(eval_dataloader)
145 | self.cur_steps += 1
146 | self.timing()
147 |
--------------------------------------------------------------------------------
/model/src/extras/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from collections import defaultdict, OrderedDict
3 | from typing import Dict, Optional
4 |
5 |
6 | CHOICES = ["A", "B", "C", "D"]
7 |
8 | DEFAULT_MODULE = defaultdict(str)
9 |
10 | DEFAULT_TEMPLATE = defaultdict(str)
11 |
12 | FILEEXT2TYPE = {
13 | "arrow": "arrow",
14 | "csv": "csv",
15 | "json": "json",
16 | "jsonl": "json",
17 | "parquet": "parquet",
18 | "txt": "text"
19 | }
20 |
21 | IGNORE_INDEX = -100
22 |
23 | LAYERNORM_NAMES = {"norm", "ln"}
24 |
25 | LOG_FILE_NAME = "trainer_log.jsonl"
26 |
27 | METHODS = ["full", "freeze", "lora"]
28 |
29 | PEFT_METHODS = ["lora"]
30 |
31 | SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
32 |
33 | SUPPORTED_MODELS = OrderedDict()
34 |
35 | TRAINING_STAGES = {
36 | "Supervised Fine-Tuning": "sft",
37 | "Reward Modeling": "rm",
38 | "PPO": "ppo",
39 | "DPO": "dpo",
40 | "Pre-Training": "pt"
41 | }
42 |
43 | V_HEAD_WEIGHTS_NAME = "value_head.bin"
44 |
45 | V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
46 |
47 | class DownloadSource(str, Enum):
48 | DEFAULT = "hf"
49 | MODELSCOPE = "ms"
50 |
51 |
52 | def register_model_group(
53 | models: Dict[str, Dict[DownloadSource, str]],
54 | module: Optional[str] = None,
55 | template: Optional[str] = None
56 | ) -> None:
57 | prefix = None
58 | for name, path in models.items():
59 | if prefix is None:
60 | prefix = name.split("-")[0]
61 | else:
62 | assert prefix == name.split("-")[0], "prefix should be identical."
63 | SUPPORTED_MODELS[name] = path
64 | if module is not None:
65 | DEFAULT_MODULE[prefix] = module
66 | if template is not None:
67 | DEFAULT_TEMPLATE[prefix] = template
68 |
69 |
70 | register_model_group(
71 | models={
72 | "Baichuan-7B-Base": {
73 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
74 | DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B"
75 | },
76 | "Baichuan-13B-Base": {
77 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
78 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base"
79 | },
80 | "Baichuan-13B-Chat": {
81 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
82 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat"
83 | }
84 | },
85 | module="W_pack",
86 | template="baichuan"
87 | )
88 |
89 |
90 | register_model_group(
91 | models={
92 | "Baichuan2-7B-Base": {
93 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
94 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base"
95 | },
96 | "Baichuan2-13B-Base": {
97 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
98 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base"
99 | },
100 | "Baichuan2-7B-Chat": {
101 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
102 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat"
103 | },
104 | "Baichuan2-13B-Chat": {
105 | DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
106 | DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat"
107 | }
108 | },
109 | module="W_pack",
110 | template="baichuan2"
111 | )
112 |
113 |
114 | register_model_group(
115 | models={
116 | "BLOOM-560M": {
117 | DownloadSource.DEFAULT: "bigscience/bloom-560m",
118 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m"
119 | },
120 | "BLOOM-3B": {
121 | DownloadSource.DEFAULT: "bigscience/bloom-3b",
122 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b"
123 | },
124 | "BLOOM-7B1": {
125 | DownloadSource.DEFAULT: "bigscience/bloom-7b1",
126 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1"
127 | }
128 | },
129 | module="query_key_value"
130 | )
131 |
132 |
133 | register_model_group(
134 | models={
135 | "BLOOMZ-560M": {
136 | DownloadSource.DEFAULT: "bigscience/bloomz-560m",
137 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m"
138 | },
139 | "BLOOMZ-3B": {
140 | DownloadSource.DEFAULT: "bigscience/bloomz-3b",
141 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b"
142 | },
143 | "BLOOMZ-7B1-mt": {
144 | DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
145 | DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt"
146 | }
147 | },
148 | module="query_key_value"
149 | )
150 |
151 |
152 | register_model_group(
153 | models={
154 | "BlueLM-7B-Base": {
155 | DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
156 | DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base"
157 | },
158 | "BlueLM-7B-Chat": {
159 | DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
160 | DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat"
161 | }
162 | },
163 | template="bluelm"
164 | )
165 |
166 |
167 | register_model_group(
168 | models={
169 | "ChatGLM2-6B-Chat": {
170 | DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
171 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b"
172 | }
173 | },
174 | module="query_key_value",
175 | template="chatglm2"
176 | )
177 |
178 |
179 | register_model_group(
180 | models={
181 | "ChatGLM3-6B-Base": {
182 | DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
183 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base"
184 | },
185 | "ChatGLM3-6B-Chat": {
186 | DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
187 | DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b"
188 | }
189 | },
190 | module="query_key_value",
191 | template="chatglm3"
192 | )
193 |
194 |
195 | register_model_group(
196 | models={
197 | "ChineseLLaMA2-1.3B": {
198 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
199 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b"
200 | },
201 | "ChineseLLaMA2-7B": {
202 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
203 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b"
204 | },
205 | "ChineseLLaMA2-13B": {
206 | DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
207 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b"
208 | },
209 | "ChineseLLaMA2-1.3B-Chat": {
210 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
211 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b"
212 | },
213 | "ChineseLLaMA2-7B-Chat": {
214 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
215 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b"
216 | },
217 | "ChineseLLaMA2-13B-Chat": {
218 | DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
219 | DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b"
220 | }
221 | },
222 | template="llama2_zh"
223 | )
224 |
225 |
226 | register_model_group(
227 | models={
228 | "DeepseekLLM-7B-Base": {
229 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
230 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base"
231 | },
232 | "DeepseekLLM-67B-Base": {
233 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
234 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base"
235 | },
236 | "DeepseekLLM-7B-Chat": {
237 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
238 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat"
239 | },
240 | "DeepseekLLM-67B-Chat": {
241 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
242 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat"
243 | }
244 | },
245 | template="deepseek"
246 | )
247 |
248 |
249 | register_model_group(
250 | models={
251 | "DeepseekCoder-6.7B-Base": {
252 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
253 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base"
254 | },
255 | "DeepseekCoder-33B-Base": {
256 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
257 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base"
258 | },
259 | "DeepseekCoder-6.7B-Chat": {
260 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
261 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct"
262 | },
263 | "DeepseekCoder-33B-Chat": {
264 | DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
265 | DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct"
266 | }
267 | },
268 | template="deepseekcoder"
269 | )
270 |
271 |
272 | register_model_group(
273 | models={
274 | "Falcon-7B": {
275 | DownloadSource.DEFAULT: "tiiuae/falcon-7b",
276 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b"
277 | },
278 | "Falcon-40B": {
279 | DownloadSource.DEFAULT: "tiiuae/falcon-40b",
280 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b"
281 | },
282 | "Falcon-180B": {
283 | DownloadSource.DEFAULT: "tiiuae/falcon-180b",
284 | DownloadSource.MODELSCOPE: "modelscope/falcon-180B"
285 | },
286 | "Falcon-7B-Chat": {
287 | DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
288 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct"
289 | },
290 | "Falcon-40B-Chat": {
291 | DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
292 | DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct"
293 | },
294 | "Falcon-180B-Chat": {
295 | DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
296 | DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat"
297 | }
298 | },
299 | module="query_key_value",
300 | template="falcon"
301 | )
302 |
303 |
304 | register_model_group(
305 | models={
306 | "InternLM-7B": {
307 | DownloadSource.DEFAULT: "internlm/internlm-7b",
308 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b"
309 | },
310 | "InternLM-20B": {
311 | DownloadSource.DEFAULT: "internlm/internlm-20b",
312 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b"
313 | },
314 | "InternLM-7B-Chat": {
315 | DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
316 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b"
317 | },
318 | "InternLM-20B-Chat": {
319 | DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
320 | DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b"
321 | }
322 | },
323 | template="intern"
324 | )
325 |
326 |
327 | register_model_group(
328 | models={
329 | "LingoWhale-8B": {
330 | DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
331 | DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B"
332 | }
333 | },
334 | module="qkv_proj"
335 | )
336 |
337 |
338 | register_model_group(
339 | models={
340 | "LLaMA-7B": {
341 | DownloadSource.DEFAULT: "huggyllama/llama-7b",
342 | DownloadSource.MODELSCOPE: "skyline2006/llama-7b"
343 | },
344 | "LLaMA-13B": {
345 | DownloadSource.DEFAULT: "huggyllama/llama-13b",
346 | DownloadSource.MODELSCOPE: "skyline2006/llama-13b"
347 | },
348 | "LLaMA-30B": {
349 | DownloadSource.DEFAULT: "huggyllama/llama-30b",
350 | DownloadSource.MODELSCOPE: "skyline2006/llama-30b"
351 | },
352 | "LLaMA-65B": {
353 | DownloadSource.DEFAULT: "huggyllama/llama-65b",
354 | DownloadSource.MODELSCOPE: "skyline2006/llama-65b"
355 | }
356 | }
357 | )
358 |
359 |
360 | register_model_group(
361 | models={
362 | "LLaMA2-7B": {
363 | DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
364 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms"
365 | },
366 | "LLaMA2-13B": {
367 | DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
368 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms"
369 | },
370 | "LLaMA2-70B": {
371 | DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
372 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms"
373 | },
374 | "LLaMA2-7B-Chat": {
375 | DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
376 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms"
377 | },
378 | "LLaMA2-13B-Chat": {
379 | DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
380 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms"
381 | },
382 | "LLaMA2-70B-Chat": {
383 | DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
384 | DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms"
385 | }
386 | },
387 | template="llama2"
388 | )
389 |
390 |
391 | register_model_group(
392 | models={
393 | "Mistral-7B": {
394 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
395 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1"
396 | },
397 | "Mistral-7B-Chat": {
398 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
399 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1"
400 | },
401 | "Mistral-7B-v0.2-Chat": {
402 | DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
403 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2"
404 | }
405 | },
406 | template="mistral"
407 | )
408 |
409 |
410 | register_model_group(
411 | models={
412 | "Mixtral-8x7B": {
413 | DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
414 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1"
415 | },
416 | "Mixtral-8x7B-Chat": {
417 | DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
418 | DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1"
419 | }
420 | },
421 | template="mistral"
422 | )
423 |
424 |
425 | register_model_group(
426 | models={
427 | "OpenChat3.5-7B-Chat": {
428 | DownloadSource.DEFAULT: "openchat/openchat_3.5",
429 | DownloadSource.MODELSCOPE: "myxiongmodel/openchat_3.5"
430 | }
431 | },
432 | template="openchat"
433 | )
434 |
435 |
436 | register_model_group(
437 | models={
438 | "Phi-1.5-1.3B": {
439 | DownloadSource.DEFAULT: "microsoft/phi-1_5",
440 | DownloadSource.MODELSCOPE: "allspace/PHI_1-5"
441 | },
442 | "Phi-2-2.7B": {
443 | DownloadSource.DEFAULT: "microsoft/phi-2",
444 | DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2"
445 | }
446 | },
447 | module="Wqkv"
448 | )
449 |
450 |
451 | register_model_group(
452 | models={
453 | "Qwen-1.8B": {
454 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
455 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B"
456 | },
457 | "Qwen-7B": {
458 | DownloadSource.DEFAULT: "Qwen/Qwen-7B",
459 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B"
460 | },
461 | "Qwen-14B": {
462 | DownloadSource.DEFAULT: "Qwen/Qwen-14B",
463 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B"
464 | },
465 | "Qwen-72B": {
466 | DownloadSource.DEFAULT: "Qwen/Qwen-72B",
467 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B"
468 | },
469 | "Qwen-1.8B-Chat": {
470 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
471 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat"
472 | },
473 | "Qwen-7B-Chat": {
474 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
475 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat"
476 | },
477 | "Qwen-14B-Chat": {
478 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
479 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat"
480 | },
481 | "Qwen-72B-Chat": {
482 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
483 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat"
484 | },
485 | "Qwen-1.8B-int8-Chat": {
486 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
487 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8"
488 | },
489 | "Qwen-1.8B-int4-Chat": {
490 | DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
491 | DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4"
492 | },
493 | "Qwen-7B-int8-Chat": {
494 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
495 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8"
496 | },
497 | "Qwen-7B-int4-Chat": {
498 | DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
499 | DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4"
500 | },
501 | "Qwen-14B-int8-Chat": {
502 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
503 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8"
504 | },
505 | "Qwen-14B-int4-Chat": {
506 | DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
507 | DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4"
508 | },
509 | "Qwen-72B-int8-Chat": {
510 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
511 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8"
512 | },
513 | "Qwen-72B-int4-Chat": {
514 | DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
515 | DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4"
516 | }
517 | },
518 | module="c_attn",
519 | template="qwen"
520 | )
521 |
522 |
523 | register_model_group(
524 | models={
525 | "Skywork-13B-Base": {
526 | DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
527 | DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base"
528 | }
529 | }
530 | )
531 |
532 |
533 | register_model_group(
534 | models={
535 | "Vicuna1.5-7B-Chat": {
536 | DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
537 | DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5"
538 | },
539 | "Vicuna1.5-13B-Chat": {
540 | DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
541 | DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5"
542 | }
543 | },
544 | template="vicuna"
545 | )
546 |
547 |
548 | register_model_group(
549 | models={
550 | "XuanYuan-70B": {
551 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B"
552 | },
553 | "XuanYuan-70B-Chat": {
554 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat"
555 | },
556 | "XuanYuan-70B-int8-Chat": {
557 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit"
558 | },
559 | "XuanYuan-70B-int4-Chat": {
560 | DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit"
561 | }
562 | },
563 | template="xuanyuan"
564 | )
565 |
566 |
567 | register_model_group(
568 | models={
569 | "XVERSE-7B": {
570 | DownloadSource.DEFAULT: "xverse/XVERSE-7B",
571 | DownloadSource.MODELSCOPE: "xverse/XVERSE-7B"
572 | },
573 | "XVERSE-13B": {
574 | DownloadSource.DEFAULT: "xverse/XVERSE-13B",
575 | DownloadSource.MODELSCOPE: "xverse/XVERSE-13B"
576 | },
577 | "XVERSE-65B": {
578 | DownloadSource.DEFAULT: "xverse/XVERSE-65B",
579 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B"
580 | },
581 | "XVERSE-65B-2": {
582 | DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
583 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2"
584 | },
585 | "XVERSE-7B-Chat": {
586 | DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
587 | DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat"
588 | },
589 | "XVERSE-13B-Chat": {
590 | DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
591 | DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat"
592 | },
593 | "XVERSE-65B-Chat": {
594 | DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
595 | DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat"
596 | }
597 | },
598 | template="xverse"
599 | )
600 |
601 |
602 | register_model_group(
603 | models={
604 | "Yayi-7B": {
605 | DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
606 | DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2"
607 | },
608 | "Yayi-13B": {
609 | DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
610 | DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2"
611 | }
612 | },
613 | template="yayi"
614 | )
615 |
616 |
617 | register_model_group(
618 | models={
619 | "Yi-6B": {
620 | DownloadSource.DEFAULT: "01-ai/Yi-6B",
621 | DownloadSource.MODELSCOPE: "01ai/Yi-6B"
622 | },
623 | "Yi-34B": {
624 | DownloadSource.DEFAULT: "01-ai/Yi-34B",
625 | DownloadSource.MODELSCOPE: "01ai/Yi-34B"
626 | },
627 | "Yi-6B-Chat": {
628 | DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
629 | DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat"
630 | },
631 | "Yi-34B-Chat": {
632 | DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
633 | DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat"
634 | },
635 | "Yi-6B-int8-Chat": {
636 | DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
637 | DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits"
638 | },
639 | "Yi-34B-int8-Chat": {
640 | DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
641 | DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits"
642 | }
643 | },
644 | template="yi"
645 | )
646 |
647 |
648 | register_model_group(
649 | models={
650 | "Yuan2-2B-Chat": {
651 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
652 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf"
653 | },
654 | "Yuan2-51B-Chat": {
655 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
656 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf"
657 | },
658 | "Yuan2-102B-Chat": {
659 | DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
660 | DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf"
661 | }
662 | },
663 | template="yuan"
664 | )
665 |
666 |
667 | register_model_group(
668 | models={
669 | "Zephyr-7B-Alpha-Chat": {
670 | DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
671 | DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha"
672 | },
673 | "Zephyr-7B-Beta-Chat": {
674 | DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
675 | DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta"
676 | }
677 | },
678 | template="zephyr"
679 | )
680 |
--------------------------------------------------------------------------------
/model/src/extras/logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 |
4 |
5 | class LoggerHandler(logging.Handler):
6 | r"""
7 | Logger handler used in Web UI.
8 | """
9 |
10 | def __init__(self):
11 | super().__init__()
12 | self.log = ""
13 |
14 | def reset(self):
15 | self.log = ""
16 |
17 | def emit(self, record):
18 | if record.name == "httpx":
19 | return
20 | log_entry = self.format(record)
21 | self.log += log_entry
22 | self.log += "\n\n"
23 |
24 |
25 | def get_logger(name: str) -> logging.Logger:
26 | r"""
27 | Gets a standard logger with a stream hander to stdout.
28 | """
29 | formatter = logging.Formatter(
30 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
31 | datefmt="%m/%d/%Y %H:%M:%S"
32 | )
33 | handler = logging.StreamHandler(sys.stdout)
34 | handler.setFormatter(formatter)
35 |
36 | logger = logging.getLogger(name)
37 | logger.setLevel(logging.INFO)
38 | logger.addHandler(handler)
39 |
40 | return logger
41 |
42 |
43 | def reset_logging() -> None:
44 | r"""
45 | Removes basic config of root logger. (unused in script)
46 | """
47 | root = logging.getLogger()
48 | list(map(root.removeHandler, root.handlers))
49 | list(map(root.removeFilter, root.filters))
50 |
--------------------------------------------------------------------------------
/model/src/extras/misc.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 | import torch
4 | from typing import TYPE_CHECKING, Dict, Tuple
5 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList, PreTrainedModel
6 | from transformers.utils import (
7 | WEIGHTS_NAME,
8 | SAFE_WEIGHTS_NAME,
9 | is_torch_bf16_gpu_available,
10 | is_torch_cuda_available,
11 | is_torch_npu_available,
12 | is_torch_xpu_available
13 | )
14 | from peft import PeftModel
15 |
16 | from .constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
17 | from .logging import get_logger
18 |
19 |
20 | _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
21 | try:
22 | _is_bf16_available = is_torch_bf16_gpu_available()
23 | except:
24 | _is_bf16_available = False
25 |
26 |
27 | if TYPE_CHECKING:
28 | from trl import AutoModelForCausalLMWithValueHead
29 | from ..hparams.model_args import ModelArguments
30 |
31 |
32 | logger = get_logger(__name__)
33 |
34 |
35 | class AverageMeter:
36 | r"""
37 | Computes and stores the average and current value.
38 | """
39 | def __init__(self):
40 | self.reset()
41 |
42 | def reset(self):
43 | self.val = 0
44 | self.avg = 0
45 | self.sum = 0
46 | self.count = 0
47 |
48 | def update(self, val, n=1):
49 | self.val = val
50 | self.sum += val * n
51 | self.count += n
52 | self.avg = self.sum / self.count
53 |
54 |
55 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
56 | r"""
57 | Returns the number of trainable parameters and number of all parameters in the model.
58 | """
59 | trainable_params, all_param = 0, 0
60 | for param in model.parameters():
61 | num_params = param.numel()
62 | # if using DS Zero 3 and the weights are initialized empty
63 | if num_params == 0 and hasattr(param, "ds_numel"):
64 | num_params = param.ds_numel
65 |
66 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2
67 | if param.__class__.__name__ == "Params4bit":
68 | num_params = num_params * 2
69 |
70 | all_param += num_params
71 | if param.requires_grad:
72 | trainable_params += num_params
73 |
74 | return trainable_params, all_param
75 |
76 |
77 | def fix_valuehead_checkpoint(
78 | model: "AutoModelForCausalLMWithValueHead",
79 | output_dir: str,
80 | safe_serialization: bool
81 | ) -> None:
82 | r"""
83 | The model is already unwrapped.
84 |
85 | There are three cases:
86 | 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...}
87 | 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...}
88 | 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...}
89 |
90 | We assume `stage3_gather_16bit_weights_on_model_save=true`.
91 | """
92 | if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)):
93 | return
94 |
95 | if safe_serialization:
96 | from safetensors import safe_open
97 | from safetensors.torch import save_file
98 | path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME)
99 | with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f:
100 | state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()}
101 | else:
102 | path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME)
103 | state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu")
104 |
105 | decoder_state_dict = {}
106 | v_head_state_dict = {}
107 | for name, param in state_dict.items():
108 | if name.startswith("v_head."):
109 | v_head_state_dict[name] = param
110 | else:
111 | decoder_state_dict[name.replace("pretrained_model.", "")] = param
112 |
113 | os.remove(path_to_checkpoint)
114 | model.pretrained_model.save_pretrained(
115 | output_dir,
116 | state_dict=decoder_state_dict or None,
117 | safe_serialization=safe_serialization
118 | )
119 |
120 | if safe_serialization:
121 | save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"})
122 | else:
123 | torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
124 |
125 | logger.info("Value head model saved at: {}".format(output_dir))
126 |
127 |
128 | def get_current_device() -> torch.device:
129 | r"""
130 | Gets the current available device.
131 | """
132 | if is_torch_xpu_available():
133 | device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
134 | elif is_torch_npu_available():
135 | device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
136 | elif is_torch_cuda_available():
137 | device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
138 | else:
139 | device = "cpu"
140 |
141 | return torch.device(device)
142 |
143 |
144 | def get_device_count() -> int:
145 | return torch.cuda.device_count()
146 |
147 |
148 | def get_logits_processor() -> "LogitsProcessorList":
149 | r"""
150 | Gets logits processor that removes NaN and Inf logits.
151 | """
152 | logits_processor = LogitsProcessorList()
153 | logits_processor.append(InfNanRemoveLogitsProcessor())
154 | return logits_processor
155 |
156 |
157 | def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
158 | r"""
159 | Infers the optimal dtype according to the model_dtype and device compatibility.
160 | """
161 | if _is_bf16_available and model_dtype == torch.bfloat16:
162 | return torch.bfloat16
163 | elif _is_fp16_available:
164 | return torch.float16
165 | else:
166 | return torch.float32
167 |
168 |
169 | def torch_gc() -> None:
170 | r"""
171 | Collects GPU memory.
172 | """
173 | gc.collect()
174 | if torch.cuda.is_available():
175 | torch.cuda.empty_cache()
176 | torch.cuda.ipc_collect()
177 |
178 |
179 | def try_download_model_from_ms(model_args: "ModelArguments") -> None:
180 | if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
181 | return
182 |
183 | try:
184 | from modelscope import snapshot_download
185 | revision = "master" if model_args.model_revision == "main" else model_args.model_revision
186 | model_args.model_name_or_path = snapshot_download(
187 | model_args.model_name_or_path,
188 | revision=revision,
189 | cache_dir=model_args.cache_dir
190 | )
191 | except ImportError:
192 | raise ImportError("Please install modelscope via `pip install modelscope -U`")
193 |
194 |
195 | def use_modelscope() -> bool:
196 | return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
197 |
--------------------------------------------------------------------------------
/model/src/extras/packages.py:
--------------------------------------------------------------------------------
1 | import importlib.metadata
2 | import importlib.util
3 |
4 |
5 | def is_package_available(name: str) -> bool:
6 | return importlib.util.find_spec(name) is not None
7 |
8 |
9 | def get_package_version(name: str) -> str:
10 | try:
11 | return importlib.metadata.version(name)
12 | except:
13 | return "0.0.0"
14 |
15 |
16 | def is_fastapi_availble():
17 | return is_package_available("fastapi")
18 |
19 |
20 | def is_flash_attn2_available():
21 | return is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2")
22 |
23 |
24 | def is_jieba_available():
25 | return is_package_available("jieba")
26 |
27 |
28 | def is_matplotlib_available():
29 | return is_package_available("matplotlib")
30 |
31 |
32 | def is_nltk_available():
33 | return is_package_available("nltk")
34 |
35 |
36 | def is_requests_available():
37 | return is_package_available("requests")
38 |
39 |
40 | def is_rouge_available():
41 | return is_package_available("rouge_chinese")
42 |
43 |
44 | def is_starlette_available():
45 | return is_package_available("sse_starlette")
46 |
47 |
48 | def is_uvicorn_available():
49 | return is_package_available("uvicorn")
50 |
--------------------------------------------------------------------------------
/model/src/extras/ploting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | from typing import List, Optional
5 | from transformers.trainer import TRAINER_STATE_NAME
6 |
7 | from .logging import get_logger
8 | from .packages import is_matplotlib_available
9 |
10 | if is_matplotlib_available():
11 | import matplotlib.pyplot as plt
12 |
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def smooth(scalars: List[float]) -> List[float]:
18 | r"""
19 | EMA implementation according to TensorBoard.
20 | """
21 | last = scalars[0]
22 | smoothed = list()
23 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
24 | for next_val in scalars:
25 | smoothed_val = last * weight + (1 - weight) * next_val
26 | smoothed.append(smoothed_val)
27 | last = smoothed_val
28 | return smoothed
29 |
30 |
31 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
32 |
33 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
34 | data = json.load(f)
35 |
36 | for key in keys:
37 | steps, metrics = [], []
38 | for i in range(len(data["log_history"])):
39 | if key in data["log_history"][i]:
40 | steps.append(data["log_history"][i]["step"])
41 | metrics.append(data["log_history"][i][key])
42 |
43 | if len(metrics) == 0:
44 | logger.warning(f"No metric {key} to plot.")
45 | continue
46 |
47 | plt.figure()
48 | plt.plot(steps, metrics, alpha=0.4, label="original")
49 | plt.plot(steps, smooth(metrics), label="smoothed")
50 | plt.title("training {} of {}".format(key, save_dictionary))
51 | plt.xlabel("step")
52 | plt.ylabel(key)
53 | plt.legend()
54 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
55 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
56 |
--------------------------------------------------------------------------------
/model/src/hparams/data_args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import List, Literal, Optional
4 | from dataclasses import dataclass, field
5 |
6 |
7 | DATA_CONFIG = "dataset_info.json"
8 |
9 |
10 | def use_modelscope() -> bool:
11 | return bool(int(os.environ.get("USE_MODELSCOPE_HUB", "0")))
12 |
13 |
14 | @dataclass
15 | class DatasetAttr:
16 |
17 | load_from: Literal["hf_hub", "ms_hub", "script", "file"]
18 | dataset_name: Optional[str] = None
19 | dataset_sha1: Optional[str] = None
20 | subset: Optional[str] = None
21 | folder: Optional[str] = None
22 | ranking: Optional[bool] = False
23 | formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
24 |
25 | prompt: Optional[str] = "instruction"
26 | query: Optional[str] = "input"
27 | response: Optional[str] = "output"
28 | history: Optional[str] = None
29 | messages: Optional[str] = "conversations"
30 | role: Optional[str] = "from"
31 | content: Optional[str] = "value"
32 | system: Optional[str] = None
33 |
34 | def __repr__(self) -> str:
35 | return self.dataset_name
36 |
37 |
38 | @dataclass
39 | class DataArguments:
40 | r"""
41 | Arguments pertaining to what data we are going to input our model for training and evaluation.
42 | """
43 | template: Optional[str] = field(
44 | default=None,
45 | metadata={"help": "Which template to use for constructing prompts in training and inference."}
46 | )
47 | dataset: Optional[str] = field(
48 | default=None,
49 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
50 | )
51 | dataset_dir: Optional[str] = field(
52 | default="data",
53 | metadata={"help": "Path to the folder containing the datasets."}
54 | )
55 | split: Optional[str] = field(
56 | default="train",
57 | metadata={"help": "Which dataset split to use for training and evaluation."}
58 | )
59 | cutoff_len: Optional[int] = field(
60 | default=1024,
61 | metadata={"help": "The maximum length of the model inputs after tokenization."}
62 | )
63 | reserved_label_len: Optional[int] = field(
64 | default=1,
65 | metadata={"help": "The maximum length reserved for label after tokenization."}
66 | )
67 | train_on_prompt: Optional[bool] = field(
68 | default=False,
69 | metadata={"help": "Whether to disable the mask on the prompt or not."}
70 | )
71 | streaming: Optional[bool] = field(
72 | default=False,
73 | metadata={"help": "Enable dataset streaming."}
74 | )
75 | buffer_size: Optional[int] = field(
76 | default=16384,
77 | metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
78 | )
79 | mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
80 | default="concat",
81 | metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
82 | )
83 | interleave_probs: Optional[str] = field(
84 | default=None,
85 | metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
86 | )
87 | overwrite_cache: Optional[bool] = field(
88 | default=False,
89 | metadata={"help": "Overwrite the cached training and evaluation sets."}
90 | )
91 | preprocessing_num_workers: Optional[int] = field(
92 | default=None,
93 | metadata={"help": "The number of processes to use for the preprocessing."}
94 | )
95 | max_samples: Optional[int] = field(
96 | default=None,
97 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
98 | )
99 | eval_num_beams: Optional[int] = field(
100 | default=None,
101 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
102 | )
103 | ignore_pad_token_for_loss: Optional[bool] = field(
104 | default=True,
105 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
106 | )
107 | val_size: Optional[float] = field(
108 | default=0,
109 | metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
110 | )
111 | sft_packing: Optional[bool] = field(
112 | default=False,
113 | metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
114 | )
115 | cache_path: Optional[str] = field(
116 | default=None,
117 | metadata={"help": "Path to save or load the preprocessed datasets."}
118 | )
119 |
120 | def __post_init__(self):
121 | if self.reserved_label_len >= self.cutoff_len:
122 | raise ValueError("`reserved_label_len` must be smaller than `cutoff_len`.")
123 |
124 | if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
125 | raise ValueError("Streaming mode should have an integer val size.")
126 |
127 | if self.streaming and self.max_samples is not None:
128 | raise ValueError("`max_samples` is incompatible with `streaming`.")
129 |
130 | def init_for_training(self, seed: int): # support mixing multiple datasets
131 | self.seed = seed
132 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
133 | try:
134 | with open(os.path.join(self.dataset_dir, DATA_CONFIG), "r") as f:
135 | dataset_info = json.load(f)
136 | except Exception as err:
137 | if self.dataset is not None:
138 | raise ValueError("Cannot open {} due to {}.".format(os.path.join(self.dataset_dir, DATA_CONFIG), str(err)))
139 | dataset_info = None
140 |
141 | if self.interleave_probs is not None:
142 | self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
143 |
144 | self.dataset_list: List[DatasetAttr] = []
145 | for name in dataset_names:
146 | if name not in dataset_info:
147 | raise ValueError("Undefined dataset {} in {}.".format(name, DATA_CONFIG))
148 |
149 | has_hf_url = "hf_hub_url" in dataset_info[name]
150 | has_ms_url = "ms_hub_url" in dataset_info[name]
151 |
152 | if has_hf_url or has_ms_url:
153 | if (use_modelscope() and has_ms_url) or (not has_hf_url):
154 | dataset_attr = DatasetAttr(
155 | "ms_hub",
156 | dataset_name=dataset_info[name]["ms_hub_url"]
157 | )
158 | else:
159 | dataset_attr = DatasetAttr(
160 | "hf_hub",
161 | dataset_name=dataset_info[name]["hf_hub_url"]
162 | )
163 | elif "script_url" in dataset_info[name]:
164 | dataset_attr = DatasetAttr(
165 | "script",
166 | dataset_name=dataset_info[name]["script_url"]
167 | )
168 | else:
169 | dataset_attr = DatasetAttr(
170 | "file",
171 | dataset_name=dataset_info[name]["file_name"],
172 | dataset_sha1=dataset_info[name].get("file_sha1", None)
173 | )
174 |
175 | if "columns" in dataset_info[name]:
176 | dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
177 | dataset_attr.query = dataset_info[name]["columns"].get("query", None)
178 | dataset_attr.response = dataset_info[name]["columns"].get("response", None)
179 | dataset_attr.history = dataset_info[name]["columns"].get("history", None)
180 | dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
181 | dataset_attr.role = dataset_info[name]["columns"].get("role", None)
182 | dataset_attr.content = dataset_info[name]["columns"].get("content", None)
183 | dataset_attr.system = dataset_info[name]["columns"].get("system", None)
184 |
185 | dataset_attr.subset = dataset_info[name].get("subset", None)
186 | dataset_attr.folder = dataset_info[name].get("folder", None)
187 | dataset_attr.ranking = dataset_info[name].get("ranking", False)
188 | dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
189 | self.dataset_list.append(dataset_attr)
190 |
--------------------------------------------------------------------------------
/model/src/hparams/evaluation_args.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Literal, Optional
3 | from dataclasses import dataclass, field
4 |
5 | from datasets import DownloadMode
6 |
7 |
8 | @dataclass
9 | class EvaluationArguments:
10 | r"""
11 | Arguments pertaining to specify the evaluation parameters.
12 | """
13 | task: str = field(
14 | metadata={"help": "Name of the evaluation task."}
15 | )
16 | task_dir: Optional[str] = field(
17 | default="evaluation",
18 | metadata={"help": "Path to the folder containing the evaluation datasets."}
19 | )
20 | batch_size: Optional[int] = field(
21 | default=4,
22 | metadata={"help": "The batch size per GPU for evaluation."}
23 | )
24 | seed: Optional[int] = field(
25 | default=42,
26 | metadata={"help": "Random seed to be used with data loaders."}
27 | )
28 | lang: Optional[Literal["en", "zh"]] = field(
29 | default="en",
30 | metadata={"help": "Language used at evaluation."}
31 | )
32 | n_shot: Optional[int] = field(
33 | default=5,
34 | metadata={"help": "Number of examplars for few-shot learning."}
35 | )
36 | save_dir: Optional[str] = field(
37 | default=None,
38 | metadata={"help": "Path to save the evaluation results."}
39 | )
40 | download_mode: Optional[DownloadMode] = field(
41 | default=DownloadMode.REUSE_DATASET_IF_EXISTS,
42 | metadata={"help": "Download mode used for the evaluation datasets."}
43 | )
44 |
45 | def __post_init__(self):
46 | task_available = []
47 | for folder in os.listdir(self.task_dir):
48 | if os.path.isdir(os.path.join(self.task_dir, folder)):
49 | task_available.append(folder)
50 |
51 | if self.task not in task_available:
52 | raise ValueError("Task {} not found in {}.".format(self.task, self.task_dir))
53 |
54 | if self.save_dir is not None and os.path.exists(self.save_dir):
55 | raise ValueError("`save_dir` already exists, use another one.")
56 |
--------------------------------------------------------------------------------
/model/src/hparams/finetuning_args.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Literal, Optional
3 | from dataclasses import asdict, dataclass, field
4 |
5 |
6 | @dataclass
7 | class FreezeArguments:
8 | r"""
9 | Arguments pertaining to the freeze (partial-parameter) training.
10 | """
11 | name_module_trainable: Optional[str] = field(
12 | default="mlp",
13 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
14 | Use commas to separate multiple modules. \
15 | LLaMA choices: [\"mlp\", \"self_attn\"], \
16 | BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
17 | Qwen choices: [\"mlp\", \"attn\"], \
18 | Phi choices: [\"mlp\", \"mixer\"], \
19 | Others choices: the same as LLaMA."}
20 | )
21 | num_layer_trainable: Optional[int] = field(
22 | default=3,
23 | metadata={"help": "The number of trainable layers for partial-parameter (freeze) fine-tuning."}
24 | )
25 |
26 |
27 | @dataclass
28 | class LoraArguments:
29 | r"""
30 | Arguments pertaining to the LoRA training.
31 | """
32 | additional_target: Optional[str] = field(
33 | default=None,
34 | metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
35 | )
36 | lora_alpha: Optional[int] = field(
37 | default=None,
38 | metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}
39 | )
40 | lora_dropout: Optional[float] = field(
41 | default=0.0,
42 | metadata={"help": "Dropout rate for the LoRA fine-tuning."}
43 | )
44 | lora_rank: Optional[int] = field(
45 | default=8,
46 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
47 | )
48 | lora_target: Optional[str] = field(
49 | default=None,
50 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
51 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
52 | BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"dense\", \"dense_h_to_4h\", \"dense_4h_to_h\"], \
53 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
54 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
55 | Phi choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
56 | Others choices: the same as LLaMA."}
57 | )
58 | create_new_adapter: Optional[bool] = field(
59 | default=False,
60 | metadata={"help": "Whether to create a new adapter with randomly initialized weight or not."}
61 | )
62 |
63 |
64 | @dataclass
65 | class RLHFArguments:
66 | r"""
67 | Arguments pertaining to the PPO and DPO training.
68 | """
69 | dpo_beta: Optional[float] = field(
70 | default=0.1,
71 | metadata={"help": "The beta parameter for the DPO loss."}
72 | )
73 | dpo_loss: Optional[Literal["sigmoid", "hinge", "ipo", "kto"]] = field(
74 | default="sigmoid",
75 | metadata={"help": "The type of DPO loss to use."}
76 | )
77 | dpo_ftx: Optional[float] = field(
78 | default=0,
79 | metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}
80 | )
81 | ppo_buffer_size: Optional[int] = field(
82 | default=1,
83 | metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}
84 | )
85 | ppo_epochs: Optional[int] = field(
86 | default=4,
87 | metadata={"help": "The number of epochs to perform in a PPO optimization step."}
88 | )
89 | ppo_logger: Optional[str] = field(
90 | default=None,
91 | metadata={"help": "Log with either \"wandb\" or \"tensorboard\" in PPO training."}
92 | )
93 | ppo_score_norm: Optional[bool] = field(
94 | default=False,
95 | metadata={"help": "Use score normalization in PPO training."}
96 | )
97 | ppo_target: Optional[float] = field(
98 | default=6.0,
99 | metadata={"help": "Target KL value for adaptive KL control in PPO training."}
100 | )
101 | ppo_whiten_rewards: Optional[bool] = field(
102 | default=False,
103 | metadata={"help": "Whiten the rewards before compute advantages in PPO training."}
104 | )
105 | ref_model: Optional[str] = field(
106 | default=None,
107 | metadata={"help": "Path to the reference model used for the PPO or DPO training."}
108 | )
109 | ref_model_adapters: Optional[str] = field(
110 | default=None,
111 | metadata={"help": "Path to the adapters of the reference model."}
112 | )
113 | ref_model_quantization_bit: Optional[int] = field(
114 | default=None,
115 | metadata={"help": "The number of bits to quantize the reference model."}
116 | )
117 | reward_model: Optional[str] = field(
118 | default=None,
119 | metadata={"help": "Path to the reward model used for the PPO training."}
120 | )
121 | reward_model_adapters: Optional[str] = field(
122 | default=None,
123 | metadata={"help": "Path to the adapters of the reward model."}
124 | )
125 | reward_model_quantization_bit: Optional[int] = field(
126 | default=None,
127 | metadata={"help": "The number of bits to quantize the reward model."}
128 | )
129 | reward_model_type: Optional[Literal["lora", "full", "api"]] = field(
130 | default="lora",
131 | metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}
132 | )
133 |
134 |
135 | @dataclass
136 | class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments):
137 | r"""
138 | Arguments pertaining to which techniques we are going to fine-tuning with.
139 | """
140 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
141 | default="sft",
142 | metadata={"help": "Which stage will be performed in training."}
143 | )
144 | finetuning_type: Optional[Literal["lora", "freeze", "full"]] = field(
145 | default="lora",
146 | metadata={"help": "Which fine-tuning method to use."}
147 | )
148 | plot_loss: Optional[bool] = field(
149 | default=False,
150 | metadata={"help": "Whether or not to save the training loss curves."}
151 | )
152 |
153 | def __post_init__(self):
154 | def split_arg(arg):
155 | if isinstance(arg, str):
156 | return [item.strip() for item in arg.split(",")]
157 | return arg
158 |
159 | self.name_module_trainable = split_arg(self.name_module_trainable)
160 | self.lora_alpha = self.lora_alpha or self.lora_rank * 2
161 | self.lora_target = split_arg(self.lora_target)
162 | self.additional_target = split_arg(self.additional_target)
163 | self.ref_model_adapters = split_arg(self.ref_model_adapters)
164 | self.reward_model_adapters = split_arg(self.reward_model_adapters)
165 |
166 | assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
167 | assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
168 | assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
169 |
170 | if self.stage == "ppo" and self.reward_model is None:
171 | raise ValueError("Reward model is necessary for PPO training.")
172 |
173 | if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
174 | raise ValueError("Freeze/Full PPO training needs `reward_model_type=full`.")
175 |
176 | def save_to_json(self, json_path: str):
177 | r"""Saves the content of this instance in JSON format inside `json_path`."""
178 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
179 | with open(json_path, "w", encoding="utf-8") as f:
180 | f.write(json_string)
181 |
182 | @classmethod
183 | def load_from_json(cls, json_path: str):
184 | r"""Creates an instance from the content of `json_path`."""
185 | with open(json_path, "r", encoding="utf-8") as f:
186 | text = f.read()
187 |
188 | return cls(**json.loads(text))
189 |
--------------------------------------------------------------------------------
/model/src/hparams/generating_args.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 | from dataclasses import asdict, dataclass, field
3 |
4 |
5 | @dataclass
6 | class GeneratingArguments:
7 | r"""
8 | Arguments pertaining to specify the decoding parameters.
9 | """
10 | do_sample: Optional[bool] = field(
11 | default=True,
12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
13 | )
14 | temperature: Optional[float] = field(
15 | default=0.95,
16 | metadata={"help": "The value used to modulate the next token probabilities."}
17 | )
18 | top_p: Optional[float] = field(
19 | default=0.7,
20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
21 | )
22 | top_k: Optional[int] = field(
23 | default=50,
24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
25 | )
26 | num_beams: Optional[int] = field(
27 | default=1,
28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."}
29 | )
30 | num_return_sequences: Optional[int] = field(
31 | default=1,
32 | metadata={"help": "Number of generated sentences"}
33 | )
34 | max_length: Optional[int] = field(
35 | default=512,
36 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
37 | )
38 | max_new_tokens: Optional[int] = field(
39 | default=512,
40 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
41 | )
42 | repetition_penalty: Optional[float] = field(
43 | default=1.0,
44 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
45 | )
46 | length_penalty: Optional[float] = field(
47 | default=1.0,
48 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}
49 | )
50 |
51 | def to_dict(self) -> Dict[str, Any]:
52 | args = asdict(self)
53 | if args.get("max_new_tokens", -1) > 0:
54 | args.pop("max_length", None)
55 | else:
56 | args.pop("max_new_tokens", None)
57 | return args
58 |
--------------------------------------------------------------------------------
/model/src/hparams/model_args.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Literal, Optional
2 | from dataclasses import asdict, dataclass, field
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | r"""
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
9 | """
10 | model_name_or_path: str = field(
11 | metadata={"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."}
12 | )
13 | adapter_name_or_path: Optional[str] = field(
14 | default=None,
15 | metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."}
16 | )
17 | cache_dir: Optional[str] = field(
18 | default=None,
19 | metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}
20 | )
21 | use_fast_tokenizer: Optional[bool] = field(
22 | default=False,
23 | metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}
24 | )
25 | resize_vocab: Optional[bool] = field(
26 | default=False,
27 | metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}
28 | )
29 | split_special_tokens: Optional[bool] = field(
30 | default=False,
31 | metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
32 | )
33 | model_revision: Optional[str] = field(
34 | default="main",
35 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
36 | )
37 | quantization_bit: Optional[int] = field(
38 | default=None,
39 | metadata={"help": "The number of bits to quantize the model."}
40 | )
41 | quantization_type: Optional[Literal["fp4", "nf4"]] = field(
42 | default="nf4",
43 | metadata={"help": "Quantization data type to use in int4 training."}
44 | )
45 | double_quantization: Optional[bool] = field(
46 | default=True,
47 | metadata={"help": "Whether or not to use double quantization in int4 training."}
48 | )
49 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
50 | default=None,
51 | metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}
52 | )
53 | flash_attn: Optional[bool] = field(
54 | default=False,
55 | metadata={"help": "Enable FlashAttention-2 for faster training."}
56 | )
57 | shift_attn: Optional[bool] = field(
58 | default=False,
59 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
60 | )
61 | use_unsloth: Optional[bool] = field(
62 | default=False,
63 | metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}
64 | )
65 | disable_gradient_checkpointing: Optional[bool] = field(
66 | default=False,
67 | metadata={"help": "Whether or not to disable gradient checkpointing."}
68 | )
69 | upcast_layernorm: Optional[bool] = field(
70 | default=False,
71 | metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}
72 | )
73 | hf_hub_token: Optional[str] = field(
74 | default=None,
75 | metadata={"help": "Auth token to log in with Hugging Face Hub."}
76 | )
77 | ms_hub_token: Optional[str] = field(
78 | default=None,
79 | metadata={"help": "Auth token to log in with ModelScope Hub."}
80 | )
81 | export_dir: Optional[str] = field(
82 | default=None,
83 | metadata={"help": "Path to the directory to save the exported model."}
84 | )
85 | export_size: Optional[int] = field(
86 | default=1,
87 | metadata={"help": "The file shard size (in GB) of the exported model."}
88 | )
89 | export_quantization_bit: Optional[int] = field(
90 | default=None,
91 | metadata={"help": "The number of bits to quantize the exported model."}
92 | )
93 | export_quantization_dataset: Optional[str] = field(
94 | default=None,
95 | metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}
96 | )
97 | export_quantization_nsamples: Optional[int] = field(
98 | default=128,
99 | metadata={"help": "The number of samples used for quantization."}
100 | )
101 | export_quantization_maxlen: Optional[int] = field(
102 | default=1024,
103 | metadata={"help": "The maximum length of the model inputs used for quantization."}
104 | )
105 | export_legacy_format: Optional[bool] = field(
106 | default=False,
107 | metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}
108 | )
109 |
110 | def __post_init__(self):
111 | self.compute_dtype = None
112 | self.model_max_length = None
113 |
114 | if self.split_special_tokens and self.use_fast_tokenizer:
115 | raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
116 |
117 | if self.adapter_name_or_path is not None: # support merging multiple lora weights
118 | self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")]
119 |
120 | assert self.quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
121 | assert self.export_quantization_bit in [None, 8, 4, 3, 2], "We only accept 2/3/4/8-bit quantization."
122 |
123 | if self.export_quantization_bit is not None and self.export_quantization_dataset is None:
124 | raise ValueError("Quantization dataset is necessary for exporting.")
125 |
126 | def to_dict(self) -> Dict[str, Any]:
127 | return asdict(self)
128 |
--------------------------------------------------------------------------------
/model/src/model/adapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING
3 | from transformers.integrations import is_deepspeed_zero3_enabled
4 | from peft import PeftModel, TaskType, LoraConfig, get_peft_model
5 |
6 | from ..extras.logging import get_logger
7 | from .utils import find_all_linear_modules
8 |
9 | if TYPE_CHECKING:
10 | from transformers.modeling_utils import PreTrainedModel
11 | from ..hparams.model_args import ModelArguments
12 | from ..hparams.finetuning_args import FinetuningArguments
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def init_adapter(
18 | model: "PreTrainedModel",
19 | model_args: "ModelArguments",
20 | finetuning_args: "FinetuningArguments",
21 | is_trainable: bool
22 | ) -> "PreTrainedModel":
23 | r"""
24 | Initializes the adapters.
25 |
26 | Support full-parameter, freeze and LoRA training.
27 |
28 | Note that the trainable parameters must be cast to float32.
29 | """
30 |
31 | if (not is_trainable) and model_args.adapter_name_or_path is None:
32 | logger.info("Adapter is not found at evaluation, load the base model.")
33 | return model
34 |
35 | if finetuning_args.finetuning_type == "full" and is_trainable:
36 | logger.info("Fine-tuning method: Full")
37 | model = model.float()
38 |
39 | if finetuning_args.finetuning_type == "freeze" and is_trainable:
40 | logger.info("Fine-tuning method: Freeze")
41 | num_layers = (
42 | getattr(model.config, "num_hidden_layers", None)
43 | or getattr(model.config, "num_layers", None)
44 | or getattr(model.config, "n_layer", None)
45 | )
46 | if not num_layers:
47 | raise ValueError("Current model does not support freeze tuning.")
48 |
49 | if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
50 | trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
51 | else: # fine-tuning the first n layers if num_layer_trainable < 0
52 | trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
53 |
54 | trainable_layers = []
55 | for module_name in finetuning_args.name_module_trainable:
56 | for idx in trainable_layer_ids:
57 | trainable_layers.append("{:d}.{}".format(idx, module_name))
58 |
59 | for name, param in model.named_parameters():
60 | if not any(trainable_layer in name for trainable_layer in trainable_layers):
61 | param.requires_grad_(False)
62 | else:
63 | param.data = param.data.to(torch.float32)
64 |
65 | if finetuning_args.finetuning_type == "lora":
66 | logger.info("Fine-tuning method: LoRA")
67 | adapter_to_resume = None
68 |
69 | if model_args.adapter_name_or_path is not None:
70 | is_mergeable = True
71 | if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable
72 | assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter."
73 | is_mergeable = False
74 |
75 | if is_deepspeed_zero3_enabled():
76 | assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3."
77 | is_mergeable = False
78 |
79 | if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable):
80 | adapter_to_merge = model_args.adapter_name_or_path[:-1]
81 | adapter_to_resume = model_args.adapter_name_or_path[-1]
82 | else:
83 | adapter_to_merge = model_args.adapter_name_or_path
84 |
85 | for adapter in adapter_to_merge:
86 | model = PeftModel.from_pretrained(model, adapter)
87 | model = model.merge_and_unload()
88 |
89 | if len(adapter_to_merge) > 0:
90 | logger.info("Merged {} adapter(s).".format(len(adapter_to_merge)))
91 |
92 | if adapter_to_resume is not None: # resume lora training
93 | model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable)
94 |
95 | if is_trainable and adapter_to_resume is None: # create new lora weights while training
96 | if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
97 | target_modules = find_all_linear_modules(model)
98 | else:
99 | target_modules = finetuning_args.lora_target
100 |
101 | peft_kwargs = {
102 | "r": finetuning_args.lora_rank,
103 | "target_modules": target_modules,
104 | "lora_alpha": finetuning_args.lora_alpha,
105 | "lora_dropout": finetuning_args.lora_dropout
106 | }
107 |
108 | if model_args.use_unsloth:
109 | from unsloth import FastLlamaModel, FastMistralModel # type: ignore
110 | unsloth_peft_kwargs = {"model": model, "max_seq_length": model_args.model_max_length}
111 | if getattr(model.config, "model_type", None) == "llama":
112 | model = FastLlamaModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
113 | elif getattr(model.config, "model_type", None) == "mistral":
114 | model = FastMistralModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs)
115 | else:
116 | raise NotImplementedError
117 |
118 | else:
119 | lora_config = LoraConfig(
120 | task_type=TaskType.CAUSAL_LM,
121 | inference_mode=False,
122 | modules_to_save=finetuning_args.additional_target,
123 | **peft_kwargs
124 | )
125 | model = get_peft_model(model, lora_config)
126 |
127 | for param in filter(lambda p: p.requires_grad, model.parameters()):
128 | param.data = param.data.to(torch.float32)
129 |
130 | if model_args.adapter_name_or_path is not None:
131 | logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
132 |
133 | return model
134 |
--------------------------------------------------------------------------------
/model/src/model/loader.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Optional, Tuple
2 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
3 | from transformers.integrations import is_deepspeed_zero3_enabled
4 | from transformers.utils.versions import require_version
5 | from trl import AutoModelForCausalLMWithValueHead
6 |
7 | from ..extras.logging import get_logger
8 | from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms
9 | from .adapter import init_adapter
10 | from .patcher import patch_config, patch_tokenizer, patch_model, patch_valuehead_model
11 | from .utils import load_valuehead_params, register_autoclass
12 |
13 | if TYPE_CHECKING:
14 | from transformers import PreTrainedModel, PreTrainedTokenizer
15 | from ..hparams.model_args import ModelArguments
16 | from ..hparams.finetuning_args import FinetuningArguments
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | require_version("transformers>=4.36.2", "To fix: pip install transformers>=4.36.2")
23 | require_version("datasets>=2.14.3", "To fix: pip install datasets>=2.14.3")
24 | require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
25 | require_version("peft>=0.7.0", "To fix: pip install peft>=0.7.0")
26 | require_version("trl>=0.7.6", "To fix: pip install trl>=0.7.6")
27 |
28 |
29 | def load_model_and_tokenizer(
30 | model_args: "ModelArguments",
31 | finetuning_args: "FinetuningArguments",
32 | is_trainable: Optional[bool] = False,
33 | add_valuehead: Optional[bool] = False
34 | ) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
35 | r"""
36 | Loads pretrained model and tokenizer.
37 |
38 | Support both training and inference.
39 | """
40 |
41 | try_download_model_from_ms(model_args)
42 |
43 | config_kwargs = {
44 | "trust_remote_code": True,
45 | "cache_dir": model_args.cache_dir,
46 | "revision": model_args.model_revision,
47 | "token": model_args.hf_hub_token
48 | }
49 |
50 | tokenizer = AutoTokenizer.from_pretrained(
51 | model_args.model_name_or_path,
52 | use_fast=model_args.use_fast_tokenizer,
53 | split_special_tokens=model_args.split_special_tokens,
54 | padding_side="right",
55 | **config_kwargs
56 | )
57 | patch_tokenizer(tokenizer)
58 |
59 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
60 | patch_config(config, tokenizer, model_args, config_kwargs, is_trainable)
61 |
62 | model = None
63 | if is_trainable and model_args.use_unsloth:
64 | require_version("unsloth", "Follow the instructions at: https://github.com/unslothai/unsloth")
65 | from unsloth import FastLlamaModel, FastMistralModel # type: ignore
66 | unsloth_kwargs = {
67 | "model_name": model_args.model_name_or_path,
68 | "max_seq_length": model_args.model_max_length,
69 | "dtype": model_args.compute_dtype,
70 | "load_in_4bit": model_args.quantization_bit == 4,
71 | "token": model_args.hf_hub_token,
72 | "device_map": get_current_device(),
73 | "rope_scaling": getattr(config, "rope_scaling", None)
74 | }
75 | if getattr(config, "model_type", None) == "llama":
76 | model, _ = FastLlamaModel.from_pretrained(**unsloth_kwargs)
77 | elif getattr(config, "model_type", None) == "mistral":
78 | model, _ = FastMistralModel.from_pretrained(**unsloth_kwargs)
79 | else:
80 | logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
81 | model_args.use_unsloth = False
82 |
83 | if model_args.adapter_name_or_path:
84 | model_args.adapter_name_or_path = None
85 | logger.warning("Unsloth does not support loading adapters.")
86 |
87 | if model is None:
88 | model = AutoModelForCausalLM.from_pretrained(
89 | model_args.model_name_or_path,
90 | config=config,
91 | torch_dtype=model_args.compute_dtype,
92 | low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
93 | **config_kwargs
94 | )
95 |
96 | patch_model(model, tokenizer, model_args, is_trainable)
97 | register_autoclass(config, model, tokenizer)
98 |
99 | model = init_adapter(model, model_args, finetuning_args, is_trainable)
100 |
101 | if add_valuehead:
102 | model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
103 | patch_valuehead_model(model)
104 |
105 | if model_args.adapter_name_or_path is not None:
106 | vhead_path = model_args.adapter_name_or_path[-1]
107 | else:
108 | vhead_path = model_args.model_name_or_path
109 |
110 | vhead_params = load_valuehead_params(vhead_path, model_args)
111 | if vhead_params is not None:
112 | model.load_state_dict(vhead_params, strict=False)
113 | logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path))
114 |
115 | if not is_trainable:
116 | model.requires_grad_(False)
117 | model = model.to(model_args.compute_dtype) if not getattr(model, "quantization_method", None) else model
118 | model.eval()
119 | else:
120 | model.train()
121 |
122 | trainable_params, all_param = count_parameters(model)
123 | logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
124 | trainable_params, all_param, 100 * trainable_params / all_param
125 | ))
126 |
127 | if not is_trainable:
128 | logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
129 |
130 | return model, tokenizer
131 |
--------------------------------------------------------------------------------
/model/src/model/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import logging
5 | import datasets
6 | import transformers
7 | from typing import Any, Dict, Optional, Tuple
8 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments
9 | from transformers.trainer_utils import get_last_checkpoint
10 |
11 | from ..extras.logging import get_logger
12 | from ..hparams.model_args import ModelArguments
13 | from ..hparams.data_args import DataArguments
14 | from ..hparams.evaluation_args import EvaluationArguments
15 | from ..hparams.finetuning_args import FinetuningArguments
16 | from ..hparams.generating_args import GeneratingArguments
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | _TRAIN_ARGS = [
22 | ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
23 | ]
24 | _TRAIN_CLS = Tuple[
25 | ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
26 | ]
27 | _INFER_ARGS = [
28 | ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
29 | ]
30 | _INFER_CLS = Tuple[
31 | ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
32 | ]
33 | _EVAL_ARGS = [
34 | ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
35 | ]
36 | _EVAL_CLS = Tuple[
37 | ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
38 | ]
39 |
40 |
41 | def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
42 | if args is not None:
43 | return parser.parse_dict(args)
44 |
45 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
46 | return parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
47 |
48 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
49 | return parser.parse_json_file(os.path.abspath(sys.argv[1]))
50 |
51 | (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True)
52 |
53 | if unknown_args:
54 | print(parser.format_help())
55 | print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args))
56 | raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args))
57 |
58 | return (*parsed_args,)
59 |
60 |
61 | def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
62 | datasets.utils.logging.set_verbosity(log_level)
63 | transformers.utils.logging.set_verbosity(log_level)
64 | transformers.utils.logging.enable_default_handler()
65 | transformers.utils.logging.enable_explicit_format()
66 |
67 |
68 | def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None:
69 | if model_args.quantization_bit is not None:
70 | if finetuning_args.finetuning_type != "lora":
71 | raise ValueError("Quantization is only compatible with the LoRA method.")
72 |
73 | if finetuning_args.create_new_adapter:
74 | raise ValueError("Cannot create new adapter upon a quantized model.")
75 |
76 | if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
77 | if finetuning_args.finetuning_type != "lora":
78 | raise ValueError("Multiple adapters are only available for LoRA tuning.")
79 |
80 | if model_args.quantization_bit is not None:
81 | raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
82 |
83 |
84 | def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
85 | parser = HfArgumentParser(_TRAIN_ARGS)
86 | return _parse_args(parser, args)
87 |
88 |
89 | def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
90 | parser = HfArgumentParser(_INFER_ARGS)
91 | return _parse_args(parser, args)
92 |
93 |
94 | def _parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
95 | parser = HfArgumentParser(_EVAL_ARGS)
96 | return _parse_args(parser, args)
97 |
98 |
99 | def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
100 | model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args)
101 |
102 | # Setup logging
103 | if training_args.should_log:
104 | _set_transformers_logging()
105 |
106 | # Check arguments
107 | data_args.init_for_training(training_args.seed)
108 |
109 | if finetuning_args.stage != "pt" and data_args.template is None:
110 | raise ValueError("Please specify which `template` to use.")
111 |
112 | if finetuning_args.stage != "sft" and training_args.predict_with_generate:
113 | raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
114 |
115 | if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
116 | raise ValueError("Please enable `predict_with_generate` to save model predictions.")
117 |
118 | if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end:
119 | raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
120 |
121 | if finetuning_args.stage == "ppo" and not training_args.do_train:
122 | raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.")
123 |
124 | if finetuning_args.stage in ["rm", "dpo"] and (not all([data_attr.ranking for data_attr in data_args.dataset_list])):
125 | raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
126 |
127 | if finetuning_args.stage == "ppo" and model_args.shift_attn:
128 | raise ValueError("PPO training is incompatible with S^2-Attn.")
129 |
130 | if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth:
131 | raise ValueError("Unsloth does not support lora reward model.")
132 |
133 | if training_args.max_steps == -1 and data_args.streaming:
134 | raise ValueError("Please specify `max_steps` in streaming mode.")
135 |
136 | if training_args.do_train and training_args.predict_with_generate:
137 | raise ValueError("`predict_with_generate` cannot be set as True while training.")
138 |
139 | if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
140 | raise ValueError("Please specify `lora_target` in LoRA training.")
141 |
142 | _verify_model_args(model_args, finetuning_args)
143 |
144 | if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
145 | logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
146 |
147 | if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
148 | logger.warning("We recommend enable mixed precision training.")
149 |
150 | if (not training_args.do_train) and model_args.quantization_bit is not None:
151 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
152 |
153 | if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
154 | logger.warning("Specify `ref_model` for computing rewards at evaluation.")
155 |
156 | # postprocess training_args
157 | if (
158 | training_args.local_rank != -1
159 | and training_args.ddp_find_unused_parameters is None
160 | and finetuning_args.finetuning_type == "lora"
161 | ):
162 | logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
163 | training_args_dict = training_args.to_dict()
164 | training_args_dict.update(dict(ddp_find_unused_parameters=False))
165 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
166 |
167 | if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
168 | can_resume_from_checkpoint = False
169 | training_args.resume_from_checkpoint = None
170 | else:
171 | can_resume_from_checkpoint = True
172 |
173 | if (
174 | training_args.resume_from_checkpoint is None
175 | and training_args.do_train
176 | and os.path.isdir(training_args.output_dir)
177 | and not training_args.overwrite_output_dir
178 | and can_resume_from_checkpoint
179 | ):
180 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
181 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
182 | raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
183 |
184 | if last_checkpoint is not None:
185 | training_args_dict = training_args.to_dict()
186 | training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
187 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
188 | logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
189 | training_args.resume_from_checkpoint
190 | ))
191 |
192 | if (
193 | finetuning_args.stage in ["rm", "ppo"]
194 | and finetuning_args.finetuning_type == "lora"
195 | and training_args.resume_from_checkpoint is not None
196 | ):
197 | logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
198 | training_args.resume_from_checkpoint
199 | ))
200 |
201 | # postprocess model_args
202 | model_args.compute_dtype = (
203 | torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
204 | )
205 | model_args.model_max_length = data_args.cutoff_len
206 |
207 | # Log on each process the small summary:
208 | logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
209 | training_args.local_rank, training_args.device, training_args.n_gpu,
210 | bool(training_args.local_rank != -1), str(model_args.compute_dtype)
211 | ))
212 | logger.info(f"Training/evaluation parameters {training_args}")
213 |
214 | # Set seed before initializing model.
215 | transformers.set_seed(training_args.seed)
216 |
217 | return model_args, data_args, training_args, finetuning_args, generating_args
218 |
219 |
220 | def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS:
221 | model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
222 | _set_transformers_logging()
223 |
224 | if data_args.template is None:
225 | raise ValueError("Please specify which `template` to use.")
226 |
227 | _verify_model_args(model_args, finetuning_args)
228 |
229 | return model_args, data_args, finetuning_args, generating_args
230 |
231 |
232 | def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS:
233 | model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
234 | _set_transformers_logging()
235 |
236 | if data_args.template is None:
237 | raise ValueError("Please specify which `template` to use.")
238 |
239 | _verify_model_args(model_args, finetuning_args)
240 |
241 | transformers.set_seed(eval_args.seed)
242 |
243 | return model_args, data_args, eval_args, finetuning_args
244 |
--------------------------------------------------------------------------------
/model/src/model/patcher.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | import random
5 | from types import MethodType
6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
7 | from datasets import load_dataset
8 |
9 | from transformers import BitsAndBytesConfig, GPTQConfig, PreTrainedModel, PreTrainedTokenizerBase
10 | from transformers.integrations import is_deepspeed_zero3_enabled
11 | from transformers.utils.versions import require_version
12 |
13 | from ..extras.constants import FILEEXT2TYPE, LAYERNORM_NAMES
14 | from ..extras.logging import get_logger
15 | from ..extras.misc import get_current_device, infer_optim_dtype
16 | from ..extras.packages import is_flash_attn2_available
17 |
18 | if TYPE_CHECKING:
19 | from transformers import PretrainedConfig, PreTrainedTokenizer
20 | from trl import AutoModelForCausalLMWithValueHead
21 | from ..hparams.model_args import ModelArguments
22 |
23 |
24 | logger = get_logger(__name__)
25 | SUPPORTED_CLASS_FOR_S2ATTN = [] # TODO: add llama
26 |
27 |
28 | def _noisy_mean_initialization(embed_weight: torch.Tensor, num_new_tokens: int):
29 | embedding_dim = embed_weight.size(1)
30 | avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True)
31 | noise_weight = torch.empty_like(avg_weight[-num_new_tokens:])
32 | noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim)))
33 | embed_weight[-num_new_tokens:] = avg_weight + noise_weight
34 |
35 |
36 | def _resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None:
37 | r"""
38 | Resize token embeddings.
39 | """
40 | current_embedding_size = model.get_input_embeddings().weight.size(0)
41 | if len(tokenizer) > current_embedding_size:
42 | if not isinstance(model.get_output_embeddings(), torch.nn.Linear):
43 | logger.warning("Current model does not support resizing token embeddings.")
44 | return
45 |
46 | model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64)
47 | new_embedding_size = model.get_input_embeddings().weight.size(0)
48 | num_new_tokens = new_embedding_size - current_embedding_size
49 | _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
50 | _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
51 |
52 | logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size))
53 |
54 |
55 | def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[str]:
56 | r"""
57 | Inspired by: https://github.com/huggingface/optimum/blob/v1.16.0/optimum/gptq/data.py#L133
58 | TODO: remove tokenizer.decode() https://github.com/huggingface/optimum/pull/1600
59 | """
60 | if os.path.isfile(model_args.export_quantization_dataset):
61 | data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None)
62 | data_files = model_args.export_quantization_dataset
63 | else:
64 | data_path = model_args.export_quantization_dataset
65 | data_files = None
66 |
67 | dataset = load_dataset(path=data_path, data_files=data_files, split="train", cache_dir=model_args.cache_dir)
68 | maxlen = model_args.export_quantization_maxlen
69 |
70 | samples = []
71 | for _ in range(model_args.export_quantization_nsamples):
72 | while True:
73 | sample_idx = random.randint(0, len(dataset) - 1)
74 | sample: Dict[str, torch.Tensor] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt")
75 | if sample["input_ids"].size(1) >= maxlen:
76 | break # TODO: fix large maxlen
77 |
78 | word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1)
79 | input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen]
80 | samples.append(tokenizer.decode(input_ids[0].tolist(), skip_special_tokens=True))
81 |
82 | return samples
83 |
84 |
85 | def _configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
86 | if not hasattr(config, "rope_scaling"):
87 | logger.warning("Current model does not support RoPE scaling.")
88 | return
89 |
90 | if is_trainable:
91 | if model_args.rope_scaling == "dynamic":
92 | logger.warning(
93 | "Dynamic NTK scaling may not work well with fine-tuning. "
94 | "See: https://github.com/huggingface/transformers/pull/24653"
95 | )
96 |
97 | current_max_length = getattr(config, "max_position_embeddings", None)
98 | if current_max_length and model_args.model_max_length > current_max_length:
99 | scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
100 | else:
101 | logger.warning("Input length is smaller than max length. Consider increase input length.")
102 | scaling_factor = 1.0
103 | else:
104 | scaling_factor = 2.0
105 |
106 | setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
107 | logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
108 | model_args.rope_scaling, scaling_factor
109 | ))
110 |
111 |
112 | def _configure_flashattn(config_kwargs: Dict[str, Any]) -> None:
113 | if not is_flash_attn2_available():
114 | logger.warning("FlashAttention2 is not installed.")
115 | return
116 |
117 | config_kwargs["use_flash_attention_2"] = True
118 | logger.info("Using FlashAttention-2 for faster training and inference.")
119 |
120 |
121 | def _configure_longlora(config: "PretrainedConfig") -> None:
122 | if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
123 | setattr(config, "group_size_ratio", 0.25)
124 | logger.info("Using shift short attention with group_size_ratio=1/4.")
125 | else:
126 | logger.warning("Current model does not support shift short attention.")
127 |
128 |
129 | def _configure_quantization(
130 | config: "PretrainedConfig",
131 | tokenizer: "PreTrainedTokenizer",
132 | model_args: "ModelArguments",
133 | config_kwargs: Dict[str, Any]
134 | ) -> None:
135 | r"""
136 | Priority: GPTQ-quantized (training) > AutoGPTQ (export) > Bitsandbytes (training)
137 | """
138 | if getattr(config, "quantization_config", None): # gptq
139 | if is_deepspeed_zero3_enabled():
140 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
141 |
142 | config_kwargs["device_map"] = {"": get_current_device()}
143 | quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
144 | if quantization_config.get("quant_method", None) == "gptq" and quantization_config.get("bits", -1) == 4:
145 | quantization_config["use_exllama"] = False # disable exllama
146 | logger.info("Loading {}-bit GPTQ-quantized model.".format(quantization_config.get("bits", -1)))
147 |
148 | elif model_args.export_quantization_bit is not None: # auto-gptq
149 | require_version("optimum>=1.16.0", "To fix: pip install optimum>=1.16.0")
150 | require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0")
151 | from accelerate.utils import get_max_memory
152 |
153 | if getattr(config, "model_type", None) == "chatglm":
154 | raise ValueError("ChatGLM model is not supported.")
155 |
156 | config_kwargs["quantization_config"] = GPTQConfig(
157 | bits=model_args.export_quantization_bit,
158 | tokenizer=tokenizer,
159 | dataset=_get_quantization_dataset(tokenizer, model_args)
160 | )
161 | config_kwargs["device_map"] = "auto"
162 | config_kwargs["max_memory"] = get_max_memory()
163 | logger.info("Quantizing model to {} bit.".format(model_args.export_quantization_bit))
164 |
165 | elif model_args.quantization_bit is not None: # bnb
166 | if is_deepspeed_zero3_enabled():
167 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
168 |
169 | if model_args.quantization_bit == 8:
170 | require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
171 | config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
172 |
173 | elif model_args.quantization_bit == 4:
174 | require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
175 | config_kwargs["quantization_config"] = BitsAndBytesConfig(
176 | load_in_4bit=True,
177 | bnb_4bit_compute_dtype=model_args.compute_dtype,
178 | bnb_4bit_use_double_quant=model_args.double_quantization,
179 | bnb_4bit_quant_type=model_args.quantization_type
180 | )
181 |
182 | config_kwargs["device_map"] = {"": get_current_device()}
183 | logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
184 |
185 |
186 | def _prepare_model_for_training(
187 | model: "PreTrainedModel",
188 | model_args: "ModelArguments",
189 | output_layer_name: Optional[str] = "lm_head"
190 | ) -> None:
191 | r"""
192 | Includes:
193 | (1) cast the layernorm in fp32
194 | (2) make output embedding layer require grads
195 | (3) add the upcasting of the lm_head in fp32
196 | Inspired by: https://github.com/huggingface/peft/blob/v0.7.1/src/peft/utils/other.py#L72
197 | """
198 | if model_args.upcast_layernorm:
199 | for name, param in model.named_parameters():
200 | if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
201 | param.data = param.data.to(torch.float32)
202 | logger.info("Upcasting layernorm weights in float32.")
203 |
204 | if not model_args.disable_gradient_checkpointing:
205 | if not getattr(model, "supports_gradient_checkpointing", False):
206 | logger.warning("Current model does not support gradient checkpointing.")
207 | else:
208 | model.enable_input_require_grads()
209 | model.gradient_checkpointing_enable()
210 | model.config.use_cache = False # turn off when gradient checkpointing is enabled
211 | logger.info("Gradient checkpointing enabled.")
212 |
213 | if hasattr(model, output_layer_name):
214 | def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
215 | return output.to(torch.float32)
216 |
217 | output_layer = getattr(model, output_layer_name)
218 | if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
219 | output_layer.register_forward_hook(fp32_forward_post_hook)
220 |
221 |
222 | def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
223 | if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__):
224 | tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
225 |
226 |
227 | def patch_config(
228 | config: "PretrainedConfig",
229 | tokenizer: "PreTrainedTokenizer",
230 | model_args: "ModelArguments",
231 | config_kwargs: Dict[str, Any],
232 | is_trainable: bool
233 | ) -> None:
234 | if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
235 | model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
236 |
237 | if getattr(config, "model_type", None) == "qwen":
238 | for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
239 | setattr(config, dtype_name, model_args.compute_dtype == dtype)
240 |
241 | if model_args.rope_scaling is not None:
242 | _configure_rope(config, model_args, is_trainable)
243 |
244 | if model_args.flash_attn:
245 | _configure_flashattn(config_kwargs)
246 |
247 | if is_trainable and model_args.shift_attn:
248 | _configure_longlora(config)
249 |
250 | _configure_quantization(config, tokenizer, model_args, config_kwargs)
251 |
252 |
253 | def patch_model(
254 | model: "PreTrainedModel",
255 | tokenizer: "PreTrainedTokenizer",
256 | model_args: "ModelArguments",
257 | is_trainable: bool
258 | ) -> None:
259 | if "GenerationMixin" not in str(model.generate.__func__):
260 | model.generate = MethodType(PreTrainedModel.generate, model)
261 |
262 | if getattr(model.config, "model_type", None) == "chatglm":
263 | setattr(model, "lm_head", model.transformer.output_layer)
264 | setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
265 |
266 | if model_args.resize_vocab:
267 | if is_deepspeed_zero3_enabled():
268 | raise ValueError("DeepSpeed ZeRO-3 is incompatible with vocab resizing.")
269 |
270 | _resize_embedding_layer(model, tokenizer)
271 |
272 | if is_trainable:
273 | _prepare_model_for_training(model, model_args)
274 |
275 |
276 | def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
277 | def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None:
278 | if isinstance(self.pretrained_model, PreTrainedModel):
279 | self.pretrained_model.tie_weights()
280 |
281 | def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module:
282 | if isinstance(self.pretrained_model, PreTrainedModel):
283 | return self.pretrained_model.get_input_embeddings()
284 |
285 | ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name]
286 | setattr(model, "_keys_to_ignore_on_save", ignore_modules)
287 | setattr(model, "tie_weights", MethodType(tie_weights, model))
288 | setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model))
289 |
--------------------------------------------------------------------------------
/model/src/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import inspect
3 | from typing import TYPE_CHECKING, Any, Dict, List
4 | from transformers import PreTrainedModel
5 | from transformers.utils import cached_file
6 |
7 | from ..extras.constants import V_HEAD_WEIGHTS_NAME, V_HEAD_SAFE_WEIGHTS_NAME
8 | from ..extras.logging import get_logger
9 | from ..extras.misc import get_current_device
10 |
11 | if TYPE_CHECKING:
12 | from transformers import PretrainedConfig, PreTrainedTokenizer
13 | from ..hparams.model_args import ModelArguments
14 | from ..hparams.data_args import DataArguments
15 | from ..hparams.finetuning_args import FinetuningArguments
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
22 | r"""
23 | Dispatches a pre-trained model to GPUs with balanced memory when the GPU is available.
24 | Borrowed from: https://github.com/huggingface/transformers/blob/v4.36.2/src/transformers/modeling_utils.py#L3570
25 | """
26 | if getattr(model, "quantization_method", None): # already set on current device
27 | return model
28 |
29 | if (
30 | torch.cuda.device_count() > 1
31 | and isinstance(model, PreTrainedModel)
32 | and model._no_split_modules is not None
33 | and model.config.model_type != "chatglm"
34 | ):
35 | from accelerate import dispatch_model
36 | from accelerate.utils import infer_auto_device_map, get_balanced_memory
37 |
38 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._get_no_split_modules("auto")}
39 | max_memory = get_balanced_memory(model, **kwargs)
40 | # Make sure tied weights are tied before creating the device map.
41 | model.tie_weights()
42 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
43 | device_map_kwargs = {"device_map": device_map}
44 | if "skip_keys" in inspect.signature(dispatch_model).parameters:
45 | device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
46 | return dispatch_model(model, **device_map_kwargs)
47 | else:
48 | return model.to(device=get_current_device())
49 |
50 |
51 | def find_all_linear_modules(model: "PreTrainedModel") -> List[str]:
52 | r"""
53 | Finds all available modules to apply lora.
54 | """
55 | quantization_method = getattr(model, "quantization_method", None)
56 | if quantization_method is None:
57 | linear_cls = torch.nn.Linear
58 | elif quantization_method == "bitsandbytes":
59 | import bitsandbytes as bnb
60 | linear_cls = bnb.nn.Linear4bit if getattr(model, "is_loaded_in_4bit", False) else bnb.nn.Linear8bitLt
61 | else:
62 | raise ValueError("Finding linear modules for {} models is not supported.".format(quantization_method))
63 |
64 | output_layer_names = ["lm_head"]
65 | if model.config.model_type == "chatglm":
66 | output_layer_names.append("output_layer")
67 |
68 | module_names = set()
69 | for name, module in model.named_modules():
70 | if (
71 | isinstance(module, linear_cls)
72 | and not any([output_layer in name for output_layer in output_layer_names])
73 | ):
74 | module_names.add(name.split(".")[-1])
75 |
76 | logger.info("Found linear modules: {}".format(",".join(module_names)))
77 | return list(module_names)
78 |
79 |
80 | def get_modelcard_args(
81 | model_args: "ModelArguments",
82 | data_args: "DataArguments",
83 | finetuning_args: "FinetuningArguments"
84 | ) -> Dict[str, Any]:
85 | return {
86 | "tasks": "text-generation",
87 | "license": "other",
88 | "finetuned_from": model_args.model_name_or_path,
89 | "dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
90 | "tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
91 | }
92 |
93 |
94 | def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
95 | r"""
96 | Loads value head parameters from Hugging Face Hub or local disk.
97 |
98 | Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`.
99 | """
100 | kwargs = {
101 | "path_or_repo_id": path_or_repo_id,
102 | "cache_dir": model_args.cache_dir,
103 | "token": model_args.hf_hub_token
104 | }
105 |
106 | try:
107 | from safetensors import safe_open
108 | vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs)
109 | with safe_open(vhead_file, framework="pt", device="cpu") as f:
110 | return {key: f.get_tensor(key) for key in f.keys()}
111 | except Exception as err:
112 | logger.info("Failed to load {}: {}".format(V_HEAD_SAFE_WEIGHTS_NAME, str(err)))
113 |
114 | try:
115 | vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs)
116 | return torch.load(vhead_file, map_location="cpu")
117 | except Exception as err:
118 | logger.info("Failed to load {}: {}".format(V_HEAD_WEIGHTS_NAME, str(err)))
119 |
120 | logger.info("Provided path ({}) does not contain value head weights.".format(path_or_repo_id))
121 | logger.info("Ignore these messages if you are not resuming the training of a value head model.")
122 | return None
123 |
124 |
125 | def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"):
126 | if "AutoConfig" in getattr(config, "auto_map", {}):
127 | config.__class__.register_for_auto_class()
128 | if "AutoModelForCausalLM" in getattr(config, "auto_map", {}):
129 | model.__class__.register_for_auto_class()
130 | if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}):
131 | tokenizer.__class__.register_for_auto_class()
132 |
--------------------------------------------------------------------------------
/model/src/train/dpo/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from dataclasses import dataclass
3 | from typing import Any, Dict, List, Sequence, Tuple
4 | from transformers import DataCollatorForSeq2Seq
5 |
6 |
7 | @dataclass
8 | class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
9 | r"""
10 | Data collator for pairwise data.
11 | """
12 |
13 | def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
14 | padded_labels = []
15 | for feature, (prompt_len, answer_len) in zip(batch, positions):
16 | if self.tokenizer.padding_side == "left":
17 | start, end = feature.size(0) - answer_len, feature.size(0)
18 | else:
19 | start, end = prompt_len, prompt_len + answer_len
20 | padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
21 | padded_tensor[start:end] = feature[start:end]
22 | padded_labels.append(padded_tensor)
23 | return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
24 |
25 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
26 | r"""
27 | Pads batched data to the longest sequence in the batch.
28 |
29 | We generate 2 * n examples where the first n examples represent chosen examples and
30 | the last n examples represent rejected examples.
31 | """
32 | concatenated_features = []
33 | label_positions = []
34 | for key in ("chosen_ids", "rejected_ids"):
35 | for feature in features:
36 | prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
37 | concatenated_features.append({
38 | "input_ids": feature["prompt_ids"] + feature[key],
39 | "attention_mask": [1] * (prompt_len + answer_len)
40 | })
41 | label_positions.append((prompt_len, answer_len))
42 |
43 | batch = self.tokenizer.pad(
44 | concatenated_features,
45 | padding=self.padding,
46 | max_length=self.max_length,
47 | pad_to_multiple_of=self.pad_to_multiple_of,
48 | return_tensors=self.return_tensors,
49 | )
50 | batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
51 | return batch
52 |
--------------------------------------------------------------------------------
/model/src/train/dpo/trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import defaultdict
3 | from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
4 | from transformers import BatchEncoding, Trainer
5 | from trl import DPOTrainer
6 | from trl.trainer.utils import disable_dropout_in_model
7 |
8 | from ...extras.constants import IGNORE_INDEX
9 |
10 | if TYPE_CHECKING:
11 | from transformers import PreTrainedModel
12 |
13 |
14 | class CustomDPOTrainer(DPOTrainer):
15 |
16 | def __init__(
17 | self,
18 | beta: float,
19 | loss_type: Literal["sigmoid", "hinge", "ipo", "kto"],
20 | ftx_gamma: float,
21 | model: Union["PreTrainedModel", torch.nn.Module],
22 | ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
23 | disable_dropout: Optional[bool] = True,
24 | **kwargs
25 | ):
26 | if disable_dropout:
27 | disable_dropout_in_model(model)
28 | if ref_model is not None:
29 | disable_dropout_in_model(ref_model)
30 |
31 | self.use_dpo_data_collator = True # hack to avoid warning
32 | self.generate_during_eval = False # disable at evaluation
33 | self.label_pad_token_id = IGNORE_INDEX
34 | self.padding_value = 0
35 | self.is_encoder_decoder = model.config.is_encoder_decoder
36 | self.precompute_ref_log_probs = False
37 | self._precomputed_train_ref_log_probs = False
38 | self._precomputed_eval_ref_log_probs = False
39 | self._peft_has_been_casted_to_bf16 = False
40 |
41 | self.ref_model = ref_model
42 | self.beta = beta
43 | self.label_smoothing = 0
44 | self.loss_type = loss_type
45 | self.ftx_gamma = ftx_gamma
46 | self._stored_metrics = defaultdict(lambda: defaultdict(list))
47 |
48 | Trainer.__init__(self, model=model, **kwargs)
49 | if not hasattr(self, "accelerator"):
50 | raise AttributeError("Please update `transformers`.")
51 |
52 | if ref_model is not None:
53 | if self.is_deepspeed_enabled:
54 | if not (
55 | getattr(ref_model, "is_loaded_in_8bit", False)
56 | or getattr(ref_model, "is_loaded_in_4bit", False)
57 | ): # quantized models are already set on the correct device
58 | self.ref_model = self._prepare_deepspeed(self.ref_model)
59 | else:
60 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
61 |
62 | def sft_loss(
63 | self,
64 | chosen_logits: torch.FloatTensor,
65 | chosen_labels: torch.LongTensor
66 | ) -> torch.Tensor:
67 | r"""
68 | Computes supervised cross-entropy loss of given labels under the given logits.
69 |
70 | Returns:
71 | A tensor of shape (batch_size,) containing the cross-entropy loss of each samples.
72 | """
73 | all_logps = self.get_batch_logps(
74 | chosen_logits,
75 | chosen_labels,
76 | average_log_prob=True
77 | )
78 | return -all_logps
79 |
80 | def concatenated_forward(
81 | self,
82 | model: "PreTrainedModel",
83 | batch: Dict[str, torch.Tensor]
84 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
85 | batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
86 |
87 | all_logits = model(
88 | input_ids=batch_copied["input_ids"],
89 | attention_mask=batch_copied["attention_mask"],
90 | return_dict=True
91 | ).logits.to(torch.float32)
92 |
93 | all_logps = self.get_batch_logps(
94 | all_logits,
95 | batch["labels"],
96 | average_log_prob=False
97 | )
98 | batch_size = batch["input_ids"].size(0) // 2
99 | chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
100 | chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
101 | return chosen_logps, rejected_logps, chosen_logits, rejected_logits
102 |
103 | def get_batch_loss_metrics(
104 | self,
105 | model: "PreTrainedModel",
106 | batch: Dict[str, torch.Tensor],
107 | train_eval: Optional[Literal["train", "eval"]] = "train"
108 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
109 | r"""
110 | Computes the DPO loss and other metrics for the given batch of inputs for train or test.
111 | """
112 | metrics = {}
113 | (
114 | policy_chosen_logps,
115 | policy_rejected_logps,
116 | policy_chosen_logits,
117 | policy_rejected_logits,
118 | ) = self.concatenated_forward(model, batch)
119 | with torch.no_grad():
120 | if self.ref_model is None:
121 | with self.accelerator.unwrap_model(self.model).disable_adapter():
122 | (
123 | reference_chosen_logps,
124 | reference_rejected_logps,
125 | _,
126 | _,
127 | ) = self.concatenated_forward(self.model, batch)
128 | else:
129 | (
130 | reference_chosen_logps,
131 | reference_rejected_logps,
132 | _,
133 | _,
134 | ) = self.concatenated_forward(self.ref_model, batch)
135 |
136 | losses, chosen_rewards, rejected_rewards = self.dpo_loss(
137 | policy_chosen_logps,
138 | policy_rejected_logps,
139 | reference_chosen_logps,
140 | reference_rejected_logps,
141 | )
142 | if self.ftx_gamma > 1e-6:
143 | batch_size = batch["input_ids"].size(0) // 2
144 | chosen_labels, _ = batch["labels"].split(batch_size, dim=0)
145 | losses += self.ftx_gamma * self.sft_loss(policy_chosen_logits, chosen_labels)
146 |
147 | reward_accuracies = (chosen_rewards > rejected_rewards).float()
148 |
149 | prefix = "eval_" if train_eval == "eval" else ""
150 | metrics[f"{prefix}rewards/chosen"] = chosen_rewards.cpu().mean()
151 | metrics[f"{prefix}rewards/rejected"] = rejected_rewards.cpu().mean()
152 | metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.cpu().mean()
153 | metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().mean()
154 | metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().cpu().mean()
155 | metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().cpu().mean()
156 | metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().cpu().mean()
157 | metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().cpu().mean()
158 |
159 | return losses.mean(), metrics
160 |
--------------------------------------------------------------------------------
/model/src/train/dpo/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
2 |
3 | from typing import TYPE_CHECKING, Optional, List
4 | from transformers import Seq2SeqTrainingArguments
5 |
6 | from ...data.loader import get_dataset
7 | from ...data.utils import split_dataset
8 | from ...extras.constants import IGNORE_INDEX
9 | from ...extras.ploting import plot_loss
10 | from ...hparams.model_args import ModelArguments
11 | from ...model.loader import load_model_and_tokenizer
12 | from ...train.dpo.collator import DPODataCollatorWithPadding
13 | from ...train.dpo.trainer import CustomDPOTrainer
14 | from ...train.utils import create_modelcard_and_push, create_ref_model
15 |
16 | if TYPE_CHECKING:
17 | from transformers import TrainerCallback
18 | from ...hparams import DataArguments, FinetuningArguments
19 |
20 |
21 | def run_dpo(
22 | model_args: "ModelArguments",
23 | data_args: "DataArguments",
24 | training_args: "Seq2SeqTrainingArguments",
25 | finetuning_args: "FinetuningArguments",
26 | callbacks: Optional[List["TrainerCallback"]] = None
27 | ):
28 |
29 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
30 | dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
31 | data_collator = DPODataCollatorWithPadding(
32 | tokenizer=tokenizer,
33 | pad_to_multiple_of=8,
34 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
35 | )
36 |
37 | # Create reference model
38 | if finetuning_args.ref_model is None and (not training_args.do_train): # use the model itself
39 | ref_model = model
40 | else:
41 | ref_model = create_ref_model(model_args, finetuning_args)
42 |
43 | # Update arguments
44 | training_args_dict = training_args.to_dict()
45 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
46 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
47 |
48 | # Initialize our Trainer
49 | trainer = CustomDPOTrainer(
50 | beta=finetuning_args.dpo_beta,
51 | loss_type=finetuning_args.dpo_loss,
52 | ftx_gamma=finetuning_args.dpo_ftx,
53 | model=model,
54 | ref_model=ref_model,
55 | args=training_args,
56 | tokenizer=tokenizer,
57 | data_collator=data_collator,
58 | callbacks=callbacks,
59 | **split_dataset(dataset, data_args, training_args)
60 | )
61 |
62 | # Training
63 | if training_args.do_train:
64 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
65 | trainer.save_model()
66 | trainer.log_metrics("train", train_result.metrics)
67 | trainer.save_metrics("train", train_result.metrics)
68 | trainer.save_state()
69 | if trainer.is_world_process_zero() and finetuning_args.plot_loss:
70 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
71 |
72 | # Evaluation
73 | if training_args.do_eval:
74 | metrics = trainer.evaluate(metric_key_prefix="eval")
75 | if id(model) == id(ref_model): # unable to compute rewards without a reference model
76 | remove_keys = [key for key in metrics.keys() if "rewards" in key]
77 | for key in remove_keys:
78 | metrics.pop(key)
79 | trainer.log_metrics("eval", metrics)
80 | trainer.save_metrics("eval", metrics)
81 |
82 | # Create model card
83 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
84 |
--------------------------------------------------------------------------------
/model/src/train/sft/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union
4 |
5 | from ...extras.constants import IGNORE_INDEX
6 | from ...extras.packages import (
7 | is_jieba_available, is_nltk_available, is_rouge_available
8 | )
9 |
10 | if TYPE_CHECKING:
11 | from transformers.tokenization_utils import PreTrainedTokenizer
12 |
13 | if is_jieba_available():
14 | import jieba
15 |
16 | if is_nltk_available():
17 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
18 |
19 | if is_rouge_available():
20 | from rouge_chinese import Rouge
21 |
22 |
23 | @dataclass
24 | class ComputeMetrics:
25 | r"""
26 | Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
27 | """
28 |
29 | tokenizer: "PreTrainedTokenizer"
30 |
31 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
32 | r"""
33 | Uses the model predictions to compute metrics.
34 | """
35 | preds, labels = eval_preds
36 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
37 |
38 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
39 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
40 |
41 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
42 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
43 |
44 | for pred, label in zip(decoded_preds, decoded_labels):
45 | hypothesis = list(jieba.cut(pred))
46 | reference = list(jieba.cut(label))
47 |
48 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
49 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
50 | else:
51 | rouge = Rouge()
52 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
53 | result = scores[0]
54 |
55 | for k, v in result.items():
56 | score_dict[k].append(round(v["f"] * 100, 4))
57 |
58 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
59 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
60 |
61 | return {k: float(np.mean(v)) for k, v in score_dict.items()}
62 |
--------------------------------------------------------------------------------
/model/src/train/sft/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
7 |
8 | import tqdm
9 | from transformers import Seq2SeqTrainer
10 |
11 | from ...extras.constants import IGNORE_INDEX
12 | from ...extras.logging import get_logger
13 |
14 | if TYPE_CHECKING:
15 | from transformers.trainer import PredictionOutput
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | class CustomSeq2SeqTrainer(Seq2SeqTrainer):
22 | r"""
23 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
24 | """
25 |
26 | def prediction_step(
27 | self,
28 | model: nn.Module,
29 | inputs: Dict[str, Union[torch.Tensor, Any]],
30 | prediction_loss_only: bool,
31 | ignore_keys: Optional[List[str]] = None,
32 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
33 | r"""
34 | Removes the prompt part in the generated tokens.
35 |
36 | Subclass and override to inject custom behavior.
37 | """
38 | labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
39 | if self.args.predict_with_generate:
40 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
41 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
42 | if prompt_len > label_len:
43 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
44 | if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility)
45 | inputs["labels"] = inputs["labels"][:, :prompt_len]
46 |
47 | loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated)
48 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
49 | )
50 | if generated_tokens is not None and self.args.predict_with_generate:
51 | generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
52 | generated_tokens = generated_tokens.contiguous()
53 |
54 | return loss, generated_tokens, labels
55 |
56 | def _pad_tensors_to_target_len(
57 | self,
58 | src_tensor: torch.Tensor,
59 | tgt_tensor: torch.Tensor
60 | ) -> torch.Tensor:
61 | r"""
62 | Pads the tensor to the same length as the target tensor.
63 | """
64 | assert self.tokenizer.pad_token_id is not None, "Pad token is required."
65 | padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
66 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
67 | return padded_tensor.contiguous() # in contiguous memory
68 |
69 | def save_predictions(
70 | self,
71 | predict_results: "PredictionOutput",
72 | sentence_num: int = 1
73 | ) -> None:
74 | r"""
75 | Saves model predictions to `output_dir`.
76 |
77 | A custom behavior that not contained in Seq2SeqTrainer.
78 | """
79 | if not self.is_world_process_zero():
80 | return
81 |
82 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
83 | logger.info(f"Saving prediction results to {output_prediction_file}")
84 |
85 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids,
86 | self.tokenizer.pad_token_id)
87 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions,
88 | self.tokenizer.pad_token_id)
89 |
90 | for i in range(len(preds)):
91 | pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0]
92 | if len(pad_len):
93 | preds[i] = np.concatenate((preds[i][pad_len[0]:], preds[i][:pad_len[0]]),
94 | axis=-1) # move pad token to last
95 |
96 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True,
97 | clean_up_tokenization_spaces=False)
98 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
99 |
100 | if sentence_num != 1:
101 | decoded_labels = [element for element in decoded_labels for _ in range(sentence_num)]
102 |
103 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
104 | res: List[str] = []
105 | for label, pred in zip(decoded_labels, decoded_preds):
106 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
107 | writer.write("\n".join(res))
--------------------------------------------------------------------------------
/model/src/train/sft/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
2 |
3 | from typing import TYPE_CHECKING, Optional, List
4 | from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
5 |
6 | from ...data.loader import get_dataset
7 | from ...data.preprocess import preprocess_dataset
8 | from ...data.utils import split_dataset
9 | from ...extras.constants import IGNORE_INDEX
10 | from ...extras.misc import get_logits_processor
11 | from ...extras.ploting import plot_loss
12 | from ...model.loader import load_model_and_tokenizer
13 | from ...train.sft.metric import ComputeMetrics
14 | from ...train.sft.trainer import CustomSeq2SeqTrainer
15 | from ...train.utils import create_modelcard_and_push
16 |
17 | if TYPE_CHECKING:
18 | from transformers import TrainerCallback
19 | from ...hparams.model_args import ModelArguments
20 | from ...hparams.data_args import DataArguments
21 | from ...hparams.evaluation_args import EvaluationArguments
22 | from ...hparams.finetuning_args import FinetuningArguments
23 | from ...hparams.generating_args import GeneratingArguments
24 |
25 |
26 | def run_sft(
27 | model_args: "ModelArguments",
28 | data_args: "DataArguments",
29 | training_args: "Seq2SeqTrainingArguments",
30 | finetuning_args: "FinetuningArguments",
31 | generating_args: "GeneratingArguments",
32 | callbacks: Optional[List["TrainerCallback"]] = None
33 | ):
34 | dataset = get_dataset(model_args, data_args)
35 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
36 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
37 |
38 | if training_args.predict_with_generate:
39 | tokenizer.padding_side = "left" # use left-padding in generation
40 |
41 | if getattr(model, "is_quantized", False) and not training_args.do_train:
42 | setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction
43 |
44 | data_collator = DataCollatorForSeq2Seq(
45 | tokenizer=tokenizer,
46 | pad_to_multiple_of=8 if tokenizer.padding_side == "right" else None, # for shift short attention
47 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
48 | )
49 |
50 | # Override the decoding parameters of Seq2SeqTrainer
51 | training_args_dict = training_args.to_dict()
52 | training_args_dict.update(dict(
53 | generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
54 | generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
55 | ))
56 | training_args = Seq2SeqTrainingArguments(**training_args_dict)
57 |
58 | # Initialize our Trainer
59 | trainer = CustomSeq2SeqTrainer(
60 | model=model,
61 | args=training_args,
62 | tokenizer=tokenizer,
63 | data_collator=data_collator,
64 | callbacks=callbacks,
65 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
66 | **split_dataset(dataset, data_args, training_args)
67 | )
68 |
69 | # Keyword arguments for `model.generate`
70 | gen_kwargs = generating_args.to_dict()
71 | gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
72 | gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
73 | gen_kwargs["logits_processor"] = get_logits_processor()
74 |
75 | # Training
76 | if training_args.do_train:
77 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
78 | trainer.save_model()
79 | trainer.log_metrics("train", train_result.metrics)
80 | trainer.save_metrics("train", train_result.metrics)
81 | trainer.save_state()
82 | if trainer.is_world_process_zero() and finetuning_args.plot_loss:
83 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
84 |
85 | # Evaluation
86 | if training_args.do_eval:
87 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
88 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
89 | metrics.pop("eval_loss", None)
90 | trainer.log_metrics("eval", metrics)
91 | trainer.save_metrics("eval", metrics)
92 |
93 | # Predict
94 | if training_args.do_predict:
95 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
96 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
97 | predict_results.metrics.pop("predict_loss", None)
98 | trainer.log_metrics("predict", predict_results.metrics)
99 | trainer.save_metrics("predict", predict_results.metrics)
100 | trainer.save_predictions(predict_results, sentence_num=gen_kwargs["num_return_sequences"])
101 |
102 | # Create model card
103 | create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
--------------------------------------------------------------------------------
/model/src/train/tuner.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING, Any, Dict, List, Optional
3 | from transformers import PreTrainedModel
4 |
5 | from llmtuner.extras.callbacks import LogCallback
6 | from llmtuner.extras.logging import get_logger
7 | from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer
8 | from llmtuner.train.pt import run_pt
9 | from llmtuner.train.sft import run_sft
10 | from llmtuner.train.rm import run_rm
11 | from llmtuner.train.ppo import run_ppo
12 | from llmtuner.train.dpo import run_dpo
13 |
14 | if TYPE_CHECKING:
15 | from transformers import TrainerCallback
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
22 | model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
23 | callbacks = [LogCallback()] if callbacks is None else callbacks
24 |
25 | if finetuning_args.stage == "pt":
26 | run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
27 | elif finetuning_args.stage == "sft":
28 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
29 | elif finetuning_args.stage == "rm":
30 | run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
31 | elif finetuning_args.stage == "ppo":
32 | run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
33 | elif finetuning_args.stage == "dpo":
34 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
35 | else:
36 | raise ValueError("Unknown task.")
37 |
38 |
39 | def export_model(args: Optional[Dict[str, Any]] = None):
40 | model_args, _, finetuning_args, _ = get_infer_args(args)
41 |
42 | if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None:
43 | raise ValueError("Please merge adapters before quantizing the model.")
44 |
45 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
46 |
47 | if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None:
48 | raise ValueError("Cannot merge adapters to a quantized model.")
49 |
50 | if not isinstance(model, PreTrainedModel):
51 | raise ValueError("The model is not a `PreTrainedModel`, export aborted.")
52 |
53 | model.config.use_cache = True
54 | if getattr(model.config, "torch_dtype", None) == "bfloat16":
55 | model = model.to(torch.bfloat16).to("cpu")
56 | else:
57 | model = model.to(torch.float16).to("cpu")
58 | setattr(model.config, "torch_dtype", "float16")
59 |
60 | model.save_pretrained(
61 | save_directory=model_args.export_dir,
62 | max_shard_size="{}GB".format(model_args.export_size),
63 | safe_serialization=(not model_args.export_legacy_format)
64 | )
65 |
66 | try:
67 | tokenizer.padding_side = "left" # restore padding side
68 | tokenizer.init_kwargs["padding_side"] = "left"
69 | tokenizer.save_pretrained(model_args.export_dir)
70 | except:
71 | logger.warning("Cannot save tokenizer, please copy the files manually.")
72 |
73 |
74 | if __name__ == "__main__":
75 | run_exp()
76 |
--------------------------------------------------------------------------------
/model/src/train/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import TYPE_CHECKING, Optional, Union
3 |
4 | from ..extras.logging import get_logger
5 | from ..hparams.model_args import ModelArguments
6 | from ..hparams.finetuning_args import FinetuningArguments
7 | from ..model.utils import get_modelcard_args, load_valuehead_params
8 | from ..model.loader import load_model_and_tokenizer
9 |
10 | if TYPE_CHECKING:
11 | from transformers import Seq2SeqTrainingArguments, Trainer
12 | from transformers.modeling_utils import PreTrainedModel
13 | from trl import AutoModelForCausalLMWithValueHead
14 | from ..hparams.data_args import DataArguments
15 |
16 |
17 | logger = get_logger(__name__)
18 |
19 |
20 | def create_modelcard_and_push(
21 | trainer: "Trainer",
22 | model_args: "ModelArguments",
23 | data_args: "DataArguments",
24 | training_args: "Seq2SeqTrainingArguments",
25 | finetuning_args: "FinetuningArguments"
26 | ) -> None:
27 | if training_args.do_train:
28 | if training_args.push_to_hub:
29 | trainer.push_to_hub(**get_modelcard_args(model_args, data_args, finetuning_args))
30 | return
31 | try:
32 | trainer.create_model_card(**get_modelcard_args(model_args, data_args, finetuning_args))
33 | except Exception as err:
34 | logger.warning("Failed to create model card: {}".format(str(err)))
35 |
36 |
37 | def create_ref_model(
38 | model_args: "ModelArguments",
39 | finetuning_args: "FinetuningArguments",
40 | add_valuehead: Optional[bool] = False
41 | ) -> Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]:
42 | r"""
43 | Creates reference model for PPO/DPO training. Evaluation mode is not supported.
44 |
45 | The valuehead parameter is randomly initialized since it is useless for PPO training.
46 | """
47 | if finetuning_args.ref_model is not None:
48 | ref_model_args_dict = model_args.to_dict()
49 | ref_model_args_dict.update(dict(
50 | model_name_or_path=finetuning_args.ref_model,
51 | adapter_name_or_path=finetuning_args.ref_model_adapters,
52 | quantization_bit=finetuning_args.ref_model_quantization_bit
53 | ))
54 | ref_model_args = ModelArguments(**ref_model_args_dict)
55 | ref_finetuning_args = FinetuningArguments(finetuning_type="lora")
56 | ref_model, _ = load_model_and_tokenizer(
57 | ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
58 | )
59 | logger.info("Created reference model from {}".format(finetuning_args.ref_model))
60 | else:
61 | if finetuning_args.finetuning_type == "lora":
62 | ref_model = None
63 | else:
64 | ref_model, _ = load_model_and_tokenizer(
65 | model_args, finetuning_args, is_trainable=False, add_valuehead=add_valuehead
66 | )
67 | logger.info("Created reference model from the model itself.")
68 |
69 | return ref_model
70 |
71 |
72 | def create_reward_model(
73 | model: "AutoModelForCausalLMWithValueHead",
74 | model_args: "ModelArguments",
75 | finetuning_args: "FinetuningArguments"
76 | ) -> "AutoModelForCausalLMWithValueHead":
77 | r"""
78 | Creates reward model for PPO training.
79 | """
80 | if finetuning_args.reward_model_type == "api":
81 | assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
82 | logger.info("Use reward server {}".format(finetuning_args.reward_model))
83 | return finetuning_args.reward_model
84 | elif finetuning_args.reward_model_type == "lora":
85 | model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
86 | for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
87 | if "default" in name:
88 | param.data = param.data.to(torch.float32) # trainable params should in fp32
89 | vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args)
90 | assert vhead_params is not None, "Reward model is not correctly loaded."
91 | model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
92 | model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
93 | model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
94 | model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
95 | logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model))
96 | return None
97 | else:
98 | reward_model_args_dict = model_args.to_dict()
99 | reward_model_args_dict.update(dict(
100 | model_name_or_path=finetuning_args.reward_model,
101 | adapter_name_or_path=finetuning_args.reward_model_adapters,
102 | quantization_bit=finetuning_args.reward_model_quantization_bit
103 | ))
104 | reward_model_args = ModelArguments(**reward_model_args_dict)
105 | reward_finetuning_args = FinetuningArguments(finetuning_type="lora")
106 | reward_model, _ = load_model_and_tokenizer(
107 | reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
108 | )
109 | logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model))
110 | logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
111 | return reward_model
112 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.26.0
2 | aiofiles==23.2.1
3 | aiohttp==3.9.0
4 | aiosignal==1.3.1
5 | altair==5.1.2
6 | annotated-types==0.6.0
7 | anyio==3.7.1
8 | appdirs==1.4.4
9 | attrs==23.1.0
10 | bird==0.1.2
11 | certifi==2023.11.17
12 | charset-normalizer==3.3.2
13 | click==8.1.7
14 | contourpy==1.2.0
15 | cycler==0.12.1
16 | datasets==2.16.1
17 | deepspeed==0.12.3
18 | dill==0.3.7
19 | distro==1.8.0
20 | docker-pycreds==0.4.0
21 | docstring-parser==0.15
22 | einops==0.7.0
23 | fastapi==0.109.0
24 | ffmpy==0.3.1
25 | filelock==3.13.1
26 | fonttools==4.45.0
27 | frozenlist==1.4.0
28 | fsspec==2023.10.0
29 | gitdb==4.0.11
30 | GitPython==3.1.42
31 | gradio==3.50.2
32 | gradio_client==0.6.1
33 | h11==0.14.0
34 | hjson==3.1.0
35 | httpcore==1.0.2
36 | httpx==0.25.1
37 | huggingface-hub==0.20.1
38 | idna==3.4
39 | importlib-resources==6.1.1
40 | jieba==0.42.1
41 | Jinja2==3.1.2
42 | joblib==1.3.2
43 | jsonschema==4.20.0
44 | jsonschema-specifications==2023.11.1
45 | kiwisolver==1.4.5
46 | markdown-it-py==3.0.0
47 | MarkupSafe==2.1.3
48 | matplotlib==3.8.2
49 | mdurl==0.1.2
50 | mpmath==1.3.0
51 | multidict==6.0.4
52 | multiprocess==0.70.15
53 | networkx==3.2.1
54 | ninja==1.11.1.1
55 | nltk==3.8.1
56 | numpy==1.26.2
57 | nvidia-cublas-cu12==12.1.3.1
58 | nvidia-cuda-cupti-cu12==12.1.105
59 | nvidia-cuda-nvrtc-cu12==12.1.105
60 | nvidia-cuda-runtime-cu12==12.1.105
61 | nvidia-cudnn-cu12==8.9.2.26
62 | nvidia-cufft-cu12==11.0.2.54
63 | nvidia-curand-cu12==10.3.2.106
64 | nvidia-cusolver-cu12==11.4.5.107
65 | nvidia-cusparse-cu12==12.1.0.106
66 | nvidia-nccl-cu12==2.18.1
67 | nvidia-nvjitlink-cu12==12.3.101
68 | nvidia-nvtx-cu12==12.1.105
69 | orjson==3.9.10
70 | openai==0.28.1
71 | packaging==23.2
72 | pandas==2.1.3
73 | peft==0.7.1
74 | Pillow==10.1.0
75 | protobuf==4.25.2
76 | psutil==5.9.6
77 | py-cpuinfo==9.0.0
78 | pyarrow==14.0.1
79 | pyarrow-hotfix==0.5
80 | pydantic==2.5.3
81 | pydantic_core==2.14.6
82 | pydub==0.25.1
83 | Pygments==2.17.1
84 | pynvml==11.5.0
85 | pyparsing==3.1.1
86 | python-dateutil==2.8.2
87 | python-multipart==0.0.6
88 | pytz==2023.3.post1
89 | PyYAML==6.0.1
90 | referencing==0.31.0
91 | regex==2023.10.3
92 | requests==2.31.0
93 | rich==13.7.0
94 | rouge-chinese==1.0.3
95 | rpds-py==0.13.1
96 | safetensors==0.4.0
97 | scipy==1.11.4
98 | semantic-version==2.10.0
99 | sentencepiece==0.1.99
100 | sentry-sdk==1.40.5
101 | setproctitle==1.3.3
102 | shortuuid==1.0.11
103 | shtab==1.6.4
104 | six==1.16.0
105 | smmap==5.0.1
106 | sniffio==1.3.0
107 | sse-starlette==1.8.2
108 | starlette==0.35.1
109 | sympy==1.12
110 | tiktoken==0.5.2
111 | tokenizers==0.15.0
112 | toolz==0.12.0
113 | torch==2.1.2
114 | torchaudio==2.1.2
115 | torchvision==0.16.2
116 | tqdm==4.66.1
117 | transformers==4.36.2
118 | transformers-stream-generator==0.0.4
119 | triton==2.1.0
120 | trl==0.7.9
121 | typing_extensions==4.8.0
122 | tyro==0.5.17
123 | tzdata==2023.3
124 | urllib3==2.1.0
125 | uvicorn==0.26.0
126 | wandb==0.16.3
127 | websockets==11.0.3
128 | xformers==0.0.23.post1
129 | xxhash==3.4.1
130 | yarl==1.9.3
131 |
--------------------------------------------------------------------------------
/slides/Framework.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Framework.pdf
--------------------------------------------------------------------------------
/slides/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Framework.png
--------------------------------------------------------------------------------
/slides/Poster-DELLM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Poster-DELLM.pdf
--------------------------------------------------------------------------------
/slides/Slides-DELLM.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rcrossmeister/Knowledge-to-SQL/4fe8d0588eb853ef44088c24dcfbdabfd1a6478c/slides/Slides-DELLM.pdf
--------------------------------------------------------------------------------