├── .gitignore ├── LICENSE ├── QA_data ├── QA.db ├── QA_test.py ├── __init__.py └── stop_words.txt ├── README.md ├── checkpoints └── chatbot_0509_1437 ├── clean_chat_corpus └── qingyun.tsv ├── config.py ├── corpus.pth ├── dataload.py ├── datapreprocess.py ├── main.py ├── model.py ├── requirements.txt ├── train_eval.py └── utils ├── __init__.py ├── beamsearch.py └── greedysearch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /QA_data/QA.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/QA.db -------------------------------------------------------------------------------- /QA_data/QA_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: UTF-8 -*- 2 | 3 | import sqlite3 4 | import jieba 5 | import logging 6 | jieba.setLogLevel(logging.INFO) #设置不输出信息 7 | 8 | conn = sqlite3.connect('./QA_data/QA.db') 9 | 10 | cursor = conn.cursor() 11 | stop_words = [] 12 | with open('./QA_data/stop_words.txt', encoding='gbk') as f: 13 | for line in f.readlines(): 14 | stop_words.append(line.strip('\n')) 15 | 16 | def match(input_question): 17 | res = [] 18 | cnt = {} 19 | question = list(jieba.cut(input_question, cut_all=False)) #对查询字符串进行分词 20 | for word in reversed(question): #去除停用词 21 | if word in stop_words: 22 | question.remove(word) 23 | for tag in question: #按照每个tag,循环构造查询语句 24 | keyword = "'%" + tag + "%'" 25 | result = cursor.execute("select * from QA where tag like " + keyword) 26 | for row in result: 27 | if row[0] not in cnt.keys(): 28 | cnt[row[0]] = 0 29 | cnt[row[0]] += 1 #统计记录出现的次数 30 | try: 31 | res_id = sorted(cnt.items(), key=lambda d:d[1],reverse=True)[0][0] #返回出现次数最高的记录的id 32 | except: 33 | return tuple() #若查询不出则返回空 34 | cursor.execute("select * from QA where id= " + str(res_id)) 35 | res = cursor.fetchone() 36 | if type(res) == type(tuple()): 37 | return res #返回元组类型(id, question, answer, tag) 38 | else: 39 | return tuple() #若查询不出则返回空 40 | 41 | -------------------------------------------------------------------------------- /QA_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/__init__.py -------------------------------------------------------------------------------- /QA_data/stop_words.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/stop_words.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍀 小智,又一个中文聊天机器人:yum: 2 | 3 | 💖 利用有趣的中文语料库qingyun,由@Doragd 同学编写的中文聊天机器人:snowman: 4 | 5 | * 尽管她不是那么完善:muscle:,不是那么出色:paw_prints: 6 | * 但她是由我自己coding出来的:sparkling_heart: ,所以![](https://img.shields.io/badge/-It%20means%20everything-ff69b4.svg) 7 | 8 | * **希望大家能够多多star支持**:star: 这个NLP初学者:runner:和他的朋友🍀 小智 9 | 10 | 11 | 12 | ## :rainbow:背景 13 | 14 | 这个项目实际是软件工程课程设计的子模块。我们的目标是开发一个智能客服工单处理系统。 15 | 16 | 智能客服工单系统实际的工作流程是:当人向系统发出提问时,系统首先去知识库中查找是否存在相关问题,如果有,则返回问题的答案,此时如何人不满意,则可以直接提交工单。如果知识库中不存在,则调用这个聊天机器人进行自动回复。 17 | 18 | 该系统服务的场景类似腾讯云的客服系统,客户多是来咨询相关问题的(云服务器,域名等),所以知识库也是有关云服务器,域名等的咨询,故障处理的 (问题,答案) 集合。 19 | 20 | 系统的前端界面和前后端消息交互由另一个同学@adjlyadv 完成,主要采用React+Django方式。 21 | 22 | @Doragd 同学负责的是知识库的获取和聊天机器人的编写,训练,测试。这个repo的内容也是关于这个的。 23 | 24 | 25 | 26 | ## :star2: 测试效果 27 | 28 | * 不使用知识库进行回答 29 | 30 | 31 | * 使用知识库进行回答 32 | 33 | 34 | 35 | * 整个系统效果: 36 | 37 | 38 | 39 | 40 | ## :floppy_disk:项目结构 41 | 42 | ``` 43 | │ .gitignore 44 | │ config.py #模型配置参数 45 | │ corpus.pth #已经过处理的数据集 46 | │ dataload.py #dataloader 47 | │ datapreprocess.py #数据预处理 48 | │ LICENSE 49 | │ main.py 50 | │ model.py 51 | │ README.md 52 | │ requirements.txt 53 | │ train_eval.py #训练和验证,测试 54 | │ 55 | ├─checkpoints 56 | │ chatbot_0509_1437 #已经训练好的模型 57 | │ 58 | ├─clean_chat_corpus 59 | │ qingyun.tsv #语料库 60 | │ 61 | ├─QA_data 62 | │ QA.db #知识库 63 | │ QA_test.py #使用知识库时调用 64 | │ stop_words.txt #停用词 65 | │ __init__.py 66 | │ 67 | └─utils 68 | beamsearch.py #to do 未完工 69 | greedysearch.py #贪婪搜索,用于测试 70 | __init__.py 71 | ``` 72 | 73 | 74 | 75 | ## :couple:依赖库 76 | 77 | ![torch](https://img.shields.io/badge/torch-1.0.1-orange.svg) 78 | ![torchnet](https://img.shields.io/badge/torchnet-0.0.4-brightgreen.svg) 79 | ![fire](https://img.shields.io/badge/fire-0.1.3-red.svg) 80 | ![jieba](https://img.shields.io/badge/jieba-0.39-blue.svg) 81 | 82 | 安装依赖 83 | 84 | ```shell 85 | $ pip install -r requirements.txt 86 | ``` 87 | 88 | 89 | 90 | ## :sparkling_heart:开始使用 91 | 92 | ### 数据预处理(可省略) 93 | 94 | ```shell 95 | $ python datapreprocess.py 96 | ``` 97 | 98 | 对语料库进行预处理,产生corpus.pth (**这里已经上传好corpus.pth, 故此步可以省略**) 99 | 100 | 可修改参数: 101 | 102 | ``` 103 | # datapreprocess.py 104 | corpus_file = 'clean_chat_corpus/qingyun.tsv' #未处理的对话数据集 105 | max_voc_length = 10000 #字典最大长度 106 | min_word_appear = 10 #加入字典的词的词频最小值 107 | max_sentence_length = 50 #最大句子长度 108 | save_path = 'corpus.pth' #已处理的对话数据集保存路径 109 | ``` 110 | 111 | ### 使用 112 | 113 | * 使用知识库 114 | 115 | 使用知识库时, 需要传入参数`use_QA_first=True` 此时,对于输入的字符串,首先在知识库中匹配最佳的问题和答案,并返回。找不到时,才调用聊天机器人自动生成回复。 116 | 117 | 这里的知识库是爬取整理的腾讯云官方文档中的常见问题和答案,100条,仅用于测试! 118 | 119 | ```shell 120 | $ python main.py chat --use_QA_first=True 121 | ``` 122 | 123 | * 不使用知识库 124 | 125 | 由于课程设计需要,加入了腾讯云的问题答案对,但对于聊天机器人这个项目来说是无关紧要的,所以一般使用时,`use_QA_first=False` ,该参数默认为`True` 126 | 127 | ```shell 128 | $ python main.py chat --use_QA_first=False 129 | ``` 130 | 131 | * 使用默认参数 132 | 133 | ```shell 134 | $ python main.py chat 135 | ``` 136 | 137 | * 退出聊天:输入`exit`, `quit`, `q` 均可 138 | 139 | ### 其他可配置参数 140 | 141 | 在`config.py` 文件中说明 142 | 143 | 需要传入新的参数时,只需要命令行传入即可,形如 144 | 145 | ```shell 146 | $ python main.py chat --model_ckpt='checkpoints/chatbot_0509_1437' --use_QA_first=False 147 | ``` 148 | 149 | 上面的命令指出了加载已训练模型的路径和是否使用知识库 150 | 151 | 152 | 153 | ## :cherry_blossom:技术实现 154 | 155 | ### 语料库 156 | 157 | | 语料名称 | 语料数量 | 语料来源说明 | 语料特点 | 语料样例 | 是否已分词 | 158 | | ------------------- | -------- | ------------------ | ---------------- | ----------------------------------------- | ---------- | 159 | | qingyun(青云语料) | 10W | 某聊天机器人交流群 | 相对不错,生活化 | Q:看来你很爱钱 A:噢是吗?那么你也差不多了 | 否 | 160 | 161 | * 来源: 162 | 163 | ### Seq2Seq 164 | 165 | * Encoder:两层双向GRU 166 | * Decoder:双层单向GRU 167 | 168 | ### Attention 169 | 170 | * Global attention,采用dot计算分数 171 | * Ref. https://arxiv.org/abs/1508.04025 172 | 173 | 174 | 175 | ## :construction_worker:模型训练与评估 176 | 177 | ```shell 178 | $ python train_eval.py train [--options] 179 | ``` 180 | 181 | 定量评估部分暂时还没写好,应该采用困惑度来衡量,目前只能生成句子,人为评估质量 182 | 183 | ```shell 184 | $ python train_eval.py eval [--options] 185 | ``` 186 | 187 | 188 | 189 | ## :sob:跳坑记录与总结 190 | 191 | * 最深刻的体会就是“深度学习知识的了解和理解之间差了N个编程实现”。虽然理论大家都很清楚,但是真正到编程实现时,总会出这样,那样的问题:从数据集的处理,到许多公式的编程实现,到参数的调节,GPU配置等等各种问题 192 | * 这次实践的过程实际是跟着PyTorch Tutorial先过了一遍Chatbot部分,跑通以后,再更换语料库,处理语料库,再按照类的风格去重构了代码,然后就是无尽的Debug过程,遇到了很多坑,尤其是把张量移到GPU上遇到各种问题,主要是不清楚to(device)时究竟移动了哪些。 193 | * 通过测试发现,model.to(device)只会把参数移到GPU,不会把类中定义的成员tensor移过去,所以如果在forward方法中定义了新的张量,要记得移动。 194 | * 还有就是移动的顺序问题:先把模型移动到GPU,再去定义优化器。以及移动的方法:model=model.to(device),不要忘记赋值。 195 | * 很容易出现GPU显存不足的情况,注意写代码时要考虑内存利用率问题,尽量减少重复tensor。 196 | * 在一开始更换中文语料库后,训练总是不收敛,最后才发现原来是batch_size设置小了,实际上我感觉batch_size在显存足够时要尽量大,其实之前看到过,只是写代码的时候完全忘记这回事了。说明自己当时看mini-batch时还不够理解,还是要真的写代码才能够深入人心,至少bug深入人心 197 | * 还有一个问题就是误解了torch.long,以为是高精度浮点,结果是int64型,造成了一个bug,找了好久才发现怎么回事。这告诉我们要认真看文档。 198 | * 最后的收获就是熟悉了如何实际实现一个模型,这很重要。 199 | * 实际上这个模型的效果不是很好,除开模型本身的问题不谈,我发现分词的质量会严重影响句子的质量,但是分词时我连停用词还没设置,会出现一些奇特的结果 200 | * 还有一个问题是处理变长序列时,损失函数如果用自己定义的,很容易出现不稳定情况,现在还在研究官方API 201 | * 本次实践还发现自己对一些参数理解还不够深,不知道怎么调,还要补理论。 202 | * 对模型的评估这部分还要继续做。 203 | 204 | ## :pray:致谢 205 | 206 | * 官方的Chatbot Tutorial 207 | * 208 | * 提供中文语料库 209 | * 210 | * 与官方的Chatbot Tutorial内容一致,但是有详尽的代码注释 211 | * 212 | * 模型的写法和习惯均参考 213 | * 214 | -------------------------------------------------------------------------------- /checkpoints/chatbot_0509_1437: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/checkpoints/chatbot_0509_1437 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | 5 | class Config: 6 | ''' 7 | Chatbot模型参数 8 | ''' 9 | corpus_data_path = 'corpus.pth' #已处理的对话数据 10 | use_QA_first = True #是否载入知识库 11 | max_input_length = 50 #输入的最大句子长度 12 | max_generate_length = 20 #生成的最大句子长度 13 | prefix = 'checkpoints/chatbot' #模型断点路径前缀 14 | model_ckpt = 'checkpoints/chatbot_0509_1437' #加载模型路径 15 | ''' 16 | 训练超参数 17 | ''' 18 | batch_size = 2048 19 | shuffle = True #dataloader是否打乱数据 20 | num_workers = 0 #dataloader多进程提取数据 21 | bidirectional = True #Encoder-RNN是否双向 22 | hidden_size = 256 23 | embedding_dim = 256 24 | method = 'dot' #attention method 25 | dropout = 0 #是否使用dropout 26 | clip = 50.0 #梯度裁剪阈值 27 | num_layers = 2 #Encoder-RNN层数 28 | learning_rate = 1e-3 29 | teacher_forcing_ratio = 1.0 #teacher_forcing比例 30 | decoder_learning_ratio = 5.0 31 | ''' 32 | 训练周期信息 33 | ''' 34 | epoch = 6000 35 | print_every = 1 #每隔print_every个Iteration打印一次 36 | save_every = 50 #每隔save_every个Epoch打印一次 37 | ''' 38 | GPU 39 | ''' 40 | use_gpu = torch.cuda.is_available() #是否使用gpu 41 | device = torch.device("cuda" if use_gpu else "cpu") #device 42 | 43 | 44 | -------------------------------------------------------------------------------- /corpus.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/corpus.pth -------------------------------------------------------------------------------- /dataload.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import itertools 5 | from torch.utils import data as dataimport 6 | 7 | def zeroPadding(l, fillvalue): 8 | ''' 9 | l是多个长度不同的句子(list),使用zip_longest padding成定长,长度为最长句子的长度 10 | 在zeroPadding函数中隐式转置 11 | [batch_size, max_seq_len] ==> [max_seq_len, batch_size] 12 | ''' 13 | return list(itertools.zip_longest(*l, fillvalue=fillvalue)) 14 | 15 | def binaryMatrix(l, value): 16 | ''' 17 | 生成mask矩阵, 0表示padding,1表示未padding 18 | shape同l,即[max_seq_len, batch_size] 19 | ''' 20 | m = [] 21 | for i, seq in enumerate(l): 22 | m.append([]) 23 | for token in seq: 24 | if token == value: 25 | m[i].append(0) 26 | else: 27 | m[i].append(1) 28 | return m 29 | 30 | def create_collate_fn(padding, eos): 31 | ''' 32 | 说明dataloader如何包装一个batch,传入的参数为的索引padding,字符索引eos 33 | collate_fn传入的参数是由一个batch的__getitem__方法的返回值组成的corpus_item 34 | 35 | corpus_item: 36 | lsit, 形如[(inputVar1, targetVar1, index1),(inputVar2, targetVar2, index2),...] 37 | inputVar1: [word_ix, word_ix,word_ix,...] 38 | targetVar1: [word_ix, word_ix,word_ix,...] 39 | inputs: 40 | 取出所有inputVar组成的list,形如[inputVar1,inputVar2,inputVar3,...], 41 | padding后(这里有隐式转置)转为tensor后形状为:[max_seq_len, batch_size] 42 | targets: 43 | 取出所有targetVar组成的list,形如[targetVar1,targetVar2,targetVar3,...] 44 | padding后(这里有隐式转置)转为tensor后形状为:[max_seq_len, batch_size] 45 | input_lengths: 46 | 在padding前要记录原来的inputVar的长度, 用于pad_packed_sequence 47 | 形如: [length_inputVar1, length_inputVar2, length_inputVar3, ...] 48 | max_targets_length: 49 | 该批次的所有target的最大长度 50 | mask: 51 | 形状: [max_seq_len, batch_size] 52 | indexes: 53 | 记录一个batch中每个 句子对 在corpus数据集中的位置 54 | 形如: [index1, index2, ...] 55 | 56 | ''' 57 | def collate_fn(corpus_item): 58 | #按照inputVar的长度进行排序,是调用pad_packed_sequence方法的要求 59 | corpus_item.sort(key=lambda p: len(p[0]), reverse=True) 60 | inputs, targets, indexes = zip(*corpus_item) 61 | input_lengths = torch.tensor([len(inputVar) for inputVar in inputs]) 62 | inputs = zeroPadding(inputs, padding) 63 | inputs = torch.LongTensor(inputs) #注意这里要LongTensor 64 | 65 | max_target_length = max([len(targetVar) for targetVar in targets]) 66 | targets = zeroPadding(targets, padding) 67 | mask = binaryMatrix(targets, padding) 68 | mask = torch.ByteTensor(mask) 69 | targets = torch.LongTensor(targets) 70 | 71 | 72 | return inputs, targets, mask, input_lengths, max_target_length, indexes 73 | 74 | return collate_fn 75 | 76 | 77 | 78 | 79 | class CorpusDataset(dataimport.Dataset): 80 | 81 | def __init__(self, opt): 82 | self.opt = opt 83 | self._data = torch.load(opt.corpus_data_path) 84 | self.word2ix = self._data['word2ix'] 85 | self.corpus = self._data['corpus'] 86 | self.padding = self.word2ix.get(self._data.get('padding')) 87 | self.eos = self.word2ix.get(self._data.get('eos')) 88 | self.sos = self.word2ix.get(self._data.get('sos')) 89 | 90 | def __getitem__(self, index): 91 | inputVar = self.corpus[index][0] 92 | targetVar = self.corpus[index][1] 93 | return inputVar,targetVar, index 94 | 95 | def __len__(self): 96 | return len(self.corpus) 97 | 98 | 99 | def get_dataloader(opt): 100 | dataset = CorpusDataset(opt) 101 | dataloader = dataimport.DataLoader(dataset, 102 | batch_size=opt.batch_size, 103 | shuffle=opt.shuffle, #是否打乱数据 104 | num_workers=opt.num_workers, #多进程提取数据 105 | drop_last=True, #丢掉最后一个不足一个batch的数据 106 | collate_fn=create_collate_fn(dataset.padding, dataset.eos)) 107 | return dataloader -------------------------------------------------------------------------------- /datapreprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import jieba 4 | import torch 5 | import re 6 | import logging 7 | jieba.setLogLevel(logging.INFO) #关闭jieba输出信息 8 | 9 | corpus_file = 'clean_chat_corpus/qingyun.tsv' #未处理的对话数据集 10 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则 11 | unknown = '' #unknown字符 12 | eos = '' #句子结束符 13 | sos = '' #句子开始符 14 | padding = '' #句子填充负 15 | max_voc_length = 10000 #字典最大长度 16 | min_word_appear = 10 #加入字典的词的词频最小值 17 | max_sentence_length = 50 #最大句子长度 18 | save_path = 'corpus.pth' #已处理的对话数据集保存路径 19 | 20 | def preprocess(): 21 | print("preprocessing...") 22 | '''处理对话数据集''' 23 | data = [] 24 | with open(corpus_file, encoding='utf-8') as f: 25 | lines = f.readlines() 26 | for line in lines: 27 | values = line.strip('\n').split('\t') 28 | sentences = [] 29 | for value in values: 30 | sentence = jieba.lcut(cop.sub("",value)) 31 | sentence = sentence[:max_sentence_length] + [eos] 32 | sentences.append(sentence) 33 | data.append(sentences) 34 | 35 | '''生成字典和句子索引''' 36 | word_nums = {} #统计单词的词频 37 | def update(word_nums): 38 | def fun(word): 39 | word_nums[word] = word_nums.get(word, 0) + 1 40 | return None 41 | return fun 42 | lambda_ = update(word_nums) 43 | _ = {lambda_(word) for sentences in data for sentence in sentences for word in sentence} 44 | #按词频从高到低排序 45 | word_nums_list = sorted([(num, word) for word, num in word_nums.items()], reverse=True) 46 | #词典最大长度: max_voc_length 最小单词词频: min_word_appear 47 | words = [word[1] for word in word_nums_list[:max_voc_length] if word[0] >= min_word_appear] 48 | #注意: 这里eos已经在前面加入了words,故这里不用重复加入 49 | words = [unknown, padding, sos] + words 50 | word2ix = {word: ix for ix, word in enumerate(words)} 51 | ix2word = {ix: word for word, ix in word2ix.items()} 52 | ix_corpus = [[[word2ix.get(word, word2ix.get(unknown)) for word in sentence] 53 | for sentence in item] 54 | for item in data] 55 | 56 | ''' 57 | 保存处理好的对话数据集 58 | 59 | ix_corpus: list, 其中元素为[Question_sequence_list, Answer_seqence_list] 60 | Question_sequence_list: e.g. [word_ix, word_ix, word_ix, ...] 61 | 62 | word2ix: dict, 单词:索引 63 | 64 | ix2word: dict, 索引:单词 65 | 66 | ''' 67 | clean_data = { 68 | 'corpus': ix_corpus, 69 | 'word2ix': word2ix, 70 | 'ix2word': ix2word, 71 | 'unknown' : '', 72 | 'eos' : '', 73 | 'sos' : '', 74 | 'padding': '', 75 | } 76 | torch.save(clean_data, save_path) 77 | print('save clean data in %s' % save_path) 78 | 79 | if __name__ == "__main__": 80 | preprocess() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | from datapreprocess import preprocess 4 | import train_eval 5 | import fire 6 | from QA_data import QA_test 7 | from config import Config 8 | 9 | 10 | 11 | def chat(**kwargs): 12 | 13 | opt = Config() 14 | for k, v in kwargs.items(): #设置参数 15 | setattr(opt, k, v) 16 | 17 | searcher, sos, eos, unknown, word2ix, ix2word = train_eval.test(opt) 18 | 19 | if os.path.isfile(opt.corpus_data_path) == False: 20 | preprocess() 21 | 22 | while(1): 23 | input_sentence = input('Doragd > ') 24 | if input_sentence == 'q' or input_sentence == 'quit' or input_sentence == 'exit': break 25 | if opt.use_QA_first: 26 | query_res = QA_test.match(input_sentence) 27 | if(query_res == tuple()): 28 | output_words = train_eval.output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word) 29 | else: 30 | output_words = "您是不是要找以下问题: " + query_res[1] + ',您可以尝试这样: ' + query_res[2] 31 | else: 32 | output_words = train_eval.output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word) 33 | print('BOT > ',output_words) 34 | 35 | QA_test.conn.close() 36 | 37 | if __name__ == "__main__": 38 | fire.Fire() 39 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | import logging 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from utils.greedysearch import GreedySearchDecoder 7 | 8 | 9 | class EncoderRNN(nn.Module): 10 | def __init__(self, opt, voc_length): 11 | ''' 12 | voc_length: 字典长度,即输入的单词的one-hot编码长度 13 | ''' 14 | super(EncoderRNN, self).__init__() 15 | self.num_layers = opt.num_layers 16 | self.hidden_size = opt.hidden_size 17 | #nn.Embedding输入向量维度是字典长度,输出向量维度是词向量维度 18 | self.embedding = nn.Embedding(voc_length, opt.embedding_dim) 19 | #双向GRU作为Encoder 20 | self.gru = nn.GRU(opt.embedding_dim, self.hidden_size, self.num_layers, 21 | dropout=(0 if opt.num_layers == 1 else opt.dropout), bidirectional=opt.bidirectional) 22 | 23 | def forward(self, input_seq, input_lengths, hidden=None): 24 | ''' 25 | input_seq: 26 | shape: [max_seq_len, batch_size] 27 | input_lengths: 28 | 一批次中每个句子对应的句子长度列表 29 | shape:[batch_size] 30 | hidden: 31 | Encoder的初始hidden输入,默认为None 32 | shape: [num_layers*num_directions, batch_size, hidden_size] 33 | 实际排列顺序是num_directions在前面, 34 | 即对于4层双向的GRU, num_layers*num_directions = 8 35 | 前4层是正向: [:4, batch_size, hidden_size] 36 | 后4层是反向: [4:, batch_size, hidden_size] 37 | embedded: 38 | 经过词嵌入后的词向量 39 | shape: [max_seq_len, batch_size, embedding_dim] 40 | outputs: 41 | 所有时刻的hidden层输出 42 | 一开始的shape: [max_seq_len, batch_size, hidden_size*num_directions] 43 | 注意: num_directions在前面, 即前面hidden_size个是正向的,后面hidden_size个是反向的 44 | 正向: [:, :, :hidden_size] 反向: [:, :, hidden_size:] 45 | 最后对双向GRU求和,得到最终的outputs: shape为[max_seq_len, batch_size, hidden_size] 46 | 输出的hidden: 47 | [num_layers*num_directions, batch_size, hidden_size] 48 | ''' 49 | 50 | embedded = self.embedding(input_seq) 51 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) 52 | outputs, hidden = self.gru(packed, hidden) 53 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) 54 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] 55 | return outputs, hidden 56 | 57 | 58 | 59 | class Attn(torch.nn.Module): 60 | def __init__(self, attn_method, hidden_size): 61 | super(Attn, self).__init__() 62 | self.method = attn_method #attention方法 63 | self.hidden_size = hidden_size 64 | if self.method not in ['dot', 'general', 'concat']: 65 | raise ValueError(self.method, "is not an appropriate attention method.") 66 | if self.method == 'general': 67 | self.attn = torch.nn.Linear(self.hidden_size, self.hidden_size) 68 | elif self.method == 'concat': 69 | self.attn = torch.nn.Linear(self.hidden_size * 2, self.hidden_size) 70 | self.v = torch.nn.Parameter(torch.FloatTensor(self.hidden_size)) 71 | 72 | def dot_score(self, hidden, encoder_outputs): 73 | ''' 74 | encoder_outputs: 75 | encoder(双向GRU)的所有时刻的最后一层的hidden输出 76 | shape: [max_seq_len, batch_size, hidden_size] 77 | 数学符号表示: h_s 78 | hidden: 79 | decoder(单向GRU)的所有时刻的最后一层的hidden输出,即decoder_ouputs 80 | shape: [max_seq_len, batch_size, hidden_size] 81 | 数学符号表示: h_t 82 | 注意: attention method: 'dot', Hadamard乘法,对应元素相乘,用*就好了 83 | torch.matmul是矩阵乘法, 所以最后的结果是h_s * h_t 84 | h_s的元素是一个hidden_size向量, 要得到score值,需要在dim=2上求和 85 | 相当于先不看batch_size,h_s * h_t 要得到的是 [max_seq_len] 86 | 即每个时刻都要得到一个分数值, 最后把batch_size加进来, 87 | 最终shape为: [max_seq_len, batch_size] 88 | ''' 89 | 90 | return torch.sum(hidden * encoder_outputs, dim=2) 91 | 92 | def general_score(self, hidden, encoder_outputs): 93 | #先学习到一个线性变换Wh_s,即energy 94 | #然后进行点乘,所以最后结果为 h_t*Wh_s 95 | energy = self.attn(encoder_outputs) 96 | return torch.sum(hidden * energy, dim=2) 97 | 98 | def concat_score(self, hidden, encoder_outputs): 99 | ''' 100 | hidden: 101 | h_t, shape: [max_seq_len, batch_size, hidden_size] 102 | expand(max_seq_len, -1,-1) ==> [max_seq_len, batch_size, hidden_size] 103 | 与encoder_outputs在第2维上进行cat, 最后shape: [max_seq_len, batch_size, hidden_size*2] 104 | 经过attn后得到[max_seq_len, batch_size, hidden_size],再进行tanh,shape不变 105 | 最后与v乘 106 | ''' 107 | energy = self.attn(torch.cat((hidden.expand(encoder_outputs.size(0), -1, -1), 108 | encoder_outputs), 2)).tanh() 109 | return torch.sum(self.v * energy, dim=2) 110 | 111 | def forward(self, hidden, encoder_outputs): 112 | if self.method == 'general': 113 | attn_energies = self.general_score(hidden, encoder_outputs) 114 | elif self.method == 'concat': 115 | attn_energies = self.concat_score(hidden, encoder_outputs) 116 | elif self.method == 'dot': 117 | attn_energies = self.dot_score(hidden, encoder_outputs) 118 | #得到score,shape为[max_seq_len, batch_size],然后转置为[batch_size, max_seq_len] 119 | attn_energies = attn_energies.t() 120 | #对dim=1进行softmax,然后插入维度[batch_size, 1, max_seq_len] 121 | return F.softmax(attn_energies, dim=1).unsqueeze(1) 122 | 123 | class LuongAttnDecoderRNN(nn.Module): 124 | def __init__(self, opt, voc_length): 125 | super(LuongAttnDecoderRNN, self).__init__() 126 | 127 | self.attn_method = opt.method 128 | self.hidden_size = opt.hidden_size 129 | self.output_size = voc_length 130 | self.num_layers = opt.num_layers 131 | self.dropout = opt.dropout 132 | self.embedding = nn.Embedding(voc_length, opt.embedding_dim) 133 | self.embedding_dropout = nn.Dropout(self.dropout) 134 | self.gru = nn.GRU(opt.embedding_dim, self.hidden_size, self.num_layers, dropout=(0 if self.num_layers == 1 else self.dropout)) 135 | self.concat = nn.Linear(self.hidden_size * 2, self.hidden_size) 136 | self.out = nn.Linear(self.hidden_size, self.output_size) 137 | self.attn = Attn(self.attn_method, self.hidden_size) 138 | 139 | def forward(self, input_step, last_hidden, encoder_outputs): 140 | ''' 141 | input_step: 142 | decoder是逐字生成的,即每个timestep产生一个字, 143 | decoder接收的输入: input_step='/SOS'的索引 和 encoder的最后时刻的最后一层hidden输出 144 | 故shape:[1, batch_size] 145 | last_hidden: 146 | 上一个GRUCell的hidden输出 147 | 初始值为encoder的最后时刻的最后一层hidden输出,传入的是encoder_hidden的正向部分, 148 | 即encoder_hidden[:decoder.num_layers], 为了和decoder对应,所以取的是decoder的num_layers 149 | shape为[num_layers, batch_size, hidden_size] 150 | encoder_outputs: 151 | 这里还接收了encoder_outputs输入,用于计算attention 152 | ''' 153 | #转为词向量[1, batch_size, embedding_dim] 154 | embedded = self.embedding(input_step) 155 | embedded = self.embedding_dropout(embedded) 156 | #rnn_output: [1, batch_size, hidden_size] 157 | #hidden: [num_layers, batch_size, hidden_size] 158 | rnn_output, hidden = self.gru(embedded, last_hidden) 159 | #attn_weights: [batch_size, 1, max_seq_len] 160 | attn_weights = self.attn(rnn_output, encoder_outputs) 161 | #bmm批量矩阵相乘, 162 | #attn_weights是batch_size个矩阵[1, max_seq_len] 163 | #encoder_outputs.transpose(0, 1)是batch_size个[max_seq_len, hidden_size] 164 | #相乘结果context为: [batch_size, 1, hidden_size] 165 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) 166 | rnn_output = rnn_output.squeeze(0) 167 | context = context.squeeze(1) 168 | #压缩维度后,rnn_output[batch_size, hidden_size],context[batch_size, hidden_size] 169 | #在dim=1上连接: [batch_size, hidden_size * 2], contact后[batch_size, hidden_size] 170 | #tanh后shape不变: [batch_size, hidden_size] 171 | #最后进行全连接,映射为[batch_size, voc_length] 172 | #最后进行softmax: [batch_size, voc_length] 173 | concat_input = torch.cat((rnn_output, context), 1) 174 | concat_output = torch.tanh(self.concat(concat_input)) 175 | output = self.out(concat_output) 176 | output = F.softmax(output, dim=1) 177 | 178 | return output, hidden -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchnet==0.0.4 2 | fire==0.1.3 3 | torch==1.0.1 4 | jieba==0.39 5 | -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import time 4 | import random 5 | import jieba 6 | import torch 7 | import logging 8 | import torch.nn as nn 9 | from torchnet import meter 10 | from model import EncoderRNN, LuongAttnDecoderRNN 11 | from utils.greedysearch import GreedySearchDecoder 12 | from dataload import get_dataloader 13 | from config import Config 14 | jieba.setLogLevel(logging.INFO) #关闭jieba输出信息 15 | 16 | def maskNLLLoss(inp, target, mask): 17 | ''' 18 | inp: shape [batch_size,voc_length] 19 | target: shape [batch_size] 经过view ==> [batch_size, 1] 这样就和inp维数相同,可以用gather 20 | target作为索引,在dim=1上索引inp的值,得到的形状同target [batch_size, 1] 21 | 然后压缩维度,得到[batch_size], 取负对数后 22 | 选择那些值为1的计算loss, 并求平均,得到loss 23 | 故loss实际是batch_size那列的均值,表示一个句子在某个位置(t)上的平均损失值 24 | 故nTotal表示nTotal个句子在某个位置上有值 25 | mask: shape [batch_size] 26 | loss: 平均一个句子在t位置上的损失值 27 | ''' 28 | nTotal = mask.sum() #padding是0,非padding是1,因此sum就可以得到词的个数 29 | crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1)) 30 | loss = crossEntropy.masked_select(mask).mean() 31 | return loss, nTotal.item() 32 | 33 | def train_by_batch(sos, opt, data, encoder_optimizer, decoder_optimizer, encoder, decoder): 34 | #清空梯度 35 | 36 | encoder_optimizer.zero_grad() 37 | decoder_optimizer.zero_grad() 38 | 39 | #处理一个batch数据 40 | inputs, targets, mask, input_lengths, max_target_length, indexes = data 41 | inputs = inputs.to(opt.device) 42 | targets = targets.to(opt.device) 43 | mask = mask.to(opt.device) 44 | input_lengths = input_lengths.to(opt.device) 45 | 46 | 47 | # 初始化变量 48 | loss = 0 49 | print_losses = [] 50 | n_totals = 0 51 | 52 | #forward计算 53 | ''' 54 | inputs: shape [max_seq_len, batch_size] 55 | input_lengths: shape [batch_size] 56 | encoder_outputs: shape [max_seq_len, batch_size, hidden_size] 57 | encoder_hidden: shape [num_layers*num_directions, batch_size, hidden_size] 58 | decoder_input: shape [1, batch_size] 59 | decoder_hidden: decoder的初始hidden输入,是encoder_hidden取正方向 60 | ''' 61 | encoder_outputs, encoder_hidden = encoder(inputs, input_lengths) 62 | decoder_input = torch.LongTensor([[sos for _ in range(opt.batch_size)]]) 63 | decoder_input = decoder_input.to(opt.device) 64 | decoder_hidden = encoder_hidden[:decoder.num_layers] 65 | 66 | # 确定是否teacher forcing 67 | use_teacher_forcing = True if random.random() < opt.teacher_forcing_ratio else False 68 | 69 | ''' 70 | 一次处理一个时刻,即一个字符 71 | decoder_output: [batch_size, voc_length] 72 | decoder_hidden: [decoder_num_layers, batch_size, hidden_size] 73 | 如果使用teacher_forcing,下一个时刻的输入是当前正确答案,即 74 | targets[t] shape: [batch_size] ==> view后 [1, batch_size] 作为decoder_input 75 | ''' 76 | if use_teacher_forcing: 77 | for t in range(max_target_length): 78 | decoder_output, decoder_hidden = decoder( 79 | decoder_input, decoder_hidden, encoder_outputs 80 | ) 81 | decoder_input = targets[t].view(1, -1) 82 | 83 | 84 | # 计算累计的loss 85 | ''' 86 | 每次迭代, 87 | targets[t]: 一个batch所有样本指定位置(t位置)上的值组成的向量,shape [batch_size] 88 | mask[t]: 一个batch所有样本指定位置(t位置)上的值组成的向量, 值为1表示此处未padding 89 | decoder_output: [batch_size, voc_length] 90 | ''' 91 | mask_loss, nTotal = maskNLLLoss(decoder_output, targets[t], mask[t]) 92 | mask_loss = mask_loss.to(opt.device) 93 | loss += mask_loss 94 | ''' 95 | 这里loss在seq_len方向迭代进行累加, 最终得到一个句子在每个位置的损失均值之和 96 | 总结: mask_loss在batch_size方向累加,然后求均值,loss在seq_len方向进行累加 97 | 即: 一个batch的损失函数: 先计算所有句子在每个位置的损失总和,再除batch_size 98 | 这里的loss变量用于反向传播,得到的是一个句子的平均损失 99 | ''' 100 | print_losses.append(mask_loss.item() * nTotal) 101 | n_totals += nTotal 102 | else: 103 | for t in range(max_target_length): 104 | decoder_output, decoder_hidden = decoder( 105 | decoder_input, decoder_hidden, encoder_outputs 106 | ) 107 | # 不是teacher forcing: 下一个时刻的输入是当前模型预测概率最高的值 108 | _, topi = decoder_output.topk(1) 109 | decoder_input = torch.LongTensor([[topi[i][0] for i in range(opt.batch_size)]]) 110 | decoder_input = decoder_input.to(opt.device) 111 | # 计算累计的loss 112 | mask_loss, nTotal = maskNLLLoss(decoder_output, targets[t], mask[t]) 113 | loss += mask_loss 114 | print_losses.append(mask_loss.item() * nTotal) 115 | n_totals += nTotal 116 | 117 | #反向传播 118 | loss.backward() 119 | 120 | # 对encoder和decoder进行梯度裁剪 121 | _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), opt.clip) 122 | _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), opt.clip) 123 | 124 | #更新参数 125 | encoder_optimizer.step() 126 | decoder_optimizer.step() 127 | #这里是batch中一个位置(视作二维矩阵一个格子)的平均损失 128 | return sum(print_losses) / n_totals 129 | 130 | def train(**kwargs): 131 | 132 | opt = Config() 133 | for k, v in kwargs.items(): #设置参数 134 | setattr(opt, k, v) 135 | 136 | # 数据 137 | dataloader = get_dataloader(opt) 138 | _data = dataloader.dataset._data 139 | word2ix = _data['word2ix'] 140 | sos = word2ix.get(_data.get('sos')) 141 | voc_length = len(word2ix) 142 | 143 | #定义模型 144 | encoder = EncoderRNN(opt, voc_length) 145 | decoder = LuongAttnDecoderRNN(opt, voc_length) 146 | 147 | #加载断点,从上次结束地方开始 148 | if opt.model_ckpt: 149 | checkpoint = torch.load(opt.model_ckpt) 150 | encoder.load_state_dict(checkpoint['en']) 151 | decoder.load_state_dict(checkpoint['de']) 152 | 153 | 154 | #切换模式 155 | encoder = encoder.to(opt.device) 156 | decoder = decoder.to(opt.device) 157 | encoder.train() 158 | decoder.train() 159 | 160 | 161 | #定义优化器(注意与encoder.to(device)前后不要反) 162 | encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=opt.learning_rate) 163 | decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=opt.learning_rate * opt.decoder_learning_ratio) 164 | if opt.model_ckpt: 165 | encoder_optimizer.load_state_dict(checkpoint['en_opt']) 166 | decoder_optimizer.load_state_dict(checkpoint['de_opt']) 167 | 168 | #定义打印loss的变量 169 | print_loss = 0 170 | 171 | for epoch in range(opt.epoch): 172 | for ii, data in enumerate(dataloader): 173 | #取一个batch训练 174 | loss = train_by_batch(sos, opt, data, encoder_optimizer, decoder_optimizer, encoder, decoder) 175 | print_loss += loss 176 | #打印损失 177 | if ii % opt.print_every == 0: 178 | print_loss_avg = print_loss / opt.print_every 179 | print("Epoch: {}; Epoch Percent complete: {:.1f}%; Average loss: {:.4f}" 180 | .format(epoch, epoch / opt.epoch * 100, print_loss_avg)) 181 | print_loss = 0 182 | 183 | # 保存checkpoint 184 | if epoch % opt.save_every == 0: 185 | checkpoint_path = '{prefix}_{time}'.format(prefix=opt.prefix, 186 | time=time.strftime('%m%d_%H%M')) 187 | torch.save({ 188 | 'en': encoder.state_dict(), 189 | 'de': decoder.state_dict(), 190 | 'en_opt': encoder_optimizer.state_dict(), 191 | 'de_opt': decoder_optimizer.state_dict(), 192 | }, checkpoint_path) 193 | 194 | def generate(input_seq, searcher, sos, eos, opt): 195 | #input_seq: 已分词且转为索引的序列 196 | #input_batch: shape: [1, seq_len] ==> [seq_len,1] (即batch_size=1) 197 | input_batch = [input_seq] 198 | input_lengths = torch.tensor([len(seq) for seq in input_batch]) 199 | input_batch = torch.LongTensor([input_seq]).transpose(0,1) 200 | input_batch = input_batch.to(opt.device) 201 | input_lengths = input_lengths.to(opt.device) 202 | tokens, scores = searcher(sos, eos, input_batch, input_lengths, opt.max_generate_length, opt.device) 203 | return tokens 204 | 205 | def eval(**kwargs): 206 | 207 | opt = Config() 208 | for k, v in kwargs.items(): #设置参数 209 | setattr(opt, k, v) 210 | 211 | 212 | # 数据 213 | dataloader = get_dataloader(opt) 214 | _data = dataloader.dataset._data 215 | word2ix,ix2word = _data['word2ix'], _data['ix2word'] 216 | sos = word2ix.get(_data.get('sos')) 217 | eos = word2ix.get(_data.get('eos')) 218 | unknown = word2ix.get(_data.get('unknown')) 219 | voc_length = len(word2ix) 220 | 221 | #定义模型 222 | encoder = EncoderRNN(opt, voc_length) 223 | decoder = LuongAttnDecoderRNN(opt, voc_length) 224 | 225 | #加载模型 226 | if opt.model_ckpt == None: 227 | raise ValueError('model_ckpt is None.') 228 | return False 229 | checkpoint = torch.load(opt.model_ckpt, map_location=lambda s, l: s) 230 | encoder.load_state_dict(checkpoint['en']) 231 | decoder.load_state_dict(checkpoint['de']) 232 | 233 | with torch.no_grad(): 234 | #切换模式 235 | encoder = encoder.to(opt.device) 236 | decoder = decoder.to(opt.device) 237 | encoder.eval() 238 | decoder.eval() 239 | #定义seracher 240 | searcher = GreedySearchDecoder(encoder, decoder) 241 | 242 | while(1): 243 | input_sentence = input('> ') 244 | if input_sentence == 'q' or input_sentence == 'quit': break 245 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则 246 | input_seq = jieba.lcut(cop.sub("",input_sentence)) #分词序列 247 | input_seq = input_seq[:opt.max_input_length] + [''] 248 | input_seq = [word2ix.get(word, unknown) for word in input_seq] 249 | tokens = generate(input_seq, searcher, sos, eos, opt) 250 | output_words = ''.join([ix2word[token.item()] for token in tokens]) 251 | print('BOT: ', output_words) 252 | 253 | def test(opt): 254 | 255 | # 数据 256 | dataloader = get_dataloader(opt) 257 | _data = dataloader.dataset._data 258 | word2ix,ix2word = _data['word2ix'], _data['ix2word'] 259 | sos = word2ix.get(_data.get('sos')) 260 | eos = word2ix.get(_data.get('eos')) 261 | unknown = word2ix.get(_data.get('unknown')) 262 | voc_length = len(word2ix) 263 | 264 | #定义模型 265 | encoder = EncoderRNN(opt, voc_length) 266 | decoder = LuongAttnDecoderRNN(opt, voc_length) 267 | 268 | #加载模型 269 | if opt.model_ckpt == None: 270 | raise ValueError('model_ckpt is None.') 271 | return False 272 | checkpoint = torch.load(opt.model_ckpt, map_location=lambda s, l: s) 273 | encoder.load_state_dict(checkpoint['en']) 274 | decoder.load_state_dict(checkpoint['de']) 275 | 276 | with torch.no_grad(): 277 | #切换模式 278 | encoder = encoder.to(opt.device) 279 | decoder = decoder.to(opt.device) 280 | encoder.eval() 281 | decoder.eval() 282 | #定义seracher 283 | searcher = GreedySearchDecoder(encoder, decoder) 284 | return searcher, sos, eos, unknown, word2ix, ix2word 285 | 286 | def output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word): 287 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则 288 | input_seq = jieba.lcut(cop.sub("",input_sentence)) #分词序列 289 | input_seq = input_seq[:opt.max_input_length] + [''] 290 | input_seq = [word2ix.get(word, unknown) for word in input_seq] 291 | tokens = generate(input_seq, searcher, sos, eos, opt) 292 | output_words = ''.join([ix2word[token.item()] for token in tokens if token.item() != eos]) 293 | return output_words 294 | 295 | 296 | if __name__ == "__main__": 297 | import fire 298 | fire.Fire() 299 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/utils/__init__.py -------------------------------------------------------------------------------- /utils/beamsearch.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/utils/beamsearch.py -------------------------------------------------------------------------------- /utils/greedysearch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | 8 | class GreedySearchDecoder(nn.Module): 9 | def __init__(self, encoder, decoder): 10 | super(GreedySearchDecoder, self).__init__() 11 | self.encoder = encoder 12 | self.decoder = decoder 13 | 14 | def forward(self, sos, eos, input_seq, input_length, max_length, device): 15 | 16 | # Encoder的Forward计算 17 | encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length) 18 | # 把Encoder最后时刻的隐状态作为Decoder的初始值 19 | decoder_hidden = encoder_hidden[:self.decoder.num_layers] 20 | # 因为我们的函数都是要求(time,batch),因此即使只有一个数据,也要做出二维的。 21 | # Decoder的初始输入是SOS 22 | decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * sos 23 | # 用于保存解码结果的tensor 24 | all_tokens = torch.zeros([0], device=device, dtype=torch.long) 25 | all_scores = torch.zeros([0], device=device) 26 | # 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。 27 | for _ in range(max_length): 28 | # Decoder forward一步 29 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, 30 | encoder_outputs) 31 | # decoder_outputs是(batch=1, vob_size) 32 | # 使用max返回概率最大的词和得分 33 | decoder_scores, decoder_input = torch.max(decoder_output, dim=1) 34 | # 把解码结果保存到all_tokens和all_scores里 35 | all_tokens = torch.cat((all_tokens, decoder_input), dim=0) 36 | all_scores = torch.cat((all_scores, decoder_scores), dim=0) 37 | # decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。 38 | # 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。 39 | if decoder_input.item() == eos: 40 | break 41 | decoder_input = torch.unsqueeze(decoder_input, 0) 42 | 43 | # 返回所有的词和得分。 44 | return all_tokens, all_scores --------------------------------------------------------------------------------