├── .gitignore ├── LICENSE ├── README.md ├── README_EN.md ├── cache ├── make_vocab.py ├── make_vocab.sh ├── vocab.txt ├── vocab_all.txt ├── vocab_guwen.txt ├── vocab_seg.txt └── vocab_small.txt ├── config ├── model_config.json ├── model_config_small.json └── model_config_test.json ├── eval.py ├── generate.py ├── generate_texts.py ├── requirements.txt ├── sample ├── doupo.jpeg ├── poem_1.png ├── poem_2.png ├── tiyu.jpg ├── 律诗绝句.png ├── 散文1.png ├── 散文2.png ├── 散文3.png ├── 浣溪沙_江城子.png ├── 蝶恋花_满江红.png ├── 金庸_倚天屠龍記.jpg ├── 金庸_天龍八部.jpg ├── 金庸_神鵰俠侶.jpg └── 金庸_鹿鼎記.jpg ├── scripts ├── generate.sh └── train.sh ├── tokenizations ├── bpe_tokenizer.py ├── encoder.json ├── thulac_dict │ └── seg ├── tokenization_bert.py ├── tokenization_bert_word_level.py └── vocab.bpe ├── train.json ├── train.py └── train_single.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | cache/.DS_Store 4 | .idea/workspace.xml 5 | .idea/misc.xml 6 | .idea/GPT2-Chinese.iml 7 | data/ 8 | .samples.txt 9 | .idea/modules.xml 10 | .idea/vcs.xml 11 | .idea 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zeyao Du 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPT2-Chinese 2 | 3 | ## Description 4 | 5 | - Chinese version of GPT2 training code, using BERT tokenizer or BPE tokenizer. It is based on the extremely awesome repository from HuggingFace team [Transformers](https://github.com/huggingface/transformers). Can write poems, news, novels, or train general language models. Support char level, word level and BPE level. Support large training corpus. 6 | - 中文的GPT2训练代码,使用BERT的Tokenizer或Sentencepiece的BPE model(感谢[kangzhonghua](https://github.com/kangzhonghua)的贡献,实现BPE模式需要略微修改train.py的代码)。可以写诗,新闻,小说,或是训练通用语言模型。支持字为单位或是分词模式或是BPE模式(需要略微修改train.py的代码)。支持大语料训练。 7 | 8 | ## UPDATE 04.11.2024 9 | 10 | - 非常感谢各位对本项目的关注。ChatGPT发布以来本项目也重新引起了一些注意。项目本身是我自学Pytorch的练手项目,我也无意做长期的维护更新。如果大家对大模型LLM感兴趣的话,可以邮件我(ned1991@gmail.com)加群沟通,或是在Issue中进行讨论。 11 | 12 | ## UPDATE 02.06.2021 13 | 14 | - 本项目新增了[通用中文GPT-2预训练模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)、[通用中文GPT-2预训练小模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)、[中文歌词GPT-2预训练模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)和[文言文GPT-2预训练模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)。模型由UER-py项目训练得到,欢迎大家使用。 15 | 此外,模型上传到了Huggingface Model Hub中。更多模型的细节请参考[gpt2-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall)、[gpt2-distil-chinese-cluecorpussmall](https://huggingface.co/uer/gpt2-distil-chinese-cluecorpussmall)、[gpt2-chinese-lyric](https://huggingface.co/uer/gpt2-chinese-lyric)和[gpt2-chinese-ancient](https://huggingface.co/uer/gpt2-chinese-ancient)。 16 | 17 | 在使用所有模型进行生成时,需要在输入的文本前加入一个起始符,如:若要输入“最美的不是下雨天,是曾与你躲过雨的屋檐”,正确的格式为“[CLS]最美的不是下雨天,是曾与你躲过雨的屋檐”。 18 | 19 | 20 | ## UPDATE 11.03.2020 21 | 22 | - 本项目新增了[古诗词GPT-2预训练模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)和[对联GPT-2预训练模型](https://github.com/Morizeyao/GPT2-Chinese#%E6%A8%A1%E5%9E%8B%E5%88%86%E4%BA%AB)。模型由UER-py项目训练得到,欢迎大家使用。 23 | 此外,模型上传到了Huggingface Model Hub中。更多模型的细节请参考[gpt2-chinese-poem](https://huggingface.co/uer/gpt2-chinese-poem)和[gpt2-chinese-couplet](https://huggingface.co/uer/gpt2-chinese-couplet)。 24 | 25 | 在使用古诗词模型进行生成时,需要在输入的文本前加入一个起始符,如:若要输入“梅山如积翠,”,正确的格式为“[CLS]梅山如积翠,”。 26 | 27 | 对联模型训练时使用的语料格式为“上联-下联”,在使用对联模型进行生成时,需要在输入的文本前加入一个起始符,如:若要输入“丹枫江冷人初去-”,正确的格式为“[CLS]丹枫江冷人初去-”。 28 | 29 | ## NEWS 08.11.2020 30 | 31 | - [CDial-GPT](https://github.com/thu-coai/CDial-GPT)(可用本代码载入)已发布。本项目包含一个经过严格清洗的大规模放开域中文对话数据集,本项目还包含在此数据集上训练的GPT对话预训练模型,以及生成样例,欢迎大家参观。 32 | 33 | ## NEWS 12.9.2019 34 | 35 | - 新项目[GPT2-chitchat](https://github.com/yangjianxin1/GPT2-chitchat)已发布,部分基于本项目代码。包含训练GPT2对话模型的代码与与训练模型,以及生成样例,欢迎大家参观。 36 | 37 | ## NEWS 12.7.2019 38 | 39 | - 新项目[Decoders-Chinese-TF2.0](https://github.com/Morizeyao/Decoders-Chinese-TF2.0)同样支持GPT2的中文训练,在使用上更加简单,不易产生各种问题。目前还在测试阶段,欢迎大家提出意见。 40 | 41 | ## NEWS 11.9 42 | 43 | - [GPT2-ML](https://github.com/imcaspar/gpt2-ml)(与本项目无任何直接关联)已发布,包含1.5B中文GPT2模型。大家如有兴趣或需要可将其转换为本项目支持的Pytorch格式进行进一步训练或生成测试。 44 | 45 | ## UPDATE 10.25 46 | 47 | - 本项目第一个预训练模型已公布,为散文生成模型,具体可查看README模型分享部分。 48 | 49 | ## 项目状态 50 | 51 | - 在本项目公布时,中文的GPT2资源几乎为零,而现在情况已有所不同。其次项目功能已经基本稳定,因此目前本项目暂已停止更新。我写下这些代码的初衷是练习Pytorch的使用,即使后期做了一些填坑工作,难免还是有很多不成熟的地方,也请谅解。 52 | 53 | ## 使用方法 54 | 55 | - 在项目根目录建立data文件夹。将训练语料以train.json为名放入data目录中。**train.json里是一个json列表,列表的每个元素都分别是一篇要训练的文章的文本内容(而不是文件链接)**。 56 | - 运行train.py文件,勾选 --raw ,会自动预处理数据。 57 | - 预处理完成之后,会自动执行训练。 58 | 59 | ### 生成文本 60 | 61 | ``` bash 62 | python ./generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save_samples --save_samples_path=/mnt/xx 63 | ``` 64 | - **--fast_pattern** (由[LeeCP8](https://github.com/LeeCP8)贡献):如果生成的length参数比较小,速度基本无差别,我个人测试length=250时,快了2秒,所以如果不添加--fast_pattern,那么默认不采用fast_pattern方式。 65 | - **--save_samples**:默认将输出样本直接打印到控制台,传递此参数,将保存在根目录下的**samples.txt**。 66 | - **--save_samples_path**:可自行指定保存的目录,默认可递归创建多级目录,不可以传递文件名称,文件名称默认为**samples.txt**。 67 | 68 | ## 文件结构 69 | 70 | - generate.py 与 train.py 分别是生成与训练的脚本。 71 | - train_single.py 是 train.py的延伸,可以用于一个很大的单独元素列表(如训练一本斗破苍穹书)。 72 | - eval.py 用于评估生成模型的ppl分值。 73 | - generate_texts.py 是 generate.py 的延伸,可以以一个列表的起始关键词分别生成若干个句子并输出到文件中。 74 | - train.json 是训练样本的格式范例,可供参考。 75 | - cache 文件夹内包含若干BERT词表,make_vocab.py 是一个协助在一个train.json语料文件上建立词表的脚本。 vocab.txt 是原始BERT词表, vocab_all.txt 额外添加了古文词, vocab_small.txt 是小词表。 76 | - tokenizations 文件夹内是可以选用的三种tokenizer,包括默认的Bert Tokenizer,分词版Bert Tokenizer以及BPE Tokenizer。 77 | - scripts 内包含了样例训练与生成脚本 78 | 79 | ## 注意 80 | 81 | - 本项目使用Bert的tokenizer处理中文字符。 82 | - 如果不使用分词版的tokenizer,不需要自己事先分词,tokenizer会帮你分。 83 | - 如果使用分词版的tokenizer,最好先使用cache文件夹内的make_vocab.py文件建立针对你的语料的词表。 84 | - 模型需自行运算。各位如果完成了预训练的话欢迎进行交流。 85 | - 如果你的内存非常大或者语料较小的话,可以改掉train.py内build files内的对应代码,不做拆分直接预处理语料。 86 | - 若使用BPE Tokenizer,需自己建立中文词表 87 | 88 | ## 语料 89 | 90 | - 可以从[这里](https://github.com/brightmart/nlp_chinese_corpus)与[这里](http://thuctc.thunlp.org/#获取链接)下载。 91 | - 斗破苍穹语料可以从[这里](https://github.com/GaoPeng97/transformer-xl-chinese/tree/master/data/doupo)下载。 92 | 93 | ## FP16与Gradient Accumulation支持 94 | 95 | - 我在train.py文件中加入了fp16与gradient accumulation支持,如果你安装了apex并且知道fp16是什么的话,可以修改变量fp16=True来启用。但是目前fp16可能不收敛,原因不明。 96 | 97 | ## 联系作者 98 | 99 | - Mail:ned1991@gmail.com 100 | 101 | ## Citing 102 | 103 | ``` 104 | @misc{GPT2-Chinese, 105 | author = {Zeyao Du}, 106 | title = {GPT2-Chinese: Tools for training GPT2 model in Chinese language}, 107 | year = {2019}, 108 | publisher = {GitHub}, 109 | journal = {GitHub repository}, 110 | howpublished = {\url{https://github.com/Morizeyao/GPT2-Chinese}}, 111 | } 112 | ``` 113 | 114 | ## 模型分享 115 | | 模型名称 | 模型介绍| 分享者| 链接地址1 | 链接地址2 | 116 | | :----------- | :----------- | :----------- | :----------- | ------------ | 117 | | 散文模型 | 使用130MB的名家散文、情感散文和散文诗歌训练所得 。 | [hughqiu](https://github.com/hughqiu "hughqiu") | [百度网盘【fpyu】](https://pan.baidu.com/s/1nbrW5iw34GRhoTin8uU2tQ) | [GDrive](https://drive.google.com/drive/folders/1rJC4niJKMVwixUQkuL9k5teLRnEYTmUf?usp=sharing "GDrive") | 118 | | 诗词模型 | 使用180MB的约80万首古诗词训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【7fev】](https://pan.baidu.com/s/1Hy0OQ5xZcTLer9MQZW8o3g) | [GDrive](https://drive.google.com/drive/folders/1Z6nF1nrgTkrZcRLHedQHXb4_M9I7yQPN?usp=sharing) | 119 | | 对联模型 | 使用40MB的约70万条对联训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【i5n0】](https://pan.baidu.com/s/1j9yVQwjlXZq58wOyXK4lcg) | [GDrive](https://drive.google.com/drive/folders/1ZnsvS7oHRVueNKj_SeEhiQt86aze3ojj?usp=sharing) | 120 | | 通用中文模型 | 使用[CLUECorpusSmall](https://github.com/CLUEbenchmark/CLUECorpus2020/)语料训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【n3s8】](https://pan.baidu.com/s/16x0hfBCekWju75xPeyyRfA) | [GDrive](https://drive.google.com/drive/folders/1dLEANs5z4pWS0pzrak6Q2H2Nq4iYsMsf?usp=sharing) | 121 | | 通用中文小模型 | 使用[CLUECorpusSmall](https://github.com/CLUEbenchmark/CLUECorpus2020/)语料训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【rpjk】](https://pan.baidu.com/s/1AiSm2GWhbGNxvhrcUlDXNA) | [GDrive](https://drive.google.com/drive/folders/1eerX1N8n_eFlnQ4xpxZ4iU2-Mx83pXFp?usp=sharing) | 122 | | 中文歌词模型 | 使用140MB的约15万首中文歌词训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【0qnn】](https://pan.baidu.com/s/19x0d0bPGCWHi9L4Pu0pSiw) | [GDrive](https://drive.google.com/drive/folders/1RFq4NoQ3phCJjrhKtu2Xbn6z0krcN9TM?usp=sharing) | 123 | | 文言文模型 | 使用1.8GB的约300万篇文言文训练所得。 | [hhou435](https://github.com/hhou435) | [百度网盘【ek2z】](https://pan.baidu.com/s/1X3Um9HketnlGYZubY9gnew) | [GDrive](https://drive.google.com/drive/folders/1dtHTRn3fX7g8cPCCaJEXA2tmrIcImR6t?usp=sharing) | 124 | 125 | 此处为热情大方的git友训练所得的模型文件,公开给所有朋友使用,同时也欢迎各位伙伴将自己训练完毕的模型公开于此处。 126 | 127 | 128 | ## Demo 129 | 130 | - 由用户[JamesHujy](https://github.com/JamesHujy)根据本仓库改版代码训练得到的模型作为律诗与绝句后台,新版[九歌诗歌生成器](https://jiuge.thunlp.cn/lvshi.html)已经上线。 131 | - 由[leemengtaiwan](https://github.com/leemengtaiwan)贡献,提供[文章直觀介紹 GPT-2 以及如何視覺化自注意力機制](https://leemeng.tw/gpt2-language-model-generate-chinese-jing-yong-novels.html)。另提供 [Colab 筆記本與模型](https://colab.research.google.com/drive/1MaT8-HUHfZkdCra0OqZEIr0IFCq0MJBx)供任何使用者一鍵生成新樣例。 132 | 133 | ## 生成样例 134 | 135 | -以下为文学散文的生成样例,由[hughqiu](https://github.com/hughqiu "hughqiu")贡献,模型已经分享于模型分享列表。语料130MB,Batch size 16,10层深度下训练10轮所得。 136 | ![avatar](sample/散文1.png) 137 | ![avatar](sample/散文2.png) 138 | ![avatar](sample/散文3.png) 139 | 140 | - 下为斗破苍穹的生成样例,使用约50M参数的GPT2以32Batch Size在16MB斗破苍穹小说内容上训练得到。此处[SEP]表示换行。 141 | 142 | ![avatar](sample/doupo.jpeg) 143 | 144 | - 下为古诗词的生成样例,由用户[JamesHujy](https://github.com/JamesHujy)运算并贡献。 145 | 146 | ![avatar](sample/poem_1.png) 147 | ![avatar](sample/poem_2.png) 148 | 149 | - 下为古诗限定了生成体裁后的生成样例,由用户[JamesHujy](https://github.com/JamesHujy)运算并贡献。 150 | 151 | ![avatar](sample/律诗绝句.png) 152 | ![avatar](sample/浣溪沙_江城子.png) 153 | ![avatar](sample/蝶恋花_满江红.png) 154 | 155 | - 下为生成剧本的样例文本,由用户[chiangandy](https://github.com/chiangandy)运算并贡献 156 | 157 | [starttext]爱情游戏剧情讲述了钢琴父女明致怀萌的爱情、个有着努力的热情以及现实为人生的价值观众,获得一系列爱情的故事。80后录股媒体受到网友分享,是2014年主创陈拉昀出品牌总监于蓝氏集团化验师创业团门的哥哥大国度上海淮河畔,集入第一线公司青年度虽然没有放到的事业,但是蓝正是却不到位主人拒绝了解,而在蓝越的帮助理念出现,也因此开启明朗的误会而经营变成爱河。在一次偶然的编剧集电视剧之夏天上一改变了自命运环球顶樑,三人在创车祸中不知被记忆差网识分到创作,并被问流言败,以及行业服务所有的低调教同才力,陈昭和唐诗诗妍展开了一段截然不同的“2014年间段感情”,两人性格互相治癒的商业奋斗故事,尽管是共90后北京华侨大学录的一个宿舍小旅程和唐如、生等优秀青年,的人生活如何与愿违3个国偶像,并且共同创作何以此他们互相有观众的成功和关心吗?[endtext] 158 | 159 | [starttext]学习爱情主要讲述了两对方小曼,经过啼笑皆非的考验,终于选择了三个孩子,携手共同创业来四个孩子,在大城市里创业的成功商。两家内事业的加入了北京城市,经过了一次元城市融风雨故、差异后得到异的他们,最终收获了梦想的真正属于自己的爱情。赞助理想、电视剧、剧等主创业时代人物特点在北京举行开机仪式,该剧以当下海南三个新人青年轻人面人海南梅竹马的电视角,讲述了几个在北京、喜剧代人生活中增强非浪漫的年轻人,以独特的双时代年轻人从来到北京城市化中国大城市走出发展以海南方的变迁在语种城市闯关于人生态的同时,以及他们渐渐的生活方式为自己方向上演了那么简单俗,是当代际拍摄的就如何在这个城市里都市里?那么平静的城市就是城市的风格特张嘉和支持工作打造,而这是一点就要打造出机场话剧组会。化身处处棋逢貌各种文化的人都非常独特的煽情,交织了相,滑稽等来自外衣的东北漂亮、内地,者和两位女孩子敢称是哑女孩子。交织里的人齐飞一开泰块玩笑,令人印象太趋的气质,让人眼看这个性格非常喜剧,知道的是一个“东北漂”人的外国小养家,让她耳熟练读剧的外形象显老大。之后齐飞、表示爱朗的齐飞、范儿、楚月子、白天杰。两代人的生活里友情似乎没有结合、精彩表态的开朗和丽丽丽。[endtext] 160 | 161 | - 下為金庸武俠小說的生成樣例,由[leemengtaiwan](https://github.com/leemengtaiwan)贡献。模型大小約 82M,語料 50 MB,Batch size 16。提供[文章直觀介紹 GPT-2 以及如何視覺化自注意力機制](https://leemeng.tw/gpt2-language-model-generate-chinese-jing-yong-novels.html)。另提供 [Colab 筆記本與模型](https://colab.research.google.com/drive/1MaT8-HUHfZkdCra0OqZEIr0IFCq0MJBx)供任何使用者一鍵生成新樣例。 162 | 163 | ![avatar](sample/金庸_天龍八部.jpg) 164 | ![avatar](sample/金庸_倚天屠龍記.jpg) 165 | ![avatar](sample/金庸_鹿鼎記.jpg) 166 | ![avatar](sample/金庸_神鵰俠侶.jpg) 167 | 168 | 169 | 170 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | # GPT2-Chinese 2 | 3 | ## Description 4 | 5 | - Chinese version of GPT2 training code, using BERT tokenizer or BPE tokenizer. from HuggingFace team [Transformers](https://github.com/huggingface/transformers). Can write poems, news, novels, or train general language models. Can write poems, news, novels, or train general language models. Support char level, word level and BPE level. training corpus. 6 | - Chinese GPT2 training code, using BERT's Tokenizer or Sentencepiece's BPE model (thanks to [kangzhonghua](https://github.com/kangzhonghua) for the contribution, the implementation of the BPE model requires a slight modification of train.py). (Code). You can write poems, news, novels, or train common language models. Support for word units or parts of words or BPE mode (need to modify the code of train.py slightly). Supports large corpus training. 7 | 8 | ## NEWS 08.11.2020 9 | 10 | - [CDial-GPT](https://github.com/thu-coai/CDial-GPT) (which can be loaded with this code) has been released. This project contains a rigorously cleaned Chinese dialogue dataset in a large scale liberalized domain, and a pre-trained GPT model trained on this dataset, as well as generated samples. 11 | 12 | ## NEWS 12.9.2019. 13 | 14 | - A new project [GPT2-chitchat](https://github.com/yangjianxin1/GPT2-chitchat) has been released, based in part on the code of this project. It contains code for training the GPT2 conversational model with and with the training model, as well as generating samples, which you are welcome to visit. 15 | 16 | ## NEWS 12.7.2019. 17 | 18 | - The new project [Decoders-Chinese-TF2.0](https://github.com/Morizeyao/Decoders-Chinese-TF2.0) also supports Chinese training of GPT2, which is easier to use and less likely to cause problems. It is still in the testing stage, so we welcome your comments. 19 | 20 | ## NEWS 11.9 21 | 22 | - [GPT2-ML](https://github.com/imcaspar/gpt2-ml) (not directly related to this project) has been released, including 1.5B Chinese GPT2 model. It contains a 1.5B Chinese GPT2 model. It can be converted to the Pytorch format supported by this project for further training or test generation if you are interested. 23 | 24 | ## UPDATE 10.25 25 | 26 | - The first pre-trained model of this project has been released, it is a prose generation model, please see the README model sharing section. 27 | 28 | ## Project Status 29 | 30 | - When this project was announced, the Chinese GPT2 resources were almost zero, but the situation is different now. Secondly, the functionality of the project has been stabilized, so the project has been stopped for the time being. The purpose of this code is to practice using Pytorch, even if I have to fill in some holes later, there are still a lot of immature places, please understand. 31 | 32 | ## Usage 33 | 34 | - Create a data folder in the project root directory. Put the training corpus into the data directory under the name train.json. **train.json is a json list, each element of the list is the text of an article to be trained (rather than a link to a file)**. 35 | - Run the train.py file, check --raw, it will automatically preprocess the data. 36 | - When the preprocessing is complete, the training will be executed automatically. 37 | 38 | ### Generate text 39 | 40 | ``` bash 41 | python . /generate.py --length=50 --nsamples=4 --prefix=xxx --fast_pattern --save_samples --save_samples_path=/mnt/xx 42 | ``` 43 | - **--fast_pattern** (contributed by [LeeCP8](https://github.com/LeeCP8)): If the generated length parameter is relatively small, the speed is basically no difference, my personal test length = 250, faster by 2 seconds, so if you do not add--fast_pattern then fast_pattern is not used by default. 44 | - **--save_samples**: Default is to print the output samples directly to the console, pass this parameter, it will be saved in the root directory **samples.txt**. 45 | - **--save_samples_path**: you can specify the directory to be saved, the default is recursive creation of multi-level directories, you can not pass the file name, the default file name is **samples.txt**. 46 | 47 | ## File structure 48 | 49 | - generate.py and train.py are generation and training scripts, respectively. 50 | - train_single.py is an extension of train.py and can be used for a large list of individual elements (e.g. training a DouDouQiongQiong book). 51 | - eval.py is used to evaluate the ppl score of the generated model. 52 | - generate_texts.py is an extension of generate.py that generates several sentences starting with a list of keywords and outputs them to a file. 53 | - train.json is an example of the format of the training samples that is available for reference. 54 | - The cache folder contains several BERT vocabularies. make_vocab.py is a script that assists in building vocabularies on a train.json corpus file. vocab.txt is the original BERT vocabulary, vocab_all.txt is an additional archaic word, vocab_small.txt is a small vocabulary. 55 | - The tokenizations folder contains the three tokenizers you can choose from: the default Bert Tokenizer, the split-word version of the Bert Tokenizer, and the BPE Tokenizer. 56 | - The scripts contain sample training and generation scripts. 57 | 58 | ## Attention. 59 | 60 | - This project uses Bert's tokenizer to handle Chinese characters. 61 | - If you don't use the word-splitting version of the tokenizer, you don't need to split the words yourself, the tokenizer will do it for you. 62 | - If you use the word splitting version of the tokenizer, you should use the make_vocab.py file in the cache folder to create a word list for your corpus. 63 | - The model needs to be calculated by yourself. If you have finished the pre-training, please feel free to talk to us. 64 | - If your memory is very big or the corpus is small, you can change the corresponding code in the build files in train.py and preprocess the corpus without splitting it. 65 | - If you use BPE Tokenizer, you need to build your own Chinese word list. 66 | 67 | ## Language 68 | 69 | - It can be downloaded from [here](https://github.com/brightmart/nlp_chinese_corpus) and [here](http://thuctc.thunlp.org/#获取链接). 70 | - The DoD language can be downloaded from [here](https://github.com/GaoPeng97/transformer-xl-chinese/tree/master/data/doupo). 71 | 72 | ## FP16 with Gradient Accumulation Support 73 | 74 | - I've added fp16 and gradient accumulation support in the train.py file, and if you have apex installed and know what fp16 is, you can modify the variable fp16=True to enable it. But currently fp16 may not converge, for reasons unknown. 75 | 76 | ## Contact the author 77 | 78 | - Mail: ned1991@gmail.com 79 | 80 | ## Citing 81 | 82 | ``` 83 | @misc{GPT2-Chinese, 84 | author = {Zeyao Du}, 85 | title = {GPT2-Chinese: Tools for training GPT2 model in Chinese language}, 86 | year = {2019}, 87 | publisher = {GitHub}, 88 | journal = {GitHub repository}, 89 | howpublished = {\url{https://github.com/Morizeyao/GPT2-Chinese}}, 90 | } 91 | ``` 92 | 93 | ## Model sharing 94 | | Model Name | Model Description | Shareholder | Link Address1 | Link Address2 | Link Address2 95 | | The first is that the number of people in the world who have been in the hospital for more than a year has been increasing. 96 | | Prose Model | Using 130MB of famous prose, emotional prose and prose poetry training results . | [hughqiu](https://github.com/hughqiu "hughqiu") | [Baidu.com [fpyu](https://pan.baidu.com/s/1nbrW5iw34GRhoTin8uU2tQ) | [GDrive](https) ://drive.google.com/drive/folders/1rJC4niJKMVwixUQkuL9k5teLRnEYTmUf?usp=sharing "gDrive") | 97 | 98 | 99 | 100 | This is the training model file of a warm and generous git user, it's open for all friends to use, and all partners are welcome to open their own training models here. 101 | 102 | 103 | ## Demo 104 | 105 | - By user [JamesHujy](https://github.com/JamesHujy), trained on the model obtained from the code revision of this repository as a rhythm and stanza background, a new version of the [Nine Song Poetry Generator](https://jiuge.thunlp.cn/lvshi.html) is now available. 106 | - Contributed by [leemengtaiwan](https://github.com/leemengtaiwan) to provide an [article intuitive introduction to GPT-2 and how to visualize the self-attention mechanism](https://leemeng.tw/gpt2-language-model-) generate-english-jing-yong-novels.html). Colab notebooks and models are also available (https://colab.research.google.com/drive/1MaT8-HUHfZkdCra0OqZEIr0IFCq0MJBx) for any user to generate new samples with a single click. 107 | 108 | Translated with www.DeepL.com/Translator (free version) -------------------------------------------------------------------------------- /cache/make_vocab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import thulac 3 | import json 4 | 5 | from tqdm import tqdm 6 | from keras.preprocessing.text import Tokenizer 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--raw_data_path', default='../data/train.json', type=str, required=False, help='原始训练语料') 12 | parser.add_argument('--vocab_file', default='vocab_processed.txt', type=str, required=False, help='生成vocab链接') 13 | parser.add_argument('--vocab_size', default=50000, type=int, required=False, help='词表大小') 14 | args = parser.parse_args() 15 | 16 | lac = thulac.thulac(seg_only=True) 17 | tokenizer = Tokenizer(num_words=args.vocab_size) 18 | print('args:\n' + args.__repr__()) 19 | print('This script is extremely slow especially for large corpus. Take a break.') 20 | 21 | f = open(args.raw_data_path, 'r') 22 | lines = json.load(f) 23 | for i, line in enumerate(tqdm(lines)): 24 | lines[i] = lac.cut(line, text=True) 25 | 26 | tokenizer.fit_on_texts(lines) 27 | vocab = list(tokenizer.index_word.values()) 28 | pre = ['[SEP]', '[CLS]', '[MASK]', '[PAD]', '[UNK]'] 29 | vocab = pre + vocab 30 | with open(args.vocab_file, 'w') as f: 31 | for word in vocab[:args.vocab_size + 5]: 32 | f.write(word + '\n') 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | -------------------------------------------------------------------------------- /cache/make_vocab.sh: -------------------------------------------------------------------------------- 1 | python make_vocab.py \ 2 | --raw_data_path ../data/train.json \ 3 | --vocab_file vocab_user.txt \ 4 | --vocab_size 50000 -------------------------------------------------------------------------------- /config/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "initializer_range": 0.02, 3 | "layer_norm_epsilon": 1e-05, 4 | "n_ctx": 1024, 5 | "n_embd": 768, 6 | "n_head": 12, 7 | "n_layer": 12, 8 | "n_positions": 1024, 9 | "vocab_size": 21128 10 | } -------------------------------------------------------------------------------- /config/model_config_small.json: -------------------------------------------------------------------------------- 1 | { 2 | "initializer_range": 0.02, 3 | "layer_norm_epsilon": 1e-05, 4 | "n_ctx": 1024, 5 | "n_embd": 768, 6 | "n_head": 12, 7 | "n_layer": 10, 8 | "n_positions": 1024, 9 | "vocab_size": 13317 10 | } -------------------------------------------------------------------------------- /config/model_config_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "initializer_range": 0.02, 3 | "layer_norm_epsilon": 1e-05, 4 | "n_ctx": 64, 5 | "n_embd": 128, 6 | "n_head": 2, 7 | "n_layer": 1, 8 | "n_positions": 64, 9 | "vocab_size": 13317 10 | } -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | import argparse 8 | from datetime import datetime 9 | from tqdm import tqdm 10 | from torch.nn import DataParallel 11 | 12 | 13 | def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length): 14 | if not os.path.exists(tokenized_data_path): 15 | os.mkdir(tokenized_data_path) 16 | with open(data_path, 'r', encoding='utf8') as f: 17 | print('reading lines') 18 | lines = json.load(f) 19 | lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 20 | all_len = len(lines) 21 | for i in tqdm(range(num_pieces)): 22 | sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)] 23 | if i == num_pieces - 1: 24 | sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece 25 | sublines = [full_tokenizer.tokenize(line) for line in sublines if 26 | len(line) > min_length] # 只考虑长度超过min_length的句子 27 | sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines] 28 | full_line = [] 29 | for subline in sublines: 30 | full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始 31 | full_line.extend(subline) 32 | full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束 33 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: 34 | for id in full_line: 35 | f.write(str(id) + ' ') 36 | print('finish') 37 | 38 | 39 | def main(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') 42 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, 43 | help='选择模型参数') 44 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') 45 | parser.add_argument('--raw_data_path', default='data/eval.json', type=str, required=False, help='原始语料') 46 | parser.add_argument('--tokenized_data_path', default='data/tokenized_eval/', type=str, required=False, 47 | help='tokenized语料存放位置') 48 | parser.add_argument('--raw', action='store_true', help='是否先做tokenize') 49 | parser.add_argument('--batch_size', default=8, type=int, required=False, help='batch size') 50 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次') 51 | parser.add_argument('--stride', default=768, type=int, required=False, help='取数据的窗口步长') 52 | parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') 53 | parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度') 54 | parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型起点路径') 55 | parser.add_argument('--output_dir', default='eval_result/', type=str, required=False, help='结果输出路径') 56 | 57 | args = parser.parse_args() 58 | print('args:\n' + args.__repr__()) 59 | 60 | # if args.no_wordpiece: 61 | # from tokenizations import tokenization_bert_without_wordpiece as tokenization_bert 62 | # else: 63 | from tokenizations import tokenization_bert 64 | 65 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 66 | 67 | model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) 68 | print('config:\n' + model_config.to_json_string()) 69 | 70 | n_ctx = model_config.n_ctx 71 | full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) 72 | full_tokenizer.max_len = n_ctx 73 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 74 | print('using device:', device) 75 | 76 | raw_data_path = args.raw_data_path 77 | tokenized_data_path = args.tokenized_data_path 78 | raw = args.raw # 选择是否从零开始构建数据集 79 | batch_size = args.batch_size 80 | log_step = args.log_step 81 | stride = args.stride 82 | num_pieces = args.num_pieces 83 | min_length = args.min_length 84 | output_dir = args.output_dir 85 | 86 | if not os.path.exists(output_dir): 87 | os.mkdir(output_dir) 88 | 89 | if raw: 90 | print('building files') 91 | build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces, 92 | full_tokenizer=full_tokenizer, min_length=min_length) 93 | print('files built') 94 | 95 | if not args.pretrained_model: 96 | print('you need to specify a trained model.') 97 | exit(1) 98 | else: 99 | model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model) 100 | model.eval() 101 | model.to(device) 102 | 103 | num_parameters = 0 104 | parameters = model.parameters() 105 | for parameter in parameters: 106 | num_parameters += parameter.numel() 107 | print('number of parameters: {}'.format(num_parameters)) 108 | 109 | multi_gpu = False 110 | full_len = 0 111 | print('calculating total steps') 112 | for i in tqdm(range(num_pieces)): 113 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 114 | full_len += len([int(item) for item in f.read().strip().split()]) 115 | 116 | if torch.cuda.device_count() > 1: 117 | print("Let's use", torch.cuda.device_count(), "GPUs!") 118 | model = DataParallel(model) 119 | multi_gpu = True 120 | print('starting training') 121 | overall_step = 0 122 | 123 | total_loss = 0 124 | total_steps = 0 125 | # eval 126 | now = datetime.now() 127 | print('time: {}'.format(now)) 128 | piece_num = 0 129 | for i in range(num_pieces): 130 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 131 | line = f.read().strip() 132 | tokens = line.split() 133 | tokens = [int(token) for token in tokens] 134 | start_point = 0 135 | samples = [] 136 | while start_point < len(tokens) - n_ctx: 137 | samples.append(tokens[start_point: start_point + n_ctx]) 138 | start_point += stride 139 | start_point -= stride 140 | last = tokens[start_point + n_ctx:] 141 | last.extend([full_tokenizer.convert_tokens_to_ids(['[PAD]']) * (n_ctx - len(last))]) 142 | random.shuffle(samples) 143 | for step in range(len(samples) // batch_size): # drop last 144 | 145 | # prepare data 146 | batch = samples[step * batch_size: (step + 1) * batch_size] 147 | batch_labels = [] 148 | batch_inputs = [] 149 | for ids in batch: 150 | int_ids_for_labels = [int(x) for x in ids] 151 | int_ids_for_inputs = [int(x) for x in ids] 152 | batch_labels.append(int_ids_for_labels) 153 | batch_inputs.append(int_ids_for_inputs) 154 | batch_labels = torch.tensor(batch_labels).long().to(device) 155 | batch_inputs = torch.tensor(batch_inputs).long().to(device) 156 | 157 | # forward pass 158 | outputs = model.forward(input_ids=batch_inputs, labels=batch_labels) 159 | loss, logits = outputs[:2] 160 | 161 | # get loss 162 | if multi_gpu: 163 | loss = loss.mean() 164 | total_loss += loss 165 | total_steps += 1 166 | 167 | if (overall_step + 1) % log_step == 0: 168 | print('now time: {}:{}. Step {} of piece {}, ppl {}'.format( 169 | datetime.now().hour, 170 | datetime.now().minute, 171 | (step + 1), 172 | piece_num, 173 | torch.exp(loss))) 174 | piece_num += 1 175 | 176 | if not os.path.exists(args.output_dir): 177 | os.mkdir(args.output_dir) 178 | else: 179 | with open(args.output_dir + 'result.txt', 'w') as f: 180 | f.write(np.exp(total_loss / total_steps)) 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import argparse 5 | from tqdm import trange 6 | from transformers import GPT2LMHeadModel 7 | 8 | 9 | def is_word(word): 10 | for item in list(word): 11 | if item not in 'qwertyuiopasdfghjklzxcvbnm': 12 | return False 13 | return True 14 | 15 | 16 | def _is_chinese_char(char): 17 | """Checks whether CP is the codepoint of a CJK character.""" 18 | # This defines a "chinese character" as anything in the CJK Unicode block: 19 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 20 | # 21 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 22 | # despite its name. The modern Korean Hangul alphabet is a different block, 23 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 24 | # space-separated words, so they are not treated specially and handled 25 | # like the all of the other languages. 26 | cp = ord(char) 27 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 28 | (cp >= 0x3400 and cp <= 0x4DBF) or # 29 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 30 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 31 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 32 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 33 | (cp >= 0xF900 and cp <= 0xFAFF) or # 34 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 35 | return True 36 | 37 | return False 38 | 39 | 40 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 41 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 42 | Args: 43 | logits: logits distribution shape (vocabulary size) 44 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 45 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 46 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 47 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 48 | """ 49 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 50 | top_k = min(top_k, logits.size(-1)) # Safety check 51 | if top_k > 0: 52 | # Remove all tokens with a probability less than the last token of the top-k 53 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 54 | logits[indices_to_remove] = filter_value 55 | 56 | if top_p > 0.0: 57 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 58 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 59 | 60 | # Remove tokens with cumulative probability above the threshold 61 | sorted_indices_to_remove = cumulative_probs > top_p 62 | # Shift the indices to the right to keep also the first token above the threshold 63 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 64 | sorted_indices_to_remove[..., 0] = 0 65 | 66 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 67 | logits[indices_to_remove] = filter_value 68 | return logits 69 | 70 | 71 | def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, 72 | device='cpu'): 73 | context = torch.tensor(context, dtype=torch.long, device=device) 74 | context = context.unsqueeze(0) 75 | generated = context 76 | with torch.no_grad(): 77 | for _ in trange(length): 78 | inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} 79 | outputs = model( 80 | **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) 81 | next_token_logits = outputs[0][0, -1, :] 82 | for id in set(generated): 83 | next_token_logits[id] /= repitition_penalty 84 | next_token_logits = next_token_logits / temperature 85 | next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') 86 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) 87 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 88 | generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) 89 | return generated.tolist()[0] 90 | 91 | 92 | def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'): 93 | inputs = torch.LongTensor(context).view(1, -1).to(device) 94 | if len(context) > 1: 95 | _, past = model(inputs[:, :-1], None)[:2] 96 | prev = inputs[:, -1].view(1, -1) 97 | else: 98 | past = None 99 | prev = inputs 100 | generate = [] + context 101 | with torch.no_grad(): 102 | for i in trange(length): 103 | output = model(prev, past=past) 104 | output, past = output[:2] 105 | output = output[-1].squeeze(0) / temperature 106 | filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p) 107 | next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1) 108 | generate.append(next_token.item()) 109 | prev = next_token.view(1, 1) 110 | return generate 111 | 112 | 113 | # 通过命令行参数--fast_pattern,指定模式 114 | def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu', 115 | is_fast_pattern=False): 116 | if is_fast_pattern: 117 | return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p, 118 | device=device) 119 | else: 120 | return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p, 121 | repitition_penalty=repitition_penalty, device=device) 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='生成设备') 127 | parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度') 128 | parser.add_argument('--batch_size', default=1, type=int, required=False, help='生成的batch size') 129 | parser.add_argument('--nsamples', default=10, type=int, required=False, help='生成几个样本') 130 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度') 131 | parser.add_argument('--topk', default=8, type=int, required=False, help='最高几选一') 132 | parser.add_argument('--topp', default=0, type=float, required=False, help='最高积累概率') 133 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, 134 | help='模型参数') 135 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径') 136 | parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径') 137 | parser.add_argument('--prefix', default='萧炎', type=str, required=False, help='生成文章的开头') 138 | parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词') 139 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') 140 | parser.add_argument('--fast_pattern', action='store_true', help='采用更加快的方式生成文本') 141 | parser.add_argument('--save_samples', action='store_true', help='保存产生的样本') 142 | parser.add_argument('--save_samples_path', default='.', type=str, required=False, help="保存样本的路径") 143 | parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False) 144 | 145 | args = parser.parse_args() 146 | print('args:\n' + args.__repr__()) 147 | 148 | if args.segment: 149 | from tokenizations import tokenization_bert_word_level as tokenization_bert 150 | else: 151 | from tokenizations import tokenization_bert 152 | 153 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 154 | length = args.length 155 | batch_size = args.batch_size 156 | nsamples = args.nsamples 157 | temperature = args.temperature 158 | topk = args.topk 159 | topp = args.topp 160 | repetition_penalty = args.repetition_penalty 161 | 162 | device = "cuda" if torch.cuda.is_available() else "cpu" 163 | 164 | tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) 165 | model = GPT2LMHeadModel.from_pretrained(args.model_path) 166 | model.to(device) 167 | model.eval() 168 | 169 | n_ctx = model.config.n_ctx 170 | 171 | if length == -1: 172 | length = model.config.n_ctx 173 | if args.save_samples: 174 | if not os.path.exists(args.save_samples_path): 175 | os.makedirs(args.save_samples_path) 176 | samples_file = open(args.save_samples_path + '/samples.txt', 'w', encoding='utf8') 177 | while True: 178 | raw_text = args.prefix 179 | context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text)) 180 | generated = 0 181 | for _ in range(nsamples // batch_size): 182 | out = generate( 183 | n_ctx=n_ctx, 184 | model=model, 185 | context=context_tokens, 186 | length=length, 187 | is_fast_pattern=args.fast_pattern, tokenizer=tokenizer, 188 | temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device 189 | ) 190 | for i in range(batch_size): 191 | generated += 1 192 | text = tokenizer.convert_ids_to_tokens(out) 193 | for i, item in enumerate(text[:-1]): # 确保英文前后有空格 194 | if is_word(item) and is_word(text[i + 1]): 195 | text[i] = item + ' ' 196 | for i, item in enumerate(text): 197 | if item == '[MASK]': 198 | text[i] = '' 199 | elif item == '[CLS]': 200 | text[i] = '\n\n' 201 | elif item == '[SEP]': 202 | text[i] = '\n' 203 | info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n" 204 | print(info) 205 | text = ''.join(text).replace('##', '').strip() 206 | print(text) 207 | if args.save_samples: 208 | samples_file.write(info) 209 | samples_file.write(text) 210 | samples_file.write('\n') 211 | samples_file.write('=' * 90) 212 | samples_file.write('\n' * 2) 213 | print("=" * 80) 214 | if generated == nsamples: 215 | # close file when finish writing. 216 | if args.save_samples: 217 | samples_file.close() 218 | break 219 | 220 | 221 | if __name__ == '__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /generate_texts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import argparse 5 | from tqdm import trange 6 | from transformers import GPT2LMHeadModel 7 | 8 | os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" # 此处设置程序使用哪些显卡 9 | 10 | 11 | def is_word(word): 12 | for item in list(word): 13 | if item not in 'qwertyuiopasdfghjklzxcvbnm': 14 | return False 15 | return True 16 | 17 | 18 | def _is_chinese_char(char): 19 | """Checks whether CP is the codepoint of a CJK character.""" 20 | # This defines a "chinese character" as anything in the CJK Unicode block: 21 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 22 | # 23 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 24 | # despite its name. The modern Korean Hangul alphabet is a different block, 25 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 26 | # space-separated words, so they are not treated specially and handled 27 | # like the all of the other languages. 28 | cp = ord(char) 29 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 30 | (cp >= 0x3400 and cp <= 0x4DBF) or # 31 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 32 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 33 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 34 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 35 | (cp >= 0xF900 and cp <= 0xFAFF) or # 36 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 37 | return True 38 | 39 | return False 40 | 41 | 42 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 43 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 44 | Args: 45 | logits: logits distribution shape (vocabulary size) 46 | top_k > 0: keep only top k tokens with highest probability (top-k filtering). 47 | top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 48 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 49 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 50 | """ 51 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 52 | top_k = min(top_k, logits.size(-1)) # Safety check 53 | if top_k > 0: 54 | # Remove all tokens with a probability less than the last token of the top-k 55 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 56 | logits[indices_to_remove] = filter_value 57 | 58 | if top_p > 0.0: 59 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 60 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 61 | 62 | # Remove tokens with cumulative probability above the threshold 63 | sorted_indices_to_remove = cumulative_probs > top_p 64 | # Shift the indices to the right to keep also the first token above the threshold 65 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 66 | sorted_indices_to_remove[..., 0] = 0 67 | 68 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 69 | logits[indices_to_remove] = filter_value 70 | return logits 71 | 72 | 73 | def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0, 74 | device='cpu'): 75 | context = torch.tensor(context, dtype=torch.long, device=device) 76 | context = context.unsqueeze(0) 77 | generated = context 78 | with torch.no_grad(): 79 | for _ in trange(length): 80 | inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)} 81 | outputs = model( 82 | **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) 83 | next_token_logits = outputs[0][0, -1, :] 84 | for id in set(generated): 85 | next_token_logits[id] /= repitition_penalty 86 | next_token_logits = next_token_logits / temperature 87 | next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf') 88 | filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) 89 | next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 90 | generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) 91 | return generated 92 | 93 | 94 | def main(): 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') 97 | parser.add_argument('--length', default=-1, type=int, required=False, help='生成长度') 98 | parser.add_argument('--temperature', default=1, type=float, required=False, help='生成温度,越高越随机') 99 | parser.add_argument('--topk', default=8, type=int, required=False, help='生成的时候最高几选一') 100 | parser.add_argument('--topp', default=0, type=float, required=False, help='生成的时候积累概率最高多少') 101 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, 102 | help='模型参数路径') 103 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='词表路径') 104 | parser.add_argument('--model_path', default='model/final_model', type=str, required=False, help='模型路径') 105 | parser.add_argument('--save_path', default='generated/', type=str, required=False, help='存放生成的文件的路径') 106 | parser.add_argument('--articles_per_title', default=5, type=int, required=False, help='每个标题生成多少篇文章') 107 | parser.add_argument('--titles', default='萧炎', type=str, required=False, help='标题列表,是一个字符串,用空格分开') 108 | parser.add_argument('--titles_file', default='', type=str, required=False, 109 | help='标题列表文件,文件中每行一个标题。如果这个选项有值则titles无效') 110 | parser.add_argument('--no_wordpiece', action='store_true', help='不做word piece切词') 111 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') 112 | parser.add_argument('--repetition_penalty', default=1.0, type=float, required=False) 113 | 114 | args = parser.parse_args() 115 | print('args:\n' + args.__repr__()) 116 | 117 | if args.segment: 118 | from tokenizations import tokenization_bert_word_level as tokenization_bert 119 | else: 120 | from tokenizations import tokenization_bert 121 | 122 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 123 | length = args.length 124 | temperature = args.temperature 125 | topk = args.topk 126 | topp = args.topp 127 | repetition_penalty = args.repetition_penalty 128 | 129 | titles = args.titles.split() # 列表,里面每个元素是一个生成的标题 130 | if args.titles_file: 131 | with open(args.titles_file, 'r') as f: 132 | titles = [line.strip('\n') for line in f.readlines()] 133 | articles_per_title = args.articles_per_title # 这里定义一个标题生成多少篇文章 134 | save_path = args.save_path # 设置存到哪 135 | 136 | device = "cuda" if torch.cuda.is_available() else "cpu" 137 | 138 | tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) 139 | model = GPT2LMHeadModel.from_pretrained(args.model_path) 140 | model.to(device) 141 | model.eval() 142 | 143 | n_ctx = model.config.n_ctx 144 | 145 | if not os.path.exists(save_path): 146 | os.mkdir(save_path) 147 | if length == -1: 148 | length = model.config.n_ctx 149 | 150 | for i, title in enumerate(titles): 151 | for j in range(articles_per_title): 152 | with open(save_path + str(i) + '-' + str(j) + '.txt', 'w') as f: 153 | context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(title)) 154 | generated = 0 155 | out = sample_sequence( 156 | n_ctx=n_ctx, 157 | model=model, length=length, 158 | context=context_tokens, tokenizer=tokenizer, 159 | temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, 160 | device=device 161 | ) 162 | out = out.tolist()[0] 163 | 164 | generated += 1 165 | text = tokenizer.convert_ids_to_tokens(out) 166 | 167 | for i, item in enumerate(text[:-1]): # 确保英文前后有空格 168 | if is_word(item) and is_word(text[i + 1]): 169 | text[i] = item + ' ' 170 | 171 | for i, item in enumerate(text): 172 | if item == '[MASK]': 173 | text[i] = '' 174 | if item == '[CLS]' or item == '[SEP]': 175 | text[i] = '\n' 176 | 177 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 178 | text = ''.join(text).replace('##', '').strip() 179 | # text = ''.join(text.split('\n')[:-1]) 180 | print(text) 181 | f.write(text + '\n') 182 | print("=" * 80) 183 | 184 | 185 | if __name__ == '__main__': 186 | main() 187 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.1.1 2 | torch 3 | numpy 4 | tqdm 5 | sklearn 6 | keras 7 | tb-nightly 8 | future 9 | thulac 10 | -------------------------------------------------------------------------------- /sample/doupo.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/doupo.jpeg -------------------------------------------------------------------------------- /sample/poem_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/poem_1.png -------------------------------------------------------------------------------- /sample/poem_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/poem_2.png -------------------------------------------------------------------------------- /sample/tiyu.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/tiyu.jpg -------------------------------------------------------------------------------- /sample/律诗绝句.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/律诗绝句.png -------------------------------------------------------------------------------- /sample/散文1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/散文1.png -------------------------------------------------------------------------------- /sample/散文2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/散文2.png -------------------------------------------------------------------------------- /sample/散文3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/散文3.png -------------------------------------------------------------------------------- /sample/浣溪沙_江城子.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/浣溪沙_江城子.png -------------------------------------------------------------------------------- /sample/蝶恋花_满江红.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/蝶恋花_满江红.png -------------------------------------------------------------------------------- /sample/金庸_倚天屠龍記.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/金庸_倚天屠龍記.jpg -------------------------------------------------------------------------------- /sample/金庸_天龍八部.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/金庸_天龍八部.jpg -------------------------------------------------------------------------------- /sample/金庸_神鵰俠侶.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/金庸_神鵰俠侶.jpg -------------------------------------------------------------------------------- /sample/金庸_鹿鼎記.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Morizeyao/GPT2-Chinese/9dc45aa24275944bec6ddfd132e0681d24d631ad/sample/金庸_鹿鼎記.jpg -------------------------------------------------------------------------------- /scripts/generate.sh: -------------------------------------------------------------------------------- 1 | python generate.py \ 2 | --device 0 \ 3 | --length 900 \ 4 | --tokenizer_path cache/vocab_small.txt \ 5 | --model_path model/final_model \ 6 | --prefix "[CLS][MASK]" \ 7 | --topp 1 \ 8 | --temperature 1.0 9 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | python train.py \ 2 | --model_config config/model_config_small.json \ 3 | --tokenized_data_path data/tokenized/ \ 4 | --tokenizer_path cache/vocab_small.txt \ 5 | --raw_data_path data/train.json \ 6 | --epochs 30 \ 7 | --log_step 200 \ 8 | --stride 512 \ 9 | --output_dir model/ \ 10 | --device 0,1,2,3 \ 11 | --num_pieces 100 \ 12 | --raw 13 | -------------------------------------------------------------------------------- /tokenizations/bpe_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | from https://github.com/openai/gpt-2/, changed for chinese 3 | """ 4 | import json 5 | import os 6 | import sentencepiece as spm 7 | """ 8 | SentencePiece is an unsupervised text tokenizer and detokenizer mainly for Neural Network-based text generation 9 | systems where the vocabulary size is predetermined prior to the neural model training. SentencePiece implements 10 | subword units (e.g., byte-pair-encoding (BPE) [Sennrich et al.]) and unigram language model [Kudo.]) with the 11 | extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end 12 | system that does not depend on language-specific pre/postprocessing. 13 | https://github.com/google/sentencepiece 14 | 15 | pip install sentencepiece 16 | 17 | or git clone https://github.com/google/sentencepiece.git 18 | python setup.py install 19 | 20 | """ 21 | 22 | def get_pairs(word): 23 | pairs = set() 24 | prev_char = word[0] 25 | for char in word[1:]: 26 | pairs.add((prev_char, char)) 27 | prev_char = char 28 | return pairs 29 | 30 | 31 | class Encoder: 32 | def __init__(self, encoder, bpe_merges): 33 | self.encoder = encoder 34 | self.decoder = {v: k for k, v in self.encoder.items()} 35 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 36 | self.cache = {} 37 | self.max_len = 0 38 | 39 | def bpe(self, token): 40 | if token in self.cache: 41 | return self.cache[token] 42 | word = tuple(token) 43 | pairs = get_pairs(word) 44 | if not pairs: 45 | return token 46 | 47 | while True: 48 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 49 | if bigram not in self.bpe_ranks: 50 | break 51 | first, second = bigram 52 | new_word = [] 53 | i = 0 54 | while i < len(word): 55 | try: 56 | j = word.index(first, i) 57 | new_word.extend(word[i:j]) 58 | i = j 59 | except: 60 | new_word.extend(word[i:]) 61 | break 62 | 63 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 64 | new_word.append(first + second) 65 | i += 2 66 | else: 67 | new_word.append(word[i]) 68 | i += 1 69 | new_word = tuple(new_word) 70 | word = new_word 71 | if len(word) == 1: 72 | break 73 | else: 74 | pairs = get_pairs(word) 75 | word = ' '.join(word) 76 | self.cache[token] = word 77 | return word 78 | 79 | def encode(self, text): 80 | return [self.encoder.get(token, 1) for token in self.tokenize(text)] 81 | 82 | def decode(self, tokens): 83 | text = ''.join([self.decoder[token] for token in tokens]) 84 | return text 85 | 86 | def tokenize(self, text): 87 | bpe_tokens = [] 88 | bpe_tokens.extend(bpe_token for bpe_token in self.bpe(text).split(' ')) 89 | return bpe_tokens 90 | 91 | def convert_tokens_to_ids(self, tokens): 92 | return [self.encoder.get(token, 1) for token in tokens] 93 | 94 | class Encoder_SP: 95 | def __init__(self, model_path): 96 | self.sp = spm.SentencePieceProcessor() 97 | self.sp.Load(model_path) 98 | 99 | 100 | def encode(self, text): 101 | """ 102 | text="...." 103 | """ 104 | return self.sp.EncodeAsIds(text) 105 | 106 | 107 | def decode(self, tokens): 108 | """ 109 | tokens=[x1,x2,...] 110 | """ 111 | text = [int(token) for token in tokens] 112 | #print(text) 113 | return self.sp.DecodeIds(text) 114 | 115 | def tokenize(self, text): 116 | return self.sp.EncodeAsPieces(text) 117 | 118 | def convert_tokens_to_ids(self, tokens): 119 | return [self.sp.PieceToId(token) for token in tokens] 120 | 121 | def get_encoder(encoder_file, bpe_file): 122 | 123 | #以下是为了同一个函数入兼容sentencepiece 124 | filepath, filename = os.path.split(encoder_file) 125 | shotname, extension = os.path.splitext(filename) 126 | 127 | if(".model" == extension) and (bpe_file == ""): 128 | return Encoder_SP(encoder_file) 129 | else: 130 | with open(encoder_file, 'r', encoding="utf-8") as f: 131 | encoder = json.load(f) 132 | with open(bpe_file, 'r', encoding="utf-8") as f: 133 | bpe_data = f.read() 134 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 135 | return Encoder( 136 | encoder=encoder, 137 | bpe_merges=bpe_merges, 138 | ) 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /tokenizations/encoder.json: -------------------------------------------------------------------------------- 1 | {"c":0, "d":1, "大学":2} -------------------------------------------------------------------------------- /tokenizations/thulac_dict/seg: -------------------------------------------------------------------------------- 1 | [SEP] 2 | [PAD] 3 | [CLS] 4 | [UNK] 5 | [MASK] -------------------------------------------------------------------------------- /tokenizations/tokenization_bert.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | from io import open 24 | 25 | from transformers.tokenization_utils import PreTrainedTokenizer 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 30 | 31 | PRETRAINED_VOCAB_FILES_MAP = { 32 | 'vocab_file': 33 | { 34 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 35 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 36 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 37 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 38 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 39 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 40 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 41 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 42 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 43 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 44 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 45 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 46 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 47 | } 48 | } 49 | 50 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 51 | 'bert-base-uncased': 512, 52 | 'bert-large-uncased': 512, 53 | 'bert-base-cased': 512, 54 | 'bert-large-cased': 512, 55 | 'bert-base-multilingual-uncased': 512, 56 | 'bert-base-multilingual-cased': 512, 57 | 'bert-base-chinese': 512, 58 | 'bert-base-german-cased': 512, 59 | 'bert-large-uncased-whole-word-masking': 512, 60 | 'bert-large-cased-whole-word-masking': 512, 61 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 62 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 63 | 'bert-base-cased-finetuned-mrpc': 512, 64 | } 65 | 66 | def load_vocab(vocab_file): 67 | """Loads a vocabulary file into a dictionary.""" 68 | vocab = collections.OrderedDict() 69 | with open(vocab_file, "r", encoding="utf-8") as reader: 70 | tokens = reader.readlines() 71 | for index, token in enumerate(tokens): 72 | token = token.rstrip('\n') 73 | vocab[token] = index 74 | return vocab 75 | 76 | 77 | def whitespace_tokenize(text): 78 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 79 | text = text.strip() 80 | if not text: 81 | return [] 82 | tokens = text.split() 83 | return tokens 84 | 85 | 86 | class BertTokenizer(PreTrainedTokenizer): 87 | r""" 88 | Constructs a BertTokenizer. 89 | :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 90 | 91 | Args: 92 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 93 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 94 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 95 | max_len: An artificial maximum length to truncate tokenized_doupo sequences to; Effective maximum length is always the 96 | minimum of this value (if specified) and the underlying BERT model's sequence length. 97 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 98 | do_wordpiece_only=False 99 | """ 100 | 101 | vocab_files_names = VOCAB_FILES_NAMES 102 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 103 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 104 | 105 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 106 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 107 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 108 | """Constructs a BertTokenizer. 109 | 110 | Args: 111 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 112 | **do_lower_case**: (`optional`) boolean (default True) 113 | Whether to lower case the input 114 | Only has an effect when do_basic_tokenize=True 115 | **do_basic_tokenize**: (`optional`) boolean (default True) 116 | Whether to do basic tokenization before wordpiece. 117 | **never_split**: (`optional`) list of string 118 | List of tokens which will never be split during tokenization. 119 | Only has an effect when do_basic_tokenize=True 120 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 121 | Whether to tokenize Chinese characters. 122 | This should likely be desactivated for Japanese: 123 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 124 | """ 125 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 126 | pad_token=pad_token, cls_token=cls_token, 127 | mask_token=mask_token, **kwargs) 128 | if not os.path.isfile(vocab_file): 129 | raise ValueError( 130 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 131 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 132 | self.vocab = load_vocab(vocab_file) 133 | self.ids_to_tokens = collections.OrderedDict( 134 | [(ids, tok) for tok, ids in self.vocab.items()]) 135 | self.do_basic_tokenize = do_basic_tokenize 136 | if do_basic_tokenize: 137 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 138 | never_split=never_split, 139 | tokenize_chinese_chars=tokenize_chinese_chars) 140 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 141 | 142 | @property 143 | def vocab_size(self): 144 | return len(self.vocab) 145 | 146 | def _tokenize(self, text): 147 | split_tokens = [] 148 | if self.do_basic_tokenize: 149 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 150 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 151 | split_tokens.append(sub_token) 152 | else: 153 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 154 | return split_tokens 155 | 156 | def _convert_token_to_id(self, token): 157 | """ Converts a token (str/unicode) in an id using the vocab. """ 158 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 159 | 160 | def _convert_id_to_token(self, index): 161 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 162 | return self.ids_to_tokens.get(index, self.unk_token) 163 | 164 | def convert_tokens_to_string(self, tokens): 165 | """ Converts a sequence of tokens (string) in a single string. """ 166 | out_string = ' '.join(tokens).replace(' ##', '').strip() 167 | return out_string 168 | 169 | def save_vocabulary(self, vocab_path): 170 | """Save the tokenizer vocabulary to a directory or file.""" 171 | index = 0 172 | if os.path.isdir(vocab_path): 173 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 174 | with open(vocab_file, "w", encoding="utf-8") as writer: 175 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 176 | if index != token_index: 177 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 178 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 179 | index = token_index 180 | writer.write(token + u'\n') 181 | index += 1 182 | return (vocab_file,) 183 | 184 | @classmethod 185 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 186 | """ Instantiate a BertTokenizer from pre-trained vocabulary files. 187 | """ 188 | if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: 189 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 190 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 191 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 192 | "you may want to check this behavior.") 193 | kwargs['do_lower_case'] = False 194 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 195 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 196 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 197 | "but you may want to check this behavior.") 198 | kwargs['do_lower_case'] = True 199 | 200 | return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 201 | 202 | 203 | class BasicTokenizer(object): 204 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 205 | 206 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 207 | """ Constructs a BasicTokenizer. 208 | 209 | Args: 210 | **do_lower_case**: Whether to lower case the input. 211 | **never_split**: (`optional`) list of str 212 | Kept for backward compatibility purposes. 213 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 214 | List of token not to split. 215 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 216 | Whether to tokenize Chinese characters. 217 | This should likely be desactivated for Japanese: 218 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 219 | """ 220 | if never_split is None: 221 | never_split = [] 222 | self.do_lower_case = do_lower_case 223 | self.never_split = never_split 224 | self.tokenize_chinese_chars = tokenize_chinese_chars 225 | 226 | def tokenize(self, text, never_split=None): 227 | """ Basic Tokenization of a piece of text. 228 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 229 | 230 | Args: 231 | **never_split**: (`optional`) list of str 232 | Kept for backward compatibility purposes. 233 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 234 | List of token not to split. 235 | """ 236 | never_split = self.never_split + (never_split if never_split is not None else []) 237 | text = self._clean_text(text) 238 | # This was added on November 1st, 2018 for the multilingual and Chinese 239 | # models. This is also applied to the English models now, but it doesn't 240 | # matter since the English models were not trained on any Chinese data 241 | # and generally don't have any Chinese data in them (there are Chinese 242 | # characters in the vocabulary because Wikipedia does have some Chinese 243 | # words in the English Wikipedia.). 244 | if self.tokenize_chinese_chars: 245 | text = self._tokenize_chinese_chars(text) 246 | orig_tokens = whitespace_tokenize(text) 247 | split_tokens = [] 248 | for token in orig_tokens: 249 | if self.do_lower_case and token not in never_split: 250 | token = token.lower() 251 | token = self._run_strip_accents(token) 252 | split_tokens.extend(self._run_split_on_punc(token)) 253 | 254 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 255 | return output_tokens 256 | 257 | def _run_strip_accents(self, text): 258 | """Strips accents from a piece of text.""" 259 | text = unicodedata.normalize("NFD", text) 260 | output = [] 261 | for char in text: 262 | cat = unicodedata.category(char) 263 | if cat == "Mn": 264 | continue 265 | output.append(char) 266 | return "".join(output) 267 | 268 | def _run_split_on_punc(self, text, never_split=None): 269 | """Splits punctuation on a piece of text.""" 270 | if never_split is not None and text in never_split: 271 | return [text] 272 | chars = list(text) 273 | i = 0 274 | start_new_word = True 275 | output = [] 276 | while i < len(chars): 277 | char = chars[i] 278 | if _is_punctuation(char): 279 | output.append([char]) 280 | start_new_word = True 281 | else: 282 | if start_new_word: 283 | output.append([]) 284 | start_new_word = False 285 | output[-1].append(char) 286 | i += 1 287 | 288 | return ["".join(x) for x in output] 289 | 290 | def _tokenize_chinese_chars(self, text): 291 | """Adds whitespace around any CJK character.""" 292 | output = [] 293 | for char in text: 294 | cp = ord(char) 295 | if self._is_chinese_char(cp) or char.isdigit(): 296 | output.append(" ") 297 | output.append(char) 298 | output.append(" ") 299 | else: 300 | output.append(char) 301 | return "".join(output) 302 | 303 | def _is_chinese_char(self, cp): 304 | """Checks whether CP is the codepoint of a CJK character.""" 305 | # This defines a "chinese character" as anything in the CJK Unicode block: 306 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 307 | # 308 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 309 | # despite its name. The modern Korean Hangul alphabet is a different block, 310 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 311 | # space-separated words, so they are not treated specially and handled 312 | # like the all of the other languages. 313 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 314 | (cp >= 0x3400 and cp <= 0x4DBF) or # 315 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 316 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 317 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 318 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 319 | (cp >= 0xF900 and cp <= 0xFAFF) or # 320 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 321 | return True 322 | 323 | return False 324 | 325 | def _clean_text(self, text): 326 | """Performs invalid character removal and whitespace cleanup on text.""" 327 | output = [] 328 | for char in text: 329 | cp = ord(char) 330 | if cp == 0 or cp == 0xfffd or _is_control(char): 331 | continue 332 | if _is_whitespace(char): 333 | output.append(" ") 334 | else: 335 | output.append(char) 336 | return "".join(output) 337 | 338 | 339 | class WordpieceTokenizer(object): 340 | """Runs WordPiece tokenization.""" 341 | 342 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 343 | self.vocab = vocab 344 | self.unk_token = unk_token 345 | self.max_input_chars_per_word = max_input_chars_per_word 346 | 347 | def tokenize(self, text): 348 | """Tokenizes a piece of text into its word pieces. 349 | 350 | This uses a greedy longest-match-first algorithm to perform tokenization 351 | using the given vocabulary. 352 | 353 | For example: 354 | input = "unaffable" 355 | output = ["un", "##aff", "##able"] 356 | 357 | Args: 358 | text: A single token or whitespace separated tokens. This should have 359 | already been passed through `BasicTokenizer`. 360 | 361 | Returns: 362 | A list of wordpiece tokens. 363 | """ 364 | 365 | output_tokens = [] 366 | for token in whitespace_tokenize(text): 367 | chars = list(token) 368 | if len(chars) > self.max_input_chars_per_word: 369 | output_tokens.append(self.unk_token) 370 | continue 371 | 372 | is_bad = False 373 | start = 0 374 | sub_tokens = [] 375 | while start < len(chars): 376 | end = len(chars) 377 | cur_substr = None 378 | while start < end: 379 | substr = "".join(chars[start:end]) 380 | if start > 0: 381 | substr = "##" + substr 382 | if substr in self.vocab: 383 | cur_substr = substr 384 | break 385 | end -= 1 386 | if cur_substr is None: 387 | is_bad = True 388 | break 389 | sub_tokens.append(cur_substr) 390 | start = end 391 | 392 | if is_bad: 393 | output_tokens.append(self.unk_token) 394 | else: 395 | output_tokens.extend(sub_tokens) 396 | return output_tokens 397 | 398 | 399 | def _is_whitespace(char): 400 | """Checks whether `chars` is a whitespace character.""" 401 | # \t, \n, and \r are technically contorl characters but we treat them 402 | # as whitespace since they are generally considered as such. 403 | if char == " " or char == "\t" or char == "\n" or char == "\r": 404 | return True 405 | cat = unicodedata.category(char) 406 | if cat == "Zs": 407 | return True 408 | return False 409 | 410 | 411 | def _is_control(char): 412 | """Checks whether `chars` is a control character.""" 413 | # These are technically control characters but we count them as whitespace 414 | # characters. 415 | if char == "\t" or char == "\n" or char == "\r": 416 | return False 417 | cat = unicodedata.category(char) 418 | if cat.startswith("C"): 419 | return True 420 | return False 421 | 422 | 423 | def _is_punctuation(char): 424 | """Checks whether `chars` is a punctuation character.""" 425 | cp = ord(char) 426 | # We treat all non-letter/number ASCII as punctuation. 427 | # Characters such as "^", "$", and "`" are not in the Unicode 428 | # Punctuation class but we treat them as punctuation anyways, for 429 | # consistency. 430 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 431 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 432 | return True 433 | cat = unicodedata.category(char) 434 | if cat.startswith("P"): 435 | return True 436 | return False 437 | -------------------------------------------------------------------------------- /tokenizations/tokenization_bert_word_level.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes.""" 16 | 17 | from __future__ import absolute_import, division, print_function, unicode_literals 18 | 19 | import collections 20 | import logging 21 | import os 22 | import unicodedata 23 | import thulac 24 | from io import open 25 | 26 | from transformers.tokenization_utils import PreTrainedTokenizer 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | lac = thulac.thulac(user_dict='tokenizations/thulac_dict/seg', seg_only=True) 31 | 32 | VOCAB_FILES_NAMES = {'vocab_file': 'vocab.txt'} 33 | 34 | PRETRAINED_VOCAB_FILES_MAP = { 35 | 'vocab_file': 36 | { 37 | 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 38 | 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", 39 | 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", 40 | 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", 41 | 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", 42 | 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", 43 | 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", 44 | 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-vocab.txt", 45 | 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-vocab.txt", 46 | 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-vocab.txt", 47 | 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-vocab.txt", 48 | 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt", 49 | 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt", 50 | } 51 | } 52 | 53 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 54 | 'bert-base-uncased': 512, 55 | 'bert-large-uncased': 512, 56 | 'bert-base-cased': 512, 57 | 'bert-large-cased': 512, 58 | 'bert-base-multilingual-uncased': 512, 59 | 'bert-base-multilingual-cased': 512, 60 | 'bert-base-chinese': 512, 61 | 'bert-base-german-cased': 512, 62 | 'bert-large-uncased-whole-word-masking': 512, 63 | 'bert-large-cased-whole-word-masking': 512, 64 | 'bert-large-uncased-whole-word-masking-finetuned-squad': 512, 65 | 'bert-large-cased-whole-word-masking-finetuned-squad': 512, 66 | 'bert-base-cased-finetuned-mrpc': 512, 67 | } 68 | 69 | def load_vocab(vocab_file): 70 | """Loads a vocabulary file into a dictionary.""" 71 | vocab = collections.OrderedDict() 72 | with open(vocab_file, "r", encoding="utf-8") as reader: 73 | tokens = reader.readlines() 74 | for index, token in enumerate(tokens): 75 | token = token.rstrip('\n') 76 | vocab[token] = index 77 | return vocab 78 | 79 | 80 | def whitespace_tokenize(text): 81 | """Runs basic whitespace cleaning and splitting on a piece of text.""" 82 | text = text.strip() 83 | if not text: 84 | return [] 85 | tokens = text.split() 86 | return tokens 87 | 88 | 89 | class BertTokenizer(PreTrainedTokenizer): 90 | r""" 91 | Constructs a BertTokenizer. 92 | :class:`~pytorch_pretrained_bert.BertTokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece 93 | 94 | Args: 95 | vocab_file: Path to a one-wordpiece-per-line vocabulary file 96 | do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False 97 | do_basic_tokenize: Whether to do basic tokenization before wordpiece. 98 | max_len: An artificial maximum length to truncate tokenized_doupo sequences to; Effective maximum length is always the 99 | minimum of this value (if specified) and the underlying BERT model's sequence length. 100 | never_split: List of tokens which will never be split during tokenization. Only has an effect when 101 | do_wordpiece_only=False 102 | """ 103 | 104 | vocab_files_names = VOCAB_FILES_NAMES 105 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 106 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 107 | 108 | def __init__(self, vocab_file, do_lower_case=True, do_basic_tokenize=True, never_split=None, 109 | unk_token="[UNK]", sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]", 110 | mask_token="[MASK]", tokenize_chinese_chars=True, **kwargs): 111 | """Constructs a BertTokenizer. 112 | 113 | Args: 114 | **vocab_file**: Path to a one-wordpiece-per-line vocabulary file 115 | **do_lower_case**: (`optional`) boolean (default True) 116 | Whether to lower case the input 117 | Only has an effect when do_basic_tokenize=True 118 | **do_basic_tokenize**: (`optional`) boolean (default True) 119 | Whether to do basic tokenization before wordpiece. 120 | **never_split**: (`optional`) list of string 121 | List of tokens which will never be split during tokenization. 122 | Only has an effect when do_basic_tokenize=True 123 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 124 | Whether to tokenize Chinese characters. 125 | This should likely be desactivated for Japanese: 126 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 127 | """ 128 | super(BertTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 129 | pad_token=pad_token, cls_token=cls_token, 130 | mask_token=mask_token, **kwargs) 131 | if not os.path.isfile(vocab_file): 132 | raise ValueError( 133 | "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " 134 | "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) 135 | self.vocab = load_vocab(vocab_file) 136 | self.ids_to_tokens = collections.OrderedDict( 137 | [(ids, tok) for tok, ids in self.vocab.items()]) 138 | self.do_basic_tokenize = do_basic_tokenize 139 | if do_basic_tokenize: 140 | self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 141 | never_split=never_split, 142 | tokenize_chinese_chars=tokenize_chinese_chars) 143 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) 144 | 145 | @property 146 | def vocab_size(self): 147 | return len(self.vocab) 148 | 149 | def _tokenize(self, text): 150 | split_tokens = [] 151 | if self.do_basic_tokenize: 152 | for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): 153 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 154 | split_tokens.append(sub_token) 155 | else: 156 | split_tokens = self.wordpiece_tokenizer.tokenize(text) 157 | return split_tokens 158 | 159 | def _convert_token_to_id(self, token): 160 | """ Converts a token (str/unicode) in an id using the vocab. """ 161 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 162 | 163 | def _convert_id_to_token(self, index): 164 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 165 | return self.ids_to_tokens.get(index, self.unk_token) 166 | 167 | def convert_tokens_to_string(self, tokens): 168 | """ Converts a sequence of tokens (string) in a single string. """ 169 | out_string = ' '.join(tokens).replace(' ##', '').strip() 170 | return out_string 171 | 172 | def save_vocabulary(self, vocab_path): 173 | """Save the tokenizer vocabulary to a directory or file.""" 174 | index = 0 175 | if os.path.isdir(vocab_path): 176 | vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES['vocab_file']) 177 | with open(vocab_file, "w", encoding="utf-8") as writer: 178 | for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): 179 | if index != token_index: 180 | logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive." 181 | " Please check that the vocabulary is not corrupted!".format(vocab_file)) 182 | index = token_index 183 | writer.write(token + u'\n') 184 | index += 1 185 | return (vocab_file,) 186 | 187 | @classmethod 188 | def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): 189 | """ Instantiate a BertTokenizer from pre-trained vocabulary files. 190 | """ 191 | if pretrained_model_name_or_path in PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES: 192 | if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): 193 | logger.warning("The pre-trained model you are loading is a cased model but you have not set " 194 | "`do_lower_case` to False. We are setting `do_lower_case=False` for you but " 195 | "you may want to check this behavior.") 196 | kwargs['do_lower_case'] = False 197 | elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): 198 | logger.warning("The pre-trained model you are loading is an uncased model but you have set " 199 | "`do_lower_case` to False. We are setting `do_lower_case=True` for you " 200 | "but you may want to check this behavior.") 201 | kwargs['do_lower_case'] = True 202 | 203 | return super(BertTokenizer, cls)._from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) 204 | 205 | 206 | class BasicTokenizer(object): 207 | """Runs basic tokenization (punctuation splitting, lower casing, etc.).""" 208 | 209 | def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True): 210 | """ Constructs a BasicTokenizer. 211 | 212 | Args: 213 | **do_lower_case**: Whether to lower case the input. 214 | **never_split**: (`optional`) list of str 215 | Kept for backward compatibility purposes. 216 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 217 | List of token not to split. 218 | **tokenize_chinese_chars**: (`optional`) boolean (default True) 219 | Whether to tokenize Chinese characters. 220 | This should likely be desactivated for Japanese: 221 | see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328 222 | """ 223 | if never_split is None: 224 | never_split = [] 225 | self.do_lower_case = do_lower_case 226 | self.never_split = never_split 227 | self.tokenize_chinese_chars = tokenize_chinese_chars 228 | 229 | def tokenize(self, text, never_split=None): 230 | """ Basic Tokenization of a piece of text. 231 | Split on "white spaces" only, for sub-word tokenization, see WordPieceTokenizer. 232 | 233 | Args: 234 | **never_split**: (`optional`) list of str 235 | Kept for backward compatibility purposes. 236 | Now implemented directly at the base class level (see :func:`PreTrainedTokenizer.tokenize`) 237 | List of token not to split. 238 | """ 239 | never_split = self.never_split + (never_split if never_split is not None else []) 240 | text = self._clean_text(text) 241 | # This was added on November 1st, 2018 for the multilingual and Chinese 242 | # models. This is also applied to the English models now, but it doesn't 243 | # matter since the English models were not trained on any Chinese data 244 | # and generally don't have any Chinese data in them (there are Chinese 245 | # characters in the vocabulary because Wikipedia does have some Chinese 246 | # words in the English Wikipedia.). 247 | if self.tokenize_chinese_chars: 248 | text = self._tokenize_chinese_chars(text) 249 | orig_tokens = whitespace_tokenize(text) 250 | split_tokens = [] 251 | for token in orig_tokens: 252 | if self.do_lower_case and token not in never_split: 253 | token = token.lower() 254 | token = self._run_strip_accents(token) 255 | split_tokens.extend(self._run_split_on_punc(token)) 256 | 257 | output_tokens = whitespace_tokenize(" ".join(split_tokens)) 258 | return output_tokens 259 | 260 | def _run_strip_accents(self, text): 261 | """Strips accents from a piece of text.""" 262 | text = unicodedata.normalize("NFD", text) 263 | output = [] 264 | for char in text: 265 | cat = unicodedata.category(char) 266 | if cat == "Mn": 267 | continue 268 | output.append(char) 269 | return "".join(output) 270 | 271 | def _run_split_on_punc(self, text, never_split=None): 272 | """Splits punctuation on a piece of text.""" 273 | if never_split is not None and text in never_split: 274 | return [text] 275 | chars = list(text) 276 | i = 0 277 | start_new_word = True 278 | output = [] 279 | while i < len(chars): 280 | char = chars[i] 281 | if _is_punctuation(char): 282 | output.append([char]) 283 | start_new_word = True 284 | else: 285 | if start_new_word: 286 | output.append([]) 287 | start_new_word = False 288 | output[-1].append(char) 289 | i += 1 290 | 291 | return ["".join(x) for x in output] 292 | 293 | # def _tokenize_chinese_chars(self, text): 294 | # """Adds whitespace around any CJK character.""" 295 | # output = [] 296 | # for char in text: 297 | # cp = ord(char) 298 | # if self._is_chinese_char(cp) or char.isdigit(): 299 | # output.append(" ") 300 | # output.append(char) 301 | # output.append(" ") 302 | # else: 303 | # output.append(char) 304 | # return "".join(output) 305 | def _tokenize_chinese_chars(self, text): 306 | """Adds whitespace around any CJK character.""" 307 | output = [] 308 | for char in text: 309 | if char.isdigit(): 310 | output.append(" ") 311 | output.append(char) 312 | output.append(" ") 313 | else: 314 | output.append(char) 315 | text = "".join(output) 316 | text = [item[0].strip() for item in lac.cut(text)] 317 | text = [item for item in text if item] 318 | return " ".join(text) 319 | 320 | def _is_chinese_char(self, cp): 321 | """Checks whether CP is the codepoint of a CJK character.""" 322 | # This defines a "chinese character" as anything in the CJK Unicode block: 323 | # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) 324 | # 325 | # Note that the CJK Unicode block is NOT all Japanese and Korean characters, 326 | # despite its name. The modern Korean Hangul alphabet is a different block, 327 | # as is Japanese Hiragana and Katakana. Those alphabets are used to write 328 | # space-separated words, so they are not treated specially and handled 329 | # like the all of the other languages. 330 | if ((cp >= 0x4E00 and cp <= 0x9FFF) or # 331 | (cp >= 0x3400 and cp <= 0x4DBF) or # 332 | (cp >= 0x20000 and cp <= 0x2A6DF) or # 333 | (cp >= 0x2A700 and cp <= 0x2B73F) or # 334 | (cp >= 0x2B740 and cp <= 0x2B81F) or # 335 | (cp >= 0x2B820 and cp <= 0x2CEAF) or 336 | (cp >= 0xF900 and cp <= 0xFAFF) or # 337 | (cp >= 0x2F800 and cp <= 0x2FA1F)): # 338 | return True 339 | 340 | return False 341 | 342 | def _clean_text(self, text): 343 | """Performs invalid character removal and whitespace cleanup on text.""" 344 | output = [] 345 | for char in text: 346 | cp = ord(char) 347 | if cp == 0 or cp == 0xfffd or _is_control(char): 348 | continue 349 | if _is_whitespace(char): 350 | output.append(" ") 351 | else: 352 | output.append(char) 353 | return "".join(output) 354 | 355 | 356 | class WordpieceTokenizer(object): 357 | """Runs WordPiece tokenization.""" 358 | 359 | def __init__(self, vocab, unk_token, max_input_chars_per_word=100): 360 | self.vocab = vocab 361 | self.unk_token = unk_token 362 | self.max_input_chars_per_word = max_input_chars_per_word 363 | 364 | def tokenize(self, text): 365 | """Tokenizes a piece of text into its word pieces. 366 | 367 | This uses a greedy longest-match-first algorithm to perform tokenization 368 | using the given vocabulary. 369 | 370 | For example: 371 | input = "unaffable" 372 | output = ["un", "##aff", "##able"] 373 | 374 | Args: 375 | text: A single token or whitespace separated tokens. This should have 376 | already been passed through `BasicTokenizer`. 377 | 378 | Returns: 379 | A list of wordpiece tokens. 380 | """ 381 | 382 | output_tokens = [] 383 | for token in whitespace_tokenize(text): 384 | chars = list(token) 385 | if len(chars) > self.max_input_chars_per_word: 386 | output_tokens.append(self.unk_token) 387 | continue 388 | 389 | is_bad = False 390 | start = 0 391 | sub_tokens = [] 392 | while start < len(chars): 393 | end = len(chars) 394 | cur_substr = None 395 | while start < end: 396 | substr = "".join(chars[start:end]) 397 | if start > 0: 398 | substr = "##" + substr 399 | if substr in self.vocab: 400 | cur_substr = substr 401 | break 402 | end -= 1 403 | if cur_substr is None: 404 | is_bad = True 405 | break 406 | sub_tokens.append(cur_substr) 407 | start = end 408 | 409 | if is_bad: 410 | output_tokens.append(self.unk_token) 411 | else: 412 | output_tokens.extend(sub_tokens) 413 | return output_tokens 414 | 415 | 416 | def _is_whitespace(char): 417 | """Checks whether `chars` is a whitespace character.""" 418 | # \t, \n, and \r are technically contorl characters but we treat them 419 | # as whitespace since they are generally considered as such. 420 | if char == " " or char == "\t" or char == "\n" or char == "\r": 421 | return True 422 | cat = unicodedata.category(char) 423 | if cat == "Zs": 424 | return True 425 | return False 426 | 427 | 428 | def _is_control(char): 429 | """Checks whether `chars` is a control character.""" 430 | # These are technically control characters but we count them as whitespace 431 | # characters. 432 | if char == "\t" or char == "\n" or char == "\r": 433 | return False 434 | cat = unicodedata.category(char) 435 | if cat.startswith("C"): 436 | return True 437 | return False 438 | 439 | 440 | def _is_punctuation(char): 441 | """Checks whether `chars` is a punctuation character.""" 442 | cp = ord(char) 443 | # We treat all non-letter/number ASCII as punctuation. 444 | # Characters such as "^", "$", and "`" are not in the Unicode 445 | # Punctuation class but we treat them as punctuation anyways, for 446 | # consistency. 447 | if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or 448 | (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): 449 | return True 450 | cat = unicodedata.category(char) 451 | if cat.startswith("P"): 452 | return True 453 | return False 454 | -------------------------------------------------------------------------------- /tokenizations/vocab.bpe: -------------------------------------------------------------------------------- 1 | #version: 0.2 2 | 大 学 -------------------------------------------------------------------------------- /train.json: -------------------------------------------------------------------------------- 1 | ["第一篇文章的正文", "第二篇文章的正文", "第三篇文章的正文"] -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import json 5 | import random 6 | import numpy as np 7 | import argparse 8 | from torch.utils.tensorboard import SummaryWriter 9 | from datetime import datetime 10 | from tqdm import tqdm 11 | from torch.nn import DataParallel 12 | from tokenizations.bpe_tokenizer import get_encoder 13 | 14 | 15 | def build_files(data_path, tokenized_data_path, num_pieces, full_tokenizer, min_length): 16 | with open(data_path, 'r', encoding='utf8') as f: 17 | print('reading lines') 18 | lines = json.load(f) 19 | lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 20 | all_len = len(lines) 21 | if not os.path.exists(tokenized_data_path): 22 | os.mkdir(tokenized_data_path) 23 | for i in tqdm(range(num_pieces)): 24 | sublines = lines[all_len // num_pieces * i: all_len // num_pieces * (i + 1)] 25 | if i == num_pieces - 1: 26 | sublines.extend(lines[all_len // num_pieces * (i + 1):]) # 把尾部例子添加到最后一个piece 27 | sublines = [full_tokenizer.tokenize(line) for line in sublines if 28 | len(line) > min_length] # 只考虑长度超过min_length的句子 29 | sublines = [full_tokenizer.convert_tokens_to_ids(line) for line in sublines] 30 | full_line = [] 31 | for subline in sublines: 32 | full_line.append(full_tokenizer.convert_tokens_to_ids('[MASK]')) # 文章开头添加MASK表示文章开始 33 | full_line.extend(subline) 34 | full_line.append(full_tokenizer.convert_tokens_to_ids('[CLS]')) # 文章之间添加CLS表示文章结束 35 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: 36 | for id in full_line: 37 | f.write(str(id) + ' ') 38 | print('finish') 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') 44 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, 45 | help='选择模型参数') 46 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') 47 | parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料') 48 | parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False, 49 | help='tokenized语料存放位置') 50 | parser.add_argument('--raw', action='store_true', help='是否先做tokenize') 51 | parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环') 52 | parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') 53 | parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') 54 | parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') 55 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss,设置为gradient accumulation的整数倍') 56 | parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') 57 | parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') 58 | parser.add_argument('--fp16', action='store_true', help='混合精度') 59 | parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False) 60 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 61 | parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') 62 | parser.add_argument('--min_length', default=128, type=int, required=False, help='最短收录文章长度') 63 | parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径') 64 | parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径') 65 | parser.add_argument('--writer_dir', default='tensorboard_summary/', type=str, required=False, help='Tensorboard路径') 66 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') 67 | parser.add_argument('--bpe_token', action='store_true', help='subword') 68 | parser.add_argument('--encoder_json', default="tokenizations/encoder.json", type=str, help="encoder.json") 69 | parser.add_argument('--vocab_bpe', default="tokenizations/vocab.bpe", type=str, help="vocab.bpe") 70 | 71 | args = parser.parse_args() 72 | print('args:\n' + args.__repr__()) 73 | 74 | if args.segment: 75 | from tokenizations import tokenization_bert_word_level as tokenization_bert 76 | else: 77 | from tokenizations import tokenization_bert 78 | 79 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 80 | 81 | model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) 82 | print('config:\n' + model_config.to_json_string()) 83 | 84 | n_ctx = model_config.n_ctx 85 | if args.bpe_token: 86 | full_tokenizer = get_encoder(args.encoder_json, args.vocab_bpe) 87 | else: 88 | full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) 89 | full_tokenizer.max_len = 999999 90 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 91 | print('using device:', device) 92 | 93 | raw_data_path = args.raw_data_path 94 | tokenized_data_path = args.tokenized_data_path 95 | raw = args.raw # 选择是否从零开始构建数据集 96 | epochs = args.epochs 97 | batch_size = args.batch_size 98 | lr = args.lr 99 | warmup_steps = args.warmup_steps 100 | log_step = args.log_step 101 | stride = args.stride 102 | gradient_accumulation = args.gradient_accumulation 103 | fp16 = args.fp16 # 不支持半精度的显卡请勿打开 104 | fp16_opt_level = args.fp16_opt_level 105 | max_grad_norm = args.max_grad_norm 106 | num_pieces = args.num_pieces 107 | min_length = args.min_length 108 | output_dir = args.output_dir 109 | tb_writer = SummaryWriter(log_dir=args.writer_dir) 110 | assert log_step % gradient_accumulation == 0 111 | 112 | if not os.path.exists(output_dir): 113 | os.mkdir(output_dir) 114 | 115 | if raw: 116 | print('building files') 117 | build_files(data_path=raw_data_path, tokenized_data_path=tokenized_data_path, num_pieces=num_pieces, 118 | full_tokenizer=full_tokenizer, min_length=min_length) 119 | print('files built') 120 | 121 | if not args.pretrained_model: 122 | model = transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config) 123 | else: 124 | model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model) 125 | model.train() 126 | model.to(device) 127 | 128 | num_parameters = 0 129 | parameters = model.parameters() 130 | for parameter in parameters: 131 | num_parameters += parameter.numel() 132 | print('number of parameters: {}'.format(num_parameters)) 133 | 134 | multi_gpu = False 135 | full_len = 0 136 | print('calculating total steps') 137 | for i in tqdm(range(num_pieces)): 138 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 139 | full_len += len([int(item) for item in f.read().strip().split()]) 140 | total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation) 141 | print('total steps = {}'.format(total_steps)) 142 | 143 | optimizer = transformers.AdamW(model.parameters(), lr=lr, correct_bias=True) 144 | scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, 145 | t_total=total_steps) 146 | if fp16: 147 | try: 148 | from apex import amp 149 | except ImportError: 150 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 151 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 152 | 153 | if torch.cuda.device_count() > 1: 154 | print("Let's use", torch.cuda.device_count(), "GPUs!") 155 | model = DataParallel(model, device_ids=[int(i) for i in args.device.split(',')]) 156 | multi_gpu = True 157 | print('starting training') 158 | overall_step = 0 159 | running_loss = 0 160 | for epoch in range(epochs): 161 | print('epoch {}'.format(epoch + 1)) 162 | now = datetime.now() 163 | print('time: {}'.format(now)) 164 | x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32) 165 | random.shuffle(x) 166 | piece_num = 0 167 | for i in x: 168 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 169 | line = f.read().strip() 170 | tokens = line.split() 171 | tokens = [int(token) for token in tokens] 172 | start_point = 0 173 | samples = [] 174 | while start_point < len(tokens) - n_ctx: 175 | samples.append(tokens[start_point: start_point + n_ctx]) 176 | start_point += stride 177 | if start_point < len(tokens): 178 | samples.append(tokens[len(tokens)-n_ctx:]) 179 | random.shuffle(samples) 180 | for step in range(len(samples) // batch_size): # drop last 181 | 182 | # prepare data 183 | batch = samples[step * batch_size: (step + 1) * batch_size] 184 | batch_inputs = [] 185 | for ids in batch: 186 | int_ids = [int(x) for x in ids] 187 | batch_inputs.append(int_ids) 188 | batch_inputs = torch.tensor(batch_inputs).long().to(device) 189 | 190 | # forward pass 191 | outputs = model.forward(input_ids=batch_inputs, labels=batch_inputs) 192 | loss, logits = outputs[:2] 193 | 194 | # get loss 195 | if multi_gpu: 196 | loss = loss.mean() 197 | if gradient_accumulation > 1: 198 | loss = loss / gradient_accumulation 199 | 200 | # loss backward 201 | if fp16: 202 | with amp.scale_loss(loss, optimizer) as scaled_loss: 203 | scaled_loss.backward() 204 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) 205 | else: 206 | loss.backward() 207 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 208 | 209 | # optimizer step 210 | if (overall_step + 1) % gradient_accumulation == 0: 211 | running_loss += loss.item() 212 | optimizer.step() 213 | optimizer.zero_grad() 214 | scheduler.step() 215 | if (overall_step + 1) % log_step == 0: 216 | tb_writer.add_scalar('loss', loss.item() * gradient_accumulation, overall_step) 217 | print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format( 218 | datetime.now().hour, 219 | datetime.now().minute, 220 | step + 1, 221 | piece_num, 222 | epoch + 1, 223 | running_loss * gradient_accumulation / (log_step / gradient_accumulation))) 224 | running_loss = 0 225 | overall_step += 1 226 | piece_num += 1 227 | 228 | print('saving model for epoch {}'.format(epoch + 1)) 229 | if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)): 230 | os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1)) 231 | model_to_save = model.module if hasattr(model, 'module') else model 232 | model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1)) 233 | # torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1)) 234 | # torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1)) 235 | print('epoch {} finished'.format(epoch + 1)) 236 | 237 | then = datetime.now() 238 | print('time: {}'.format(then)) 239 | print('time for one epoch: {}'.format(then - now)) 240 | 241 | print('training finished') 242 | if not os.path.exists(output_dir + 'final_model'): 243 | os.mkdir(output_dir + 'final_model') 244 | model_to_save = model.module if hasattr(model, 'module') else model 245 | model_to_save.save_pretrained(output_dir + 'final_model') 246 | # torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt') 247 | # torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt') 248 | 249 | 250 | if __name__ == '__main__': 251 | main() 252 | -------------------------------------------------------------------------------- /train_single.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | import torch 3 | import os 4 | import json 5 | import random 6 | import argparse 7 | import numpy as np 8 | from datetime import datetime 9 | from torch.nn import DataParallel 10 | from tqdm import tqdm 11 | 12 | ''' 13 | 如果训练材料是全部堆在一起不分篇章的话用这个文件 14 | ''' 15 | 16 | 17 | def build_files(raw_data_path, tokenized_data_path, full_tokenizer, num_pieces): 18 | with open(raw_data_path, 'r', encoding='utf8') as f: 19 | print('reading lines') 20 | lines = json.load(f) 21 | lines = [line.replace('\n', ' [SEP] ') for line in lines] # 用[SEP]表示换行, 段落之间使用SEP表示段落结束 22 | single = ''.join(lines) 23 | len_single = len(single) 24 | if not os.path.exists(tokenized_data_path): 25 | os.mkdir(tokenized_data_path) 26 | for i in tqdm(range(num_pieces)): 27 | single_ids = full_tokenizer.convert_tokens_to_ids( 28 | full_tokenizer.tokenize(single[len_single // num_pieces * i: len_single // num_pieces * (i + 1)])) 29 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'w') as f: 30 | for id in single_ids[:-1]: 31 | f.write(str(id) + ' ') 32 | f.write(str(single_ids[-1])) 33 | f.write('\n') 34 | 35 | print('finish') 36 | 37 | 38 | def main(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--device', default='0,1,2,3', type=str, required=False, help='设置使用哪些显卡') 41 | parser.add_argument('--model_config', default='config/model_config_small.json', type=str, required=False, 42 | help='选择模型参数') 43 | parser.add_argument('--tokenizer_path', default='cache/vocab_small.txt', type=str, required=False, help='选择词库') 44 | parser.add_argument('--raw_data_path', default='data/train.json', type=str, required=False, help='原始训练语料') 45 | parser.add_argument('--tokenized_data_path', default='data/tokenized/', type=str, required=False, 46 | help='tokenized语料存放位置') 47 | parser.add_argument('--raw', action='store_true', help='是否先做tokenize') 48 | parser.add_argument('--epochs', default=5, type=int, required=False, help='训练循环') 49 | parser.add_argument('--batch_size', default=8, type=int, required=False, help='训练batch size') 50 | parser.add_argument('--lr', default=1.5e-4, type=float, required=False, help='学习率') 51 | parser.add_argument('--warmup_steps', default=2000, type=int, required=False, help='warm up步数') 52 | parser.add_argument('--log_step', default=1, type=int, required=False, help='多少步汇报一次loss') 53 | parser.add_argument('--stride', default=768, type=int, required=False, help='训练时取训练数据的窗口步长') 54 | parser.add_argument('--gradient_accumulation', default=1, type=int, required=False, help='梯度积累') 55 | parser.add_argument('--fp16', action='store_true', help='混合精度') 56 | parser.add_argument('--fp16_opt_level', default='O1', type=str, required=False) 57 | parser.add_argument('--max_grad_norm', default=1.0, type=float, required=False) 58 | parser.add_argument('--num_pieces', default=100, type=int, required=False, help='将训练语料分成多少份') 59 | parser.add_argument('--output_dir', default='model/', type=str, required=False, help='模型输出路径') 60 | parser.add_argument('--pretrained_model', default='', type=str, required=False, help='模型训练起点路径') 61 | parser.add_argument('--segment', action='store_true', help='中文以词为单位') 62 | 63 | args = parser.parse_args() 64 | print('args:\n' + args.__repr__()) 65 | 66 | if args.segment: 67 | from tokenizations import tokenization_bert_word_level as tokenization_bert 68 | else: 69 | from tokenizations import tokenization_bert 70 | 71 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device # 此处设置程序使用哪些显卡 72 | model_config = transformers.modeling_gpt2.GPT2Config.from_json_file(args.model_config) 73 | print('config:\n' + model_config.to_json_string()) 74 | 75 | n_ctx = model_config.n_ctx 76 | full_tokenizer = tokenization_bert.BertTokenizer(vocab_file=args.tokenizer_path) 77 | full_tokenizer.max_len = 999999 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 79 | print('using device:', device) 80 | 81 | raw_data_path = args.raw_data_path 82 | tokenized_data_path = args.tokenized_data_path 83 | raw = args.raw # 选择是否从零开始构建数据集 84 | epochs = args.epochs 85 | batch_size = args.batch_size 86 | lr = args.lr 87 | warmup_steps = args.warmup_steps 88 | log_step = args.log_step 89 | stride = args.stride 90 | gradient_accumulation = args.gradient_accumulation 91 | fp16 = args.fp16 # 不支持半精度的显卡请勿打开 92 | fp16_opt_level = args.fp16_opt_level 93 | max_grad_norm = args.max_grad_norm 94 | num_pieces = args.num_pieces 95 | output_dir = args.output_dir 96 | 97 | if raw: 98 | print('building files') 99 | build_files(raw_data_path=raw_data_path, tokenized_data_path=tokenized_data_path, full_tokenizer=full_tokenizer, 100 | num_pieces=num_pieces) 101 | print('files built') 102 | 103 | if not args.pretrained_model: 104 | model = transformers.modeling_gpt2.GPT2LMHeadModel(config=model_config) 105 | else: 106 | model = transformers.modeling_gpt2.GPT2LMHeadModel.from_pretrained(args.pretrained_model) 107 | model.train() 108 | model.to(device) 109 | multi_gpu = False 110 | full_len = 0 111 | print('calculating total steps') 112 | for i in tqdm(range(num_pieces)): 113 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 114 | full_len += len([int(item) for item in f.read().strip().split()]) 115 | total_steps = int(full_len / stride * epochs / batch_size / gradient_accumulation) 116 | print('total steps = {}'.format(total_steps)) 117 | 118 | optimizer = transformers.AdamW(model.parameters(), lr=lr, correct_bias=True) 119 | scheduler = transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, 120 | t_total=total_steps) 121 | if fp16: 122 | try: 123 | from apex import amp 124 | except ImportError: 125 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 126 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 127 | 128 | if torch.cuda.device_count() > 1: 129 | print("Let's use", torch.cuda.device_count(), "GPUs!") 130 | model = DataParallel(model) 131 | multi_gpu = True 132 | print('starting training') 133 | running_loss = 0 134 | for epoch in range(epochs): 135 | print('epoch {}'.format(epoch + 1)) 136 | now = datetime.now() 137 | print('time: {}'.format(now)) 138 | x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32) 139 | random.shuffle(x) 140 | piece_num = 0 141 | for i in x: 142 | with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f: 143 | line = f.read().strip() 144 | tokens = line.split() 145 | tokens = [int(token) for token in tokens] 146 | start_point = 0 147 | samples = [] 148 | while start_point < len(tokens) - n_ctx: 149 | samples.append(tokens[start_point: start_point + n_ctx]) 150 | start_point += stride 151 | if start_point < len(tokens): 152 | samples.append(tokens[len(tokens)-n_ctx:]) 153 | random.shuffle(samples) 154 | for step in range(len(samples) // batch_size): 155 | 156 | # prepare data 157 | batch = samples[step * batch_size: (step + 1) * batch_size] 158 | batch_labels = [] 159 | batch_inputs = [] 160 | for ids in batch: 161 | int_ids_for_labels = [int(x) for x in ids] 162 | int_ids_for_inputs = [int(x) for x in ids] 163 | batch_labels.append(int_ids_for_labels) 164 | batch_inputs.append(int_ids_for_inputs) 165 | batch_labels = torch.tensor(batch_labels).long().to(device) 166 | batch_inputs = torch.tensor(batch_inputs).long().to(device) 167 | 168 | # forward pass 169 | outputs = model.forward(input_ids=batch_inputs, labels=batch_labels) 170 | loss, logits = outputs[:2] 171 | 172 | # get loss 173 | if multi_gpu: 174 | loss = loss.mean() 175 | if gradient_accumulation > 1: 176 | loss = loss / gradient_accumulation 177 | 178 | # loss backward 179 | if fp16: 180 | with amp.scale_loss(loss, optimizer) as scaled_loss: 181 | scaled_loss.backward() 182 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) 183 | else: 184 | loss.backward() 185 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 186 | 187 | # optimizer step 188 | if (step + 1) % gradient_accumulation == 0: 189 | running_loss += loss.item() 190 | optimizer.step() 191 | optimizer.zero_grad() 192 | scheduler.step() 193 | if (step + 1) % log_step == 0: 194 | print('now time: {}:{}. Step {} of piece {} of epoch {}, loss {}'.format( 195 | datetime.now().hour, 196 | datetime.now().minute, 197 | (step + 1) // gradient_accumulation, 198 | piece_num, 199 | epoch + 1, 200 | running_loss / log_step)) 201 | running_loss = 0 202 | piece_num += 1 203 | 204 | print('saving model for epoch {}'.format(epoch + 1)) 205 | if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)): 206 | os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1)) 207 | model_to_save = model.module if hasattr(model, 'module') else model 208 | model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1)) 209 | # torch.save(scheduler.state_dict(), output_dir + 'model_epoch{}/scheduler.pt'.format(epoch + 1)) 210 | # torch.save(optimizer.state_dict(), output_dir + 'model_epoch{}/optimizer.pt'.format(epoch + 1)) 211 | print('epoch {} finished'.format(epoch + 1)) 212 | 213 | then = datetime.now() 214 | print('time: {}'.format(then)) 215 | print('time for one epoch: {}'.format(then - now)) 216 | 217 | print('training finished') 218 | if not os.path.exists(output_dir + 'final_model'): 219 | os.mkdir(output_dir + 'final_model') 220 | model_to_save = model.module if hasattr(model, 'module') else model 221 | model_to_save.save_pretrained(output_dir + 'final_model') 222 | # torch.save(scheduler.state_dict(), output_dir + 'final_model/scheduler.pt') 223 | # torch.save(optimizer.state_dict(), output_dir + 'final_model/optimizer.pt') 224 | 225 | 226 | if __name__ == '__main__': 227 | main() 228 | --------------------------------------------------------------------------------