├── .gitignore
├── LICENSE
├── QA_data
├── QA.db
├── QA_test.py
├── __init__.py
└── stop_words.txt
├── README.md
├── checkpoints
└── chatbot_0509_1437
├── clean_chat_corpus
└── qingyun.tsv
├── config.py
├── corpus.pth
├── dataload.py
├── datapreprocess.py
├── main.py
├── model.py
├── requirements.txt
├── train_eval.py
└── utils
├── __init__.py
├── beamsearch.py
└── greedysearch.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/QA_data/QA.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/QA.db
--------------------------------------------------------------------------------
/QA_data/QA_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: UTF-8 -*-
2 |
3 | import sqlite3
4 | import jieba
5 | import logging
6 | jieba.setLogLevel(logging.INFO) #设置不输出信息
7 |
8 | conn = sqlite3.connect('./QA_data/QA.db')
9 |
10 | cursor = conn.cursor()
11 | stop_words = []
12 | with open('./QA_data/stop_words.txt', encoding='gbk') as f:
13 | for line in f.readlines():
14 | stop_words.append(line.strip('\n'))
15 |
16 | def match(input_question):
17 | res = []
18 | cnt = {}
19 | question = list(jieba.cut(input_question, cut_all=False)) #对查询字符串进行分词
20 | for word in reversed(question): #去除停用词
21 | if word in stop_words:
22 | question.remove(word)
23 | for tag in question: #按照每个tag,循环构造查询语句
24 | keyword = "'%" + tag + "%'"
25 | result = cursor.execute("select * from QA where tag like " + keyword)
26 | for row in result:
27 | if row[0] not in cnt.keys():
28 | cnt[row[0]] = 0
29 | cnt[row[0]] += 1 #统计记录出现的次数
30 | try:
31 | res_id = sorted(cnt.items(), key=lambda d:d[1],reverse=True)[0][0] #返回出现次数最高的记录的id
32 | except:
33 | return tuple() #若查询不出则返回空
34 | cursor.execute("select * from QA where id= " + str(res_id))
35 | res = cursor.fetchone()
36 | if type(res) == type(tuple()):
37 | return res #返回元组类型(id, question, answer, tag)
38 | else:
39 | return tuple() #若查询不出则返回空
40 |
41 |
--------------------------------------------------------------------------------
/QA_data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/__init__.py
--------------------------------------------------------------------------------
/QA_data/stop_words.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/QA_data/stop_words.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🍀 小智,又一个中文聊天机器人:yum:
2 |
3 | 💖 利用有趣的中文语料库qingyun,由@Doragd 同学编写的中文聊天机器人:snowman:
4 |
5 | * 尽管她不是那么完善:muscle:,不是那么出色:paw_prints:
6 | * 但她是由我自己coding出来的:sparkling_heart: ,所以
7 |
8 | * **希望大家能够多多star支持**:star: 这个NLP初学者:runner:和他的朋友🍀 小智
9 |
10 |
11 |
12 | ## :rainbow:背景
13 |
14 | 这个项目实际是软件工程课程设计的子模块。我们的目标是开发一个智能客服工单处理系统。
15 |
16 | 智能客服工单系统实际的工作流程是:当人向系统发出提问时,系统首先去知识库中查找是否存在相关问题,如果有,则返回问题的答案,此时如何人不满意,则可以直接提交工单。如果知识库中不存在,则调用这个聊天机器人进行自动回复。
17 |
18 | 该系统服务的场景类似腾讯云的客服系统,客户多是来咨询相关问题的(云服务器,域名等),所以知识库也是有关云服务器,域名等的咨询,故障处理的 (问题,答案) 集合。
19 |
20 | 系统的前端界面和前后端消息交互由另一个同学@adjlyadv 完成,主要采用React+Django方式。
21 |
22 | @Doragd 同学负责的是知识库的获取和聊天机器人的编写,训练,测试。这个repo的内容也是关于这个的。
23 |
24 |
25 |
26 | ## :star2: 测试效果
27 |
28 | * 不使用知识库进行回答
29 |
30 |
31 | * 使用知识库进行回答
32 |
33 |
34 |
35 | * 整个系统效果:
36 |
37 |
38 |
39 |
40 | ## :floppy_disk:项目结构
41 |
42 | ```
43 | │ .gitignore
44 | │ config.py #模型配置参数
45 | │ corpus.pth #已经过处理的数据集
46 | │ dataload.py #dataloader
47 | │ datapreprocess.py #数据预处理
48 | │ LICENSE
49 | │ main.py
50 | │ model.py
51 | │ README.md
52 | │ requirements.txt
53 | │ train_eval.py #训练和验证,测试
54 | │
55 | ├─checkpoints
56 | │ chatbot_0509_1437 #已经训练好的模型
57 | │
58 | ├─clean_chat_corpus
59 | │ qingyun.tsv #语料库
60 | │
61 | ├─QA_data
62 | │ QA.db #知识库
63 | │ QA_test.py #使用知识库时调用
64 | │ stop_words.txt #停用词
65 | │ __init__.py
66 | │
67 | └─utils
68 | beamsearch.py #to do 未完工
69 | greedysearch.py #贪婪搜索,用于测试
70 | __init__.py
71 | ```
72 |
73 |
74 |
75 | ## :couple:依赖库
76 |
77 | 
78 | 
79 | 
80 | 
81 |
82 | 安装依赖
83 |
84 | ```shell
85 | $ pip install -r requirements.txt
86 | ```
87 |
88 |
89 |
90 | ## :sparkling_heart:开始使用
91 |
92 | ### 数据预处理(可省略)
93 |
94 | ```shell
95 | $ python datapreprocess.py
96 | ```
97 |
98 | 对语料库进行预处理,产生corpus.pth (**这里已经上传好corpus.pth, 故此步可以省略**)
99 |
100 | 可修改参数:
101 |
102 | ```
103 | # datapreprocess.py
104 | corpus_file = 'clean_chat_corpus/qingyun.tsv' #未处理的对话数据集
105 | max_voc_length = 10000 #字典最大长度
106 | min_word_appear = 10 #加入字典的词的词频最小值
107 | max_sentence_length = 50 #最大句子长度
108 | save_path = 'corpus.pth' #已处理的对话数据集保存路径
109 | ```
110 |
111 | ### 使用
112 |
113 | * 使用知识库
114 |
115 | 使用知识库时, 需要传入参数`use_QA_first=True` 此时,对于输入的字符串,首先在知识库中匹配最佳的问题和答案,并返回。找不到时,才调用聊天机器人自动生成回复。
116 |
117 | 这里的知识库是爬取整理的腾讯云官方文档中的常见问题和答案,100条,仅用于测试!
118 |
119 | ```shell
120 | $ python main.py chat --use_QA_first=True
121 | ```
122 |
123 | * 不使用知识库
124 |
125 | 由于课程设计需要,加入了腾讯云的问题答案对,但对于聊天机器人这个项目来说是无关紧要的,所以一般使用时,`use_QA_first=False` ,该参数默认为`True`
126 |
127 | ```shell
128 | $ python main.py chat --use_QA_first=False
129 | ```
130 |
131 | * 使用默认参数
132 |
133 | ```shell
134 | $ python main.py chat
135 | ```
136 |
137 | * 退出聊天:输入`exit`, `quit`, `q` 均可
138 |
139 | ### 其他可配置参数
140 |
141 | 在`config.py` 文件中说明
142 |
143 | 需要传入新的参数时,只需要命令行传入即可,形如
144 |
145 | ```shell
146 | $ python main.py chat --model_ckpt='checkpoints/chatbot_0509_1437' --use_QA_first=False
147 | ```
148 |
149 | 上面的命令指出了加载已训练模型的路径和是否使用知识库
150 |
151 |
152 |
153 | ## :cherry_blossom:技术实现
154 |
155 | ### 语料库
156 |
157 | | 语料名称 | 语料数量 | 语料来源说明 | 语料特点 | 语料样例 | 是否已分词 |
158 | | ------------------- | -------- | ------------------ | ---------------- | ----------------------------------------- | ---------- |
159 | | qingyun(青云语料) | 10W | 某聊天机器人交流群 | 相对不错,生活化 | Q:看来你很爱钱 A:噢是吗?那么你也差不多了 | 否 |
160 |
161 | * 来源:
162 |
163 | ### Seq2Seq
164 |
165 | * Encoder:两层双向GRU
166 | * Decoder:双层单向GRU
167 |
168 | ### Attention
169 |
170 | * Global attention,采用dot计算分数
171 | * Ref. https://arxiv.org/abs/1508.04025
172 |
173 |
174 |
175 | ## :construction_worker:模型训练与评估
176 |
177 | ```shell
178 | $ python train_eval.py train [--options]
179 | ```
180 |
181 | 定量评估部分暂时还没写好,应该采用困惑度来衡量,目前只能生成句子,人为评估质量
182 |
183 | ```shell
184 | $ python train_eval.py eval [--options]
185 | ```
186 |
187 |
188 |
189 | ## :sob:跳坑记录与总结
190 |
191 | * 最深刻的体会就是“深度学习知识的了解和理解之间差了N个编程实现”。虽然理论大家都很清楚,但是真正到编程实现时,总会出这样,那样的问题:从数据集的处理,到许多公式的编程实现,到参数的调节,GPU配置等等各种问题
192 | * 这次实践的过程实际是跟着PyTorch Tutorial先过了一遍Chatbot部分,跑通以后,再更换语料库,处理语料库,再按照类的风格去重构了代码,然后就是无尽的Debug过程,遇到了很多坑,尤其是把张量移到GPU上遇到各种问题,主要是不清楚to(device)时究竟移动了哪些。
193 | * 通过测试发现,model.to(device)只会把参数移到GPU,不会把类中定义的成员tensor移过去,所以如果在forward方法中定义了新的张量,要记得移动。
194 | * 还有就是移动的顺序问题:先把模型移动到GPU,再去定义优化器。以及移动的方法:model=model.to(device),不要忘记赋值。
195 | * 很容易出现GPU显存不足的情况,注意写代码时要考虑内存利用率问题,尽量减少重复tensor。
196 | * 在一开始更换中文语料库后,训练总是不收敛,最后才发现原来是batch_size设置小了,实际上我感觉batch_size在显存足够时要尽量大,其实之前看到过,只是写代码的时候完全忘记这回事了。说明自己当时看mini-batch时还不够理解,还是要真的写代码才能够深入人心,至少bug深入人心
197 | * 还有一个问题就是误解了torch.long,以为是高精度浮点,结果是int64型,造成了一个bug,找了好久才发现怎么回事。这告诉我们要认真看文档。
198 | * 最后的收获就是熟悉了如何实际实现一个模型,这很重要。
199 | * 实际上这个模型的效果不是很好,除开模型本身的问题不谈,我发现分词的质量会严重影响句子的质量,但是分词时我连停用词还没设置,会出现一些奇特的结果
200 | * 还有一个问题是处理变长序列时,损失函数如果用自己定义的,很容易出现不稳定情况,现在还在研究官方API
201 | * 本次实践还发现自己对一些参数理解还不够深,不知道怎么调,还要补理论。
202 | * 对模型的评估这部分还要继续做。
203 |
204 | ## :pray:致谢
205 |
206 | * 官方的Chatbot Tutorial
207 | *
208 | * 提供中文语料库
209 | *
210 | * 与官方的Chatbot Tutorial内容一致,但是有详尽的代码注释
211 | *
212 | * 模型的写法和习惯均参考
213 | *
214 |
--------------------------------------------------------------------------------
/checkpoints/chatbot_0509_1437:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/checkpoints/chatbot_0509_1437
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 |
5 | class Config:
6 | '''
7 | Chatbot模型参数
8 | '''
9 | corpus_data_path = 'corpus.pth' #已处理的对话数据
10 | use_QA_first = True #是否载入知识库
11 | max_input_length = 50 #输入的最大句子长度
12 | max_generate_length = 20 #生成的最大句子长度
13 | prefix = 'checkpoints/chatbot' #模型断点路径前缀
14 | model_ckpt = 'checkpoints/chatbot_0509_1437' #加载模型路径
15 | '''
16 | 训练超参数
17 | '''
18 | batch_size = 2048
19 | shuffle = True #dataloader是否打乱数据
20 | num_workers = 0 #dataloader多进程提取数据
21 | bidirectional = True #Encoder-RNN是否双向
22 | hidden_size = 256
23 | embedding_dim = 256
24 | method = 'dot' #attention method
25 | dropout = 0 #是否使用dropout
26 | clip = 50.0 #梯度裁剪阈值
27 | num_layers = 2 #Encoder-RNN层数
28 | learning_rate = 1e-3
29 | teacher_forcing_ratio = 1.0 #teacher_forcing比例
30 | decoder_learning_ratio = 5.0
31 | '''
32 | 训练周期信息
33 | '''
34 | epoch = 6000
35 | print_every = 1 #每隔print_every个Iteration打印一次
36 | save_every = 50 #每隔save_every个Epoch打印一次
37 | '''
38 | GPU
39 | '''
40 | use_gpu = torch.cuda.is_available() #是否使用gpu
41 | device = torch.device("cuda" if use_gpu else "cpu") #device
42 |
43 |
44 |
--------------------------------------------------------------------------------
/corpus.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/corpus.pth
--------------------------------------------------------------------------------
/dataload.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import itertools
5 | from torch.utils import data as dataimport
6 |
7 | def zeroPadding(l, fillvalue):
8 | '''
9 | l是多个长度不同的句子(list),使用zip_longest padding成定长,长度为最长句子的长度
10 | 在zeroPadding函数中隐式转置
11 | [batch_size, max_seq_len] ==> [max_seq_len, batch_size]
12 | '''
13 | return list(itertools.zip_longest(*l, fillvalue=fillvalue))
14 |
15 | def binaryMatrix(l, value):
16 | '''
17 | 生成mask矩阵, 0表示padding,1表示未padding
18 | shape同l,即[max_seq_len, batch_size]
19 | '''
20 | m = []
21 | for i, seq in enumerate(l):
22 | m.append([])
23 | for token in seq:
24 | if token == value:
25 | m[i].append(0)
26 | else:
27 | m[i].append(1)
28 | return m
29 |
30 | def create_collate_fn(padding, eos):
31 | '''
32 | 说明dataloader如何包装一个batch,传入的参数为的索引padding,字符索引eos
33 | collate_fn传入的参数是由一个batch的__getitem__方法的返回值组成的corpus_item
34 |
35 | corpus_item:
36 | lsit, 形如[(inputVar1, targetVar1, index1),(inputVar2, targetVar2, index2),...]
37 | inputVar1: [word_ix, word_ix,word_ix,...]
38 | targetVar1: [word_ix, word_ix,word_ix,...]
39 | inputs:
40 | 取出所有inputVar组成的list,形如[inputVar1,inputVar2,inputVar3,...],
41 | padding后(这里有隐式转置)转为tensor后形状为:[max_seq_len, batch_size]
42 | targets:
43 | 取出所有targetVar组成的list,形如[targetVar1,targetVar2,targetVar3,...]
44 | padding后(这里有隐式转置)转为tensor后形状为:[max_seq_len, batch_size]
45 | input_lengths:
46 | 在padding前要记录原来的inputVar的长度, 用于pad_packed_sequence
47 | 形如: [length_inputVar1, length_inputVar2, length_inputVar3, ...]
48 | max_targets_length:
49 | 该批次的所有target的最大长度
50 | mask:
51 | 形状: [max_seq_len, batch_size]
52 | indexes:
53 | 记录一个batch中每个 句子对 在corpus数据集中的位置
54 | 形如: [index1, index2, ...]
55 |
56 | '''
57 | def collate_fn(corpus_item):
58 | #按照inputVar的长度进行排序,是调用pad_packed_sequence方法的要求
59 | corpus_item.sort(key=lambda p: len(p[0]), reverse=True)
60 | inputs, targets, indexes = zip(*corpus_item)
61 | input_lengths = torch.tensor([len(inputVar) for inputVar in inputs])
62 | inputs = zeroPadding(inputs, padding)
63 | inputs = torch.LongTensor(inputs) #注意这里要LongTensor
64 |
65 | max_target_length = max([len(targetVar) for targetVar in targets])
66 | targets = zeroPadding(targets, padding)
67 | mask = binaryMatrix(targets, padding)
68 | mask = torch.ByteTensor(mask)
69 | targets = torch.LongTensor(targets)
70 |
71 |
72 | return inputs, targets, mask, input_lengths, max_target_length, indexes
73 |
74 | return collate_fn
75 |
76 |
77 |
78 |
79 | class CorpusDataset(dataimport.Dataset):
80 |
81 | def __init__(self, opt):
82 | self.opt = opt
83 | self._data = torch.load(opt.corpus_data_path)
84 | self.word2ix = self._data['word2ix']
85 | self.corpus = self._data['corpus']
86 | self.padding = self.word2ix.get(self._data.get('padding'))
87 | self.eos = self.word2ix.get(self._data.get('eos'))
88 | self.sos = self.word2ix.get(self._data.get('sos'))
89 |
90 | def __getitem__(self, index):
91 | inputVar = self.corpus[index][0]
92 | targetVar = self.corpus[index][1]
93 | return inputVar,targetVar, index
94 |
95 | def __len__(self):
96 | return len(self.corpus)
97 |
98 |
99 | def get_dataloader(opt):
100 | dataset = CorpusDataset(opt)
101 | dataloader = dataimport.DataLoader(dataset,
102 | batch_size=opt.batch_size,
103 | shuffle=opt.shuffle, #是否打乱数据
104 | num_workers=opt.num_workers, #多进程提取数据
105 | drop_last=True, #丢掉最后一个不足一个batch的数据
106 | collate_fn=create_collate_fn(dataset.padding, dataset.eos))
107 | return dataloader
--------------------------------------------------------------------------------
/datapreprocess.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import jieba
4 | import torch
5 | import re
6 | import logging
7 | jieba.setLogLevel(logging.INFO) #关闭jieba输出信息
8 |
9 | corpus_file = 'clean_chat_corpus/qingyun.tsv' #未处理的对话数据集
10 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则
11 | unknown = '' #unknown字符
12 | eos = '' #句子结束符
13 | sos = '' #句子开始符
14 | padding = '' #句子填充负
15 | max_voc_length = 10000 #字典最大长度
16 | min_word_appear = 10 #加入字典的词的词频最小值
17 | max_sentence_length = 50 #最大句子长度
18 | save_path = 'corpus.pth' #已处理的对话数据集保存路径
19 |
20 | def preprocess():
21 | print("preprocessing...")
22 | '''处理对话数据集'''
23 | data = []
24 | with open(corpus_file, encoding='utf-8') as f:
25 | lines = f.readlines()
26 | for line in lines:
27 | values = line.strip('\n').split('\t')
28 | sentences = []
29 | for value in values:
30 | sentence = jieba.lcut(cop.sub("",value))
31 | sentence = sentence[:max_sentence_length] + [eos]
32 | sentences.append(sentence)
33 | data.append(sentences)
34 |
35 | '''生成字典和句子索引'''
36 | word_nums = {} #统计单词的词频
37 | def update(word_nums):
38 | def fun(word):
39 | word_nums[word] = word_nums.get(word, 0) + 1
40 | return None
41 | return fun
42 | lambda_ = update(word_nums)
43 | _ = {lambda_(word) for sentences in data for sentence in sentences for word in sentence}
44 | #按词频从高到低排序
45 | word_nums_list = sorted([(num, word) for word, num in word_nums.items()], reverse=True)
46 | #词典最大长度: max_voc_length 最小单词词频: min_word_appear
47 | words = [word[1] for word in word_nums_list[:max_voc_length] if word[0] >= min_word_appear]
48 | #注意: 这里eos已经在前面加入了words,故这里不用重复加入
49 | words = [unknown, padding, sos] + words
50 | word2ix = {word: ix for ix, word in enumerate(words)}
51 | ix2word = {ix: word for word, ix in word2ix.items()}
52 | ix_corpus = [[[word2ix.get(word, word2ix.get(unknown)) for word in sentence]
53 | for sentence in item]
54 | for item in data]
55 |
56 | '''
57 | 保存处理好的对话数据集
58 |
59 | ix_corpus: list, 其中元素为[Question_sequence_list, Answer_seqence_list]
60 | Question_sequence_list: e.g. [word_ix, word_ix, word_ix, ...]
61 |
62 | word2ix: dict, 单词:索引
63 |
64 | ix2word: dict, 索引:单词
65 |
66 | '''
67 | clean_data = {
68 | 'corpus': ix_corpus,
69 | 'word2ix': word2ix,
70 | 'ix2word': ix2word,
71 | 'unknown' : '',
72 | 'eos' : '',
73 | 'sos' : '',
74 | 'padding': '',
75 | }
76 | torch.save(clean_data, save_path)
77 | print('save clean data in %s' % save_path)
78 |
79 | if __name__ == "__main__":
80 | preprocess()
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | from datapreprocess import preprocess
4 | import train_eval
5 | import fire
6 | from QA_data import QA_test
7 | from config import Config
8 |
9 |
10 |
11 | def chat(**kwargs):
12 |
13 | opt = Config()
14 | for k, v in kwargs.items(): #设置参数
15 | setattr(opt, k, v)
16 |
17 | searcher, sos, eos, unknown, word2ix, ix2word = train_eval.test(opt)
18 |
19 | if os.path.isfile(opt.corpus_data_path) == False:
20 | preprocess()
21 |
22 | while(1):
23 | input_sentence = input('Doragd > ')
24 | if input_sentence == 'q' or input_sentence == 'quit' or input_sentence == 'exit': break
25 | if opt.use_QA_first:
26 | query_res = QA_test.match(input_sentence)
27 | if(query_res == tuple()):
28 | output_words = train_eval.output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word)
29 | else:
30 | output_words = "您是不是要找以下问题: " + query_res[1] + ',您可以尝试这样: ' + query_res[2]
31 | else:
32 | output_words = train_eval.output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word)
33 | print('BOT > ',output_words)
34 |
35 | QA_test.conn.close()
36 |
37 | if __name__ == "__main__":
38 | fire.Fire()
39 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | import logging
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from utils.greedysearch import GreedySearchDecoder
7 |
8 |
9 | class EncoderRNN(nn.Module):
10 | def __init__(self, opt, voc_length):
11 | '''
12 | voc_length: 字典长度,即输入的单词的one-hot编码长度
13 | '''
14 | super(EncoderRNN, self).__init__()
15 | self.num_layers = opt.num_layers
16 | self.hidden_size = opt.hidden_size
17 | #nn.Embedding输入向量维度是字典长度,输出向量维度是词向量维度
18 | self.embedding = nn.Embedding(voc_length, opt.embedding_dim)
19 | #双向GRU作为Encoder
20 | self.gru = nn.GRU(opt.embedding_dim, self.hidden_size, self.num_layers,
21 | dropout=(0 if opt.num_layers == 1 else opt.dropout), bidirectional=opt.bidirectional)
22 |
23 | def forward(self, input_seq, input_lengths, hidden=None):
24 | '''
25 | input_seq:
26 | shape: [max_seq_len, batch_size]
27 | input_lengths:
28 | 一批次中每个句子对应的句子长度列表
29 | shape:[batch_size]
30 | hidden:
31 | Encoder的初始hidden输入,默认为None
32 | shape: [num_layers*num_directions, batch_size, hidden_size]
33 | 实际排列顺序是num_directions在前面,
34 | 即对于4层双向的GRU, num_layers*num_directions = 8
35 | 前4层是正向: [:4, batch_size, hidden_size]
36 | 后4层是反向: [4:, batch_size, hidden_size]
37 | embedded:
38 | 经过词嵌入后的词向量
39 | shape: [max_seq_len, batch_size, embedding_dim]
40 | outputs:
41 | 所有时刻的hidden层输出
42 | 一开始的shape: [max_seq_len, batch_size, hidden_size*num_directions]
43 | 注意: num_directions在前面, 即前面hidden_size个是正向的,后面hidden_size个是反向的
44 | 正向: [:, :, :hidden_size] 反向: [:, :, hidden_size:]
45 | 最后对双向GRU求和,得到最终的outputs: shape为[max_seq_len, batch_size, hidden_size]
46 | 输出的hidden:
47 | [num_layers*num_directions, batch_size, hidden_size]
48 | '''
49 |
50 | embedded = self.embedding(input_seq)
51 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
52 | outputs, hidden = self.gru(packed, hidden)
53 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
54 | outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
55 | return outputs, hidden
56 |
57 |
58 |
59 | class Attn(torch.nn.Module):
60 | def __init__(self, attn_method, hidden_size):
61 | super(Attn, self).__init__()
62 | self.method = attn_method #attention方法
63 | self.hidden_size = hidden_size
64 | if self.method not in ['dot', 'general', 'concat']:
65 | raise ValueError(self.method, "is not an appropriate attention method.")
66 | if self.method == 'general':
67 | self.attn = torch.nn.Linear(self.hidden_size, self.hidden_size)
68 | elif self.method == 'concat':
69 | self.attn = torch.nn.Linear(self.hidden_size * 2, self.hidden_size)
70 | self.v = torch.nn.Parameter(torch.FloatTensor(self.hidden_size))
71 |
72 | def dot_score(self, hidden, encoder_outputs):
73 | '''
74 | encoder_outputs:
75 | encoder(双向GRU)的所有时刻的最后一层的hidden输出
76 | shape: [max_seq_len, batch_size, hidden_size]
77 | 数学符号表示: h_s
78 | hidden:
79 | decoder(单向GRU)的所有时刻的最后一层的hidden输出,即decoder_ouputs
80 | shape: [max_seq_len, batch_size, hidden_size]
81 | 数学符号表示: h_t
82 | 注意: attention method: 'dot', Hadamard乘法,对应元素相乘,用*就好了
83 | torch.matmul是矩阵乘法, 所以最后的结果是h_s * h_t
84 | h_s的元素是一个hidden_size向量, 要得到score值,需要在dim=2上求和
85 | 相当于先不看batch_size,h_s * h_t 要得到的是 [max_seq_len]
86 | 即每个时刻都要得到一个分数值, 最后把batch_size加进来,
87 | 最终shape为: [max_seq_len, batch_size]
88 | '''
89 |
90 | return torch.sum(hidden * encoder_outputs, dim=2)
91 |
92 | def general_score(self, hidden, encoder_outputs):
93 | #先学习到一个线性变换Wh_s,即energy
94 | #然后进行点乘,所以最后结果为 h_t*Wh_s
95 | energy = self.attn(encoder_outputs)
96 | return torch.sum(hidden * energy, dim=2)
97 |
98 | def concat_score(self, hidden, encoder_outputs):
99 | '''
100 | hidden:
101 | h_t, shape: [max_seq_len, batch_size, hidden_size]
102 | expand(max_seq_len, -1,-1) ==> [max_seq_len, batch_size, hidden_size]
103 | 与encoder_outputs在第2维上进行cat, 最后shape: [max_seq_len, batch_size, hidden_size*2]
104 | 经过attn后得到[max_seq_len, batch_size, hidden_size],再进行tanh,shape不变
105 | 最后与v乘
106 | '''
107 | energy = self.attn(torch.cat((hidden.expand(encoder_outputs.size(0), -1, -1),
108 | encoder_outputs), 2)).tanh()
109 | return torch.sum(self.v * energy, dim=2)
110 |
111 | def forward(self, hidden, encoder_outputs):
112 | if self.method == 'general':
113 | attn_energies = self.general_score(hidden, encoder_outputs)
114 | elif self.method == 'concat':
115 | attn_energies = self.concat_score(hidden, encoder_outputs)
116 | elif self.method == 'dot':
117 | attn_energies = self.dot_score(hidden, encoder_outputs)
118 | #得到score,shape为[max_seq_len, batch_size],然后转置为[batch_size, max_seq_len]
119 | attn_energies = attn_energies.t()
120 | #对dim=1进行softmax,然后插入维度[batch_size, 1, max_seq_len]
121 | return F.softmax(attn_energies, dim=1).unsqueeze(1)
122 |
123 | class LuongAttnDecoderRNN(nn.Module):
124 | def __init__(self, opt, voc_length):
125 | super(LuongAttnDecoderRNN, self).__init__()
126 |
127 | self.attn_method = opt.method
128 | self.hidden_size = opt.hidden_size
129 | self.output_size = voc_length
130 | self.num_layers = opt.num_layers
131 | self.dropout = opt.dropout
132 | self.embedding = nn.Embedding(voc_length, opt.embedding_dim)
133 | self.embedding_dropout = nn.Dropout(self.dropout)
134 | self.gru = nn.GRU(opt.embedding_dim, self.hidden_size, self.num_layers, dropout=(0 if self.num_layers == 1 else self.dropout))
135 | self.concat = nn.Linear(self.hidden_size * 2, self.hidden_size)
136 | self.out = nn.Linear(self.hidden_size, self.output_size)
137 | self.attn = Attn(self.attn_method, self.hidden_size)
138 |
139 | def forward(self, input_step, last_hidden, encoder_outputs):
140 | '''
141 | input_step:
142 | decoder是逐字生成的,即每个timestep产生一个字,
143 | decoder接收的输入: input_step='/SOS'的索引 和 encoder的最后时刻的最后一层hidden输出
144 | 故shape:[1, batch_size]
145 | last_hidden:
146 | 上一个GRUCell的hidden输出
147 | 初始值为encoder的最后时刻的最后一层hidden输出,传入的是encoder_hidden的正向部分,
148 | 即encoder_hidden[:decoder.num_layers], 为了和decoder对应,所以取的是decoder的num_layers
149 | shape为[num_layers, batch_size, hidden_size]
150 | encoder_outputs:
151 | 这里还接收了encoder_outputs输入,用于计算attention
152 | '''
153 | #转为词向量[1, batch_size, embedding_dim]
154 | embedded = self.embedding(input_step)
155 | embedded = self.embedding_dropout(embedded)
156 | #rnn_output: [1, batch_size, hidden_size]
157 | #hidden: [num_layers, batch_size, hidden_size]
158 | rnn_output, hidden = self.gru(embedded, last_hidden)
159 | #attn_weights: [batch_size, 1, max_seq_len]
160 | attn_weights = self.attn(rnn_output, encoder_outputs)
161 | #bmm批量矩阵相乘,
162 | #attn_weights是batch_size个矩阵[1, max_seq_len]
163 | #encoder_outputs.transpose(0, 1)是batch_size个[max_seq_len, hidden_size]
164 | #相乘结果context为: [batch_size, 1, hidden_size]
165 | context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
166 | rnn_output = rnn_output.squeeze(0)
167 | context = context.squeeze(1)
168 | #压缩维度后,rnn_output[batch_size, hidden_size],context[batch_size, hidden_size]
169 | #在dim=1上连接: [batch_size, hidden_size * 2], contact后[batch_size, hidden_size]
170 | #tanh后shape不变: [batch_size, hidden_size]
171 | #最后进行全连接,映射为[batch_size, voc_length]
172 | #最后进行softmax: [batch_size, voc_length]
173 | concat_input = torch.cat((rnn_output, context), 1)
174 | concat_output = torch.tanh(self.concat(concat_input))
175 | output = self.out(concat_output)
176 | output = F.softmax(output, dim=1)
177 |
178 | return output, hidden
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torchnet==0.0.4
2 | fire==0.1.3
3 | torch==1.0.1
4 | jieba==0.39
5 |
--------------------------------------------------------------------------------
/train_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import re
3 | import time
4 | import random
5 | import jieba
6 | import torch
7 | import logging
8 | import torch.nn as nn
9 | from torchnet import meter
10 | from model import EncoderRNN, LuongAttnDecoderRNN
11 | from utils.greedysearch import GreedySearchDecoder
12 | from dataload import get_dataloader
13 | from config import Config
14 | jieba.setLogLevel(logging.INFO) #关闭jieba输出信息
15 |
16 | def maskNLLLoss(inp, target, mask):
17 | '''
18 | inp: shape [batch_size,voc_length]
19 | target: shape [batch_size] 经过view ==> [batch_size, 1] 这样就和inp维数相同,可以用gather
20 | target作为索引,在dim=1上索引inp的值,得到的形状同target [batch_size, 1]
21 | 然后压缩维度,得到[batch_size], 取负对数后
22 | 选择那些值为1的计算loss, 并求平均,得到loss
23 | 故loss实际是batch_size那列的均值,表示一个句子在某个位置(t)上的平均损失值
24 | 故nTotal表示nTotal个句子在某个位置上有值
25 | mask: shape [batch_size]
26 | loss: 平均一个句子在t位置上的损失值
27 | '''
28 | nTotal = mask.sum() #padding是0,非padding是1,因此sum就可以得到词的个数
29 | crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
30 | loss = crossEntropy.masked_select(mask).mean()
31 | return loss, nTotal.item()
32 |
33 | def train_by_batch(sos, opt, data, encoder_optimizer, decoder_optimizer, encoder, decoder):
34 | #清空梯度
35 |
36 | encoder_optimizer.zero_grad()
37 | decoder_optimizer.zero_grad()
38 |
39 | #处理一个batch数据
40 | inputs, targets, mask, input_lengths, max_target_length, indexes = data
41 | inputs = inputs.to(opt.device)
42 | targets = targets.to(opt.device)
43 | mask = mask.to(opt.device)
44 | input_lengths = input_lengths.to(opt.device)
45 |
46 |
47 | # 初始化变量
48 | loss = 0
49 | print_losses = []
50 | n_totals = 0
51 |
52 | #forward计算
53 | '''
54 | inputs: shape [max_seq_len, batch_size]
55 | input_lengths: shape [batch_size]
56 | encoder_outputs: shape [max_seq_len, batch_size, hidden_size]
57 | encoder_hidden: shape [num_layers*num_directions, batch_size, hidden_size]
58 | decoder_input: shape [1, batch_size]
59 | decoder_hidden: decoder的初始hidden输入,是encoder_hidden取正方向
60 | '''
61 | encoder_outputs, encoder_hidden = encoder(inputs, input_lengths)
62 | decoder_input = torch.LongTensor([[sos for _ in range(opt.batch_size)]])
63 | decoder_input = decoder_input.to(opt.device)
64 | decoder_hidden = encoder_hidden[:decoder.num_layers]
65 |
66 | # 确定是否teacher forcing
67 | use_teacher_forcing = True if random.random() < opt.teacher_forcing_ratio else False
68 |
69 | '''
70 | 一次处理一个时刻,即一个字符
71 | decoder_output: [batch_size, voc_length]
72 | decoder_hidden: [decoder_num_layers, batch_size, hidden_size]
73 | 如果使用teacher_forcing,下一个时刻的输入是当前正确答案,即
74 | targets[t] shape: [batch_size] ==> view后 [1, batch_size] 作为decoder_input
75 | '''
76 | if use_teacher_forcing:
77 | for t in range(max_target_length):
78 | decoder_output, decoder_hidden = decoder(
79 | decoder_input, decoder_hidden, encoder_outputs
80 | )
81 | decoder_input = targets[t].view(1, -1)
82 |
83 |
84 | # 计算累计的loss
85 | '''
86 | 每次迭代,
87 | targets[t]: 一个batch所有样本指定位置(t位置)上的值组成的向量,shape [batch_size]
88 | mask[t]: 一个batch所有样本指定位置(t位置)上的值组成的向量, 值为1表示此处未padding
89 | decoder_output: [batch_size, voc_length]
90 | '''
91 | mask_loss, nTotal = maskNLLLoss(decoder_output, targets[t], mask[t])
92 | mask_loss = mask_loss.to(opt.device)
93 | loss += mask_loss
94 | '''
95 | 这里loss在seq_len方向迭代进行累加, 最终得到一个句子在每个位置的损失均值之和
96 | 总结: mask_loss在batch_size方向累加,然后求均值,loss在seq_len方向进行累加
97 | 即: 一个batch的损失函数: 先计算所有句子在每个位置的损失总和,再除batch_size
98 | 这里的loss变量用于反向传播,得到的是一个句子的平均损失
99 | '''
100 | print_losses.append(mask_loss.item() * nTotal)
101 | n_totals += nTotal
102 | else:
103 | for t in range(max_target_length):
104 | decoder_output, decoder_hidden = decoder(
105 | decoder_input, decoder_hidden, encoder_outputs
106 | )
107 | # 不是teacher forcing: 下一个时刻的输入是当前模型预测概率最高的值
108 | _, topi = decoder_output.topk(1)
109 | decoder_input = torch.LongTensor([[topi[i][0] for i in range(opt.batch_size)]])
110 | decoder_input = decoder_input.to(opt.device)
111 | # 计算累计的loss
112 | mask_loss, nTotal = maskNLLLoss(decoder_output, targets[t], mask[t])
113 | loss += mask_loss
114 | print_losses.append(mask_loss.item() * nTotal)
115 | n_totals += nTotal
116 |
117 | #反向传播
118 | loss.backward()
119 |
120 | # 对encoder和decoder进行梯度裁剪
121 | _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), opt.clip)
122 | _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), opt.clip)
123 |
124 | #更新参数
125 | encoder_optimizer.step()
126 | decoder_optimizer.step()
127 | #这里是batch中一个位置(视作二维矩阵一个格子)的平均损失
128 | return sum(print_losses) / n_totals
129 |
130 | def train(**kwargs):
131 |
132 | opt = Config()
133 | for k, v in kwargs.items(): #设置参数
134 | setattr(opt, k, v)
135 |
136 | # 数据
137 | dataloader = get_dataloader(opt)
138 | _data = dataloader.dataset._data
139 | word2ix = _data['word2ix']
140 | sos = word2ix.get(_data.get('sos'))
141 | voc_length = len(word2ix)
142 |
143 | #定义模型
144 | encoder = EncoderRNN(opt, voc_length)
145 | decoder = LuongAttnDecoderRNN(opt, voc_length)
146 |
147 | #加载断点,从上次结束地方开始
148 | if opt.model_ckpt:
149 | checkpoint = torch.load(opt.model_ckpt)
150 | encoder.load_state_dict(checkpoint['en'])
151 | decoder.load_state_dict(checkpoint['de'])
152 |
153 |
154 | #切换模式
155 | encoder = encoder.to(opt.device)
156 | decoder = decoder.to(opt.device)
157 | encoder.train()
158 | decoder.train()
159 |
160 |
161 | #定义优化器(注意与encoder.to(device)前后不要反)
162 | encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=opt.learning_rate)
163 | decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=opt.learning_rate * opt.decoder_learning_ratio)
164 | if opt.model_ckpt:
165 | encoder_optimizer.load_state_dict(checkpoint['en_opt'])
166 | decoder_optimizer.load_state_dict(checkpoint['de_opt'])
167 |
168 | #定义打印loss的变量
169 | print_loss = 0
170 |
171 | for epoch in range(opt.epoch):
172 | for ii, data in enumerate(dataloader):
173 | #取一个batch训练
174 | loss = train_by_batch(sos, opt, data, encoder_optimizer, decoder_optimizer, encoder, decoder)
175 | print_loss += loss
176 | #打印损失
177 | if ii % opt.print_every == 0:
178 | print_loss_avg = print_loss / opt.print_every
179 | print("Epoch: {}; Epoch Percent complete: {:.1f}%; Average loss: {:.4f}"
180 | .format(epoch, epoch / opt.epoch * 100, print_loss_avg))
181 | print_loss = 0
182 |
183 | # 保存checkpoint
184 | if epoch % opt.save_every == 0:
185 | checkpoint_path = '{prefix}_{time}'.format(prefix=opt.prefix,
186 | time=time.strftime('%m%d_%H%M'))
187 | torch.save({
188 | 'en': encoder.state_dict(),
189 | 'de': decoder.state_dict(),
190 | 'en_opt': encoder_optimizer.state_dict(),
191 | 'de_opt': decoder_optimizer.state_dict(),
192 | }, checkpoint_path)
193 |
194 | def generate(input_seq, searcher, sos, eos, opt):
195 | #input_seq: 已分词且转为索引的序列
196 | #input_batch: shape: [1, seq_len] ==> [seq_len,1] (即batch_size=1)
197 | input_batch = [input_seq]
198 | input_lengths = torch.tensor([len(seq) for seq in input_batch])
199 | input_batch = torch.LongTensor([input_seq]).transpose(0,1)
200 | input_batch = input_batch.to(opt.device)
201 | input_lengths = input_lengths.to(opt.device)
202 | tokens, scores = searcher(sos, eos, input_batch, input_lengths, opt.max_generate_length, opt.device)
203 | return tokens
204 |
205 | def eval(**kwargs):
206 |
207 | opt = Config()
208 | for k, v in kwargs.items(): #设置参数
209 | setattr(opt, k, v)
210 |
211 |
212 | # 数据
213 | dataloader = get_dataloader(opt)
214 | _data = dataloader.dataset._data
215 | word2ix,ix2word = _data['word2ix'], _data['ix2word']
216 | sos = word2ix.get(_data.get('sos'))
217 | eos = word2ix.get(_data.get('eos'))
218 | unknown = word2ix.get(_data.get('unknown'))
219 | voc_length = len(word2ix)
220 |
221 | #定义模型
222 | encoder = EncoderRNN(opt, voc_length)
223 | decoder = LuongAttnDecoderRNN(opt, voc_length)
224 |
225 | #加载模型
226 | if opt.model_ckpt == None:
227 | raise ValueError('model_ckpt is None.')
228 | return False
229 | checkpoint = torch.load(opt.model_ckpt, map_location=lambda s, l: s)
230 | encoder.load_state_dict(checkpoint['en'])
231 | decoder.load_state_dict(checkpoint['de'])
232 |
233 | with torch.no_grad():
234 | #切换模式
235 | encoder = encoder.to(opt.device)
236 | decoder = decoder.to(opt.device)
237 | encoder.eval()
238 | decoder.eval()
239 | #定义seracher
240 | searcher = GreedySearchDecoder(encoder, decoder)
241 |
242 | while(1):
243 | input_sentence = input('> ')
244 | if input_sentence == 'q' or input_sentence == 'quit': break
245 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则
246 | input_seq = jieba.lcut(cop.sub("",input_sentence)) #分词序列
247 | input_seq = input_seq[:opt.max_input_length] + ['']
248 | input_seq = [word2ix.get(word, unknown) for word in input_seq]
249 | tokens = generate(input_seq, searcher, sos, eos, opt)
250 | output_words = ''.join([ix2word[token.item()] for token in tokens])
251 | print('BOT: ', output_words)
252 |
253 | def test(opt):
254 |
255 | # 数据
256 | dataloader = get_dataloader(opt)
257 | _data = dataloader.dataset._data
258 | word2ix,ix2word = _data['word2ix'], _data['ix2word']
259 | sos = word2ix.get(_data.get('sos'))
260 | eos = word2ix.get(_data.get('eos'))
261 | unknown = word2ix.get(_data.get('unknown'))
262 | voc_length = len(word2ix)
263 |
264 | #定义模型
265 | encoder = EncoderRNN(opt, voc_length)
266 | decoder = LuongAttnDecoderRNN(opt, voc_length)
267 |
268 | #加载模型
269 | if opt.model_ckpt == None:
270 | raise ValueError('model_ckpt is None.')
271 | return False
272 | checkpoint = torch.load(opt.model_ckpt, map_location=lambda s, l: s)
273 | encoder.load_state_dict(checkpoint['en'])
274 | decoder.load_state_dict(checkpoint['de'])
275 |
276 | with torch.no_grad():
277 | #切换模式
278 | encoder = encoder.to(opt.device)
279 | decoder = decoder.to(opt.device)
280 | encoder.eval()
281 | decoder.eval()
282 | #定义seracher
283 | searcher = GreedySearchDecoder(encoder, decoder)
284 | return searcher, sos, eos, unknown, word2ix, ix2word
285 |
286 | def output_answer(input_sentence, searcher, sos, eos, unknown, opt, word2ix, ix2word):
287 | cop = re.compile("[^\u4e00-\u9fa5^a-z^A-Z^0-9]") #分词处理正则
288 | input_seq = jieba.lcut(cop.sub("",input_sentence)) #分词序列
289 | input_seq = input_seq[:opt.max_input_length] + ['']
290 | input_seq = [word2ix.get(word, unknown) for word in input_seq]
291 | tokens = generate(input_seq, searcher, sos, eos, opt)
292 | output_words = ''.join([ix2word[token.item()] for token in tokens if token.item() != eos])
293 | return output_words
294 |
295 |
296 | if __name__ == "__main__":
297 | import fire
298 | fire.Fire()
299 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/utils/__init__.py
--------------------------------------------------------------------------------
/utils/beamsearch.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Doragd/Chinese-Chatbot-PyTorch-Implementation/a3983f1ffad2282c022e8995a4945dc8be98d76e/utils/beamsearch.py
--------------------------------------------------------------------------------
/utils/greedysearch.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 |
8 | class GreedySearchDecoder(nn.Module):
9 | def __init__(self, encoder, decoder):
10 | super(GreedySearchDecoder, self).__init__()
11 | self.encoder = encoder
12 | self.decoder = decoder
13 |
14 | def forward(self, sos, eos, input_seq, input_length, max_length, device):
15 |
16 | # Encoder的Forward计算
17 | encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)
18 | # 把Encoder最后时刻的隐状态作为Decoder的初始值
19 | decoder_hidden = encoder_hidden[:self.decoder.num_layers]
20 | # 因为我们的函数都是要求(time,batch),因此即使只有一个数据,也要做出二维的。
21 | # Decoder的初始输入是SOS
22 | decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * sos
23 | # 用于保存解码结果的tensor
24 | all_tokens = torch.zeros([0], device=device, dtype=torch.long)
25 | all_scores = torch.zeros([0], device=device)
26 | # 循环,这里只使用长度限制,后面处理的时候把EOS去掉了。
27 | for _ in range(max_length):
28 | # Decoder forward一步
29 | decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden,
30 | encoder_outputs)
31 | # decoder_outputs是(batch=1, vob_size)
32 | # 使用max返回概率最大的词和得分
33 | decoder_scores, decoder_input = torch.max(decoder_output, dim=1)
34 | # 把解码结果保存到all_tokens和all_scores里
35 | all_tokens = torch.cat((all_tokens, decoder_input), dim=0)
36 | all_scores = torch.cat((all_scores, decoder_scores), dim=0)
37 | # decoder_input是当前时刻输出的词的ID,这是个一维的向量,因为max会减少一维。
38 | # 但是decoder要求有一个batch维度,因此用unsqueeze增加batch维度。
39 | if decoder_input.item() == eos:
40 | break
41 | decoder_input = torch.unsqueeze(decoder_input, 0)
42 |
43 | # 返回所有的词和得分。
44 | return all_tokens, all_scores
--------------------------------------------------------------------------------