├── README.md ├── data_process └── data_process.py ├── image-1.png ├── image-2.png ├── image.png ├── image.webp ├── loss_vs_time_hours.png ├── loss_vs_tokens_millions.png ├── pretrain ├── ds_config.json ├── generate_pretrain_data.py ├── model │ ├── config.json │ ├── configuration_miaomiao.py │ └── modeling_miaomiao.py ├── pretrain.py ├── pretrain.sh ├── pretrain_dataset.py └── test_pretrain_model.py ├── requirements.txt ├── rlhf ├── rlhf │ ├── __pycache__ │ │ ├── ppo_trainer.cpython-311.pyc │ │ └── rlhf_engine.cpython-311.pyc │ ├── ppo_trainer.py │ └── rlhf_engine.py ├── rlhf_data_process.py ├── rw_eval.py ├── step2.py ├── step2.sh ├── step2_eval.sh ├── step3.py ├── step3.sh └── utils │ ├── __pycache__ │ ├── data_utils.cpython-311.pyc │ ├── ds_utils.cpython-311.pyc │ ├── model_utils.cpython-311.pyc │ ├── perf.cpython-311.pyc │ ├── raw_datasets.cpython-311.pyc │ ├── reward_model.cpython-311.pyc │ └── utils.cpython-311.pyc │ ├── data_utils.py │ ├── ds_utils.py │ ├── model_utils.py │ ├── perf.py │ ├── raw_datasets.py │ ├── reward_model.py │ └── utils.py ├── sft ├── ds_config.json ├── model │ ├── config.json │ ├── configuration_miaomiao.py │ ├── merges.txt │ ├── modeling_miaomiao.py │ ├── tokenization_miaomiao.py │ ├── tokenizer.json │ ├── tokenizer_config.json │ └── vocab.json ├── sft.py ├── sft.sh ├── sft_data_filted.py ├── sft_dataset.py └── test_sft_model.py └── train_tokenizer ├── miaomiao_tokenizer ├── merges.txt ├── tokenization_miaomiao.py ├── tokenizer.json ├── tokenizer_config.json └── vocab.json └── train_tokenizer.py /README.md: -------------------------------------------------------------------------------- 1 | # Zero-Chatgpt 2 |

3 | Zero-Chatgpt 4 |

5 | 6 | 本开源项目的目的是想从0开始,将chatgpt的技术路线跑一遍。 7 | 包括:数据收集 -> 数据清洗和去重 -> 词表训练 -> 语言模型预训练 -> 指令微调 -> 强化学习(rlhf,ppo)。 8 | 最主要的是把代码和流程跑通,效果有时间再调优。 9 | 预训练数据:10B token,指令微调数据:30w条,rlhf数据:10w条,模型大小:0.1B。 10 | 训练流程和代码都已经跑通,想要更好的效果的话可以直接调整模型配置文件做scaling up,这边训练的经验看更大的模型、更多的数据对于效果的提升是十分明显的。 11 | 12 | —————————————————————————————————————————————————————————————————— 13 | 介绍下另一个开源图文多模态项目:[Zero-Qwen-VL](https://github.com/AI-Study-Han/Zero-Qwen-VL), 从0开始训练一个对中文支持更友好的图文大模型,跑通图文多模态的训练流程。本项目用的是qwen-vl的图片编码器和Qwen2-0.5B-Instruct的语言模型,计算资源足够的话可以自己换成更大的模型,会有更好的效果。 14 | ## 一、训练环境 15 | cuda 12.1、pytorch、transformers、deepspeed等常用的环境,这里的requirements.txt是运行环境的介绍的列表。 16 | 17 | 计算资源是2块A40,预训练是2天左右。 18 | 19 | ## 二、训练数据、模型权重和训练镜像文件 20 | [预训练数据、微调数据、rlhf数据、模型权重、预训练和指令微调镜像](https://huggingface.co/My521/Zero-Chatgpt/tree/main)都放在这里了,模型权重去掉前缀名后(修改为model.safetensors或者pytorch_model.bin)和模型代码、配置文件放在一起(model文件夹下)就可以加载了。预训练数据、训练镜像太大,稍后上传。 21 | 22 | | 文件名称 | 文件介绍 | 23 | |------------------------|--------------------------------------------------------| 24 | | [pretrain_model.safetensors](https://huggingface.co/My521/Zero-Chatgpt/blob/main/pretrain_model.safetensors) | 预训练模型的权重文件| 25 | | [pretrain_model.safetensors](https://huggingface.co/My521/Zero-Chatgpt/blob/main/sft_model.safetensors) | 指令微调后模型的权重文件| 26 | | [rlhf_pytorch_model.bin](https://huggingface.co/My521/Zero-Chatgpt/blob/main/rlhf_pytorch_model.bin) | rlhf后的模型权重文件| 27 | | [pretrain_sft.tar](https://huggingface.co/My521/Zero-Chatgpt/blob/main/pretrain_sft.tar) | 预训练和sft运行镜像| 28 | | [rlhf.tar](https://huggingface.co/My521/Zero-Chatgpt/blob/main/rlhf.tar) | rlhf运行镜像| 29 | | [rlhf.jsonl](https://huggingface.co/My521/Zero-Chatgpt/blob/main/rlhf.jsonl) | rlhf数据集| 30 | | [sft.jsonl](https://huggingface.co/My521/Zero-Chatgpt/blob/main/sft.jsonl) | sft数据| 31 | | [pretrain.bin](https://huggingface.co/My521/Zero-Chatgpt/blob/main/pretrain.bin) | 预训练数据| 32 | 33 | 34 | ## 三、数据收集和清洗 35 | 本项目一共收集了10B左右的中文训练语料,包括[中文维基百科](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered/blob/main/wikipedia-cn-20230720-filtered.json),[中文百度百科](https://huggingface.co/datasets/xuqinyang/BaiduBaike-5.63M/blob/main/563w_baidubaike.json)和[SkyPile-150B](https://huggingface.co/datasets/Skywork/SkyPile-150B)随机抽取了部分数据。 36 | 37 | 中文维基百科和SkyPile-150B数据比较干净,只对中文百度百科进行了清洗和去重。去除了一些人物介绍、产品介绍和长度比较短的数据,并进行了严格的去重,最终563w条数据只剩下140多w条数据。 38 | 39 | 数据处理的代码在data_process文件夹下。 40 | 41 | ## 四、Tokenizer训练 42 | 从3类数据中随机抽取了部分数据(取决你服务器内存大小,本项目抽取了1.5G文本)训练。词表大小设置为32000(参考llama),因为这里模型设置的比较小,为了避免模型头重脚轻(embedding层参数占比太高),所以词表也比较小。special_tokens参考qwen设置。 43 | 44 | tokenizer训练的代码在train_tokenizer文件夹下。 45 | 46 | ## 五、预训练 47 | 模型结构参考llama(这也是大多数开源模型的选择),模型代码参考huggingface的代码(之前训练代码不兼容huggingface,进行rlhf的时候坑太多,后面改了)。这里考虑到手头目前可以使用的计算资源,模型大小设计为0.1B左右,计算资源多的可以对模型和数据进行scaling。训练过更大的模型和更多的数据,更大的模型和更多的数据效果就是更好,差异还是很明显的。 48 | 49 | 首先对数据进行分词,生成.bin文件,然后使用huggingface的trainer进行训练。 50 | 51 | 预训练数据生成代码、训练脚本、训练代码在pretrain文件夹下。 52 | 53 |

54 | loss 55 |

56 | 57 |

58 | loss 59 |

60 | 61 | ## 六、指令微调 62 | 这里指令微调的数据使用了[firefly-train-1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M/blob/main/firefly-train-1.1M.jsonl),[ruozhiout_qa_cn.jsonl](https://www.modelscope.cn/datasets/baicai003/Llama3-Chinese-dataset/files)。 63 | 64 | firefly-train-1.1M的数据质量不是特别高,这里根据问题长度对数据集进行了清洗和去重,最后剩余40多w条数据。因为模型尺寸比较小,只想训练单论对话能力,这里也只保留了单轮对话数据。 65 | 66 | sft过程使用了30w条数据,效果也不是很好,可能是因为模型尺寸的原因。之前尝试使用50B token训练了1.5B的模型,2w条训练数据就有比较好的对话能力,这里0.1B的模型2w条sft数据训练后对话能力还是比较差,需要更多的sft数据训练,这里用了30w条。 67 | 68 | 指令微调的脚本和代码在sft文件夹下面。 69 | 70 | 微调后模型简单测试结果: 71 | ![alt text](image.png) 72 | 73 | ## 七、强化学习 74 | 强化学习的数据是根据sft没有使用的数据进行生成的,sft数据原有的回答为"chosen",使用之前指令微调后的模型生成的回答作为"rejected",一共生成了10w条数据。rlhf的代码参考[DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat#readme)并进行了一定的修改。 75 | 76 | 其中5w条数据用来训练Reward Model(step 2),1个epoch后分类准确率可以达到92%。 77 | 78 | 所有10w条数据用来训练Reinforcement Learning with Human Feedback(step 3)。 79 | 80 | 从最后的训练结果看,rlhf的效果并不好,随着训练步数的增加,模型能力反而下降了(训练步数比较小的时候效果还可以),拒绝回答的频率增加了,但是复读机的频率降低了。一开始考虑可能是学习率设置的比较大的原因,降低学习率后会有一定的缓解,但是效果仍然不是很好,可能还是数据的质量比较差。 81 | 82 | 强化学习的脚本和代码在rlhf文件夹下面。 83 | 84 | 训练步数比较少的时候的回答: 85 | ![alt text](image-1.png) 86 | 87 | 训练步数比较多的时候的回答: 88 | ![alt text](image-2.png) 89 | 90 | -------------------------------------------------------------------------------- /data_process/data_process.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import time 4 | from tqdm import tqdm 5 | import re 6 | import os 7 | import pandas as pd 8 | from datasketch import MinHash, MinHashLSH 9 | import random 10 | 11 | 12 | def process_baike(): 13 | input_file = './563w_baidubaike.json' 14 | output_file = './baidubaike_no_depulication.json' 15 | batch_size = 100000 16 | 17 | processed_lines = 0 18 | start_time = time.time() 19 | # 正则表达式模式匹配 [1]、[2]、[3]、[1-2] 等内容 20 | bracket_pattern = re.compile(r'\[\d+(-\d+)?\]') 21 | punctuation_pattern = re.compile(r'[。!?:]$') 22 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fa5]') 23 | repeated_punctuation_pattern = re.compile(r'([。!?])\1+') 24 | whitespace_pattern = re.compile(r'\s+| +') 25 | 26 | 27 | def process_lines(lines, outfile): 28 | nonlocal processed_lines, start_time 29 | for line in lines: 30 | try: 31 | data = json.loads(line) 32 | text = "" 33 | title = data.get("title", "") 34 | summary = data.get("summary", "") 35 | if summary is None or summary.strip() == "": 36 | text = f"{title}。" 37 | elif summary.startswith(title): 38 | text = f"{summary}" 39 | if not punctuation_pattern.search(text): 40 | text += "。" 41 | else: 42 | text = f"{title},{summary}" 43 | if not punctuation_pattern.search(text): 44 | text += "。" 45 | skip_line = False 46 | sections = data.get("sections", []) 47 | for section in sections: 48 | section_title = section.get("title", "") 49 | if "重要参数" in section_title or "项目简介" in section_title or "产品简介" in section_title or "个人资料" in section_title or "个人简介" in section_title: 50 | skip_line = True 51 | break 52 | section_content = section.get("content", "") 53 | text += f"{section_title},{section_content}" 54 | if not punctuation_pattern.search(text): 55 | text += "。" 56 | 57 | chinese_chars = chinese_char_pattern.findall(text) 58 | if skip_line or len(chinese_chars) < 30 or text.count(' ') > 10: 59 | continue 60 | 61 | # 移除所有空白字符(包括全角空格) 62 | text = re.sub(whitespace_pattern, '', text) 63 | # 移除文本中的 [1]、[2]、[3] 等内容 64 | text = re.sub(bracket_pattern, '', text) 65 | # 合并重复的标点符号 66 | text = re.sub(repeated_punctuation_pattern, r'\1', text) 67 | new_data = { 68 | "text": text, 69 | "source": "baidubaike" 70 | } 71 | 72 | outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n') 73 | processed_lines += 1 74 | 75 | except json.JSONDecodeError as e: 76 | print(f"Error decoding JSON: {e}") 77 | except Exception as e: 78 | print(f"Error processing line: {e}") 79 | 80 | # Print total processed lines and processing speed 81 | elapsed_time = time.time() - start_time 82 | speed = processed_lines / elapsed_time 83 | tqdm.write(f"Processed {processed_lines} lines at {speed:.2f} lines/second") 84 | 85 | with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile: 86 | batch_lines = [] 87 | for line in tqdm(infile, desc="Reading lines"): 88 | batch_lines.append(line) 89 | if len(batch_lines) == batch_size: 90 | process_lines(batch_lines, outfile) 91 | batch_lines = [] 92 | 93 | # Process remaining lines 94 | if batch_lines: 95 | process_lines(batch_lines, outfile) 96 | 97 | def process_cn_wiki(): 98 | input_file = "./wikipedia-cn-20230720-filtered.json" 99 | output_file = "./wiki_cn.json" 100 | 101 | with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile: 102 | data = json.load(infile) 103 | for entry in data: 104 | text = entry.get("completion", "") 105 | new_entry = { 106 | "text": text, 107 | "source": "wiki_cn" 108 | } 109 | json.dump(new_entry, outfile, ensure_ascii=False) 110 | outfile.write('\n') 111 | 112 | print("Processing complete. Output saved to", output_file) 113 | 114 | def process_skypile(): 115 | input_dir = "./SkyPile-50/" 116 | output_file = "./skypile.json" 117 | 118 | # 获取所有 .jsonl 文件列表 119 | jsonl_files = [f for f in os.listdir(input_dir) if f.endswith(".jsonl")] 120 | 121 | with open(output_file, 'w', encoding='utf-8') as outfile: 122 | # 初始化文件级别的进度条 123 | for filename in tqdm(jsonl_files, desc="Processing files"): 124 | input_file = os.path.join(input_dir, filename) 125 | with open(input_file, 'r', encoding='utf-8') as infile: 126 | for line in infile: 127 | try: 128 | data = json.loads(line) 129 | text = data.get("text", "") 130 | new_entry = { 131 | "text": text, 132 | "source": "skypile" 133 | } 134 | json.dump(new_entry, outfile, ensure_ascii=False) 135 | outfile.write('\n') 136 | except json.JSONDecodeError as e: 137 | print(f"Error decoding JSON in file {filename}: {e}") 138 | except Exception as e: 139 | print(f"Error processing line in file {filename}: {e}") 140 | 141 | print("Processing complete. Output saved to", output_file) 142 | 143 | def ngrams(text, n=2): 144 | return [text[i:i+n] for i in range(len(text)-n+1)] 145 | 146 | def process_line(line, num_perm): 147 | data = json.loads(line) 148 | text = data["text"] 149 | minhash = MinHash(num_perm=num_perm) 150 | for d in ngrams(text, 2): 151 | minhash.update(d.encode('utf-8')) 152 | return data, minhash 153 | 154 | def depulication_cn_file(input_file, output_file, threshold): 155 | # MinHash-LSH 参数 156 | num_perm = 128 157 | lsh = MinHashLSH(threshold, num_perm=num_perm) 158 | key_counter = 0 159 | 160 | retained_lines = 0 161 | processed_lines = 0 162 | 163 | # 创建进度条 164 | pbar = tqdm(desc="Processing lines", unit="line", mininterval=0.1) 165 | 166 | with open(output_file, 'w', encoding='utf-8') as out_file: 167 | start_time = time.time() 168 | with open(input_file, 'r', encoding='utf-8') as file: 169 | for line in file: 170 | data, minhash = process_line(line, num_perm) 171 | unique_key = f"{data['source']}_{key_counter}" 172 | key_counter += 1 173 | if not lsh.query(minhash): 174 | lsh.insert(unique_key, minhash) 175 | json.dump(data, out_file, ensure_ascii=False) 176 | out_file.write('\n') 177 | retained_lines += 1 178 | processed_lines += 1 179 | pbar.update(1) 180 | elapsed_time = time.time() - start_time 181 | lines_per_second = processed_lines / elapsed_time if elapsed_time > 0 else 0 182 | pbar.set_postfix({"Retained": retained_lines, "Processed": processed_lines, "Speed": f"{lines_per_second:.2f} lines/sec"}) 183 | 184 | # 关闭进度条 185 | pbar.close() 186 | 187 | def depulication_cn_files(): 188 | # 定义路径 189 | input_dir = "/home/" 190 | output_file = "/home/deduplicated_cn_data.json" 191 | # MinHash-LSH 参数 192 | num_perm = 128 193 | lsh = MinHashLSH(threshold=0.6, num_perm=num_perm) 194 | key_counter = 0 195 | 196 | retained_lines = 0 197 | processed_lines = 0 198 | 199 | # 创建进度条 200 | pbar = tqdm(desc="Processing lines", unit="line", mininterval=0.1) 201 | 202 | with open(output_file, 'w', encoding='utf-8') as out_file: 203 | start_time = time.time() 204 | for filename in os.listdir(input_dir): 205 | if filename.endswith(".json"): 206 | file_path = os.path.join(input_dir, filename) 207 | with open(file_path, 'r', encoding='utf-8') as file: 208 | for line in file: 209 | data, minhash = process_line(line, num_perm) 210 | unique_key = f"{data['source']}_{key_counter}" 211 | key_counter += 1 212 | if not lsh.query(minhash): 213 | lsh.insert(unique_key, minhash) 214 | json.dump(data, out_file, ensure_ascii=False) 215 | out_file.write('\n') 216 | retained_lines += 1 217 | processed_lines += 1 218 | pbar.update(1) 219 | elapsed_time = time.time() - start_time 220 | lines_per_second = processed_lines / elapsed_time if elapsed_time > 0 else 0 221 | pbar.set_postfix({"Retained": retained_lines, "Processed": processed_lines, "Speed": f"{lines_per_second:.2f} lines/sec"}) 222 | 223 | # 关闭进度条 224 | pbar.close() 225 | 226 | 227 | def merge_data(): 228 | input_files = [ 229 | './baidubaike.json', 230 | './wiki_cn.json', 231 | './skypile.json' 232 | ] 233 | output_file = './pretrain.json' 234 | sampling_ratios = [1, 1, 0.57] # 分别从每个文件中抽取100%、100%和57%的数据 235 | 236 | assert len(input_files) == len(sampling_ratios), "输入文件数和抽样比例数不匹配" 237 | 238 | line_counts = {} 239 | 240 | with open(output_file, 'w', encoding='utf-8') as out_f: 241 | for file, ratio in zip(input_files, sampling_ratios): 242 | line_counts[file] = 0 243 | with open(file, 'r', encoding='utf-8') as in_f: 244 | for line in in_f: 245 | if random.random() <= ratio: 246 | data = json.loads(line.strip()) 247 | out_f.write(json.dumps(data, ensure_ascii=False) + '\n') 248 | line_counts[file] += 1 249 | 250 | for file, count in line_counts.items(): 251 | print(f"{file} 写入了 {count} 行") 252 | 253 | 254 | def generate_train_tokenizer_data(): 255 | input_files = [ 256 | './baidubaike.json', 257 | './wiki_cn.json', 258 | './skypile.json' 259 | ] 260 | output_file = './train_tokenizer.json' 261 | sampling_ratios = [1, 0.5, 0.02] # 分别从每个文件中抽取100%、50%和2%的数据 262 | 263 | assert len(input_files) == len(sampling_ratios), "输入文件数和抽样比例数不匹配" 264 | 265 | line_counts = {} 266 | 267 | with open(output_file, 'w', encoding='utf-8') as out_f: 268 | for file, ratio in zip(input_files, sampling_ratios): 269 | line_counts[file] = 0 270 | with open(file, 'r', encoding='utf-8') as in_f: 271 | for line in in_f: 272 | if random.random() < ratio: 273 | data = json.loads(line.strip()) 274 | out_f.write(json.dumps(data, ensure_ascii=False) + '\n') 275 | line_counts[file] += 1 276 | 277 | for file, count in line_counts.items(): 278 | print(f"{file} 写入了 {count} 行") 279 | 280 | def sft_process_firefly(): 281 | input_data_path = './firefly-cn-train-1.1M.jsonl' 282 | output_data_path = './processed_firefly.jsonl' 283 | 284 | line_count = 0 285 | 286 | with open(input_data_path, 'r', encoding='utf-8') as infile, open(output_data_path, 'w', encoding='utf-8') as outfile: 287 | for line in infile: 288 | data = json.loads(line) 289 | conversations = data.get("conversations", []) 290 | if len(conversations) == 2 and conversations[0].get("from") == "human" and conversations[1].get("from") == "gpt": 291 | human_value = conversations[0].get("value", "") 292 | if len(human_value) > 5: 293 | new_data = { 294 | "messages": [ 295 | {"from": "user", "value": human_value}, 296 | {"from": "assistant", "value": conversations[1].get("value", "")} 297 | ] 298 | } 299 | outfile.write(json.dumps(new_data, ensure_ascii=False) + '\n') 300 | line_count += 1 301 | 302 | print(f"Total lines written: {line_count}") 303 | 304 | def process_sft_line(line, num_perm): 305 | data = json.loads(line) 306 | messages = data.get("messages", []) 307 | combined_text = ''.join([msg['value'] for msg in messages if msg['from'] in ['user', 'assistant']]) 308 | minhash = MinHash(num_perm=num_perm) 309 | for d in ngrams(combined_text, 2): 310 | minhash.update(d.encode('utf-8')) 311 | return data, minhash 312 | 313 | def depulication_cn_firefly(): 314 | # 定义路径 315 | input_file = "./processed_firefly.jsonl" 316 | output_file = "./depulication_firefly.jsonl" 317 | # MinHash-LSH 参数 318 | num_perm = 128 319 | lsh = MinHashLSH(threshold=0.4, num_perm=num_perm) 320 | key_counter = 0 321 | 322 | retained_lines = 0 323 | processed_lines = 0 324 | 325 | # 创建进度条 326 | pbar = tqdm(desc="Processing lines", unit="line", mininterval=0.1) 327 | 328 | with open(output_file, 'w', encoding='utf-8') as out_file: 329 | start_time = time.time() 330 | with open(input_file, 'r', encoding='utf-8') as file: 331 | for line in file: 332 | data, minhash = process_sft_line(line, num_perm) 333 | unique_key = f"{key_counter}" 334 | key_counter += 1 335 | if not lsh.query(minhash): 336 | lsh.insert(unique_key, minhash) 337 | json.dump(data, out_file, ensure_ascii=False) 338 | out_file.write('\n') 339 | retained_lines += 1 340 | processed_lines += 1 341 | pbar.update(1) 342 | elapsed_time = time.time() - start_time 343 | lines_per_second = processed_lines / elapsed_time if elapsed_time > 0 else 0 344 | pbar.set_postfix({"Retained": retained_lines, "Processed": processed_lines, "Speed": f"{lines_per_second:.2f} lines/sec"}) 345 | 346 | # 关闭进度条 347 | pbar.close() 348 | 349 | def generate_sft_rlfh_data(): 350 | cn_firefly_file = './depulication_firefly.jsonl' 351 | ruozhiba_file = './ruozhiout_qa_cn.jsonl' 352 | total_count = 400000 353 | sft_count = 300000 354 | #rlhf_count = total_count - sft_count 355 | output_file_sft = './sft.jsonl' 356 | output_file_rlfh = './rlhf.jsonl' 357 | # Load data from files 358 | with open(ruozhiba_file, 'r', encoding='utf-8') as f: 359 | ruozhiba_data = [json.loads(line) for line in f] 360 | 361 | with open(cn_firefly_file, 'r', encoding='utf-8') as f: 362 | cn_firefly_data = [json.loads(line) for line in f] 363 | 364 | # Extract conversations from ruozhiba_data 365 | sft_data = [] 366 | for item in ruozhiba_data: 367 | for conversation in item['conversations']: 368 | if conversation['from'] == 'human': 369 | prompt = conversation['value'] 370 | elif conversation['from'] == 'gpt': 371 | answer = conversation['value'] 372 | sft_data.append({'prompt': prompt, 'answer': answer}) 373 | 374 | # Calculate the number of entries to pick from cn_firefly_data 375 | ruozhiba_count = len(sft_data) 376 | cn_firefly_count = total_count - ruozhiba_count 377 | 378 | # Randomly select entries from cn_firefly_data 379 | random_cn_firefly_data = random.sample(cn_firefly_data, cn_firefly_count) 380 | 381 | # Extract messages from cn_firefly_data 382 | for item in random_cn_firefly_data: 383 | for message in item['messages']: 384 | if message['from'] == 'user': 385 | prompt = message['value'] 386 | elif message['from'] == 'assistant': 387 | answer = message['value'] 388 | sft_data.append({'prompt': prompt, 'answer': answer}) 389 | 390 | # Randomly shuffle the data 391 | random.shuffle(sft_data) 392 | 393 | # Split the data into two parts 394 | #split_index = len(sft_data) // 2 395 | sft_part = sft_data[:sft_count] 396 | rlhf_part = sft_data[sft_count:] 397 | 398 | # Write data to output files 399 | with open(output_file_sft, 'w', encoding='utf-8') as f: 400 | for item in sft_part: 401 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 402 | 403 | with open(output_file_rlfh, 'w', encoding='utf-8') as f: 404 | for item in rlhf_part: 405 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 406 | 407 | def main(): 408 | # process_baike()#保留2,763,469行 409 | #process_cn_wiki() 410 | #process_skypile() 411 | #百度百科去重 412 | # depulication_cn_file('./baidubaike_no_depulication.json', './baidubaike.json', 0.4) 413 | # merge_data() 414 | # generate_train_tokenizer_data() 415 | #sft_process_firefly() 416 | #depulication_cn_firefly() 417 | generate_sft_rlfh_data() 418 | 419 | 420 | if __name__ == '__main__': 421 | main() -------------------------------------------------------------------------------- /image-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/image-1.png -------------------------------------------------------------------------------- /image-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/image-2.png -------------------------------------------------------------------------------- /image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/image.png -------------------------------------------------------------------------------- /image.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/image.webp -------------------------------------------------------------------------------- /loss_vs_time_hours.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/loss_vs_time_hours.png -------------------------------------------------------------------------------- /loss_vs_tokens_millions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/loss_vs_tokens_millions.png -------------------------------------------------------------------------------- /pretrain/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupDecayLR", 23 | "params": { 24 | "warmup_min_lr": 1e-5, 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto", 27 | "total_num_steps": "auto" 28 | } 29 | }, 30 | 31 | "zero_optimization": { 32 | "stage": 2, 33 | "allgather_partitions": true, 34 | "allgather_bucket_size": 2e8, 35 | "overlap_comm": true, 36 | "reduce_scatter": true, 37 | "reduce_bucket_size": 2e8, 38 | "contiguous_gradients": true 39 | }, 40 | 41 | "gradient_accumulation_steps": "auto", 42 | "gradient_clipping": "auto", 43 | "steps_per_print": 2000, 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "wall_clock_breakdown": false 47 | } -------------------------------------------------------------------------------- /pretrain/generate_pretrain_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import random 5 | import numpy as np 6 | from multiprocessing import Process, Manager 7 | from transformers import AutoTokenizer 8 | 9 | def split_file(data_path, num_splits=20): 10 | file_handles = [open(f"{data_path}.part{i}", 'w', encoding='utf-8') for i in range(num_splits)] 11 | 12 | try: 13 | total_lines = 0 14 | with open(data_path, 'r', encoding='utf-8') as f: 15 | for i, line in enumerate(f): 16 | part_idx = i % num_splits 17 | file_handles[part_idx].write(line) 18 | total_lines += 1 19 | if total_lines % 1000 == 0: # 每处理1000行打印一次进度 20 | print(f"Processed lines: {total_lines}") 21 | finally: 22 | for handle in file_handles: 23 | handle.close() 24 | print(f"Total lines processed: {total_lines}") 25 | 26 | def process_file(part_path, bin_path, tokenizer, ratio, result_dict): 27 | source_token_counts = {} 28 | total_token_count = 0 29 | line_count = 0 30 | start_time = time.time() 31 | 32 | with open(part_path, 'r', encoding='utf-8') as f, open(bin_path, 'wb') as f2: 33 | for line in f: 34 | if random.random() > ratio: 35 | continue 36 | data = json.loads(line) 37 | text = data['text'] 38 | source = data['source'] 39 | text_id = tokenizer.encode(text, add_special_tokens=False) 40 | text_id.append(tokenizer.eos_token_id) 41 | 42 | token_count = len(text_id) 43 | if source not in source_token_counts: 44 | source_token_counts[source] = 0 45 | source_token_counts[source] += token_count 46 | 47 | total_token_count += token_count 48 | 49 | arr = np.array(text_id, dtype=np.uint16) 50 | f2.write(arr.tobytes()) 51 | 52 | line_count += 1 53 | elapsed_time = time.time() - start_time 54 | print(f"Processed lines: {line_count}, Time elapsed: {elapsed_time:.2f} seconds") 55 | 56 | result_dict[part_path] = (source_token_counts, total_token_count) 57 | 58 | def merge_bins(bin_paths, final_bin_path, chunk_size=10*1024*1024): 59 | with open(final_bin_path, 'wb') as f_out: 60 | for bin_path in bin_paths: 61 | with open(bin_path, 'rb') as f_in: 62 | while True: 63 | chunk = f_in.read(chunk_size) 64 | if not chunk: 65 | break 66 | f_out.write(chunk) 67 | 68 | def main(data_path, bin_path, ratio=1): 69 | tokenizer_path = './miaomiao_tokenizer' 70 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_path,trust_remote_code=True) # 主进程加载tokenizer 71 | num_splits = 20 72 | 73 | # Split the file into parts 74 | split_file(data_path, num_splits) 75 | 76 | manager = Manager() 77 | result_dict = manager.dict() 78 | 79 | processes = [] 80 | bin_paths = [f"{bin_path}.part{i}.bin" for i in range(num_splits)] 81 | 82 | for i in range(num_splits): 83 | part_path = f"{data_path}.part{i}" 84 | bin_part_path = bin_paths[i] 85 | p = Process(target=process_file, args=(part_path, bin_part_path, tokenizer, ratio, result_dict)) 86 | processes.append(p) 87 | p.start() 88 | 89 | for p in processes: 90 | p.join() 91 | 92 | # Merge binary files 93 | merge_bins(bin_paths, bin_path) 94 | 95 | # Output combined statistics 96 | combined_source_token_counts = {} 97 | combined_total_token_count = 0 98 | 99 | for source_token_counts, total_token_count in result_dict.values(): 100 | for source, count in source_token_counts.items(): 101 | if source not in combined_source_token_counts: 102 | combined_source_token_counts[source] = 0 103 | combined_source_token_counts[source] += count 104 | combined_total_token_count += total_token_count 105 | 106 | print("Token counts by source:", combined_source_token_counts) 107 | print("Total token count:", combined_total_token_count) 108 | 109 | if __name__ == "__main__": 110 | #一共15M行 111 | data_path = "./pretrain_data_train.json" 112 | bin_path = "./pretrain_data_train.bin" 113 | main(data_path, bin_path, ratio=1) 114 | -------------------------------------------------------------------------------- /pretrain/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "miaomiao", 3 | "architectures": [ 4 | "MiaomiaoModel" 5 | ], 6 | "auto_map": { 7 | "AutoConfig": "configuration_miaomiao.MiaomiaoConfig", 8 | "AutoModel": "modeling_miaomiao.MiaomiaoModel", 9 | "AutoModelForCausalLM": "modeling_miaomiao.MiaomiaoForCausalLM" 10 | }, 11 | "attention_dropout": 0.0, 12 | "bos_token_id": 32005, 13 | "eos_token_id": 32005, 14 | "hidden_act": "silu", 15 | "hidden_size": 512, 16 | "initializer_range": 0.02, 17 | "intermediate_size": 2752, 18 | "max_position_embeddings": 131072, 19 | "max_window_layers": 28, 20 | "num_attention_heads": 16, 21 | "num_hidden_layers": 24, 22 | "num_key_value_heads": 16, 23 | "rms_norm_eps": 1e-06, 24 | "rope_theta": 1000000.0, 25 | "sliding_window": 131072, 26 | "tie_word_embeddings": false, 27 | "torch_dtype": "bfloat16", 28 | "transformers_version": "4.37.2", 29 | "use_cache": true, 30 | "use_sliding_window": false, 31 | "vocab_size": 32006 32 | } -------------------------------------------------------------------------------- /pretrain/model/configuration_miaomiao.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | """ Miaomiao model configuration""" 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class MiaomiaoConfig(PretrainedConfig): 13 | 14 | model_type = "miaomiao" 15 | keys_to_ignore_at_inference = ["past_key_values"] 16 | 17 | def __init__( 18 | self, 19 | vocab_size=32000, 20 | hidden_size=4096, 21 | intermediate_size=11008, 22 | num_hidden_layers=32, 23 | num_attention_heads=32, 24 | num_key_value_heads=None, 25 | hidden_act="silu", 26 | max_position_embeddings=2048, 27 | initializer_range=0.02, 28 | rms_norm_eps=1e-6, 29 | use_cache=True, 30 | pad_token_id=None, 31 | bos_token_id=1, 32 | eos_token_id=2, 33 | pretraining_tp=1, 34 | tie_word_embeddings=False, 35 | rope_theta=10000.0, 36 | rope_scaling=None, 37 | attention_bias=False, 38 | attention_dropout=0.0, 39 | mlp_bias=False, 40 | _attn_implementation="eager", 41 | **kwargs, 42 | ): 43 | self.vocab_size = vocab_size 44 | self.max_position_embeddings = max_position_embeddings 45 | self.hidden_size = hidden_size 46 | self.intermediate_size = intermediate_size 47 | self.num_hidden_layers = num_hidden_layers 48 | self.num_attention_heads = num_attention_heads 49 | 50 | # for backward compatibility 51 | if num_key_value_heads is None: 52 | num_key_value_heads = num_attention_heads 53 | 54 | self.num_key_value_heads = num_key_value_heads 55 | self.hidden_act = hidden_act 56 | self.initializer_range = initializer_range 57 | self.rms_norm_eps = rms_norm_eps 58 | self.pretraining_tp = pretraining_tp 59 | self.use_cache = use_cache 60 | self.rope_theta = rope_theta 61 | self.rope_scaling = rope_scaling 62 | self._rope_scaling_validation() 63 | self.attention_bias = attention_bias 64 | self.attention_dropout = attention_dropout 65 | self.mlp_bias = mlp_bias 66 | self._attn_implementation = _attn_implementation 67 | super().__init__( 68 | pad_token_id=pad_token_id, 69 | bos_token_id=bos_token_id, 70 | eos_token_id=eos_token_id, 71 | tie_word_embeddings=tie_word_embeddings, 72 | **kwargs, 73 | ) 74 | 75 | def _rope_scaling_validation(self): 76 | """ 77 | Validate the `rope_scaling` configuration. 78 | """ 79 | if self.rope_scaling is None: 80 | return 81 | 82 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 83 | raise ValueError( 84 | "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" 85 | ) 86 | rope_scaling_type = self.rope_scaling.get("type", None) 87 | rope_scaling_factor = self.rope_scaling.get("factor", None) 88 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 89 | raise ValueError( 90 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 91 | ) 92 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 93 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 94 | -------------------------------------------------------------------------------- /pretrain/pretrain.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | from transformers.utils.versions import require_version 4 | import transformers 5 | from model.modeling_miaomiao import MiaomiaoForCausalLM 6 | from model.configuration_miaomiao import MiaomiaoConfig 7 | from pretrain_dataset import PretrainDataset 8 | from transformers import ( 9 | CONFIG_MAPPING, 10 | MODEL_FOR_CAUSAL_LM_MAPPING, 11 | AutoConfig, 12 | AutoModelForCausalLM, 13 | HfArgumentParser, 14 | Trainer, 15 | TrainingArguments, 16 | is_torch_tpu_available, 17 | set_seed, 18 | ) 19 | from transformers.trainer_callback import TrainerCallback 20 | import torch 21 | import json 22 | import os 23 | import logging 24 | import glob 25 | import random 26 | import numpy as np 27 | 28 | # 设置随机种子 29 | def set_seed(seed): 30 | random.seed(seed) 31 | np.random.seed(seed) 32 | torch.manual_seed(seed) 33 | if torch.cuda.is_available(): 34 | torch.cuda.manual_seed_all(seed) 35 | 36 | 37 | class LoggingCallback(TrainerCallback): 38 | def __init__(self, logger): 39 | self.logger = logger 40 | 41 | def on_log(self, args, state, control, logs=None, **kwargs): 42 | if logs is not None: 43 | self.logger.info(logs) 44 | 45 | 46 | @dataclass 47 | class ModelArguments: 48 | config_file: Optional[str] = None 49 | torch_dtype: Optional[str] = None 50 | 51 | @dataclass 52 | class DataTrainingArguments: 53 | train_dataset_dir: Optional[str] = None 54 | block_size: Optional[int] = None 55 | overwrite_cache: bool = False 56 | preprocessing_num_workers: Optional[int] = None 57 | 58 | 59 | @dataclass 60 | class MyTrainingArguments(TrainingArguments): 61 | modules_to_save: Optional[str] = None 62 | 63 | 64 | # 模型初始化方式 65 | init_from: Optional[str] = "scratch" 66 | use_device: Optional[str] = 'cuda' 67 | use_compile: Optional[bool] = False 68 | log_file: Optional[str] = None 69 | nnodes: Optional[int] = None 70 | nproc_per_node: Optional[int] = None 71 | 72 | def load_config(file_path): 73 | with open(file_path, 'r', encoding='utf-8') as file: 74 | config = json.load(file) 75 | return config 76 | 77 | def init_model(training_args, model_args): 78 | if training_args.init_from == "scratch": 79 | config = MiaomiaoConfig.from_pretrained(model_args.config_file) 80 | print(config) 81 | model = MiaomiaoForCausalLM(config) 82 | return model 83 | 84 | 85 | 86 | def my_data_collator(input_datas): 87 | # 将所有样本的输入 (`X`) 和标签 (`Y`) 分别堆叠 88 | input_ids = torch.stack([input_data[0] for input_data in input_datas]) 89 | labels = torch.stack([input_data[1] for input_data in input_datas]) 90 | 91 | # 返回一个字典,包含模型需要的键和值 92 | return { 93 | "input_ids": input_ids, 94 | "labels": labels 95 | } 96 | 97 | def main(): 98 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments)) 99 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 100 | 101 | # 设置日志记录器 102 | logging.basicConfig(filename=training_args.log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 103 | logger = logging.getLogger(__name__) 104 | # 创建文件处理器,并设置写模式 105 | file_handler = logging.FileHandler(training_args.log_file, mode='w') 106 | file_handler.setLevel(logging.INFO) 107 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 108 | file_handler.setFormatter(file_formatter) 109 | logger.addHandler(file_handler) 110 | # 输出日志到控制台(可选) 111 | console_handler = logging.StreamHandler() 112 | console_handler.setLevel(logging.INFO) 113 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 114 | console_handler.setFormatter(formatter) 115 | logger.addHandler(console_handler) 116 | 117 | set_seed(training_args.seed) 118 | 119 | model=init_model(training_args, model_args) 120 | model.to(training_args.use_device) 121 | 122 | if training_args.use_compile: 123 | model = torch.compile(model) 124 | 125 | 126 | total_params = sum(p.numel() for p in model.parameters()) 127 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 128 | logger.info(f"总参数: {total_params}") 129 | logger.info(f"可训练参数: {trainable_params}") 130 | 131 | logger.info(f"torch_dtype:{model_args.torch_dtype}") 132 | logger.info(f"training_args.bf16: {training_args.bf16}") 133 | 134 | 135 | train_data_path_list = glob.glob(os.path.join(data_args.train_dataset_dir, '*.bin')) 136 | train_ds = PretrainDataset(train_data_path_list, max_length=data_args.block_size, memmap=True, seed=training_args.seed) 137 | logger.info(f"Train dataset size: {len(train_ds)}") 138 | 139 | trainer = Trainer( 140 | model=model, 141 | args=training_args, 142 | train_dataset=train_ds, 143 | data_collator=my_data_collator, 144 | callbacks=[LoggingCallback(logger)], # 添加自定义回调 145 | ) 146 | print(training_args.bf16) 147 | 148 | trainer.train() 149 | 150 | 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /pretrain/pretrain.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | lr=4e-4 6 | block_size=1024 7 | 8 | per_device_train_batch_size=24 9 | gradient_accumulation_steps=1 10 | config_file=./model/config.json 11 | train_dataset_dir=./pretrain_data 12 | log_file=./log/pretrain1.log 13 | output_dir=./output 14 | deepspeed_config_file=./ds_config.json 15 | random_seed=42 16 | 17 | torchrun --nnodes 1 --nproc_per_node 2 pretrain.py \ 18 | --deepspeed ${deepspeed_config_file} \ 19 | --config_file ${config_file} \ 20 | --train_dataset_dir ${train_dataset_dir} \ 21 | --per_device_train_batch_size ${per_device_train_batch_size} \ 22 | --do_train \ 23 | --bf16 True\ 24 | --torch_dtype bfloat16 \ 25 | --seed ${random_seed} \ 26 | --num_train_epochs 1 \ 27 | --logging_strategy steps \ 28 | --logging_steps 100 \ 29 | --log_file ${log_file} \ 30 | --logging_first_step True \ 31 | --adam_beta1 0.9 \ 32 | --adam_beta1 0.95 \ 33 | --lr_scheduler_type cosine \ 34 | --learning_rate ${lr} \ 35 | --warmup_ratio 0.05 \ 36 | --weight_decay 0.01 \ 37 | --save_strategy steps \ 38 | --save_total_limit 1 \ 39 | --save_steps 0.01 \ 40 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 41 | --block_size ${block_size} \ 42 | --output_dir ${output_dir} \ 43 | --overwrite_output_dir \ 44 | --ddp_timeout 30000 \ 45 | --init_from scratch \ 46 | --use_device cuda \ 47 | --use_compile False \ -------------------------------------------------------------------------------- /pretrain/pretrain_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import random 3 | import pandas as pd 4 | import numpy as np 5 | from torch.utils.data import Dataset,DataLoader 6 | import torch 7 | from sklearn.model_selection import train_test_split 8 | 9 | class PretrainDataset(Dataset): 10 | def __init__(self, data_path_lst, max_length=512, memmap=False, seed=42): 11 | super().__init__() 12 | 13 | self.max_length = max_length 14 | self.seed = seed 15 | 16 | if memmap: 17 | with open(data_path_lst[0], 'rb') as f: 18 | nbytes = f.seek(0, 2) 19 | flen = nbytes // np.dtype('int16').itemsize # 使用 int16 数据类型 20 | self.data = np.memmap(data_path_lst[0], dtype=np.dtype('int16'), shape=(flen // max_length, max_length), mode='r') 21 | else: 22 | data_lst = [] 23 | for data_path in data_path_lst: 24 | with open(data_path, 'rb') as f: 25 | data = np.fromfile(f, dtype=np.int16) # 使用 int16 数据类型 26 | data_lst.append(data) 27 | data = np.concatenate(data_lst) 28 | data = data[:max_length * (len(data) // max_length)] 29 | self.data = data.reshape(-1, max_length) 30 | 31 | self.indices = np.arange(len(self.data)) 32 | np.random.shuffle(self.indices) 33 | print("memmap:{} train data.shape:{}".format(memmap, self.data.shape)) 34 | print("downloading finished.....") 35 | 36 | def __len__(self): 37 | return self.data.shape[0] 38 | 39 | def shuffle_indices(self): 40 | np.random.seed(self.seed) 41 | 42 | def __getitem__(self, index: int): 43 | index = self.indices[index] 44 | sample = self.data[index] 45 | X = np.array(sample).astype(np.int64) 46 | Y = np.array(sample).astype(np.int64) 47 | return torch.from_numpy(X), torch.from_numpy(Y) 48 | -------------------------------------------------------------------------------- /pretrain/test_pretrain_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | 4 | device = "cuda" # the device to load the model onto 5 | 6 | model = AutoModelForCausalLM.from_pretrained( 7 | './pretrain/model', 8 | torch_dtype="auto", 9 | device_map="auto", 10 | trust_remote_code=True 11 | ) 12 | tokenizer = AutoTokenizer.from_pretrained('./miaomiao_tokenizer', trust_remote_code=True) 13 | text = "床前明月光," 14 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 15 | print(model_inputs) 16 | generated_ids = model.generate( 17 | **model_inputs, 18 | max_new_tokens=1024 19 | ) 20 | print(generated_ids) 21 | generated_ids = [ 22 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 23 | ] 24 | print(generated_ids) 25 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 26 | print(response) 27 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.32.1 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | annotated-types==0.7.0 6 | attrs==23.2.0 7 | blinker==1.4 8 | certifi==2024.7.4 9 | charset-normalizer==3.3.2 10 | contourpy==1.2.1 11 | cryptography==3.4.8 12 | cycler==0.12.1 13 | datasets==2.20.0 14 | datasketch==1.6.5 15 | dbus-python==1.2.18 16 | deepspeed==0.14.4 17 | dill==0.3.8 18 | distro==1.7.0 19 | distro-info==1.1+ubuntu0.2 20 | einops==0.8.0 21 | filelock==3.15.4 22 | flash-attn==2.5.9.post1 23 | fonttools==4.53.1 24 | frozenlist==1.4.1 25 | fsspec==2024.5.0 26 | grpcio==1.64.1 27 | hjson==3.1.0 28 | httplib2==0.20.2 29 | huggingface-hub==0.23.4 30 | idna==3.7 31 | importlib-metadata==4.6.4 32 | jeepney==0.7.1 33 | Jinja2==3.1.4 34 | joblib==1.4.2 35 | keyring==23.5.0 36 | kiwisolver==1.4.5 37 | launchpadlib==1.10.16 38 | lazr.restfulclient==0.14.4 39 | lazr.uri==1.0.6 40 | Markdown==3.6 41 | MarkupSafe==2.1.5 42 | matplotlib==3.9.1 43 | more-itertools==8.10.0 44 | mpmath==1.3.0 45 | multidict==6.0.5 46 | multiprocess==0.70.16 47 | networkx==3.3 48 | ninja==1.11.1.1 49 | numpy==1.26.4 50 | nvidia-cublas-cu12==12.1.3.1 51 | nvidia-cuda-cupti-cu12==12.1.105 52 | nvidia-cuda-nvrtc-cu12==12.1.105 53 | nvidia-cuda-runtime-cu12==12.1.105 54 | nvidia-cudnn-cu12==8.9.2.26 55 | nvidia-cufft-cu12==11.0.2.54 56 | nvidia-curand-cu12==10.3.2.106 57 | nvidia-cusolver-cu12==11.4.5.107 58 | nvidia-cusparse-cu12==12.1.0.106 59 | nvidia-ml-py==12.555.43 60 | nvidia-nccl-cu12==2.20.5 61 | nvidia-nvjitlink-cu12==12.5.82 62 | nvidia-nvtx-cu12==12.1.105 63 | oauthlib==3.2.0 64 | packaging==24.1 65 | pandas==2.2.2 66 | pillow==10.4.0 67 | pip==24.1.2 68 | protobuf==4.25.3 69 | psutil==6.0.0 70 | py-cpuinfo==9.0.0 71 | pyarrow==16.1.0 72 | pyarrow-hotfix==0.6 73 | pydantic==2.8.2 74 | pydantic_core==2.20.1 75 | PyGObject==3.42.1 76 | PyJWT==2.3.0 77 | pyparsing==3.1.2 78 | python-apt==2.4.0+ubuntu3 79 | python-dateutil==2.9.0.post0 80 | pytz==2024.1 81 | PyYAML==6.0.1 82 | regex==2024.5.15 83 | requests==2.32.3 84 | safetensors==0.4.3 85 | scikit-learn==1.5.1 86 | scipy==1.14.0 87 | SecretStorage==3.3.1 88 | setuptools==70.3.0 89 | six==1.16.0 90 | sympy==1.13.0 91 | tensorboard==2.17.0 92 | tensorboard-data-server==0.7.2 93 | threadpoolctl==3.5.0 94 | tiktoken==0.7.0 95 | tokenizers==0.19.1 96 | torch==2.3.1 97 | torchaudio==2.3.1 98 | torchvision==0.18.1 99 | tqdm==4.66.4 100 | transformers==4.42.3 101 | triton==2.3.1 102 | typing_extensions==4.12.2 103 | tzdata==2024.1 104 | unattended-upgrades==0.1 105 | urllib3==2.2.2 106 | wadllib==1.3.6 107 | Werkzeug==3.0.3 108 | wheel==0.43.0 109 | xxhash==3.4.1 110 | yarl==1.9.4 111 | zipp==1.0.0 112 | -------------------------------------------------------------------------------- /rlhf/rlhf/__pycache__/ppo_trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/rlhf/__pycache__/ppo_trainer.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/rlhf/__pycache__/rlhf_engine.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/rlhf/__pycache__/rlhf_engine.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/rlhf/ppo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | import torch 6 | import torch.nn.functional as F 7 | import time 8 | import deepspeed 9 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 10 | from deepspeed.accelerator import get_accelerator 11 | 12 | from utils.utils import print_rank_0 13 | 14 | 15 | def print_all_ranks(tag, value, rank): 16 | world_size = torch.distributed.get_world_size() 17 | all_tensor = torch.zeros(world_size, dtype=torch.float32).to( 18 | get_accelerator().current_device_name()) 19 | all_tensor[rank] = value 20 | torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM) 21 | print_rank_0(f'{tag} {all_tensor}', rank) 22 | 23 | 24 | def get_model_norm(model): 25 | with torch.no_grad(): 26 | total = 0.0 27 | for param in model.parameters(): 28 | should_gather = hasattr( 29 | param, 30 | 'ds_id') and param.ds_status == ZeroParamStatus.NOT_AVAILABLE 31 | with deepspeed.zero.GatheredParameters(param, 32 | enabled=should_gather): 33 | total += float(param.float().norm()) 34 | 35 | return total 36 | 37 | 38 | def gather_log_probs(logits, labels): 39 | log_probs = F.log_softmax(logits, dim=-1) 40 | log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) 41 | return log_probs_labels.squeeze(-1) 42 | 43 | 44 | class DeepSpeedPPOTrainer(): 45 | 46 | def __init__(self, rlhf_engine, args): 47 | self.rlhf_engine = rlhf_engine 48 | self.actor_model = self.rlhf_engine.actor 49 | self.critic_model = self.rlhf_engine.critic 50 | self.ref_model = self.rlhf_engine.ref 51 | self.reward_model = self.rlhf_engine.reward 52 | self.tokenizer = self.rlhf_engine.tokenizer 53 | self.args = args 54 | self.max_answer_seq_len = args.max_answer_seq_len 55 | self.end_of_conversation_token_id = self.tokenizer.eos_token_id 56 | self.z3_enabled = args.actor_zero_stage == 3 57 | self.compute_fp32_loss = self.args.compute_fp32_loss 58 | 59 | # In case the generated experience is not valid (too short), we use the last valid 60 | # generated experience. Alternatively, we can skip the step (on all workers). 61 | # For now, use the last valid experience which is a simpler solution 62 | self.last_generated_experience = None 63 | 64 | # Those value can be changed 65 | self.kl_ctl = 0.1 66 | self.clip_reward_value = 5 67 | self.cliprange = 0.2 68 | self.cliprange_value = 0.2 69 | self.gamma = 1.0 70 | self.lam = 0.95 71 | self.generate_time = 0.0 72 | 73 | def _generate_sequence(self, prompts, mask, step): 74 | 75 | max_min_length = self.max_answer_seq_len + prompts.shape[1] 76 | 77 | # This has been added due to a probability/nan error that happens after 78 | # meta-llama/Llama-2-7b-hf enabled do_sample: 79 | # https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9 80 | 81 | kwargs = dict(do_sample=False) 82 | 83 | 84 | with torch.no_grad(): 85 | seq = self.actor_model.module.generate( 86 | prompts, 87 | attention_mask=mask, 88 | max_length=max_min_length, 89 | pad_token_id=self.tokenizer.pad_token_id, 90 | synced_gpus=self.z3_enabled, 91 | **kwargs) 92 | 93 | # Filter out seq with no answers (or very short). This happens when users directly use the pre-training ckpt without supervised finetuning 94 | # NOTE: this will causes each GPU has different number of examples 95 | batch_size = seq.shape[0] 96 | prompt_length = prompts.shape[1] 97 | self.prompt_length = prompt_length 98 | ans = seq[:, prompt_length:] 99 | valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1) 100 | 101 | if self.args.print_answers and (step % self.args.print_answers_interval 102 | == 0): 103 | print( 104 | f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}" 105 | ) 106 | print( 107 | f"--- ans --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(ans, skip_special_tokens=True)}" 108 | ) 109 | 110 | out_seq = [] 111 | for i in range(batch_size): 112 | if valid_ans_len[ 113 | i] <= 1: # if the answer is shorter than 1 token, drop it 114 | print( 115 | f'Dropping too short generated answer: {step=}: \n' 116 | f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n' 117 | f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}' 118 | ) 119 | continue 120 | else: 121 | out_seq.append(seq[i:i + 1]) 122 | 123 | if not out_seq: 124 | print( 125 | f'All generated results are too short for rank={self.args.local_rank} step={step}\n' 126 | f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n' 127 | f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}' 128 | ) 129 | return None 130 | 131 | out_seq = torch.cat(out_seq, dim=0) # concat output in the batch dim 132 | 133 | return out_seq 134 | 135 | def generate_experience(self, prompts, mask, step): 136 | self.eval() 137 | generate_start = time.time() 138 | seq = self._generate_sequence(prompts, mask, step) 139 | generate_end = time.time() 140 | if seq is None: 141 | assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}' 142 | prompts = self.last_generated_experience['prompts'] 143 | seq = self.last_generated_experience['seq'] 144 | else: 145 | self.last_generated_experience = {'prompts': prompts, 'seq': seq} 146 | self.train() 147 | 148 | pad_token_id = self.tokenizer.pad_token_id 149 | attention_mask = seq.not_equal(pad_token_id).long() 150 | with torch.no_grad(): 151 | output = self.actor_model(seq, attention_mask=attention_mask) 152 | output_ref = self.ref_model(seq, attention_mask=attention_mask) 153 | reward_score = self.reward_model.forward_value( 154 | seq, attention_mask, 155 | prompt_length=self.prompt_length)['chosen_end_scores'].detach( 156 | ) 157 | values = self.critic_model.forward_value( 158 | seq, attention_mask, return_value_only=True).detach()[:, :-1] 159 | 160 | logits = output.logits 161 | logits_ref = output_ref.logits 162 | if self.compute_fp32_loss: 163 | logits = logits.to(torch.float) 164 | logits_ref = logits_ref.to(torch.float) 165 | 166 | self.generate_time = generate_end - generate_start 167 | 168 | return { 169 | 'prompts': prompts, 170 | 'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]), 171 | 'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:, 172 | 1:]), 173 | 'value': values, 174 | 'rewards': reward_score, 175 | 'input_ids': seq, 176 | "attention_mask": attention_mask 177 | } 178 | 179 | def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score, 180 | action_mask): 181 | 182 | kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) 183 | rewards = kl_divergence_estimate 184 | start = prompts.shape[1] - 1 185 | ends = start + action_mask[:, start:].sum(1) + 1 186 | reward_clip = torch.clamp(reward_score, -self.clip_reward_value, 187 | self.clip_reward_value) 188 | batch_size = log_probs.shape[0] 189 | for j in range(batch_size): 190 | rewards[j, start:ends[j]][-1] += reward_clip[j] 191 | 192 | return rewards 193 | 194 | def train_rlhf(self, inputs): 195 | # train the rlhf mode here 196 | ### process the old outputs 197 | prompts = inputs['prompts'] 198 | log_probs = inputs['logprobs'] 199 | ref_log_probs = inputs['ref_logprobs'] 200 | reward_score = inputs['rewards'] 201 | values = inputs['value'] 202 | attention_mask = inputs['attention_mask'] 203 | seq = inputs['input_ids'] 204 | 205 | start = prompts.size()[-1] - 1 206 | action_mask = attention_mask[:, 1:] 207 | 208 | old_values = values 209 | with torch.no_grad(): 210 | old_rewards = self.compute_rewards(prompts, log_probs, 211 | ref_log_probs, reward_score, 212 | action_mask) 213 | ends = start + action_mask[:, start:].sum(1) + 1 214 | # we need to zero out the reward and value after the end of the conversation 215 | # otherwise the advantage/return will be wrong 216 | for i in range(old_rewards.shape[0]): 217 | old_rewards[i, ends[i]:] = 0 218 | old_values[i, ends[i]:] = 0 219 | advantages, returns = self.get_advantages_and_returns( 220 | old_values, old_rewards, start) 221 | 222 | ### process the new outputs 223 | batch = {'input_ids': seq, "attention_mask": attention_mask} 224 | actor_prob = self.actor_model(**batch, use_cache=False).logits 225 | actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:]) 226 | actor_loss = self.actor_loss_fn(actor_log_prob[:, start:], 227 | log_probs[:, start:], advantages, 228 | action_mask[:, start:]) 229 | self.actor_model.backward(actor_loss) 230 | 231 | if not self.args.align_overflow: 232 | self.actor_model.step() 233 | 234 | value = self.critic_model.forward_value(**batch, 235 | return_value_only=True, 236 | use_cache=False)[:, :-1] 237 | critic_loss = self.critic_loss_fn(value[:, start:], old_values[:, 238 | start:], 239 | returns, action_mask[:, start:]) 240 | self.critic_model.backward(critic_loss) 241 | 242 | if self.args.align_overflow: 243 | actor_overflow = self.actor_model.optimizer.check_overflow( 244 | external=True) 245 | critic_overflow = self.critic_model.optimizer.check_overflow( 246 | external=True) 247 | 248 | rank = torch.distributed.get_rank() 249 | if actor_overflow and not critic_overflow: 250 | self.critic_model.optimizer.skip_step = True 251 | print_rank_0( 252 | "OVERFLOW: actor overflow, skipping both actor and critic steps", 253 | rank) 254 | elif not actor_overflow and critic_overflow: 255 | self.actor_model.optimizer.skip_step = True 256 | print_rank_0( 257 | "OVERFLOW: critic overflow, skipping both actor and critic steps", 258 | rank) 259 | elif actor_overflow and critic_overflow: 260 | print_rank_0( 261 | "OVERFLOW: actor and critic overflow, skipping both actor and critic steps", 262 | rank) 263 | self.actor_model.step() 264 | 265 | self.critic_model.step() 266 | 267 | return actor_loss, critic_loss 268 | 269 | def get_overflow(self): 270 | # Overflow is not expected when using bf16 271 | # Therefore, DeepSpeed's BF16_Optimizer does not maintain an overflow indication 272 | if self.args.dtype == "bf16": 273 | return False, False 274 | 275 | actor_overflow = self.actor_model.optimizer.overflow 276 | critic_overflow = self.critic_model.optimizer.overflow 277 | 278 | return actor_overflow, critic_overflow 279 | 280 | def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask): 281 | ## policy gradient loss 282 | log_ratio = (logprobs - old_logprobs) * mask 283 | ratio = torch.exp(log_ratio) 284 | pg_loss1 = -advantages * ratio 285 | pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange, 286 | 1.0 + self.cliprange) 287 | pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum() 288 | return pg_loss 289 | 290 | def critic_loss_fn(self, values, old_values, returns, mask): 291 | ## value loss 292 | values_clipped = torch.clamp( 293 | values, 294 | old_values - self.cliprange_value, 295 | old_values + self.cliprange_value, 296 | ) 297 | if self.compute_fp32_loss: 298 | values = values.float() 299 | values_clipped = values_clipped.float() 300 | vf_loss1 = (values - returns)**2 301 | vf_loss2 = (values_clipped - returns)**2 302 | vf_loss = 0.5 * torch.sum( 303 | torch.max(vf_loss1, vf_loss2) * mask) / mask.sum() 304 | return vf_loss 305 | 306 | def get_advantages_and_returns(self, values, rewards, start): 307 | # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134 308 | lastgaelam = 0 309 | advantages_reversed = [] 310 | length = rewards.size()[-1] 311 | for t in reversed(range(start, length)): 312 | nextvalues = values[:, t + 1] if t < length - 1 else 0.0 313 | delta = rewards[:, t] + self.gamma * nextvalues - values[:, t] 314 | lastgaelam = delta + self.gamma * self.lam * lastgaelam 315 | advantages_reversed.append(lastgaelam) 316 | advantages = torch.stack(advantages_reversed[::-1], dim=1) 317 | returns = advantages + values[:, start:] 318 | return advantages.detach(), returns 319 | 320 | def _validate_training_mode(self): 321 | assert self.actor_model.module.training 322 | assert self.critic_model.module.training 323 | 324 | def _validate_evaluation_mode(self): 325 | assert not self.actor_model.module.training 326 | assert not self.critic_model.module.training 327 | assert not self.ref_model.module.training 328 | assert not self.reward_model.module.training 329 | 330 | def train(self): 331 | self.actor_model.train() 332 | self.critic_model.train() 333 | 334 | def eval(self): 335 | self.actor_model.eval() 336 | self.critic_model.eval() 337 | self.reward_model.eval() 338 | self.ref_model.eval() 339 | 340 | def dump_model_norms(self, tag): 341 | actor_model_norm = get_model_norm(self.actor_model) 342 | ref_model_norm = get_model_norm(self.ref_model) 343 | critic_model_norm = get_model_norm(self.critic_model) 344 | reward_model_norm = get_model_norm(self.reward_model) 345 | print_all_ranks(f'{tag} global_actor_model_norm', actor_model_norm, 346 | self.args.local_rank) 347 | print_all_ranks(f'{tag} global_ref_model_norm', ref_model_norm, 348 | self.args.local_rank) 349 | print_all_ranks(f'{tag} global_critic_model_norm', critic_model_norm, 350 | self.args.local_rank) 351 | print_all_ranks(f'{tag} global_reward_model_norm', reward_model_norm, 352 | self.args.local_rank) 353 | 354 | 355 | class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer): 356 | 357 | def __init__(self, *args, **kwargs): 358 | super().__init__(*args, **kwargs) 359 | 360 | def train_unsupervised(self, inputs, unsup_coef): 361 | # Train the unsupervised model here 362 | self._validate_training_mode() 363 | 364 | outputs = self.actor_model(**inputs, use_cache=False) 365 | loss = outputs.loss 366 | self.actor_model.backward(unsup_coef * loss) 367 | self.actor_model.step() 368 | 369 | return loss 370 | -------------------------------------------------------------------------------- /rlhf/rlhf/rlhf_engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | import time 6 | import torch 7 | import deepspeed 8 | from deepspeed.ops.adam import FusedAdam 9 | from deepspeed.ops.adam import DeepSpeedCPUAdam 10 | from transformers import AutoModelForCausalLM, get_scheduler 11 | 12 | from utils.ds_utils import get_train_ds_config, get_eval_ds_config 13 | from utils.model_utils import create_hf_model, create_critic_model 14 | from utils.utils import get_optimizer_grouped_parameters 15 | """ 16 | TODOs: 17 | * support HF models for critic (for debugging), must be a previously saved ckpt from step-2 18 | * determine ds_config/zero_stage based on model size, gpu style, world size, etc 19 | - get model size by creating simple meta model 20 | - 1.3b: zero-2 for actor/ref models, zero-0 for others 21 | - 13b+: zero-3 for all models 22 | """ 23 | 24 | 25 | def log_init(model_name, stime=None): 26 | if torch.distributed.get_rank() == 0: 27 | tag = "start" if stime is None else "end" 28 | suffix = "ing" if stime is None else "ed" 29 | duration = "" 30 | if stime is not None: 31 | duration = "(duration: {:.2f}s)".format(time.time() - stime) 32 | msg = f"[{tag}] Initializ{suffix} {model_name} Model [{tag}] {duration}" 33 | stars = (90 - len(msg)) // 2 34 | extra_star = "*" if (90 - len(msg)) % 2 == 1 else "" 35 | print("*" * stars + msg + "*" * stars + extra_star) 36 | return time.time() 37 | 38 | 39 | class DeepSpeedRLHFEngine(): 40 | 41 | def __init__(self, actor_model_name_or_path, critic_model_name_or_path, 42 | tokenizer, args, num_total_iters): 43 | self.args = args 44 | self.num_total_iters = num_total_iters 45 | self.tokenizer = tokenizer 46 | 47 | self.actor = self._init_actor( 48 | actor_model_name_or_path=actor_model_name_or_path) 49 | self.ref = self._init_ref( 50 | actor_model_name_or_path=actor_model_name_or_path) 51 | self.actor_ema = None 52 | if self.args.enable_ema: 53 | self.actor_ema = self._init_ema( 54 | actor_model_name_or_path=actor_model_name_or_path) 55 | self.critic = self._init_critic( 56 | critic_model_name_or_path=critic_model_name_or_path) 57 | self.reward = self._init_reward( 58 | critic_model_name_or_path=critic_model_name_or_path) 59 | if self.args.critic_gradient_checkpointing: 60 | self.critic.gradient_checkpointing_enable() 61 | 62 | def _init_actor(self, actor_model_name_or_path): 63 | stime = log_init("Actor") 64 | 65 | # DS Config 66 | ds_config = get_train_ds_config( 67 | offload=self.args.offload, 68 | dtype=self.args.dtype, 69 | stage=self.args.actor_zero_stage, 70 | enable_hybrid_engine=self.args.enable_hybrid_engine, 71 | inference_tp_size=self.args.inference_tp_size, 72 | release_inference_cache=self.args.release_inference_cache, 73 | pin_parameters=(not self.args.unpin_actor_parameters), 74 | tp_gather_partition_size=self.args.tp_gather_partition_size, 75 | max_out_tokens=self.args.max_prompt_seq_len + 76 | self.args.max_answer_seq_len, 77 | enable_tensorboard=self.args.enable_tensorboard, 78 | enable_mixed_precision_lora=self.args.enable_mixed_precision_lora, 79 | tb_path=self.args.tensorboard_path, 80 | tb_name="step3_actor") 81 | ds_config[ 82 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 83 | #TODO(jeff): we should probably set grad accumlation steps here as well for clarity 84 | ds_config[ 85 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 86 | ) * self.args.gradient_accumulation_steps_actor 87 | 88 | # Model 89 | actor_model = create_hf_model( 90 | model_class=AutoModelForCausalLM, 91 | model_name_or_path=actor_model_name_or_path, 92 | tokenizer=self.tokenizer, 93 | ds_config=ds_config, 94 | dropout=self.args.actor_dropout) 95 | 96 | 97 | # Optimizer 98 | AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam 99 | optim_params = get_optimizer_grouped_parameters( 100 | actor_model, self.args.actor_weight_decay, 101 | self.args.actor_lora_learning_rate) 102 | optim = AdamOptimizer(optim_params, 103 | lr=self.args.actor_learning_rate, 104 | betas=(0.9, 0.95)) 105 | 106 | # LR Scheduler 107 | lr_scheduler = get_scheduler( 108 | name=self.args.lr_scheduler_type, 109 | optimizer=optim, 110 | num_warmup_steps=self.args.num_warmup_steps, 111 | num_training_steps=self.num_total_iters, 112 | ) 113 | 114 | # DeepSpeed Engine 115 | #TODO: move enable_hybrid_engine and pin_parameters to ds_config 116 | actor_engine, *_ = deepspeed.initialize(model=actor_model, 117 | optimizer=optim, 118 | lr_scheduler=lr_scheduler, 119 | config=ds_config) 120 | 121 | log_init("Actor", stime=stime) 122 | 123 | return actor_engine 124 | 125 | def _init_ref(self, actor_model_name_or_path): 126 | stime = log_init("Ref") 127 | # DS Config 128 | zero_stage = self.args.actor_zero_stage 129 | if zero_stage != 3: 130 | # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory for ref model 131 | zero_stage = 0 132 | ds_config = get_eval_ds_config(self.args.offload_reference_model, 133 | self.args.dtype, zero_stage) 134 | ds_config[ 135 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 136 | #TODO(jeff): we should probably set grad accumlation steps here as well for clarity 137 | ds_config[ 138 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 139 | ) * self.args.gradient_accumulation_steps_actor 140 | 141 | ref_model = create_hf_model(AutoModelForCausalLM, 142 | actor_model_name_or_path, self.tokenizer, 143 | ds_config) 144 | 145 | ref_engine, *_ = deepspeed.initialize(model=ref_model, 146 | config=ds_config) 147 | 148 | log_init("Ref", stime=stime) 149 | return ref_engine 150 | 151 | def _init_ema(self, actor_model_name_or_path): 152 | stime = log_init("EMA") 153 | # DS Config 154 | zero_stage = self.args.actor_zero_stage 155 | if zero_stage != 3: 156 | # If actor is ZeRO-3 then we use it for everything, otherwise assume we have enough memory 157 | zero_stage = 0 158 | ds_config = get_eval_ds_config(self.args.offload_reference_model, 159 | self.args.dtype, zero_stage) 160 | ds_config[ 161 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 162 | #TODO(jeff): we should probably set grad accumlation steps here as well for clarity 163 | ds_config[ 164 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 165 | ) * self.args.gradient_accumulation_steps_actor 166 | 167 | actor_model_ema = create_hf_model(AutoModelForCausalLM, 168 | actor_model_name_or_path, 169 | self.tokenizer, ds_config) 170 | 171 | ema_engine, *_ = deepspeed.initialize(model=actor_model_ema, 172 | config=ds_config) 173 | 174 | log_init("EMA", stime=stime) 175 | return ema_engine 176 | 177 | def _init_critic(self, critic_model_name_or_path): 178 | stime = log_init("Critic") 179 | ds_config = get_train_ds_config( 180 | offload=self.args.offload, 181 | dtype=self.args.dtype, 182 | stage=self.args.critic_zero_stage, 183 | enable_tensorboard=self.args.enable_tensorboard, 184 | tb_path=self.args.tensorboard_path, 185 | tb_name="step3_critic") 186 | ds_config[ 187 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 188 | #TODO(jeff): we should probably set grad accumlation steps here as well for clarity 189 | ds_config[ 190 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 191 | ) * self.args.gradient_accumulation_steps 192 | 193 | ds_eval_config = get_eval_ds_config(offload=False, 194 | dtype=self.args.dtype, 195 | stage=self.args.critic_zero_stage) 196 | # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. 197 | ds_eval_config[ 198 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 199 | ds_eval_config[ 200 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 201 | ) * self.args.gradient_accumulation_steps 202 | 203 | # Model 204 | critic_model = create_critic_model( 205 | model_name_or_path=critic_model_name_or_path, 206 | tokenizer=self.tokenizer, 207 | ds_config=ds_eval_config, 208 | num_padding_at_beginning=self.args.num_padding_at_beginning, 209 | rlhf_training=True, 210 | dropout=self.args.critic_dropout, 211 | zero_stage=self.args.critic_zero_stage) 212 | 213 | # Optimizer 214 | AdamOptimizer = DeepSpeedCPUAdam if self.args.offload else FusedAdam 215 | optim_params = get_optimizer_grouped_parameters( 216 | critic_model, self.args.critic_weight_decay, 217 | self.args.critic_lora_learning_rate) 218 | optim = AdamOptimizer(optim_params, 219 | lr=self.args.critic_learning_rate, 220 | betas=(0.9, 0.95)) 221 | 222 | # LR Scheduler 223 | lr_scheduler = get_scheduler( 224 | name=self.args.lr_scheduler_type, 225 | optimizer=optim, 226 | num_warmup_steps=self.args.num_warmup_steps, 227 | num_training_steps=self.num_total_iters, 228 | ) 229 | 230 | # DeepSpeed Engine 231 | critic_engine, *_ = deepspeed.initialize(model=critic_model, 232 | optimizer=optim, 233 | lr_scheduler=lr_scheduler, 234 | config=ds_config) 235 | 236 | log_init("Critic", stime=stime) 237 | return critic_engine 238 | 239 | def _init_reward(self, critic_model_name_or_path): 240 | stime = log_init("Reward") 241 | # DS Config 242 | zero_stage = self.args.critic_zero_stage 243 | if zero_stage != 3: 244 | # If critic is ZeRO-3 then we use it for everything, otherwise assume we have enough memory 245 | zero_stage = 0 246 | 247 | ds_config = get_eval_ds_config(offload=self.args.offload, 248 | dtype=self.args.dtype, 249 | stage=zero_stage) 250 | ds_config[ 251 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 252 | ds_config[ 253 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 254 | ) * self.args.gradient_accumulation_steps 255 | 256 | ds_eval_config = get_eval_ds_config(offload=False, 257 | dtype=self.args.dtype, 258 | stage=zero_stage) 259 | 260 | # We need to set train batch size and micro batch size here to pass the sanity check of DeepSpeed engine. 261 | ds_eval_config[ 262 | 'train_micro_batch_size_per_gpu'] = self.args.per_device_training_batch_size 263 | ds_eval_config[ 264 | 'train_batch_size'] = self.args.per_device_training_batch_size * torch.distributed.get_world_size( 265 | ) * self.args.gradient_accumulation_steps 266 | 267 | # Model 268 | reward_model = create_critic_model( 269 | model_name_or_path=critic_model_name_or_path, 270 | tokenizer=self.tokenizer, 271 | ds_config=ds_eval_config, 272 | num_padding_at_beginning=self.args.num_padding_at_beginning, 273 | rlhf_training=True, 274 | dropout=self.args.critic_dropout, 275 | zero_stage=zero_stage) 276 | 277 | reward_engine, *_ = deepspeed.initialize(model=reward_model, 278 | config=ds_config) 279 | 280 | log_init("Reward", stime=stime) 281 | return reward_engine 282 | -------------------------------------------------------------------------------- /rlhf/rlhf_data_process.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import json 3 | from tqdm import tqdm 4 | import torch 5 | import random 6 | import os 7 | def split_jsonl(): 8 | input_file = './rlhf.jsonl' 9 | output_files = ['./rlhf_part1.jsonl', './rlhf_part2.jsonl', './rlhf_part3.jsonl', './rlhf_part4.jsonl'] 10 | 11 | # 读取输入文件的内容 12 | with open(input_file, 'r', encoding='utf-8') as f: 13 | lines = f.readlines() 14 | 15 | # 打乱顺序 16 | random.shuffle(lines) 17 | 18 | # 计算每个文件的行数 19 | num_lines = len(lines) 20 | chunk_size = num_lines // 4 21 | 22 | # 将行分成 4 组 23 | chunks = [lines[i * chunk_size: (i + 1) * chunk_size] for i in range(4)] 24 | 25 | # 如果有多余的行,均匀分配到各个文件 26 | for i in range(num_lines % 4): 27 | chunks[i].append(lines[4 * chunk_size + i]) 28 | 29 | # 将每组写入不同的输出文件 30 | for i, output_file in enumerate(output_files): 31 | with open(output_file, 'w', encoding='utf-8') as out_f: 32 | for line in chunks[i]: 33 | out_f.write(line) 34 | 35 | 36 | def generate_rlhf_data(): 37 | input_file = './rlhf_part4.jsonl' 38 | model_path = './sft_model' 39 | output_file = './rlhf_generate_part4.jsonl' 40 | device = "cuda" 41 | 42 | model = AutoModelForCausalLM.from_pretrained( 43 | model_path, 44 | torch_dtype="auto", 45 | device_map="auto", 46 | trust_remote_code=True 47 | ) 48 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 49 | 50 | # 读取输入文件的内容 51 | with open(input_file, 'r', encoding='utf-8') as f: 52 | lines = f.readlines() 53 | 54 | # 检查 output_file 是否存在以及已经写了多少条数据 55 | existing_data = [] 56 | if os.path.exists(output_file): 57 | with open(output_file, 'r', encoding='utf-8') as out_f: 58 | existing_data = out_f.readlines() 59 | 60 | processed_prompts = set() 61 | for line in existing_data: 62 | data = json.loads(line) 63 | processed_prompts.add(data['prompt']) 64 | 65 | # 处理每一行 JSON 对象 66 | with open(output_file, 'a', encoding='utf-8') as out_f: 67 | for line in tqdm(lines, desc="Processing"): 68 | data = json.loads(line) 69 | prompt = data['prompt'] 70 | 71 | answer = data['answer'] 72 | messages = [ 73 | {"role": "user", "content": prompt} 74 | ] 75 | text = tokenizer.apply_chat_template( 76 | messages, 77 | tokenize=False, 78 | add_generation_prompt=True 79 | ) 80 | if text in processed_prompts: 81 | continue # Skip already processed prompts 82 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 83 | 84 | with torch.no_grad(): 85 | generated_ids = model.generate( 86 | **model_inputs, 87 | max_new_tokens=512, 88 | do_sample=False 89 | ) 90 | generated_ids = [ 91 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 92 | ] 93 | 94 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 95 | result = { 96 | 'prompt': text, 97 | 'response': answer + tokenizer.eos_token, 98 | 'chosen': answer + tokenizer.eos_token, 99 | 'rejected': response + tokenizer.eos_token 100 | } 101 | # 将结果写入输出文件 102 | out_f.write(json.dumps(result, ensure_ascii=False) + '\n') 103 | 104 | 105 | def process_step_2_3_data(): 106 | input_files = [ 107 | './rlhf_generate_part1.jsonl', 108 | './rlhf_generate_part2.jsonl', 109 | './rlhf_generate_part3.jsonl', 110 | './rlhf_generate_part4.jsonl' 111 | ] 112 | output_step2_train_file = './step2_data/train.jsonl' 113 | output_step2_eval_file = './step2_data/eval.jsonl' 114 | output_step3_train_file = './step3_data/train.jsonl' 115 | output_step3_eval_file = './step3_data/eval.jsonl' 116 | 117 | data = [] 118 | 119 | # 读取所有输入文件 120 | for file in input_files: 121 | with open(file, 'r', encoding='utf-8') as f: 122 | for line in f: 123 | data.append(json.loads(line.strip())) 124 | 125 | # 随机打乱数据 126 | random.shuffle(data) 127 | 128 | # 分割数据 129 | total_size = len(data) 130 | step3_train_size = int(total_size * 0.95) 131 | step3_eval_size = total_size - step3_train_size 132 | step2_train_size = int(total_size * 0.475) 133 | step2_eval_size = int(total_size * 0.025) 134 | 135 | step3_train_data = data[:step3_train_size] 136 | step3_eval_data = data[step3_train_size:] 137 | 138 | step2_train_data = data[:step2_train_size] 139 | step2_eval_data = data[step2_train_size:step2_train_size + step2_eval_size] 140 | 141 | # 写入输出文件 142 | with open(output_step3_train_file, 'w', encoding='utf-8') as f: 143 | for item in step3_train_data: 144 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 145 | 146 | with open(output_step3_eval_file, 'w', encoding='utf-8') as f: 147 | for item in step3_eval_data: 148 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 149 | 150 | with open(output_step2_train_file, 'w', encoding='utf-8') as f: 151 | for item in step2_train_data: 152 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 153 | 154 | with open(output_step2_eval_file, 'w', encoding='utf-8') as f: 155 | for item in step2_eval_data: 156 | f.write(json.dumps(item, ensure_ascii=False) + '\n') 157 | 158 | 159 | 160 | 161 | 162 | def main(): 163 | #split_jsonl() 164 | 165 | #generate_rlhf_data() 166 | process_step_2_3_data() 167 | 168 | 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /rlhf/rw_eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | import argparse 7 | import torch 8 | 9 | from utils.model_utils import create_critic_model 10 | from utils.utils import to_device, load_hf_tokenizer 11 | from deepspeed import get_accelerator 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description="Eval the finetued reward model") 17 | parser.add_argument( 18 | "--model_name_or_path", 19 | type=str, 20 | help= 21 | "Path to pretrained model or model identifier from huggingface.co/models.", 22 | required=True, 23 | ) 24 | parser.add_argument( 25 | "--num_padding_at_beginning", 26 | type=int, 27 | default=1, 28 | help= 29 | "OPT model has a fixed number (1) of padding tokens at the beginning of the input. " 30 | "We did not see this in other models but keep it as an option for now.", 31 | ) 32 | parser.add_argument( 33 | "--add_eot_token", 34 | action='store_true', 35 | help="Add <|endoftext|> as additional special token to tokenizer") 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def load_stuff(model_name_or_path, num_padding_at_beginning, 41 | additional_special_tokens): 42 | 43 | tokenizer = load_hf_tokenizer(model_name_or_path) 44 | model = create_critic_model(model_name_or_path, 45 | tokenizer, 46 | None, 47 | num_padding_at_beginning, 48 | dropout=0) 49 | 50 | return model, tokenizer 51 | 52 | 53 | def prepare_datapair(prompt, 54 | good_ans, 55 | bad_ans, 56 | tokenizer, 57 | max_seq_len=512, 58 | end_of_conversation_token=None): 59 | chosen_sentence = prompt + good_ans 60 | reject_sentence = prompt + bad_ans 61 | chosen_token = tokenizer(chosen_sentence, 62 | max_length=max_seq_len, 63 | padding="max_length", 64 | truncation=True, 65 | return_tensors="pt") 66 | 67 | reject_token = tokenizer(reject_sentence, 68 | max_length=max_seq_len, 69 | padding="max_length", 70 | truncation=True, 71 | return_tensors="pt") 72 | 73 | batch = {} 74 | batch["input_ids"] = torch.cat([chosen_token["input_ids"]] + 75 | [reject_token["input_ids"]], 76 | dim=0) 77 | batch["attention_mask"] = torch.cat([chosen_token["attention_mask"]] + 78 | [reject_token["attention_mask"]], 79 | dim=0) 80 | return batch 81 | 82 | 83 | def prepare_singlesample(prompt, 84 | good_ans, 85 | tokenizer, 86 | max_seq_len=512, 87 | end_of_conversation_token=None): 88 | chosen_sentence = prompt + good_ans + end_of_conversation_token 89 | chosen_token = tokenizer(chosen_sentence, 90 | max_length=max_seq_len, 91 | padding="max_length", 92 | truncation=True, 93 | return_tensors="pt") 94 | 95 | batch = {} 96 | batch["input_ids"] = chosen_token["input_ids"] 97 | batch["attention_mask"] = chosen_token["attention_mask"] 98 | 99 | return batch 100 | 101 | 102 | def run_pair_comparison(): 103 | args = parse_args() 104 | 105 | device = torch.device(get_accelerator().device_name(0)) 106 | 107 | args.end_of_conversation_token = None 108 | additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None 109 | 110 | rm_model, tokenizer = load_stuff(args.model_name_or_path, 111 | args.num_padding_at_beginning, 112 | additional_special_tokens) 113 | rm_model.to(device) 114 | rm_model.eval() 115 | 116 | prompt_list = [ 117 | "<|im_start|>system\n你是一个由喵阿姨开发的喵喵小助手<|im_end|>\n<|im_start|>user\n帮我生成一些音乐热评<|im_end|>\n<|im_start|>assistant\n", 118 | "<|im_start|>system\n你是一个由喵阿姨开发的喵喵小助手<|im_end|>\n<|im_start|>user\n根据开头,续写古诗:\n翠幄千章荫晚空<|im_end|>\n<|im_start|>assistant\n" 119 | ] 120 | good_ans_list = [ 121 | "1、1997年听了耀威的《有缘千里》专辑,到今年20年了,一直关注,有没有像我一样的朋友?\n2、爱的故事·上集·万屡爱意寄窗扉\n爱的故事·下集·我愿他能珍惜你\n爱的故事·曲终·只有我懂得自己<|im_end|>", 122 | "年华心赏两无穷。云头欲落催诗雨,池面微生解愠风。经笥使君谈似绮,仙舟令尹饮如虹。娵隅自适清池乐,不信参军是郝隆。<|im_end|>" 123 | ] 124 | bad_ans_list = [ 125 | "1、我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,我一直都觉得,这首歌是我的最爱,<|im_end|>", 126 | "金蟾照影照金蟾。玉兔飞来玉兔飞,玉兔飞来玉兔飞。<|im_end|>" 127 | ] 128 | 129 | for prompt, good_ans, bad_ans in zip(prompt_list, good_ans_list, 130 | bad_ans_list): 131 | batch = prepare_datapair( 132 | prompt, 133 | good_ans, 134 | bad_ans, 135 | tokenizer, 136 | max_seq_len=512, 137 | end_of_conversation_token=None) 138 | batch = to_device(batch, device) 139 | # Run inference 140 | with torch.no_grad(): 141 | outputs = rm_model(**batch) 142 | print("==================Eval result============================") 143 | print("prompt: ", prompt) 144 | print("\ngood_ans: ", good_ans) 145 | print("\nbad_ans:", bad_ans) 146 | print() 147 | print("=============Scores (higher, better)========================") 148 | print("good_ans score: ", outputs["chosen_mean_scores"].item()) 149 | print("bad_ans score: ", outputs["rejected_mean_scores"].item()) 150 | 151 | 152 | def run_single_sample(): 153 | args = parse_args() 154 | device = torch.device(get_accelerator().device_name()) 155 | 156 | args.end_of_conversation_token = None 157 | additional_special_tokens = args.end_of_conversation_token if args.add_eot_token else None 158 | 159 | rm_model, tokenizer = load_stuff(args.model_name_or_path, 160 | args.num_padding_at_beginning, 161 | additional_special_tokens) 162 | rm_model.to(device) 163 | 164 | prompt = "Human: Explain the moon landing to a 6 year old in a few sentences." 165 | my_ans = "Assistant: The moon landing was a major milestone in the history of human exploration of the solar system. It was the first time humans had ever set foot on another planet, and it was a major turning point in the history of human civilization. The astronauts, Neil Armstrong, Buzz Aldrin, and Michael Collins, successfully landed the Apollo 11 spacecraft on the moon, marking the first time humans had ever set foot on another" 166 | 167 | batch = prepare_singlesample( 168 | prompt, 169 | my_ans, 170 | tokenizer, 171 | max_seq_len=512, 172 | end_of_conversation_token=args.end_of_conversation_token) 173 | batch = to_device(batch, device) 174 | 175 | rm_model.eval() 176 | # Run inference 177 | with torch.no_grad(): 178 | outputs = rm_model.forward_value( 179 | **batch, prompt_length=max(2, args.num_padding_at_beginning) 180 | ) # we just need to skip the number of padding tokens at the beginning 181 | print("==================Eval result============================") 182 | print("prompt: ", prompt) 183 | print("my_ans: ", my_ans) 184 | print() 185 | print("=============Scores========================") 186 | print("my_ans score: ", outputs["chosen_end_scores"].item()) 187 | 188 | 189 | if __name__ == "__main__": 190 | run_pair_comparison() 191 | # run_single_sample() 192 | -------------------------------------------------------------------------------- /rlhf/step2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | import argparse 7 | import math 8 | 9 | import torch 10 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 11 | from torch.utils.data.distributed import DistributedSampler 12 | 13 | from transformers import ( 14 | SchedulerType, 15 | get_scheduler, 16 | ) 17 | 18 | import deepspeed 19 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 20 | from deepspeed.accelerator import get_accelerator 21 | 22 | from utils.model_utils import create_critic_model 23 | from utils.data_utils import create_prompt_dataset, DataCollatorReward 24 | from utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, load_hf_tokenizer 25 | from utils.ds_utils import get_train_ds_config 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser( 30 | description= 31 | "Finetune a transformers model on a causal language modeling task") 32 | parser.add_argument('--data_path', 33 | type=str) 34 | parser.add_argument( 35 | '--data_output_path', 36 | type=str, 37 | default='/tmp/data_files/', 38 | help='Where to store the data-related files such as shuffle index.') 39 | parser.add_argument( 40 | "--model_name_or_path", 41 | type=str, 42 | help= 43 | "Path to pretrained model or model identifier from huggingface.co/models.", 44 | required=True, 45 | ) 46 | parser.add_argument( 47 | "--num_padding_at_beginning", 48 | type=int, 49 | default=1, 50 | help= 51 | "OPT model has a fixed number (1) of padding tokens at the beginning of the input. " 52 | "We did not see this in other models but keep it as an option for now.", 53 | ) 54 | parser.add_argument( 55 | "--per_device_train_batch_size", 56 | type=int, 57 | default=16, 58 | help="Batch size (per device) for the training dataloader.", 59 | ) 60 | parser.add_argument( 61 | "--per_device_eval_batch_size", 62 | type=int, 63 | default=16, 64 | help="Batch size (per device) for the evaluation dataloader.", 65 | ) 66 | parser.add_argument( 67 | "--max_seq_len", 68 | type=int, 69 | default=512, 70 | help="The maximum sequence length.", 71 | ) 72 | parser.add_argument( 73 | "--learning_rate", 74 | type=float, 75 | default=5e-5, 76 | help= 77 | "Initial learning rate (after the potential warmup period) to use.", 78 | ) 79 | parser.add_argument("--weight_decay", 80 | type=float, 81 | default=0., 82 | help="Weight decay to use.") 83 | parser.add_argument("--num_train_epochs", 84 | type=int, 85 | default=1, 86 | help="Total number of training epochs to perform.") 87 | parser.add_argument( 88 | "--gradient_accumulation_steps", 89 | type=int, 90 | default=1, 91 | help= 92 | "Number of updates steps to accumulate before performing a backward/update pass.", 93 | ) 94 | parser.add_argument( 95 | "--lr_scheduler_type", 96 | type=SchedulerType, 97 | default="cosine", 98 | help="The scheduler type to use.", 99 | choices=[ 100 | "linear", "cosine", "cosine_with_restarts", "polynomial", 101 | "constant", "constant_with_warmup" 102 | ], 103 | ) 104 | parser.add_argument( 105 | "--num_warmup_steps", 106 | type=int, 107 | default=0, 108 | help="Number of steps for the warmup in the lr scheduler.") 109 | parser.add_argument("--output_dir", 110 | type=str, 111 | default=None, 112 | help="Where to store the model.") 113 | parser.add_argument("--seed", 114 | type=int, 115 | default=1234, 116 | help="A seed for reproducible training.") 117 | parser.add_argument("--local_rank", 118 | type=int, 119 | default=-1, 120 | help="local_rank for distributed training on gpus") 121 | parser.add_argument( 122 | '--gradient_checkpointing', 123 | action='store_true', 124 | help='Enable HF gradient checkpointing for Actor model.') 125 | parser.add_argument( 126 | "--dropout", 127 | type=float, 128 | default=None, 129 | help="If dropout configured, use it. " 130 | "Otherwise, keep the default dropout configuration of the model.") 131 | # deepspeed features 132 | parser.add_argument('--offload', 133 | action='store_true', 134 | help='Enable ZeRO Offload techniques.') 135 | parser.add_argument('--dtype', 136 | type=str, 137 | default='fp16', 138 | choices=['fp16', 'bf16'], 139 | help='Training data type') 140 | parser.add_argument( 141 | '--zero_stage', 142 | type=int, 143 | default=0, 144 | help='ZeRO optimization stage for Actor model (and clones).') 145 | ## LoRA for efficient training setting 146 | parser.add_argument("--lora_dim", 147 | type=int, 148 | default=0, 149 | help="If > 0, use LoRA for efficient training.") 150 | parser.add_argument("--lora_module_name", 151 | type=str, 152 | default="decoder.layers.", 153 | help="The scope of LoRA.") 154 | parser.add_argument('--only_optimize_lora', 155 | action='store_true', 156 | help='Only optimize the LoRA parameters.') 157 | parser.add_argument( 158 | "--lora_learning_rate", 159 | type=float, 160 | default=5e-4, 161 | help= 162 | "Initial LoRA learning rate (after the potential warmup period) to use." 163 | ) 164 | 165 | # Evaluation 166 | parser.add_argument("--eval_interval", 167 | type=int, 168 | default=0, 169 | help="If > 0, perform evaluation at this interval") 170 | parser.add_argument("--eval_iters", 171 | type=int, 172 | default=100, 173 | help="Maximum evaluation iterations") 174 | ## low precision 175 | parser.add_argument( 176 | '--compute_fp32_loss', 177 | action='store_true', 178 | help='Relevant for low precision dtypes (fp16, bf16, etc.). ' 179 | 'If specified, loss is calculated in fp32.') 180 | 181 | ## Tensorboard logging 182 | parser.add_argument('--enable_tensorboard', 183 | action='store_true', 184 | help='Enable tensorboard logging') 185 | parser.add_argument('--tensorboard_path', 186 | type=str, 187 | default="step2_tensorboard") 188 | ## Tokenizer 189 | parser.add_argument( 190 | "--add_eot_token", 191 | action='store_true', 192 | help="Add <|endoftext|> as additional special token to tokenizer") 193 | parser = deepspeed.add_config_arguments(parser) 194 | args = parser.parse_args() 195 | 196 | return args 197 | 198 | 199 | def main(): 200 | args = parse_args() 201 | 202 | if args.local_rank == -1: 203 | device = torch.device(get_accelerator().device_name()) 204 | else: 205 | get_accelerator().set_device(args.local_rank) 206 | device = torch.device(get_accelerator().device_name(), args.local_rank) 207 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 208 | # torch.distributed.init_process_group(backend='nccl') 209 | deepspeed.init_distributed() 210 | 211 | args.global_rank = torch.distributed.get_rank() 212 | 213 | ds_config = get_train_ds_config(offload=args.offload, 214 | dtype=args.dtype, 215 | stage=args.zero_stage, 216 | enable_tensorboard=args.enable_tensorboard, 217 | tb_path=args.tensorboard_path, 218 | tb_name="step2_model") 219 | ds_config[ 220 | 'train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size 221 | ds_config[ 222 | 'train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size( 223 | ) * args.gradient_accumulation_steps 224 | 225 | # If passed along, set the training seed now. 226 | set_random_seed(args.seed) 227 | torch.distributed.barrier() 228 | 229 | tokenizer = load_hf_tokenizer(args.model_name_or_path) 230 | rm_model = create_critic_model(args.model_name_or_path, 231 | tokenizer, 232 | ds_config, 233 | args.num_padding_at_beginning, 234 | dropout=args.dropout, 235 | zero_stage=args.zero_stage, 236 | compute_fp32_loss=args.compute_fp32_loss) 237 | 238 | 239 | print_rank_0("create_prompt_dataset前") 240 | train_phase = 2 241 | train_dataset, eval_dataset = create_prompt_dataset( 242 | args.local_rank, args.data_path, train_phase, args.seed, tokenizer,args.max_seq_len) 243 | 244 | # 打印train_dataset的部分内容 245 | 246 | 247 | 248 | data_collator = DataCollatorReward() 249 | 250 | if args.local_rank == -1: 251 | train_sampler = RandomSampler(train_dataset) 252 | eval_sampler = SequentialSampler(eval_dataset) 253 | else: 254 | train_sampler = DistributedSampler(train_dataset) 255 | eval_sampler = DistributedSampler(eval_dataset) 256 | train_dataloader = DataLoader(train_dataset, 257 | collate_fn=data_collator, 258 | sampler=train_sampler, 259 | batch_size=args.per_device_train_batch_size) 260 | eval_dataloader = DataLoader(eval_dataset, 261 | collate_fn=data_collator, 262 | sampler=eval_sampler, 263 | batch_size=args.per_device_eval_batch_size) 264 | 265 | 266 | 267 | 268 | 269 | def evaluation_reward(model, dataloader, eval_iters): 270 | model.eval() 271 | correct_predictions = 0 272 | total_predictions = 0 273 | chosen_scores = 0. 274 | rejected_scores = 0. 275 | for _step, _batch in enumerate(dataloader): 276 | _batch = to_device(_batch, device) 277 | with torch.no_grad(): 278 | _outputs = model(**_batch) 279 | 280 | chosen = _outputs["chosen_mean_scores"] 281 | rejected = _outputs["rejected_mean_scores"] 282 | correct_predictions += (chosen > rejected).sum() 283 | total_predictions += chosen.shape[0] 284 | chosen_scores += _outputs["chosen_mean_scores"].mean().float() 285 | rejected_scores += _outputs["rejected_mean_scores"].mean().float() 286 | if (_step + 1) == eval_iters: 287 | break 288 | _acc = correct_predictions / total_predictions 289 | chosen_scores = chosen_scores / (_step + 1) 290 | rejected_scores = rejected_scores / (_step + 1) 291 | try: 292 | _acc = get_all_reduce_mean(_acc).item() 293 | chosen_scores = get_all_reduce_mean(chosen_scores).item() 294 | rejected_scores = get_all_reduce_mean(rejected_scores).item() 295 | except: 296 | pass 297 | return chosen_scores, rejected_scores, _acc 298 | 299 | # Split weights in two groups, one with weight decay and the other not. 300 | optimizer_grouped_parameters = get_optimizer_grouped_parameters( 301 | rm_model, args.weight_decay, args.lora_learning_rate) 302 | 303 | AdamOptimizer = FusedAdam 304 | 305 | optimizer = AdamOptimizer(optimizer_grouped_parameters, 306 | lr=args.learning_rate, 307 | betas=(0.9, 0.95)) 308 | 309 | num_update_steps_per_epoch = math.ceil( 310 | len(train_dataloader) / args.gradient_accumulation_steps) 311 | 312 | lr_scheduler = get_scheduler( 313 | name=args.lr_scheduler_type, 314 | optimizer=optimizer, 315 | num_warmup_steps=args.num_warmup_steps, 316 | num_training_steps=args.num_train_epochs * num_update_steps_per_epoch, 317 | ) 318 | 319 | rm_model, optimizer, _, lr_scheduler = deepspeed.initialize( 320 | model=rm_model, 321 | optimizer=optimizer, 322 | args=args, 323 | config=ds_config, 324 | lr_scheduler=lr_scheduler, 325 | dist_init_required=True) 326 | 327 | if args.gradient_checkpointing: 328 | rm_model.gradient_checkpointing_enable() 329 | 330 | # Train! 331 | print_rank_0("***** Running training *****", args.global_rank) 332 | 333 | print_rank_0( 334 | f"***** Evaluating reward, Epoch {0}/{args.num_train_epochs} *****", 335 | args.global_rank) 336 | reward_score, reject_score, acc = evaluation_reward( 337 | rm_model, eval_dataloader, args.eval_iters) 338 | print_rank_0( 339 | f"chosen_last_scores (higher is better) : {reward_score}, " 340 | f"rejected_last_scores (lower is better) : {reject_score}, " 341 | f"acc (higher is better) : {acc}", args.global_rank) 342 | 343 | total_micro_steps = 0 344 | for epoch in range(args.num_train_epochs): 345 | print_rank_0( 346 | f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}", 347 | args.global_rank) 348 | rm_model.train() 349 | mean_loss = 0 350 | for step, batch in enumerate(train_dataloader): 351 | batch = to_device(batch, device) 352 | outputs = rm_model(**batch, use_cache=False) 353 | loss = outputs["loss"] 354 | rm_model.backward(loss) 355 | rm_model.step() 356 | mean_loss += loss.item() 357 | total_micro_steps += 1 358 | gas_boundary = (total_micro_steps % 359 | args.gradient_accumulation_steps == 0) 360 | total_steps = total_micro_steps // args.gradient_accumulation_steps 361 | if args.eval_interval and gas_boundary and ( 362 | total_steps % args.eval_interval == 0): 363 | print_rank_0(f"Iter {total_steps}: Evaluating reward", 364 | args.global_rank) 365 | reward_score, reject_score, acc = evaluation_reward( 366 | rm_model, eval_dataloader, args.eval_iters) 367 | print_rank_0( 368 | f"Iter {total_steps}: c_scores: {reward_score}, r_scores: {reject_score}, " 369 | f"diff: {reward_score - reject_score}, acc: {acc}", 370 | args.global_rank) 371 | rm_model.train() 372 | 373 | print_rank_0( 374 | f"Epoch {epoch+1}/{args.num_train_epochs} with loss {mean_loss/(step+1)}", 375 | args.global_rank) 376 | # Evaluate reward_loss on the validation set. 377 | print_rank_0( 378 | f"***** Evaluating reward, Epoch {epoch+1}/{args.num_train_epochs} *****", 379 | args.global_rank) 380 | reward_score, reject_score, acc = evaluation_reward( 381 | rm_model, eval_dataloader, args.eval_iters) 382 | print_rank_0( 383 | f"chosen_last_scores (higher is better) : {reward_score}, " 384 | f"rejected_last_scores (lower is better) : {reject_score}, " 385 | f"acc (higher is better) : {acc}", args.global_rank) 386 | rm_model.tput_timer.update_epoch_count() 387 | 388 | if args.output_dir is not None: 389 | print_rank_0('saving model ...', args.global_rank) 390 | 391 | if args.global_rank == 0: 392 | save_hf_format(rm_model, tokenizer, args, sub_folder=f"epoch_{epoch+1}") 393 | if args.zero_stage == 3: 394 | raise RuntimeError('不支持zero3') 395 | 396 | 397 | if __name__ == "__main__": 398 | main() 399 | -------------------------------------------------------------------------------- /rlhf/step2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | OUTPUT=$1 7 | ZERO_STAGE=$2 8 | if [ "$OUTPUT" == "" ]; then 9 | OUTPUT=./step2_output 10 | fi 11 | if [ "$ZERO_STAGE" == "" ]; then 12 | ZERO_STAGE=2 13 | fi 14 | mkdir -p $OUTPUT 15 | 16 | deepspeed step2.py \ 17 | --data_path ./step2_data \ 18 | --model_name_or_path ./sft_model \ 19 | --per_device_train_batch_size 16 \ 20 | --per_device_eval_batch_size 16 \ 21 | --max_seq_len 1024 \ 22 | --learning_rate 9.65e-6 \ 23 | --weight_decay 0.1 \ 24 | --num_padding_at_beginning 0 \ 25 | --num_train_epochs 2 \ 26 | --gradient_accumulation_steps 1 \ 27 | --lr_scheduler_type cosine \ 28 | --num_warmup_steps 0 \ 29 | --seed 1234 \ 30 | --zero_stage $ZERO_STAGE \ 31 | --deepspeed \ 32 | --output_dir $OUTPUT \ 33 | &> $OUTPUT/training.log 34 | -------------------------------------------------------------------------------- /rlhf/step2_eval.sh: -------------------------------------------------------------------------------- 1 | python rw_eval.py \ 2 | --model_name_or_path ./step2_output/epoch_1 -------------------------------------------------------------------------------- /rlhf/step3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) Microsoft Corporation. 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | # DeepSpeed Team 6 | ACTOR_MODEL_PATH=./sft_model 7 | CRITIC_MODEL_PATH=./step2_output/epoch_1 8 | ACTOR_ZERO_STAGE=2 9 | CRITIC_ZERO_STAGE=2 10 | OUTPUT=$5 11 | if [ "$OUTPUT" == "" ]; then 12 | OUTPUT=./step3_output 13 | fi 14 | if [ "$ACTOR_ZERO_STAGE" == "" ]; then 15 | ACTOR_ZERO_STAGE=2 16 | fi 17 | if [ "$CRITIC_ZERO_STAGE" == "" ]; then 18 | CRITIC_ZERO_STAGE=2 19 | fi 20 | mkdir -p $OUTPUT 21 | 22 | Num_Padding_at_Beginning=0 # this is model related 23 | 24 | Actor_Lr=1e-6 25 | Critic_Lr=1e-6 26 | 27 | deepspeed --master_port 12346 step3.py \ 28 | --data_path ./step3_data \ 29 | --actor_model_name_or_path $ACTOR_MODEL_PATH \ 30 | --critic_model_name_or_path $CRITIC_MODEL_PATH \ 31 | --num_padding_at_beginning 0 \ 32 | --per_device_generation_batch_size 40 \ 33 | --per_device_training_batch_size 40 \ 34 | --generation_batches 1 \ 35 | --ppo_epochs 1 \ 36 | --max_answer_seq_len 128 \ 37 | --max_prompt_seq_len 256 \ 38 | --actor_learning_rate ${Actor_Lr} \ 39 | --critic_learning_rate ${Critic_Lr} \ 40 | --actor_weight_decay 0.1 \ 41 | --critic_weight_decay 0.1 \ 42 | --num_train_epochs 3 \ 43 | --lr_scheduler_type cosine \ 44 | --gradient_accumulation_steps 1 \ 45 | --actor_gradient_checkpointing \ 46 | --critic_gradient_checkpointing \ 47 | --offload_reference_model \ 48 | --enable_ema \ 49 | --actor_dropout 0.0 \ 50 | --num_warmup_steps 100 \ 51 | --deepspeed --seed 1234 \ 52 | --actor_zero_stage $ACTOR_ZERO_STAGE \ 53 | --critic_zero_stage $CRITIC_ZERO_STAGE \ 54 | --enable_hybrid_engine \ 55 | --output_dir $OUTPUT \ 56 | &> $OUTPUT/training.log 57 | -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/data_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/data_utils.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/ds_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/ds_utils.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/model_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/model_utils.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/perf.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/perf.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/raw_datasets.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/raw_datasets.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/reward_model.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/reward_model.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AI-Study-Han/Zero-Chatgpt/03a1d98d5fcf879bf13eb410bdd54547bbd46095/rlhf/utils/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /rlhf/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, Subset, ConcatDataset 3 | from torch.nn.utils.rnn import pad_sequence 4 | import torch.nn.functional as F 5 | from datasets import load_dataset 6 | import numpy as np 7 | import os 8 | import hashlib 9 | from itertools import chain 10 | from utils.raw_datasets import LocalJsonFileDataset 11 | from deepspeed.accelerator import get_accelerator 12 | 13 | 14 | 15 | def get_raw_dataset(data_path): 16 | 17 | return LocalJsonFileDataset(data_path) 18 | 19 | 20 | 21 | def get_shuffle_idx(seed, size): 22 | np_rng = np.random.RandomState(seed=seed) 23 | dtype_ = np.uint32 24 | if size >= (np.iinfo(np.uint32).max - 1): 25 | dtype_ = np.int64 26 | shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) 27 | np_rng.shuffle(shuffle_idx) 28 | return shuffle_idx 29 | 30 | def get_raw_dataset_split_index(seed, data_size): 31 | """ 32 | Generate raw dataset split indices without saving or loading. 33 | 34 | Parameters: 35 | - seed: int, random seed for shuffling 36 | - data_size: int, size of the dataset 37 | 38 | Returns: 39 | - index_list: list, shuffled index list 40 | """ 41 | shuffle_idx = get_shuffle_idx(seed, data_size) 42 | return shuffle_idx.tolist() 43 | 44 | def create_dataset(data_path, 45 | train_phase, seed, tokenizer, end_of_conversation_token, 46 | max_seq_len): 47 | raw_dataset = get_raw_dataset(data_path) 48 | train_dataset = raw_dataset.get_train_data() 49 | train_dataset = create_dataset_split(train_dataset, raw_dataset, 50 | train_phase, tokenizer, 51 | max_seq_len) 52 | eval_dataset = raw_dataset.get_eval_data() 53 | eval_dataset = create_dataset_split(eval_dataset, raw_dataset, train_phase, 54 | tokenizer, 55 | max_seq_len) 56 | 57 | return train_dataset, eval_dataset 58 | 59 | class PromptDataset(Dataset): 60 | 61 | def __init__(self, prompt_dataset, chosen_dataset, reject_dataset, 62 | pad_token_id, train_phase) -> None: 63 | super().__init__() 64 | self.prompt_dataset = prompt_dataset 65 | self.chosen_dataset = chosen_dataset 66 | self.reject_dataset = reject_dataset 67 | self.pad_token_id = pad_token_id 68 | self.train_phase = train_phase 69 | 70 | def __len__(self): 71 | length = len(self.chosen_dataset) 72 | if self.train_phase == 3: 73 | length = len(self.prompt_dataset) 74 | return length 75 | 76 | def __getitem__(self, idx): 77 | if self.train_phase == 2: 78 | return self.chosen_dataset[idx]["input_ids"], self.chosen_dataset[idx]["attention_mask"], \ 79 | self.reject_dataset[idx]["input_ids"], self.reject_dataset[idx]["attention_mask"] 80 | elif self.train_phase == 3: 81 | return self.prompt_dataset[idx]["input_ids"],self.prompt_dataset[idx]["attention_mask"], \ 82 | self.pad_token_id 83 | 84 | 85 | def create_prompt_dataset(local_rank, 86 | data_path, 87 | train_phase, 88 | seed, 89 | tokenizer, 90 | max_seq_len, 91 | end_of_conversation_token=None): 92 | """ 93 | Creates the prompt dataset 94 | """ 95 | if local_rank <= 0 : 96 | 97 | train_dataset, eval_dataset = create_dataset( 98 | data_path, 99 | train_phase, 100 | seed, 101 | tokenizer, 102 | end_of_conversation_token, 103 | max_seq_len) 104 | return train_dataset, eval_dataset 105 | 106 | torch.distributed.barrier() 107 | return None, None 108 | 109 | def create_dataset_split(current_dataset, raw_dataset, train_phase, tokenizer, max_seq_len): 110 | prompt_dataset = [] 111 | chosen_dataset = [] 112 | reject_dataset = [] 113 | 114 | if train_phase == 2: 115 | for i, tmp_data in enumerate(current_dataset): 116 | # tokenize the text 117 | chosen_sentence = raw_dataset.get_prompt_and_chosen( 118 | tmp_data) # the accept response 119 | reject_sentence = raw_dataset.get_prompt_and_rejected( 120 | tmp_data) # the accept response 121 | if chosen_sentence is not None and reject_sentence is not None: 122 | # chosen_sentence += end_of_conversation_token # the accept response 123 | # reject_sentence += end_of_conversation_token 124 | chosen_token = tokenizer(chosen_sentence, 125 | max_length=max_seq_len, 126 | padding="max_length", 127 | truncation=True, 128 | return_tensors="pt") 129 | reject_token = tokenizer(reject_sentence, 130 | max_length=max_seq_len, 131 | padding="max_length", 132 | truncation=True, 133 | return_tensors="pt") 134 | chosen_token["input_ids"] = chosen_token["input_ids"] 135 | chosen_token["attention_mask"] = chosen_token["attention_mask"] 136 | chosen_dataset.append(chosen_token) 137 | 138 | reject_token["input_ids"] = reject_token["input_ids"] 139 | reject_token["attention_mask"] = reject_token["attention_mask"] 140 | reject_dataset.append(reject_token) 141 | print( 142 | f'Creating dataset {raw_dataset.dataset_name_clean} for {train_phase=} size={len(chosen_dataset)}' 143 | ) 144 | 145 | elif train_phase == 3: 146 | filtered = 0 147 | for i, tmp_data in enumerate(current_dataset): 148 | # tokenize the text 149 | prompt = raw_dataset.get_prompt(tmp_data) 150 | if prompt is not None: 151 | prompt_token = tokenizer(prompt, return_tensors="pt") 152 | if prompt_token["input_ids"].size()[-1] <= max_seq_len: 153 | for key_word in ["input_ids", "attention_mask"]: 154 | prompt_token[key_word] = prompt_token[ 155 | key_word].squeeze(0).flip(0) 156 | prompt_dataset.append(prompt_token) 157 | else: 158 | filtered += 1 159 | print(f'Creating dataset {raw_dataset.dataset_name_clean} ' 160 | f'for {train_phase=} size={len(prompt_dataset)} {filtered=}') 161 | 162 | return PromptDataset(prompt_dataset, chosen_dataset, reject_dataset, 163 | tokenizer.pad_token_id, train_phase) 164 | 165 | 166 | class DataCollatorReward: 167 | 168 | def __call__(self, data): 169 | batch = {} 170 | batch["input_ids"] = torch.cat([f[0] 171 | for f in data] + [f[2] for f in data], 172 | dim=0) 173 | batch["attention_mask"] = torch.cat([f[1] for f in data] + 174 | [f[3] for f in data], 175 | dim=0) 176 | return batch 177 | 178 | class MiniDataset: 179 | 180 | def __init__(self, max_size, small_batch_size): 181 | self.dataset = [] 182 | self.max_size = max_size 183 | self.small_batch_size = small_batch_size 184 | 185 | def seperate(self): 186 | small_dataset = [] 187 | for large_batch in self.dataset: 188 | if type(large_batch) == list or type(large_batch) == tuple: 189 | large_size = len(large_batch[0]) 190 | elif type(large_batch) == dict: 191 | large_size = len(large_batch[list(large_batch.keys())[0]]) 192 | else: 193 | large_size = len(large_batch) 194 | for i in range(0, large_size, self.small_batch_size): 195 | if type(large_batch) == list or type(large_batch) == tuple: 196 | small_dataset.append( 197 | [x[i:i + self.small_batch_size] for x in large_batch]) 198 | elif type(large_batch) == dict: 199 | small_dataset.append({ 200 | k: v[i:i + self.small_batch_size] 201 | for k, v in large_batch.items() 202 | }) 203 | else: 204 | small_dataset.append(large_batch[i:i + 205 | self.small_batch_size]) 206 | self.free() 207 | 208 | return small_dataset 209 | 210 | def add(self, data): 211 | if len(self.dataset) < self.max_size: 212 | self.dataset.append(data) 213 | if len(self.dataset) == self.max_size: 214 | return self.seperate() 215 | else: 216 | return None 217 | else: 218 | raise ValueError( 219 | "The dataset is full but we did not stop it. There is a bug in the code." 220 | ) 221 | 222 | def free(self): 223 | self.dataset = [] 224 | 225 | class DataCollatorRLHF: 226 | 227 | def __init__(self, max_token_len, inference_tp_size, pad_token_id): 228 | self.max_token_len = max_token_len 229 | self.inference_tp_size = inference_tp_size 230 | self.pad_token_id = pad_token_id 231 | 232 | def __call__(self, data): 233 | batch = {} 234 | # pad_token_id = data[-1][-1] 235 | 236 | prompt = pad_sequence([f[0] for f in data], 237 | padding_value=self.pad_token_id, 238 | batch_first=True) 239 | prompt_mask = pad_sequence([f[1] for f in data], 240 | padding_value=0, 241 | batch_first=True) 242 | 243 | ### make sure the final ouput is a seqence of 2**? 244 | length = prompt.size()[-1] 245 | pad_length = self.max_token_len - length 246 | if pad_length > 0: 247 | batch["prompt"] = F.pad(prompt, 248 | pad=(0, pad_length), 249 | mode='constant', 250 | value=self.pad_token_id) 251 | batch["prompt_att_mask"] = F.pad(prompt_mask, 252 | pad=(0, pad_length), 253 | mode='constant', 254 | value=0) 255 | else: 256 | batch["prompt"] = prompt 257 | batch["prompt_att_mask"] = prompt_mask 258 | batch["prompt"] = batch["prompt"].flip(1) 259 | batch["prompt_att_mask"] = batch["prompt_att_mask"].flip(1) 260 | return batch 261 | 262 | 263 | def get_unsupervised_data(args, tokenizer): 264 | unsupervised_raw_datasets = load_dataset( 265 | args.unsupervised_dataset_name, args.unsupervised_dataset_config_name) 266 | column_names = unsupervised_raw_datasets["train"].column_names 267 | text_column_name = "text" if "text" in column_names else column_names[0] 268 | 269 | def tokenize_function(examples): 270 | return tokenizer(examples[text_column_name]) 271 | 272 | tokenized_datasets = unsupervised_raw_datasets.map( 273 | tokenize_function, 274 | batched=True, 275 | num_proc=args.preprocessing_num_workers, 276 | remove_columns=column_names, 277 | load_from_cache_file=True, 278 | desc="Running tokenizer on dataset", 279 | ) 280 | 281 | block_size = args.max_prompt_seq_len + args.max_answer_seq_len 282 | 283 | def group_texts(examples): 284 | # Concatenate all texts. 285 | concatenated_examples = { 286 | k: list(chain(*examples[k])) 287 | for k in examples.keys() 288 | } 289 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 290 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 291 | # customize this part to your needs. 292 | if total_length >= block_size: 293 | total_length = (total_length // block_size) * block_size 294 | # Split by chunks of max_len. 295 | result = { 296 | k: 297 | [t[i:i + block_size] for i in range(0, total_length, block_size)] 298 | for k, t in concatenated_examples.items() 299 | } 300 | result["labels"] = result["input_ids"].copy() 301 | return result 302 | 303 | lm_datasets = tokenized_datasets.map( 304 | group_texts, 305 | batched=True, 306 | num_proc=args.preprocessing_num_workers, 307 | load_from_cache_file=True, 308 | desc=f"Grouping texts in chunks of {block_size}", 309 | ) 310 | 311 | train_dataset = lm_datasets["train"] 312 | 313 | return train_dataset -------------------------------------------------------------------------------- /rlhf/utils/ds_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import deepspeed.comm as dist 7 | from deepspeed.accelerator import get_accelerator 8 | 9 | GLOBAL_BATCH_SIZE = 32 10 | MICRO_BATCH_SIZE = 4 11 | 12 | 13 | def get_train_ds_config(offload, 14 | dtype, 15 | stage=2, 16 | enable_hybrid_engine=False, 17 | inference_tp_size=1, 18 | release_inference_cache=False, 19 | pin_parameters=True, 20 | tp_gather_partition_size=8, 21 | max_out_tokens=512, 22 | enable_tensorboard=False, 23 | enable_mixed_precision_lora=False, 24 | tb_path="", 25 | tb_name=""): 26 | 27 | device = "cpu" if offload else "none" 28 | if dtype == "fp16": 29 | data_type = "fp16" 30 | dtype_config = {"enabled": True, "loss_scale_window": 100} 31 | elif dtype == "bf16": 32 | data_type = "bfloat16" 33 | dtype_config = {"enabled": True} 34 | zero_opt_dict = { 35 | "stage": stage, 36 | "offload_param": { 37 | "device": device 38 | }, 39 | "offload_optimizer": { 40 | "device": device 41 | }, 42 | "stage3_param_persistence_threshold": 1e4, 43 | "stage3_max_live_parameters": 3e7, 44 | "stage3_prefetch_bucket_size": 3e7, 45 | "memory_efficient_linear": False 46 | } 47 | if enable_mixed_precision_lora: 48 | zero_opt_dict["zero_quantized_nontrainable_weights"] = True 49 | if dist.get_world_size() != get_accelerator().device_count(): 50 | zero_opt_dict["zero_hpz_partition_size"] = get_accelerator( 51 | ).device_count() 52 | return { 53 | "train_batch_size": GLOBAL_BATCH_SIZE, 54 | "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, 55 | "steps_per_print": 10, 56 | "zero_optimization": zero_opt_dict, 57 | data_type: dtype_config, 58 | "gradient_clipping": 1.0, 59 | "prescale_gradients": False, 60 | "wall_clock_breakdown": False, 61 | "hybrid_engine": { 62 | "enabled": enable_hybrid_engine, 63 | "max_out_tokens": max_out_tokens, 64 | "inference_tp_size": inference_tp_size, 65 | "release_inference_cache": release_inference_cache, 66 | "pin_parameters": pin_parameters, 67 | "tp_gather_partition_size": tp_gather_partition_size, 68 | }, 69 | "tensorboard": { 70 | "enabled": enable_tensorboard, 71 | "output_path": f"{tb_path}/ds_tensorboard_logs/", 72 | "job_name": f"{tb_name}_tensorboard" 73 | } 74 | } 75 | 76 | 77 | def get_eval_ds_config(offload, dtype, stage=0): 78 | device = "cpu" if offload else "none" 79 | if dtype == "fp16": 80 | data_type = "fp16" 81 | dtype_config = { 82 | "enabled": True, 83 | } 84 | elif dtype == "bf16": 85 | data_type = "bfloat16" 86 | dtype_config = {"enabled": True} 87 | zero_opt_dict = { 88 | "stage": stage, 89 | "stage3_param_persistence_threshold": 1e4, 90 | "offload_param": { 91 | "device": device 92 | }, 93 | "memory_efficient_linear": False 94 | } 95 | return { 96 | "train_batch_size": GLOBAL_BATCH_SIZE, 97 | "train_micro_batch_size_per_gpu": MICRO_BATCH_SIZE, 98 | "steps_per_print": 10, 99 | "zero_optimization": zero_opt_dict, 100 | data_type: dtype_config, 101 | "gradient_clipping": 1.0, 102 | "prescale_gradients": False, 103 | "wall_clock_breakdown": False 104 | } 105 | -------------------------------------------------------------------------------- /rlhf/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from transformers import ( 5 | AutoConfig, 6 | AutoModel, 7 | ) 8 | from transformers.deepspeed import HfDeepSpeedConfig 9 | from utils.reward_model import RewardModel 10 | from utils.utils import load_state_dict_into_model, print_rank_0 11 | 12 | 13 | 14 | def create_hf_model(model_class, 15 | model_name_or_path, 16 | tokenizer, 17 | ds_config=None, 18 | rlhf_training=False, 19 | dropout=None): 20 | model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) 21 | 22 | 23 | # Note: dschf is defined in function scope to avoid global effects 24 | # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration 25 | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: 26 | dschf = HfDeepSpeedConfig(ds_config) 27 | else: 28 | dschf = None 29 | if rlhf_training: 30 | # the weight loading is handled by create critic model 31 | model = model_class.from_config(model_config, trust_remote_code=True) 32 | else: 33 | model = model_class.from_pretrained( 34 | model_name_or_path, 35 | from_tf=bool(".ckpt" in model_name_or_path), 36 | config=model_config, 37 | trust_remote_code=True) 38 | 39 | model.config.end_token_id = tokenizer.eos_token_id 40 | model.config.pad_token_id = model.config.eos_token_id 41 | model.resize_token_embeddings(int( 42 | 8 * 43 | math.ceil(len(tokenizer) / 8.0))) # make the vocab size multiple of 8 44 | 45 | return model 46 | 47 | def create_critic_model(model_name_or_path, 48 | tokenizer, 49 | ds_config, 50 | num_padding_at_beginning=0, 51 | rlhf_training=False, 52 | dropout=None, 53 | zero_stage=0, 54 | compute_fp32_loss=False): 55 | # OPT model family always put a padding token at the beginning of the sequence, 56 | # we did not see this in other models but not sure if it is a general rule 57 | 58 | import time 59 | 60 | start = time.time() 61 | critic_model = create_hf_model(AutoModel, model_name_or_path, tokenizer, 62 | ds_config, rlhf_training, dropout) 63 | end = time.time() 64 | print_rank_0(f">Creating model from_config took {end - start} seconds", 65 | None) 66 | 67 | critic_model = RewardModel( 68 | critic_model, 69 | tokenizer, 70 | num_padding_at_beginning=num_padding_at_beginning, 71 | compute_fp32_loss=compute_fp32_loss) 72 | 73 | if rlhf_training: 74 | # load critic model from checkpoint 75 | 76 | 77 | model_ckpt_path = os.path.join(model_name_or_path, 'pytorch_model.bin') 78 | assert os.path.exists( 79 | model_ckpt_path 80 | ), f"Cannot find model checkpoint at {model_ckpt_path}" 81 | 82 | start = time.time() 83 | model_ckpt_state_dict = torch.load(model_ckpt_path, map_location='cpu') 84 | end = time.time() 85 | print_rank_0(f">Creating model from_config took {end - start} seconds", 86 | None) 87 | 88 | # load critic model from checkpoint with zero-stage 3 compatibility 89 | # this functionality may be moved to DS checkpoint load API in future 90 | start = time.time() 91 | load_state_dict_into_model(critic_model, 92 | model_ckpt_state_dict, 93 | "", 94 | zero_stage=zero_stage) 95 | end = time.time() 96 | 97 | print_rank_0(f">Creating model from_config took {end - start} seconds", 98 | None) 99 | 100 | return critic_model -------------------------------------------------------------------------------- /rlhf/utils/perf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | import torch 7 | 8 | 9 | # This function can be used to print throughput for Step 1 and 2 only 10 | def print_throughput(hf_model, args, e2e_time, rank=0): 11 | if rank <= 0: 12 | hf_config = hf_model.config 13 | num_layers, hidden_size, vocab_size = get_hf_configs(hf_config) 14 | 15 | gpus_per_model = torch.distributed.get_world_size() 16 | seq_length = args.max_seq_len 17 | batch_size = args.per_device_train_batch_size 18 | samples_per_second = batch_size / e2e_time 19 | checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3 20 | if args.lora_dim > 0: 21 | k = args.lora_dim * 2 / hidden_size 22 | checkpoint_activations_factor -= (1 - k) 23 | 24 | hf_model._num_params = sum([ 25 | p.ds_numel if hasattr(p, "ds_tensor") else p.numel() 26 | for p in hf_model.parameters() 27 | ]) 28 | params_in_billions = hf_model._num_params / (1e9) 29 | 30 | # Megatron paper's formula to calculate training flops 31 | train_flops_per_iteration = calculate_flops( 32 | checkpoint_activations_factor, batch_size, seq_length, hf_config) 33 | 34 | train_tflops = train_flops_per_iteration / (e2e_time * gpus_per_model * 35 | (10**12)) 36 | 37 | param_string = f"{params_in_billions:.3f} B" if params_in_billions != 0 else "NA" 38 | print( 39 | f"Model Parameters: {param_string}, Latency: {e2e_time:.2f}s, TFLOPs: {train_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Sequence Length: {seq_length}" 40 | ) 41 | 42 | 43 | # Enhanced version of the function above that provides calculations and printing for Step 3 44 | def print_throughput_step3(actor_model, 45 | critic_model, 46 | args, 47 | e2e_time, 48 | gen_exp_time, 49 | train_time, 50 | rank=0): 51 | if rank <= 0: 52 | # Actor model passed here is a HF model. 53 | actor_hf_config = actor_model.config 54 | # Critic model passed here is a DeepSpeed Engine. The module inside is the Reward model (that wraps a HF model). 55 | critic_hf_config = critic_model.module.config 56 | 57 | actor_num_layers, actor_hidden_size, actor_vocab_size = get_hf_configs( 58 | actor_hf_config) 59 | critic_num_layers, critic_hidden_size, critic_vocab_size = get_hf_configs( 60 | critic_hf_config) 61 | 62 | gpus_per_model = torch.distributed.get_world_size() 63 | seq_length = args.max_answer_seq_len + args.max_prompt_seq_len 64 | batch_size = args.per_device_generation_batch_size * args.generation_batches * args.ppo_epochs * gpus_per_model * 1 if args.unsupervised_dataset_name is None else 2 65 | samples_per_second = batch_size / e2e_time 66 | 67 | actor_checkpoint_activations_factor = 4 if args.actor_gradient_checkpointing else 3 68 | critic_checkpoint_activations_factor = 4 if args.critic_gradient_checkpointing else 3 69 | if args.actor_lora_dim > 0: 70 | k = args.actor_lora_dim * 2 / actor_hidden_size 71 | actor_checkpoint_activations_factor -= (1 - k) 72 | if args.critic_lora_dim > 0: 73 | k = args.critic_lora_dim * 2 / critic_hidden_size 74 | critic_checkpoint_activations_factor -= (1 - k) 75 | 76 | actor_model._num_params = sum([ 77 | p.ds_numel if hasattr(p, "ds_tensor") else p.numel() 78 | for p in actor_model.parameters() 79 | ]) 80 | actor_params_in_billions = actor_model._num_params / (1e9) 81 | 82 | critic_model._num_params = sum([ 83 | p.ds_numel if hasattr(p, "ds_tensor") else p.numel() 84 | for p in critic_model.parameters() 85 | ]) 86 | critic_params_in_billions = critic_model._num_params / (1e9) 87 | 88 | # Megatron paper's formula to calculate training flops 89 | 90 | actor_train_flops_per_iteration = calculate_flops( 91 | actor_checkpoint_activations_factor, batch_size, seq_length, 92 | actor_hf_config) 93 | critic_train_flops_per_iteration = calculate_flops( 94 | critic_checkpoint_activations_factor, batch_size, seq_length, 95 | critic_hf_config) 96 | 97 | total_train_flops = actor_train_flops_per_iteration + critic_train_flops_per_iteration 98 | train_tflops = total_train_flops / (train_time * gpus_per_model * 99 | (10**12)) 100 | 101 | gen_bs = args.per_device_generation_batch_size * gpus_per_model 102 | 103 | # Modified formula for calculating flops in the forward pass only 104 | gen_flops_per_iteration = ( 105 | 24 * gen_bs * seq_length * actor_num_layers * 106 | (actor_hidden_size**2)) * ( 107 | 1.0 + (seq_length / (6.0 * actor_hidden_size)) + 108 | (actor_vocab_size / 109 | (16.0 * actor_num_layers * actor_hidden_size))) 110 | 111 | gen_tflops = gen_flops_per_iteration / (gen_exp_time * gpus_per_model * 112 | (10**12)) 113 | 114 | if actor_hf_config.torch_dtype == torch.float16: 115 | num_bytes = 2 116 | elif actor_hf_config.torch_dtype == torch.float32: 117 | num_bytes = 4 118 | else: 119 | num_bytes = -1 120 | 121 | pertok_lat = gen_exp_time / args.max_answer_seq_len 122 | gen_bw = 1 / pertok_lat * actor_model._num_params * num_bytes / 1e9 123 | 124 | total_flops_per_iteration = total_train_flops + gen_flops_per_iteration * args.generation_batches 125 | total_tflops = total_flops_per_iteration / (e2e_time * gpus_per_model * 126 | (10**12)) 127 | 128 | print( 129 | f"End-to-End => Latency: {e2e_time:.2f}s, TFLOPs: {total_tflops:.2f}, Samples/sec: {samples_per_second:.2f}, Time/seq {e2e_time/batch_size:.2f}s, Batch Size: {batch_size}, Total Seq. Length: {seq_length}" 130 | ) 131 | print( 132 | f"Generation => Latency: {gen_exp_time:.2f}s, Per-token Latency {pertok_lat*1000:.2f} ms, TFLOPs: {gen_tflops:.2f}, BW: {gen_bw if num_bytes > 0 else num_bytes:.2f} GB/sec, Answer Seq. Length: {args.max_answer_seq_len}" 133 | ) 134 | print( 135 | f"Training => Latency: {train_time:.2f}s, TFLOPs: {train_tflops:.2f}" 136 | ) 137 | actor_param_string = f"{actor_params_in_billions:.3f} B" if actor_params_in_billions != 0 else "NA" 138 | critic_param_string = f"{critic_params_in_billions:.3f} B" if critic_params_in_billions != 0 else "NA" 139 | print( 140 | f"Actor Model Parameters => {actor_param_string}, Critic Model Parameters => {critic_param_string}" 141 | ) 142 | 143 | 144 | # Helper function to calculate FLOPs using the Megatron-LM paper's formula 145 | def calculate_flops(checkpoint_activations_factor, batch_size, seq_length, 146 | hf_config): 147 | num_layers, hidden_size, vocab_size = get_hf_configs(hf_config) 148 | flops_per_iteration = (24 * checkpoint_activations_factor * batch_size * 149 | seq_length * num_layers * (hidden_size**2)) * ( 150 | 1.0 + (seq_length / (6.0 * hidden_size)) + 151 | (vocab_size / 152 | (16.0 * num_layers * hidden_size))) 153 | return flops_per_iteration 154 | 155 | 156 | def get_hf_configs(hf_config): 157 | num_layers = getattr(hf_config, "num_hidden_layers", 158 | getattr(hf_config, "n_layer", None)) 159 | hidden_size = getattr(hf_config, "hidden_size", 160 | getattr(hf_config, "n_embd", None)) 161 | vocab_size = getattr(hf_config, "vocab_size", None) 162 | assert all( 163 | (num_layers, hidden_size, vocab_size) 164 | ), "Could not determine number of layers, hidden size, and vocab size of the model" 165 | 166 | return num_layers, hidden_size, vocab_size 167 | -------------------------------------------------------------------------------- /rlhf/utils/raw_datasets.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | # DeepSpeed Team 4 | from datasets import load_dataset, load_from_disk 5 | from torch.utils.data import Subset 6 | import re 7 | 8 | class LocalJsonFileDataset(object): 9 | 10 | def __init__(self, data_path): 11 | self.dataset_name = "local/jsonfile" 12 | self.dataset_name_clean = "jsonfile" 13 | self.raw_datasets = load_dataset('json', 14 | data_files={ 15 | "train": 16 | data_path + '/train.jsonl', 17 | "eval": 18 | data_path + '/eval.jsonl' 19 | }) 20 | 21 | def get_train_data(self): 22 | if self.raw_datasets['train'] is not None: 23 | return self.raw_datasets['train'] 24 | return None 25 | 26 | def get_eval_data(self): 27 | if self.raw_datasets['eval'] is not None: 28 | return self.raw_datasets['eval'] 29 | return None 30 | 31 | # The prompt should be in the format of: " Human: " + actual_prompt_sentence + " Assistant:" 32 | def get_prompt(self, sample): 33 | if sample['prompt'] is not None: 34 | return sample['prompt'] 35 | return None 36 | 37 | # The chosen response should be in the format of: " " + actual_response_sentence 38 | def get_chosen(self, sample): 39 | if sample['chosen'] is not None: 40 | return sample['chosen'] 41 | return None 42 | 43 | # The rejected response should be in the format of: " " + actual_response_sentence 44 | # If the dataset does not have rejected response, return None 45 | def get_rejected(self, sample): 46 | if sample['rejected'] is not None: 47 | return sample['rejected'] 48 | return None 49 | 50 | def get_prompt_and_chosen(self, sample): 51 | if sample['prompt'] is not None and sample['chosen'] is not None: 52 | return sample['prompt'] + sample['chosen'] 53 | return None 54 | 55 | def get_prompt_and_rejected(self, sample): 56 | if sample['prompt'] is not None and sample['rejected'] is not None: 57 | return sample['prompt'] + sample['rejected'] 58 | return None -------------------------------------------------------------------------------- /rlhf/utils/reward_model.py: -------------------------------------------------------------------------------- 1 | # DeepSpeed Team 2 | import torch 3 | from torch import nn 4 | 5 | 6 | ## Note that the following code is modified from 7 | ## https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py 8 | class RewardModel(nn.Module): 9 | 10 | def __init__(self, 11 | base_model, 12 | tokenizer, 13 | num_padding_at_beginning=0, 14 | compute_fp32_loss=False): 15 | super().__init__() 16 | self.config = base_model.config 17 | self.num_padding_at_beginning = num_padding_at_beginning 18 | if hasattr(self.config, "word_embed_proj_dim"): 19 | # `OPT` models use word_embed_proj_dim as final output 20 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L497 21 | self.v_head = nn.Linear(self.config.word_embed_proj_dim, 22 | 1, 23 | bias=False) 24 | else: 25 | # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` 26 | self.config.n_embd = self.config.hidden_size if hasattr( 27 | self.config, "hidden_size") else self.config.n_embd 28 | self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) 29 | self.rwtransformer = base_model 30 | self.PAD_ID = tokenizer.pad_token_id 31 | self.compute_fp32_loss = compute_fp32_loss 32 | 33 | def gradient_checkpointing_enable(self): 34 | self.rwtransformer.gradient_checkpointing_enable() 35 | 36 | def gradient_checkpointing_disable(self): 37 | self.rwtransformer.gradient_checkpointing_disable() 38 | 39 | def forward(self, 40 | input_ids=None, 41 | past_key_values=None, 42 | attention_mask=None, 43 | position_ids=None, 44 | head_mask=None, 45 | inputs_embeds=None, 46 | use_cache=False): 47 | loss = None 48 | 49 | 50 | kwargs = dict() 51 | 52 | transformer_outputs = self.rwtransformer( 53 | input_ids, 54 | past_key_values=past_key_values, 55 | attention_mask=attention_mask, 56 | inputs_embeds=inputs_embeds, 57 | use_cache=use_cache, 58 | **kwargs) 59 | 60 | hidden_states = transformer_outputs[0] 61 | rewards = self.v_head(hidden_states).squeeze(-1) 62 | chosen_mean_scores = [] 63 | rejected_mean_scores = [] 64 | 65 | # Split the inputs and rewards into two parts, chosen and rejected 66 | assert len(input_ids.shape) == 2 67 | bs = input_ids.shape[0] // 2 68 | seq_len = input_ids.shape[1] 69 | 70 | chosen_ids = input_ids[:bs] # bs x seq x 1 71 | rejected_ids = input_ids[bs:] 72 | chosen_rewards = rewards[:bs] 73 | rejected_rewards = rewards[bs:] 74 | 75 | # Compute pairwise loss. Only backprop on the different tokens before padding 76 | loss = 0. 77 | for i in range(bs): 78 | chosen_id = chosen_ids[i] 79 | rejected_id = rejected_ids[i] 80 | chosen_reward = chosen_rewards[i] 81 | rejected_reward = rejected_rewards[i] 82 | 83 | c_inds = (chosen_id == self.PAD_ID).nonzero() 84 | c_ind = c_inds[self.num_padding_at_beginning].item() if len( 85 | c_inds 86 | ) > self.num_padding_at_beginning else seq_len # OPT model pads the first token, so we need to use the second padding token as the end of the sequence 87 | check_divergence = (chosen_id != rejected_id).nonzero() 88 | 89 | if len(check_divergence) == 0: 90 | end_ind = rejected_reward.size(-1) 91 | divergence_ind = end_ind - 1 92 | r_ind = c_ind 93 | else: 94 | # Check if there is any padding otherwise take length of sequence 95 | r_inds = (rejected_id == self.PAD_ID).nonzero() 96 | r_ind = r_inds[self.num_padding_at_beginning].item( 97 | ) if len(r_inds) > self.num_padding_at_beginning else seq_len 98 | end_ind = max(c_ind, r_ind) 99 | divergence_ind = check_divergence[0] 100 | assert divergence_ind > 0 101 | c_truncated_reward = chosen_reward[divergence_ind:end_ind] 102 | r_truncated_reward = rejected_reward[divergence_ind:end_ind] 103 | chosen_mean_scores.append( 104 | chosen_reward[c_ind - 1]) #use the end score for reference 105 | rejected_mean_scores.append(rejected_reward[r_ind - 1]) 106 | 107 | if self.compute_fp32_loss: 108 | c_truncated_reward = c_truncated_reward.float() 109 | r_truncated_reward = r_truncated_reward.float() 110 | loss += -torch.nn.functional.logsigmoid(c_truncated_reward - 111 | r_truncated_reward).mean() 112 | 113 | loss = loss / bs 114 | chosen_mean_scores = torch.stack(chosen_mean_scores) 115 | rejected_mean_scores = torch.stack(rejected_mean_scores) 116 | return { 117 | "loss": loss, 118 | "chosen_mean_scores": chosen_mean_scores, 119 | "rejected_mean_scores": rejected_mean_scores, 120 | } 121 | 122 | def forward_value(self, 123 | input_ids=None, 124 | attention_mask=None, 125 | past_key_values=None, 126 | position_ids=None, 127 | head_mask=None, 128 | inputs_embeds=None, 129 | return_value_only=False, 130 | prompt_length=0, 131 | use_cache=False): 132 | 133 | 134 | kwargs = dict() 135 | 136 | 137 | transformer_outputs = self.rwtransformer( 138 | input_ids, 139 | past_key_values=past_key_values, 140 | attention_mask=attention_mask, 141 | inputs_embeds=inputs_embeds, 142 | use_cache=use_cache, 143 | **kwargs) 144 | hidden_states = transformer_outputs[0] 145 | values = self.v_head(hidden_states).squeeze(-1) 146 | if return_value_only: 147 | return values 148 | else: 149 | # [0 0 0 0 prompt, answer, 0 0 0 0 ] for step 3, we have padding at the beginning 150 | # [prompt, answer, 0, 0, 0, 0] this is normal 151 | assert prompt_length > 1, "prompt_length must be greater than 1 to help select the end score" 152 | bs = values.size(0) 153 | seq_len = input_ids.shape[1] 154 | chosen_end_scores = [ 155 | ] # we use this name for consistency with the original forward function 156 | for i in range(bs): 157 | input_id = input_ids[i] 158 | value = values[i] 159 | 160 | c_inds = (input_id[prompt_length:] == self.PAD_ID).nonzero() 161 | # here we only use the answer part of the sequence so we do not need to care about the padding at the beginning 162 | c_ind = c_inds[0].item() + prompt_length if len( 163 | c_inds) > 0 else seq_len 164 | chosen_end_scores.append(value[c_ind - 1]) 165 | return { 166 | "values": values, 167 | "chosen_end_scores": torch.stack(chosen_end_scores), 168 | } 169 | -------------------------------------------------------------------------------- /rlhf/utils/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import os 3 | import torch 4 | import random 5 | import numpy as np 6 | from transformers import set_seed, AutoTokenizer 7 | import json 8 | import deepspeed 9 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 10 | from deepspeed.accelerator import get_accelerator 11 | import torch.nn as nn 12 | 13 | def print_rank_0(msg, rank=None): 14 | if rank is not None and rank <= 0: 15 | print(msg) 16 | elif is_rank_0(): 17 | print(msg) 18 | 19 | def is_rank_0(): 20 | """Check whether it is rank 0.""" 21 | if torch.distributed.is_initialized(): 22 | if torch.distributed.get_rank() == 0: 23 | return True 24 | else: 25 | return False 26 | else: 27 | return True 28 | 29 | def load_hf_tokenizer(model_name_or_path): 30 | tokenizer = AutoTokenizer.from_pretrained( 31 | model_name_or_path, trust_remote_code=True) 32 | return tokenizer 33 | 34 | def load_state_dict_into_model(model_to_load=None, 35 | state_dict=None, 36 | start_prefix="", 37 | zero_stage=0): 38 | 39 | # copy state_dict so _load_from_state_dict can modify it 40 | metadata = getattr(state_dict, "_metadata", None) 41 | state_dict = state_dict.copy() 42 | if metadata is not None: 43 | state_dict._metadata = metadata 44 | 45 | error_msgs = [] 46 | 47 | # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants 48 | # so we need to apply the function recursively. 49 | def load(module: nn.Module, state_dict, prefix=""): 50 | local_metadata = {} if metadata is None else metadata.get( 51 | prefix[:-1], {}) 52 | args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) 53 | # Parameters of module and children will start with prefix. We can exit early if there are none in this 54 | # state_dict 55 | if len([key for key in state_dict if key.startswith(prefix)]) > 0: 56 | if zero_stage == 3: 57 | # In sharded models, each shard has only part of the full state_dict, so only gather 58 | # parameters that are in the current state_dict. 59 | named_parameters = dict( 60 | module.named_parameters(prefix=prefix[:-1], recurse=False)) 61 | params_to_gather = [ 62 | named_parameters[k] for k in state_dict.keys() 63 | if k in named_parameters 64 | ] 65 | if len(params_to_gather) > 0: 66 | # because zero3 puts placeholders in model params, this context 67 | # manager gathers (unpartitions) the params of the current layer, then loads from 68 | # the state dict and then re-partitions them again 69 | with deepspeed.zero.GatheredParameters(params_to_gather, 70 | modifier_rank=0): 71 | if torch.distributed.get_rank() == 0: 72 | module._load_from_state_dict(*args) 73 | else: 74 | module._load_from_state_dict(*args) 75 | 76 | for name, child in module._modules.items(): 77 | if child is not None: 78 | load(child, state_dict, prefix + name + ".") 79 | 80 | load(model_to_load, state_dict, prefix=start_prefix) 81 | # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so 82 | # it's safe to delete it. 83 | del state_dict 84 | 85 | return error_msgs 86 | 87 | def save_hf_format(model, tokenizer, args, sub_folder=""): 88 | # used to save huggingface format, so we can use it for hf.from_pretrained 89 | model_to_save = model.module if hasattr(model, 'module') else model 90 | CONFIG_NAME = "config.json" 91 | WEIGHTS_NAME = "pytorch_model.bin" 92 | output_dir = os.path.join(args.output_dir, sub_folder) 93 | os.makedirs(output_dir, exist_ok=True) 94 | output_model_file = os.path.join(output_dir, WEIGHTS_NAME) 95 | output_config_file = os.path.join(output_dir, CONFIG_NAME) 96 | save_dict = model_to_save.state_dict() 97 | for key in list(save_dict.keys()): 98 | if "lora" in key: 99 | del save_dict[key] 100 | torch.save(save_dict, output_model_file) 101 | model_to_save.config.to_json_file(output_config_file) 102 | tokenizer.save_vocabulary(output_dir) 103 | 104 | def to_device(batch, device): 105 | output = {} 106 | for k, v in batch.items(): 107 | try: 108 | output[k] = v.to(device) 109 | except: 110 | output[k] = v 111 | return output 112 | 113 | def set_random_seed(seed): 114 | if seed is not None: 115 | set_seed(seed) 116 | random.seed(seed) 117 | np.random.seed(seed) 118 | torch.manual_seed(seed) 119 | get_accelerator().manual_seed_all(seed) 120 | 121 | def get_all_reduce_mean(tensor): 122 | torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM) 123 | tensor = tensor / torch.distributed.get_world_size() 124 | return tensor 125 | 126 | def get_optimizer_grouped_parameters( 127 | model, 128 | weight_decay, 129 | lora_lr=5e-4, 130 | no_decay_name_list=[ 131 | "bias", "layer_norm.weight", "layernorm.weight", "norm.weight", 132 | "ln_f.weight" 133 | ], 134 | lora_name_list=["lora_right_weight", "lora_left_weight"], 135 | ): 136 | optimizer_grouped_parameters = [ 137 | { 138 | "params": [ 139 | p for n, p in model.named_parameters() 140 | if (not any(nd in n.lower() for nd in no_decay_name_list) 141 | and p.requires_grad and not any(nd in n.lower() 142 | for nd in lora_name_list)) 143 | ], 144 | "weight_decay": 145 | weight_decay, 146 | }, 147 | { 148 | "params": [ 149 | p for n, p in model.named_parameters() 150 | if (not any(nd in n.lower() for nd in no_decay_name_list) 151 | and p.requires_grad and any(nd in n.lower() 152 | for nd in lora_name_list)) 153 | ], 154 | "weight_decay": 155 | weight_decay, 156 | "lr": 157 | lora_lr 158 | }, 159 | { 160 | "params": [ 161 | p for n, p in model.named_parameters() 162 | if (any(nd in n.lower() 163 | for nd in no_decay_name_list) and p.requires_grad) 164 | ], 165 | "weight_decay": 166 | 0.0, 167 | }, 168 | ] 169 | 170 | non_empty_groups = [] 171 | for group in optimizer_grouped_parameters: 172 | if group["params"]: 173 | non_empty_groups.append(group) 174 | return non_empty_groups 175 | 176 | def _z3_params_to_fetch(param_list): 177 | return [ 178 | p for p in param_list 179 | if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE 180 | ] 181 | 182 | def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0): 183 | zero_stage_3 = (zero_stage == 3) 184 | with torch.no_grad(): 185 | for param, param_ema in zip(model.parameters(), 186 | model_ema.parameters()): 187 | # TODO: use prefiltering for efficiency 188 | params_to_fetch = _z3_params_to_fetch([param, param_ema 189 | ]) if zero_stage_3 else [] 190 | should_gather_param = len(params_to_fetch) > 0 191 | with deepspeed.zero.GatheredParameters( 192 | params_to_fetch, enabled=should_gather_param): 193 | data = param.data 194 | if device is not None: 195 | data = data.to(device) 196 | param_ema.data.copy_(torch.lerp(data, param_ema.data, beta)) 197 | 198 | 199 | 200 | def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0): 201 | zero_stage_3 = (zero_stage == 3) 202 | os.makedirs(save_dir, exist_ok=True) 203 | WEIGHTS_NAME = "pytorch_model.bin" 204 | output_model_file = os.path.join(save_dir, WEIGHTS_NAME) 205 | 206 | model_to_save = model_ema.module if hasattr(model_ema, 207 | 'module') else model_ema 208 | if not zero_stage_3: 209 | if global_rank == 0: 210 | torch.save(model_to_save.state_dict(), output_model_file) 211 | else: 212 | output_state_dict = {} 213 | for k, v in model_to_save.named_parameters(): 214 | 215 | if hasattr(v, 'ds_id'): 216 | with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v 217 | ]), 218 | enabled=zero_stage_3): 219 | v_p = v.data.cpu() 220 | else: 221 | v_p = v.cpu() 222 | if global_rank == 0 and "lora" not in k: 223 | output_state_dict[k] = v_p 224 | if global_rank == 0: 225 | torch.save(output_state_dict, output_model_file) 226 | del output_state_dict 227 | 228 | class ExponentialMovingAverage: 229 | 230 | def __init__(self, alpha=0.9): 231 | self.alpha = alpha 232 | self.ema = None 233 | 234 | def update(self, num): 235 | prev_ema = num if self.ema is None else self.ema 236 | self.ema = self.alpha * prev_ema + (1.0 - self.alpha) * num 237 | return self.ema 238 | 239 | def get(self): 240 | return self.ema if self.ema is not None else 0. -------------------------------------------------------------------------------- /sft/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupDecayLR", 23 | "params": { 24 | "warmup_min_lr": 1e-5, 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto", 27 | "total_num_steps": "auto" 28 | } 29 | }, 30 | 31 | "zero_optimization": { 32 | "stage": 2, 33 | "allgather_partitions": true, 34 | "allgather_bucket_size": 2e8, 35 | "overlap_comm": true, 36 | "reduce_scatter": true, 37 | "reduce_bucket_size": 2e8, 38 | "contiguous_gradients": true 39 | }, 40 | 41 | "gradient_accumulation_steps": "auto", 42 | "gradient_clipping": "auto", 43 | "steps_per_print": 2000, 44 | "train_batch_size": "auto", 45 | "train_micro_batch_size_per_gpu": "auto", 46 | "wall_clock_breakdown": false 47 | } -------------------------------------------------------------------------------- /sft/model/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_type": "miaomiao", 3 | "architectures": [ 4 | "MiaomiaoModel" 5 | ], 6 | "auto_map": { 7 | "AutoConfig": "configuration_miaomiao.MiaomiaoConfig", 8 | "AutoModel": "modeling_miaomiao.MiaomiaoModel", 9 | "AutoModelForCausalLM": "modeling_miaomiao.MiaomiaoForCausalLM" 10 | }, 11 | "attention_dropout": 0.0, 12 | "bos_token_id": 32005, 13 | "eos_token_id": 32005, 14 | "hidden_act": "silu", 15 | "hidden_size": 512, 16 | "initializer_range": 0.02, 17 | "intermediate_size": 2752, 18 | "max_position_embeddings": 131072, 19 | "max_window_layers": 28, 20 | "num_attention_heads": 16, 21 | "num_hidden_layers": 24, 22 | "num_key_value_heads": 16, 23 | "rms_norm_eps": 1e-06, 24 | "rope_theta": 1000000.0, 25 | "sliding_window": 131072, 26 | "tie_word_embeddings": false, 27 | "torch_dtype": "bfloat16", 28 | "transformers_version": "4.37.2", 29 | "use_cache": true, 30 | "use_sliding_window": false, 31 | "vocab_size": 32006 32 | } -------------------------------------------------------------------------------- /sft/model/configuration_miaomiao.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | """ Miaomiao model configuration""" 4 | 5 | from transformers.configuration_utils import PretrainedConfig 6 | from transformers.utils import logging 7 | 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | class MiaomiaoConfig(PretrainedConfig): 13 | 14 | model_type = "miaomiao" 15 | keys_to_ignore_at_inference = ["past_key_values"] 16 | 17 | def __init__( 18 | self, 19 | vocab_size=32000, 20 | hidden_size=4096, 21 | intermediate_size=11008, 22 | num_hidden_layers=32, 23 | num_attention_heads=32, 24 | num_key_value_heads=None, 25 | hidden_act="silu", 26 | max_position_embeddings=2048, 27 | initializer_range=0.02, 28 | rms_norm_eps=1e-6, 29 | use_cache=True, 30 | pad_token_id=None, 31 | bos_token_id=1, 32 | eos_token_id=2, 33 | pretraining_tp=1, 34 | tie_word_embeddings=False, 35 | rope_theta=10000.0, 36 | rope_scaling=None, 37 | attention_bias=False, 38 | attention_dropout=0.0, 39 | mlp_bias=False, 40 | _attn_implementation="eager", 41 | **kwargs, 42 | ): 43 | self.vocab_size = vocab_size 44 | self.max_position_embeddings = max_position_embeddings 45 | self.hidden_size = hidden_size 46 | self.intermediate_size = intermediate_size 47 | self.num_hidden_layers = num_hidden_layers 48 | self.num_attention_heads = num_attention_heads 49 | 50 | # for backward compatibility 51 | if num_key_value_heads is None: 52 | num_key_value_heads = num_attention_heads 53 | 54 | self.num_key_value_heads = num_key_value_heads 55 | self.hidden_act = hidden_act 56 | self.initializer_range = initializer_range 57 | self.rms_norm_eps = rms_norm_eps 58 | self.pretraining_tp = pretraining_tp 59 | self.use_cache = use_cache 60 | self.rope_theta = rope_theta 61 | self.rope_scaling = rope_scaling 62 | self._rope_scaling_validation() 63 | self.attention_bias = attention_bias 64 | self.attention_dropout = attention_dropout 65 | self.mlp_bias = mlp_bias 66 | self._attn_implementation = _attn_implementation 67 | super().__init__( 68 | pad_token_id=pad_token_id, 69 | bos_token_id=bos_token_id, 70 | eos_token_id=eos_token_id, 71 | tie_word_embeddings=tie_word_embeddings, 72 | **kwargs, 73 | ) 74 | 75 | def _rope_scaling_validation(self): 76 | """ 77 | Validate the `rope_scaling` configuration. 78 | """ 79 | if self.rope_scaling is None: 80 | return 81 | 82 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 83 | raise ValueError( 84 | "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" 85 | ) 86 | rope_scaling_type = self.rope_scaling.get("type", None) 87 | rope_scaling_factor = self.rope_scaling.get("factor", None) 88 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 89 | raise ValueError( 90 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 91 | ) 92 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 93 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 94 | -------------------------------------------------------------------------------- /sft/model/tokenization_miaomiao.py: -------------------------------------------------------------------------------- 1 | 2 | """Tokenization classes for Miaomiao.""" 3 | 4 | import json 5 | import os 6 | import unicodedata 7 | from functools import lru_cache 8 | from typing import Optional, Tuple 9 | 10 | import regex as re 11 | 12 | from transformers import AddedToken, PreTrainedTokenizer 13 | from transformers.utils import logging 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | VOCAB_FILES_NAMES = { 19 | "vocab_file": "vocab.json", 20 | "merges_file": "merges.txt", 21 | } 22 | 23 | 24 | MAX_MODEL_INPUT_SIZES = {"miaomiao/miaomiao-tokenizer": 1024} 25 | 26 | PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 27 | 28 | 29 | @lru_cache() 30 | # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode 31 | def bytes_to_unicode(): 32 | 33 | bs = ( 34 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 35 | ) 36 | cs = bs[:] 37 | n = 0 38 | for b in range(2**8): 39 | if b not in bs: 40 | bs.append(b) 41 | cs.append(2**8 + n) 42 | n += 1 43 | cs = [chr(n) for n in cs] 44 | return dict(zip(bs, cs)) 45 | 46 | 47 | # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs 48 | def get_pairs(word): 49 | """ 50 | Return set of symbol pairs in a word. 51 | 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | class MiaomiaoTokenizer(PreTrainedTokenizer): 63 | vocab_files_names = VOCAB_FILES_NAMES 64 | model_input_names = ["input_ids", "attention_mask"] 65 | 66 | def __init__( 67 | self, 68 | vocab_file, 69 | merges_file, 70 | errors="replace", 71 | unk_token="<|endoftext|>", 72 | bos_token=None, 73 | eos_token="<|im_end|>", 74 | pad_token="<|endoftext|>", 75 | clean_up_tokenization_spaces=False, 76 | split_special_tokens=False, 77 | **kwargs, 78 | ): 79 | bos_token = ( 80 | AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) 81 | if isinstance(bos_token, str) 82 | else bos_token 83 | ) 84 | eos_token = ( 85 | AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) 86 | if isinstance(eos_token, str) 87 | else eos_token 88 | ) 89 | unk_token = ( 90 | AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) 91 | if isinstance(unk_token, str) 92 | else unk_token 93 | ) 94 | pad_token = ( 95 | AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) 96 | if isinstance(pad_token, str) 97 | else pad_token 98 | ) 99 | 100 | with open(vocab_file, encoding="utf-8") as vocab_handle: 101 | self.encoder = json.load(vocab_handle) 102 | self.decoder = {v: k for k, v in self.encoder.items()} 103 | self.errors = errors # how to handle errors in decoding 104 | self.byte_encoder = bytes_to_unicode() 105 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 106 | bpe_merges = [] 107 | with open(merges_file, encoding="utf-8") as merges_handle: 108 | for i, line in enumerate(merges_handle): 109 | line = line.strip() 110 | if (i == 0 and line.startswith("#version:")) or not line: 111 | continue 112 | bpe_merges.append(tuple(line.split())) 113 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 114 | # NOTE: the cache can grow without bound and will get really large for long running processes 115 | # (esp. for texts of language that do not use space between word, e.g. Chinese); technically 116 | # not a memory leak but appears as one. 117 | # GPT2Tokenizer has the same problem, so let's be consistent. 118 | self.cache = {} 119 | 120 | self.pat = re.compile(PRETOKENIZE_REGEX) 121 | 122 | if kwargs.get("add_prefix_space", False): 123 | logger.warning_once( 124 | f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." 125 | ) 126 | 127 | super().__init__( 128 | errors=errors, 129 | bos_token=bos_token, 130 | eos_token=eos_token, 131 | pad_token=pad_token, 132 | unk_token=unk_token, 133 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 134 | split_special_tokens=split_special_tokens, 135 | **kwargs, 136 | ) 137 | 138 | @property 139 | def vocab_size(self) -> int: 140 | return len(self.encoder) 141 | 142 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab 143 | def get_vocab(self): 144 | return dict(self.encoder, **self.added_tokens_encoder) 145 | 146 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe 147 | @lru_cache(maxsize=100) # 设置缓存大小为100 148 | def bpe(self, token): 149 | # if token in self.cache: 150 | # return self.cache[token] 151 | word = tuple(token) 152 | pairs = get_pairs(word) 153 | 154 | if not pairs: 155 | return token 156 | 157 | while True: 158 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 159 | if bigram not in self.bpe_ranks: 160 | break 161 | first, second = bigram 162 | new_word = [] 163 | i = 0 164 | while i < len(word): 165 | try: 166 | j = word.index(first, i) 167 | except ValueError: 168 | new_word.extend(word[i:]) 169 | break 170 | else: 171 | new_word.extend(word[i:j]) 172 | i = j 173 | 174 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 175 | new_word.append(first + second) 176 | i += 2 177 | else: 178 | new_word.append(word[i]) 179 | i += 1 180 | new_word = tuple(new_word) 181 | word = new_word 182 | if len(word) == 1: 183 | break 184 | else: 185 | pairs = get_pairs(word) 186 | word = " ".join(word) 187 | # self.cache[token] = word 188 | return word 189 | 190 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize 191 | def _tokenize(self, text): 192 | """Tokenize a string.""" 193 | bpe_tokens = [] 194 | for token in re.findall(self.pat, text): 195 | token = "".join( 196 | self.byte_encoder[b] for b in token.encode("utf-8") 197 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) 198 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 199 | return bpe_tokens 200 | 201 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id 202 | def _convert_token_to_id(self, token): 203 | """Converts a token (str) in an id using the vocab.""" 204 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 205 | 206 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token 207 | def _convert_id_to_token(self, index): 208 | """Converts an index (integer) in a token (str) using the vocab.""" 209 | return self.decoder.get(index) 210 | 211 | 212 | def convert_tokens_to_string(self, tokens): 213 | """Converts a sequence of tokens (string) in a single string.""" 214 | text = "".join(tokens) 215 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 216 | return text 217 | 218 | def decode( 219 | self, 220 | token_ids, 221 | skip_special_tokens: bool = False, 222 | clean_up_tokenization_spaces: Optional[bool] = False, 223 | spaces_between_special_tokens: bool = False, 224 | **kwargs, 225 | ) -> str: 226 | 227 | return super().decode( 228 | token_ids, 229 | skip_special_tokens=skip_special_tokens, 230 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 231 | spaces_between_special_tokens=spaces_between_special_tokens, 232 | **kwargs, 233 | ) 234 | 235 | 236 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 237 | if not os.path.isdir(save_directory): 238 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 239 | return 240 | vocab_file = os.path.join( 241 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 242 | ) 243 | merge_file = os.path.join( 244 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] 245 | ) 246 | 247 | with open(vocab_file, "w", encoding="utf-8") as f: 248 | f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") 249 | 250 | index = 0 251 | with open(merge_file, "w", encoding="utf-8") as writer: 252 | writer.write("#version: 0.2\n") 253 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 254 | if index != token_index: 255 | logger.warning( 256 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." 257 | " Please check that the tokenizer is not corrupted!" 258 | ) 259 | index = token_index 260 | writer.write(" ".join(bpe_tokens) + "\n") 261 | index += 1 262 | 263 | return vocab_file, merge_file 264 | 265 | def prepare_for_tokenization(self, text, **kwargs): 266 | text = unicodedata.normalize("NFC", text) 267 | return (text, kwargs) 268 | -------------------------------------------------------------------------------- /sft/model/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_map": { 3 | "AutoTokenizer": [ 4 | "tokenization_miaomiao.MiaomiaoTokenizer", 5 | null 6 | ] 7 | }, 8 | "add_prefix_space": false, 9 | "added_tokens_decoder": { 10 | "32000": { 11 | "content": "system", 12 | "lstrip": false, 13 | "normalized": false, 14 | "rstrip": false, 15 | "single_word": false, 16 | "special": true 17 | }, 18 | "32001": { 19 | "content": "user", 20 | "lstrip": false, 21 | "normalized": false, 22 | "rstrip": false, 23 | "single_word": false, 24 | "special": true 25 | }, 26 | "32002": { 27 | "content": "assistant", 28 | "lstrip": false, 29 | "normalized": false, 30 | "rstrip": false, 31 | "single_word": false, 32 | "special": true 33 | }, 34 | "32003": { 35 | "content": "<|endoftext|>", 36 | "lstrip": false, 37 | "normalized": false, 38 | "rstrip": false, 39 | "single_word": false, 40 | "special": true 41 | }, 42 | "32004": { 43 | "content": "<|im_start|>", 44 | "lstrip": false, 45 | "normalized": false, 46 | "rstrip": false, 47 | "single_word": false, 48 | "special": true 49 | }, 50 | "32005": { 51 | "content": "<|im_end|>", 52 | "lstrip": false, 53 | "normalized": false, 54 | "rstrip": false, 55 | "single_word": false, 56 | "special": true 57 | } 58 | }, 59 | "additional_special_tokens": [ 60 | "<|im_start|>", 61 | "<|im_end|>" 62 | ], 63 | "bos_token": null, 64 | "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\n你是一个由喵阿姨开发的喵喵小助手<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", 65 | "clean_up_tokenization_spaces": false, 66 | "eos_token": "<|im_end|>", 67 | "errors": "replace", 68 | "model_max_length": 32768, 69 | "pad_token": "<|endoftext|>", 70 | "split_special_tokens": false, 71 | "tokenizer_class": "MiaomiaoTokenizer", 72 | "unk_token": null 73 | } -------------------------------------------------------------------------------- /sft/sft.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | import transformers 4 | from sft_dataset import SFTDataset 5 | from transformers import ( 6 | AutoModelForCausalLM, 7 | HfArgumentParser, 8 | Trainer, 9 | TrainingArguments, 10 | AutoTokenizer, 11 | set_seed, 12 | ) 13 | from transformers.trainer_callback import TrainerCallback 14 | import torch 15 | import os 16 | import logging 17 | import glob 18 | import random 19 | import numpy as np 20 | from typing import Dict, Optional, Sequence 21 | 22 | IGNORE_INDEX = -100 23 | # 设置随机种子 24 | def set_seed(seed): 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | torch.manual_seed(seed) 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed_all(seed) 30 | 31 | 32 | class LoggingCallback(TrainerCallback): 33 | def __init__(self, logger): 34 | self.logger = logger 35 | 36 | def on_log(self, args, state, control, logs=None, **kwargs): 37 | if logs is not None: 38 | self.logger.info(logs) 39 | 40 | 41 | @dataclass 42 | class ModelArguments: 43 | model_path: Optional[str] = None 44 | torch_dtype: Optional[str] = None 45 | 46 | @dataclass 47 | class DataTrainingArguments: 48 | train_dataset_file: Optional[str] = None 49 | overwrite_cache: bool = False 50 | preprocessing_num_workers: Optional[int] = None 51 | block_size: Optional[int] = None 52 | 53 | 54 | @dataclass 55 | class MyTrainingArguments(TrainingArguments): 56 | modules_to_save: Optional[str] = None 57 | 58 | 59 | # 模型初始化方式 60 | init_from: Optional[str] = "scratch" 61 | use_device: Optional[str] = 'cuda' 62 | use_compile: Optional[bool] = False 63 | log_file: Optional[str] = None 64 | nnodes: Optional[int] = None 65 | nproc_per_node: Optional[int] = None 66 | 67 | 68 | def init_model(model_args): 69 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_path, trust_remote_code=True) 70 | model = AutoModelForCausalLM.from_pretrained( 71 | model_args.model_path, 72 | trust_remote_code=True 73 | ) 74 | return tokenizer, model 75 | 76 | 77 | 78 | @dataclass 79 | class DataCollatorForSFTDataset(object): 80 | """Collate examples for supervised fine-tuning.""" 81 | 82 | tokenizer: transformers.PreTrainedTokenizer 83 | 84 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 85 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 86 | input_ids = torch.nn.utils.rnn.pad_sequence( 87 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 88 | ) 89 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 90 | print(f"DataCollatorForSFTDataset:{input_ids}") 91 | return dict( 92 | input_ids=input_ids, 93 | labels=labels, 94 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 95 | ) 96 | 97 | 98 | def main(): 99 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, MyTrainingArguments)) 100 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 101 | 102 | # 设置日志记录器 103 | logging.basicConfig(filename=training_args.log_file, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 104 | logger = logging.getLogger(__name__) 105 | # 创建文件处理器,并设置写模式 106 | file_handler = logging.FileHandler(training_args.log_file, mode='w') 107 | file_handler.setLevel(logging.INFO) 108 | file_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 109 | file_handler.setFormatter(file_formatter) 110 | logger.addHandler(file_handler) 111 | # 输出日志到控制台(可选) 112 | console_handler = logging.StreamHandler() 113 | console_handler.setLevel(logging.INFO) 114 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') 115 | console_handler.setFormatter(formatter) 116 | logger.addHandler(console_handler) 117 | 118 | set_seed(training_args.seed) 119 | 120 | tokenizer, model =init_model(model_args) 121 | model.to(training_args.use_device) 122 | 123 | if training_args.use_compile: 124 | model = torch.compile(model) 125 | 126 | 127 | total_params = sum(p.numel() for p in model.parameters()) 128 | trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 129 | logger.info(f"总参数: {total_params}") 130 | logger.info(f"可训练参数: {trainable_params}") 131 | 132 | logger.info(f"torch_dtype:{model_args.torch_dtype}") 133 | logger.info(f"training_args.bf16: {training_args.bf16}") 134 | 135 | 136 | train_ds = SFTDataset(data_path=data_args.train_dataset_file, tokenizer=tokenizer, max_length=data_args.block_size, prompt_max_len=int(data_args.block_size/2), answer_max_len=int(data_args.block_size/2), seed=training_args.seed) 137 | logger.info(f"Train dataset size: {len(train_ds)}") 138 | 139 | 140 | trainer = Trainer( 141 | model=model, 142 | args=training_args, 143 | train_dataset=train_ds, 144 | callbacks=[LoggingCallback(logger)], # 添加自定义回调 145 | ) 146 | print(training_args.bf16) 147 | 148 | trainer.train() 149 | 150 | 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /sft/sft.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | lr=1e-4 6 | block_size=1024 7 | 8 | per_device_train_batch_size=24 9 | gradient_accumulation_steps=1 10 | model_path=./model 11 | train_dataset_file=./sft.jsonl 12 | log_file=./log/sft.log 13 | output_dir=./output 14 | deepspeed_config_file=./ds_config.json 15 | random_seed=42 16 | torchrun --nnodes 1 --nproc_per_node 2 sft.py \ 17 | --deepspeed ${deepspeed_config_file} \ 18 | --model_path ${model_path} \ 19 | --train_dataset_file ${train_dataset_file} \ 20 | --per_device_train_batch_size ${per_device_train_batch_size} \ 21 | --do_train \ 22 | --bf16 True\ 23 | --torch_dtype bfloat16 \ 24 | --seed ${random_seed} \ 25 | --num_train_epochs 3 \ 26 | --logging_strategy steps \ 27 | --logging_steps 100 \ 28 | --log_file ${log_file} \ 29 | --logging_first_step True \ 30 | --adam_beta1 0.9 \ 31 | --adam_beta1 0.95 \ 32 | --lr_scheduler_type cosine \ 33 | --learning_rate ${lr} \ 34 | --warmup_ratio 0.05 \ 35 | --weight_decay 0.01 \ 36 | --save_strategy epoch \ 37 | --save_total_limit 3 \ 38 | --save_steps 0.01 \ 39 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 40 | --block_size ${block_size} \ 41 | --output_dir ${output_dir} \ 42 | --overwrite_output_dir \ 43 | --ddp_timeout 30000 \ 44 | --use_device cuda \ 45 | --use_compile False \ -------------------------------------------------------------------------------- /sft/sft_data_filted.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | import torch 3 | import torch.nn.functional as F 4 | import json 5 | from tqdm import tqdm 6 | 7 | def calculate_perplexity(model, tokenizer, messages, device): 8 | 9 | formatted_messages = [ 10 | {"role": "user", "content": messages[0]['value']}, 11 | {"role": "assistant", "content": messages[1]['value']} 12 | ] 13 | user_input = [ 14 | {"role": "user", "content": messages[0]['value']} 15 | ] 16 | 17 | # 编码输入 18 | inputs_text = tokenizer.apply_chat_template( 19 | user_input, 20 | tokenize=False, 21 | add_generation_prompt=True 22 | ) 23 | #print(inputs_text) 24 | inputs = tokenizer(inputs_text, return_tensors="pt").to(device) 25 | 26 | # 编码输入 27 | full_text = tokenizer.apply_chat_template( 28 | formatted_messages, 29 | tokenize=False, 30 | add_generation_prompt=False 31 | ) 32 | #print(full_text) 33 | 34 | full_inputs = tokenizer(full_text, return_tensors="pt").to(device) 35 | 36 | # 计算给定用户输入情况下生成助理响应的困惑度 37 | with torch.no_grad(): 38 | outputs = model(**full_inputs) 39 | logits = outputs.logits 40 | 41 | # 只关注助理响应部分的logits 42 | start_pos = inputs.input_ids.size(1) 43 | shift_logits = logits[:, start_pos:-1, :].contiguous() 44 | shift_labels = full_inputs['input_ids'][:, start_pos+1:].contiguous() 45 | 46 | # 计算交叉熵损失 47 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none') 48 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 49 | loss = loss.view(shift_labels.size()) 50 | perplexity_given_user = torch.exp(loss.mean()) 51 | #print(f"给定输入的困惑度:{perplexity_given_user}") 52 | # 计算直接生成助理响应的困惑度 53 | with torch.no_grad(): 54 | assistant_input = messages[1]["value"] 55 | assistant_inputs = tokenizer(assistant_input, return_tensors="pt").to(device) 56 | 57 | outputs = model(**assistant_inputs) 58 | logits = outputs.logits 59 | 60 | # Shift the logits and labels to ignore the first token 61 | shift_logits = logits[:, :-1, :].contiguous() 62 | shift_labels = assistant_inputs['input_ids'][:, 1:].contiguous() 63 | 64 | # Flatten the logits and labels 65 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 66 | loss = loss.view(shift_labels.size()) 67 | 68 | # 计算每个token的困惑度 69 | perplexity_direct = torch.exp(loss.mean()) 70 | #print(f"直接生成的困惑度:{perplexity_direct}") 71 | 72 | return perplexity_given_user.item(), perplexity_direct.item() 73 | 74 | def main(): 75 | model_name = "./Qwen2-0.5B-Instruct" 76 | device = "cuda" # 设备 77 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") 78 | tokenizer = AutoTokenizer.from_pretrained(model_name) 79 | input_data = './depulication_firefly.jsonl' 80 | output_data = './depulication_firefly_ppl.jsonl' 81 | # 打开输入文件 82 | with open(input_data, 'r', encoding='utf-8') as f: 83 | lines = f.readlines() 84 | 85 | with open(output_data, 'w', encoding='utf-8') as out_f: 86 | # 逐行处理输入数据并计算困惑度 87 | for line in tqdm(lines, desc="Processing"): 88 | data = json.loads(line) 89 | messages = data["messages"] 90 | perplexity_given_user, perplexity_direct = calculate_perplexity(model, tokenizer, messages, device) 91 | result = { 92 | "messages": messages, 93 | "ppl_a_q": perplexity_given_user, 94 | "ppl_a": perplexity_direct, 95 | "ifd": perplexity_given_user / perplexity_direct, 96 | } 97 | out_f.write(json.dumps(result, ensure_ascii=False) + '\n') 98 | # print(f"Perplexity given user input: {perplexity_given_user}") 99 | # print(f"Perplexity of direct assistant response: {perplexity_direct}") 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /sft/sft_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pandas as pd 3 | import numpy as np 4 | from torch.utils.data import Dataset,DataLoader 5 | import torch 6 | from sklearn.model_selection import train_test_split 7 | import json 8 | from datasets import load_dataset,Features, Value 9 | import copy 10 | class SFTDataset(Dataset): 11 | def __init__(self, data_path, tokenizer, max_length=1024, prompt_max_len=512, answer_max_len=512, seed=42): 12 | super().__init__() 13 | IGNORE_INDEX = -100 14 | self.max_length = max_length 15 | self.prompt_max_len = prompt_max_len 16 | self.answer_max_len = answer_max_len 17 | self.tokenizer = tokenizer 18 | self.input_ids = [] 19 | self.labels = [] 20 | self.attention_mask = [] 21 | # 指定自定义字段 22 | features = Features({ 23 | 'prompt': Value('string'), 24 | 'answer': Value('string') 25 | }) 26 | sft_dataset = load_dataset('json', data_files=data_path, features=features) 27 | data = [] 28 | # 遍历数据集并取出每个元素 29 | for example in sft_dataset['train']: 30 | prompt = example['prompt'] 31 | answer = example['answer'] 32 | messages = [ 33 | {"role": "user", "content": prompt} 34 | ] 35 | prompt_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 36 | answer_text = answer + tokenizer.eos_token 37 | 38 | prompt_id = self.tokenizer.encode(prompt_text) 39 | if (len(prompt_id) > self.prompt_max_len): 40 | prompt_id = prompt_id[:self.prompt_max_len] 41 | 42 | answer_id = tokenizer.encode(answer_text) 43 | if (len(answer_id) > self.answer_max_len): 44 | answer_id = prompt_id[:self.prompt_max_len] 45 | input_id = prompt_id + answer_id 46 | labels = [self.tokenizer.pad_token_id] * len(prompt_id) + answer_id 47 | pad_len = self.max_length - len(input_id) 48 | input_id = input_id + [self.tokenizer.pad_token_id] * pad_len 49 | labels = labels + [self.tokenizer.pad_token_id] * pad_len 50 | labels = [(l if l != self.tokenizer.pad_token_id else IGNORE_INDEX ) for l in labels] 51 | input_id = torch.LongTensor(input_id) 52 | labels = torch.LongTensor(labels) 53 | attention_mask = input_id.ne(self.tokenizer.pad_token_id) 54 | data.append({ 55 | "input_ids": input_id, 56 | "labels": labels, 57 | "attention_mask": attention_mask 58 | }) 59 | 60 | # 打乱数据集 61 | random.seed(seed) 62 | random.shuffle(data) 63 | 64 | for item in data: 65 | self.input_ids.append(item["input_ids"]) 66 | self.labels.append(item["labels"]) 67 | self.attention_mask.append(item["attention_mask"]) 68 | 69 | 70 | def __len__(self): 71 | return len(self.input_ids) 72 | 73 | def __getitem__(self, i: int): 74 | return { 75 | "input_ids": self.input_ids[i], 76 | "labels": self.labels[i], 77 | "attention_mask": self.attention_mask[i], 78 | } 79 | 80 | -------------------------------------------------------------------------------- /sft/test_sft_model.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | 3 | 4 | device = "cuda" # the device to load the model onto 5 | 6 | model = AutoModelForCausalLM.from_pretrained( 7 | './model', 8 | torch_dtype="auto", 9 | device_map="auto", 10 | trust_remote_code=True 11 | ) 12 | tokenizer = AutoTokenizer.from_pretrained('./miaomiao_tokenizer', trust_remote_code=True) 13 | 14 | prompt_list = ["你知道北京吗? ", 15 | "你知道杭州有哪些美食吗?", 16 | "你知道中国的四大名著吗?", 17 | "你了解美国的历史吗?", 18 | "左手一只鸭,右手一只鸡。交换两次后左右手里各是什么?", 19 | "鸡兔同笼,共35只头,94只脚,问鸡兔各多少?", 20 | "世界上最大的动物是什么?", 21 | "介绍一下刘德华。", 22 | "介绍一下中国。" 23 | ] 24 | for prompt in prompt_list: 25 | messages = [ 26 | {"role": "user", "content": prompt} 27 | ] 28 | text = tokenizer.apply_chat_template( 29 | messages, 30 | tokenize=False, 31 | add_generation_prompt=True 32 | ) 33 | model_inputs = tokenizer([text], return_tensors="pt").to(device) 34 | 35 | generated_ids = model.generate( 36 | **model_inputs, 37 | max_new_tokens=512, 38 | do_sample=True, 39 | temperature = 0.9, 40 | top_k = 30 41 | ) 42 | generated_ids = [ 43 | output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 44 | ] 45 | 46 | response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 47 | print(f"question:{prompt}") 48 | print(f"response:{response}") -------------------------------------------------------------------------------- /train_tokenizer/miaomiao_tokenizer/tokenization_miaomiao.py: -------------------------------------------------------------------------------- 1 | 2 | """Tokenization classes for Miaomiao.""" 3 | 4 | import json 5 | import os 6 | import unicodedata 7 | from functools import lru_cache 8 | from typing import Optional, Tuple 9 | 10 | import regex as re 11 | 12 | from transformers import AddedToken, PreTrainedTokenizer 13 | from transformers.utils import logging 14 | 15 | 16 | logger = logging.get_logger(__name__) 17 | 18 | VOCAB_FILES_NAMES = { 19 | "vocab_file": "vocab.json", 20 | "merges_file": "merges.txt", 21 | } 22 | 23 | 24 | MAX_MODEL_INPUT_SIZES = {"miaomiao/miaomiao-tokenizer": 1024} 25 | 26 | PRETOKENIZE_REGEX = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 27 | 28 | 29 | @lru_cache() 30 | # Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode 31 | def bytes_to_unicode(): 32 | 33 | bs = ( 34 | list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 35 | ) 36 | cs = bs[:] 37 | n = 0 38 | for b in range(2**8): 39 | if b not in bs: 40 | bs.append(b) 41 | cs.append(2**8 + n) 42 | n += 1 43 | cs = [chr(n) for n in cs] 44 | return dict(zip(bs, cs)) 45 | 46 | 47 | # Copied from transformers.models.gpt2.tokenization_gpt2.get_pairs 48 | def get_pairs(word): 49 | """ 50 | Return set of symbol pairs in a word. 51 | 52 | Word is represented as tuple of symbols (symbols being variable-length strings). 53 | """ 54 | pairs = set() 55 | prev_char = word[0] 56 | for char in word[1:]: 57 | pairs.add((prev_char, char)) 58 | prev_char = char 59 | return pairs 60 | 61 | 62 | class MiaomiaoTokenizer(PreTrainedTokenizer): 63 | vocab_files_names = VOCAB_FILES_NAMES 64 | model_input_names = ["input_ids", "attention_mask"] 65 | 66 | def __init__( 67 | self, 68 | vocab_file, 69 | merges_file, 70 | errors="replace", 71 | unk_token="<|endoftext|>", 72 | bos_token=None, 73 | eos_token="<|im_end|>", 74 | pad_token="<|endoftext|>", 75 | clean_up_tokenization_spaces=False, 76 | split_special_tokens=False, 77 | **kwargs, 78 | ): 79 | bos_token = ( 80 | AddedToken(bos_token, lstrip=False, rstrip=False, special=True, normalized=False) 81 | if isinstance(bos_token, str) 82 | else bos_token 83 | ) 84 | eos_token = ( 85 | AddedToken(eos_token, lstrip=False, rstrip=False, special=True, normalized=False) 86 | if isinstance(eos_token, str) 87 | else eos_token 88 | ) 89 | unk_token = ( 90 | AddedToken(unk_token, lstrip=False, rstrip=False, special=True, normalized=False) 91 | if isinstance(unk_token, str) 92 | else unk_token 93 | ) 94 | pad_token = ( 95 | AddedToken(pad_token, lstrip=False, rstrip=False, special=True, normalized=False) 96 | if isinstance(pad_token, str) 97 | else pad_token 98 | ) 99 | 100 | with open(vocab_file, encoding="utf-8") as vocab_handle: 101 | self.encoder = json.load(vocab_handle) 102 | self.decoder = {v: k for k, v in self.encoder.items()} 103 | self.errors = errors # how to handle errors in decoding 104 | self.byte_encoder = bytes_to_unicode() 105 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 106 | bpe_merges = [] 107 | with open(merges_file, encoding="utf-8") as merges_handle: 108 | for i, line in enumerate(merges_handle): 109 | line = line.strip() 110 | if (i == 0 and line.startswith("#version:")) or not line: 111 | continue 112 | bpe_merges.append(tuple(line.split())) 113 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 114 | # NOTE: the cache can grow without bound and will get really large for long running processes 115 | # (esp. for texts of language that do not use space between word, e.g. Chinese); technically 116 | # not a memory leak but appears as one. 117 | # GPT2Tokenizer has the same problem, so let's be consistent. 118 | self.cache = {} 119 | 120 | self.pat = re.compile(PRETOKENIZE_REGEX) 121 | 122 | if kwargs.get("add_prefix_space", False): 123 | logger.warning_once( 124 | f"{self.__class__.__name} does not support `add_prefix_space`, setting it to True has no effect." 125 | ) 126 | 127 | super().__init__( 128 | errors=errors, 129 | bos_token=bos_token, 130 | eos_token=eos_token, 131 | pad_token=pad_token, 132 | unk_token=unk_token, 133 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 134 | split_special_tokens=split_special_tokens, 135 | **kwargs, 136 | ) 137 | 138 | @property 139 | def vocab_size(self) -> int: 140 | return len(self.encoder) 141 | 142 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.get_vocab 143 | def get_vocab(self): 144 | return dict(self.encoder, **self.added_tokens_encoder) 145 | 146 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer.bpe 147 | @lru_cache(maxsize=100) # 设置缓存大小为100 148 | def bpe(self, token): 149 | # if token in self.cache: 150 | # return self.cache[token] 151 | word = tuple(token) 152 | pairs = get_pairs(word) 153 | 154 | if not pairs: 155 | return token 156 | 157 | while True: 158 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 159 | if bigram not in self.bpe_ranks: 160 | break 161 | first, second = bigram 162 | new_word = [] 163 | i = 0 164 | while i < len(word): 165 | try: 166 | j = word.index(first, i) 167 | except ValueError: 168 | new_word.extend(word[i:]) 169 | break 170 | else: 171 | new_word.extend(word[i:j]) 172 | i = j 173 | 174 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 175 | new_word.append(first + second) 176 | i += 2 177 | else: 178 | new_word.append(word[i]) 179 | i += 1 180 | new_word = tuple(new_word) 181 | word = new_word 182 | if len(word) == 1: 183 | break 184 | else: 185 | pairs = get_pairs(word) 186 | word = " ".join(word) 187 | # self.cache[token] = word 188 | return word 189 | 190 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._tokenize 191 | def _tokenize(self, text): 192 | """Tokenize a string.""" 193 | bpe_tokens = [] 194 | for token in re.findall(self.pat, text): 195 | token = "".join( 196 | self.byte_encoder[b] for b in token.encode("utf-8") 197 | ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case) 198 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) 199 | return bpe_tokens 200 | 201 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_token_to_id 202 | def _convert_token_to_id(self, token): 203 | """Converts a token (str) in an id using the vocab.""" 204 | return self.encoder.get(token, self.encoder.get(self.unk_token)) 205 | 206 | # Copied from transformers.models.gpt2.tokenization_gpt2.GPT2Tokenizer._convert_id_to_token 207 | def _convert_id_to_token(self, index): 208 | """Converts an index (integer) in a token (str) using the vocab.""" 209 | return self.decoder.get(index) 210 | 211 | 212 | def convert_tokens_to_string(self, tokens): 213 | """Converts a sequence of tokens (string) in a single string.""" 214 | text = "".join(tokens) 215 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 216 | return text 217 | 218 | def decode( 219 | self, 220 | token_ids, 221 | skip_special_tokens: bool = False, 222 | clean_up_tokenization_spaces: Optional[bool] = False, 223 | spaces_between_special_tokens: bool = False, 224 | **kwargs, 225 | ) -> str: 226 | 227 | return super().decode( 228 | token_ids, 229 | skip_special_tokens=skip_special_tokens, 230 | clean_up_tokenization_spaces=clean_up_tokenization_spaces, 231 | spaces_between_special_tokens=spaces_between_special_tokens, 232 | **kwargs, 233 | ) 234 | 235 | 236 | def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: 237 | if not os.path.isdir(save_directory): 238 | logger.error(f"Vocabulary path ({save_directory}) should be a directory") 239 | return 240 | vocab_file = os.path.join( 241 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] 242 | ) 243 | merge_file = os.path.join( 244 | save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"] 245 | ) 246 | 247 | with open(vocab_file, "w", encoding="utf-8") as f: 248 | f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n") 249 | 250 | index = 0 251 | with open(merge_file, "w", encoding="utf-8") as writer: 252 | writer.write("#version: 0.2\n") 253 | for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): 254 | if index != token_index: 255 | logger.warning( 256 | f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive." 257 | " Please check that the tokenizer is not corrupted!" 258 | ) 259 | index = token_index 260 | writer.write(" ".join(bpe_tokens) + "\n") 261 | index += 1 262 | 263 | return vocab_file, merge_file 264 | 265 | def prepare_for_tokenization(self, text, **kwargs): 266 | text = unicodedata.normalize("NFC", text) 267 | return (text, kwargs) 268 | -------------------------------------------------------------------------------- /train_tokenizer/miaomiao_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_map": { 3 | "AutoTokenizer": [ 4 | "tokenization_miaomiao.MiaomiaoTokenizer", 5 | null 6 | ] 7 | }, 8 | "add_prefix_space": false, 9 | "added_tokens_decoder": { 10 | "32000": { 11 | "content": "system", 12 | "lstrip": false, 13 | "normalized": false, 14 | "rstrip": false, 15 | "single_word": false, 16 | "special": true 17 | }, 18 | "32001": { 19 | "content": "user", 20 | "lstrip": false, 21 | "normalized": false, 22 | "rstrip": false, 23 | "single_word": false, 24 | "special": true 25 | }, 26 | "32002": { 27 | "content": "assistant", 28 | "lstrip": false, 29 | "normalized": false, 30 | "rstrip": false, 31 | "single_word": false, 32 | "special": true 33 | }, 34 | "32003": { 35 | "content": "<|endoftext|>", 36 | "lstrip": false, 37 | "normalized": false, 38 | "rstrip": false, 39 | "single_word": false, 40 | "special": true 41 | }, 42 | "32004": { 43 | "content": "<|im_start|>", 44 | "lstrip": false, 45 | "normalized": false, 46 | "rstrip": false, 47 | "single_word": false, 48 | "special": true 49 | }, 50 | "32005": { 51 | "content": "<|im_end|>", 52 | "lstrip": false, 53 | "normalized": false, 54 | "rstrip": false, 55 | "single_word": false, 56 | "special": true 57 | } 58 | }, 59 | "additional_special_tokens": [ 60 | "<|im_start|>", 61 | "<|im_end|>" 62 | ], 63 | "bos_token": null, 64 | "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\n你是一个由喵阿姨开发的喵喵小助手<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", 65 | "clean_up_tokenization_spaces": false, 66 | "eos_token": "<|im_end|>", 67 | "errors": "replace", 68 | "model_max_length": 32768, 69 | "pad_token": "<|endoftext|>", 70 | "split_special_tokens": false, 71 | "tokenizer_class": "MiaomiaoTokenizer", 72 | "unk_token": null 73 | } -------------------------------------------------------------------------------- /train_tokenizer/train_tokenizer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | from transformers import AutoTokenizer 4 | import json 5 | from datasets import load_dataset 6 | from tokenizers import ( 7 | decoders, 8 | models, 9 | normalizers, 10 | pre_tokenizers, 11 | processors, 12 | trainers, 13 | Tokenizer, 14 | ) 15 | import os 16 | 17 | random.seed(42) 18 | 19 | def train_tokenizer(): 20 | # 读取JSON文件并提取文本数据 21 | def read_texts_from_json(file_path): 22 | with open(file_path, 'r', encoding='utf-8') as f: 23 | for line in f: 24 | data = json.loads(line) 25 | yield data['text'] 26 | 27 | data_path = './tokenizer_data/tokenizer_data.json' 28 | 29 | # 初始化tokenizer 30 | tokenizer = Tokenizer(models.BPE()) 31 | tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) 32 | 33 | # 设置训练器 34 | trainer = trainers.BpeTrainer( 35 | vocab_size=32000, 36 | show_progress=True, 37 | initial_alphabet=pre_tokenizers.ByteLevel.alphabet() 38 | ) 39 | 40 | # 读取文本数据 41 | texts = read_texts_from_json(data_path) 42 | 43 | # 训练tokenizer 44 | tokenizer.train_from_iterator(texts, trainer=trainer) 45 | 46 | # 设置解码器 47 | tokenizer.decoder = decoders.ByteLevel() 48 | 49 | # 保存tokenizer 50 | tokenizer_dir = "./miaomiao_tokenizer" 51 | os.makedirs(tokenizer_dir, exist_ok=True) 52 | tokenizer.save(os.path.join(tokenizer_dir, "tokenizer.json")) 53 | tokenizer.model.save("./miaomiao_tokenizer") 54 | 55 | # 手动创建配置文件 56 | config = { 57 | "auto_map": { 58 | "AutoTokenizer": [ 59 | "tokenization_miaomiao.MiaomiaoTokenizer", 60 | None 61 | ] 62 | }, 63 | "add_prefix_space": False, 64 | "added_tokens_decoder": { 65 | "32000": { 66 | "content": "system", 67 | "lstrip": False, 68 | "normalized": False, 69 | "rstrip": False, 70 | "single_word": False, 71 | "special": True 72 | }, 73 | "32001": { 74 | "content": "user", 75 | "lstrip": False, 76 | "normalized": False, 77 | "rstrip": False, 78 | "single_word": False, 79 | "special": True 80 | }, 81 | "32002": { 82 | "content": "assistant", 83 | "lstrip": False, 84 | "normalized": False, 85 | "rstrip": False, 86 | "single_word": False, 87 | "special": True 88 | }, 89 | "32003": { 90 | "content": "<|endoftext|>", 91 | "lstrip": False, 92 | "normalized": False, 93 | "rstrip": False, 94 | "single_word": False, 95 | "special": True 96 | }, 97 | "32004": { 98 | "content": "<|im_start|>", 99 | "lstrip": False, 100 | "normalized": False, 101 | "rstrip": False, 102 | "single_word": False, 103 | "special": True 104 | }, 105 | "32005": { 106 | "content": "<|im_end|>", 107 | "lstrip": False, 108 | "normalized": False, 109 | "rstrip": False, 110 | "single_word": False, 111 | "special": True 112 | } 113 | }, 114 | "additional_special_tokens": ["<|im_start|>", "<|im_end|>"], 115 | "bos_token": None, 116 | "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\n你是一个由喵阿姨开发的喵喵小助手<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", 117 | "clean_up_tokenization_spaces": False, 118 | "eos_token": "<|im_end|>", 119 | "errors": "replace", 120 | "model_max_length": 32768, 121 | "pad_token": "<|endoftext|>", 122 | "split_special_tokens": False, 123 | "tokenizer_class": "MiaomiaoTokenizer", 124 | "unk_token": None 125 | } 126 | 127 | # 保存配置文件 128 | with open(os.path.join(tokenizer_dir, "tokenizer_config.json"), "w", encoding="utf-8") as config_file: 129 | json.dump(config, config_file, ensure_ascii=False, indent=4) 130 | 131 | print("Tokenizer training completed and saved.") 132 | 133 | def test_tokenizer(): 134 | # 加载保存的分词器 135 | tokenizer = Tokenizer.from_file("./tokenizer/custom/tokenizer.json") 136 | 137 | # 测试分词器 138 | text = "hello word.You are a helpful assistant.今天,我们来训练一个大模型<|im_end|><|endoftext|>" 139 | encoding = tokenizer.encode(text) 140 | 141 | print("Original text:", text) 142 | print("Tokens:", encoding.tokens) 143 | print("Token IDs:", encoding.ids) 144 | # 获取词汇表 145 | vocab = tokenizer.get_vocab() 146 | 147 | # 获取特殊token的ID 148 | special_tokens=["", "<|endoftext|>", "<|im_start|>", "<|im_end|>", "system", "user", "assistant"] 149 | token_ids = {token: vocab[token] for token in special_tokens if token in vocab} 150 | 151 | print("Special tokens IDs:", token_ids) 152 | eos_token_id = token_ids.get("<|im_end|>", None) 153 | print("EOS token ID:", eos_token_id) 154 | print(vocab['<|im_end|>']) 155 | print(tokenizer.eos_token_id) 156 | 157 | 158 | 159 | def main(): 160 | 161 | train_tokenizer() 162 | #test_tokenizer() 163 | 164 | if __name__ == '__main__': 165 | main() 166 | --------------------------------------------------------------------------------