├── LICENSE ├── README.md ├── config ├── cpm-medium.json ├── cpm-one-layer.json └── cpm-small.json ├── data └── .gitkeep ├── data_parallel.py ├── dataset.py ├── distill.py ├── generate.py ├── http_service.py ├── log └── .gitignore ├── model └── .gitignore ├── preprocess.py ├── script ├── distill.sh └── train-v2.sh ├── train-v2.py ├── train.py ├── utils.py └── vocab ├── chinese_vocab.model ├── chinese_vocab.vocab └── vocab.json /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPM 2 | 3 | ## 项目描述 4 | 微信公众号【YeungNLP】文章:[基于CPM的中文作文生成模型,引经据典、修辞手法,信手拈来](https://mp.weixin.qq.com/s/sFzUNtwrTvi2kAAGQ2M3UA) ,文章内可获取26w+中文作文语料。 5 | 6 | CPM(Chinese Pretrained Models)模型是北京智源人工智能研究院和清华大学发布的中文大规模预训练模型。官方发布了三种规模的模型,参数量分别为109M、334M、2.6B,用户需申请与通过审核,方可下载。 7 | 由于原项目需要考虑大模型的训练和使用,需要安装较为复杂的环境依赖,使用上也较为复杂。 8 | 本项目采用了109M的CPM模型(若资源允许也可以考虑334M的模型),并且简化了模型的训练和使用。 9 | 10 | 本项目是基于CPM模型的中文文本生成项目,可用于作文、小说、新闻、古诗等中文生成任务,并且训练和分享了[中文作文生成模型](#model_share),取得了不错的[生成效果](#sample)。 11 | 本项目提供了数据预处理、模型训练、文本生成、Http服务等代码模块。 12 | 详情可参考[CPM模型论文](https://arxiv.org/abs/2012.00413), [CPM官网](https://cpm.baai.ac.cn/), [项目源码](https://github.com/TsinghuaAI/CPM-Generate) 。 13 | 14 | 15 | ## 运行环境 16 | python==3.6、transformers==4.6.0、sentencepiece==0.1.94、torch==1.7.0、Flask==1.1.2 17 | 18 | 19 | ## 项目结构 20 | 用户可自行创建以下目录。 21 | - config:存放模型的配置文件 22 | - data:存放训练数据 23 | - model:存放模型 24 | - log:存放日志文件 25 | - vocab: 26 | - chinese_voca.model:sentencepiece模型 27 | - vocab.json:分词与id的键值对 28 | - data_parallel.py:解决pytorch的GPU负载不均衡的问题 29 | - generate.py:生成代码 30 | - http_service.py:封装成http服务,支持post与get请求 31 | - preprocess.py:数据预处理代码 32 | - utils.py:存放一些工具代码 33 | 34 | ## 模型参数与训练细节 35 | 由于GPU资源有限,本项目使用cpm-small.json中的模型参数,若资源充足,可尝试cpm-medium.json中的参数配置。 36 | 37 | 本项目的部分模型参数如下: 38 | - n_ctx: 1024 39 | - n_embd: 768 40 | - n_head: 12 41 | - n_layer: 12 42 | - n_positions: 1024 43 | - vocab_size: 30000 44 | 45 | 对26w篇作文进行预处理之后,得到60w+长度为200的训练数据。显卡为三张GTX 1080Ti,batch_size=50,三张卡显存满载,一轮训练大约需要3个小时。训练40轮之后,loss降到2.1左右,单词预测准确率大约为54%。 46 | 47 | ## 使用方法 48 | ### Quick Start 49 | 在[模型分享](#model_share)中下载模型,将模型文件夹zuowen_epoch40放到model目录下,执行如下命令,指定作文标题、作文开头和长度,进行生成。 50 | ``` 51 | python generate.py --model_path model/zuowen_epoch40 --title 家乡的四季 --context 家乡的四季,最美不过了 --max_len 200 52 | ``` 53 | 54 | ### 数据预处理 55 | 每篇作文对应一个txt文件,txt内容格式如下: 56 | ``` 57 | --- 58 | 标题:徜徉在书籍的阳光世界 59 | 日期:xxxx-xx-xx xx:xx:xx 60 | 作者:佚名 61 | --- 62 | 63 | 64 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙;一本书是一个人的耳朵,它可以让你听到大自然的呼唤,听到社会的声音。 65 | 66 | 《森林报》是苏联著名科普作家维。比安基的代表作品,他以春夏秋冬四季为序,有层次、有类别地将森林里动植物的新鲜事描写得栩栩如生,引人入胜。这本书教会我们如何去观察、认识大自然。这本书教会我们感悟生命,体验生动的愉快探速之旅,激发我们热爱科学的兴趣。 67 | 68 | 《三字经》、《弟子规》、《论语》这样的国学经典,我也会拿来阅读,虽然似懂非懂,但读起来朗朗上口,觉得挺有趣。读着读着,好似开始了一场时空之旅,与古代圣贤结为知己,进行心与心之间的倾听与问候。这些书籍让我们在阅读中品味高雅。 69 | 70 | 在成长的过程中,每个人都有着自己不一样的心路历程。阳光姐姐写的《成长的秘密》一书让我们获益不浅。作者用简单生动的文字,把温馨感人、新鲜快乐、爆笑的校园生活展现在我们眼前。书中的人物宁佳鑫看上去弱小,但她实际却很坚强,在她身上,我看到了她散发出的正能量和她在逆境中奋起的精神。她的经历告诉我:无论遇到什么样的挫折与坎坷,都不要气馁。阳光总在风雨后,只要我们坚持不懈地去想办法克服困难,并付诸行动,就一定会柳暗花明! 71 | 72 | 法国作家德尔伦曾说过“智慧可以转化成力量,完成你认为不可能完成的事。”是啊,智慧的力量很强大,这些力量隐藏在书中。当我们在阅读之际,这些知识就偷偷地跑进我们的脑海里,渐渐地,渐渐地,它们就永远地保存下来,显示出无穷的魅力,让我们的未来畅通无阻。 73 | 74 | 书籍,用爱和勇气唤醒每个孩子的心灵;书籍让我们感受到温暖与力量;书籍,教我们用心灵在文字间快乐舞蹈。 75 | 76 | 让我们走进书籍的阳光世界,获取成长的力量。 77 | ``` 78 | 对于每个txt文件,首先取出标题与内容,将标题与内容按照"title[sep]content[eod]"的方式拼接起来,然后对其进行tokenize,最后使用滑动窗口对内容进行截断,得到训练数据。 79 | 运行如下命令,进行数据预处理。注:预处理之后的数据保存为train.pkl,这是一个list,list中每个元素表示一条训练数据。 80 | ``` 81 | python preprocess.py --data_path data/zuowen --save_path data/train.pkl --win_size 200 --step 200 82 | ``` 83 | 超参数说明: 84 | - vocab_file:sentencepiece模型路径,用于tokenize 85 | - log_path:日志存放位置 86 | - data_path:数据集存放位置 87 | - save_path:对训练数据集进行tokenize之后的数据存放位置 88 | - win_size:滑动窗口的大小,相当于每条数据的最大长度 89 | - step:滑动窗口的滑动步幅 90 | 91 | 用户也可以根据自身的需求,对预处理的代码进行相应的修改。后续将更新项目代码,以便用于处理各种数据集。 92 | 93 | ### 训练模型 94 | 运行如下命令,使用预处理后的数据训练模型。 95 | ``` 96 | python train.py --epochs 100 --batch_size 16 --device 0,1 --gpu0_bsz 5 --train_path data/train.pkl 97 | ``` 98 | 超参数说明: 99 | - device:设置使用哪些GPU 100 | - no_cuda:设为True时,不使用GPU 101 | - vocab_path:sentencepiece模型路径,用于tokenize 102 | - model_config:需要从头训练一个模型时,模型参数的配置文件 103 | - train_path:经过预处理之后的数据存放路径 104 | - max_len:训练时,输入数据的最大长度。 105 | - log_path:训练日志存放位置 106 | - ignore_index:对于该token_id,不计算loss,默认为-100 107 | - epochs:训练的最大轮次 108 | - batch_size:训练的batch size 109 | - gpu0_bsz:pytorch使用多GPU并行训练时,存在负载不均衡的问题,即0号卡满载了,其他卡还存在很多空间,抛出OOM异常。该参数可以设置分配到0号卡上的数据数量。 110 | - lr:学习率 111 | - eps:AdamW优化器的衰减率 112 | - log_step:多少步汇报一次loss 113 | - gradient_accumulation_steps:梯度累计的步数。当显存空间不足,batch_size无法设置为较大时,通过梯度累计,缓解batch_size较小的问题。 114 | - save_model_path:模型输出路径 115 | - pretrained_model:预训练的模型的路径 116 | - num_workers:dataloader加载数据时使用的线程数量 117 | - warmup_steps:训练时的warm up步数 118 | 119 | 120 | ### 文本生成 121 | 运行如下命令,进行文本生成。 122 | ``` 123 | python generate.py --device 0 --max_len 200 --title 家乡的四季 --context 家乡的四季,最美不过了 124 | ``` 125 | 超参数说明: 126 | - device:使用哪个GPU进行生成 127 | - temperature:详情可参考temperature sampling的思想 128 | - topk:top-k采样(注:topp为0,topk不为0时采用top-k采样) 129 | - topp:核采样(注:topk为0,topp不为0时,采用核采样) 130 | - max_len:生成的最长长度 131 | - log_path:生成日志存放位置 132 | - no_cuda:设为True时,不使用GPU 133 | - model_path:模型存放路径 134 | - title:作文标题 135 | - context:作文上文 136 | - context_len:每一步生成时,参考的上文的长度 137 | 138 | ### Http服务 139 | 将模型生成能力封装成Http服务,支持Post与Get请求。运行如下命令,启动服务。 140 | ``` 141 | python http_service.py --port 8085 --model_path model/zuowen_epoch40 --context_len 200 142 | ``` 143 | Get请求: 144 | ``` 145 | http://localhost:8085/zuowen?title="家乡的四季"&context="家乡的四季,最美不过了"&max_len=200 146 | ``` 147 | Post请求 148 | ``` 149 | localhost:8085/zuowen 150 | { 151 | 'title':'家乡的四季', 152 | 'context':'家乡的四季,最美不过了', 153 | 'max_len':200 154 | } 155 | ``` 156 | 157 | 超参数说明: 158 | - device:使用哪个GPU进行生成 159 | - temperature:详情可参考temperature sampling的思想 160 | - topk:top-k采样(注:topp为0,topk不为0时采用top-k采样) 161 | - topp:核采样(注:topk为0,topp不为0时,采用核采样) 162 | - port:服务绑定的端口号 163 | - log_path:生成日志存放位置 164 | - no_cuda:设为True时,不使用GPU 165 | - model_path:模型存放路径 166 | - context_len:每一步生成时,参考的上文的长度 167 | 168 | 169 |

模型分享

170 | 171 | |模型 | 共享地址 |模型描述| 172 | |---------|--------|--------| 173 | |zuowen_epoch40 | [百度网盘【提取码:8o3v】](https://pan.baidu.com/s/1nwyqQ6WyE0mE0U6OVlThEQ) |使用26w篇中文作文语料训练了40个epoch,loss降到2.1左右,单词预测准确率大约为54%| 174 | 175 | ## Future Work 176 | - 使用3张1080Ti进行训练,由于显卡资源有限,在数据预处理时,使用了大小为200的滑动窗口对数据进行截断,batch_size设为50。没有充分使用模型1024的最大输入长度,导致训练不够充分。若有充足的显卡资源,可以使用1024的滑动窗口对数据进行截断,提高模型的生成效果。 177 | - 当前代码主要针对作文数据集进行数据预处理、训练、生成。后续将会更新代码,以便用于处理各种数据集。 178 | 179 |

生成样例

180 | 以下生成样例,生成长度默认为200。 181 | 182 | ### 家乡的四季 183 | ``` 184 | title:家乡的四季 185 | context:家乡的四季,最美不过了 186 | 187 | result: 188 | 家乡的四季,最美不过了。家乡的四季,是令人沉醉的。 189 | 春天,万物复苏,冰雪融化,万物复苏。树枝抽出了嫩芽,花朵绽放了笑脸,树木吐出了嫩芽,春笋也破土而出,像是迎接春天的到来。小鸟们也在枝头唱起了动听的歌曲,周围的一切都变成了春的样子。 190 | 夏天,荷塘里的荷花开了,散发出阵阵清香。远处,山的颜色深浅不一,像是穿着一件翠绿的长裙,在荷塘的衬托下显得更加美,更加翠绿。微风拂过,荷花轻轻地摆动着,像是在和我打招呼呢! 191 | 秋天, 192 | 193 | result: 194 | 家乡的四季,最美不过了。 195 | 家乡的春天,柳树发芽了,小草从泥土里探出头来,小花也张开了笑脸,小草偷偷地探出头来。我小时候,经常到那里玩,在那里捉迷藏,去田野里捉迷藏。到了晚上,爷爷便去田野里找蟋蟀,等到第二天早上,爷爷就去捉蟋蟀了。 196 | 家乡的夏天,荷塘里开满了荷花,碧绿的荷叶,荷花都开了,荷叶上还有青蛙王子,他们正在开大会呢! 197 | 家乡的秋天,果实累累,果园里更是瓜果飘香。你看,农民伯伯正忙着摘果实呢!爷爷会摘苹果,苹果熟了, 198 | 199 | result: 200 | 家乡的四季,最美不过了。 201 | 春天,嫩芽破土而出,焕发出生机。每当春姑娘来临之际,小草就会脱下旧衣服,冲出家门,迫不及待地站在土地上,感受春风亲吻着自己的脸庞,贪婪地吸吮着甘甜的露水。春姑娘来到田野里,到处都是一片嫩绿,一派盎然的景象。柳树姑娘刚刚梳理好头发,甩动着长长的头发,伴随着阵阵春风,跳起了欢快的舞蹈。此时此刻,春雨也来凑热闹了,她滴落在溪水中,随着春风舞动起来,漾起一圈圈水纹。在河边,长满了一串串一串串鲜艳的鲜花, 202 | 203 | result: 204 | 家乡的四季,最美不过了,四季各有特色。 205 | 春天,小草探出了它那绿绿的小脑袋,柳树的枝条随风飘动,好像正在给春姑娘梳头发。桃花、杏花、梨花争先恐后的开放,如同一个个粉红的小精灵在枝头跳着美丽的舞蹈。小燕子从南方飞来,在空中快乐的飞来飞去,非常动听。 206 | 夏天,骄阳似火,树木葱葱笼,在骄阳的照耀下,鸟儿也在树上唱着动听的歌。小孩子们穿着短袖,在大树下坐着乘凉,偶尔会出现几个小朋友在那里捉迷藏,嬉戏。 207 | 秋天, 208 | 209 | result: 210 | 家乡的四季,最美不过了,我家乡的四季是如此美丽。 211 | 春天到了,小草从泥土里钻出来了,正东张西望地观察着四周,像是在寻找着什么。大树也绽开了笑脸,开出了许多颜色各异的花,有黄色、红色、紫色、绿色,真是色色俱全啊!花儿在春雨的滋润下,绽放出了自己美丽的花朵,散发出了迷人的芳香,那花儿就像一位位亭亭玉立的少女,娇艳迷人,美丽极了。那嫩绿的小草,铺满了大地,让我们感到生命的希望。 212 | 夏天,小草长得郁郁葱葱,到处都是绿茵茵的,走在路上, 213 | 214 | result: 215 | 家乡的四季,最美不过了。 216 | 春天,到处充满了生机勃勃。风和日丽,万物复苏,柳树那碧绿的头发被风吹得翩翩起舞,像一个亭亭玉立的少女在对我招手。 217 | 夏天,太阳高高地挂在天空,灿烂的阳光照耀着大地,我看见农民伯伯忙碌的身影,心想:这么热的天,还要干什么?我要帮助他们干农活。我想着想着,又想到了奶奶家,我就跑到奶奶家的西瓜地里,去拿奶奶的小锄头,把小锄头递到奶奶的手里,奶奶一边干活,一边说:“你可真棒!” 218 | 秋天, 219 | 220 | result: 221 | 家乡的四季,最美不过了。 222 | 春天到了,花儿苏醒了,小草冒出了头,树木抽出了新的枝条,燕子又飞回来了。远处,连绵起伏的高山一座连着一座,就像一座座大山。山下有一条清澈的小溪,水哗啦啦地流着,就像一匹美丽的蓝丝绸。山上的树木也抽出了新的枝条,长出了嫩绿的叶子,叶子好像一块块绿宝石。燕子从南方飞回来了,站在枝头上,叽叽喳喳地叫着,好像在唱着春天的赞歌。 223 | 夏天到了,太阳像个大火球,照着地面,我和小伙伴们经常到小溪里游泳、打水仗, 224 | 225 | result: 226 | 家乡的四季,最美不过了。 227 | 春天,那条河流解冻了。岸边的柳树随风飘动,那优美的身姿,加上那动人的歌喉,简直是春姑娘的杰作。小燕子从南方赶来,叽叽喳喳的叫着,好像在和柳树说:“柳树姐姐,你的头发可真美呀!”小草偷偷的从土里钻了出来,那嫩嫩的、绿绿的,就像刚打过仗的战士一样。花儿们,现在都刚刚长出来,那些花,就像刚睡醒的小婴儿,那些花,有红的、黄的、紫的......五彩缤纷,美丽极了。夏天, 228 | 229 | result: 230 | 家乡的四季,最美不过了,无论是万物复苏的春天,还是烈日炎炎的夏天,抑或是硕果累累的秋天,每一个季节都有它们的特点。 231 | 春姑娘刚走,夏姐姐就来了。太阳公公热情的烘烤着大地,知了在树上“知了知了”的叫着,仿佛在说:“好热啊!好热啊!”荷花池塘也“出淤泥而不染”,荷叶“接天莲叶无穷碧”的的清香随风飘来,整个家乡弥漫着沁人心脾的花香,让人陶醉。 232 | 春姑娘刚刚走,夏姑娘就来了。太阳像一个又大又圆的火球挂在天空中, 233 | ``` 234 | 235 | ### 徜徉在书籍的阳光世界 236 | ``` 237 | title: 徜徉在书籍的阳光世界 238 | contetx: 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙 239 | 240 | result: 241 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙。 242 | 你可曾品尝过文学的魅力? 243 | 唐诗宋词,群星璀璨,给我们无尽的想象与愉悦。 244 | “明月几时有,把酒问青天。不知天上宫阙,今夕是何年。”它描绘了一个美好的时代。苏轼在赤壁赏月时,不禁为这美景感叹。“明月几时有,把酒问青天。”它告诉了我们人生的哲理。 245 | 文学作品,不但丰富了我们的知识,也为我们描绘了一幅幅优美的山水画。 246 | 语文书中的婉约柔情,让我感受到世间的人情冷暖, 247 | 248 | result: 249 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙;一本好书是一个人的眸子,它可以让你看清世界的脉络;一本好书是一把钥匙,它可以打开你心灵的窗户。我徜徉在书的世界里,在阅读中,我找到了梦想。 250 | 一本好书,犹如一泓清泉,流入我干渴的心田;一本好书,犹如一只小舟,载着我遨游在知识的海洋;一本好书,犹如一缕阳光,照亮我的心房。 251 | 记得在我很小的时候,我每天都要缠着妈妈给我讲故事,每次妈妈讲完故事,我都会依偎在妈妈的怀里, 252 | 253 | result: 254 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙;一本书是一场细雨,滋润你的心田;一本书是你的拐杖,带你走进这个美妙的世界。 255 | 在我很小的时候,就开始接触书籍了,我有一个非常要好的朋友,叫做书。在我很小的时候,书还是不可缺少的。 256 | 在我不认字的时候,我就会捧着《格林童话》,开始认真地看书,我看的津津有味。《格林童话》让我明白了做人的道理,《白雪公主》让我知道了善良的重要;《卖火柴的小女孩》让我明白了人间的幸福是美好的, 257 | 258 | result: 259 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙。书就像是一颗闪烁的星星,给你引航;书就像一汪清泉,给你洗涤心灵;书就像一束阳光,给你带来无穷的温暖...... 260 | 我从小就喜欢读书。一个冬天的下午,我在家楼下的小广场上坐着,静静地享受着小时候的乐趣。突然,一位老爷爷从远处走了过来,手里拿着一本厚厚的《安徒生童话》,我拿起这本书,心想:这书可是我的心爱之物啊! 261 | 于是,我跑到他身边,与他交谈起来。原来,这位老爷爷就是在我六岁时, 262 | 263 | reslut: 264 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙,每一本都有着不一样的内涵。 265 | ——题记 266 | 在某个宁静的午后,沉醉在书本的世界里,沉醉在阅读的魅力里,沉醉在阅读的心灵深处。 267 | 坐在一望无际的草原上,静静地读书。我像一匹饿狼,贪婪地读着,不一会儿,我就沉浸在书中。不知不觉,太阳已落下去,不知不觉,天色已晚,我们只好依依不舍地收起书本。 268 | 夕阳西下,落日把天空染成了红色,火烧云像一只只巨象,汹涌澎湃,在天空中横飞, 269 | 270 | result: 271 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙之处,更能让你认识到书本的神奇之处,而我则常常沉浸在那散发着淡淡书香的阅读之中。 272 | 从幼儿园开始,我便爱上了读书,最开始是读绘本,后来爱上了古诗文。记得,一开始是爸爸妈妈带我去书店买诗。“床前明月光,疑是地上霜。举头望明月,低头思故乡。”那时,我对诗的理解是这样的:月光皎洁的夜晚,举头望明月,低头思故乡。“水光,山色空蒙雨亦奇。”那时的我对诗的理解是“帘卷西风寒鸦知人不识月, 273 | 274 | result1: 275 | 一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙。书,是一个精神世界的源泉,在我们的精神世界里,书是一位无声的老师,也是一个最忠实的朋友。 276 | 书是人们的良师益友,是精神世界的指南针,有了书,我们便知道了知识,有了知识,才会使我们做事情变得更有趣味,有了书,我们才能做更多的事。书是我的伴侣,书是我的好老师。 277 | 高尔基曾经说过:"书籍是人类进步的阶梯。莎士比亚曾经说过"书籍是全世界的营养品" 278 | 279 | ``` 280 | 281 | ### 我最敬佩的一个人 282 | ``` 283 | title: 我最敬佩的一个人 284 | context:在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸 285 | 286 | result: 287 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,还有勤劳朴素的爷爷奶奶。但是,我最敬佩的是我的妈妈。 288 | 我的妈妈,有一双炯炯有神的眼睛,高高的鼻梁上架着一副眼镜,她很爱笑,很爱笑。妈妈的头发非常多,细长的柳叶眉下镶嵌着一双炯炯有神的眼睛,好像一个宝石。妈妈长的高高的个子,鼻梁上架着一副眼镜。 289 | 290 | result: 291 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,有工作努力的奶奶,有勤奋好学的姐姐......但是,我最敬佩的是那位平凡的清洁工人。 292 | 一天,我和妈妈一起乘坐公交车回家。到达了车站,我和妈妈下了车,就急匆匆地跑向家附近的早餐店。吃完早餐后,我们正准备上公交车,可是我发现有一位清洁工老爷爷正在寒风中扫地。他身穿一件单薄的衣服,衣服上沾满了灰尘,他却是用手把垃圾一点一点地扫起来,垃圾车井然有序地行驶着。我心想:这位清洁工真不容易啊! 293 | 294 | result: 295 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,有幽默搞笑的爷爷,有坚持不懈的老师......但是,我最敬佩的人是我的奶奶。 296 | 我的奶奶非常爱美,喜欢穿红色衣服。有一天,奶奶过生日,我早早地起了床,迫不及待地对奶奶说:“奶奶,奶奶,祝你生日快乐,身体健康,万事如意!”奶奶开心地说:“谢谢你,宝贝,你真是长大了。” 297 | 我的奶奶又很勤劳。她很会做家务活,做家务活,她把家里打扫得干干净净,就像一个小家一样。她不仅把家里打扫得干净, 298 | 299 | result: 300 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,有知识渊博的爷爷,但是我最敬佩的还是我的妈妈。 301 | 我的妈妈长着一头乌黑发亮的头发,又短又黑,一双炯炯有神的大眼睛,笑起来特别好看,小小的嘴巴一笑起来就露出两个甜甜的酒窝,非常迷人。 302 | 妈妈的个子比较高,还稍微有点胖,这都是为什么妈妈很胖的原因。但是妈妈每天都非常累,她是一个非常勤劳的人,每天都要很早起床,为了家人做出更多的早餐,她总是天不亮就起床,然后再叫醒我。 303 | 304 | result: 305 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,还有日夜奔波的老师......而我最敬佩的人就是我的英语老师--彭老师。 306 | 她有一头乌黑亮丽的长发,弯弯的眉毛下面是一双炯炯有神的大眼睛。上课时,她的声音很小,经常在黑板上点出枯燥的英语单词。如果我们在写字,她就用眼睛一遍一遍地望着我们,就像在注视着自己的孩子一样。彭老师十分严格。 307 | 308 | result: 309 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,还有一个吃苦耐劳的爷爷,但是最令我敬佩的是我的爷爷。 310 | 爷爷的身高已经一米七了,已经有80多岁了,他虽然已经退休了,但仍然坚持每天给我做可口的饭菜,陪我玩耍,照顾我的生活起居,而且还坚持每天接送我上下学,爸爸妈妈很爱我。 311 | 我的爷爷长着一张圆圆的脸,一双炯炯有神的眼睛, 312 | 313 | result: 314 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸"我最敬佩的一个人",其中,有一个人我最敬佩。 315 | 她长着一双炯炯有神的眼睛,高高的鼻梁上架着一副黑色的眼镜,给人一种文雅大气的感觉,一张樱桃小嘴一张一合,给人一种读书的感觉,她就是我的妈妈。 316 | 我的妈妈是一个喜欢化妆的人。她每次都会把自己打扮得漂漂亮亮的, 317 | 318 | result: 319 | 在我的生活中,有外冷内热的妈妈,有拼命工作的爸爸,但最敬佩我的哥哥。 320 | 哥哥是一个人见人爱,花见花开,车见车爆胎的卖菜小贩。哥哥的头上会扎成一个三角形,眉毛下面长着一双明亮的大眼睛。鼻子很小巧,还有一个樱桃小嘴。他的嘴巴虽然小,但是能说会道。你看,他的脸蛋上还长了一对小酒窝。 321 | ``` 322 | 323 | 324 | -------------------------------------------------------------------------------- /config/cpm-medium.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 1, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 2, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 1024, 15 | "n_head": 16, 16 | "n_layer": 24, 17 | "n_positions": 1024, 18 | "n_special": 0, 19 | "predict_special_tokens": true, 20 | "resid_pdrop": 0.1, 21 | "summary_activation": null, 22 | "summary_first_dropout": 0.1, 23 | "summary_proj_to_labels": true, 24 | "summary_type": "cls_index", 25 | "summary_use_proj": true, 26 | "task_specific_params": { 27 | "text-generation": { 28 | "do_sample": true, 29 | "max_length": 50 30 | } 31 | }, 32 | "vocab_size": 30000 33 | } -------------------------------------------------------------------------------- /config/cpm-one-layer.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 50256, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 50256, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 768, 15 | "n_head": 12, 16 | "n_layer": 1, 17 | "n_positions": 1024, 18 | "resid_pdrop": 0.1, 19 | "summary_activation": null, 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_type": "cls_index", 23 | "summary_use_proj": true, 24 | "task_specific_params": { 25 | "text-generation": { 26 | "do_sample": true, 27 | "max_length": 50 28 | } 29 | }, 30 | "vocab_size": 30000 31 | } -------------------------------------------------------------------------------- /config/cpm-small.json: -------------------------------------------------------------------------------- 1 | { 2 | "activation_function": "gelu_new", 3 | "architectures": [ 4 | "GPT2LMHeadModel" 5 | ], 6 | "attn_pdrop": 0.1, 7 | "bos_token_id": 50256, 8 | "embd_pdrop": 0.1, 9 | "eos_token_id": 50256, 10 | "initializer_range": 0.02, 11 | "layer_norm_epsilon": 1e-05, 12 | "model_type": "gpt2", 13 | "n_ctx": 1024, 14 | "n_embd": 768, 15 | "n_head": 12, 16 | "n_layer": 12, 17 | "n_positions": 1024, 18 | "resid_pdrop": 0.1, 19 | "summary_activation": null, 20 | "summary_first_dropout": 0.1, 21 | "summary_proj_to_labels": true, 22 | "summary_type": "cls_index", 23 | "summary_use_proj": true, 24 | "task_specific_params": { 25 | "text-generation": { 26 | "do_sample": true, 27 | "max_length": 50 28 | } 29 | }, 30 | "vocab_size": 30000 31 | } -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.nn.parallel import DataParallel 3 | import torch 4 | from torch.nn.parallel._functions import Scatter 5 | from torch.nn.parallel.parallel_apply import parallel_apply 6 | 7 | def scatter(inputs, target_gpus, chunk_sizes, dim=0): 8 | r""" 9 | Slices tensors into approximately equal chunks and 10 | distributes them across given GPUs. Duplicates 11 | references to objects that are not tensors. 12 | """ 13 | def scatter_map(obj): 14 | if isinstance(obj, torch.Tensor): 15 | try: 16 | return Scatter.apply(target_gpus, chunk_sizes, dim, obj) 17 | except: 18 | print('obj', obj.size()) 19 | print('dim', dim) 20 | print('chunk_sizes', chunk_sizes) 21 | quit() 22 | if isinstance(obj, tuple) and len(obj) > 0: 23 | return list(zip(*map(scatter_map, obj))) 24 | if isinstance(obj, list) and len(obj) > 0: 25 | return list(map(list, zip(*map(scatter_map, obj)))) 26 | if isinstance(obj, dict) and len(obj) > 0: 27 | return list(map(type(obj), zip(*map(scatter_map, obj.items())))) 28 | return [obj for targets in target_gpus] 29 | 30 | # After scatter_map is called, a scatter_map cell will exist. This cell 31 | # has a reference to the actual function scatter_map, which has references 32 | # to a closure that has a reference to the scatter_map cell (because the 33 | # fn is recursive). To avoid this reference cycle, we set the function to 34 | # None, clearing the cell 35 | try: 36 | return scatter_map(inputs) 37 | finally: 38 | scatter_map = None 39 | 40 | def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): 41 | r"""Scatter with support for kwargs dictionary""" 42 | inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] 43 | kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] 44 | if len(inputs) < len(kwargs): 45 | inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) 46 | elif len(kwargs) < len(inputs): 47 | kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) 48 | inputs = tuple(inputs) 49 | kwargs = tuple(kwargs) 50 | return inputs, kwargs 51 | 52 | class BalancedDataParallel(DataParallel): 53 | def __init__(self, gpu0_bsz, *args, **kwargs): 54 | self.gpu0_bsz = gpu0_bsz 55 | super().__init__(*args, **kwargs) 56 | 57 | def forward(self, *inputs, **kwargs): 58 | if not self.device_ids: 59 | return self.module(*inputs, **kwargs) 60 | if self.gpu0_bsz == 0: 61 | device_ids = self.device_ids[1:] 62 | else: 63 | device_ids = self.device_ids 64 | inputs, kwargs = self.scatter(inputs, kwargs, device_ids) 65 | 66 | # print('len(inputs): ', str(len(inputs))) 67 | # print('self.device_ids[:len(inputs)]', str(self.device_ids[:len(inputs)])) 68 | 69 | if len(self.device_ids) == 1: 70 | return self.module(*inputs[0], **kwargs[0]) 71 | if self.gpu0_bsz == 0: 72 | replicas = self.replicate(self.module, self.device_ids) 73 | else: 74 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 75 | 76 | # replicas = self.replicate(self.module, device_ids[:len(inputs)]) 77 | if self.gpu0_bsz == 0: 78 | replicas = replicas[1:] 79 | 80 | # print('replicas:', str(len(replicas))) 81 | 82 | outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) 83 | return self.gather(outputs, self.output_device) 84 | 85 | def parallel_apply(self, replicas, device_ids, inputs, kwargs): 86 | return parallel_apply(replicas, inputs, kwargs, device_ids[:len(inputs)]) 87 | 88 | def scatter(self, inputs, kwargs, device_ids): 89 | bsz = inputs[0].size(self.dim) 90 | num_dev = len(self.device_ids) 91 | gpu0_bsz = self.gpu0_bsz 92 | bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) 93 | if gpu0_bsz < bsz_unit: 94 | chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) 95 | delta = bsz - sum(chunk_sizes) 96 | for i in range(delta): 97 | chunk_sizes[i + 1] += 1 98 | if gpu0_bsz == 0: 99 | chunk_sizes = chunk_sizes[1:] 100 | else: 101 | return super().scatter(inputs, kwargs, device_ids) 102 | 103 | # print('bsz: ', bsz) 104 | # print('num_dev: ', num_dev) 105 | # print('gpu0_bsz: ', gpu0_bsz) 106 | # print('bsz_unit: ', bsz_unit) 107 | # print('chunk_sizes: ', chunk_sizes) 108 | return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) 109 | 110 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | 4 | 5 | class CPMDataset(Dataset): 6 | """ 7 | 8 | """ 9 | 10 | def __init__(self, input_list, max_len): 11 | self.input_list = input_list 12 | self.max_len = max_len 13 | 14 | def __getitem__(self, index): 15 | input_ids = self.input_list[index] 16 | input_ids = input_ids[:self.max_len] 17 | input_ids = torch.tensor(input_ids, dtype=torch.long) 18 | return input_ids 19 | 20 | def __len__(self): 21 | return len(self.input_list) 22 | -------------------------------------------------------------------------------- /distill.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from loguru import logger 3 | 4 | import numpy as np 5 | 6 | import torch 7 | from torch.utils.data import DataLoader 8 | 9 | from transformers import GPT2LMHeadModel, GPT2Config 10 | import os 11 | from os.path import join 12 | import random 13 | import pickle 14 | import time 15 | import torch.nn.utils.rnn as rnn_utils 16 | import transformers 17 | import torch.nn.functional as F 18 | from torch.utils.tensorboard import SummaryWriter 19 | from dataset import CPMDataset 20 | """ 21 | 模型蒸馏 22 | """ 23 | 24 | 25 | def collate_fn(batch): 26 | input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=5) 27 | labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100) 28 | return input_ids, labels 29 | 30 | 31 | def seed_everything(seed=42): 32 | """ 33 | 设置整个开发环境的seed 34 | """ 35 | random.seed(seed) 36 | os.environ['PYTHONHASHSEED'] = str(seed) 37 | np.random.seed(seed) 38 | # tf.random.set_seed(seed) 39 | torch.manual_seed(seed) 40 | torch.cuda.manual_seed(seed) 41 | torch.cuda.manual_seed_all(seed) 42 | # some cudnn methods can be random even after fixing the seed 43 | # unless you tell it to be deterministic 44 | torch.backends.cudnn.deterministic = True 45 | 46 | 47 | def calculate_acc(logit, labels, ignore_index=-100): 48 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 49 | labels = labels[..., 1:].contiguous().view(-1) 50 | 51 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 52 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 53 | non_pad_mask = labels.ne(ignore_index) 54 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 55 | n_word = non_pad_mask.sum().item() 56 | return n_correct, n_word 57 | 58 | 59 | def distill_loss(logits, labels, target_logits, ignore_index): 60 | # hard loss 61 | hard_loss = hard_cross_entropy_loss(logits, labels, ignore_index) 62 | # soft loss 63 | soft_loss = soft_cross_entropy_loss(logits, labels, target_logits, ignore_index) 64 | # 加权 65 | loss = 0.5 * hard_loss + 0.5 * soft_loss 66 | return loss 67 | 68 | 69 | def hard_cross_entropy_loss(logits, labels, ignore_index): 70 | logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) 71 | labels = labels[..., 1:].contiguous().view(-1) 72 | loss = F.cross_entropy(logits, labels, ignore_index=ignore_index) 73 | return loss 74 | 75 | 76 | def soft_cross_entropy_loss(logits, labels, target_logits, ignore_index): 77 | logits = logits[..., :-1, :].contiguous().view(-1, logits.size(-1)) 78 | labels = labels[..., 1:].contiguous().view(-1) 79 | target_probs = torch.softmax(target_logits, axis=-1) 80 | target_probs = target_probs[..., :-1, :].contiguous().view(-1, target_probs.size(-1)) 81 | 82 | # 计算每个位置的loss 83 | loss = F.cross_entropy(logits, target_probs, reduction='none') 84 | 85 | # 选出非padding的loss,求平均 86 | loss_mask = (labels == ignore_index) 87 | loss = torch.masked_select(loss, ~loss_mask) 88 | loss = loss.mean() 89 | 90 | return loss 91 | 92 | 93 | def load_dataset(logger, args): 94 | """ 95 | 加载训练集 96 | """ 97 | logger.info("loading training dataset") 98 | train_path = args.train_path 99 | 100 | with open(train_path, "rb") as f: 101 | train_list = pickle.load(f) 102 | 103 | # test 104 | # train_list = train_list[:24] 105 | logger.info('len of train data:{}'.format(len(train_list))) 106 | train_dataset = CPMDataset(train_list, args.max_len) 107 | 108 | return train_dataset 109 | 110 | 111 | def train(teacher, student, train_dataset, writer, args): 112 | teacher.eval() 113 | student.train() 114 | train_dataloader = DataLoader( 115 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, 116 | drop_last=True 117 | ) 118 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs 119 | optimizer = transformers.AdamW(student.parameters(), lr=args.lr, eps=args.eps) 120 | scheduler = transformers.get_linear_schedule_with_warmup( 121 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 122 | ) 123 | 124 | logger.info('start training') 125 | device = args.device 126 | ignore_index = args.ignore_index 127 | step = 0 128 | train_loss = 0 129 | train_acc = 0 130 | log_step = args.log_step 131 | save_step = args.save_step 132 | 133 | # ========== start training ========== # 134 | for epoch in range(args.epochs): 135 | logger.info('start {}th epoch training'.format(epoch + 1)) 136 | for batch_idx, (input_ids, labels) in enumerate(train_dataloader): 137 | step += 1 138 | input_ids = input_ids.to(device) 139 | labels = labels.to(device) 140 | with torch.no_grad(): 141 | target_logits = teacher(input_ids=input_ids).logits 142 | logits = student(input_ids=input_ids).logits 143 | 144 | # 计算loss 145 | loss = soft_cross_entropy_loss(logits, labels, target_logits, args.ignore_index) 146 | # 统计该batch的预测token的正确数与总数 147 | batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index) 148 | batch_acc = batch_correct_num/batch_total_num 149 | train_loss += loss 150 | train_acc += batch_acc 151 | 152 | if args.gradient_accumulation_steps > 1: 153 | loss = loss / args.gradient_accumulation_steps 154 | 155 | loss.backward() 156 | # 梯度裁剪 157 | torch.nn.utils.clip_grad_norm_(student.parameters(), args.max_grad_norm) 158 | # 进行一定step的梯度累计之后,更新参数 159 | if step % args.gradient_accumulation_steps == 0: 160 | # 更新参数 161 | optimizer.step() 162 | # 更新学习率 163 | scheduler.step() 164 | # 清空梯度信息 165 | optimizer.zero_grad() 166 | 167 | if step % log_step == 0: 168 | train_loss = train_loss / log_step 169 | train_acc = train_acc / log_step 170 | # 训练集指标 171 | logger.info('Epoch {} step {} train Loss {:.4f}, train ACC {:.4f}'.format(epoch + 1, step, train_loss, train_acc)) 172 | writer.add_scalar('train loss', train_loss, step) 173 | writer.add_scalar('train acc', train_acc, step) 174 | train_loss = 0 175 | train_acc = 0 176 | 177 | if step % save_step == 0: 178 | logger.info('Saving model at Epoch {} step {}'.format(epoch + 1, step)) 179 | model_path = join(args.output_path, 'epoch_{}-step_{}'.format(epoch + 1, step)) 180 | if not os.path.exists(model_path): 181 | os.mkdir(model_path) 182 | model_to_save = student.module if hasattr(student, 'module') else student 183 | model_to_save.save_pretrained(model_path) 184 | 185 | logger.info('training finished') 186 | 187 | 188 | def main(): 189 | # 参数设置 190 | args = set_args() 191 | # 设置随机种子 192 | seed_everything(args.seed) 193 | # 设置显卡 194 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_ids 195 | args.device = 'cuda:0' if torch.cuda.is_available() else 'cpu' 196 | # 创建输出目录 197 | if not os.path.exists(args.output_path): 198 | os.makedirs(args.output_path) 199 | # 日志输出位置 200 | cur_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 201 | logger.add(join(args.output_path, 'distill-{}.log'.format(cur_time))) 202 | # 初始化tensorboard 203 | writer = SummaryWriter(args.output_path) 204 | # 加载tokenizer 205 | # tokenizer = CpmTokenizer(vocab_file=args.vocab_path) 206 | # args.eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 207 | # args.pad_id = tokenizer.pad_token_id 208 | # 加载teacher模型 209 | teacher = GPT2LMHeadModel.from_pretrained(args.teacher_checkpoint) 210 | teacher = teacher.to(args.device) 211 | # 初始化student模型 212 | student_config = GPT2Config.from_pretrained(args.student_config_path) 213 | student = GPT2LMHeadModel(student_config) 214 | student = student.to(args.device) 215 | logger.info('student model config:{}'.format(student_config)) 216 | 217 | # 计算模型参数量 218 | params_teacher = sum([param.nelement() for param in teacher.parameters()]) 219 | logger.info("Number of teacher parameter: %.2fM" % (params_teacher / 1e6)) 220 | params_student = sum([param.nelement() for param in student.parameters()]) 221 | logger.info("Number of student parameter: %.2fM" % (params_student / 1e6)) 222 | # 记录参数设置 223 | logger.info(args) 224 | 225 | # 加载训练集 226 | train_dataset = load_dataset(logger, args) 227 | train(teacher, student, train_dataset, writer, args) 228 | 229 | 230 | def set_args(): 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument("--device_ids", type=str, default='3', help="") 233 | parser.add_argument("--output_path", type=str, default='output/distill') 234 | parser.add_argument('--vocab_path', default='vocab/chinese_vocab.model', type=str, required=False, 235 | help='sp模型路径') 236 | parser.add_argument("--teacher_checkpoint", type=str, default="model/zuowen_epoch40", help='teacher模型的路径') 237 | parser.add_argument("--student_config_path", type=str, default="config/cpm-one-layer.json", help='student模型的配置') 238 | parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='经过预处理之后的数据存放路径') 239 | parser.add_argument('--max_len', default=200, type=int, required=False, help='训练时,输入数据的最大长度') 240 | parser.add_argument('--ignore_index', default=-100, type=int, required=False, 241 | help='对于ignore_index的label token不计算梯度') 242 | 243 | parser.add_argument("--lr", type=float, default=1e-4) 244 | parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='AdamW优化器的衰减率') 245 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 246 | parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数') 247 | parser.add_argument('--gradient_accumulation_steps', default=1, type=int, required=False, help='梯度积累的步数') 248 | 249 | parser.add_argument("--epochs", type=int, default=40) 250 | parser.add_argument("--batch_size", type=int, default=4) 251 | parser.add_argument("--num_workers", type=int, default=0) 252 | parser.add_argument("--save_step", type=int, default=100, help="every eval_step to save model") 253 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') 254 | parser.add_argument("--seed", type=int, default=42, help="random seed") 255 | 256 | args = parser.parse_args() 257 | return args 258 | 259 | 260 | if __name__ == '__main__': 261 | main() 262 | 263 | 264 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import argparse 5 | from tqdm import trange 6 | from transformers import GPT2LMHeadModel, GPT2Config, CpmTokenizer 7 | from utils import top_k_top_p_filtering, set_logger 8 | from os.path import join, exists 9 | 10 | 11 | def generate_next_token(input_ids): 12 | """ 13 | 对于给定的上文,生成下一个单词 14 | """ 15 | outputs = model(input_ids=input_ids) 16 | logits = outputs.logits 17 | # next_token_logits表示最后一个token的hidden_state对应的prediction_scores,也就是模型要预测的下一个token的概率 18 | next_token_logits = logits[0, -1, :] 19 | next_token_logits = next_token_logits / args.temperature 20 | # 对于的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token 21 | next_token_logits[unk_id] = -float('Inf') 22 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) 23 | # torch.multinomial表示从候选集合中选出无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 24 | next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 25 | return next_token_id 26 | 27 | 28 | def generate(max_len): 29 | # 对title与context进行tokenize 30 | title_ids = tokenizer.encode(title, add_special_tokens=False) 31 | context_ids = tokenizer.encode(context, add_special_tokens=False) 32 | input_ids = title_ids + [sep_id] + context_ids 33 | cur_len = len(input_ids) 34 | last_token_id = input_ids[-1] # 已生成的内容的最后一个token 35 | input_ids = torch.tensor([input_ids], dtype=torch.long, device=device) 36 | 37 | while True: 38 | next_token_id = generate_next_token(input_ids[:, -args.context_len:]) 39 | input_ids = torch.cat((input_ids, next_token_id.unsqueeze(0)), dim=1) 40 | cur_len += 1 41 | word = tokenizer.convert_ids_to_tokens(next_token_id.item()) 42 | # if cur_len >= max_len: 43 | # break 44 | # 超过最大长度,并且换行 45 | if cur_len >= max_len and last_token_id == 8 and next_token_id == 3: 46 | break 47 | # 超过最大长度,并且生成标点符号 48 | if cur_len >= max_len and word in [".", "。", "!", "!", "?", "?", ",", ","]: 49 | break 50 | # 生成结束符 51 | if next_token_id == eod_id: 52 | break 53 | result = tokenizer.decode(input_ids.squeeze(0)) 54 | return result 55 | 56 | 57 | if __name__ == '__main__': 58 | # 参数设置 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--device', default='0', type=str, required=False, help='生成设备') 61 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度') 62 | parser.add_argument('--topk', default=0, type=int, required=False, help='最高几选一') 63 | parser.add_argument('--topp', default=0.85, type=float, required=False, help='最高积累概率') 64 | parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help='重复惩罚参数') 65 | parser.add_argument('--context_len', default=200, type=int, required=False, help='每一步生成时,参考的上文的长度') 66 | parser.add_argument('--max_len', default=300, type=int, required=False, help='生成的最长长度') 67 | parser.add_argument('--log_path', default='log/generate.log', type=str, required=False, help='日志存放位置') 68 | parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') 69 | parser.add_argument('--model_path', type=str, default='model/zuowen_epoch40', help='模型存放位置') 70 | # parser.add_argument('--title', type=str, default='徜徉在书籍的阳光世界', help='作文标题') 71 | # parser.add_argument('--context', type=str, default='一本书是一个人的眼睛,它可以让你看到另一个世界的奇妙', help='作文上文') 72 | parser.add_argument('--title', type=str, default='家乡的四季', help='作文标题') 73 | parser.add_argument('--context', type=str, default='家乡的四季,最美不过了', help='作文上文') 74 | args = parser.parse_args() 75 | 76 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 77 | args.cuda = torch.cuda.is_available() and not args.no_cuda # 当用户使用GPU,并且GPU可用时 78 | device = 'cuda:0' if args.cuda else 'cpu' 79 | # device = 'cpu' 80 | 81 | # 创建日志对象 82 | logger = set_logger(args.log_path) 83 | 84 | # 初始化tokenizer 85 | tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model") 86 | eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 87 | sep_id = tokenizer.sep_token_id 88 | unk_id = tokenizer.unk_token_id 89 | 90 | # 加载模型 91 | model = GPT2LMHeadModel.from_pretrained(args.model_path) 92 | model.eval() 93 | model = model.to(device) 94 | 95 | title = args.title 96 | context = args.context 97 | logger.info("title:{}".format(title)) 98 | logger.info("context:{}".format(context)) 99 | 100 | # 开始生成 101 | result = generate(args.max_len) 102 | result = result.split("")[1] 103 | logger.info("result:{}\n".format(result)) 104 | 105 | # 通过控制台循环生成 106 | # print('开始生成,输入CTRL + Z以退出') 107 | # while True: 108 | # try: 109 | # # 用户输入title与context 110 | # title = input("请输入作文标题:") 111 | # context = input("请输入作文起始句子:") 112 | # 113 | # logger.info("title:{}".format(title)) 114 | # logger.info("context:{}".format(context)) 115 | # 116 | # # 开始生成 117 | # result = generate(args.max_len) 118 | # result = result.split("")[1] 119 | # logger.info("result:{}\n".format(result)) 120 | # break 121 | # 122 | # except KeyboardInterrupt: 123 | # break 124 | 125 | 126 | -------------------------------------------------------------------------------- /http_service.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import argparse 5 | from tqdm import trange 6 | from transformers import GPT2LMHeadModel, CpmTokenizer 7 | from utils import top_k_top_p_filtering, set_logger 8 | from os.path import join 9 | from flask import Flask, redirect, url_for, request 10 | app = Flask(__name__) 11 | app.config["JSON_AS_ASCII"] = False # 防止返回中文乱码 12 | 13 | 14 | def generate_next_token(input_ids): 15 | """ 16 | 对于给定的上文,生成下一个单词 17 | """ 18 | # 只根据当前位置的前context_len个token进行生成 19 | input_ids = input_ids[:, -args.context_len:] 20 | outputs = model(input_ids=input_ids) 21 | logits = outputs.logits 22 | # next_token_logits表示最后一个token的hidden_state对应的prediction_scores,也就是模型要预测的下一个token的概率 23 | next_token_logits = logits[0, -1, :] 24 | next_token_logits = next_token_logits / args.temperature 25 | # 对于的概率设为无穷小,也就是说模型的预测结果不可能是[UNK]这个token 26 | next_token_logits[unk_id] = -float('Inf') 27 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=args.topk, top_p=args.topp) 28 | # torch.multinomial表示从候选集合中选出无放回地进行抽取num_samples个元素,权重越高,抽到的几率越高,返回元素的下标 29 | next_token_id = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 30 | return next_token_id 31 | 32 | 33 | @app.route('/zuowen', methods=['POST', 'GET']) 34 | def zuowen(): 35 | if request.method == 'POST': 36 | data = request.get_json() 37 | title = data['title'] 38 | context = data['context'] 39 | max_len = data['max_len'] 40 | elif request.method == 'GET': 41 | title = request.args.get('title', type=str) 42 | context = request.args.get('context', type=str) 43 | max_len = request.args.get('max_len', type=int) 44 | 45 | # print("title:{}".format(title)) 46 | # print("context:{}".format(context)) 47 | logger.info("receive request,title:{}, context:{}".format(title, context)) 48 | 49 | title_ids = tokenizer.encode(title, add_special_tokens=False) 50 | context_ids = tokenizer.encode(context, add_special_tokens=False) 51 | input_ids = title_ids + [sep_id] + context_ids 52 | cur_len = len(input_ids) 53 | last_token_id = input_ids[-1] # 已生成的内容的最后一个token 54 | input_ids = torch.tensor([input_ids], dtype=torch.long, device=device) 55 | 56 | while True: 57 | next_token_id = generate_next_token(input_ids) 58 | input_ids = torch.cat((input_ids, next_token_id.unsqueeze(0)), dim=1) 59 | cur_len += 1 60 | word = tokenizer.convert_ids_to_tokens(next_token_id.item()) 61 | # 超过最大长度,并且换行 62 | if cur_len >= max_len and last_token_id == 8 and next_token_id == 3: 63 | break 64 | # 超过最大长度,并且生成标点符号 65 | if cur_len >= max_len and word in [".", "。", "!", "!", "?", "?", ",", ","]: 66 | break 67 | # 生成结束符 68 | if next_token_id == eod_id: 69 | break 70 | result = tokenizer.decode(input_ids.squeeze(0)) 71 | content = result.split("")[1] # 生成的最终内容 72 | result = {"title": title, "content": content} 73 | logger.info("generated result:{}".format(result)) 74 | return result 75 | 76 | 77 | if __name__ == '__main__': 78 | # 参数设置 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument('--device', default='1', type=str, required=False, help='生成设备') 81 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度') 82 | parser.add_argument('--topk', default=0, type=int, required=False, help='最高几选一') 83 | parser.add_argument('--topp', default=0.85, type=float, required=False, help='最高积累概率') 84 | parser.add_argument('--context_len', default=200, type=int, required=False, help='作文生成中,每一步生成时,参考的上文的长度') 85 | # parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False, help='重复惩罚参数') 86 | parser.add_argument('--port', type=int, default=8085, help='服务绑定的端口号') 87 | parser.add_argument('--log_path', default='log/http_service.log', type=str, required=False, help='日志存放位置') 88 | parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行预测') 89 | parser.add_argument('--model_path', type=str, default='model/zuowen_epoch40', help='模型存放位置') 90 | args = parser.parse_args() 91 | 92 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 93 | args.cuda = torch.cuda.is_available() and not args.no_cuda # 当用户使用GPU,并且GPU可用时 94 | device = 'cuda:0' if args.cuda else 'cpu' 95 | # device = 'cpu' 96 | 97 | # 创建日志对象 98 | logger = set_logger(args.log_path) 99 | 100 | # 加载tokenizer 101 | tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model") 102 | eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 103 | sep_id = tokenizer.sep_token_id 104 | unk_id = tokenizer.unk_token_id 105 | 106 | # 加载模型 107 | model = GPT2LMHeadModel.from_pretrained(args.model_path) 108 | model.eval() 109 | model = model.to(device) 110 | 111 | app.run(debug=True, host="0.0.0.0", port=args.port) 112 | -------------------------------------------------------------------------------- /log/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /model/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file !.gitkeep -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import set_logger 3 | from transformers import CpmTokenizer 4 | import os 5 | import pickle 6 | from tqdm import tqdm 7 | 8 | 9 | def preprocess(): 10 | """ 11 | 对故事数据集进行预处理 12 | """ 13 | # 设置参数 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--vocab_file', default='vocab/chinese_vocab.model', type=str, required=False, 16 | help='词表路径') 17 | parser.add_argument('--log_path', default='log/preprocess.log', type=str, required=False, help='日志存放位置') 18 | parser.add_argument('--data_path', default='data/zuowen', type=str, required=False, help='数据集存放位置') 19 | parser.add_argument('--save_path', default='data/train.pkl', type=str, required=False, help='对训练数据集进行tokenize之后的数据存放位置') 20 | parser.add_argument('--win_size', default=200, type=int, required=False, help='滑动窗口的大小,相当于每条数据的最大长度') 21 | parser.add_argument('--step', default=200, type=int, required=False, help='滑动窗口的滑动步幅') 22 | args = parser.parse_args() 23 | 24 | # 初始化日志对象 25 | logger = set_logger(args.log_path) 26 | 27 | # 初始化tokenizer 28 | tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model") 29 | eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 30 | sep_id = tokenizer.sep_token_id 31 | 32 | # 读取作文数据集目录下的所有文件 33 | train_list = [] 34 | logger.info("start tokenizing data") 35 | for file in tqdm(os.listdir(args.data_path)): 36 | file = os.path.join(args.data_path, file) 37 | with open(file, "r", encoding="utf8")as reader: 38 | lines = reader.readlines() 39 | title = lines[1][3:].strip() # 取出标题 40 | lines = lines[7:] # 取出正文内容 41 | article = "" 42 | for line in lines: 43 | if line.strip() != "": # 去除换行 44 | article += line 45 | title_ids = tokenizer.encode(title, add_special_tokens=False) 46 | article_ids = tokenizer.encode(article, add_special_tokens=False) 47 | token_ids = title_ids + [sep_id] + article_ids + [eod_id] 48 | # train_list.append(token_ids) 49 | 50 | # 对于每条数据,使用滑动窗口对其进行截断 51 | win_size = args.win_size 52 | step = args.step 53 | start_index = 0 54 | end_index = win_size 55 | data = token_ids[start_index:end_index] 56 | train_list.append(data) 57 | start_index += step 58 | end_index += step 59 | while end_index+50 < len(token_ids): # 剩下的数据长度,大于或等于50,才加入训练数据集 60 | data = token_ids[start_index:end_index] 61 | train_list.append(data) 62 | start_index += step 63 | end_index += step 64 | 65 | # 序列化训练数据 66 | with open(args.save_path, "wb") as f: 67 | pickle.dump(train_list, f) 68 | 69 | 70 | if __name__ == '__main__': 71 | preprocess() 72 | 73 | 74 | -------------------------------------------------------------------------------- /script/distill.sh: -------------------------------------------------------------------------------- 1 | python distill.py \ 2 | --device_ids 3 \ 3 | --output_path output/distill \ 4 | --vocab_path vocab/chinese_vocab.model \ 5 | --teacher_checkpoint model/zuowen_epoch40 \ 6 | --student_config_path config/cpm-one-layer.json \ 7 | --train_path data/train.pkl \ 8 | --max_len 200 \ 9 | --ignore_index -100 \ 10 | --lr 1e-4 \ 11 | --eps 1.0e-09 \ 12 | --epochs 40 \ 13 | --batch_size 128 \ 14 | --save_step 5000 \ 15 | --log_step 50 \ 16 | --gradient_accumulation_steps 1 \ 17 | --seed 42 \ 18 | --warmup_steps 4000 -------------------------------------------------------------------------------- /script/train-v2.sh: -------------------------------------------------------------------------------- 1 | python train-v2.py \ 2 | --device_ids 0 \ 3 | --vocab_path vocab/chinese_vocab.model \ 4 | --model_config config/cpm-one-layer.json \ 5 | --train_path data/train.pkl \ 6 | --max_len 200 \ 7 | --ignore_index -100 \ 8 | --epochs 40 \ 9 | --batch_size 64 \ 10 | --lr 1e-4 \ 11 | --eps 1.0e-09 \ 12 | --log_step 50 \ 13 | --save_step 5000 \ 14 | --gradient_accumulation_steps 1 \ 15 | --output_path output/train \ 16 | --pretrained_model '' \ 17 | --seed 42 \ 18 | --warmup_steps 4000 -------------------------------------------------------------------------------- /train-v2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from loguru import logger 5 | import os 6 | from torch.utils.data import Dataset, DataLoader 7 | from os.path import join, exists 8 | import transformers 9 | import pickle 10 | import sys 11 | from utils import set_logger, set_random_seed 12 | from sklearn.model_selection import train_test_split 13 | from data_parallel import BalancedDataParallel 14 | from transformers import GPT2LMHeadModel, GPT2Config 15 | import pandas as pd 16 | import torch.nn.utils.rnn as rnn_utils 17 | import numpy as np 18 | from dataset import CPMDataset 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | 22 | def set_args(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--device_ids', default='0', type=str, required=False, help='设置使用哪些显卡') 25 | parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练') 26 | parser.add_argument('--vocab_path', default='vocab/chinese_vocab.model', type=str, required=False, 27 | help='sp模型路径') 28 | parser.add_argument('--model_config', default='config/cpm-small.json', type=str, required=False, 29 | help='需要从头训练一个模型时,模型参数的配置文件') 30 | parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='经过预处理之后的数据存放路径') 31 | parser.add_argument('--max_len', default=200, type=int, required=False, help='训练时,输入数据的最大长度') 32 | 33 | parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度') 34 | parser.add_argument('--epochs', default=40, type=int, required=False, help='训练的最大轮次') 35 | parser.add_argument('--batch_size', default=4, type=int, required=False, help='训练的batch size') 36 | parser.add_argument('--gpu0_bsz', default=6, type=int, required=False, help='0号卡的batch size') 37 | parser.add_argument('--lr', default=1e-5, type=float, required=False, help='学习率') 38 | parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='AdamW优化器的衰减率') 39 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') 40 | parser.add_argument("--save_step", type=int, default=1, help="every eval_step to save model") 41 | parser.add_argument('--gradient_accumulation_steps', default=6, type=int, required=False, help='梯度积累的步数') 42 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 43 | parser.add_argument('--output_path', default='output/train', type=str, required=False, 44 | help='模型输出路径') 45 | parser.add_argument('--pretrained_model', default='', type=str, required=False, 46 | help='预训练的模型的路径') 47 | parser.add_argument('--seed', type=int, default=1234, help='设置随机种子') 48 | parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量") 49 | parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数') 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def collate_fn(batch): 55 | input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=5) 56 | labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100) 57 | return input_ids, labels 58 | 59 | 60 | def load_dataset(logger, args): 61 | """ 62 | 加载训练集 63 | """ 64 | logger.info("loading training dataset") 65 | train_path = args.train_path 66 | 67 | with open(train_path, "rb") as f: 68 | train_list = pickle.load(f) 69 | 70 | # test 71 | # train_list = train_list[:24] 72 | logger.info('len of train data:{}'.format(len(train_list))) 73 | train_dataset = CPMDataset(train_list, args.max_len) 74 | 75 | return train_dataset 76 | 77 | 78 | def train(model, logger, train_dataset, writer, args): 79 | model.train() 80 | train_dataloader = DataLoader( 81 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, 82 | drop_last=True 83 | ) 84 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs 85 | optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps) 86 | scheduler = transformers.get_linear_schedule_with_warmup( 87 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 88 | ) 89 | 90 | logger.info('start training') 91 | device = args.device 92 | ignore_index = args.ignore_index 93 | step = 0 94 | train_loss = 0 95 | train_acc = 0 96 | log_step = args.log_step 97 | save_step = args.save_step 98 | 99 | # ========== start training ========== # 100 | for epoch in range(args.epochs): 101 | logger.info('start {}th epoch training'.format(epoch + 1)) 102 | for batch_idx, (input_ids, labels) in enumerate(train_dataloader): 103 | step += 1 104 | input_ids = input_ids.to(device) 105 | labels = labels.to(device) 106 | 107 | outputs = model.forward(input_ids, labels=labels) 108 | logits = outputs.logits 109 | loss = outputs.loss 110 | loss = loss.mean() # 多卡损失的均值 111 | 112 | # 统计该batch的预测token的正确数与总数 113 | batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index) 114 | batch_acc = batch_correct_num / batch_total_num 115 | train_loss += loss 116 | train_acc += batch_acc 117 | 118 | if args.gradient_accumulation_steps > 1: 119 | loss = loss / args.gradient_accumulation_steps 120 | 121 | loss.backward() 122 | # 梯度裁剪 123 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 124 | # 进行一定step的梯度累计之后,更新参数 125 | if step % args.gradient_accumulation_steps == 0: 126 | # 更新参数 127 | optimizer.step() 128 | # 更新学习率 129 | scheduler.step() 130 | # 清空梯度信息 131 | optimizer.zero_grad() 132 | 133 | if step % log_step == 0: 134 | train_loss = train_loss / log_step 135 | train_acc = train_acc / log_step 136 | # 训练集指标 137 | logger.info('Epoch {} step {} train Loss {:.4f}, train ACC {:.4f}'.format(epoch + 1, step, train_loss, 138 | train_acc)) 139 | writer.add_scalar('train loss', train_loss, step) 140 | writer.add_scalar('train acc', train_acc, step) 141 | train_loss = 0 142 | train_acc = 0 143 | 144 | if step % save_step == 0: 145 | logger.info('Saving model at Epoch {} step {}'.format(epoch + 1, step)) 146 | model_path = join(args.output_path, 'epoch_{}-step_{}'.format(epoch + 1, step)) 147 | if not os.path.exists(model_path): 148 | os.mkdir(model_path) 149 | model_to_save = model.module if hasattr(model, 'module') else model 150 | model_to_save.save_pretrained(model_path) 151 | 152 | logger.info('training finished') 153 | 154 | 155 | def calculate_acc(logit, labels, ignore_index=-100): 156 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 157 | labels = labels[..., 1:].contiguous().view(-1) 158 | 159 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 160 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 161 | non_pad_mask = labels.ne(ignore_index) 162 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 163 | n_word = non_pad_mask.sum().item() 164 | return n_correct, n_word 165 | 166 | 167 | def main(): 168 | # 初始化参数 169 | args = set_args() 170 | 171 | # 设置使用哪些显卡进行训练 172 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device_ids 173 | args.cuda = not args.no_cuda 174 | 175 | # if args.batch_size < 2048 and args.warmup_steps <= 4000: 176 | # print('[Warning] The warmup steps may be not enough.\n' \ 177 | # '(sz_b, warmup) = (2048, 4000) is the official setting.\n' \ 178 | # 'Using smaller batch w/o longer warmup may cause ' \ 179 | # 'the warmup stage ends with only little data trained.') 180 | 181 | # 创建日志对象 182 | cur_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 183 | logger.add(join(args.output_path, 'train-{}.log'.format(cur_time))) 184 | # 初始化tensorboard 185 | writer = SummaryWriter(args.output_path) 186 | # 当用户使用GPU,并且GPU可用时 187 | args.cuda = torch.cuda.is_available() and not args.no_cuda 188 | device = 'cuda:0' if args.cuda else 'cpu' 189 | args.device = device 190 | logger.info('using device:{}'.format(device)) 191 | 192 | # 设置随机种子 193 | set_random_seed(args.seed, args.cuda) 194 | 195 | # 初始化tokenizer 196 | # tokenizer = CpmTokenizer(vocab_file=args.vocab_path) 197 | # args.eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 198 | # args.pad_id = tokenizer.pad_token_id 199 | 200 | # 创建模型的输出目录 201 | if not os.path.exists(args.output_path): 202 | os.mkdir(args.output_path) 203 | 204 | # 创建模型 205 | if args.pretrained_model: # 加载预训练模型 206 | logger.info('') 207 | model = GPT2LMHeadModel.from_pretrained(args.pretrained_model) 208 | else: # 初始化模型 209 | model_config = GPT2Config.from_json_file(args.model_config) 210 | model = GPT2LMHeadModel(config=model_config) 211 | model = model.to(device) 212 | logger.info('model config:\n{}'.format(model.config.to_json_string())) 213 | # assert model.config.vocab_size == tokenizer.vocab_size 214 | 215 | # 多卡并行训练模型 216 | if args.cuda and torch.cuda.device_count() > 1: 217 | # model = DataParallel(model).cuda() 218 | model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda() 219 | logger.info("use GPU {} to train".format(args.device)) 220 | 221 | # 计算模型参数数量 222 | num_parameters = 0 223 | parameters = model.parameters() 224 | for parameter in parameters: 225 | num_parameters += parameter.numel() 226 | logger.info("Number of teacher parameter: %.2fM" % (num_parameters / 1e6)) 227 | 228 | # 记录参数设置 229 | logger.info("args:{}".format(args)) 230 | 231 | # 加载训练集和验证集 232 | # ========= Loading Dataset ========= # 233 | train_dataset = load_dataset(logger, args) 234 | 235 | train(model, logger, train_dataset, writer, args) 236 | 237 | 238 | if __name__ == '__main__': 239 | main() 240 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import time 4 | import torch 5 | import torch.nn.functional as F 6 | from loguru import logger 7 | from datetime import datetime 8 | import os 9 | from torch.utils.data import Dataset, DataLoader 10 | from os.path import join, exists 11 | from torch.nn import DataParallel 12 | import transformers 13 | import pickle 14 | import sys 15 | from utils import set_logger, set_random_seed 16 | from sklearn.model_selection import train_test_split 17 | from data_parallel import BalancedDataParallel 18 | from transformers import GPT2LMHeadModel, GPT2Config 19 | import pandas as pd 20 | import torch.nn.utils.rnn as rnn_utils 21 | import numpy as np 22 | from dataset import CPMDataset 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | 26 | def set_args(): 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument('--device', default='0,1', type=str, required=False, help='设置使用哪些显卡') 29 | parser.add_argument('--no_cuda', action='store_true', help='不使用GPU进行训练') 30 | parser.add_argument('--vocab_path', default='vocab/chinese_vocab.model', type=str, required=False, 31 | help='sp模型路径') 32 | parser.add_argument('--model_config', default='config/cpm-small.json', type=str, required=False, 33 | help='需要从头训练一个模型时,模型参数的配置文件') 34 | parser.add_argument('--train_path', default='data/train.pkl', type=str, required=False, help='经过预处理之后的数据存放路径') 35 | parser.add_argument('--max_len', default=200, type=int, required=False, help='训练时,输入数据的最大长度') 36 | 37 | parser.add_argument('--ignore_index', default=-100, type=int, required=False, help='对于ignore_index的label token不计算梯度') 38 | parser.add_argument('--epochs', default=40, type=int, required=False, help='训练的最大轮次') 39 | parser.add_argument('--batch_size', default=16, type=int, required=False, help='训练的batch size') 40 | parser.add_argument('--gpu0_bsz', default=6, type=int, required=False, help='0号卡的batch size') 41 | parser.add_argument('--lr', default=1e-5, type=float, required=False, help='学习率') 42 | parser.add_argument('--eps', default=1.0e-09, type=float, required=False, help='AdamW优化器的衰减率') 43 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') 44 | parser.add_argument('--gradient_accumulation_steps', default=6, type=int, required=False, help='梯度积累的步数') 45 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 46 | parser.add_argument('--output_path', default='output/train', type=str, required=False, 47 | help='模型输出路径') 48 | parser.add_argument('--pretrained_model', default='model/zuowen_epoch40', type=str, required=False, 49 | help='预训练的模型的路径') 50 | parser.add_argument('--seed', type=int, default=1234, help='设置随机种子') 51 | parser.add_argument('--num_workers', type=int, default=0, help="dataloader加载数据时使用的线程数量") 52 | parser.add_argument('--warmup_steps', type=int, default=4000, help='warm up步数') 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | def collate_fn(batch): 58 | input_ids = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=5) 59 | labels = rnn_utils.pad_sequence(batch, batch_first=True, padding_value=-100) 60 | return input_ids, labels 61 | 62 | 63 | def load_dataset(logger, args): 64 | """ 65 | 加载训练集 66 | """ 67 | logger.info("loading training dataset") 68 | train_path = args.train_path 69 | 70 | with open(train_path, "rb") as f: 71 | train_list = pickle.load(f) 72 | 73 | # test 74 | # train_list = train_list[:24] 75 | logger.info('len of train data:{}'.format(len(train_list))) 76 | train_dataset = CPMDataset(train_list, args.max_len) 77 | 78 | return train_dataset 79 | 80 | 81 | def train_epoch(model, train_dataloader, optimizer, scheduler, logger, writer, 82 | epoch, args): 83 | model.train() 84 | device = args.device 85 | ignore_index = args.ignore_index 86 | epoch_start_time = datetime.now() 87 | 88 | total_loss = 0 # 记录下整个epoch的loss的总和 89 | epoch_correct_num = 0 # 每个epoch中,预测正确的word的数量 90 | epoch_total_num = 0 # 每个epoch中,预测的word的总数量 91 | 92 | for batch_idx, (input_ids, labels) in enumerate(train_dataloader): 93 | # 捕获cuda out of memory exception 94 | try: 95 | input_ids = input_ids.to(device) 96 | labels = labels.to(device) 97 | outputs = model.forward(input_ids, labels=labels) 98 | logits = outputs.logits 99 | loss = outputs.loss 100 | loss = loss.mean() 101 | 102 | # 统计该batch的预测token的正确数与总数 103 | batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index) 104 | # 统计该epoch的预测token的正确数与总数 105 | epoch_correct_num += batch_correct_num 106 | epoch_total_num += batch_total_num 107 | # 计算该batch的accuracy 108 | batch_acc = batch_correct_num / batch_total_num 109 | 110 | total_loss += loss.item() 111 | if args.gradient_accumulation_steps > 1: 112 | loss = loss / args.gradient_accumulation_steps 113 | 114 | loss.backward() 115 | # 梯度裁剪 116 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 117 | 118 | # 进行一定step的梯度累计之后,更新参数 119 | if (batch_idx + 1) % args.gradient_accumulation_steps == 0: 120 | # 更新参数 121 | optimizer.step() 122 | # 更新学习率 123 | scheduler.step() 124 | # 清空梯度信息 125 | optimizer.zero_grad() 126 | 127 | if (batch_idx + 1) % args.log_step == 0: 128 | logger.info( 129 | "batch {} of epoch {}, step {}, loss {}, batch_acc {}, lr {}".format( 130 | batch_idx + 1, epoch + 1, step, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr())) 131 | step = epoch * len(train_dataloader) + batch_idx 132 | writer.add_scalar('train loss', loss.item()*args.gradient_accumulation_steps, step) 133 | writer.add_scalar('train acc', batch_acc, step) 134 | 135 | del input_ids, outputs 136 | 137 | except RuntimeError as exception: 138 | if "out of memory" in str(exception): 139 | logger.info("WARNING: ran out of memory") 140 | if hasattr(torch.cuda, 'empty_cache'): 141 | torch.cuda.empty_cache() 142 | else: 143 | logger.info(str(exception)) 144 | raise exception 145 | 146 | # 记录当前epoch的平均loss与accuracy 147 | epoch_mean_loss = total_loss / len(train_dataloader) 148 | epoch_mean_acc = epoch_correct_num / epoch_total_num 149 | logger.info( 150 | "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc)) 151 | 152 | # save model 153 | logger.info('saving model for epoch {}'.format(epoch + 1)) 154 | model_path = join(args.output_path, 'epoch{}'.format(epoch + 1)) 155 | if not os.path.exists(model_path): 156 | os.mkdir(model_path) 157 | model_to_save = model.module if hasattr(model, 'module') else model 158 | model_to_save.save_pretrained(model_path) 159 | logger.info('epoch {} finished'.format(epoch + 1)) 160 | epoch_finish_time = datetime.now() 161 | logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time)) 162 | 163 | return epoch_mean_loss 164 | 165 | 166 | def train(model, logger, train_dataset, writer, args): 167 | train_dataloader = DataLoader( 168 | train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn, 169 | drop_last=True 170 | ) 171 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochs 172 | optimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps) 173 | scheduler = transformers.get_linear_schedule_with_warmup( 174 | optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total 175 | ) 176 | 177 | logger.info('start training') 178 | 179 | train_losses = [] # 记录每个epoch的平均loss 180 | # ========== start training ========== # 181 | for epoch in range(args.epochs): 182 | train_loss = train_epoch( 183 | model=model, train_dataloader=train_dataloader, 184 | optimizer=optimizer, scheduler=scheduler, 185 | logger=logger, writer=writer, epoch=epoch, args=args) 186 | train_losses.append(round(train_loss, 4)) 187 | logger.info("train loss list:{}".format(train_losses)) 188 | 189 | logger.info('training finished') 190 | logger.info("train_losses:{}".format(train_losses)) 191 | 192 | 193 | def calculate_acc(logit, labels, ignore_index=-100): 194 | logit = logit[..., :-1, :].contiguous().view(-1, logit.size(-1)) 195 | labels = labels[..., 1:].contiguous().view(-1) 196 | 197 | _, logit = logit.max(dim=-1) # 对于每条数据,返回最大的index 198 | # 进行非运算,返回一个tensor,若labels的第i个位置为pad_id,则置为0,否则为1 199 | non_pad_mask = labels.ne(ignore_index) 200 | n_correct = logit.eq(labels).masked_select(non_pad_mask).sum().item() 201 | n_word = non_pad_mask.sum().item() 202 | return n_correct, n_word 203 | 204 | 205 | def main(): 206 | # 初始化参数 207 | args = set_args() 208 | 209 | # 设置使用哪些显卡进行训练 210 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 211 | args.cuda = not args.no_cuda 212 | 213 | # if args.batch_size < 2048 and args.warmup_steps <= 4000: 214 | # print('[Warning] The warmup steps may be not enough.\n' \ 215 | # '(sz_b, warmup) = (2048, 4000) is the official setting.\n' \ 216 | # 'Using smaller batch w/o longer warmup may cause ' \ 217 | # 'the warmup stage ends with only little data trained.') 218 | 219 | # 创建日志对象 220 | cur_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 221 | logger.add(join(args.output_path, 'train-{}.log'.format(cur_time))) 222 | # 初始化tensorboard 223 | writer = SummaryWriter(args.output_path) 224 | # 当用户使用GPU,并且GPU可用时 225 | args.cuda = torch.cuda.is_available() and not args.no_cuda 226 | device = 'cuda:0' if args.cuda else 'cpu' 227 | args.device = device 228 | logger.info('using device:{}'.format(device)) 229 | 230 | # 设置随机种子 231 | set_random_seed(args.seed, args.cuda) 232 | 233 | # 初始化tokenizer 234 | # tokenizer = CpmTokenizer(vocab_file=args.vocab_path) 235 | # args.eod_id = tokenizer.convert_tokens_to_ids("") # 文档结束符 236 | # args.pad_id = tokenizer.pad_token_id 237 | 238 | # 创建模型的输出目录 239 | if not os.path.exists(args.output_path): 240 | os.mkdir(args.output_path) 241 | 242 | # 创建模型 243 | if args.pretrained_model: # 加载预训练模型 244 | model = GPT2LMHeadModel.from_pretrained(args.pretrained_model) 245 | else: # 初始化模型 246 | model_config = GPT2Config.from_json_file(args.model_config) 247 | model = GPT2LMHeadModel(config=model_config) 248 | model = model.to(device) 249 | logger.info('model config:\n{}'.format(model.config.to_json_string())) 250 | # assert model.config.vocab_size == tokenizer.vocab_size 251 | 252 | # 多卡并行训练模型 253 | if args.cuda and torch.cuda.device_count() > 1: 254 | # model = DataParallel(model).cuda() 255 | model = BalancedDataParallel(args.gpu0_bsz, model, dim=0).cuda() 256 | logger.info("use GPU {} to train".format(args.device)) 257 | 258 | # 计算模型参数数量 259 | num_parameters = 0 260 | parameters = model.parameters() 261 | for parameter in parameters: 262 | num_parameters += parameter.numel() 263 | logger.info("Number of teacher parameter: %.2fM" % (num_parameters / 1e6)) 264 | 265 | # 记录参数设置 266 | logger.info("args:{}".format(args)) 267 | 268 | # 加载训练集和验证集 269 | # ========= Loading Dataset ========= # 270 | train_dataset = load_dataset(logger, args) 271 | 272 | train(model, logger, train_dataset, writer, args) 273 | 274 | 275 | if __name__ == '__main__': 276 | main() 277 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | import random 4 | import numpy as np 5 | import torch.nn.functional as F 6 | 7 | 8 | def set_logger(log_path): 9 | """ 10 | 将日志输出到日志文件和控制台 11 | """ 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | formatter = logging.Formatter( 16 | '%(asctime)s - %(levelname)s - %(message)s') 17 | 18 | # 创建一个handler,用于写入日志文件 19 | file_handler = logging.FileHandler( 20 | filename=log_path) 21 | file_handler.setFormatter(formatter) 22 | file_handler.setLevel(logging.INFO) 23 | logger.addHandler(file_handler) 24 | 25 | # 创建一个handler,用于将日志输出到控制台 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.DEBUG) 28 | console.setFormatter(formatter) 29 | logger.addHandler(console) 30 | return logger 31 | 32 | 33 | def set_random_seed(seed, cuda): 34 | """ 35 | 设置训练的随机种子 36 | """ 37 | torch.manual_seed(seed) 38 | random.seed(seed) 39 | np.random.seed(seed) 40 | 41 | if cuda: 42 | torch.backends.cudnn.deterministic = True 43 | torch.backends.cudnn.benchmark = False 44 | 45 | 46 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 47 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 48 | Args: 49 | logits: logits distribution shape (vocabulary size) 50 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 51 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 52 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 53 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 54 | """ 55 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 56 | top_k = min(top_k, logits.size(-1)) # Safety check 57 | if top_k > 0: 58 | # Remove all tokens with a probability less than the last token of the top-k 59 | # torch.topk()返回最后一维最大的top_k个元素,返回值为二维(values,indices) 60 | # ...表示其他维度由计算机自行推断 61 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 62 | logits[indices_to_remove] = filter_value 63 | 64 | if top_p > 0.0: 65 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 66 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 67 | 68 | # Remove tokens with cumulative probability above the threshold 69 | sorted_indices_to_remove = cumulative_probs > top_p 70 | # Shift the indices to the right to keep also the first token above the threshold 71 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 72 | sorted_indices_to_remove[..., 0] = 0 73 | 74 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 75 | logits[indices_to_remove] = filter_value 76 | return logits -------------------------------------------------------------------------------- /vocab/chinese_vocab.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangjianxin1/CPM/2e11cc9dc8e545afdf82d49edbd774b7b9e47cd5/vocab/chinese_vocab.model --------------------------------------------------------------------------------