The response has been limited to 50k tokens of the smallest files in the repo. You can remove this limitation by removing the max tokens filter.
├── .gitignore
├── LICENSE
├── README.md
├── chatglm_tokenizer
    ├── tokenization_chatglm.py
    ├── tokenizer.model
    └── tokenizer_config.json
├── data_clean
    ├── clear.py
    ├── functions.py
    └── logger.py
├── data_process.py
├── dataset.py
├── dataset_sft.py
├── eval.py
├── eval_pretrain.py
├── images
    ├── loss_tokens-v1.png
    ├── loss_tokens-v3.png
    └── loss_tokens.png
├── model.py
├── pretrain.py
├── requirements.txt
├── sft.py
└── sft_data_process.py


/.gitignore:
--------------------------------------------------------------------------------
1 | data/


--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
 1 | MIT License
 2 | 
 3 | Copyright (c) 2023 Qu Yuli
 4 | 
 5 | Permission is hereby granted, free of charge, to any person obtaining a copy
 6 | of this software and associated documentation files (the "Software"), to deal
 7 | in the Software without restriction, including without limitation the rights
 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 | 
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 | 
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 | 


--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
  1 | # Baby-Llama2-Chinese
  2 | Created by Limzero & Ambrose & Guolin
  3 | ## 📝介绍
  4 | 本项目致力于构建一个小参数量的中文Llama2仓库。
  5 | 
  6 | 包含:预训练、SFT指令微调、**奖励模型以及强化学习**(待做)完整流程。
  7 | 
  8 | 除此之外,本项目还会梳理一套完整的LLM学习资料(正在进行中)。
  9 | 
 10 | 希望该开源项目可以帮助LLM初学者以最快速度入门!
 11 | 
 12 | ## 📚项目愿景
 13 | - 收集并汇总中文预训练语料,训练一个参数量500M-1B的Llama2-Chinese预训练模型,并在某个垂直领域可以表现不错
 14 | - 构建包含预训练、SFT指令微调、奖励模型以及强化学习整个完整流程的LLM代码仓库,包含DeepSpeed、Megatron等分布式训练技术
 15 | - 知识分享:梳理一套完整的LLM学习资料
 16 | 
 17 | ## 🌟Quick Start
 18 | ```bash
 19 | # 1. 从“Baby-llama2-chinese Corpus”的百度网盘中下载分词处理后的预训练语料。(按需求下载-共634亿tokens,文件总大小为118G)
 20 | # 2. 将下载好的数据放到./data/目录下
 21 | # 3. 根据下载的语料,修改data_process.py中的data_path_list部分
 22 | # 4. 运行data_process.py,在./data/目录下生成pretrain_data.bin文件
 23 | python data_process.py
 24 | # 5. 根据自身算力,修改 pretrain.py文件中的模型参数调整模型大小(max_seq_len、dim、n_layers、n_heads),如果爆显存可以调整batch_size参数
 25 | # 6. 预训练 pretrain.py——以下示例是基于4*3090
 26 | screen -S ambrose    #(创建新的名称为ambrose的screen)
 27 | screen -r ambrose    #(进入名称为ambrose的screen)
 28 | torchrun --standalone --nproc_per_node=4 pretrain.py
 29 | # 7. 运行结束后,预训练模型会保存在out/pretrain文件夹中
 30 | # 8. 针对alpaca-zh和bell两个SFT语料进行处理,如果新加SFT语料可以自行扩展。运行sft_data_process.py
 31 | python sft_data_process.py
 32 | # 9. 运行结束后,会在./sft_data目录下产生sft_data.csv文件
 33 | # 10. SFT微调
 34 | python sft.py
 35 | # 11. 运行结束后,SFT模型会保存在‘out/sft’文件夹中
 36 | 
 37 | # 12. 如果需要测试训练好的SFT模型,可以运行eval.py。(可以自定义问题)
 38 | python eval.py
 39 | ```
 40 | 
 41 | ## 📢 更新公告
 42 | - 2024年01月24日:新增了在84亿tokens预训练语料上的两个新模型Llama2-Chinese-92M-v1-smallvocab与Llama2-Chinese-218M-v1,与Llama2-Chinese-92M-v1进行对比分析模型大小和词表大小对预训练效果的影响!
 43 | - 2024年02月29日:新增了在634亿tokens预训练语料上的模型Llama2-Chinese-218M-v3,并以此为基座,使用医学垂直领域SFT数据进行finetune得到模型Llama2-Chinese-218M-v3-MedicalChat
 44 | - 2024年05月21日:新增了数据清洗代码,包括:短文本过滤、Minhash(和Simhash)去重、数据存储格式转换、多数据集合并等功能。代码见clean_data目录,以budubaike数据为例,进行了数据清洗,清洗效果见下文《预训练语料预处理》部分。
 45 | 
 46 | ## 🤖预训练
 47 | 一个好的预训练基座模型要具备**续写**的能力。
 48 | 1. **分词器(Tokenizer)**:LLM分词器的构建方式有两种:一种是自己构造词表并训练一个分词器[custom tokenizers](https://github.com/karpathy/llama2.c),另一种是选择开源模型训练好的分词器,例如ChatGLM2-6B,Llama2等。
 49 | 
 50 |    由于llama官方所提供的词表中,中文的部分只有700个,这也是llama中文能力聊胜于无的原因。因此,为了方便使用,本项目选择[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B)的分词器,该词表大小为64793,值得注意的是:这是一个很妙的数字,因为它刚好在uint16的表示范围(0~65535的无符号整数),每一个token只需要两个字节即可表示,当我们的语料较大时候,相比常用的int32可以节省一半的存储空间。
 51 | 
 52 | 2. **预训练语料(Corpus for pre-training )**:从LLM技术革命以来,开源中文预训练语料越来越多。本项目本着拾人牙慧的精神,收集并处理了以下几个经典数据集:
 53 |       
 54 |    | 中文预训练语料                                                                                                                                                                                                                    | 描述                                                            |
 55 |    |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------|
 56 |    | Wiki中文百科:[wikipedia-cn-20230720-filtered](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)                                                                                              | 中文Wikipedia的数据                                                |
 57 |    | BaiduBaiKe:[百度网盘](https://pan.baidu.com/s/1jIpCHnWLTNYabftavo3DVw?pwd=bwvb) 提取码: bwvb                                                                                                                                      | 中文BaiduBaiKe的数据                                               |
 58 |    | C4_zh:[百度网盘 part1](https://pan.baidu.com/s/18O2Tj_PPB718K8gnaWrWUQ) 提取码:zv4r;[百度网盘 part2](https://pan.baidu.com/s/11PTgtUfFXvpNkOige9Iw4w) 提取码:sb83;[百度网盘 part3](https://pan.baidu.com/s/1248QfTS8QHPojYW-0fd5jQ) 提取码:l89d | C4是可用的最大语言数据集之一,收集了来自互联网上超过3.65亿个域的超过1560亿个token。C4_zh是其中的一部分 |
 59 |    | WuDaoCorpora:[智源研究院BAAI:WuDaoCorpora Text文本预训练数据集](https://data.baai.ac.cn/details/WuDaoCorporaText)                                                                                                                       | 中文悟道开源的200G数据                                                 |
 60 |    | shibing624/medical:[shibing624/medical](https://huggingface.co/datasets/shibing624/medical/tree/main)                                                                                                          | 源自shibing624的一部分医学领域的预训练数据                                    |
 61 | 
 62 |    同时,为了给大家节省数据预处理的时间,本项目开源了经过ChatGLM2-6B的分词器处理后的预训练语料,共计**634亿Tokens**的数据量,链接如下:[Baby-llama2-chinese Corpus](https://pan.baidu.com/s/18o4gF-G68qfgOGWQXgAg3g) 提取码:6unr。将下载好的数据放到./data目录下即可。
 63 |    
 64 |    【考虑到作者所持有机子的局限性(4张3090),目前634亿Tokens的预训练语料+300M参数量的模型已经是本人预训练的极限-注:没有使用DeepSpeed、Megatron等分布式训练架构】
 65 | 
 66 | ### 预训练语料预处理
 67 | 1. **数据清洗**:大规模的高质量语料是训练大语言模型的关键“养料”。这些语料提供了世界性的知识体系,能够提升语言模型的理解能力和生成质量,同时也能够支持多样化的应用场景。事实上,高质量的文本对于大语言模型的训练和能力表现具有非常重要的影响。
 68 |    
 69 |    ### 低质量文本过滤
 70 |    **待补充......**
 71 |    
 72 |    ### 文本去重
 73 |    ```bash
 74 |    #脚本里面将短文本过滤、Minhash(和Simhash)去重、数据存储格式转换、多数据集合并等功能封装成对应函数,可以根据需求选择调用函数。注意:需要根据本地数据路径修改函数中的数据路径
 75 |    cd data_clean
 76 |    python clear.py
 77 |    #以budubaike数据为例,运行结束后,会产生baike.parquet和all_no_dulpticates.parquet(all_no_dulpticates_simhash.parquet)文件
 78 |    ```
 79 |    数据清洗实验结果说明:(测试设备为 CPU:Intel(R) Xeon(R) Platinum 8168 CPU @ 2.70GHz)
 80 |    
 81 |    (去重前budubaike数据总共:5634898行)
 82 | 
 83 |    | 函数名称                                    | 函数功能           | 去重效果            | Time Consuming |
 84 |    |-----------------------------------------|----------------|--------------------|----------------|
 85 |    | process_baike()                         | 短文本过滤+数据存储格式转换 | 5634898 to 3605212 | 552.046 s      |
 86 |    | remove_dataset_duplicate_rows()         | Minhash去重      | 3605212 to 2736033 | 4 h            |
 87 |    | remove_dataset_duplicate_rows_simhash() | Simhash去重      | 3605212 to 3548779 | 23 min         |
 88 | 
 89 |    - 推荐使用parquet格式存储数据,可以大大减小存储占用
 90 |    - 推荐使用Minhash去重,效果优于Simhash,但时间消耗长!
 91 | 
 92 | 2. **分词器处理数据**:数据预处理采取GPT的通用做法,对语料进行提前分词,对一个样本做完分词后在末尾加上一个结束符号`<eos>`,与下一个样本区分开。然后将所有的训练语料拼接成一个数组(np.uint16)以.bin二进制格式存储到磁盘上。如果语料过大,避免内存溢出,可以选择mmap格式。
 93 |    ```bash
 94 |    #脚本里面每一个函数对应一个语料库的预处理,搭建新加语料可以自行扩展。
 95 |    python data_process.py
 96 |    #运行结束后,会在./data目录下产生pretrain_data.bin文件
 97 |    ```
 98 | ### 预训练
 99 | ```bash
100 | #考虑到预训练的运行时间非常久,需要采用程序后台运行的措施,本项目提供一种常用的程序后台运行的操作:
101 | screen -S ambrose    #(创建新的名称为ambrose的screen)
102 | screen -r ambrose    #(进入名称为ambrose的screen)
103 | #在该screen下执行预训练代码,如果你有四张卡,则nproc_per_node设置为4
104 | torchrun --standalone --nproc_per_node=4 pretrain.py
105 | #运行结束后,预训练模型会保存在‘out/pretrain’文件夹中
106 | ```
107 |    
108 | ## 💡SFT指令微调
109 | LLM微调的目的是将预训练模型中的知识引导出来的一种手段,通俗的讲就是教会模型说人话。
110 | 1. **微调方法**:自然语言处理目前存在一个重要的范式:一般领域数据的大规模预训练,对特定任务或领域的适应。因此,为了让预训练模型在特定任务或领域有不错的表现,需要对模型进行微调。目前主流的四种微调方法如下:
111 | 
112 |    ### LLM微调方法
113 |    - **全面微调(Full Fine-tuning)**:使用任务特定数据调整LLM的所有参数。
114 |    - **参数高效精细调整(Parameter Efficient Fine-tuning)**:修改选定参数以实现更高效的适应。例如:LoRA、Adapter、Prefix-tuning、P-tuning以及P-tuning v2。
115 |    - **提示工程(Prompt Engineering)**:改进模型输入以指导模型输出理想结果。
116 |    - **检索增强生成(Retrieval Augmented Generation)**:将提示工程与数据库查询结合,以获得丰富的上下文答案。
117 | 
118 |    其中Full Fine-tuning和Parameter Efficient Fine-tuning是需要基于特定任务或者垂直领域数据对模型(全部 or 部分)参数进行微调;
119 |    Prompt Engineering和Retrieval Augmented Generation是通过设计模型输入的template,引导模型输出我们想要的内容,不需要对模型参数进行微调。其中RAG是通过外挂数据库的方式,为模型提供领域知识输入。
120 | 
121 |    由于本项目模型参数(仅有218M左右,与bert-large-340M参数量差不多)并不大,因此选择Full Fine-tuning对特定任务或领域数据进行微调。后续有更大的预训练模型会补充其他微调方法。
122 | 2. **SFT微调数据**:LLM在垂直领域的适应已经是2023年的主格调,因此各个领域的SFT语料和微调模型层出不穷。目前已经有大佬整理并持续更新这方面的[最新进展](https://github.com/HqWu-HITCS/Awesome-Chinese-LLM),大家有需要可以自己访问。
123 |    
124 |    本项目主要针对两类SFT语料进行模型微调,如下:
125 |       
126 |    **日常问答SFT数据**:
127 | 
128 |    | SFT语料                                                                       | 描述                                                                  |
129 |    |-----------------------------------------------------------------------------|---------------------------------------------------------------------|
130 |    | alpaca-zh:[alpaca-zh](https://huggingface.co/datasets/shibing624/alpaca-zh) | 源自shibing624的一部分SFT数据。该数据集是参考Alpaca方法基于GPT4得到的self-instruct数据,约5万条。 |
131 |    | bell:[bell](https://huggingface.co/datasets/BelleGroup/train_1M_CN)         | 源自BelleGroup的一部分SFT数据。包含约100万条由BELLE项目生成的中文指令数据。|
132 | 
133 |    **医学垂直领域SFT数据**:
134 |          
135 |    | SFT语料                                                                                                                    | 描述                                                                                                                        |
136 |    |--------------------------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------|
137 |    | shibing624/medical:[shibing624/medical](https://huggingface.co/datasets/shibing624/medical/tree/main)        | 源自shibing624。该数据集不仅包含了预训练语料如上文所述,还包含一部分SFT数据。                                                                             |
138 |    | HuatuoGPT-sft-data-v1:[HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/FreedomIntelligence/HuatuoGPT-sft-data-v1) | 源自HuatuoGPT的SFT数据                                                                                                         |
139 |    | DISC-Med-SFT:[HuatuoGPT-sft-data-v1](https://huggingface.co/datasets/Flmc/DISC-Med-SFT) | DISC-Med-SFT Dataset的子集                                                                                                   |
140 |    | ChatMed_Consult-v0.3:[michaelwzhu/ChatMed_Consult-v0.3](https://huggingface.co/datasets/michaelwzhu/ChatMed_Consult_Dataset) | 本数据集, ChatMed-Dataset, 中的query(或者是prompt)来自于互联网上的医疗问诊问题(549,326),反映了真实世界的不同用户/患者的医疗问诊需求。目前response都是由OpenAI GPT-3.5引擎回答的。 |
141 | 
142 |    ### SFT样本构建
143 |    因为SFT语料一般较小,我们没必要提前分词,而是在构建Dataloader的时候进行分词构建batch送给模型。所以自行参考dataset_sft.py即可!
144 |    
145 |    基本逻辑如下:
146 |    - prompt和answer之间一定要有一个开始符`<bos>`隔开,然后answer后需要一个结束符`<eos>`。
147 |    - 计算loss的时候,对prompt部分的loss进行mask,只计算answer部分的loss即可。
148 |    
149 |    ```bash
150 |    #脚本里面针对alpaca-zh和bell两个SFT语料进行处理,搭建新加SFT语料可以自行扩展。
151 |    python sft_data_process.py
152 |    #运行结束后,会在./sft_data目录下产生sft_data.csv文件
153 |    ```
154 |    ### 全面微调(Full Fine-tuning)
155 |    ```bash
156 |    #微调所需时间一般较短,如需要后台运行,本项目提供一种常用的程序后台运行的操作:
157 |    screen -S ambrose    #(创建新的名称为ambrose的screen)
158 |    screen -r ambrose    #(进入名称为ambrose的screen)
159 |    #在该screen下执行微调代码
160 |    python sft.py
161 |    #运行结束后,SFT模型会保存在‘out/sft’文件夹中
162 |    ```
163 | 
164 | ## 🥇模型权重以及评测
165 | 1. **预训练模型**
166 |    
167 |    | 模型名称                                                        | 预训练语料                                                      | 🤗模型参数                                                  | 下载地址                                                            |
168 |    |-------------------------------------------------------------|------------------------------------------------------------|---------------------------------------------------------|-----------------------------------------------------------------|
169 |    | Llama2-Chinese-92M-v1                                       | (82.78亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical     | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/14hwHVvv_5YrIrJg2NWI62g)提取码:da7h |
170 |    | Llama2-Chinese-92M-v2                                       | (140亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical<br/>+C4_zh | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/1slimqUbDsnChqFY3CsybVw)提取码:bjal |
171 |    | Llama2-Chinese-92M-v1-smallvocab<br/>**Notes:vocab size:21131** | (82.78亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical     | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/1bKtAo8MBlDur6JIDW5cSYg)提取码:ttst |
172 |    | Llama2-Chinese-218M-v1                                      |(82.78亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical     | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | [模型下载](https://pan.baidu.com/s/1wLVGFbT4OF4LG2E8Ymf6VA)提取码:c10m |
173 |    | Llama2-Chinese-218M-v2                                      | (140亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical<br/>+C4_zh | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | [模型下载](https://pan.baidu.com/s/1cud_kEyRpXLR74DTRvqjGQ)提取码:dkne | 
174 |    | Llama2-Chinese-218M-v3                                      | (634亿 Tokens)<br/>Wiki中文百科<br/>+BaiduBaiKe<br/>+shibing624/medical<br/>+C4_zh<br/>+WuDaoCorpora | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | [模型下载](https://pan.baidu.com/s/1-Zd0uBUY23lCvr5s_QSwwg)提取码:tpyy  |
175 | 
176 |    ### 各个预训练模型效果对比
177 |    预训练loss可视化展示:
178 |    
179 |    预训练语料v1:(82.78亿 Tokens) Wiki中文百科 + BaiduBaiKe + shibing624/medical
180 |    
181 |    对比模型说明:**Llama2-Chinese-92M-v1** vs **Llama2-Chinese-92M-v1-smallvocab** vs **Llama2-Chinese-218M-v1**
182 |    ![loss_tokens-v1.png](images/loss_tokens-v1.png)
183 |    预训练语料v2:(140亿 Tokens) Wiki中文百科 + BaiduBaiKe + shibing624/medical + C4_zh
184 |    
185 |    对比模型说明:**Llama2-Chinese-92M-v2** vs **Llama2-Chinese-218M-v2**
186 |    ![loss_tokens.png](images/loss_tokens.png)
187 |    预训练语料v3:(634亿 Tokens) Wiki中文百科 + BaiduBaiKe + shibing624/medical + C4_zh + WuDaoCorpora
188 |    
189 |    对比模型说明:**Llama2-Chinese-218M-v3**
190 |    ![loss_tokens-v3.png](images/loss_tokens-v3.png)
191 |    主要从**续写能力**的角度对比各个预训练模型:
192 |    ```bash
193 |    #预训练模型的推理:eval_pretrain.py。
194 |    python eval_pretrain.py
195 |    ```
196 |    
197 |    ```bash
198 |    #示例一:Input:‘床前明月光,疑是地上霜。举头望明月,’
199 |    Llama2-Chinese-92M-v1 response:‘倚阑干照斜阳。● 早春何处无天,谁将归家归。谁愿去依花向,谁来教谁回。● 玉泉何处无玉泉,碧水如莲开处。何处何许何许人,何许何何何许地。● 青山里,春色深,月色寒。月影轻,春草凉。月来斜阳,雨声稀’
200 |    Llama2-Chinese-92M-v2 response:‘却见天涯何处无。3、 想问你为什么对我这么好?1、 你是否在问自己,我为什么一直爱你?2、我对你有多好?3、 我是谁? 4、 我是谁? 5、 为什么我从来都不像你。’
201 |    Llama2-Chinese-218M-v1 response:‘低头思故乡。2. 泪眼相对愁成丝,眉梢泪滴泪自流。3. 斜阳外,落花深似海。4. 泪眼横飞心易碎, 千言难答是何许? 梦不成, 泪洒江天。5. 泪眼空, 伤心无处诉, 思君如忆君,泪自流。6. 泪雨声、’
202 |    Llama2-Chinese-218M-v2 response:‘低头思故乡。遥知兄弟登高处,遍插茱萸少一人。注释1、明月光:月光是秋天的明丽之色。清秋:秋季。2、知月与明影:见 知月与明影,比喻知识 知识 知识 知识 知识 知识 知识 知识 知识 秋之明影。唐·韩愈《祭李太白》诗:“一树秋光满 关’
203 |    Llama2-Chinese-218M-v3 response:‘低头思故乡。人生如风,风雨来时裹住风,风雨后时披着雨。世事茫茫难料,世间万物皆随波逐流。人生如梦,醒与醉随人自个儿。人生如风,风起时能吹动树枝摇摆不定;人生如梦,梦醒时却依然在梦乡挣扎。人生如风,风狂时能翻滚浪涛,风雨中仍能’
204 |    
205 |    #示例二:Input:‘《小王子》是一本畅销童话书,它讲述了:’
206 |    Llama2-Chinese-92M-v1 response:‘小王子和他的公主的故事,讲述的是一个善良、勇敢、智慧、善良的小男孩,他们之间的爱情、友情,他们之间的亲情,他们的爱情、友情,他们之间的情感,他们之间的故事,最后的故事……《小王子》以漫画的形式把小王子和善良可爱的小公主的故事呈现给读者,相信大家一定可以从中领略到小王子和公主的故事带来的乐趣,从中体会到小王子的人生真谛。。内容简介:《小王子:小王子》故事 ’
207 |    Llama2-Chinese-92M-v2 response:‘我有一座神奇的城堡。我有一个神奇的故事;我有一个美丽的童话;我有一个可爱的童话;我有一个神奇的童话。“书虫 好书 书品 美书” 内容简介: 当阅读完了“书虫,好书,书品美书”,你一定感到你读出来的书比读一本书的好;当你看完了“书虫,好书” 最后,你会发现,书的世界竟是这么地广阔……本书中的所有童话童话,都是这样一’
208 |    Llama2-Chinese-218M-v1 response:‘小王子和他的伙伴——小丑们要和人类结缘。在他们的帮助下,小王子又遇到了各种怪事,小丑们又碰来了各种奇遇,他不仅救出了人类,还帮助了人类,还救出许多小王子。小王子和孩子们一起,和孩子们一起成长,一起冒险……书中的故事内容丰富,充满了乐趣……孩子们在阅读的过程中,将会逐渐熟悉并喜爱各种神奇的东西,同时也发现许多奇妙的趣事。’
209 |    Llama2-Chinese-218M-v2 response:‘小王子和他的伙伴们在森林里玩耍。他们遇见了很多小动物,可是它们却都讨厌他们。小王子想让小动物们开心,所以要去找小动物们开心……本书故事情节生动,内容丰富多彩,非常适合小朋友阅读。《小王子》的插画很特别,除了小王子外,还有灰姑娘、爱莎、小矮人、灰姑娘、丑小鸭……’
210 |    Llama2-Chinese-218M-v3 response:‘小王子去森林探望父亲。森林中发生了很多有趣的事情,可是小王子却非常不高兴,因为小树不见了,小树被埋了地底下。小树死了,小王子非常伤心。小王子想出许多办法来把树救出来,可是树好像一点也没死掉,它又跑到森林中央去。小树被埋在沙堆里了,可是小树并没有死,小王子觉得小树好像很关心他们,便’
211 |    ```
212 | 2. **微调模型**
213 |    
214 |    | 模型名称                               | SFT语料                                                                                     | 🤗模型参数                                          | 下载地址                                                            |
215 |    |------------------------------------|-------------------------------------------------------------------------------------------|-------------------------------------------------|-----------------------------------------------------------------|
216 |    | Llama2-Chinese-92M-v1-NormalChat   | alpaca-zh+bell                                                                            | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/14hwHVvv_5YrIrJg2NWI62g)提取码:da7h |
217 |    | Llama2-Chinese-92M-v1-MedicalChat  | shibing624/medical<br/>+HuatuoGPT-sft-data-v1<br/>+DISC-Med-SFT<br/>+ChatMed_Consult-v0.3 | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/14hwHVvv_5YrIrJg2NWI62g)提取码:da7h |
218 |    | Llama2-Chinese-92M-v2-NormalChat   | alpaca-zh+bell                                                                            | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | [模型下载](https://pan.baidu.com/s/1slimqUbDsnChqFY3CsybVw)提取码:bjal |
219 |    | Llama2-Chinese-92M-v2-MedicalChat  | shibing624/medical<br/>+HuatuoGPT-sft-data-v1<br/>+DISC-Med-SFT<br/>+ChatMed_Consult-v0.3 | max_seq_len=512<br/>dim=512<br/>n_layers=8<br/>n_heads=8    | 正在加紧训练中!!!                                                      |
220 |    | Llama2-Chinese-218M-v1-NormalChat  | alpaca-zh+bell                                                                            | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | 正在加紧训练中!!!                                                      |
221 |    | Llama2-Chinese-218M-v1-MedicalChat | shibing624/medical<br/>+HuatuoGPT-sft-data-v1<br/>+DISC-Med-SFT<br/>+ChatMed_Consult-v0.3 | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | 正在加紧训练中!!!                                                      |
222 |    | Llama2-Chinese-218M-v2-NormalChat  | alpaca-zh+bell                                                                            | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | [模型下载](https://pan.baidu.com/s/1cud_kEyRpXLR74DTRvqjGQ)提取码:dkne |
223 |    | Llama2-Chinese-218M-v2-MedicalChat | shibing624/medical<br/>+HuatuoGPT-sft-data-v1<br/>+DISC-Med-SFT<br/>+ChatMed_Consult-v0.3 | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | 正在加紧训练中!!!                                                      |
224 |    | Llama2-Chinese-218M-v3-NormalChat  | alpaca-zh+bell                                                                            | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | 正在加紧训练中!!!                                                      |
225 |    | Llama2-Chinese-218M-v3-MedicalChat | shibing624/medical<br/>+HuatuoGPT-sft-data-v1<br/>+DISC-Med-SFT<br/>+ChatMed_Consult-v0.3 | max_seq_len=1024<br/>dim=1024<br/>n_layers=12<br/>n_heads=8 | [模型下载](https://pan.baidu.com/s/1-Zd0uBUY23lCvr5s_QSwwg)提取码:tpyy |
226 |    目前没有在标准测试集上验证SFT模型效果。仅仅给出以下示例:
227 |    ```bash
228 |    #SFT微调模型的推理:eval.py。
229 |    python eval.py
230 |    ```
231 |    ```bash
232 |    #示例一:Input:‘最近我在办公室坐久了会感到头晕,请问这是什么原因?有什么缓解办法吗?’——医学向问题
233 |    Llama2-Chinese-92M-v1-NormalChat response:‘是的,由于我的工作忙,我会很可能会坐在办公室里坐。你可以尝试保持安静并遵守交通规则,尝试在办公室里或周末去散步或散步,这有助于缓解紧张的语气。’
234 |    Llama2-Chinese-92M-v1-MedicalChat response:‘您的症状可能是由多种原因引起的,例如低血压、脑部疾病、颈椎病、贫血、贫血、低血糖、颈椎病等。以下是一些缓解这种情况的方法:1. 休息:在长时间使用电脑或手机前,放松身体,放松肌肉,缓解晕厥。2. 按摩:轻轻按摩头部或颈部,缓解头晕。3. 热敷:使用热毛巾或热水袋敷在头部或颈部上,缓解’
235 |    Llama2-Chinese-92M-v2-NormalChat response:‘我无法确定您所在的位置。您可以在公司办公室的电脑上查询我查询的信息。不过,如果您访问了该公司的网站或应用,您可以访问该公司的客服部门,查看最近的办公室的列表,看看是否有可能是无法进入该公司的办公室。通常办公室可能是安静的,没有噪音,而且可能不会发出任何声音。因此,我建议您与您公司客服联系以获取相关信息并咨询您所在公司的办公室。’
236 |    Llama2-Chinese-92M-v2-MedicalChat response:‘’
237 |    Llama2-Chinese-218M-v1-NormalChat response:‘’
238 |    Llama2-Chinese-218M-v1-MedicalChat response:‘’
239 |    Llama2-Chinese-218M-v2-NormalChat response:‘我无法确定您所描述的问题,因此无法回答。是一种可能的诊断是偏头痛。这通常与气候变化、心理压力或饮食因素有关。 它可能会导致以下症状:1. 热或热配菜。2. 热食物或饮料。3. 食物中添加的盐或糖。’
240 |    Llama2-Chinese-218M-v2-MedicalChat response:‘’
241 |    Llama2-Chinese-218M-v3-NormalChat response:‘’
242 |    Llama2-Chinese-218M-v3-MedicalChat response:‘长时间保持同一姿势会导致血液循环不畅,影响血使用的食用,从而导致头晕。您可以尝试以下方法缓解:1. 调整坐姿或站立的姿势,让身体的休息时间;2. 经常活动身体,如走动、做一些简单的伸展运动;3. 保持良好的睡眠质量,避免在睡前过度使用电子设备;4. 饮食要均衡,避免吃太过于油腻或刺激性的食物。’
243 |    
244 |    #示例二:Input:‘前列腺囊肿的症状是什么?’——医学向问题
245 |    Llama2-Chinese-92M-v1-NormalChat response:‘人类大脑中的神经控制系统是研究人类大脑结构和功能的重要手段。它们负责控制大脑的温度、光线、温度、大气、睡眠、心理健康等方面的功能。这些控制系统在人类大脑中扮演着至关重要的角色,从而对大脑的影响非常显著。’
246 |    Llama2-Chinese-92M-v1-MedicalChat response:‘尿频、尿急、尿痛;慢性前列腺炎;排尿不畅;遗精;会阴痛;夜尿频多,频尿;夜间阴茎勃起;尿频;睾丸疼痛;尿痛;尿分心;排尿疼痛’
247 |    Llama2-Chinese-92M-v2-NormalChat response:‘影片光速是有可能划破天际的,可能带来无限可能的进展,可能会有点困难。不过,由于该故事情节并不准确,所以仅凭这些信息是不够准确诊断的。建议您咨询专业的医学e-卡路里评估方法,以获取更准确的诊断。’
248 |    Llama2-Chinese-92M-v2-MedicalChat response:‘’
249 |    Llama2-Chinese-218M-v1-NormalChat response:‘’
250 |    Llama2-Chinese-218M-v1-MedicalChat response:‘’
251 |    Llama2-Chinese-218M-v2-NormalChat response:‘也被称为“水花如水”,通常指在火上燃烧的面包或饭单,也称为“烤面包”。这些食物,即面包,通常含有脂肪,蛋白质和.’
252 |    Llama2-Chinese-218M-v2-MedicalChat response:‘’
253 |    Llama2-Chinese-218M-v3-NormalChat response:‘’
254 |    Llama2-Chinese-218M-v3-MedicalChat response:‘尿频;脓性分泌物;血尿;尿急;前列腺痛;脓肿;尿痛;尿后滴沥;排尿困难及尿潴留
255 |    ’
256 |    
257 |    #示例三:Input:‘请问,世界上最大的动物是什么?’——科普向问题
258 |    Llama2-Chinese-92M-v1-NormalChat response:‘世界上最大的动物是蓝鲸。’
259 |    Llama2-Chinese-92M-v1-MedicalChat response:‘38%的动物都是天马科动物,而“天牛”的称法是 10%的,因为这种动物会作为天马长成而生存,所以,它们是家牛的一种。’
260 |    Llama2-Chinese-92M-v2-NormalChat response:‘世界上最大的动物是蓝鲸。它们是鲸类中最大的,可以长达30米,重达173吨。它们是地球上最大的动物,可以长达30米长,体重可达7吨,并且可以长达38米高。’
261 |    Llama2-Chinese-92M-v2-MedicalChat response:‘’
262 |    Llama2-Chinese-218M-v1-NormalChat response:‘’
263 |    Llama2-Chinese-218M-v1-MedicalChat response:‘’
264 |    Llama2-Chinese-218M-v2-NormalChat response:‘世界上最大的动物是蓝鲸。它们的体重可以达到4000至5000公斤,体重可达到7000至9000公斤。他们来自海洋,并且是地球上最适应 蓝鲸是一种非常适应生存由海洋环境而产生的哺乳动物。它们可以达到1.2至1.4米重。它们以鱼类为食,但也会吃小鱼。蓝鲸是肉食性的动物,但它们也可以吃小型’
265 |    Llama2-Chinese-218M-v2-MedicalChat response:‘’
266 |    Llama2-Chinese-218M-v3-MedicalChat response:‘除了导致的,在一般情况下,保持适当的中毒处理方法是首先通过服用药物。’
267 |    Llama2-Chinese-218M-v3-NormalChat response:‘’
268 |    ```
269 |    
270 |    可以明显看出,经过medical SFT数据微调后的模型在医学向问题的回答上比其他模型更加准确,但是对于日常科普向问题的回答遗忘性太大。
271 |    
272 |    总而言之,模型越大,语料越多模型的性能越强。
273 |    
274 | ## 🎉号召
275 | 欢迎大家一起共建这个小项目,这对于希望入门LLM的同学来说,是一次不可多得的练手机会!感兴趣的小伙伴可以加QQ群: 716455397。 
276 | 
277 | ## 🎉参考链接
278 | [Llama2](https://github.com/karpathy/llama2.c)
279 | 
280 | [数据清洗-ChatLM-mini-Chinese](https://github.com/charent/ChatLM-mini-Chinese/)
281 | 


--------------------------------------------------------------------------------
/chatglm_tokenizer/tokenization_chatglm.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import torch
  3 | from typing import List, Optional, Union, Dict
  4 | from sentencepiece import SentencePieceProcessor
  5 | from transformers import PreTrainedTokenizer
  6 | from transformers.utils import logging, PaddingStrategy
  7 | from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
  8 | 
  9 | 
 10 | class SPTokenizer:
 11 |     def __init__(self, model_path: str):
 12 |         # reload tokenizer
 13 |         assert os.path.isfile(model_path), model_path
 14 |         self.sp_model = SentencePieceProcessor(model_file=model_path)
 15 | 
 16 |         # BOS / EOS token IDs
 17 |         self.n_words: int = self.sp_model.vocab_size()
 18 |         self.bos_id: int = self.sp_model.bos_id()
 19 |         self.eos_id: int = self.sp_model.eos_id()
 20 |         self.pad_id: int = self.sp_model.unk_id()
 21 |         assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
 22 | 
 23 |         special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"]
 24 |         self.special_tokens = {}
 25 |         self.index_special_tokens = {}
 26 |         for token in special_tokens:
 27 |             self.special_tokens[token] = self.n_words
 28 |             self.index_special_tokens[self.n_words] = token
 29 |             self.n_words += 1
 30 | 
 31 |     def tokenize(self, s: str):
 32 |         return self.sp_model.EncodeAsPieces(s)
 33 | 
 34 |     def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
 35 |         assert type(s) is str
 36 |         t = self.sp_model.encode(s)
 37 |         if bos:
 38 |             t = [self.bos_id] + t
 39 |         if eos:
 40 |             t = t + [self.eos_id]
 41 |         return t
 42 | 
 43 |     def decode(self, t: List[int]) -> str:
 44 |         return self.sp_model.decode(t)
 45 | 
 46 |     def decode_tokens(self, tokens: List[str]) -> str:
 47 |         text = self.sp_model.DecodePieces(tokens)
 48 |         return text
 49 | 
 50 |     def convert_token_to_id(self, token):
 51 |         """ Converts a token (str) in an id using the vocab. """
 52 |         if token in self.special_tokens:
 53 |             return self.special_tokens[token]
 54 |         return self.sp_model.PieceToId(token)
 55 | 
 56 |     def convert_id_to_token(self, index):
 57 |         """Converts an index (integer) in a token (str) using the vocab."""
 58 |         if index in self.index_special_tokens or index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
 59 |             return ""
 60 |         return self.sp_model.IdToPiece(index)
 61 | 
 62 | 
 63 | class ChatGLMTokenizer(PreTrainedTokenizer):
 64 |     vocab_files_names = {"vocab_file": "tokenizer.model"}
 65 | 
 66 |     model_input_names = ["input_ids", "attention_mask", "position_ids"]
 67 | 
 68 |     def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, **kwargs):
 69 |         super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs)
 70 |         self.name = "GLMTokenizer"
 71 | 
 72 |         self.vocab_file = vocab_file
 73 |         self.tokenizer = SPTokenizer(vocab_file)
 74 |         self.special_tokens = {
 75 |             "<bos>": self.tokenizer.bos_id,
 76 |             "<eos>": self.tokenizer.eos_id,
 77 |             "<pad>": self.tokenizer.pad_id
 78 |         }
 79 | 
 80 |     def get_command(self, token):
 81 |         if token in self.special_tokens:
 82 |             return self.special_tokens[token]
 83 |         assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
 84 |         return self.tokenizer.special_tokens[token]
 85 | 
 86 |     @property
 87 |     def unk_token(self) -> str:
 88 |         return "<unk>"
 89 | 
 90 |     @property
 91 |     def pad_token(self) -> str:
 92 |         return "<unk>"
 93 | 
 94 |     @property
 95 |     def pad_token_id(self):
 96 |         return self.get_command("<pad>")
 97 | 
 98 |     @property
 99 |     def eos_token(self) -> str:
100 |         return "</s>"
101 | 
102 |     @property
103 |     def eos_token_id(self):
104 |         return self.get_command("<eos>")
105 | 
106 |     @property
107 |     def vocab_size(self):
108 |         return self.tokenizer.n_words
109 | 
110 |     def get_vocab(self):
111 |         """ Returns vocab as a dict """
112 |         vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
113 |         vocab.update(self.added_tokens_encoder)
114 |         return vocab
115 | 
116 |     def _tokenize(self, text, **kwargs):
117 |         return self.tokenizer.tokenize(text)
118 | 
119 |     def _convert_token_to_id(self, token):
120 |         """ Converts a token (str) in an id using the vocab. """
121 |         return self.tokenizer.convert_token_to_id(token)
122 | 
123 |     def _convert_id_to_token(self, index):
124 |         """Converts an index (integer) in a token (str) using the vocab."""
125 |         return self.tokenizer.convert_id_to_token(index)
126 | 
127 |     def convert_tokens_to_string(self, tokens: List[str]) -> str:
128 |         return self.tokenizer.decode_tokens(tokens)
129 | 
130 |     def save_vocabulary(self, save_directory, filename_prefix=None):
131 |         """
132 |         Save the vocabulary and special tokens file to a directory.
133 | 
134 |         Args:
135 |             save_directory (`str`):
136 |                 The directory in which to save the vocabulary.
137 |             filename_prefix (`str`, *optional*):
138 |                 An optional prefix to add to the named of the saved files.
139 | 
140 |         Returns:
141 |             `Tuple(str)`: Paths to the files saved.
142 |         """
143 |         if os.path.isdir(save_directory):
144 |             vocab_file = os.path.join(
145 |                 save_directory, self.vocab_files_names["vocab_file"]
146 |             )
147 |         else:
148 |             vocab_file = save_directory
149 | 
150 |         with open(self.vocab_file, 'rb') as fin:
151 |             proto_str = fin.read()
152 | 
153 |         with open(vocab_file, "wb") as writer:
154 |             writer.write(proto_str)
155 | 
156 |         return (vocab_file,)
157 | 
158 |     def get_prefix_tokens(self):
159 |         prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
160 |         return prefix_tokens
161 | 
162 |     def build_prompt(self, query, history=None):
163 |         if history is None:
164 |             history = []
165 |         prompt = ""
166 |         for i, (old_query, response) in enumerate(history):
167 |             prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i + 1, old_query, response)
168 |         prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
169 |         return prompt
170 | 
171 |     def build_inputs_with_special_tokens(
172 |             self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
173 |     ) -> List[int]:
174 |         """
175 |         Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
176 |         adding special tokens. A BERT sequence has the following format:
177 | 
178 |         - single sequence: `[CLS] X [SEP]`
179 |         - pair of sequences: `[CLS] A [SEP] B [SEP]`
180 | 
181 |         Args:
182 |             token_ids_0 (`List[int]`):
183 |                 List of IDs to which the special tokens will be added.
184 |             token_ids_1 (`List[int]`, *optional*):
185 |                 Optional second list of IDs for sequence pairs.
186 | 
187 |         Returns:
188 |             `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
189 |         """
190 |         prefix_tokens = self.get_prefix_tokens()
191 |         token_ids_0 = prefix_tokens + token_ids_0
192 |         if token_ids_1 is not None:
193 |             token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
194 |         return token_ids_0
195 | 
196 |     def _pad(
197 |             self,
198 |             encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
199 |             max_length: Optional[int] = None,
200 |             padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
201 |             pad_to_multiple_of: Optional[int] = None,
202 |             return_attention_mask: Optional[bool] = None,
203 |     ) -> dict:
204 |         """
205 |         Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
206 | 
207 |         Args:
208 |             encoded_inputs:
209 |                 Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
210 |             max_length: maximum length of the returned list and optionally padding length (see below).
211 |                 Will truncate by taking into account the special tokens.
212 |             padding_strategy: PaddingStrategy to use for padding.
213 | 
214 |                 - PaddingStrategy.LONGEST Pad to the longest sequence in the batch
215 |                 - PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
216 |                 - PaddingStrategy.DO_NOT_PAD: Do not pad
217 |                 The tokenizer padding sides are defined in self.padding_side:
218 | 
219 |                     - 'left': pads on the left of the sequences
220 |                     - 'right': pads on the right of the sequences
221 |             pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
222 |                 This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
223 |                 `>= 7.5` (Volta).
224 |             return_attention_mask:
225 |                 (optional) Set to False to avoid returning attention mask (default: set to model specifics)
226 |         """
227 |         # Load from model defaults
228 |         assert self.padding_side == "left"
229 | 
230 |         required_input = encoded_inputs[self.model_input_names[0]]
231 |         seq_length = len(required_input)
232 | 
233 |         if padding_strategy == PaddingStrategy.LONGEST:
234 |             max_length = len(required_input)
235 | 
236 |         if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
237 |             max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
238 | 
239 |         needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
240 | 
241 |         # Initialize attention mask if not present.
242 |         if "attention_mask" not in encoded_inputs:
243 |             encoded_inputs["attention_mask"] = [1] * seq_length
244 | 
245 |         if "position_ids" not in encoded_inputs:
246 |             encoded_inputs["position_ids"] = list(range(seq_length))
247 | 
248 |         if needs_to_be_padded:
249 |             difference = max_length - len(required_input)
250 | 
251 |             if "attention_mask" in encoded_inputs:
252 |                 encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
253 |             if "position_ids" in encoded_inputs:
254 |                 encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
255 |             encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
256 | 
257 |         return encoded_inputs
258 | 


--------------------------------------------------------------------------------
/chatglm_tokenizer/tokenizer.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLLXW/baby-llama2-chinese/98a20dbb35e686a62188f61f479809cb2d4f8d6e/chatglm_tokenizer/tokenizer.model


--------------------------------------------------------------------------------
/chatglm_tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
 1 | {
 2 |   "name_or_path": "THUDM/chatglm2-6b",
 3 |   "remove_space": false,
 4 |   "do_lower_case": false,
 5 |   "tokenizer_class": "ChatGLMTokenizer",
 6 |   "auto_map": {
 7 |     "AutoTokenizer": [
 8 |       "tokenization_chatglm.ChatGLMTokenizer",
 9 |       null
10 |       ]
11 |   }
12 | }
13 | 


--------------------------------------------------------------------------------
/data_clean/clear.py:
--------------------------------------------------------------------------------
  1 | import ujson
  2 | import re
  3 | from os.path import dirname, abspath, exists, isdir
  4 | from os import remove, mkdir, walk
  5 | import time
  6 | from collections import defaultdict
  7 | 
  8 | from matplotlib import pyplot as plt
  9 | import codecs, csv
 10 | import pandas as pd 
 11 | import numpy as np
 12 | from rich import progress
 13 | from rich.table import Table
 14 | from rich.console import Console
 15 | from fastparquet import ParquetFile, write
 16 | import pyarrow.parquet as pq
 17 | 
 18 | import sys
 19 | sys.path.extend(['.','..'])
 20 | 
 21 | from logger import Logger
 22 | from functions import get_path_of_suffix_files, DropDatasetDuplicate, DropDatasetDuplicate_SimHash
 23 | 
 24 | log = Logger('data_process', save2file=True, file_name='./logs/raw_data_process.log')
 25 | 
 26 | punctuation = set("!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~.,;《》?!“”‘’@#¥%…&×()——+【】{};;●,。&~、|\s::\n")
 27 | en_punctuation = ",().!;:"
 28 | zh_punctuation = ",()。!;:"
 29 | 
 30 | def delete_file(file: str)-> bool:
 31 |     '''
 32 |     询问删除文件
 33 |     '''
 34 |     if exists(file):
 35 |         ans = input('delete file: {} ? Yes (y) or No (n)'.format(file))
 36 |         ans = ans.lower()
 37 |         if ans in ('yes', 'y'):
 38 |             remove(file)
 39 |             print('deleted.')
 40 |             return True
 41 |     return False
 42 | 
 43 | def remove_duplicate_punctuation(sentence: str) -> str:
 44 |     '''
 45 |     删除句子中重复的标点符号、重复的空格,同时将换行变为特殊字符'\n'
 46 |     '''
 47 |     # 将空格(全角空格)替换为逗号, 可能会有重复的空客,下面删除重复标点会删除
 48 |     sentence = re.sub(' | ', ',', sentence) 
 49 | 
 50 |     ans = ''
 51 |     n = len(sentence)
 52 |     p = 0
 53 |     while p < n:
 54 |         ans += sentence[p]
 55 | 
 56 |         while p + 1 < n and sentence[p] in punctuation and sentence[p + 1] in punctuation:
 57 |             p += 1
 58 |         p += 1
 59 | 
 60 |     return ans
 61 | 
 62 | def convert_en_punctuation_to_zh_punct(sentence: str) -> str:
 63 |     '''
 64 |     将句子中的英文标点替换文中文标点
 65 |     '''
 66 |     n = len(zh_punctuation)
 67 |     for i in range(n):
 68 |         sentence = sentence.replace(en_punctuation[i], zh_punctuation[i])
 69 |     return sentence
 70 | 
 71 | def get_sentences_dice_similarity(st_a: str, st_b: str) -> float:
 72 |     '''
 73 |     获取两个句子的Dice相似度(Dice similarity)
 74 |     s(a, b) =  2 * len( set(a) & set(b) ) / (len(set(a)) + len(set(b)))
 75 |     '''
 76 |     set_a, set_b = set(st_a), set(st_b)
 77 |     total_len  = len(set_a) + len(set_b)
 78 |     
 79 |     if total_len == 0: return 0.0
 80 | 
 81 |     inter_set =  set_a & set_b
 82 |     
 83 |     return ( 2 * len(inter_set)) / total_len
 84 | 
 85 | def write_single_parquet_file(file_name: str, data_frame: pd.DataFrame) -> None:
 86 |     '''
 87 |     将dataframe写到单独的parquet file中
 88 |     '''
 89 |     append = False
 90 |     if exists(file_name):
 91 |         append = True 
 92 | 
 93 |     write(file_name, data_frame, compression='GZIP', append=append)
 94 |     
 95 | 
 96 | def merge_dataset_as_single_file(groups_cnt: int=50000, max_len: int=512, min_len: int=3, cut_max_len: bool=False) -> None:
 97 |     '''
 98 |     将多个数据集合并为一个数据集
 99 |     '''
100 |     from_parquet_files = get_path_of_suffix_files('./data/', '.parquet')
101 | 
102 |     save_file = './data/all_dataset.parquet'
103 | 
104 |     # 后续append写入,存在文件先删除
105 |     if exists(save_file): 
106 |         assert delete_file(save_file)
107 | 
108 |     cur_rows = []
109 |     append = cur_rows.append
110 |     
111 |     all_cnt, keep_cnt = 0, 0
112 |     for file in from_parquet_files:
113 |         print('process file: {}'.format(file))
114 | 
115 |         parquet_table = pq.read_table(file)
116 |      
117 |         for response in progress.track(parquet_table['response'], total=parquet_table.num_rows):
118 | 
119 |             response =  response.as_py()
120 |             all_cnt += 1
121 | 
122 |             if len(response) < min_len:
123 |                 continue
124 | 
125 |             if cut_max_len and len(response) > max_len:
126 |                 response = response[0: max_len]
127 | 
128 |             keep_cnt += 1
129 |             append({'response': response})
130 | 
131 |             if len(cur_rows) >= groups_cnt:
132 |                 df = pd.DataFrame(cur_rows)
133 |                 write_single_parquet_file(save_file, df)
134 |                 cur_rows = []
135 |                 append = cur_rows.append
136 |                 
137 |     # 处理末尾部分
138 |     if len(cur_rows) > 0:
139 |         df = pd.DataFrame(cur_rows)
140 |         write_single_parquet_file(save_file, df)
141 |         cur_rows = []
142 | 
143 |     log.info("merge into file: {}, 全部数据共{}行,清洗后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True)
144 |     
145 | def remove_dataset_duplicate_rows(groups_cnt: int=50000) -> None:
146 |     '''
147 |     使用min_hash删除数据集中重复的部分
148 |     '''
149 |     from_parquet_files = '../data/563w_baidubaike/baike.parquet'
150 | 
151 |     save_file = '../data/563w_baidubaike/all_no_dulpticates.parquet'
152 | 
153 |     # 后续append写入,存在文件先删除
154 |     if exists(save_file): 
155 |         assert delete_file(save_file)
156 | 
157 |     cur_rows = []
158 |     all_cnt, keep_cnt = 0, 0
159 |     row_index = -1
160 |     drop_dataset_duplicate = DropDatasetDuplicate(threshold=0.85, num_perm=256)
161 |     
162 |     parquet_table = pq.read_table(from_parquet_files)
163 |     all_cnt = parquet_table.num_rows
164 |     print(all_cnt)
165 |     # 先顺序遍历获取哪些行是重复的
166 |     for response in progress.track(parquet_table['response'], total=parquet_table.num_rows):
167 |         row_index += 1
168 | 
169 |         doc = f"{response.as_py()}" # 将JSON格式的响应转换为Python字典
170 |         drop_dataset_duplicate.add_doc(index=row_index, doc=doc)
171 | 
172 |     row_index = -1
173 |     need_to_drop_indexs = drop_dataset_duplicate.get_duplicate_indexs()
174 | 
175 |     # 再顺序遍历一遍,重复的行不添加到新的数据集
176 |     for response in progress.track(parquet_table['response'], total=parquet_table.num_rows):
177 |         row_index += 1  # 不管有没有跳过行, row_index都必须+1
178 | 
179 |         # 重复的行跳过
180 |         if row_index in need_to_drop_indexs:
181 |             continue
182 | 
183 |         cur_rows.append({'response': response.as_py()})
184 |         keep_cnt += 1
185 | 
186 |         # 分块写入
187 |         if len(cur_rows) >= groups_cnt:
188 |             df = pd.DataFrame(cur_rows)
189 |             write_single_parquet_file(save_file, df)
190 |             cur_rows = []
191 |     # 处理末尾部分,并写入
192 |     if len(cur_rows) > 0:
193 |         df = pd.DataFrame(cur_rows)
194 |         write_single_parquet_file(save_file, df)
195 |     log.info("merge into file: {}, 全部数据共{}行,文档去重后剩余{}行".format(save_file, all_cnt, keep_cnt), save_to_file=True)
196 | 
197 | 
198 | def remove_dataset_duplicate_rows_simhash(groups_cnt: int = 50000) -> None:
199 |     '''
200 |     使用sim_hash删除数据集中重复的部分
201 |     '''
202 |     from_parquet_files = '../data/563w_baidubaike/baike.parquet'
203 | 
204 |     save_file = '../data/563w_baidubaike/all_no_dulpticates_simhash.parquet'
205 | 
206 |     # 后续append写入,存在文件先删除
207 |     if exists(save_file):
208 |         assert delete_file(save_file)
209 | 
210 |     cur_rows = []
211 |     all_cnt, keep_cnt = 0, 0
212 |     row_index = -1
213 | 
214 |     parquet_table = pq.read_table(from_parquet_files)
215 |     all_cnt = parquet_table.num_rows
216 |     print(all_cnt)
217 |     drop_dataset_duplicate = DropDatasetDuplicate_SimHash(threshold=3, f=128)
218 |     # 先顺序遍历获取哪些行是重复的
219 |     for response in progress.track(parquet_table['response'], total=parquet_table.num_rows):
220 |         row_index += 1
221 | 
222 |         doc = f"{response.as_py()}"
223 |         drop_dataset_duplicate.add_doc(index=row_index, doc=doc)
224 | 
225 |     droped_database = drop_dataset_duplicate.database
226 | 
227 |     # 写入去重后的数据
228 |     for k, v in droped_database.items():
229 |         cur_rows.append({'response': v})
230 |         keep_cnt += 1
231 | 
232 |         # 分块写入
233 |         if len(cur_rows) >= groups_cnt:
234 |             df = pd.DataFrame(cur_rows)
235 |             write_single_parquet_file(save_file, df)
236 |             cur_rows = []
237 |     # 处理末尾部分,并写入
238 |     if len(cur_rows) > 0:
239 |         df = pd.DataFrame(cur_rows)
240 |         write_single_parquet_file(save_file, df)
241 |     log.info("merge into file: {}, 全部数据共{}行,文档去重后剩余{}行".format(save_file, all_cnt, keep_cnt),
242 |              save_to_file=True)
243 | 
244 | def shuffle_parquet_dataset(parquet_file: str, shuffle_file: str, seed: int=23333, groups_cnt: int=65536) -> None:
245 |     '''
246 |     打乱一个parquet文件数据集
247 |     '''
248 |     if not exists(parquet_file):
249 |         raise Exception('can not find parquet file: {}'.format(parquet_file))
250 |     
251 |     print('start shuffle...')
252 |     pf =  pq.read_table(parquet_file)
253 |     df = pf.to_pandas()
254 |     df = df.sample(frac=1.0, replace=False, random_state=seed, axis=0)
255 |     
256 |     if exists(shuffle_file): 
257 |         assert delete_file(shuffle_file)
258 |     
259 |     # 分块写入parquet,否则小内存读取直接OOM
260 |     n = len(df)
261 |     for i in range(0, n, groups_cnt):
262 |         cur_group_df = df[i: i + groups_cnt]
263 |         write_single_parquet_file(shuffle_file, cur_group_df)
264 | 
265 | 
266 | def read_and_write_template_wiki(read_file: str, write_to_file: str, call_back: object, group_cnt: int=10000) -> None:
267 |     '''
268 |     处理数据读写模板,需要提供一个回调函数call_back,
269 |     read_file: 原始数据文件
270 |     write_to_file:处理后的要保存数据文件
271 |     call_back:函数输入一个字符串,输出一个处理后的字典dict,如果输入的字符串为无效数据,请返回None
272 |     group_cnt: parquet file分割行数
273 |     如:
274 |     >>> def call_back(inputs: str) -> dict:
275 |     >>>     if check(inputs) not valid:
276 |     >>>         return None
277 |     ...    
278 |     ...    do something for inputs
279 |     ...
280 |     >>>     my_dict = {
281 |     >>>             'prompt': inputs['p'],
282 |     >>>             'response': inputs['a1'] + inputs['a2'],
283 |     >>>             ...
284 |     >>>         }
285 |     >>>     return my_dict
286 |     '''
287 | 
288 |     log.info('process file:{}'.format(read_file), save_to_file=True)
289 |     start = time.time()
290 |     
291 |     raw_line_cnt = 0
292 |     keep_line_cnt = 0
293 |     with progress.open(read_file, 'r', encoding='utf-8') as f_read:
294 |         
295 |         json_list = ujson.load(f_read)
296 |         cur_rows = []
297 |         append = cur_rows.append
298 | 
299 |         for line in json_list:
300 |             try:
301 |                 #print(line)
302 |                 raw_line_cnt += 1
303 |                 write_dict = call_back(line)
304 |                 if write_dict is None: continue
305 |                 keep_line_cnt += 1
306 |                 append(write_dict)
307 |                 if len(cur_rows) >= group_cnt:
308 |                     df = pd.DataFrame(cur_rows)
309 |                     write_single_parquet_file(write_to_file, df)
310 |                     cur_rows = []
311 |                     append = cur_rows.append
312 |             except Exception as e:
313 |                 # log.error('处理文件异常:{}, content:{}'.format(str(e), line))
314 |                 print(line)
315 |                 raise e
316 |             # end for
317 |             # 处理末尾部分
318 |         if len(cur_rows) > 0:
319 |             df = pd.DataFrame(cur_rows)
320 |             write_single_parquet_file(write_to_file, df)
321 |             cur_rows = []
322 |         end = time.time()
323 |         log.info('原始文件:{},共{}行,处理后剩余{}行,保存到文件:{}。耗时:{:.6}s'\
324 |                     .format(read_file, raw_line_cnt, keep_line_cnt, write_to_file, end - start), save_to_file=True)
325 | 
326 | '''
327 | {
328 | 
329 |     "completion": "陈准,字道基,颍川郡许昌(今河南许昌)人。西晋官员。官至太尉。出身颍川陈氏,青州刺史陈佐之子,曹魏司空陈群族孙,曾祖父是陈群的叔叔陈谌。\n生平\n陈准早年居于乡里,被乡人称赞,有声望,晋惠帝元康五年(295年)官至中书令。时皇后贾南风擅权,由于张华、裴𬱟等人共同辅政,朝野安静。氐人齐万年反叛,陈准多次指斥负责赵王司马伦、梁王司马肜等不任军事,荐举由大将周处、孟观指挥作战。司马肜忌恨周处,致使周处力战而死。后来朝廷从了陈准的建议,派孟观督师征讨齐万年胜利。永康元年(300年),司马伦发动政变,废杀贾南风,陈准有功,封海陵公。司马伦意图篡位,淮南王司马允发兵讨伐,围困司马伦。陈准暗地支持司马允,骗晋惠帝说打出劝和驺虞幡,其实派人打出督战的令旗白虎幡。可是,派去的人被司马伦收买,诱杀了司马允。陈准本有心袒护司马允,到头来反而救了司马伦。司马伦不知其故,提升陈准为太尉,录尚书事,改封广陵公。不久,陈准去世,谥号元。\n家庭\n平辈\n* 弟陈徽,太子左卫率,淮南王司马允讨赵王司马伦,曾集结东宫兵在宫内响应淮南王。\n* 弟陈戴,国子助教。\n后代\n* 子陈眕,西晋左卫将军,幽州刺史。\n* 子陈匡,司马遹东宫侍读。\n* 子陈规\n* 孙陈逵,陈眕子,东晋梁淮南二郡太守。",
330 | 
331 |     "source": "wikipedia.zh2307"
332 | 
333 |   },
334 | '''
335 | 
336 | def process_wiki(response_less_word: int=15) -> None:
337 |     file_names = [
338 |         '../data/wikipedia-cn-20230720-filtered.json',
339 |     ]
340 |     save_file_name = './data/wiki.parquet'
341 |     if exists(save_file_name): 
342 |         assert delete_file(save_file_name)
343 | 
344 |     def process_function(item: dict) -> dict:
345 |         #print(item['completion'])
346 |         # 数据清洗
347 |         response = item['completion'].replace('\r','')
348 |         response = remove_duplicate_punctuation(response)
349 |         # 剔除短数据
350 |         if len(response) < response_less_word:
351 |             return None
352 |         write_dict = {
353 |                 "response": response,
354 |             }
355 |         return write_dict
356 | 
357 |     for file_name in file_names:
358 |         read_file = file_name
359 |         read_and_write_template_wiki(read_file, save_file_name, process_function)
360 |         
361 | 
362 | def read_and_write_template_baike(read_file: str, write_to_file: str, call_back: object, group_cnt: int=10000) -> None:
363 |     '''
364 |     处理数据读写模板,需要提供一个回调函数call_back,
365 |     read_file: 原始数据文件
366 |     write_to_file:处理后的要保存数据文件
367 |     call_back:函数输入一个字符串,输出一个处理后的字典dict,如果输入的字符串为无效数据,请返回None
368 |     group_cnt: parquet file分割行数
369 |     如:
370 |     >>> def call_back(inputs: str) -> dict:
371 |     >>>     if check(inputs) not valid:
372 |     >>>         return None
373 |     ...    
374 |     ...    do something for inputs
375 |     ...
376 |     >>>     my_dict = {
377 |     >>>             'prompt': inputs['p'],
378 |     >>>             'response': inputs['a1'] + inputs['a2'],
379 |     >>>             ...
380 |     >>>         }
381 |     >>>     return my_dict
382 |     '''
383 | 
384 |     log.info('process file:{}'.format(read_file), save_to_file=True)
385 |     start = time.time()
386 |     
387 |     raw_line_cnt = 0
388 |     keep_line_cnt = 0
389 |     with progress.open(read_file, 'r', encoding='utf-8') as f_read:
390 |         cur_rows = []
391 |         append = cur_rows.append
392 |         for line in f_read:
393 |             try:
394 |                 #print(line)
395 |                 raw_line_cnt += 1
396 |                 write_dict = call_back(line)
397 |                 if write_dict is None: continue
398 |                 keep_line_cnt += 1
399 |                 append(write_dict)
400 |                 if len(cur_rows) >= group_cnt:
401 |                     df = pd.DataFrame(cur_rows)
402 |                     write_single_parquet_file(write_to_file, df)
403 |                     cur_rows = []
404 |                     append = cur_rows.append
405 |             except Exception as e:
406 |                 # log.error('处理文件异常:{}, content:{}'.format(str(e), line))
407 |                 print(line)
408 |                 raise e
409 |             # end for
410 |             # 处理末尾部分
411 |         if len(cur_rows) > 0:
412 |             df = pd.DataFrame(cur_rows)
413 |             write_single_parquet_file(write_to_file, df)
414 |             cur_rows = []
415 |         end = time.time()
416 |         log.info('原始文件:{},共{}行,处理后剩余{}行,保存到文件:{}。耗时:{:.6}s'\
417 |                     .format(read_file, raw_line_cnt, keep_line_cnt, write_to_file, end - start), save_to_file=True)
418 | 
419 | def process_baike(response_less_word: int=15) -> None:
420 |     file_names = [
421 |         '../data/563w_baidubaike/563w_baidubaike.json',
422 |     ]
423 |     save_file_name = '../data/563w_baidubaike/baike.parquet'
424 |     if exists(save_file_name): 
425 |         assert delete_file(save_file_name)
426 | 
427 |     def process_function(line: str) -> dict:
428 |         
429 |         item = ujson.loads(line)
430 |         item_title = item['title']
431 |         item_sections = item ['sections']
432 |         for data in item_sections:
433 |             #print(item['completion'])
434 |             # 数据清洗
435 |             response = data['content'].replace('\r','')
436 |             response = remove_duplicate_punctuation(response)
437 |             # 剔除短数据
438 |             if len(response) < response_less_word:
439 |                 return None
440 |             response = data['title']+data['content']
441 |             write_dict = {
442 |                     "response": response,
443 |                 }
444 |             return write_dict
445 | 
446 |     for file_name in file_names:
447 |         read_file = file_name
448 |         read_and_write_template_baike(read_file, save_file_name, process_function)
449 | 
450 | #https://blog.csdn.net/m0_63834988/article/details/135000567
451 | #了解rich,pyarrow,parquet等包,minhash算法
452 | 
453 | if __name__ == '__main__':
454 |     
455 |     #查看原始文件内容
456 |     #data=open('../data/563w_baidubaike.json','r')
457 |     #for line in data.readlines()[:10]:
458 |     #    print(line)
459 |     
460 |     #process_wiki()
461 |     # 内容查看
462 |     #parquet_table = pq.read_table('./data/baike.parquet')
463 |     #data = parquet_table.to_pandas()
464 |     #print(data.head())
465 | 
466 |     # 将原始文件进行短文本过滤 + 存储为.parquet格式,可以有效减小存储占用
467 |     process_baike()
468 |     
469 |     #合并
470 |     #merge_dataset_as_single_file()
471 |     
472 |     #去重
473 |     # minhash (推荐)
474 |     remove_dataset_duplicate_rows()
475 | 
476 |     # simhash
477 |     #remove_dataset_duplicate_rows_simhash()


--------------------------------------------------------------------------------
/data_clean/functions.py:
--------------------------------------------------------------------------------
  1 | from collections import Counter
  2 | from typing import Union
  3 | from dataclasses import make_dataclass, field
  4 | from transformers import T5Config
  5 | import ctypes
  6 | import os
  7 | import platform
  8 | import re
  9 | import torch
 10 | 
 11 | from datasketch import MinHash, MinHashLSH
 12 | from collections import defaultdict
 13 | from transformers.trainer_callback import TrainerControl, TrainerState
 14 | from transformers import TrainingArguments, TrainerCallback
 15 | 
 16 | import jieba
 17 | import pandas as pd
 18 | from simhash import Simhash, SimhashIndex
 19 | 
 20 | from nltk import ngrams
 21 | from nltk.translate.bleu_score import sentence_bleu
 22 | import numpy as np
 23 | import ujson
 24 | 
 25 | # 结束标点符号
 26 | END_PUN = set(".。!!))》}】??\"”")
 27 | 
 28 | class MyTrainerCallback(TrainerCallback):
 29 |     log_cnt = 0
 30 |     def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 31 |         '''
 32 |         在打印 n 次日志后清除cuda缓存,适合低显存设备,能防止OOM
 33 |         '''
 34 |         self.log_cnt += 1
 35 |         if self.log_cnt % 2 == 0:
 36 |             torch.cuda.empty_cache()
 37 |     
 38 |     def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
 39 |         '''
 40 |         在 on_epoch_end 时保存一次模型。
 41 |         TrainingArguments的 save_strategy 中 epoch 和 steps 不兼容。要实现每隔 save_steps 步保存一次检查点,考虑到磁盘空间大小,最多只保存最近N个检查点。
 42 |         '''
 43 |         # 设置should_save=True并返回即可
 44 |         control.should_save = True
 45 |         return control
 46 | 
 47 | 
 48 | # 保留中文和英文、下划线,不要标点符号
 49 | NON_CHAR = re.compile("[^[\u4E00-\u9FA5|A-Za-z_0-9]")
 50 | 
 51 | def _get_doc_mini_hash(doc, num_perm: int) -> MinHash:
 52 |     '''
 53 |     获取一段文本的mini hash
 54 |     '''
 55 |     mini_hash = MinHash(num_perm=num_perm)
 56 |     for s in doc:
 57 |         mini_hash.update(s.encode('utf-8'))
 58 |     return mini_hash
 59 | 
 60 | class DropDatasetDuplicate:
 61 | 
 62 |     def __init__(self,  threshold: float=0.85, num_perm: int=256) -> None:
 63 |         '''
 64 |         获取一个数据集中所有重复(相似的超过threshold)的index,输入为:list[str],一个str元素为一段文本(doc)
 65 |         如输入: [a, b, c, d, c, d, e] 返回:{4, 5} (后面两个 c, d 的index)
 66 | 
 67 |         MinHashLSH 参数说明:
 68 |         threshold (float):Jaccard 距离阈值设定,默认为0.9
 69 |         num_perm (int, optional):哈希置换函数设定个数,在weighted-MinHash中为样本规模大小。
 70 |         weights (tuple, optional):优化Jaccard 阈值,能够弹性选择。
 71 |         params (tuple, optional):bands 的数量与规模大小。
 72 |         '''
 73 |         self.similar_index_cluster = defaultdict(set)
 74 |         #
 75 |         self.data_lsh = MinHashLSH(threshold=threshold, num_perm=num_perm)
 76 |         self.num_perm = num_perm
 77 | 
 78 |     def add_doc(self, index, doc: str):
 79 |         '''
 80 |         添加文档,
 81 |         index: 文档的索引
 82 |         doc: 文档本身
 83 |         '''
 84 | 
 85 |         # 只保留中文和英文、下划线,不要标点符号 分词!!!
 86 |         doc = ''.join(NON_CHAR.split(doc))
 87 |         # doc = [''.join(t) for t in list(ngrams(doc, 3))]
 88 | 
 89 |         doc_hash = _get_doc_mini_hash(doc, self.num_perm)
 90 |         close_duplicates = self.data_lsh.query(doc_hash)
 91 | 
 92 |         self.data_lsh.insert(index, doc_hash)
 93 | 
 94 |         # 所有相似的doc在similar_index_cluster中的key都是最早出现的idx
 95 |         # 如:data中索引index 2, 7, 8, 9, 10, 12 是相似的,则在similar_index_cluster中表现为 {2: {8, 9, 10, 12}}
 96 |         if len(close_duplicates) > 0:
 97 |             min_idx= min(close_duplicates)
 98 |             self.similar_index_cluster[min_idx].add(index)
 99 |     
100 |     def get_duplicate_indexs(self):
101 |         '''
102 |         返回所有的重复文档索引
103 |         '''
104 |         similar_index_cluster = self.similar_index_cluster
105 |         need_to_remove_idx = set()
106 |         
107 |         for key_idx in similar_index_cluster.keys():
108 |             need_to_remove_idx |= similar_index_cluster[key_idx]  # 并集
109 | 
110 |         return need_to_remove_idx
111 | 
112 | # ambrose add
113 | class DropDatasetDuplicate_SimHash:
114 |     def __init__(self, threshold: int = 3, f: int = 64) -> None:
115 |         '''
116 |         threshold: 汉明距离阈值
117 |         f: 哈希值的长度
118 |         '''
119 |         self.database = {}
120 |         self.dupcount = 0
121 |         self.index = SimhashIndex([], k=threshold, f=f)
122 | 
123 |         self.threshold = threshold
124 |         self.f = f
125 | 
126 |     def get_features(self, s: str):
127 |         '''
128 |         文本预处理(正则、分词、滑窗) 滑窗的目的是增加上下文。预处理视情况进行增删
129 |         s: 文档
130 |         return: List[str] 分词后的文档
131 |         '''
132 |         width = 3
133 |         s = s.lower()
134 |         s = re.sub(r'[^\w]+', '', s)
135 |         return [s[i:i + width] for i in range(max(len(s) - width + 1, 1))]
136 | 
137 |     def add_doc(self, index, doc: str):
138 |         '''
139 |         添加文档,
140 |         index: 文档的索引
141 |         doc: 文档本身
142 |         '''
143 | 
144 |         if index == 0:
145 |             self.database[index] = doc
146 |             self.index.add(str(index), Simhash(self.get_features(doc), f=self.f))
147 |         else:
148 |             s1 = Simhash(self.get_features(doc), f=self.f)
149 |             if self.index.get_near_dups(s1) == []:
150 |                 self.database[index] = doc
151 | 
152 |                 self.index.add(str(index), s1)
153 |             else:
154 |                 self.dupcount += 1
155 | 
156 | 
157 | 
158 | def f1_p_r_compute(spo_list_pred: list, spo_list_true: list, repair: bool=False):
159 |     '''
160 |     spo_list: [ [(s,p,o)...], [(s,p,o)]], 每一行[(s,p,o)...]为一个句子中的spo
161 |     计算spo的f1分数,精确率,召回率,
162 |     '''
163 |     assert len(spo_list_pred) == len(spo_list_true)
164 | 
165 |     def repair_song_album(spo_list: list, song: list, album: list):
166 |         '''
167 |         修复一条文本的'歌曲'和'专辑'的spo。对于歌曲x(subject)的关系歌手、作词、作曲,x必须同时存在于song和album中
168 |         '''
169 |         if len(song) == 0 and len(album) == 0:
170 |             return spo_list
171 | 
172 |         ps = ['歌手', '作词', '作曲']
173 |         new_spo_list = []
174 |         for spo in spo_list:
175 |             s, p = spo[0], spo[1]
176 |             if p in ps and s in album and s not in song:
177 |                 continue
178 |             new_spo_list.append(spo)
179 |         
180 |         return new_spo_list
181 | 
182 |     def repair_song_album_list(spo_list: list):
183 |         '''
184 |         '''
185 |         new_spo_list = []
186 |         for spos in spo_list:
187 |             song, album = [], []
188 |             for spo in spos:
189 |                 s, p, o = spo
190 |                 if p == '所属专辑':
191 |                     song.append(s)
192 |                     album.append(o)
193 |             new_spo_list.append(repair_song_album(spos, song, album))
194 |         
195 |         return new_spo_list
196 |     if repair:
197 |         spo_list_pred = repair_song_album_list(spo_list_pred)
198 |         spo_list_true = repair_song_album_list(spo_list_true)
199 | 
200 |     TP = 1e-10      # 正类判定为正类, A
201 |     # TN = 1e-10    # 负类判定为负类
202 |     TP_FP = 1e-10   # 检索到的, A + B
203 |     TP_FN = 1e-10   # 真正想要的,A + C
204 |     # FP = 1e-10    # 负类判定为正类
205 |     # FN = 1e-10    # 正类判定为负类
206 | 
207 |     # p = a / (a + b)
208 |     # r = a / (a + c)
209 |     # f1 = 2pr / (p + r)
210 | 
211 |     for i in range(len(spo_list_true)):
212 |         pred_set = set(spo_list_pred[i])
213 |         true_set = set(spo_list_true[i])
214 | 
215 |         pred_true_set = pred_set & true_set     # 预测和真实取交集
216 | 
217 |         TP += len(pred_true_set)    # 检索到且是想要的, A
218 |         TP_FP += len(pred_set)      # 检索到的,包括想要的和不想要的,A + B
219 |         TP_FN += len(true_set)      # 真正想要的, 包括检索到和没检索到的,A + C
220 | 
221 |     p = TP / TP_FP
222 |     r = TP / TP_FN
223 |     f1 = (2 * p * r) / (p + r)
224 |     
225 |     return f1, p, r
226 | 
227 | 
228 | def fixed_response(item: str) -> str:
229 |     '''
230 |     修复被截断的回答,从末尾往回找第一个结束标点
231 |     '''
232 |     if len(item) <= 1: return item
233 |     if item[-1] in END_PUN: return item
234 | 
235 |     n = len(item)
236 |     i = n - 1
237 |     while i > 0 and item[i] not in END_PUN:
238 |         i -= 1
239 | 
240 |     return ''.join(item[0: i + 1])
241 | 
242 | 
243 | def fixed_space(sentence: str)->str:
244 |     '''单个空格删除,连续两个空格保留一个
245 |     '''
246 |     n = len(sentence)
247 |     new_sentence = []
248 |     i = 0
249 |     while i < n:
250 |         word =  sentence[i]
251 |         if word != ' ':
252 |             new_sentence.append(word)
253 |         elif i + 1 < n and sentence[i + 1] == ' ':
254 |             new_sentence.append(word)
255 |             i += 1 # 两个空格保留一个,指针往下走一步
256 |         i += 1
257 | 
258 |     return ''.join(new_sentence)
259 | 
260 | def get_free_space_of_disk(folder: str='./') -> float:
261 |     '''
262 |     获取指定目录所在磁盘大小,返回单位: GB
263 |     '''
264 |     res_val = 0.0
265 |     if platform.system() == 'Windows':
266 |         free_bytes = ctypes.c_ulonglong(0)
267 |         ctypes.windll.kernel32.GetDiskFreeSpaceExW(ctypes.c_wchar_p(folder), None, None, ctypes.pointer(free_bytes))
268 |         res_val = free_bytes.value 
269 |     else:
270 |         st = os.statvfs(folder)
271 |         res_val = st.f_bavail * st.f_frsize
272 |     
273 |     return res_val / (1024 ** 3)
274 | 
275 | def my_average(arry_list) -> float:
276 |     '''
277 |     自定义均值计算,空数组返回0.0
278 |     '''
279 |     if len(arry_list) == 0: return 0.0
280 | 
281 |     return np.average(arry_list)
282 | 
283 | 
284 | def json_to_dataclass(json_file: str, class_name: str='Config') -> type:
285 |     '''
286 |     将json配置文件转换为dataclass
287 |     >>> example:
288 |     >>> data_class = json_to_dataclass('my_config.json', 'Config')
289 |     >>> my_config = data_class()
290 |     >>> assert my_config.name == 'Alice'
291 |     >>> my_config.name = 'Bob' 
292 |     '''
293 |     json_dict = {}
294 |     with open(json_file, 'r', encoding='utf-8') as f:
295 |         json_dict = ujson.load(f)
296 | 
297 |     # 将dict转换为可迭代的属性名称、属性类型,默认值
298 |     fields_list = []
299 |     for k, v in json_dict.items():
300 |         fields_list.append( (k, type(v), field(default=v)) )
301 |     
302 |     data_class = make_dataclass(cls_name=class_name, fields=fields_list)
303 | 
304 |     return data_class
305 | 
306 | 
307 | def get_path_of_suffix_files(root: str, suffix: str, with_create_time: bool=False) -> list:
308 |     '''
309 |         获取指定目录下下指定后缀的所有文件的绝对路径
310 |     '''
311 |     suffix_files = []
312 |     for root, _, files in os.walk(root):
313 |         for file in files:
314 |             if file.endswith(suffix):
315 |                 full_path = '{}/{}'.format(root, file)
316 |                 if with_create_time:
317 |                     suffix_files.append( (full_path, os.path.getctime(full_path)) )
318 |                 else:
319 |                     suffix_files.append(full_path)
320 |                             
321 |     return suffix_files
322 | 
323 | def get_bleu4_score(reference, outputs, n_gram=4):
324 |     '''
325 |     获取bleu4分数
326 |     '''
327 |     
328 |     weights = np.ones(n_gram) * (1.0 / n_gram)
329 | 
330 |     outputs_len, reference_len = len(outputs), len(reference)
331 | 
332 |     if not type(reference) is list:
333 |         reference = list(reference)
334 |     if not type(outputs) is list:
335 |         outputs = list(outputs)
336 | 
337 |     outputs_counter = extract_Ngram(outputs, n_gram=n_gram)
338 |     reference_counter = extract_Ngram(reference, n_gram=n_gram)
339 | 
340 |     ngram_counter_clip = outputs_counter & reference_counter
341 | 
342 |     clip_counter = np.zeros(n_gram)
343 |     output_ngram_counter = np.zeros(n_gram)
344 | 
345 |     for (key, ngram), cnt in ngram_counter_clip.items():
346 |         clip_counter[ngram - 1] += cnt 
347 |     
348 |     for (key, ngram), cnt in outputs_counter.items():
349 |         output_ngram_counter[ngram - 1] += cnt
350 |     
351 |     # print(clip_counter, output_ngram_counter)
352 |     if np.min(clip_counter) == 0.0:
353 |         return np.array(0.0)
354 | 
355 |     precision_scores = clip_counter / output_ngram_counter
356 |    
357 |     # bleu
358 |     log_precision_scores = weights * np.log(precision_scores)
359 |     
360 |     # 几何平均形式求平均值然后加权
361 |     geometric_mean = np.exp(np.sum(log_precision_scores))
362 |     brevity_penalty = np.exp(1 - (reference_len / outputs_len))
363 | 
364 |     # brevity_penalty = 1.0,   bleu = sentence_bleu([reference], outputs)
365 |     # brevity_penalty = 1.0
366 | 
367 |     bleu = brevity_penalty * geometric_mean
368 | 
369 |     return bleu
370 | 
371 | 
372 | def extract_Ngram(words_list, n_gram):
373 |     '''
374 |     获取一个句子的n_grama
375 |     return:
376 |         ngram_counter: key = ('w1  w2 ... wn', n_gram), value: count of key
377 |     '''
378 |     n = len(words_list)
379 |     ngram_counter = Counter()
380 | 
381 |     for i in range(1, n_gram + 1):
382 |         for j in range(n - i + 1):
383 |             key = ' '.join(words_list[j: j + i])
384 |             ngram_counter[(key, i)] += 1
385 | 
386 |     return ngram_counter
387 | 
388 | 
389 | def save_model_config(config_dict, file):
390 |     '''
391 |     将模型配置写入到json文件, 输入模型保存的目录及文件名
392 |     '''
393 |     # file = file.replace('\\', '/')
394 |     # file = '{}/model_config.json'.format('/'.join(file.split('/')[0: -1]))
395 |     
396 |     with open(file, 'w', encoding='utf-8') as f:
397 |         ujson.dump(config_dict, f, indent=4, ensure_ascii=False)
398 | 
399 | if __name__ == '__main__':
400 |     ref = '抱歉,我不知道ABB代表什么意思'
401 |     out = '我不明白ABB是什么意思'
402 |     b1 = sentence_bleu([list(out)], list(ref),  weights=(0.25, 0.25, 0.25, 0.25))
403 |     print(b1)
404 |     b2 = get_bleu4_score(out, ref)
405 |     print(b2)
406 | 
407 |     
408 |     candidate_corpus = ['i', 'have', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'c', 'd','f','f']
409 |     reference_corpus = ['there', 'is', 'a', 'pen', 'on', 'my', 'desk', 'a', 'b', 'd', 'd', 'fd']
410 |     
411 |     print('----')
412 |     print(sentence_bleu([reference_corpus], candidate_corpus,  weights=(0.25, 0.25, 0.25, 0.25)))
413 |     print(get_bleu4_score(reference_corpus, candidate_corpus))


--------------------------------------------------------------------------------
/data_clean/logger.py:
--------------------------------------------------------------------------------
  1 | import logging
  2 | from os.path import dirname, abspath
  3 | import os
  4 | import colorlog 
  5 | import time
  6 | 
  7 | # 自定义日志格式
  8 | class Logger(object):
  9 |     def __init__(self, logger_name: str, level=logging.DEBUG, std_out: bool=True, save2file: bool=False, file_name: str=None) ->None:
 10 |         super().__init__()
 11 | 
 12 |         if std_out == False and save2file == False:
 13 |             raise ValueError('args: [std_out, save2file], at less one of them must be True')
 14 | 
 15 |         # 默认的格式化
 16 |         datefmt = "%Y-%m-%d %H:%M:%S"
 17 |         
 18 |         # 输出到控制台
 19 |         if std_out:
 20 |             
 21 |             std_logfmt = "[%(asctime)s.%(msecs)03d] [%(levelname)s]: %(log_color)s%(message)s"
 22 | 
 23 |             self.stdout_logger = logging.getLogger('{}_std'.format(logger_name))
 24 |             self.stdout_logger.setLevel(level)
 25 | 
 26 |              # 彩色输出格式化
 27 |             log_colors_config = {
 28 |                 'DEBUG': 'cyan',
 29 |                 'INFO': 'green',
 30 |                 'WARNING': 'yellow',
 31 |                 'ERROR': 'red',
 32 |                 'CRITICAL': 'red'
 33 |             }
 34 |             formatter = colorlog.ColoredFormatter(
 35 |                         fmt=std_logfmt,
 36 |                         datefmt=datefmt,
 37 |                         log_colors=log_colors_config,
 38 |                         )
 39 |             
 40 |             sh = logging.StreamHandler()
 41 |             sh.setLevel(level)        
 42 |             sh.setFormatter(formatter)
 43 |             
 44 |             self.stdout_logger.addHandler(sh)
 45 |        
 46 |                     
 47 |          # 输出到文件
 48 |         if save2file:
 49 | 
 50 |             file_logfmt = "[%(asctime)s.%(msecs)03d] [%(levelname)s]: %(message)s"
 51 | 
 52 |             self.file_logger = logging.getLogger('{}_file'.format(logger_name))
 53 |             self.file_logger.setLevel(level)
 54 | 
 55 |             base_dir ='./logs' # 获取上级目录的绝对路径
 56 |             if not os.path.exists(base_dir):
 57 |                 os.mkdir(base_dir)
 58 |             
 59 |             log_file = ''
 60 |             if file_name is not None:
 61 |                 log_file = file_name
 62 |             else:
 63 |                 log_file = base_dir + '/' + logger_name  + '-' + str(time.strftime('%Y%m%d', time.localtime())) +'.log'
 64 | 
 65 |             fh = logging.FileHandler(filename=log_file, mode='a', encoding='utf-8')
 66 |             fh.setLevel(level)
 67 |             save_formatter =  logging.Formatter(
 68 |                 fmt=file_logfmt,
 69 |                 datefmt=datefmt,
 70 |                 )
 71 |             fh.setFormatter(save_formatter)
 72 |             self.file_logger.addHandler(fh)
 73 | 
 74 |     def info(self, message: str, std_out: bool=True, save_to_file: bool=False) -> None:
 75 |         if std_out:
 76 |             self.stdout_logger.info(message)
 77 |         if save_to_file:
 78 |             self.file_logger.info(message)
 79 | 
 80 |     def debug(self, message: str, std_out: bool=True, save_to_file: bool=False) -> None:
 81 |         if std_out:
 82 |             self.stdout_logger.debug(message)
 83 |         if save_to_file:
 84 |             self.file_logger.debug(message)
 85 | 
 86 |     def warning(self, message: str, std_out: bool=True, save_to_file: bool=False) -> None:
 87 |         if std_out:
 88 |             self.stdout_logger.warning(message)
 89 |         if save_to_file:
 90 |             self.file_logger.warning(message)
 91 | 
 92 |     def error(self, message: str, std_out: bool=True, save_to_file: bool=False) -> None:
 93 |         if std_out:
 94 |             self.stdout_logger.error(message)
 95 |         if save_to_file:
 96 |             self.file_logger.error(message)
 97 | 
 98 | if __name__ == "__main__":
 99 |     log = Logger('test', std_out=True, save2file=True, file_name='../logs/test.log')
100 |     # log = Logger('test', save2file=True)
101 |     log.info('test info')
102 |     log.info('test file log', save_to_file=True)


--------------------------------------------------------------------------------
/data_process.py:
--------------------------------------------------------------------------------
  1 | import json
  2 | import glob
  3 | import numpy as np
  4 | from tqdm import tqdm
  5 | from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
  6 | import pandas as pd
  7 | #from zhconv import convert
  8 | def process_wiki_clean():
  9 |     with open('./data/wikipedia_cn_20230720/wikipedia-cn-20230720-filtered.json','r',encoding='utf-8') as f:
 10 |         data=json.load(f)
 11 |     doc_ids=[]
 12 |     for line in tqdm(data):
 13 |         text=line['completion']
 14 |         text_id=tokenizer.encode(text,add_special_tokens=False)
 15 |         text_id.append(tokenizer.special_tokens['<eos>'])
 16 |         if len(text_id)>5:
 17 |             doc_ids+=text_id
 18 |     arr = np.array(doc_ids,dtype=np.uint16)
 19 |     with open('./data/wiki.bin','wb') as f:
 20 |         f.write(arr.tobytes())
 21 | 
 22 | def process_medical(data_path,name):
 23 |     f=open(data_path,'r',encoding='utf-8')
 24 |     doc_ids=[]
 25 |     while True:
 26 |         line=f.readline()
 27 |         if not line:
 28 |             break
 29 |         line=json.loads(line)
 30 |         text=line['text']
 31 |         text_id=tokenizer.encode(text,add_special_tokens=False)
 32 |         text_id.append(tokenizer.special_tokens['<eos>'])
 33 |         if len(text_id)>5:
 34 |             doc_ids+=text_id
 35 |     arr = np.array(doc_ids,dtype=np.uint16)
 36 |     with open('./data/medical_{}.bin'.format(name),'wb') as f:
 37 |         f.write(arr.tobytes()) 
 38 | 
 39 | def sft_to_pretrain():
 40 |     doc_ids=[]
 41 | 
 42 |     '''
 43 |     df=pd.read_csv('./data/medical_qa_144w.csv')
 44 |     for _,q,a in tqdm(df.itertuples()):
 45 |         q_id = tokenizer.encode(q,add_special_tokens=False)
 46 |         a_id = tokenizer.encode(a,add_special_tokens=False)
 47 |         #
 48 |         print(q)
 49 |         print(a)
 50 |         print('-----')
 51 |         text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
 52 |         if len(text_id)>5:
 53 |             doc_ids+=text_id
 54 |     '''
 55 | 
 56 |     with open('./data/shibing624_medical/finetune/train_en_1.json','r',encoding='utf-8') as f:
 57 |         for row in f:
 58 |             line=json.loads(row)
 59 |             q=line['input']
 60 |             a=line['output']
 61 |             q_id=tokenizer.encode(q,add_special_tokens=False)
 62 |             a_id=tokenizer.encode(a,add_special_tokens=False)
 63 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
 64 |             if len(text_id)>5:
 65 |                 doc_ids+=text_id
 66 |     with open('./data/shibing624_medical/finetune/test_en_1.json','r',encoding='utf-8') as f:
 67 |         for row in f:
 68 |             line=json.loads(row)
 69 |             q=line['input']
 70 |             a=line['output']
 71 |             q_id=tokenizer.encode(q,add_special_tokens=False)
 72 |             a_id=tokenizer.encode(a,add_special_tokens=False)
 73 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
 74 |             if len(text_id)>5:
 75 |                 doc_ids+=text_id
 76 |     with open('./data/shibing624_medical/finetune/valid_en_1.json','r',encoding='utf-8') as f:
 77 |         for row in f:
 78 |             line=json.loads(row)
 79 |             q=line['input']
 80 |             a=line['output']
 81 |             q_id=tokenizer.encode(q,add_special_tokens=False)
 82 |             a_id=tokenizer.encode(a,add_special_tokens=False)
 83 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
 84 |             if len(text_id)>5:
 85 |                 doc_ids+=text_id
 86 | 
 87 |     with open('./data/shibing624_medical/finetune/train_zh_0.json','r',encoding='utf-8') as f:
 88 |         for row in f:
 89 |             line=json.loads(row)
 90 |             q=line['instruction']+line['input']
 91 |             a=line['output']
 92 |             q_id=tokenizer.encode(q,add_special_tokens=False)
 93 |             a_id=tokenizer.encode(a,add_special_tokens=False)
 94 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
 95 |             if len(text_id)>5:
 96 |                 doc_ids+=text_id
 97 |     with open('./data/shibing624_medical/finetune/test_zh_0.json','r',encoding='utf-8') as f:
 98 |         for row in f:
 99 |             line=json.loads(row)
100 |             q=line['instruction']+line['input']
101 |             a=line['output']
102 |             q_id=tokenizer.encode(q,add_special_tokens=False)
103 |             a_id=tokenizer.encode(a,add_special_tokens=False)
104 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
105 |             if len(text_id)>5:
106 |                 doc_ids+=text_id
107 |     with open('./data/shibing624_medical/finetune/valid_zh_0.json','r',encoding='utf-8') as f:
108 |         for row in f:
109 |             line=json.loads(row)
110 |             q=line['instruction']+line['input']
111 |             a=line['output']
112 |             q_id=tokenizer.encode(q,add_special_tokens=False)
113 |             a_id=tokenizer.encode(a,add_special_tokens=False)
114 |             text_id=q_id+a_id+[tokenizer.special_tokens['<eos>']]
115 |             if len(text_id)>5:
116 |                 doc_ids+=text_id
117 | 
118 |     arr = np.array(doc_ids,dtype=np.uint16)
119 |     print(arr.shape)
120 |     with open('./data/medical_qa.bin','wb') as f:
121 |         f.write(arr.tobytes())
122 | 
123 | def process_baidu():
124 |     BATCH_SIZE = 1000000
125 | 
126 |     cnt=0
127 |     batch_cnt=0
128 |     token=0
129 |     doc_ids=[]
130 | 
131 |     f1=open('./data/563w_baidubaike/563w_baidubaike.json','r',encoding='utf-8')
132 |     
133 |     while True:
134 |         line = f1.readline()
135 |         if not line:
136 |             break
137 |         line=json.loads(line)
138 |         text=''
139 |         try:
140 |             text+=line['title']+':'+line['summary']
141 |         except:
142 |             pass
143 |         for per in line['sections']:
144 |             text+=per['title']+':'+per['content']+'。'
145 |         text_id=tokenizer.encode(text,add_special_tokens=False)
146 |         text_id.append(tokenizer.special_tokens['<eos>'])
147 |         if len(text_id)>5:
148 |             doc_ids+=text_id
149 |         cnt+=1
150 |         if cnt%BATCH_SIZE==0:
151 |             batch_cnt+=1
152 |             arr = np.array(doc_ids,dtype=np.uint16)
153 |             doc_ids=[]
154 |             print('cnt:',cnt,'arr_shape:',arr.shape)
155 |             with open('./data/baidubaike_563w_{}.bin'.format(batch_cnt),'wb') as f2:
156 |                 f2.write(arr.tobytes())
157 |             del arr
158 | 
159 |     if not doc_ids:
160 |         batch_cnt+=1
161 |         arr = np.array(doc_ids,dtype=np.uint16)
162 |         print('cnt:',cnt,'arr_shape:',arr.shape)
163 |         with open('./data/baidubaike_563w_{}.bin'.format(batch_cnt),'wb') as f:
164 |             f.write(arr.tobytes())
165 |     
166 | def process_c4():
167 |     c4_zh_paths = glob.glob('./data/c4_zh/*')
168 |     c4_zh_paths=sorted(c4_zh_paths)
169 |     print(len(c4_zh_paths))
170 |     cnt=0
171 |     token=0
172 |     doc_ids=[]
173 |     for per in tqdm(c4_zh_paths):
174 |         with open(per,'r') as f:
175 |             for line in f:
176 |                 text = json.loads(line)
177 |                 text = text['text']
178 |                 text_id=tokenizer.encode(text,add_special_tokens=False)
179 |                 text_id.append(tokenizer.special_tokens['<eos>'])
180 |                 if len(text_id)>5:
181 |                     doc_ids+=text_id
182 |                 cnt+=1
183 | 
184 |     arr = np.array(doc_ids,dtype=np.uint16)
185 |     with open('./data/c4_zh.bin','wb') as f:
186 |         f.write(arr.tobytes())
187 |     print(arr.shape)
188 | 
189 | def process_wudao():
190 |     wudao_zh_paths = glob.glob('./data/WuDaoCorpus2.0_base_200G/*')
191 |     wudao_zh_paths=sorted(wudao_zh_paths)
192 |     print(len(wudao_zh_paths))#很多子文件
193 |     cnt=0
194 |     token=0
195 |     doc_ids=[]
196 |     for per in tqdm(wudao_zh_paths[320:]):#wudao_zh_paths[i:j]手动分片,一片片处理,不然太大一次性处理不完
197 |         with open(per,'r') as f:
198 |             data=json.load(f)
199 |             for text in data:
200 |                 text = text['title'] + text['content']
201 |                 text_id=tokenizer.encode(text,add_special_tokens=False)
202 |                 text_id.append(tokenizer.special_tokens['<eos>'])
203 |                 if len(text_id)>5:
204 |                     doc_ids+=text_id
205 |                 #
206 |                 # if cnt%10000==0:
207 |                 #     print(cnt)
208 |                 cnt+=1
209 |                 #token+=len(text_id)
210 |                 #break
211 |         #
212 |         # arr = np.array(doc_ids,dtype=np.uint16)
213 |         # with open('./data/c4-zh/{}.bin'.format(per.split('/')[-1].split('.')[0]),'wb') as f:
214 |         #     f.write(arr.tobytes())
215 |         # print(arr.shape)
216 |     arr = np.array(doc_ids,dtype=np.uint16)
217 |     with open('./data/wudaocorpus_zh_16.bin','wb') as f:
218 |         f.write(arr.tobytes())
219 |     print(arr.shape)
220 | 
221 | if __name__=="__main__":
222 |     tokenizer = ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
223 |     # 数据预处理-如果下载分词处理后的数据,可以不用执行以下函数
224 |     # process_wiki_clean()
225 |     # process_medical('./data/shibing624_medical/pretrain/medical_book_zh.json','book')
226 |     # process_medical('./data/shibing624_medical/pretrain/train_encyclopedia.json','encyclopedia')
227 |     # process_baidu()
228 |     # process_c4()
229 |     # process_wudao()
230 | 
231 |     # print('data processing finished!')
232 | 
233 |     # 分词处理后的文件列表
234 |     data_path_list=[
235 |         './data/baidubaike_563w_1.bin',
236 |         './data/baidubaike_563w_2.bin',
237 |         './data/baidubaike_563w_3.bin',
238 |         './data/baidubaike_563w_4.bin',
239 |         './data/baidubaike_563w_5.bin',
240 |         './data/medical_book.bin',
241 |         './data/medical_encyclopedia.bin',
242 |         './data/wiki.bin',
243 |         './data/c4_zh_0.bin',
244 |         './data/c4_zh_1.bin',
245 |         './data/c4_zh_2.bin',
246 |         './data/c4_zh_3.bin',
247 |         './data/c4_zh_4.bin',
248 |         './data/c4_zh_5.bin',
249 |         './data/c4_zh_6.bin',
250 |         './data/c4_zh_7.bin',
251 |         './data/c4_zh_8.bin',
252 |         './data/wudaocorpus_zh_0.bin',
253 |         './data/wudaocorpus_zh_1.bin',
254 |         './data/wudaocorpus_zh_2.bin',
255 |         './data/wudaocorpus_zh_3.bin',
256 |         './data/wudaocorpus_zh_4.bin',
257 |         './data/wudaocorpus_zh_5.bin',
258 |         './data/wudaocorpus_zh_6.bin',
259 |         './data/wudaocorpus_zh_7.bin',
260 |         './data/wudaocorpus_zh_8.bin',
261 |         './data/wudaocorpus_zh_9.bin',
262 |         './data/wudaocorpus_zh_10.bin',
263 |         './data/wudaocorpus_zh_11.bin',
264 |         './data/wudaocorpus_zh_12.bin',
265 |         './data/wudaocorpus_zh_13.bin',
266 |         './data/wudaocorpus_zh_14.bin',
267 |         './data/wudaocorpus_zh_15.bin',
268 |         './data/wudaocorpus_zh_16.bin',
269 |     ]
270 |     data_lst=[]
271 |     for data_path in tqdm(data_path_list):
272 |         with open(data_path,'rb') as f:
273 |             data=np.fromfile(f,dtype=np.uint16)
274 |             data_lst.append(data)
275 |     arr = np.concatenate(data_lst)
276 |     print(arr.shape)
277 |     with open('./data/pretrain_data.bin','wb') as f:
278 |         f.write(arr.tobytes())
279 | 


--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import random
 3 | import pandas as pd
 4 | import numpy as np
 5 | from torch.utils.data import Dataset,DataLoader
 6 | import torch
 7 | from sklearn.model_selection import train_test_split
 8 | 
 9 | class PretrainDataset(Dataset):
10 |     def __init__(self,data_path_lst,max_length=256,memmap=False):
11 |         super().__init__()
12 |         #
13 |         if memmap:
14 |             with open(data_path_lst[0],'r') as f:
15 |                 nbytes = f.seek(0,2)
16 |                 flen = f.tell() // np.dtype('uint16').itemsize
17 |             self.data = np.memmap(data_path_lst[0],dtype=np.dtype('uint16'),shape=(flen//max_length,max_length))
18 |         else:
19 |             data_lst=[]
20 |             for data_path in data_path_lst:
21 |                 with open(data_path,'rb') as f:
22 |                     data=np.fromfile(f,dtype=np.uint16)
23 |                     data_lst.append(data)
24 |             data = np.concatenate(data_lst)
25 |             data = data[:max_length*int(len(data)/max_length)]
26 |             #np.random.shuffle(data)
27 |             self.data = data.reshape(-1,max_length)
28 |         #
29 |         print("memmap:{} train data.shape:{}".format(memmap,self.data.shape))
30 |         print("downloading finished.....")
31 |         
32 |     def __len__(self):
33 |         return self.data.shape[0]
34 |     def __getitem__(self, index: int):
35 |         #
36 |         sample = self.data[index]
37 |         X=np.array(sample[:-1]).astype(np.int64)
38 |         Y=np.array(sample[1:]).astype(np.int64)
39 |         
40 |         return torch.from_numpy(X),torch.from_numpy(Y)
41 | #
42 | if __name__=="__main__":
43 |     pass


--------------------------------------------------------------------------------
/dataset_sft.py:
--------------------------------------------------------------------------------
 1 | 
 2 | import random
 3 | import pandas as pd
 4 | import numpy as np
 5 | from torch.utils.data import Dataset,DataLoader
 6 | import torch
 7 | from sklearn.model_selection import train_test_split
 8 | from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
 9 | class SFTDataset(Dataset):
10 |     def __init__(self,df,tokenizer
11 |                  ,max_length=256
12 |                  ,prompt_max_len=128
13 |                  ,answer_max_len=128):
14 |         super().__init__()
15 |         self.df=df
16 |         self.max_length = max_length
17 |         self.prompt_max_len = prompt_max_len
18 |         self.answer_max_len = answer_max_len
19 |         #
20 |         self.tokenizer = tokenizer
21 |         self.bos=self.tokenizer.special_tokens['<bos>']
22 |         self.eos=self.tokenizer.special_tokens['<eos>']
23 |         self.pad=0#self.tokenizer.special_tokens['<pad>']
24 |         
25 |     def __len__(self):
26 |         return self.df.shape[0]
27 |     def __getitem__(self, index: int):
28 |         #
29 |         sample = self.df.iloc[index]
30 |         prompt = self.tokenizer.encode(sample['prompt'],add_special_tokens=False)
31 |         answer = self.tokenizer.encode(sample['answer'],add_special_tokens=False)
32 |         if len(prompt) > self.prompt_max_len:
33 |             prompt = prompt[:self.prompt_max_len-2]
34 |         if len(answer) > self.answer_max_len:
35 |             answer = answer[:self.answer_max_len-2]
36 |         #
37 |         input_id=prompt+[self.bos]+answer+[self.eos]
38 |         context_length = input_id.index(self.bos)
39 |         mask_position = context_length - 1
40 |         pad_len = self.max_length - len(input_id)
41 |         input_id = input_id + [self.pad] * pad_len
42 |         if pad_len==0:
43 |             loss_mask = [0]*context_length+[1]*(len(input_id[mask_position+1:])) + [0]*pad_len
44 |         else:
45 |             loss_mask = [0]*context_length+[1]*(len(input_id[mask_position+1:-pad_len])) + [0]*pad_len
46 |         #
47 |         input_id=np.array(input_id)
48 |         X=np.array(input_id[:-1]).astype(np.int64)
49 |         Y=np.array(input_id[1:]).astype(np.int64)
50 |         loss_mask=np.array(loss_mask[:-1])
51 |         #
52 |         return torch.from_numpy(X),torch.from_numpy(Y),torch.from_numpy(loss_mask)
53 | #
54 | if __name__=="__main__":
55 |     df=pd.read_csv('./data/sft_data.csv')
56 |     tokenizer=ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
57 |     train_ds = SFTDataset(df,tokenizer,max_length=256)
58 |     train_loader = torch.utils.data.DataLoader(
59 |         train_ds,
60 |         batch_size=1,
61 |         pin_memory=False,
62 |         drop_last=False,
63 |         shuffle=False,        
64 |         num_workers=0,
65 |     )
66 |     for i, (X, Y,loss_mask) in enumerate(train_loader):
67 |         print(X.shape,Y.shape)
68 |         print(X[0])
69 |         print(Y[0])
70 |         print(loss_mask[0])
71 |         break


--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Sample from the trained model with PyTorch
  3 | """
  4 | import os
  5 | import json
  6 | from contextlib import nullcontext
  7 | import torch
  8 | from model import ModelArgs, Transformer
  9 | from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
 10 | import numpy as np
 11 | 
 12 | # def compute_bleu(labels, preds, weights=None):
 13 | #     from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
 14 | #     weights = weights or (0.25, 0.25, 0.25, 0.25)
 15 | #     return np.mean([sentence_bleu(references=[label],
 16 | #                                   hypothesis=pred,
 17 | #                                   smoothing_function=SmoothingFunction().method1,
 18 | #                                   weights=weights) for label, pred in zip(labels, preds)])
 19 | # -----------------------------------------------------------------------------
 20 | out_dir = 'out' # ignored if init_from is not 'resume'
 21 | start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
 22 | num_samples = 1 # number of samples to draw
 23 | max_new_tokens = 100 # number of tokens generated in each sample
 24 | temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
 25 | top_k = 100 # retain only the top_k most likely tokens, clamp others to have 0 probability
 26 | seed = 1337
 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
 28 | #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
 29 | dtype = "float32"
 30 | compile = False # use PyTorch 2.0 to compile the model to be faster
 31 | #exec(open('configurator.py').read()) # overrides from command line or config file
 32 | # -----------------------------------------------------------------------------
 33 | max_seq_len = 512
 34 | dim = 512
 35 | n_layers = 8
 36 | n_heads = 8
 37 | 
 38 | # max_seq_len = 1024
 39 | # dim = 1024
 40 | # n_layers = 12
 41 | # n_heads = 8
 42 | multiple_of = 32
 43 | dropout = 0.0 
 44 | model_args = dict(
 45 |         dim=dim,
 46 |         n_layers=n_layers,
 47 |         n_heads=n_heads,
 48 |         n_kv_heads=n_heads,
 49 |         vocab_size=64793,#64793,
 50 |         multiple_of=multiple_of,
 51 |         max_seq_len=max_seq_len,
 52 |         dropout=dropout,
 53 |     )  # s
 54 | torch.manual_seed(seed)
 55 | torch.cuda.manual_seed(seed)
 56 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
 57 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
 58 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
 59 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
 60 | ctx = nullcontext() if device_type == 'cpu' else torch.cuda.amp.autocast()
 61 | 
 62 | # init from a model saved in a specific directory
 63 | ckpt_path = 'out/sft_Llama2-Chinese-92M-v2/epoch_4.pth'
 64 | state_dict = torch.load(ckpt_path, map_location=device)
 65 | gptconf = ModelArgs(**model_args)
 66 | model = Transformer(gptconf)
 67 | unwanted_prefix = '_orig_mod.'
 68 | for k,v in list(state_dict.items()):
 69 |     if k.startswith(unwanted_prefix):
 70 |         state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
 71 | model.load_state_dict(state_dict, strict=False)
 72 | 
 73 | model.eval()
 74 | model.to(device)
 75 | if compile:
 76 |     print("Compiling the model...")
 77 |     model = torch.compile(model) # requires PyTorch 2.0 (optional)
 78 | 
 79 | # load the tokenizer
 80 | tokenizer=ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
 81 | #
 82 | # data = []
 83 | # with open('./test_data/test.json','r') as f:
 84 | #     for line in f:
 85 | #         data.append(json.loads(line))
 86 | 
 87 | #如果有标准答案,可以填到target里面,打开最后几行的注释,计算bleu分数。
 88 | #如果随便测试测试,那就只填你希望问的问题到question里面就可以。
 89 | data = [
 90 |     {"question": "最近我在办公室坐久了会感到头晕,请问这是什么原因?有什么缓解办法吗?", "target": ""},
 91 |     {"question": "前列腺囊肿的症状是什么?", "target": ""},
 92 |     {"question": "请问,世界上最大的动物是什么?", "target": ""},
 93 | ]
 94 | 
 95 | ans_lst=[]
 96 | target_lst=[]
 97 | for p in data:
 98 |     # run generation
 99 |     prompt=p['question']
100 |     x=tokenizer.encode(prompt,add_special_tokens=False)+[tokenizer.special_tokens['<bos>']]
101 |     x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
102 |     target = p['target']
103 |     target_lst.append(target)
104 |     with torch.no_grad():
105 |         with ctx:
106 |             y = model.generate(x, 2, max_new_tokens, temperature=temperature, top_k=top_k)
107 |             #
108 |             answer=tokenizer.decode(y[0].tolist())
109 |             answer=answer.replace(prompt,'')
110 |             ans_lst.append(answer)
111 |             print('[prompt]:',prompt)
112 |             print('[answer]:',answer)
113 |             print('---------------')
114 | #
115 | # import jieba
116 | # target_lst=[jieba.lcut(result.lower()) for result in target_lst]
117 | # preds_lst=[jieba.lcut(result.lower()) for result in ans_lst]
118 | # scores = compute_bleu(preds_lst, target_lst)
119 | # print(scores)
120 | 


--------------------------------------------------------------------------------
/eval_pretrain.py:
--------------------------------------------------------------------------------
  1 | """
  2 | Sample from the trained model with PyTorch
  3 | """
  4 | import os
  5 | import json
  6 | from contextlib import nullcontext
  7 | import torch
  8 | from model import ModelArgs, Transformer
  9 | from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
 10 | import numpy as np
 11 | 
 12 | # def compute_bleu(labels, preds, weights=None):
 13 | #     from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
 14 | #     weights = weights or (0.25, 0.25, 0.25, 0.25)
 15 | #     return np.mean([sentence_bleu(references=[label],
 16 | #                                   hypothesis=pred,
 17 | #                                   smoothing_function=SmoothingFunction().method1,
 18 | #                                   weights=weights) for label, pred in zip(labels, preds)])
 19 | # -----------------------------------------------------------------------------
 20 | out_dir = 'out' # ignored if init_from is not 'resume'
 21 | start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
 22 | num_samples = 1 # number of samples to draw
 23 | max_new_tokens = 100 # number of tokens generated in each sample
 24 | temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
 25 | top_k = 30 # retain only the top_k most likely tokens, clamp others to have 0 probability
 26 | seed = 1337
 27 | device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
 28 | #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
 29 | dtype = "float32"
 30 | compile = False # use PyTorch 2.0 to compile the model to be faster
 31 | #exec(open('configurator.py').read()) # overrides from command line or config file
 32 | # -----------------------------------------------------------------------------
 33 | # max_seq_len = 512
 34 | # dim = 512
 35 | # n_layers = 8
 36 | # n_heads = 8
 37 | 
 38 | max_seq_len = 1024
 39 | dim = 1024
 40 | n_layers = 12
 41 | n_heads = 8
 42 | multiple_of = 32
 43 | dropout = 0.0 
 44 | model_args = dict(
 45 |         dim=dim,
 46 |         n_layers=n_layers,
 47 |         n_heads=n_heads,
 48 |         n_kv_heads=n_heads,
 49 |         vocab_size=64793,#64793,
 50 |         multiple_of=multiple_of,
 51 |         max_seq_len=max_seq_len,
 52 |         dropout=dropout,
 53 |     )  # s
 54 | torch.manual_seed(seed)
 55 | torch.cuda.manual_seed(seed)
 56 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
 57 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
 58 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
 59 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
 60 | ctx = nullcontext() if device_type == 'cpu' else torch.cuda.amp.autocast()
 61 | 
 62 | # init from a model saved in a specific directory
 63 | ckpt_path = 'out/Llama2-Chinese-218M-v1/epoch_0.pth'
 64 | state_dict = torch.load(ckpt_path, map_location=device)
 65 | gptconf = ModelArgs(**model_args)
 66 | model = Transformer(gptconf)
 67 | unwanted_prefix = '_orig_mod.'
 68 | for k,v in list(state_dict.items()):
 69 |     if k.startswith(unwanted_prefix):
 70 |         state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
 71 | model.load_state_dict(state_dict, strict=False)
 72 | 
 73 | model.eval()
 74 | model.to(device)
 75 | if compile:
 76 |     print("Compiling the model...")
 77 |     model = torch.compile(model) # requires PyTorch 2.0 (optional)
 78 | 
 79 | # load the tokenizer
 80 | tokenizer=ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
 81 | #
 82 | # data = []
 83 | # with open('./test_data/test.json','r') as f:
 84 | #     for line in f:
 85 | #         data.append(json.loads(line))
 86 | 
 87 | data = [
 88 |     {"question": "床前明月光,疑是地上霜。举头望明月,"},
 89 |     {"question": "请你讲一个童话故事:"},
 90 |     {"question": "《小王子》是一本畅销童话书,它讲述了:"},
 91 | ]
 92 | 
 93 | ans_lst=[]
 94 | target_lst=[]
 95 | for p in data[:100]:
 96 |     # run generation
 97 |     prompt=p['question']
 98 |     x=tokenizer.encode(prompt,add_special_tokens=False)
 99 |     x = (torch.tensor(x, dtype=torch.long, device=device)[None, ...])
100 |     with torch.no_grad():
101 |         with ctx:
102 |             y = model.generate(x, 2, max_new_tokens, temperature=temperature, top_k=top_k)
103 |             #
104 |             answer=tokenizer.decode(y[0].tolist())
105 |             answer=answer.replace(prompt,'')
106 |             ans_lst.append(answer)
107 |             print('[prompt]:',prompt)
108 |             print('[answer]:',answer)
109 |             print('---------------')
110 | #
111 | # import jieba
112 | # target_lst=[jieba.lcut(result.lower()) for result in target_lst]
113 | # preds_lst=[jieba.lcut(result.lower()) for result in ans_lst]
114 | # scores = compute_bleu(preds_lst, target_lst)
115 | # print(scores)
116 | 


--------------------------------------------------------------------------------
/images/loss_tokens-v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLLXW/baby-llama2-chinese/98a20dbb35e686a62188f61f479809cb2d4f8d6e/images/loss_tokens-v1.png


--------------------------------------------------------------------------------
/images/loss_tokens-v3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLLXW/baby-llama2-chinese/98a20dbb35e686a62188f61f479809cb2d4f8d6e/images/loss_tokens-v3.png


--------------------------------------------------------------------------------
/images/loss_tokens.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DLLXW/baby-llama2-chinese/98a20dbb35e686a62188f61f479809cb2d4f8d6e/images/loss_tokens.png


--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import struct
  3 | import inspect
  4 | from dataclasses import dataclass
  5 | from typing import Any, Optional, Tuple
  6 | 
  7 | import numpy as np
  8 | import torch
  9 | import torch.nn.functional as F
 10 | from torch import nn
 11 | 
 12 | @dataclass
 13 | class ModelArgs:
 14 |     dim: int = 4096
 15 |     n_layers: int = 32
 16 |     n_heads: int = 32
 17 |     n_kv_heads: Optional[int] = None
 18 |     vocab_size: int = -1  # defined later by tokenizer
 19 |     multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
 20 |     norm_eps: float = 1e-5
 21 |     max_seq_len: int = 2048
 22 |     dropout: float = 0.0
 23 | 
 24 | 
 25 | class RMSNorm(torch.nn.Module):
 26 |     def __init__(self, dim: int, eps: float):
 27 |         super().__init__()
 28 |         self.eps = eps
 29 |         self.weight = nn.Parameter(torch.ones(dim))
 30 | 
 31 |     def _norm(self, x):
 32 |         return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
 33 | 
 34 |     def forward(self, x):
 35 |         output = self._norm(x.float()).type_as(x)
 36 |         return output * self.weight
 37 | 
 38 | 
 39 | def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
 40 |     freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
 41 |     t = torch.arange(end, device=freqs.device)  # type: ignore
 42 |     freqs = torch.outer(t, freqs).float()  # type: ignore
 43 |     freqs_cos = torch.cos(freqs)  # real part
 44 |     freqs_sin = torch.sin(freqs)  # imaginary part
 45 |     return freqs_cos, freqs_sin
 46 | 
 47 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
 48 |     ndim = x.ndim
 49 |     assert 0 <= 1 < ndim
 50 |     assert freqs_cis.shape == (x.shape[1], x.shape[-1])
 51 |     shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
 52 |     return freqs_cis.view(shape)
 53 | 
 54 | def apply_rotary_emb(
 55 |     xq: torch.Tensor,
 56 |     xk: torch.Tensor,
 57 |     freqs_cos: torch.Tensor,
 58 |     freqs_sin: torch.Tensor
 59 | ) -> Tuple[torch.Tensor, torch.Tensor]:
 60 | 
 61 |     # reshape xq and xk to match the complex representation
 62 |     xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
 63 |     xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
 64 | 
 65 |     # reshape freqs_cos and freqs_sin for broadcasting
 66 |     freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
 67 |     freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
 68 | 
 69 |     # apply rotation using real numbers
 70 |     xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
 71 |     xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
 72 |     xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
 73 |     xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
 74 | 
 75 |     # flatten last two dimensions
 76 |     xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
 77 |     xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
 78 | 
 79 |     return xq_out.type_as(xq), xk_out.type_as(xk)
 80 | 
 81 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
 82 |     """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
 83 |     bs, slen, n_kv_heads, head_dim = x.shape
 84 |     if n_rep == 1:
 85 |         return x
 86 |     return (
 87 |         x[:, :, :, None, :]
 88 |         .expand(bs, slen, n_kv_heads, n_rep, head_dim)
 89 |         .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
 90 |     )
 91 | 
 92 | class Attention(nn.Module):
 93 |     def __init__(self, args: ModelArgs):
 94 |         super().__init__()
 95 |         self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
 96 |         model_parallel_size = 1
 97 |         self.n_local_heads = args.n_heads // model_parallel_size
 98 |         self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
 99 |         self.n_rep = self.n_local_heads // self.n_local_kv_heads
100 |         self.head_dim = args.dim // args.n_heads
101 |         self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
102 |         self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
103 |         self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
104 |         self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
105 |         self.attn_dropout = nn.Dropout(args.dropout)
106 |         self.resid_dropout = nn.Dropout(args.dropout)
107 |         self.dropout = args.dropout
108 | 
109 |         # use flash attention or a manual implementation?
110 |         self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
111 |         if not self.flash:
112 |             print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
113 |             mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
114 |             mask = torch.triu(mask, diagonal=1)
115 |             self.register_buffer("mask", mask)
116 | 
117 |     def forward(
118 |         self,
119 |         x: torch.Tensor,
120 |         freqs_cos: torch.Tensor,
121 |         freqs_sin: torch.Tensor,
122 |     ):
123 |         bsz, seqlen, _ = x.shape
124 | 
125 |         # QKV
126 |         xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
127 |         xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
128 |         xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
129 |         xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
130 | 
131 |         # RoPE relative positional embeddings
132 |         xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
133 | 
134 |         # grouped multiquery attention: expand out keys and values
135 |         xk = repeat_kv(xk, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
136 |         xv = repeat_kv(xv, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
137 | 
138 |         # make heads into a batch dimension
139 |         xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
140 |         xk = xk.transpose(1, 2)
141 |         xv = xv.transpose(1, 2)
142 | 
143 |         # flash implementation
144 |         if self.flash:
145 |             output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True)
146 |         else:
147 |             # manual implementation
148 |             scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
149 |             assert hasattr(self, 'mask')
150 |             scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
151 |             scores = F.softmax(scores.float(), dim=-1).type_as(xq)
152 |             scores = self.attn_dropout(scores)
153 |             output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)
154 | 
155 |         # restore time as batch dimension and concat heads
156 |         output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
157 | 
158 |         # final projection into the residual stream
159 |         output = self.wo(output)
160 |         output = self.resid_dropout(output)
161 |         return output
162 | 
163 | 
164 | class FeedForward(nn.Module):
165 |     def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
166 |         super().__init__()
167 |         hidden_dim = int(2 * hidden_dim / 3)
168 |         hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
169 |         self.w1 = nn.Linear(dim, hidden_dim, bias=False)
170 |         self.w2 = nn.Linear(hidden_dim, dim, bias=False)
171 |         self.w3 = nn.Linear(dim, hidden_dim, bias=False)
172 |         self.dropout = nn.Dropout(dropout)
173 | 
174 |     def forward(self, x):
175 |         return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
176 | 
177 | 
178 | class TransformerBlock(nn.Module):
179 |     def __init__(self, layer_id: int, args: ModelArgs):
180 |         super().__init__()
181 |         self.n_heads = args.n_heads
182 |         self.dim = args.dim
183 |         self.head_dim = args.dim // args.n_heads
184 |         self.attention = Attention(args)
185 |         self.feed_forward = FeedForward(
186 |             dim=args.dim,
187 |             hidden_dim=4 * args.dim,
188 |             multiple_of=args.multiple_of,
189 |             dropout=args.dropout,
190 |         )
191 |         self.layer_id = layer_id
192 |         self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
193 |         self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
194 | 
195 |     def forward(self, x, freqs_cos, freqs_sin):
196 |         h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
197 |         out = h + self.feed_forward.forward(self.ffn_norm(h))
198 |         return out
199 | 
200 | 
201 | class Transformer(nn.Module):
202 |     last_loss: Optional[torch.Tensor]
203 | 
204 |     def __init__(self, params: ModelArgs):
205 |         super().__init__()
206 |         self.params = params
207 |         self.vocab_size = params.vocab_size
208 |         self.n_layers = params.n_layers
209 | 
210 |         self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
211 |         self.dropout = nn.Dropout(params.dropout)
212 |         self.layers = torch.nn.ModuleList()
213 |         for layer_id in range(params.n_layers):
214 |             self.layers.append(TransformerBlock(layer_id, params))
215 |         self.norm = RMSNorm(params.dim, eps=params.norm_eps)
216 |         self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
217 | 
218 |         # share the unembedding parameters with the embedding parameters
219 |         self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying
220 | 
221 |         # some useful precompute for the RoPE relative positional embeddings
222 |         freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
223 |         self.register_buffer("freqs_cos", freqs_cos, persistent=False)
224 |         self.register_buffer("freqs_sin", freqs_sin, persistent=False)
225 | 
226 |         # init all weights
227 |         self.apply(self._init_weights)
228 |         # apply special scaled init to the residual projections, per GPT-2 paper
229 |         for pn, p in self.named_parameters():
230 |             if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
231 |                 torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))
232 | 
233 |         # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
234 |         self.last_loss = None
235 | 
236 |     def _init_weights(self, module):
237 |         if isinstance(module, nn.Linear):
238 |             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
239 |             if module.bias is not None:
240 |                 torch.nn.init.zeros_(module.bias)
241 |         elif isinstance(module, nn.Embedding):
242 |             torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
243 | 
244 |     def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
245 |         _bsz, seqlen = tokens.shape
246 |         h = self.tok_embeddings(tokens)
247 |         h = self.dropout(h)
248 |         freqs_cos = self.freqs_cos[:seqlen]
249 |         freqs_sin = self.freqs_sin[:seqlen]
250 | 
251 |         for layer in self.layers:
252 |             h = layer(h, freqs_cos, freqs_sin)
253 |         h = self.norm(h)
254 | 
255 |         if targets is not None:
256 |             # if we are given some desired targets also calculate the loss
257 |             logits = self.output(h)
258 |             self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
259 |         else:
260 |             # inference-time mini-optimization: only forward the output on the very last position
261 |             logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim
262 |             self.last_loss = None
263 | 
264 |         return logits
265 | 
266 |     def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
267 |         # start with all of the candidate parameters
268 |         param_dict = {pn: p for pn, p in self.named_parameters()}
269 |         # filter out those that do not require grad
270 |         param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
271 |         # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
272 |         # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
273 |         decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
274 |         nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
275 |         optim_groups = [
276 |             {'params': decay_params, 'weight_decay': weight_decay},
277 |             {'params': nodecay_params, 'weight_decay': 0.0}
278 |         ]
279 |         num_decay_params = sum(p.numel() for p in decay_params)
280 |         num_nodecay_params = sum(p.numel() for p in nodecay_params)
281 |         print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
282 |         print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
283 |         # Create AdamW optimizer and use the fused version if it is available
284 |         fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
285 |         use_fused = fused_available and device_type == 'cuda'
286 |         extra_args = dict(fused=True) if use_fused else dict()
287 |         optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
288 |         print(f"using fused AdamW: {use_fused}")
289 | 
290 |         return optimizer
291 | 
292 |     def estimate_mfu(self, fwdbwd_per_iter, dt):
293 |         """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
294 |         # first estimate the number of flops we do per iteration.
295 |         # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
296 |         N = sum(p.numel() for p in self.parameters())
297 |         cfg = self.params
298 |         L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len
299 |         flops_per_token = 6*N + 12*L*H*Q*T
300 |         flops_per_fwdbwd = flops_per_token * T
301 |         flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
302 |         # express our flops throughput as ratio of A100 bfloat16 peak flops
303 |         flops_achieved = flops_per_iter * (1.0/dt) # per second
304 |         flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
305 |         mfu = flops_achieved / flops_promised
306 |         return mfu
307 | 
308 |     #@torch.inference_mode()
309 |     @torch.no_grad()
310 |     def generate(self, idx, eos, max_new_tokens, temperature=1.0, top_k=None):
311 |         """
312 |         Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
313 |         the sequence max_new_tokens times, feeding the predictions back into the model each time.
314 |         Most likely you'll want to make sure to be in model.eval() mode of operation for this.
315 |         Also note this is a super inefficient version of sampling with no key/value cache.
316 |         """
317 |         for _ in range(max_new_tokens):
318 |             # if the sequence context is growing too long we must crop it at block_size
319 |             idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:]
320 |             # forward the model to get the logits for the index in the sequence
321 |             logits = self(idx_cond)
322 |             logits = logits[:, -1, :] # crop to just the final time step
323 |             if temperature == 0.0:
324 |                 # "sample" the single most likely index
325 |                 _, idx_next = torch.topk(logits, k=1, dim=-1)
326 |             else:
327 |                 # pluck the logits at the final step and scale by desired temperature
328 |                 logits = logits / temperature
329 |                 # optionally crop the logits to only the top k options
330 |                 if top_k is not None:
331 |                     v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
332 |                     logits[logits < v[:, [-1]]] = -float('Inf')
333 |                 # apply softmax to convert logits to (normalized) probabilities
334 |                 probs = F.softmax(logits, dim=-1)
335 |                 idx_next = torch.multinomial(probs, num_samples=1)
336 |             # append sampled index to the running sequence and continue
337 |             idx = torch.cat((idx, idx_next), dim=1)
338 |             if idx_next==eos:
339 |                 break
340 | 
341 |         return idx
342 | 
343 |     def export(self, filepath='model.bin'):
344 |         """export the model weights in fp32 into .bin file to be read from C"""
345 |         f = open(filepath, 'wb')
346 | 
347 |         def serialize(t):
348 |             d = t.detach().cpu().view(-1).numpy().astype(np.float32)
349 |             b = struct.pack(f'{len(d)}f', *d)
350 |             f.write(b)
351 | 
352 |         # first write out the header
353 |         hidden_dim = self.layers[0].feed_forward.w1.weight.shape[0]
354 |         p = self.params
355 |         n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads
356 |         header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads,
357 |                                        n_kv_heads, p.vocab_size, p.max_seq_len)
358 |         f.write(header)
359 | 
360 |         # next write out the embedding weights
361 |         serialize(self.tok_embeddings.weight)
362 | 
363 |         # now all the layers
364 |         # attention weights
365 |         for layer in self.layers:
366 |             serialize(layer.attention_norm.weight)
367 |         for layer in self.layers:
368 |             serialize(layer.attention.wq.weight)
369 |         for layer in self.layers:
370 |             serialize(layer.attention.wk.weight)
371 |         for layer in self.layers:
372 |             serialize(layer.attention.wv.weight)
373 |         for layer in self.layers:
374 |             serialize(layer.attention.wo.weight)
375 |         # ffn weights
376 |         for layer in self.layers:
377 |             serialize(layer.ffn_norm.weight)
378 |         for layer in self.layers:
379 |             serialize(layer.feed_forward.w1.weight)
380 |         for layer in self.layers:
381 |             serialize(layer.feed_forward.w2.weight)
382 |         for layer in self.layers:
383 |             serialize(layer.feed_forward.w3.weight)
384 |         # final rmsnorm
385 |         serialize(self.norm.weight)
386 |         # note: no need to write final classifier weights due to weight sharing
387 |         # freqs_cis
388 |         serialize(self.freqs_cos[:p.max_seq_len])
389 |         serialize(self.freqs_sin[:p.max_seq_len])
390 | 
391 |         # write to binary file
392 |         f.close()
393 |         print(f"wrote {filepath}")


--------------------------------------------------------------------------------
/pretrain.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3 | import time
  4 | import math
  5 | import pickle
  6 | from contextlib import nullcontext
  7 | import numpy as np
  8 | import torch
  9 | from model import Transformer, ModelArgs
 10 | from torch.distributed import destroy_process_group, init_process_group
 11 | from torch.nn.parallel import DistributedDataParallel as DDP
 12 | 
 13 | from dataset import PretrainDataset
 14 | import logging
 15 | 
 16 | #To run with DDP on 4 gpus on 1 node, example:
 17 | # torchrun --standalone --nproc_per_node=4 pretrain.py OR python -m torch.distributed.launch --nproc_per_node=4 pretrain.py
 18 |         
 19 | def get_logger(filename, verbosity=1, name=None):
 20 |     level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
 21 |     formatter = logging.Formatter(
 22 |         "[%(asctime)s][%(filename)s][%(levelname)s] %(message)s"
 23 |     )
 24 |     logger = logging.getLogger(name)
 25 |     logger.setLevel(level_dict[verbosity])
 26 | 
 27 |     fh = logging.FileHandler(filename, "w")
 28 |     fh.setFormatter(formatter)
 29 |     logger.addHandler(fh)
 30 | 
 31 |     sh = logging.StreamHandler()
 32 |     sh.setFormatter(formatter)
 33 |     logger.addHandler(sh)
 34 |     return logger
 35 | # -----------------------------------------------------------------------------
 36 | def get_lr(it):
 37 |     # 1) linear warmup for warmup_iters steps
 38 |     if it < warmup_iters:
 39 |         return learning_rate * it / warmup_iters
 40 |     # 2) if it > lr_decay_iters, return min learning rate
 41 |     if it > lr_decay_iters:
 42 |         return min_lr
 43 |     # 3) in between, use cosine decay down to min learning rate
 44 |     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
 45 |     assert 0 <= decay_ratio <= 1
 46 |     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
 47 |     return min_lr + coeff * (learning_rate - min_lr)
 48 | 
 49 | def train_epoch(epoch):
 50 |     start_time=time.time()
 51 |     for step, (X, Y) in enumerate(train_loader):
 52 |         X=X.to(device)
 53 |         Y=Y.to(device)
 54 |         lr = get_lr(epoch*iter_per_epoch+step) if decay_lr else learning_rate
 55 |         for param_group in optimizer.param_groups:
 56 |             param_group['lr'] = lr
 57 |         # and using the GradScaler if data type is float16
 58 |         #for micro_step in range(gradient_accumulation_steps):
 59 |         if ddp:
 60 |             # in DDP training we only need to sync gradients at the last micro step.
 61 |             # the official way to do this is with model.no_sync() context manager, but
 62 |             # I really dislike that this bloats the code and forces us to repeat code
 63 |             # looking at the source of that context manager, it just toggles this variable
 64 |             model.require_backward_grad_sync = 0 == gradient_accumulation_steps - 1
 65 |         with ctx:
 66 |             logits = model(X, Y)
 67 |             loss = raw_model.last_loss
 68 |             loss = loss / gradient_accumulation_steps
 69 |         # immediately async prefetch next batch while model is doing the forward pass on the GPU
 70 |         # backward pass, with gradient scaling if training in fp16
 71 |         scaler.scale(loss).backward()
 72 |         #
 73 |         if (step + 1) % gradient_accumulation_steps == 0:
 74 |             # clip the gradient
 75 |             if grad_clip != 0.0:
 76 |                 scaler.unscale_(optimizer)
 77 |                 torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
 78 |             # step the optimizer and scaler if training in fp16
 79 |             scaler.step(optimizer)
 80 |             scaler.update()
 81 |             # flush the gradients as soon as we can, no need for this memory anymore
 82 |             optimizer.zero_grad(set_to_none=True)
 83 |         #打印日志
 84 |         if step % log_interval == 0:
 85 |             spend_time=time.time()-start_time
 86 |             logger.info(
 87 |                     'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
 88 |                         epoch,
 89 |                         max_epoch, 
 90 |                         step, 
 91 |                         iter_per_epoch,
 92 |                         loss.item(), 
 93 |                         optimizer.param_groups[-1]['lr'],
 94 |                         spend_time / (step+1) * iter_per_epoch // 60 - spend_time // 60))
 95 |         #
 96 |         if step % save_interval == 0:
 97 |             if ddp:
 98 |                 if torch.distributed.get_rank() == 0:
 99 |                     model.eval()
100 |                     torch.save(model.module.state_dict(),'{}/iter_{}.pth'.format(save_dir,int(step+epoch*iter_per_epoch)))
101 |                     model.train()
102 |             else:
103 |                 model.eval()
104 |                 torch.save(model.state_dict(),'{}/iter_{}.pth'.format(save_dir,int(step+epoch*iter_per_epoch)))
105 |                 model.train()
106 | 
107 | #@torch.no_grad()
108 | # def valid_epoch(epoch):
109 | #     global best_val_loss
110 | #     losses = []
111 | #     model.eval()
112 | #     for _, (X, Y) in enumerate(val_loader):
113 | #         X=X.to(device)
114 | #         Y=Y.to(device)
115 | #         with ctx:
116 | #             logits, loss = model(X, Y)
117 | #         losses.append(loss.item())
118 | #     model.train()
119 | #     val_loss=np.mean(losses)
120 | #     #
121 | #     logger.info('valid loss = {:.4f}'.format(val_loss))
122 | #     if val_loss < best_val_loss:
123 | #         best_val_loss = val_loss
124 | #         logger.info('best val_loss: {} best_epoch: {} '.format(best_val_loss,epoch))
125 | #         torch.save(raw_model.state_dict(),'{}/best.pth'.format(save_dir))
126 | #     #
127 | #     return val_loss
128 | 
129 | def init_model():
130 |     # model init
131 |     # model init
132 |     model_args = dict(
133 |         dim=dim,
134 |         n_layers=n_layers,
135 |         n_heads=n_heads,
136 |         n_kv_heads=n_heads,
137 |         vocab_size=64793,
138 |         multiple_of=multiple_of,
139 |         max_seq_len=max_seq_len,
140 |         dropout=dropout,
141 |     )  # start with model_args from command line
142 |     if init_from == "scratch":
143 |         # init a new model from scratch
144 |         print("Initializing a new model from scratch")
145 |         gptconf = ModelArgs(**model_args)
146 |         model = Transformer(gptconf)
147 |     elif init_from == "resume":
148 |         print(f"Resuming training from {out_dir}")
149 |         # resume training from a checkpoint.
150 |         ckpt_path = os.path.join(out_dir, "ckpt.pt")
151 |         checkpoint = torch.load(ckpt_path, map_location=device)
152 |         checkpoint_model_args = checkpoint["model_args"]
153 |         # force these config attributes to be equal otherwise we can't even resume training
154 |         # the rest of the attributes (e.g. dropout) can stay as desired from command line
155 |         for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
156 |             model_args[k] = checkpoint_model_args[k]
157 |         # create the model
158 |         gptconf = ModelArgs(**model_args)
159 |         model = Transformer(gptconf)
160 |         state_dict = checkpoint["model"]
161 |         # fix the keys of the state dictionary :(
162 |         # honestly no idea how checkpoints sometimes get this prefix, have to debug more
163 |         unwanted_prefix = "_orig_mod."
164 |         for k, v in list(state_dict.items()):
165 |             if k.startswith(unwanted_prefix):
166 |                 state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
167 |         model.load_state_dict(state_dict)
168 |         iter_num = checkpoint["iter_num"]
169 |         best_val_loss = checkpoint["best_val_loss"]
170 |     return model
171 | # I/O
172 | if __name__=="__main__":
173 |     out_dir = 'out'
174 |     max_epoch = 1
175 |     eval_interval = 1
176 |     log_interval = 100
177 |     save_interval = 10000
178 |     eval_iters = 200
179 |     eval_only = False # if True, script exits right after the first eval
180 |     always_save_checkpoint = True # if True, always save a checkpoint after each eval
181 |     init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
182 |     #
183 |     gradient_accumulation_steps = 1 # used to simulate larger batch sizes
184 |     batch_size = 32  # if gradient_accumulation_steps > 1, this is the micro-batch size
185 |     # model 根据需要更改 
186 |     max_seq_len = 512
187 |     dim = 512
188 |     n_layers = 8
189 |     n_heads = 8
190 |     multiple_of = 32
191 |     dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
192 |     bias = False # do we use bias inside LayerNorm and Linear layers?
193 |     # adamw optimizer
194 |     learning_rate = 3e-4 # max learning rate
195 |     weight_decay = 1e-1
196 |     beta1 = 0.9
197 |     beta2 = 0.95
198 |     grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
199 |     # learning rate decay settings
200 |     decay_lr = True # whether to decay the learning rate
201 |     warmup_iters = 1000 # how many steps to warm up for
202 |     lr_decay_iters = 80000 # should be ~= max_iters per Chinchilla
203 |     min_lr = 1e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
204 |     # DDP settings
205 |     backend = 'nccl' # 'nccl', 'gloo', etc.
206 |     # system
207 |     device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
208 |     dtype = 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
209 |     compile = False # use PyTorch 2.0 to compile the model to be faster
210 |     # -----------------------------------------------------------------------------
211 |     config_keys = [
212 |         k
213 |         for k, v in globals().items()
214 |         if not k.startswith("_") and isinstance(v, (int, float, bool, str))
215 |     ]
216 |     # exec(open("configurator.py").read())  # overrides from command line or config file
217 |     # config = {k: globals()[k] for k in config_keys}  # will be useful for logging
218 |     # -----------------------------------------------------------------------------
219 | 
220 |     save_dir =os.path.join(out_dir , 'pretrain')
221 |     if not os.path.exists(save_dir): os.makedirs(save_dir)
222 |     logger = get_logger(os.path.join(save_dir,'log.log'))
223 |     # various inits, derived attributes, I/O setup
224 |    # various inits, derived attributes, I/O setup
225 |     ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
226 |     
227 |     if ddp:
228 |         # Check if the operating system is Windows
229 |         if os.name == 'nt':
230 |             # Diff between backends: https://pytorch.org/docs/stable/distributed.html
231 |             init_process_group(backend="gloo")
232 |         else:
233 |             # If the operating system is Linux based, os.name == 'posix'
234 |             init_process_group(backend="nccl")
235 |         ddp_rank = int(os.environ["RANK"])
236 |         ddp_local_rank = int(os.environ["LOCAL_RANK"])
237 |         ddp_world_size = int(os.environ["WORLD_SIZE"])
238 |         device = f"cuda:{ddp_local_rank}"
239 |         torch.cuda.set_device(device)
240 |         master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
241 |         seed_offset = ddp_rank  # each process gets a different seed
242 |         # world_size number of processes will be training simultaneously, so we can scale
243 |         # down the desired gradient accumulation iterations per process proportionally
244 |         #assert gradient_accumulation_steps % ddp_world_size == 0
245 |         #gradient_accumulation_steps //= ddp_world_size
246 |     else:
247 |         # if not ddp, we are running on a single gpu, and one process
248 |         master_process = True
249 |         seed_offset = 0
250 |         ddp_world_size = 1
251 |     tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
252 |     if master_process:
253 |         print(f"tokens per iteration will be: {tokens_per_iter:,}")
254 |         print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
255 | 
256 |     if master_process:
257 |         os.makedirs(out_dir, exist_ok=True)
258 |     torch.manual_seed(1337 + seed_offset)
259 |     torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
260 |     torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
261 |     device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
262 |     # note: float16 data type will automatically use a GradScaler
263 |     ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
264 |     ctx = (
265 |         nullcontext()
266 |         if device_type == "cpu"
267 |         else torch.cuda.amp.autocast()
268 |     )
269 |     #
270 |     best_val_loss = 1e9
271 |     #
272 |     #-----init dataloader------
273 |     data_path_list=[
274 |         './data/pretrain_data.bin'
275 |         #'./data/baidubaike_563w.bin',
276 |         #'./data/medical_book.bin',
277 |         # './data/medical_encyclopedia.bin',
278 |         # './data/medical_qa.bin',
279 |         # './data/wiki.bin'
280 |     ]
281 |     train_ds = PretrainDataset(data_path_list, max_length=max_seq_len,memmap=True)
282 |     train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds)
283 |     train_loader = torch.utils.data.DataLoader(
284 |         train_ds,
285 |         batch_size=batch_size,
286 |         pin_memory=False,
287 |         drop_last=False,
288 |         shuffle=False,        
289 |         num_workers=0 if os.name == 'nt' else 4,
290 |         sampler=train_sampler
291 |     )
292 |     # val_ds = PretrainDataset(data_path_list, max_length=256)
293 |     # val_loader = torch.utils.data.DataLoader(
294 |     #     val_ds,
295 |     #     batch_size=batch_size,
296 |     #     pin_memory=False,
297 |     #     drop_last=False,
298 |     #     shuffle=False,        
299 |     #     num_workers=0,
300 |     # )
301 |     #init model
302 |     model=init_model()
303 |     model.to(device)
304 |     # initialize a GradScaler. If enabled=False scaler is a no-op
305 |     scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
306 |     # optimizer
307 |     optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
308 |     # compile the model
309 |     if compile:
310 |         print("compiling the model... (takes a ~minute)")
311 |         unoptimized_model = model
312 |         model = torch.compile(model) # requires PyTorch 2.0
313 |     # wrap model into DDP container
314 |     if ddp:
315 |         # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
316 |         # construction time since NCCL does not support `ComplexFloat`
317 |         prefix = "_orig_mod." if compile else ""
318 |         model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
319 |         model = DDP(model, device_ids=[ddp_local_rank])
320 |         #
321 |     raw_model = model.module if ddp else model # unwrap DDP container if needed
322 |     # training loop
323 |     iter_per_epoch=len(train_loader)
324 |     for epoch in range(max_epoch):
325 |         train_epoch(epoch)
326 |         #val_loss=valid_epoch(epoch)
327 |         if ddp:
328 |             if torch.distributed.get_rank() == 0:  #一般用0,当然,可以选任意的rank保存。
329 |                 torch.save(raw_model.state_dict(),'{}/epoch_{}.pth'.format(save_dir,epoch))
330 |         else:
331 |             torch.save(raw_model.state_dict(),'{}/epoch_{}.pth'.format(save_dir,epoch))
332 |     if ddp:
333 |         destroy_process_group()
334 | 


--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
 1 | numpy==1.23.5
 2 | pytest==7.4.0
 3 | Requests==2.31.0
 4 | sentencepiece==0.1.99
 5 | torch==2.0.1
 6 | scikit-learn==1.3.0
 7 | tqdm==4.64.1
 8 | jieba
 9 | pandas
10 | transformers==4.33.2
11 | 


--------------------------------------------------------------------------------
/sft.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  3 | import time
  4 | import math
  5 | import pickle
  6 | from contextlib import nullcontext
  7 | import numpy as np
  8 | import torch
  9 | from model import Transformer, ModelArgs
 10 | from torch.distributed import destroy_process_group, init_process_group
 11 | from torch.nn.parallel import DistributedDataParallel as DDP
 12 | import pandas as pd
 13 | from dataset_sft import SFTDataset
 14 | import logging
 15 | import json
 16 | import torch.nn.functional as F
 17 | from chatglm_tokenizer.tokenization_chatglm import ChatGLMTokenizer
 18 |         
 19 | def get_logger(filename, verbosity=1, name=None):
 20 |     level_dict = {0: logging.DEBUG, 1: logging.INFO, 2: logging.WARNING}
 21 |     formatter = logging.Formatter(
 22 |         "[%(asctime)s][%(filename)s][%(levelname)s] %(message)s"
 23 |     )
 24 |     logger = logging.getLogger(name)
 25 |     logger.setLevel(level_dict[verbosity])
 26 | 
 27 |     fh = logging.FileHandler(filename, "w")
 28 |     fh.setFormatter(formatter)
 29 |     logger.addHandler(fh)
 30 | 
 31 |     sh = logging.StreamHandler()
 32 |     sh.setFormatter(formatter)
 33 |     logger.addHandler(sh)
 34 |     return logger
 35 | # -----------------------------------------------------------------------------
 36 | def get_lr(it):
 37 |     # 1) linear warmup for warmup_iters steps
 38 |     if it < warmup_iters:
 39 |         return learning_rate * it / warmup_iters
 40 |     # 2) if it > lr_decay_iters, return min learning rate
 41 |     if it > lr_decay_iters:
 42 |         return min_lr
 43 |     # 3) in between, use cosine decay down to min learning rate
 44 |     decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
 45 |     assert 0 <= decay_ratio <= 1
 46 |     coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
 47 |     return min_lr + coeff * (learning_rate - min_lr)
 48 | #------------------------------------------------------------------------------
 49 | def train_epoch(epoch):
 50 |     start_time=time.time()
 51 |     for step, (X, Y,loss_mask) in enumerate(train_loader):
 52 |         X=X.to(device)
 53 |         Y=Y.to(device)
 54 |         loss_mask=loss_mask.to(device)
 55 |         lr = get_lr(epoch*iter_per_epoch+step) if decay_lr else learning_rate
 56 |         for param_group in optimizer.param_groups:
 57 |             param_group['lr'] = lr
 58 |         # and using the GradScaler if data type is float16
 59 |         #for micro_step in range(gradient_accumulation_steps):
 60 |         if ddp:
 61 |             # in DDP training we only need to sync gradients at the last micro step.
 62 |             # the official way to do this is with model.no_sync() context manager, but
 63 |             # I really dislike that this bloats the code and forces us to repeat code
 64 |             # looking at the source of that context manager, it just toggles this variable
 65 |             model.require_backward_grad_sync = 0 == gradient_accumulation_steps - 1
 66 |         with ctx:
 67 |             logits = model(X, Y)
 68 |             loss = F.cross_entropy(logits.view(-1, logits.size(-1)), Y.view(-1), ignore_index=0,reduce=False)
 69 |             loss_mask = loss_mask.view(-1)
 70 |             loss = torch.sum(loss*loss_mask)/loss_mask.sum()
 71 |             #loss = raw_model.last_loss
 72 |             #loss = loss / gradient_accumulation_steps
 73 |         # immediately async prefetch next batch while model is doing the forward pass on the GPU
 74 |         # backward pass, with gradient scaling if training in fp16
 75 |         scaler.scale(loss).backward()
 76 |         #
 77 |         # clip the gradient
 78 |         if grad_clip != 0.0:
 79 |             scaler.unscale_(optimizer)
 80 |             torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
 81 |         # step the optimizer and scaler if training in fp16
 82 |         scaler.step(optimizer)
 83 |         scaler.update()
 84 |         # flush the gradients as soon as we can, no need for this memory anymore
 85 |         optimizer.zero_grad(set_to_none=True)
 86 |         #打印日志
 87 |         if step % log_interval == 0:
 88 |             spend_time=time.time()-start_time
 89 |             logger.info(
 90 |                     'Epoch:[{}/{}]({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
 91 |                         epoch,
 92 |                         max_epoch, 
 93 |                         step, 
 94 |                         iter_per_epoch,
 95 |                         loss.item(), 
 96 |                         optimizer.param_groups[-1]['lr'],
 97 |                         spend_time / (step+1) * iter_per_epoch // 60 - spend_time // 60))
 98 | #------------------
 99 | @torch.no_grad()
100 | def valid_epoch(epoch):
101 |     global best_val_loss
102 |     losses = []
103 |     model.eval()
104 |     for _, (X, Y) in enumerate(val_loader):
105 |         X=X.to(device)
106 |         Y=Y.to(device)
107 |         with ctx:
108 |             logits, loss = model(X, Y)
109 |         losses.append(loss.item())
110 |     model.train()
111 |     val_loss=np.mean(losses)
112 |     #
113 |     logger.info('valid loss = {:.4f}'.format(val_loss))
114 |     if val_loss < best_val_loss:
115 |         best_val_loss = val_loss
116 |         logger.info('best val_loss: {} best_epoch: {} '.format(best_val_loss,epoch))
117 |         torch.save(raw_model.state_dict(),'{}/best.pth'.format(save_dir))
118 |     #
119 |     return val_loss
120 | 
121 | def init_model():
122 |     # model init
123 |     # model init
124 |     model_args = dict(
125 |         dim=dim,
126 |         n_layers=n_layers,
127 |         n_heads=n_heads,
128 |         n_kv_heads=n_heads,
129 |         vocab_size=64793,#64793,
130 |         multiple_of=multiple_of,
131 |         max_seq_len=max_seq_len,
132 |         dropout=dropout,
133 |     )  # start with model_args from command line
134 |     if init_from == "scratch":
135 |         # init a new model from scratch
136 |         print("Initializing a new model from scratch")
137 |         gptconf = ModelArgs(**model_args)
138 |         model = Transformer(gptconf)
139 |     elif init_from == "resume":
140 |         print(f"Resuming training from {out_dir}")
141 |         # resume training from a checkpoint.
142 |         ckpt_path = os.path.join(out_dir, "ckpt.pt")
143 |         checkpoint = torch.load(ckpt_path, map_location=device)
144 |         checkpoint_model_args = checkpoint["model_args"]
145 |         # force these config attributes to be equal otherwise we can't even resume training
146 |         # the rest of the attributes (e.g. dropout) can stay as desired from command line
147 |         for k in ["dim", "n_layers", "n_heads", "n_kv_heads", "vocab_size", "multiple_of", "max_seq_len"]:
148 |             model_args[k] = checkpoint_model_args[k]
149 |         # create the model
150 |         gptconf = ModelArgs(**model_args)
151 |         model = Transformer(gptconf)
152 |         state_dict = checkpoint["model"]
153 |         # fix the keys of the state dictionary :(
154 |         # honestly no idea how checkpoints sometimes get this prefix, have to debug more
155 |         unwanted_prefix = "_orig_mod."
156 |         for k, v in list(state_dict.items()):
157 |             if k.startswith(unwanted_prefix):
158 |                 state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
159 |         model.load_state_dict(state_dict)
160 |         iter_num = checkpoint["iter_num"]
161 |         best_val_loss = checkpoint["best_val_loss"]
162 |     return model
163 | # I/O
164 | if __name__=="__main__":
165 |     out_dir = 'out'
166 |     max_epoch = 2
167 |     eval_interval = 1
168 |     log_interval = 50
169 |     eval_iters = 200
170 |     eval_only = False # if True, script exits right after the first eval
171 |     always_save_checkpoint = True # if True, always save a checkpoint after each eval
172 |     init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*'
173 |     #
174 |     gradient_accumulation_steps = 1 # used to simulate larger batch sizes
175 |     batch_size = 32 # if gradient_accumulation_steps > 1, this is the micro-batch size
176 |     # model
177 |     max_seq_len = 512
178 |     dim = 512
179 |     n_layers = 8
180 |     n_heads = 8
181 |     multiple_of = 32
182 |     dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+
183 |     bias = False # do we use bias inside LayerNorm and Linear layers?
184 |     # adamw optimizer
185 |     learning_rate = 2e-5 # max learning rate
186 |     weight_decay = 1e-4
187 |     beta1 = 0.9
188 |     beta2 = 0.95
189 |     grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0
190 |     # learning rate decay settings
191 |     decay_lr = True # whether to decay the learning rate
192 |     warmup_iters = 1000 # how many steps to warm up for
193 |     lr_decay_iters = 50000 # should be ~= max_iters per Chinchilla
194 |     min_lr = 1e-6 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
195 |     # DDP settings
196 |     backend = 'nccl' # 'nccl', 'gloo', etc.
197 |     # system
198 |     device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
199 |     dtype = 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
200 |     compile = False # use PyTorch 2.0 to compile the model to be faster
201 |     # -----------------------------------------------------------------------------
202 |     config_keys = [
203 |         k
204 |         for k, v in globals().items()
205 |         if not k.startswith("_") and isinstance(v, (int, float, bool, str))
206 |     ]
207 |     # exec(open("configurator.py").read())  # overrides from command line or config file
208 |     # config = {k: globals()[k] for k in config_keys}  # will be useful for logging
209 |     # -----------------------------------------------------------------------------
210 | 
211 |     save_dir =os.path.join(out_dir , 'sft')
212 |     if not os.path.exists(save_dir): os.makedirs(save_dir)
213 |     logger = get_logger(os.path.join(save_dir,'log.log'))
214 |     # various inits, derived attributes, I/O setup
215 |    # various inits, derived attributes, I/O setup
216 |     ddp = int(os.environ.get("RANK", -1)) != -1  # is this a ddp run?
217 |     if ddp:
218 |         init_process_group(backend="nccl")
219 |         ddp_rank = int(os.environ["RANK"])
220 |         ddp_local_rank = int(os.environ["LOCAL_RANK"])
221 |         ddp_world_size = int(os.environ["WORLD_SIZE"])
222 |         device = f"cuda:{ddp_local_rank}"
223 |         torch.cuda.set_device(device)
224 |         master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.
225 |         seed_offset = ddp_rank  # each process gets a different seed
226 |         # world_size number of processes will be training simultaneously, so we can scale
227 |         # down the desired gradient accumulation iterations per process proportionally
228 |         #assert gradient_accumulation_steps % ddp_world_size == 0
229 |         #gradient_accumulation_steps //= ddp_world_size
230 |     else:
231 |         # if not ddp, we are running on a single gpu, and one process
232 |         master_process = True
233 |         seed_offset = 0
234 |         ddp_world_size = 1
235 |     tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * max_seq_len
236 |     if master_process:
237 |         print(f"tokens per iteration will be: {tokens_per_iter:,}")
238 |         print(f"breaks down as: {gradient_accumulation_steps} grad accum steps * {ddp_world_size} processes * {batch_size} batch size * {max_seq_len} max seq len")
239 | 
240 |     if master_process:
241 |         os.makedirs(out_dir, exist_ok=True)
242 |     torch.manual_seed(1337 + seed_offset)
243 |     torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
244 |     torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
245 |     device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
246 |     # note: float16 data type will automatically use a GradScaler
247 |     ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]
248 |     ctx = (
249 |         nullcontext()
250 |         if device_type == "cpu"
251 |         else torch.cuda.amp.autocast()
252 |     )
253 |     #
254 |     best_val_loss = 1e9
255 |     
256 |     #-----init dataloader------
257 |     df=pd.read_csv('./sft_data/sft_data.csv')
258 |     # input=[]
259 |     # target=[]
260 |     # with open('../track1/train_valid.json','r') as f:
261 |     #     data=json.load(f)
262 |     # #
263 |     # for l in data:
264 |     #     input.append(l['question'])
265 |     #     target.append(l['answer'])
266 |     # df = pd.DataFrame()
267 |     # df['prompt']=input
268 |     # df['answer']=target
269 |     # df=pd.concat((df_sft,df[100:])).reset_index(drop=True)
270 |     df=df.sample(frac=1.0)
271 |     print(df)
272 |     tokenizer=ChatGLMTokenizer(vocab_file='./chatglm_tokenizer/tokenizer.model')
273 |     train_ds = SFTDataset(df,tokenizer, max_length=512)
274 |     train_loader = torch.utils.data.DataLoader(
275 |         train_ds,
276 |         batch_size=batch_size,
277 |         pin_memory=False,
278 |         drop_last=False,
279 |         shuffle=False,        
280 |         num_workers=0,
281 |     )
282 |     # val_ds = PretrainDataset(data_path_list, max_length=256)
283 |     # val_loader = torch.utils.data.DataLoader(
284 |     #     val_ds,
285 |     #     batch_size=batch_size,
286 |     #     pin_memory=False,
287 |     #     drop_last=False,
288 |     #     shuffle=False,        
289 |     #     num_workers=0,
290 |     # )
291 |     #init model
292 |     model=init_model()
293 |     model.load_state_dict(torch.load('./out/baike_pretrain/epoch_0.pth'))
294 |     model.to(device)
295 |     # initialize a GradScaler. If enabled=False scaler is a no-op
296 |     scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))
297 |     # optimizer
298 |     optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)
299 |     #
300 |     iter_per_epoch=len(train_loader)
301 |     # compile the model
302 |     if compile:
303 |         print("compiling the model... (takes a ~minute)")
304 |         unoptimized_model = model
305 |         model = torch.compile(model) # requires PyTorch 2.0
306 |     # wrap model into DDP container
307 |     if ddp:
308 |         # Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
309 |         # construction time since NCCL does not support `ComplexFloat`
310 |         prefix = "_orig_mod." if compile else ""
311 |         model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
312 |         model = DDP(model, device_ids=[ddp_local_rank])
313 |     #
314 |     raw_model = model.module if ddp else model # unwrap DDP container if needed
315 |     # training loop
316 |     for epoch in range(max_epoch):
317 |         train_epoch(epoch)
318 |         #val_loss=valid_epoch(epoch)
319 |         if ddp:
320 |             if torch.distributed.get_rank() == 0:
321 |                 torch.save(raw_model.state_dict(),'{}/epoch_{}.pth'.format(save_dir,epoch))
322 |         else:
323 |             torch.save(raw_model.state_dict(),'{}/epoch_{}.pth'.format(save_dir,epoch))
324 |     if ddp:
325 |         destroy_process_group()
326 | 


--------------------------------------------------------------------------------
/sft_data_process.py:
--------------------------------------------------------------------------------
 1 | import json
 2 | import numpy as np
 3 | from tqdm import tqdm
 4 | import pandas as pd
 5 | 
 6 | 
 7 | def sft_process():
 8 |     with open('./sft_data/alpaca_gpt4_data_zh.json', 'r', encoding='utf-8') as f:
 9 |         data = json.load(f)
10 |     #
11 |     q_lst = []
12 |     a_lst = []
13 |     for per in data:
14 |         q = per['instruction']
15 |         i = per['input']
16 |         a = per['output']
17 |         q = q + i
18 |         if len(q) < 10 or len(a) < 5:
19 |             continue
20 |         if len(q) > 256 or len(a) > 256:
21 |             continue
22 |         q_lst.append(q)
23 |         a_lst.append(a)
24 | 
25 |     f = open('./sft_data/Belle_open_source_1M.json', 'r', encoding='utf-8')
26 | 
27 |     # s
28 |     while True:
29 |         line = f.readline()
30 |         if not line:
31 |             break
32 |         per = json.loads(line)
33 |         q = per['instruction']
34 |         i = per['input']
35 |         a = per['output']
36 |         q = q + i
37 |         if len(q) < 10 or len(a) < 5:
38 |             continue
39 |         if len(q) > 256 or len(a) > 256:
40 |             continue
41 |         q_lst.append(q)
42 |         a_lst.append(a)
43 |     df = pd.DataFrame(columns=['prompt', 'answer'])
44 |     df['prompt'] = q_lst
45 |     df['answer'] = a_lst
46 |     df.to_csv('sft_data/sft_data.csv', index=False)
47 |     print(df)
48 | 
49 | save_dir = './sft_data'
50 | if not os.path.exists(save_dir): os.makedirs(save_dir)
51 | sft_process()


--------------------------------------------------------------------------------