├── doc
└── img
│ ├── plugin_example.jpg
│ ├── plugin_step1.jpg
│ ├── plugin_step2.jpg
│ ├── plugin_step3.jpg
│ ├── backend_helloworld.jpg
│ ├── code_completion_result_1.jpg
│ └── code_completion_result_2.jpg
├── plugin
└── 01-code_completion
│ ├── lib
│ ├── ezmorph-1.0.6.jar
│ ├── commons-beanutils.jar
│ ├── commons-lang-2.5.jar
│ ├── commons-logging-1.1.jar
│ ├── json-lib-2.4-jdk15.jar
│ └── commons-collections-3.2.1.jar
│ ├── .idea
│ ├── compiler.xml
│ ├── libraries
│ │ ├── ezmorph_1_0_6.xml
│ │ ├── commons_lang_2_5.xml
│ │ ├── commons_beanutils.xml
│ │ ├── json_lib_2_4_jdk15.xml
│ │ ├── commons_logging_1_1.xml
│ │ └── commons_collections_3_2_1.xml
│ ├── .gitignore
│ ├── modules.xml
│ └── misc.xml
│ ├── out
│ └── production
│ │ └── 01-code_completion
│ │ ├── SimpleCompletionContributor.class
│ │ └── META-INF
│ │ └── plugin.xml
│ ├── 01-code_completion.iml
│ ├── resources
│ └── META-INF
│ │ └── plugin.xml
│ └── src
│ └── SimpleCompletionContributor.java
├── backend
├── 01-flask
│ ├── server_demo.py
│ ├── cors.py
│ └── server.py
└── 02-crow
│ ├── server.py
│ └── server.cpp
├── README.md
├── LICENSE
└── model
├── 02-Seq2Seq
├── Seq2Seq_tf20.ipynb
├── .~Seq2Seq.ipynb
├── .ipynb_checkpoints
│ └── Seq2Seq-checkpoint.ipynb
└── Seq2Seq_tf10.ipynb
└── 01-Text_Gen
├── Text_Gen_tf10.ipynb
└── .ipynb_checkpoints
└── Text_Gen_tf10-checkpoint.ipynb
/doc/img/plugin_example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_example.jpg
--------------------------------------------------------------------------------
/doc/img/plugin_step1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step1.jpg
--------------------------------------------------------------------------------
/doc/img/plugin_step2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step2.jpg
--------------------------------------------------------------------------------
/doc/img/plugin_step3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step3.jpg
--------------------------------------------------------------------------------
/doc/img/backend_helloworld.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/backend_helloworld.jpg
--------------------------------------------------------------------------------
/doc/img/code_completion_result_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/code_completion_result_1.jpg
--------------------------------------------------------------------------------
/doc/img/code_completion_result_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/code_completion_result_2.jpg
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/ezmorph-1.0.6.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/ezmorph-1.0.6.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/commons-beanutils.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-beanutils.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/commons-lang-2.5.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-lang-2.5.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/commons-logging-1.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-logging-1.1.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/json-lib-2.4-jdk15.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/json-lib-2.4-jdk15.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/lib/commons-collections-3.2.1.jar:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-collections-3.2.1.jar
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/compiler.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/out/production/01-code_completion/SimpleCompletionContributor.class:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/out/production/01-code_completion/SimpleCompletionContributor.class
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/ezmorph_1_0_6.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /../../../../../:\02-code\00-bishe\01-code_completion\.idea/dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/commons_lang_2_5.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/commons_beanutils.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/json_lib_2_4_jdk15.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/commons_logging_1_1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/libraries/commons_collections_3_2_1.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/backend/01-flask/server_demo.py:
--------------------------------------------------------------------------------
1 | from cors import crossdomain
2 | from flask import Flask, jsonify, request
3 |
4 |
5 | app = Flask(__name__)
6 |
7 |
8 | def get_args(req):
9 | if request.method == 'POST':
10 | args = request.json
11 | elif request.method == "GET":
12 | args = request.args
13 | return args
14 |
15 |
16 | @app.route("/plugin_test", methods=["GET", "POST", "OPTIONS"])
17 | @crossdomain(origin='*', headers="Content-Type")
18 | def plugin_test():
19 | args = get_args(request)
20 | sentence = args.get("keyword", "Error: nothing input")
21 | # 处理输入返回数据
22 | results = []
23 | for i in range(5):
24 | results.append(sentence + " test: " + str(i))
25 | return jsonify({"data": results})
26 |
27 |
28 | def main(host="127.0.0.1", port=9078):
29 | app.run(host=host, port=port, debug=True)
30 |
31 |
32 | if __name__ == "__main__":
33 | main()
34 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/01-code_completion.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/resources/META-INF/plugin.xml:
--------------------------------------------------------------------------------
1 |
2 | com.yang.pku.coder_aider
3 | Code Sentence Completion for PyCharm
4 | 1.0
5 | PKU_yang
6 |
7 |
9 | Able to autocomplete for code sentences
10 | ]]>
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
23 |
24 | since-build
25 | com.intellij.modules.platform
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
39 |
40 |
--------------------------------------------------------------------------------
/plugin/01-code_completion/out/production/01-code_completion/META-INF/plugin.xml:
--------------------------------------------------------------------------------
1 |
2 | com.yang.pku.coder_aider
3 | Code Sentence Completion for PyCharm
4 | 1.0
5 | PKU_yang
6 |
7 |
9 | Able to autocomplete for code sentences
10 | ]]>
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
23 |
24 | since-build
25 | com.intellij.modules.platform
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
39 |
40 |
--------------------------------------------------------------------------------
/backend/01-flask/cors.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from flask import make_response, request, current_app
3 | from functools import update_wrapper
4 |
5 |
6 | def crossdomain(origin=None, methods=None, headers=None, max_age=21600, attach_to_all=True, automatic_options=True):
7 | if methods is not None:
8 | methods = ', '.join(sorted(x.upper() for x in methods))
9 | if headers is not None and not isinstance(headers, str):
10 | headers = ', '.join(x.upper() for x in headers)
11 | if not isinstance(origin, str):
12 | origin = ', '.join(origin)
13 | if isinstance(max_age, timedelta):
14 | max_age = max_age.total_seconds()
15 |
16 | def get_methods():
17 | if methods is not None:
18 | return methods
19 |
20 | options_resp = current_app.make_default_options_response()
21 | return options_resp.headers['allow']
22 |
23 | def decorator(f):
24 | def wrapped_function(*args, **kwargs):
25 | if automatic_options and request.method == 'OPTIONS':
26 | resp = current_app.make_default_options_response()
27 | else:
28 | resp = make_response(f(*args, **kwargs))
29 | if not attach_to_all and request.method != 'OPTIONS':
30 | return resp
31 |
32 | h = resp.headers
33 |
34 | h['Access-Control-Allow-Origin'] = origin
35 | h['Access-Control-Allow-Methods'] = get_methods()
36 | h['Access-Control-Max-Age'] = str(max_age)
37 | if headers is not None:
38 | h['Access-Control-Allow-Headers'] = headers
39 | return resp
40 |
41 | f.provide_automatic_options = False
42 | return update_wrapper(wrapped_function, f)
43 | return decorator
44 |
--------------------------------------------------------------------------------
/backend/02-crow/server.py:
--------------------------------------------------------------------------------
1 | from keras.preprocessing.text import Tokenizer
2 | from keras.models import load_model
3 | from keras.preprocessing.sequence import pad_sequences
4 | import json
5 |
6 | class Server:
7 | def __init__(self):
8 | self.model_path = "model/CheckpointModel_2.h5"
9 | self.word_index = json.load(open('word_dict.json', 'r'))
10 | self.para = json.load(open('para_dict.json', 'r'))
11 | self.total_words = self.para['total_words']
12 | self.max_sequence_len = self.para['max_sequence_len']
13 | self.model = load_model(self.model_path)
14 | print("模型初始化成功!")
15 |
16 | def reference(self,keyword):
17 | result = self.rep(self.generate_text(keyword))
18 | #print("python预测结果:" + result)
19 | return result
20 |
21 | def generate_text(self,keyword):
22 | keyword = keyword.replace("(", " pareleft", ).replace(")", " pareright").replace(".", " dot").replace(",",\
23 | " comma").replace(" \'\'", " quotMark").replace("=", " equal").replace(":", " colon")
24 | i = 0
25 | while i <= 10:
26 | token_list = self.find_index(keyword)
27 | if token_list == None:
28 | return "无法预测"
29 | # print(token_list)
30 | token_list = pad_sequences([token_list], maxlen=self.max_sequence_len - 1, padding='pre')
31 | predicted = self.model.predict_classes(token_list, verbose=0)
32 | output_word = ""
33 | for word, index in self.word_index.items():
34 | if index == predicted:
35 | output_word = word
36 | break
37 | if output_word == 'dom':
38 | break
39 | keyword += " " + output_word
40 | i += 1
41 | return keyword
42 |
43 | def find_index(self,keyword):
44 | index_list = []
45 | keyword = keyword.split()
46 | for word in keyword:
47 | if not word in self.word_index:
48 | return
49 | index_list.append(self.word_index[word])
50 | return index_list
51 |
52 | def rep(self,str):
53 | return str.replace(" pareleft", "(", ).replace(" pareright", ")").replace(" dot", ".") \
54 | .replace(" comma", ",").replace(" quotMark ", "\'\'").replace(" equal", "=").replace(" colon", ":") \
55 | .replace(" quotmark", "\'\'")
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AI-Coder
2 | AI-Coder是一款基于PyCharm的代码句补全插件。
3 |
4 | 其补全效果如下:
5 |
6 |

7 |
8 | 句中调用代码句补全
9 |
10 |
11 |
12 |
13 |
14 |
15 | 句间调用代码句补全
16 |
17 |
18 |
19 | ## 目录结构
20 | - backend——代码句补全服务器
21 | - dataset——训练数据集
22 | - model——代码句补全模型
23 | - plugin——插件开发配置
24 |
25 |
26 |
27 | ## 服务器
28 | 代码句补全服务器尝试了两种框架,分别是Flask和Crow。
29 | ### Flask
30 |
31 | #### 1. 准备
32 |
33 | 如果没有安装 flask,先安装 flask。anaconda 自带了 flask。
34 | `pip install flask`
35 |
36 | #### 2. 运行
37 |
38 | 进入 backend 文件夹,运行 server_demo.py。
39 |
40 | 在浏览器中输入 localhost:9078/plugin_test?keyword=helloworld ,浏览器返回内容如下。
41 |
42 |
43 |
44 | 后端获取 keyword 中的数据,处理之后返回。后续我们使用模型处理输入,道理是一样的。
45 |
46 | ### [Crow](https://github.com/ipkn/crow)
47 | Crow是一个轻量级的Web服务器框架,这个框架是受Python下的Flask启发的,其实现的功能和Flask基本一致,核心的区别在于Crow是用C++编写的,性能较Flask有一定的提升。
48 |
49 |
50 |
51 | ## 数据集
52 | 数据集有两个,第一个是Keras领域的代码数据,第二个是TensorFlow领域的代码数据。
53 |
54 | 代码数据中均删除了参数。
55 |
56 |
57 |
58 |
59 | ## 模型
60 | ### 深度学习环境
61 |
62 | 模型训练环境包含两种:
63 | #### tensorflow 1
64 | - python 3.5
65 | - tensorflow 1.13.1
66 | - keras 2.2.4
67 |
68 | #### tensorflow 2
69 | - python 3.7
70 | - tensorflow 2.2
71 |
72 | 从模型代码的文件名中可以得知对应的版本号,例如:
73 | > Text_Gen_tf10.ipynb
74 |
75 |
76 | 表示该代码文件为生成模型,使用的版本为tensorflow 1的版本
77 |
78 |
79 | ### 模型介绍
80 | 模型尝试了三种,分别是:
81 | - 基于长短期记忆的代码句生成模型
82 | - 基于序列到序列的代码句预测模型
83 | - 基于Transformer的代码句预测模型
84 |
85 |
86 |
87 |
88 | 注:模型文件类型为ipynb,需要用jupyter打开。
89 |
90 |
91 |
92 |
93 | ## 插件
94 |
95 | ### 1. 准备
96 |
97 | - 开发插件所用的编辑器——IDEA 测试时使用版本 IntelliJ IDEA 2019.1 x64
98 |
99 | - 插件适用对象——Pycharm 测试时使用版本 JetBrains PyCharm Community Edition 2019.1.1 x64
100 |
101 | 注:Pycharm必须安装社区版!否则不能调试。
102 |
103 | ### 2. 插件项目的导入与运行
104 |
105 | 打开IDEA,File->open->我们项目的根目录
106 |
107 | 然后需要配置:
108 |
109 | 1) 在IDEA中选择项目根目录右键Open module settings
110 |
111 | 
112 |
113 | 设置项目的SDK为本机安装的Pycharm社区版,新建一个SDK,路径选择为安装的pycharm社区版根目录
114 |
115 | 
116 |
117 | 
118 |
119 | 2) 运行项目时会启动一个Pycahrm 窗口,是带有我们这个插件效果的。
120 |
121 | 
122 |
--------------------------------------------------------------------------------
/backend/01-flask/server.py:
--------------------------------------------------------------------------------
1 | from cors import crossdomain
2 | from flask import Flask, jsonify, request
3 | from keras.preprocessing.text import Tokenizer
4 | from keras.models import load_model
5 | from keras.preprocessing.sequence import pad_sequences
6 | import json
7 |
8 | app = Flask(__name__)
9 |
10 | # total_words = 2058
11 | total_words = 0
12 | # max_sequence_len = 9
13 | max_sequence_len = 0
14 | word_index = dict()
15 | model = None
16 |
17 |
18 | def get_args(req):
19 | if request.method == 'POST':
20 | args = request.json
21 | elif request.method == "GET":
22 | args = request.args
23 | return args
24 |
25 |
26 | @app.route("/plugin_test", methods=["GET", "POST", "OPTIONS"])
27 | @crossdomain(origin='*', headers="Content-Type")
28 | def plugin_test():
29 | if model == None:
30 | init()
31 | args = get_args(request)
32 |
33 | for seed in args.keys():
34 | result = rep(generate_text(seed, model, max_sequence_len))
35 | return jsonify({"data": result})
36 |
37 |
38 | def main(host="0.0.0.0", port=9000):
39 | app.run(host=host, port=port, debug=True)
40 |
41 |
42 | def init():
43 | savePath = "model/CheckpointModel_2.h5"
44 | global word_index
45 | word_index = json.load(open('word_dict.json', 'r'))
46 | para = json.load(open('para_dict.json', 'r'))
47 | global total_words, max_sequence_len
48 | total_words = para['total_words']
49 | max_sequence_len = para['max_sequence_len']
50 | global model
51 | model = load_model(savePath)
52 |
53 |
54 | def find_index(seed_text):
55 | index_list = []
56 | seed_text = seed_text.split()
57 | for word in seed_text:
58 | if not word in word_index:
59 | index = None
60 | return
61 | index_list.append(word_index[word])
62 | return index_list
63 |
64 |
65 | def generate_text(seed_text, model, max_sequence_len):
66 | # for _ in range(next_words):
67 | seed_text = seed_text.replace("(", " pareleft", ).replace(")", " pareright").replace(".", " dot").replace(",",\
68 | " comma").replace(" \'\'", " quotMark").replace("=", " equal").replace(":", " colon")
69 | i = 0
70 | while i <= 10:
71 | token_list = find_index(seed_text)
72 | if token_list == None:
73 | return "cant predict!"
74 | # print(token_list)
75 | token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
76 | predicted = model.predict_classes(token_list, verbose=0)
77 | output_word = ""
78 | for word, index in word_index.items():
79 | if index == predicted:
80 | output_word = word
81 | break
82 | if output_word == 'dom':
83 | break
84 | seed_text += " " + output_word
85 | i += 1
86 | return seed_text
87 |
88 |
89 | def rep(str):
90 | return str.replace(" pareleft", "(", ).replace(" pareright", ")").replace(" dot", ".") \
91 | .replace(" comma", ",").replace(" quotMark ", "\'\'").replace(" equal", "=").replace(" colon", ":") \
92 | .replace(" quotmark", "\'\'")
93 |
94 |
95 | if __name__ == "__main__":
96 | main()
97 |
--------------------------------------------------------------------------------
/backend/02-crow/server.cpp:
--------------------------------------------------------------------------------
1 | #include "crow.h"
2 | #include "crow/query_string.h"
3 | #include "Python.h"
4 | #include
5 | #include
6 | #include
7 |
8 | using namespace std;
9 |
10 | //CPython编译:g++ -Wall -std=c++11 test.cpp -o 0test -fsanitize=leak -lpython3.6m
11 | //crow编译:g++ -std=c++11 serve.cpp -o server -lpython3.6m -lboost_system -lboost_filesystem -L../boost/stage/lib -pthread
12 |
13 | int main()
14 | {
15 | //step1: PyC init.
16 | Py_Initialize();
17 | if(!Py_IsInitialized()){
18 | cout << "[error]: PyC init error." << endl;
19 | return 1;
20 | }
21 | cout << "[INFO]: PyC init succeed." << endl;
22 |
23 | //step2: Export serve.py path
24 | string work_path = string("sys.path.append(\'") + "\')";
25 | PyRun_SimpleString("import sys");
26 | PyRun_SimpleString(work_path.c_str());
27 |
28 | //step3: Import serve.py
29 | PyObject* pModule = PyImport_ImportModule("serve");
30 | if(!pModule){
31 | cout << "[Error]: Import serve.py failed." << endl;
32 | return 1;
33 | }
34 | cout << "[INFO]: Import serve.py succeed." << endl;
35 |
36 | //step4: Get classes and methods from pModule
37 | PyObject *pDict = PyModule_GetDict(pModule);
38 | if(!pDict){
39 | cout << "[Error]: Get classes and methods from pModule failed." << endl;
40 | return 1;
41 | }
42 | cout << "[INFO]: Get classes and methods from pModule succeed." << endl;
43 |
44 | //step5: Get Server class
45 | PyObject* pClass = PyDict_GetItemString(pDict,"Server");
46 | if(!pClass){
47 | cout << "[Error]: Get Server class failed." << endl;
48 | return 1;
49 | }
50 | cout << "[INFO]: Get Server class succeed." << endl;
51 |
52 | //step6: Get server construct function
53 | PyObject *pConstruct = PyInstanceMethod_New(pClass);
54 | if(!pConstruct){
55 | cout << "[Error]: Get server construct function failed." << endl;
56 | return 1;
57 | }
58 | cout << "[INFO]: Get server construct function succeed." << endl;
59 |
60 | //step7: Create and init Server instance
61 | PyObject* pServer = PyObject_CallObject(pConstruct,nullptr);
62 | if(!pServer){
63 | cout << "[Error]: Create and init Server instance failed." << endl;
64 | return 1;
65 | }
66 | cout << "[INFO]: Create and init Server instance succeed." << endl;
67 |
68 | //step8: Crow start
69 | crow::SimpleApp app;
70 |
71 | CROW_ROUTE(app, "/")([](){
72 | return "Welcome to use the plugin!";
73 | });
74 |
75 | //step9: Model inference
76 | CROW_ROUTE(app, "/plugin_test")
77 | .methods("GET"_method)
78 | ([pServer](const crow::request& req){
79 | std::ostringstream os;
80 | crow::query_string keyword = req.url_params;
81 | for(char* str :keyword.get_key_value_pairs_()){
82 | if(strlen(str) == 0){
83 | cout << "Got a empty value." < 0;++i){
52 | String line = strs[strs.length - i].trim();
53 | if(line.length() > 3){
54 | key = line + "\n" + key;
55 | sentence_nums--;
56 | }
57 | }
58 | //网页会自动将空格编解码为%20
59 | key = key.replaceAll(" ","%20");
60 | System.out.println("key: " + key);
61 |
62 | try{
63 | String urlStr = "http://xx.xx.xx.xx:5000/plugin_test";
64 | JSONObject jsonObject = null;
65 | URL url = new URL(urlStr);
66 | //创建http链接
67 | HttpURLConnection httpURLConnection = (HttpURLConnection) url.openConnection();
68 | //设置请求的方法类型
69 | httpURLConnection.setRequestMethod("POST");
70 | // Post请求不能使用缓存
71 | httpURLConnection.setUseCaches(false);
72 | //设置请求的内容类型
73 | httpURLConnection.setRequestProperty("Content-type", "application/json");
74 | //设置发送数据
75 | httpURLConnection.setDoOutput(true);
76 | //设置接受数据
77 | httpURLConnection.setDoInput(true);
78 |
79 | //设置body内的参数,put到JSONObject中
80 | JSONObject param = new JSONObject();
81 | param.put("type","Text_Gen");
82 | param.put("data",key + "\n" + new_line);
83 |
84 | //param.put("type","Seq2Seq");
85 | //param.put("data",key);
86 |
87 | // 建立实际的连接
88 | httpURLConnection.connect();
89 | // 得到请求的输出流对象
90 | OutputStreamWriter writer = new OutputStreamWriter(httpURLConnection.getOutputStream(),"UTF-8");
91 | writer.write(param.toString());
92 | writer.flush();
93 |
94 | //接收数据
95 | InputStream inputStream = httpURLConnection.getInputStream();
96 | //定义字节数组
97 | byte[] b = new byte[1024];
98 | //定义一个输出流存储接收到的数据
99 | ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
100 | //开始接收数据
101 | int len = 0;
102 | while (true) {
103 | len = inputStream.read(b);
104 | if (len == -1) {
105 | //数据读完
106 | break;
107 | }
108 | byteArrayOutputStream.write(b, 0, len);
109 | }
110 | //从输出流中获取读取到数据(服务端返回的)
111 | String response = byteArrayOutputStream.toString();
112 | //System.out.println("response:"+response);
113 | jsonObject = JSONObject.fromObject(response);
114 | //遍历JSON对象
115 | Iterator iter = jsonObject.entrySet().iterator();
116 | while (iter.hasNext()) {
117 | Map.Entry entry = (Map.Entry) iter.next();
118 | //System.out.println(entry.getKey().toString());
119 | String completion_result = entry.getValue().toString();
120 | //System.out.println("data:"+completion_result);
121 | if(new_line.length() < completion_result.length()){
122 | String completion_result_start = completion_result.substring(0,new_line.length());
123 | //System.out.println("new_line: " + new_line + ",completion_result_start:"+ completion_result_start);
124 | if(new_line.equals(completion_result_start))
125 | completion_result = completion_result.substring(new_line.length());
126 | }
127 | result.addElement(LookupElementBuilder.create(completion_result.trim()));
128 | }
129 |
130 | //byteArrayOutputStream.close();
131 | // 关闭连接
132 | httpURLConnection.disconnect();
133 |
134 | }catch(IOException e){
135 | e.printStackTrace();
136 | }
137 | super.fillCompletionVariants(parameters, result);
138 | }
139 | }
140 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/model/02-Seq2Seq/Seq2Seq_tf20.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "decent-client",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import io\n",
11 | "import re\n",
12 | "import random\n",
13 | "import tensorflow as tf\n",
14 | "\n",
15 | "\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import matplotlib.ticker as ticker\n",
18 | "from sklearn.model_selection import train_test_split\n",
19 | "\n",
20 | "import unicodedata\n",
21 | "import numpy as np\n",
22 | "import os\n",
23 | "import time"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "polar-recognition",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "#GPU设置\n",
34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n",
35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n",
36 | "for gpu in gpus:\n",
37 | " tf.config.experimental.set_memory_growth(gpu, True)\n",
38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\""
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 5,
44 | "id": "parallel-adjustment",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/plain": [
50 | "'/device:GPU:0'"
51 | ]
52 | },
53 | "execution_count": 5,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "tf.test.gpu_device_name()"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "scenic-decrease",
65 | "metadata": {},
66 | "source": [
67 | "## 构建数据集"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 6,
73 | "id": "starting-toilet",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def preprocess_sentence(line): \n",
78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n",
79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
80 | " \n",
81 | " new_line = ' ' + ' '.join(line) + ' '\n",
82 | " return new_line"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 7,
88 | "id": "variable-going",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n",
93 | " input_data = []\n",
94 | " output_data = []\n",
95 | " \n",
96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n",
97 | " if num_examples == -1:\n",
98 | " num_examples = len(lines)\n",
99 | " \n",
100 | " for i in range(1,num_examples):\n",
101 | " \n",
102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n",
103 | " for rand_num in rand_nums:\n",
104 | " data = ''\n",
105 | " for j in range(i - rand_num,i):\n",
106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n",
107 | " data += line\n",
108 | " input_data.append(data.strip())\n",
109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n",
110 | " \n",
111 | " return input_data,output_data"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 8,
117 | "id": "polished-order",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "def max_length(tensor):\n",
122 | " return max(len(t) for t in tensor)\n",
123 | "\n",
124 | "def get_num_words(lang,scale):\n",
125 | " #获取词典大小\n",
126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
127 | " lang_tokenizer.fit_on_texts(lang)\n",
128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n",
129 | " return num_words\n",
130 | " \n",
131 | "def tokenize(lang,scale):\n",
132 | " num_words = get_num_words(lang,scale)\n",
133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n",
134 | " oov_token='',filters='')\n",
135 | " lang_tokenizer.fit_on_texts(lang)\n",
136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n",
137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n",
138 | " \n",
139 | " return tensor, lang_tokenizer\n",
140 | "\n",
141 | "def load_dataset(path,num_examples=None,scale=0.9):\n",
142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n",
143 | "\n",
144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n",
145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n",
146 | "\n",
147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "higher-punishment",
153 | "metadata": {},
154 | "source": [
155 | "### 限制数据集的大小以加快实验速度(可选)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 9,
161 | "id": "narrow-palmer",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# 尝试实验不同大小的数据集\n",
166 | "num_examples = -1\n",
167 | "scale = 0.95\n",
168 | "path = \"../00-data/tf_data.txt\"\n",
169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n",
170 | "\n",
171 | "# 计算目标张量的最大长度 (max_length)\n",
172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 10,
178 | "id": "convenient-string",
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "data": {
183 | "text/plain": [
184 | "109453"
185 | ]
186 | },
187 | "execution_count": 10,
188 | "metadata": {},
189 | "output_type": "execute_result"
190 | }
191 | ],
192 | "source": [
193 | "len(input_tensor)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 11,
199 | "id": "significant-oxygen",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "109453"
206 | ]
207 | },
208 | "execution_count": 11,
209 | "metadata": {},
210 | "output_type": "execute_result"
211 | }
212 | ],
213 | "source": [
214 | "len(target_tensor)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 31,
220 | "id": "convinced-grounds",
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "2501 2501 278 278\n"
228 | ]
229 | }
230 | ],
231 | "source": [
232 | "\"\"\"目前先不采用\"\"\"\n",
233 | "# 采用 90 - 10 的比例切分训练集和验证集\n",
234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n",
235 | "\n",
236 | "# 显示长度\n",
237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "id": "incorrect-software",
243 | "metadata": {},
244 | "source": [
245 | "### 创建一个 tf.data 数据集"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 12,
251 | "id": "hungry-lafayette",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "BUFFER_SIZE = len(input_tensor)\n",
256 | "BATCH_SIZE = 64\n",
257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n",
258 | "embedding_dim = 256\n",
259 | "units = 256\n",
260 | "vocab_inp_size = len(inp_lang.word_index)+1\n",
261 | "vocab_tar_size = len(targ_lang.word_index)+1\n",
262 | "\n",
263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n",
264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 13,
270 | "id": "requested-township",
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "data": {
275 | "text/plain": [
276 | "(TensorShape([64, 597]), TensorShape([64, 245]))"
277 | ]
278 | },
279 | "execution_count": 13,
280 | "metadata": {},
281 | "output_type": "execute_result"
282 | }
283 | ],
284 | "source": [
285 | "example_input_batch, example_target_batch = next(iter(dataset))\n",
286 | "example_input_batch.shape, example_target_batch.shape"
287 | ]
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "id": "trying-channel",
292 | "metadata": {},
293 | "source": [
294 | "## 编写编码器encoder和解码器decoder模型"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 14,
300 | "id": "consecutive-literacy",
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "class Encoder(tf.keras.Model):\n",
305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n",
306 | " super(Encoder, self).__init__()\n",
307 | " self.batch_size = batch_size\n",
308 | " self.enc_units = enc_units\n",
309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n",
310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n",
311 | " return_sequences=True,\n",
312 | " return_state=True,\n",
313 | " dropout=0.1,\n",
314 | " recurrent_dropout=0.1)\n",
315 | "\n",
316 | " def call(self, x):\n",
317 | " x = self.embedding(x)\n",
318 | " output_l1,state_l1,_ = self.lstm(x)\n",
319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n",
320 | " return output_l1,state_l1\n",
321 | "\n",
322 | " def initialize_hidden_state(self):\n",
323 | " return tf.zeros((self.batch_size, self.enc_units))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 15,
329 | "id": "offensive-survivor",
330 | "metadata": {
331 | "scrolled": true
332 | },
333 | "outputs": [
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n",
339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n"
340 | ]
341 | }
342 | ],
343 | "source": [
344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
345 | "\n",
346 | "# 样本输入\n",
347 | "sample_hidden = encoder.initialize_hidden_state()\n",
348 | "sample_output,sample_hidden = encoder(example_input_batch)\n",
349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 16,
356 | "id": "therapeutic-favor",
357 | "metadata": {},
358 | "outputs": [],
359 | "source": [
360 | "class BahdanauAttention(tf.keras.layers.Layer):\n",
361 | " def __init__(self, units):\n",
362 | " super(BahdanauAttention, self).__init__()\n",
363 | " self.W1 = tf.keras.layers.Dense(units)\n",
364 | " self.W2 = tf.keras.layers.Dense(units)\n",
365 | " self.V = tf.keras.layers.Dense(1)\n",
366 | "\n",
367 | " def call(self, query, values):\n",
368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n",
369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n",
370 | " # 这样做是为了执行加法以计算分数 \n",
371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n",
372 | "\n",
373 | " # 分数的形状 == (批大小,最大长度,1)\n",
374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n",
375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n",
376 | " score = self.V(tf.nn.tanh(\n",
377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n",
378 | "\n",
379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n",
380 | " attention_weights = tf.nn.softmax(score, axis=1)\n",
381 | "\n",
382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n",
383 | " context_vector = attention_weights * values\n",
384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
385 | "\n",
386 | " return context_vector, attention_weights"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 17,
392 | "id": "under-swaziland",
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "Attention result shape: (batch size, units) (64, 256)\n",
400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "attention_layer = BahdanauAttention(10)\n",
406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
407 | "\n",
408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 18,
415 | "id": "smoking-horizon",
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "class Decoder(tf.keras.Model):\n",
420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
421 | " super(Decoder, self).__init__()\n",
422 | " self.batch_sz = batch_sz\n",
423 | " self.dec_units = dec_units\n",
424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
426 | " return_sequences=True,\n",
427 | " return_state=True,\n",
428 | " dropout=0.1,\n",
429 | " recurrent_dropout=0.1)\n",
430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
431 | " return_sequences=True,\n",
432 | " return_state=True)\n",
433 | " \n",
434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n",
435 | "\n",
436 | " # 用于注意力\n",
437 | " self.attention = BahdanauAttention(self.dec_units)\n",
438 | "\n",
439 | " def call(self, x, hidden, enc_output,is_train=True):\n",
440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n",
441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n",
442 | "\n",
443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n",
444 | " x = self.embedding(x)\n",
445 | "\n",
446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n",
447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
448 | "\n",
449 | " # 将合并后的向量传送到 LSTM\n",
450 | " if is_train:\n",
451 | " output, state,_ = self.train_LSTM(x)\n",
452 | " else:\n",
453 | " output, state,_ = self.infer_LSTM(x)\n",
454 | "\n",
455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n",
456 | " output = tf.reshape(output, (-1, output.shape[2]))\n",
457 | "\n",
458 | " # 输出的形状 == (批大小,vocab)\n",
459 | " x = self.fc(output)\n",
460 | "\n",
461 | " return x, state, attention_weights"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 19,
467 | "id": "straight-sending",
468 | "metadata": {},
469 | "outputs": [
470 | {
471 | "name": "stdout",
472 | "output_type": "stream",
473 | "text": [
474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n"
475 | ]
476 | }
477 | ],
478 | "source": [
479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
480 | "\n",
481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n",
482 | " sample_hidden, sample_output,is_train=True)\n",
483 | "\n",
484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 20,
490 | "id": "compliant-directive",
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "optimizer = tf.keras.optimizers.Adam()\n",
495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
496 | " from_logits=True, reduction='none')\n",
497 | "\n",
498 | "def loss_function(real, pred):\n",
499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
500 | " loss_ = loss_object(real, pred)\n",
501 | "\n",
502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n",
503 | " loss_ *= mask\n",
504 | "\n",
505 | " return tf.reduce_mean(loss_)"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 21,
511 | "id": "adult-grocery",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "checkpoint_dir = '../02-checkpoints/'\n",
516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
518 | " encoder=encoder,\n",
519 | " decoder=decoder)"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 22,
525 | "id": "charming-apparatus",
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "@tf.function\n",
530 | "def train_step(inp,targ,enc_hidden):\n",
531 | " loss = 0\n",
532 | " \n",
533 | " with tf.GradientTape() as tape:\n",
534 | " enc_output,enc_hidden = encoder(inp)\n",
535 | " dec_hidden = enc_hidden\n",
536 | " \n",
537 | " dec_hidden = enc_hidden\n",
538 | " \n",
539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n",
540 | " \n",
541 | " for t in range(1,targ.shape[1]):\n",
542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n",
543 | " loss += loss_function(targ[:,t],predictions)\n",
544 | " \n",
545 | " dec_input = tf.expand_dims(targ[:,t],1)\n",
546 | " \n",
547 | " batch_loss = (loss / int(targ.shape[1]))\n",
548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n",
549 | " gradients = tape.gradient(loss,variables)\n",
550 | " optimizer.apply_gradients(zip(gradients,variables))\n",
551 | " \n",
552 | " return batch_loss"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": null,
558 | "id": "based-wound",
559 | "metadata": {},
560 | "outputs": [],
561 | "source": [
562 | "EPOCHS = 30\n",
563 | "\n",
564 | "for epoch in range(EPOCHS):\n",
565 | " start = time.time()\n",
566 | "\n",
567 | " enc_hidden = encoder.initialize_hidden_state()\n",
568 | " total_loss = 0\n",
569 | "\n",
570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
571 | " batch_loss = train_step(inp, targ, enc_hidden)\n",
572 | " total_loss += batch_loss\n",
573 | "\n",
574 | " #if batch % 5 == 0:\n",
575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
576 | " batch,\n",
577 | " batch_loss.numpy()))\n",
578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n",
579 | " if (epoch + 1) % 5 == 0:\n",
580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n",
581 | "\n",
582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
583 | " total_loss / steps_per_epoch))\n",
584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "id": "terminal-suite",
590 | "metadata": {},
591 | "source": [
592 | "## 预测"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 13,
598 | "id": "satellite-seating",
599 | "metadata": {},
600 | "outputs": [],
601 | "source": [
602 | "def create_inference_data(lines):\n",
603 | " data = \"\"\n",
604 | " for line in lines:\n",
605 | " line = preprocess_sentence(line.strip()) + ' '\n",
606 | " data += line\n",
607 | " return data.strip()"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": 14,
613 | "id": "cellular-impossible",
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "def evaluate(sentence):\n",
618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
619 | "\n",
620 | " sentence = create_inference_data(sentence)\n",
621 | "\n",
622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
624 | " maxlen=max_length_inp,\n",
625 | " padding='post')\n",
626 | " inputs = tf.convert_to_tensor(inputs)\n",
627 | "\n",
628 | " result = ''\n",
629 | "\n",
630 | " hidden = [tf.zeros((1, units))]\n",
631 | " enc_out, enc_hidden = encoder(inputs)\n",
632 | "\n",
633 | " dec_hidden = enc_hidden\n",
634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n",
635 | "\n",
636 | " for t in range(max_length_targ):\n",
637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
638 | " dec_hidden,\n",
639 | " enc_out,\n",
640 | " is_train=False)\n",
641 | "\n",
642 | " # 存储注意力权重以便后面制图\n",
643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n",
644 | " attention_plot[t] = attention_weights.numpy()\n",
645 | "\n",
646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n",
647 | "\n",
648 | " result += targ_lang.index_word[predicted_id] + ' '\n",
649 | "\n",
650 | " if targ_lang.index_word[predicted_id] == '':\n",
651 | " return result, sentence, attention_plot\n",
652 | "\n",
653 | " # 预测的 ID 被输送回模型\n",
654 | " dec_input = tf.expand_dims([predicted_id], 0)\n",
655 | "\n",
656 | " return result, sentence, attention_plot"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 21,
662 | "id": "current-feedback",
663 | "metadata": {},
664 | "outputs": [],
665 | "source": [
666 | "# 注意力权重制图函数\n",
667 | "def plot_attention(attention, sentence, predicted_sentence):\n",
668 | " fig = plt.figure(figsize=(20,20))\n",
669 | " ax = fig.add_subplot(1, 1, 1)\n",
670 | " ax.matshow(attention, cmap='viridis')\n",
671 | "\n",
672 | " fontdict = {'fontsize': 10}\n",
673 | "\n",
674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n",
675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n",
676 | "\n",
677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
679 | "\n",
680 | " plt.show()"
681 | ]
682 | },
683 | {
684 | "cell_type": "code",
685 | "execution_count": 22,
686 | "id": "worse-logging",
687 | "metadata": {},
688 | "outputs": [],
689 | "source": [
690 | "def translate(sentence):\n",
691 | " result,sentence,attention_plot = evaluate(sentence)\n",
692 | " \n",
693 | " print('Input: %s' % (sentence))\n",
694 | " print('Predicted translation: {}'.format(result))\n",
695 | " \n",
696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n",
697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n",
698 | " print(result.split(' '))"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 29,
704 | "id": "integrated-victoria",
705 | "metadata": {},
706 | "outputs": [],
707 | "source": [
708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n",
709 | "lines = create_inference_data(lines)"
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": 30,
715 | "id": "numerous-advisory",
716 | "metadata": {
717 | "scrolled": false
718 | },
719 | "outputs": [
720 | {
721 | "data": {
722 | "text/plain": [
723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '"
724 | ]
725 | },
726 | "execution_count": 30,
727 | "metadata": {},
728 | "output_type": "execute_result"
729 | }
730 | ],
731 | "source": [
732 | "lines"
733 | ]
734 | }
735 | ],
736 | "metadata": {
737 | "kernelspec": {
738 | "display_name": "myconda",
739 | "language": "python",
740 | "name": "myconda"
741 | },
742 | "language_info": {
743 | "codemirror_mode": {
744 | "name": "ipython",
745 | "version": 3
746 | },
747 | "file_extension": ".py",
748 | "mimetype": "text/x-python",
749 | "name": "python",
750 | "nbconvert_exporter": "python",
751 | "pygments_lexer": "ipython3",
752 | "version": "3.7.7"
753 | }
754 | },
755 | "nbformat": 4,
756 | "nbformat_minor": 5
757 | }
758 |
--------------------------------------------------------------------------------
/model/02-Seq2Seq/.~Seq2Seq.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "decent-client",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import io\n",
11 | "import re\n",
12 | "import random\n",
13 | "import tensorflow as tf\n",
14 | "\n",
15 | "\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import matplotlib.ticker as ticker\n",
18 | "from sklearn.model_selection import train_test_split\n",
19 | "\n",
20 | "import unicodedata\n",
21 | "import numpy as np\n",
22 | "import os\n",
23 | "import time"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "polar-recognition",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "#GPU设置\n",
34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n",
35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n",
36 | "for gpu in gpus:\n",
37 | " tf.config.experimental.set_memory_growth(gpu, True)\n",
38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\""
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 5,
44 | "id": "parallel-adjustment",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/plain": [
50 | "'/device:GPU:0'"
51 | ]
52 | },
53 | "execution_count": 5,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "tf.test.gpu_device_name()"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "scenic-decrease",
65 | "metadata": {},
66 | "source": [
67 | "## 构建数据集"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 6,
73 | "id": "starting-toilet",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def preprocess_sentence(line): \n",
78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n",
79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
80 | " \n",
81 | " new_line = ' ' + ' '.join(line) + ' '\n",
82 | " return new_line"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 7,
88 | "id": "variable-going",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n",
93 | " input_data = []\n",
94 | " output_data = []\n",
95 | " \n",
96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n",
97 | " if num_examples == -1:\n",
98 | " num_examples = len(lines)\n",
99 | " \n",
100 | " for i in range(1,num_examples):\n",
101 | " \n",
102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n",
103 | " for rand_num in rand_nums:\n",
104 | " data = ''\n",
105 | " for j in range(i - rand_num,i):\n",
106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n",
107 | " data += line\n",
108 | " input_data.append(data.strip())\n",
109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n",
110 | " \n",
111 | " return input_data,output_data"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 8,
117 | "id": "polished-order",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "def max_length(tensor):\n",
122 | " return max(len(t) for t in tensor)\n",
123 | "\n",
124 | "def get_num_words(lang,scale):\n",
125 | " #获取词典大小\n",
126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
127 | " lang_tokenizer.fit_on_texts(lang)\n",
128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n",
129 | " return num_words\n",
130 | " \n",
131 | "def tokenize(lang,scale):\n",
132 | " num_words = get_num_words(lang,scale)\n",
133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n",
134 | " oov_token='',filters='')\n",
135 | " lang_tokenizer.fit_on_texts(lang)\n",
136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n",
137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n",
138 | " \n",
139 | " return tensor, lang_tokenizer\n",
140 | "\n",
141 | "def load_dataset(path,num_examples=None,scale=0.9):\n",
142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n",
143 | "\n",
144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n",
145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n",
146 | "\n",
147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "higher-punishment",
153 | "metadata": {},
154 | "source": [
155 | "### 限制数据集的大小以加快实验速度(可选)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 9,
161 | "id": "narrow-palmer",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# 尝试实验不同大小的数据集\n",
166 | "num_examples = -1\n",
167 | "scale = 0.95\n",
168 | "path = \"../00-data/tf_data.txt\"\n",
169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n",
170 | "\n",
171 | "# 计算目标张量的最大长度 (max_length)\n",
172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 10,
178 | "id": "convenient-string",
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "data": {
183 | "text/plain": [
184 | "109453"
185 | ]
186 | },
187 | "execution_count": 10,
188 | "metadata": {},
189 | "output_type": "execute_result"
190 | }
191 | ],
192 | "source": [
193 | "len(input_tensor)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 11,
199 | "id": "significant-oxygen",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "109453"
206 | ]
207 | },
208 | "execution_count": 11,
209 | "metadata": {},
210 | "output_type": "execute_result"
211 | }
212 | ],
213 | "source": [
214 | "len(target_tensor)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 31,
220 | "id": "convinced-grounds",
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "2501 2501 278 278\n"
228 | ]
229 | }
230 | ],
231 | "source": [
232 | "\"\"\"目前先不采用\"\"\"\n",
233 | "# 采用 90 - 10 的比例切分训练集和验证集\n",
234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n",
235 | "\n",
236 | "# 显示长度\n",
237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "id": "incorrect-software",
243 | "metadata": {},
244 | "source": [
245 | "### 创建一个 tf.data 数据集"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 12,
251 | "id": "hungry-lafayette",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "BUFFER_SIZE = len(input_tensor)\n",
256 | "BATCH_SIZE = 64\n",
257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n",
258 | "embedding_dim = 256\n",
259 | "units = 256\n",
260 | "vocab_inp_size = len(inp_lang.word_index)+1\n",
261 | "vocab_tar_size = len(targ_lang.word_index)+1\n",
262 | "\n",
263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n",
264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 13,
270 | "id": "requested-township",
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "data": {
275 | "text/plain": [
276 | "(TensorShape([64, 597]), TensorShape([64, 245]))"
277 | ]
278 | },
279 | "execution_count": 13,
280 | "metadata": {},
281 | "output_type": "execute_result"
282 | }
283 | ],
284 | "source": [
285 | "example_input_batch, example_target_batch = next(iter(dataset))\n",
286 | "example_input_batch.shape, example_target_batch.shape"
287 | ]
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "id": "trying-channel",
292 | "metadata": {},
293 | "source": [
294 | "## 编写编码器encoder和解码器decoder模型"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 14,
300 | "id": "consecutive-literacy",
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "class Encoder(tf.keras.Model):\n",
305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n",
306 | " super(Encoder, self).__init__()\n",
307 | " self.batch_size = batch_size\n",
308 | " self.enc_units = enc_units\n",
309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n",
310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n",
311 | " return_sequences=True,\n",
312 | " return_state=True,\n",
313 | " dropout=0.1,\n",
314 | " recurrent_dropout=0.1)\n",
315 | "\n",
316 | " def call(self, x):\n",
317 | " x = self.embedding(x)\n",
318 | " output_l1,state_l1,_ = self.lstm(x)\n",
319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n",
320 | " return output_l1,state_l1\n",
321 | "\n",
322 | " def initialize_hidden_state(self):\n",
323 | " return tf.zeros((self.batch_size, self.enc_units))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 15,
329 | "id": "offensive-survivor",
330 | "metadata": {
331 | "scrolled": true
332 | },
333 | "outputs": [
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n",
339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n"
340 | ]
341 | }
342 | ],
343 | "source": [
344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
345 | "\n",
346 | "# 样本输入\n",
347 | "sample_hidden = encoder.initialize_hidden_state()\n",
348 | "sample_output,sample_hidden = encoder(example_input_batch)\n",
349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 16,
356 | "id": "therapeutic-favor",
357 | "metadata": {},
358 | "outputs": [],
359 | "source": [
360 | "class BahdanauAttention(tf.keras.layers.Layer):\n",
361 | " def __init__(self, units):\n",
362 | " super(BahdanauAttention, self).__init__()\n",
363 | " self.W1 = tf.keras.layers.Dense(units)\n",
364 | " self.W2 = tf.keras.layers.Dense(units)\n",
365 | " self.V = tf.keras.layers.Dense(1)\n",
366 | "\n",
367 | " def call(self, query, values):\n",
368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n",
369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n",
370 | " # 这样做是为了执行加法以计算分数 \n",
371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n",
372 | "\n",
373 | " # 分数的形状 == (批大小,最大长度,1)\n",
374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n",
375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n",
376 | " score = self.V(tf.nn.tanh(\n",
377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n",
378 | "\n",
379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n",
380 | " attention_weights = tf.nn.softmax(score, axis=1)\n",
381 | "\n",
382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n",
383 | " context_vector = attention_weights * values\n",
384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
385 | "\n",
386 | " return context_vector, attention_weights"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 17,
392 | "id": "under-swaziland",
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "Attention result shape: (batch size, units) (64, 256)\n",
400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "attention_layer = BahdanauAttention(10)\n",
406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
407 | "\n",
408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 18,
415 | "id": "smoking-horizon",
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "class Decoder(tf.keras.Model):\n",
420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
421 | " super(Decoder, self).__init__()\n",
422 | " self.batch_sz = batch_sz\n",
423 | " self.dec_units = dec_units\n",
424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
426 | " return_sequences=True,\n",
427 | " return_state=True,\n",
428 | " dropout=0.1,\n",
429 | " recurrent_dropout=0.1)\n",
430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
431 | " return_sequences=True,\n",
432 | " return_state=True)\n",
433 | " \n",
434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n",
435 | "\n",
436 | " # 用于注意力\n",
437 | " self.attention = BahdanauAttention(self.dec_units)\n",
438 | "\n",
439 | " def call(self, x, hidden, enc_output,is_train=True):\n",
440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n",
441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n",
442 | "\n",
443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n",
444 | " x = self.embedding(x)\n",
445 | "\n",
446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n",
447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
448 | "\n",
449 | " # 将合并后的向量传送到 LSTM\n",
450 | " if is_train:\n",
451 | " output, state,_ = self.train_LSTM(x)\n",
452 | " else:\n",
453 | " output, state,_ = self.infer_LSTM(x)\n",
454 | "\n",
455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n",
456 | " output = tf.reshape(output, (-1, output.shape[2]))\n",
457 | "\n",
458 | " # 输出的形状 == (批大小,vocab)\n",
459 | " x = self.fc(output)\n",
460 | "\n",
461 | " return x, state, attention_weights"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 19,
467 | "id": "straight-sending",
468 | "metadata": {},
469 | "outputs": [
470 | {
471 | "name": "stdout",
472 | "output_type": "stream",
473 | "text": [
474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n"
475 | ]
476 | }
477 | ],
478 | "source": [
479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
480 | "\n",
481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n",
482 | " sample_hidden, sample_output,is_train=True)\n",
483 | "\n",
484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 20,
490 | "id": "compliant-directive",
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "optimizer = tf.keras.optimizers.Adam()\n",
495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
496 | " from_logits=True, reduction='none')\n",
497 | "\n",
498 | "def loss_function(real, pred):\n",
499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
500 | " loss_ = loss_object(real, pred)\n",
501 | "\n",
502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n",
503 | " loss_ *= mask\n",
504 | "\n",
505 | " return tf.reduce_mean(loss_)"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 21,
511 | "id": "adult-grocery",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "checkpoint_dir = '../02-checkpoints/'\n",
516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
518 | " encoder=encoder,\n",
519 | " decoder=decoder)"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 22,
525 | "id": "charming-apparatus",
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "@tf.function\n",
530 | "def train_step(inp,targ,enc_hidden):\n",
531 | " loss = 0\n",
532 | " \n",
533 | " with tf.GradientTape() as tape:\n",
534 | " enc_output,enc_hidden = encoder(inp)\n",
535 | " dec_hidden = enc_hidden\n",
536 | " \n",
537 | " dec_hidden = enc_hidden\n",
538 | " \n",
539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n",
540 | " \n",
541 | " for t in range(1,targ.shape[1]):\n",
542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n",
543 | " loss += loss_function(targ[:,t],predictions)\n",
544 | " \n",
545 | " dec_input = tf.expand_dims(targ[:,t],1)\n",
546 | " \n",
547 | " batch_loss = (loss / int(targ.shape[1]))\n",
548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n",
549 | " gradients = tape.gradient(loss,variables)\n",
550 | " optimizer.apply_gradients(zip(gradients,variables))\n",
551 | " \n",
552 | " return batch_loss"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": null,
558 | "id": "based-wound",
559 | "metadata": {},
560 | "outputs": [],
561 | "source": [
562 | "EPOCHS = 30\n",
563 | "\n",
564 | "for epoch in range(EPOCHS):\n",
565 | " start = time.time()\n",
566 | "\n",
567 | " enc_hidden = encoder.initialize_hidden_state()\n",
568 | " total_loss = 0\n",
569 | "\n",
570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
571 | " batch_loss = train_step(inp, targ, enc_hidden)\n",
572 | " total_loss += batch_loss\n",
573 | "\n",
574 | " #if batch % 5 == 0:\n",
575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
576 | " batch,\n",
577 | " batch_loss.numpy()))\n",
578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n",
579 | " if (epoch + 1) % 5 == 0:\n",
580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n",
581 | "\n",
582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
583 | " total_loss / steps_per_epoch))\n",
584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "id": "terminal-suite",
590 | "metadata": {},
591 | "source": [
592 | "## 预测"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 13,
598 | "id": "satellite-seating",
599 | "metadata": {},
600 | "outputs": [],
601 | "source": [
602 | "def create_inference_data(lines):\n",
603 | " data = \"\"\n",
604 | " for line in lines:\n",
605 | " line = preprocess_sentence(line.strip()) + ' '\n",
606 | " data += line\n",
607 | " return data.strip()"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": 14,
613 | "id": "cellular-impossible",
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "def evaluate(sentence):\n",
618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
619 | "\n",
620 | " sentence = create_inference_data(sentence)\n",
621 | "\n",
622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
624 | " maxlen=max_length_inp,\n",
625 | " padding='post')\n",
626 | " inputs = tf.convert_to_tensor(inputs)\n",
627 | "\n",
628 | " result = ''\n",
629 | "\n",
630 | " hidden = [tf.zeros((1, units))]\n",
631 | " enc_out, enc_hidden = encoder(inputs)\n",
632 | "\n",
633 | " dec_hidden = enc_hidden\n",
634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n",
635 | "\n",
636 | " for t in range(max_length_targ):\n",
637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
638 | " dec_hidden,\n",
639 | " enc_out,\n",
640 | " is_train=False)\n",
641 | "\n",
642 | " # 存储注意力权重以便后面制图\n",
643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n",
644 | " attention_plot[t] = attention_weights.numpy()\n",
645 | "\n",
646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n",
647 | "\n",
648 | " result += targ_lang.index_word[predicted_id] + ' '\n",
649 | "\n",
650 | " if targ_lang.index_word[predicted_id] == '':\n",
651 | " return result, sentence, attention_plot\n",
652 | "\n",
653 | " # 预测的 ID 被输送回模型\n",
654 | " dec_input = tf.expand_dims([predicted_id], 0)\n",
655 | "\n",
656 | " return result, sentence, attention_plot"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 21,
662 | "id": "current-feedback",
663 | "metadata": {},
664 | "outputs": [],
665 | "source": [
666 | "# 注意力权重制图函数\n",
667 | "def plot_attention(attention, sentence, predicted_sentence):\n",
668 | " fig = plt.figure(figsize=(20,20))\n",
669 | " ax = fig.add_subplot(1, 1, 1)\n",
670 | " ax.matshow(attention, cmap='viridis')\n",
671 | "\n",
672 | " fontdict = {'fontsize': 10}\n",
673 | "\n",
674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n",
675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n",
676 | "\n",
677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
679 | "\n",
680 | " plt.show()"
681 | ]
682 | },
683 | {
684 | "cell_type": "code",
685 | "execution_count": 22,
686 | "id": "worse-logging",
687 | "metadata": {},
688 | "outputs": [],
689 | "source": [
690 | "def translate(sentence):\n",
691 | " result,sentence,attention_plot = evaluate(sentence)\n",
692 | " \n",
693 | " print('Input: %s' % (sentence))\n",
694 | " print('Predicted translation: {}'.format(result))\n",
695 | " \n",
696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n",
697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n",
698 | " print(result.split(' '))"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 29,
704 | "id": "integrated-victoria",
705 | "metadata": {},
706 | "outputs": [],
707 | "source": [
708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n",
709 | "lines = create_inference_data(lines)"
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": 30,
715 | "id": "numerous-advisory",
716 | "metadata": {
717 | "scrolled": false
718 | },
719 | "outputs": [
720 | {
721 | "data": {
722 | "text/plain": [
723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '"
724 | ]
725 | },
726 | "execution_count": 30,
727 | "metadata": {},
728 | "output_type": "execute_result"
729 | }
730 | ],
731 | "source": [
732 | "lines"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": null,
738 | "id": "structured-planet",
739 | "metadata": {},
740 | "outputs": [],
741 | "source": []
742 | }
743 | ],
744 | "metadata": {
745 | "kernelspec": {
746 | "display_name": "myconda",
747 | "language": "python",
748 | "name": "myconda"
749 | },
750 | "language_info": {
751 | "codemirror_mode": {
752 | "name": "ipython",
753 | "version": 3
754 | },
755 | "file_extension": ".py",
756 | "mimetype": "text/x-python",
757 | "name": "python",
758 | "nbconvert_exporter": "python",
759 | "pygments_lexer": "ipython3",
760 | "version": "3.7.7"
761 | }
762 | },
763 | "nbformat": 4,
764 | "nbformat_minor": 5
765 | }
766 |
--------------------------------------------------------------------------------
/model/02-Seq2Seq/.ipynb_checkpoints/Seq2Seq-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "decent-client",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import io\n",
11 | "import re\n",
12 | "import random\n",
13 | "import tensorflow as tf\n",
14 | "\n",
15 | "\n",
16 | "import matplotlib.pyplot as plt\n",
17 | "import matplotlib.ticker as ticker\n",
18 | "from sklearn.model_selection import train_test_split\n",
19 | "\n",
20 | "import unicodedata\n",
21 | "import numpy as np\n",
22 | "import os\n",
23 | "import time"
24 | ]
25 | },
26 | {
27 | "cell_type": "code",
28 | "execution_count": 2,
29 | "id": "polar-recognition",
30 | "metadata": {},
31 | "outputs": [],
32 | "source": [
33 | "#GPU设置\n",
34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n",
35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n",
36 | "for gpu in gpus:\n",
37 | " tf.config.experimental.set_memory_growth(gpu, True)\n",
38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\""
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": 5,
44 | "id": "parallel-adjustment",
45 | "metadata": {},
46 | "outputs": [
47 | {
48 | "data": {
49 | "text/plain": [
50 | "'/device:GPU:0'"
51 | ]
52 | },
53 | "execution_count": 5,
54 | "metadata": {},
55 | "output_type": "execute_result"
56 | }
57 | ],
58 | "source": [
59 | "tf.test.gpu_device_name()"
60 | ]
61 | },
62 | {
63 | "cell_type": "markdown",
64 | "id": "scenic-decrease",
65 | "metadata": {},
66 | "source": [
67 | "## 构建数据集"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": 6,
73 | "id": "starting-toilet",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "def preprocess_sentence(line): \n",
78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n",
79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
80 | " \n",
81 | " new_line = ' ' + ' '.join(line) + ' '\n",
82 | " return new_line"
83 | ]
84 | },
85 | {
86 | "cell_type": "code",
87 | "execution_count": 7,
88 | "id": "variable-going",
89 | "metadata": {},
90 | "outputs": [],
91 | "source": [
92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n",
93 | " input_data = []\n",
94 | " output_data = []\n",
95 | " \n",
96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n",
97 | " if num_examples == -1:\n",
98 | " num_examples = len(lines)\n",
99 | " \n",
100 | " for i in range(1,num_examples):\n",
101 | " \n",
102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n",
103 | " for rand_num in rand_nums:\n",
104 | " data = ''\n",
105 | " for j in range(i - rand_num,i):\n",
106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n",
107 | " data += line\n",
108 | " input_data.append(data.strip())\n",
109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n",
110 | " \n",
111 | " return input_data,output_data"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 8,
117 | "id": "polished-order",
118 | "metadata": {},
119 | "outputs": [],
120 | "source": [
121 | "def max_length(tensor):\n",
122 | " return max(len(t) for t in tensor)\n",
123 | "\n",
124 | "def get_num_words(lang,scale):\n",
125 | " #获取词典大小\n",
126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n",
127 | " lang_tokenizer.fit_on_texts(lang)\n",
128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n",
129 | " return num_words\n",
130 | " \n",
131 | "def tokenize(lang,scale):\n",
132 | " num_words = get_num_words(lang,scale)\n",
133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n",
134 | " oov_token='',filters='')\n",
135 | " lang_tokenizer.fit_on_texts(lang)\n",
136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n",
137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n",
138 | " \n",
139 | " return tensor, lang_tokenizer\n",
140 | "\n",
141 | "def load_dataset(path,num_examples=None,scale=0.9):\n",
142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n",
143 | "\n",
144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n",
145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n",
146 | "\n",
147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer"
148 | ]
149 | },
150 | {
151 | "cell_type": "markdown",
152 | "id": "higher-punishment",
153 | "metadata": {},
154 | "source": [
155 | "### 限制数据集的大小以加快实验速度(可选)"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": 9,
161 | "id": "narrow-palmer",
162 | "metadata": {},
163 | "outputs": [],
164 | "source": [
165 | "# 尝试实验不同大小的数据集\n",
166 | "num_examples = -1\n",
167 | "scale = 0.95\n",
168 | "path = \"../00-data/tf_data.txt\"\n",
169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n",
170 | "\n",
171 | "# 计算目标张量的最大长度 (max_length)\n",
172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)"
173 | ]
174 | },
175 | {
176 | "cell_type": "code",
177 | "execution_count": 10,
178 | "id": "convenient-string",
179 | "metadata": {},
180 | "outputs": [
181 | {
182 | "data": {
183 | "text/plain": [
184 | "109453"
185 | ]
186 | },
187 | "execution_count": 10,
188 | "metadata": {},
189 | "output_type": "execute_result"
190 | }
191 | ],
192 | "source": [
193 | "len(input_tensor)"
194 | ]
195 | },
196 | {
197 | "cell_type": "code",
198 | "execution_count": 11,
199 | "id": "significant-oxygen",
200 | "metadata": {},
201 | "outputs": [
202 | {
203 | "data": {
204 | "text/plain": [
205 | "109453"
206 | ]
207 | },
208 | "execution_count": 11,
209 | "metadata": {},
210 | "output_type": "execute_result"
211 | }
212 | ],
213 | "source": [
214 | "len(target_tensor)"
215 | ]
216 | },
217 | {
218 | "cell_type": "code",
219 | "execution_count": 31,
220 | "id": "convinced-grounds",
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "2501 2501 278 278\n"
228 | ]
229 | }
230 | ],
231 | "source": [
232 | "\"\"\"目前先不采用\"\"\"\n",
233 | "# 采用 90 - 10 的比例切分训练集和验证集\n",
234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n",
235 | "\n",
236 | "# 显示长度\n",
237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))"
238 | ]
239 | },
240 | {
241 | "cell_type": "markdown",
242 | "id": "incorrect-software",
243 | "metadata": {},
244 | "source": [
245 | "### 创建一个 tf.data 数据集"
246 | ]
247 | },
248 | {
249 | "cell_type": "code",
250 | "execution_count": 12,
251 | "id": "hungry-lafayette",
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "BUFFER_SIZE = len(input_tensor)\n",
256 | "BATCH_SIZE = 64\n",
257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n",
258 | "embedding_dim = 256\n",
259 | "units = 256\n",
260 | "vocab_inp_size = len(inp_lang.word_index)+1\n",
261 | "vocab_tar_size = len(targ_lang.word_index)+1\n",
262 | "\n",
263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n",
264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)"
265 | ]
266 | },
267 | {
268 | "cell_type": "code",
269 | "execution_count": 13,
270 | "id": "requested-township",
271 | "metadata": {},
272 | "outputs": [
273 | {
274 | "data": {
275 | "text/plain": [
276 | "(TensorShape([64, 597]), TensorShape([64, 245]))"
277 | ]
278 | },
279 | "execution_count": 13,
280 | "metadata": {},
281 | "output_type": "execute_result"
282 | }
283 | ],
284 | "source": [
285 | "example_input_batch, example_target_batch = next(iter(dataset))\n",
286 | "example_input_batch.shape, example_target_batch.shape"
287 | ]
288 | },
289 | {
290 | "cell_type": "markdown",
291 | "id": "trying-channel",
292 | "metadata": {},
293 | "source": [
294 | "## 编写编码器encoder和解码器decoder模型"
295 | ]
296 | },
297 | {
298 | "cell_type": "code",
299 | "execution_count": 14,
300 | "id": "consecutive-literacy",
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "class Encoder(tf.keras.Model):\n",
305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n",
306 | " super(Encoder, self).__init__()\n",
307 | " self.batch_size = batch_size\n",
308 | " self.enc_units = enc_units\n",
309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n",
310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n",
311 | " return_sequences=True,\n",
312 | " return_state=True,\n",
313 | " dropout=0.1,\n",
314 | " recurrent_dropout=0.1)\n",
315 | "\n",
316 | " def call(self, x):\n",
317 | " x = self.embedding(x)\n",
318 | " output_l1,state_l1,_ = self.lstm(x)\n",
319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n",
320 | " return output_l1,state_l1\n",
321 | "\n",
322 | " def initialize_hidden_state(self):\n",
323 | " return tf.zeros((self.batch_size, self.enc_units))"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 15,
329 | "id": "offensive-survivor",
330 | "metadata": {
331 | "scrolled": true
332 | },
333 | "outputs": [
334 | {
335 | "name": "stdout",
336 | "output_type": "stream",
337 | "text": [
338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n",
339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n"
340 | ]
341 | }
342 | ],
343 | "source": [
344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n",
345 | "\n",
346 | "# 样本输入\n",
347 | "sample_hidden = encoder.initialize_hidden_state()\n",
348 | "sample_output,sample_hidden = encoder(example_input_batch)\n",
349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n",
350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))"
351 | ]
352 | },
353 | {
354 | "cell_type": "code",
355 | "execution_count": 16,
356 | "id": "therapeutic-favor",
357 | "metadata": {},
358 | "outputs": [],
359 | "source": [
360 | "class BahdanauAttention(tf.keras.layers.Layer):\n",
361 | " def __init__(self, units):\n",
362 | " super(BahdanauAttention, self).__init__()\n",
363 | " self.W1 = tf.keras.layers.Dense(units)\n",
364 | " self.W2 = tf.keras.layers.Dense(units)\n",
365 | " self.V = tf.keras.layers.Dense(1)\n",
366 | "\n",
367 | " def call(self, query, values):\n",
368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n",
369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n",
370 | " # 这样做是为了执行加法以计算分数 \n",
371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n",
372 | "\n",
373 | " # 分数的形状 == (批大小,最大长度,1)\n",
374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n",
375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n",
376 | " score = self.V(tf.nn.tanh(\n",
377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n",
378 | "\n",
379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n",
380 | " attention_weights = tf.nn.softmax(score, axis=1)\n",
381 | "\n",
382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n",
383 | " context_vector = attention_weights * values\n",
384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n",
385 | "\n",
386 | " return context_vector, attention_weights"
387 | ]
388 | },
389 | {
390 | "cell_type": "code",
391 | "execution_count": 17,
392 | "id": "under-swaziland",
393 | "metadata": {},
394 | "outputs": [
395 | {
396 | "name": "stdout",
397 | "output_type": "stream",
398 | "text": [
399 | "Attention result shape: (batch size, units) (64, 256)\n",
400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n"
401 | ]
402 | }
403 | ],
404 | "source": [
405 | "attention_layer = BahdanauAttention(10)\n",
406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n",
407 | "\n",
408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n",
409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))"
410 | ]
411 | },
412 | {
413 | "cell_type": "code",
414 | "execution_count": 18,
415 | "id": "smoking-horizon",
416 | "metadata": {},
417 | "outputs": [],
418 | "source": [
419 | "class Decoder(tf.keras.Model):\n",
420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n",
421 | " super(Decoder, self).__init__()\n",
422 | " self.batch_sz = batch_sz\n",
423 | " self.dec_units = dec_units\n",
424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n",
425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
426 | " return_sequences=True,\n",
427 | " return_state=True,\n",
428 | " dropout=0.1,\n",
429 | " recurrent_dropout=0.1)\n",
430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n",
431 | " return_sequences=True,\n",
432 | " return_state=True)\n",
433 | " \n",
434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n",
435 | "\n",
436 | " # 用于注意力\n",
437 | " self.attention = BahdanauAttention(self.dec_units)\n",
438 | "\n",
439 | " def call(self, x, hidden, enc_output,is_train=True):\n",
440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n",
441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n",
442 | "\n",
443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n",
444 | " x = self.embedding(x)\n",
445 | "\n",
446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n",
447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n",
448 | "\n",
449 | " # 将合并后的向量传送到 LSTM\n",
450 | " if is_train:\n",
451 | " output, state,_ = self.train_LSTM(x)\n",
452 | " else:\n",
453 | " output, state,_ = self.infer_LSTM(x)\n",
454 | "\n",
455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n",
456 | " output = tf.reshape(output, (-1, output.shape[2]))\n",
457 | "\n",
458 | " # 输出的形状 == (批大小,vocab)\n",
459 | " x = self.fc(output)\n",
460 | "\n",
461 | " return x, state, attention_weights"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 19,
467 | "id": "straight-sending",
468 | "metadata": {},
469 | "outputs": [
470 | {
471 | "name": "stdout",
472 | "output_type": "stream",
473 | "text": [
474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n"
475 | ]
476 | }
477 | ],
478 | "source": [
479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n",
480 | "\n",
481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n",
482 | " sample_hidden, sample_output,is_train=True)\n",
483 | "\n",
484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))"
485 | ]
486 | },
487 | {
488 | "cell_type": "code",
489 | "execution_count": 20,
490 | "id": "compliant-directive",
491 | "metadata": {},
492 | "outputs": [],
493 | "source": [
494 | "optimizer = tf.keras.optimizers.Adam()\n",
495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n",
496 | " from_logits=True, reduction='none')\n",
497 | "\n",
498 | "def loss_function(real, pred):\n",
499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n",
500 | " loss_ = loss_object(real, pred)\n",
501 | "\n",
502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n",
503 | " loss_ *= mask\n",
504 | "\n",
505 | " return tf.reduce_mean(loss_)"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": 21,
511 | "id": "adult-grocery",
512 | "metadata": {},
513 | "outputs": [],
514 | "source": [
515 | "checkpoint_dir = '../02-checkpoints/'\n",
516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n",
517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n",
518 | " encoder=encoder,\n",
519 | " decoder=decoder)"
520 | ]
521 | },
522 | {
523 | "cell_type": "code",
524 | "execution_count": 22,
525 | "id": "charming-apparatus",
526 | "metadata": {},
527 | "outputs": [],
528 | "source": [
529 | "@tf.function\n",
530 | "def train_step(inp,targ,enc_hidden):\n",
531 | " loss = 0\n",
532 | " \n",
533 | " with tf.GradientTape() as tape:\n",
534 | " enc_output,enc_hidden = encoder(inp)\n",
535 | " dec_hidden = enc_hidden\n",
536 | " \n",
537 | " dec_hidden = enc_hidden\n",
538 | " \n",
539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n",
540 | " \n",
541 | " for t in range(1,targ.shape[1]):\n",
542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n",
543 | " loss += loss_function(targ[:,t],predictions)\n",
544 | " \n",
545 | " dec_input = tf.expand_dims(targ[:,t],1)\n",
546 | " \n",
547 | " batch_loss = (loss / int(targ.shape[1]))\n",
548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n",
549 | " gradients = tape.gradient(loss,variables)\n",
550 | " optimizer.apply_gradients(zip(gradients,variables))\n",
551 | " \n",
552 | " return batch_loss"
553 | ]
554 | },
555 | {
556 | "cell_type": "code",
557 | "execution_count": null,
558 | "id": "based-wound",
559 | "metadata": {},
560 | "outputs": [],
561 | "source": [
562 | "EPOCHS = 30\n",
563 | "\n",
564 | "for epoch in range(EPOCHS):\n",
565 | " start = time.time()\n",
566 | "\n",
567 | " enc_hidden = encoder.initialize_hidden_state()\n",
568 | " total_loss = 0\n",
569 | "\n",
570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n",
571 | " batch_loss = train_step(inp, targ, enc_hidden)\n",
572 | " total_loss += batch_loss\n",
573 | "\n",
574 | " #if batch % 5 == 0:\n",
575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n",
576 | " batch,\n",
577 | " batch_loss.numpy()))\n",
578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n",
579 | " if (epoch + 1) % 5 == 0:\n",
580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n",
581 | "\n",
582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n",
583 | " total_loss / steps_per_epoch))\n",
584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))"
585 | ]
586 | },
587 | {
588 | "cell_type": "markdown",
589 | "id": "terminal-suite",
590 | "metadata": {},
591 | "source": [
592 | "## 预测"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 13,
598 | "id": "satellite-seating",
599 | "metadata": {},
600 | "outputs": [],
601 | "source": [
602 | "def create_inference_data(lines):\n",
603 | " data = \"\"\n",
604 | " for line in lines:\n",
605 | " line = preprocess_sentence(line.strip()) + ' '\n",
606 | " data += line\n",
607 | " return data.strip()"
608 | ]
609 | },
610 | {
611 | "cell_type": "code",
612 | "execution_count": 14,
613 | "id": "cellular-impossible",
614 | "metadata": {},
615 | "outputs": [],
616 | "source": [
617 | "def evaluate(sentence):\n",
618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n",
619 | "\n",
620 | " sentence = create_inference_data(sentence)\n",
621 | "\n",
622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n",
623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n",
624 | " maxlen=max_length_inp,\n",
625 | " padding='post')\n",
626 | " inputs = tf.convert_to_tensor(inputs)\n",
627 | "\n",
628 | " result = ''\n",
629 | "\n",
630 | " hidden = [tf.zeros((1, units))]\n",
631 | " enc_out, enc_hidden = encoder(inputs)\n",
632 | "\n",
633 | " dec_hidden = enc_hidden\n",
634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n",
635 | "\n",
636 | " for t in range(max_length_targ):\n",
637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n",
638 | " dec_hidden,\n",
639 | " enc_out,\n",
640 | " is_train=False)\n",
641 | "\n",
642 | " # 存储注意力权重以便后面制图\n",
643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n",
644 | " attention_plot[t] = attention_weights.numpy()\n",
645 | "\n",
646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n",
647 | "\n",
648 | " result += targ_lang.index_word[predicted_id] + ' '\n",
649 | "\n",
650 | " if targ_lang.index_word[predicted_id] == '':\n",
651 | " return result, sentence, attention_plot\n",
652 | "\n",
653 | " # 预测的 ID 被输送回模型\n",
654 | " dec_input = tf.expand_dims([predicted_id], 0)\n",
655 | "\n",
656 | " return result, sentence, attention_plot"
657 | ]
658 | },
659 | {
660 | "cell_type": "code",
661 | "execution_count": 21,
662 | "id": "current-feedback",
663 | "metadata": {},
664 | "outputs": [],
665 | "source": [
666 | "# 注意力权重制图函数\n",
667 | "def plot_attention(attention, sentence, predicted_sentence):\n",
668 | " fig = plt.figure(figsize=(20,20))\n",
669 | " ax = fig.add_subplot(1, 1, 1)\n",
670 | " ax.matshow(attention, cmap='viridis')\n",
671 | "\n",
672 | " fontdict = {'fontsize': 10}\n",
673 | "\n",
674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n",
675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n",
676 | "\n",
677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n",
678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n",
679 | "\n",
680 | " plt.show()"
681 | ]
682 | },
683 | {
684 | "cell_type": "code",
685 | "execution_count": 22,
686 | "id": "worse-logging",
687 | "metadata": {},
688 | "outputs": [],
689 | "source": [
690 | "def translate(sentence):\n",
691 | " result,sentence,attention_plot = evaluate(sentence)\n",
692 | " \n",
693 | " print('Input: %s' % (sentence))\n",
694 | " print('Predicted translation: {}'.format(result))\n",
695 | " \n",
696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n",
697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n",
698 | " print(result.split(' '))"
699 | ]
700 | },
701 | {
702 | "cell_type": "code",
703 | "execution_count": 29,
704 | "id": "integrated-victoria",
705 | "metadata": {},
706 | "outputs": [],
707 | "source": [
708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n",
709 | "lines = create_inference_data(lines)"
710 | ]
711 | },
712 | {
713 | "cell_type": "code",
714 | "execution_count": 30,
715 | "id": "numerous-advisory",
716 | "metadata": {
717 | "scrolled": false
718 | },
719 | "outputs": [
720 | {
721 | "data": {
722 | "text/plain": [
723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '"
724 | ]
725 | },
726 | "execution_count": 30,
727 | "metadata": {},
728 | "output_type": "execute_result"
729 | }
730 | ],
731 | "source": [
732 | "lines"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": null,
738 | "id": "structured-planet",
739 | "metadata": {},
740 | "outputs": [],
741 | "source": []
742 | }
743 | ],
744 | "metadata": {
745 | "kernelspec": {
746 | "display_name": "myconda",
747 | "language": "python",
748 | "name": "myconda"
749 | },
750 | "language_info": {
751 | "codemirror_mode": {
752 | "name": "ipython",
753 | "version": 3
754 | },
755 | "file_extension": ".py",
756 | "mimetype": "text/x-python",
757 | "name": "python",
758 | "nbconvert_exporter": "python",
759 | "pygments_lexer": "ipython3",
760 | "version": "3.7.7"
761 | }
762 | },
763 | "nbformat": 4,
764 | "nbformat_minor": 5
765 | }
766 |
--------------------------------------------------------------------------------
/model/01-Text_Gen/Text_Gen_tf10.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "\"\"\"数据导入\"\"\"\n",
10 | "\n",
11 | "import re\n",
12 | "data_file = \"../00-data/tf_data.txt\"\n",
13 | "filename =open(data_file,'r',encoding='utf-8') #打开数据文件\n",
14 | "\n",
15 | "text = filename.read() #将数据读取到字符串text中\n",
16 | "text = ' '.join(re.split(' |\\t|\\v',text)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n",
17 | "text = re.split('([: ,.*\\n(){}\\[\\]=])',text) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n",
18 | "\n",
19 | "text = list(filter(lambda x: x!=' 'and x!='',text)) #将列表中的空格和非空格筛选掉\n",
20 | "list_text = text #保留一份列表格式的数据\n",
21 | "text = ' '.join(text) #将列表转换成字符串"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "\"\"\"文本词频统计\"\"\"\n",
31 | "\n",
32 | "def word_count(list_text): #定义计算文本词频的函数,传入list_text列表\n",
33 | " import collections\n",
34 | " word_freq = collections.defaultdict(int) #定义一个int型的词频词典,并提供默认值\n",
35 | " for w in list_text: #遍历列表中的元素,元素出现一次,频次加一\n",
36 | " word_freq[w] += 1\n",
37 | " return word_freq #返回词频词典\n",
38 | " \n",
39 | " #return word_freq.items() 该语句返回值的类型为list(这句话有语法问题,不必考虑)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "\"\"\"根据text文本创建代码词词典\"\"\"\n",
49 | "\n",
50 | "def build_dict(text, min_word_freq=50):\n",
51 | " word_freq = word_count(text) #文本词频统计,返回一个词频词典\n",
52 | " word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) # filter将词频数量低于指定值的代码词删除。\n",
53 | " word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) # key用于指定排序的元素,因为sorted默认使用list中每个item的第一个元素从小到大排列,所以这里通过lambda进行前后元素调序,并对词频去相反数,从而将词频最大的排列在最前面\n",
54 | " words, _ = list(zip(*word_freq_sorted)) #获取每一个代码词\n",
55 | " words = list(words)\n",
56 | " words.append('')\n",
57 | " word_idx = dict(zip(words, range(len(words)))) #构建词典(不包含词频)\n",
58 | " return words,word_idx #这里只返回了words,倒数两行代码还用不上。返回的是一个不含重复的代码词词典,不包含词频。"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 4,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "name": "stderr",
68 | "output_type": "stream",
69 | "text": [
70 | "Using TensorFlow backend.\n"
71 | ]
72 | },
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "Number of sequences: 79181\n",
78 | "Unique words: 7865\n",
79 | "Vectorization...\n"
80 | ]
81 | }
82 | ],
83 | "source": [
84 | "\"\"\"数据预处理-字符串序列向量化\"\"\"\n",
85 | "\n",
86 | "import numpy as np\n",
87 | "import keras\n",
88 | "import json\n",
89 | "\n",
90 | "maxlen = 50 #提取50个代码词组成的序列\n",
91 | "step = 5 #每5个代码词采样一个新序列\n",
92 | "sentences = [] #保存所提取的序列\n",
93 | "next_words = [] #保存目标代码词\n",
94 | "vocab_file = \"../00-data/vocab\"\n",
95 | "\n",
96 | "cut_words = list_text #将列表形式的元数据保存在cut_words中\n",
97 | "for i in range(0,len(cut_words) - maxlen,step):\n",
98 | " sentences.append(cut_words[i:i + maxlen]) #将元数据按照步长来存储在每个序列中 \n",
99 | " next_words.append(cut_words[i + maxlen]) #将目标代码词存储在next_words中\n",
100 | " \n",
101 | " \n",
102 | "print('Number of sequences:', len(sentences))\n",
103 | "\n",
104 | "\n",
105 | "words,word_idx = list(build_dict(list_text,1)) #创建代码词词典,返回的是一个不含重复的代码词词典,不包含词频。\n",
106 | "print('Unique words:',len(words))\n",
107 | "json_dict = json.dumps(word_idx)\n",
108 | "with open(vocab_file,\"w\") as f:\n",
109 | " f.write(json_dict)\n",
110 | "\n",
111 | "word_indices = dict((word,words.index(word)) for word in words) #创建一个包含代码词唯一索引的代码词词典,返回的是一个字典\n",
112 | "#print(word_indices)\n",
113 | "\n",
114 | "print('Vectorization...')\n",
115 | "x = np.zeros((len(sentences),maxlen)) #初始化x\n",
116 | "y = np.zeros((len(sentences))) #初始化y\n",
117 | "for i,sentence in enumerate(sentences):\n",
118 | " for t,word in enumerate(sentence):\n",
119 | " x[i,t] = word_indices.get(word,word_indices['']) #将代码词转换成向量形式的编码\n",
120 | " #y[i] = word_indices[next_words[i]]\n",
121 | " y[i] = word_indices.get(next_words[i],word_indices[''])\n",
122 | "\n",
123 | "y = keras.utils.to_categorical(y, len(words)) #将int型数组y转换成one-hot编码"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "\"\"\"定义下一个代码词的采样函数---temperature越大,代码生成的随机性越强---\"\"\"\n",
133 | "\n",
134 | "def sample(preds,temperature=0.1):\n",
135 | " preds = np.asarray(preds).astype('float')\n",
136 | " preds = np.log(preds) /temperature\n",
137 | " exp_preds = np.exp(preds)\n",
138 | " preds = exp_preds / np.sum(exp_preds)\n",
139 | " probas = np.random.multinomial(1,preds,1)\n",
140 | " return np.argmax(probas)"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 7,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "\"\"\"将字符串写到指定文件中\"\"\"\n",
150 | "\n",
151 | "def save(filename, contents): \n",
152 | " file = open(filename, 'a', encoding='utf-8')\n",
153 | " file.write(contents)\n",
154 | " file.close()"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "### 模型训练"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 8,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "import keras\n",
171 | "from keras import layers\n",
172 | "from keras.layers import LSTM, Dense, Dropout\n",
173 | "\n",
174 | "def create_model(words,learning_rate): #定义创建模型的函数\n",
175 | " model = keras.models.Sequential() #模型初始化\n",
176 | " model.add(layers.Embedding(len(words),512)) #模型第一层为embedding层\n",
177 | " model.add(layers.LSTM(512,return_sequences=True,dropout=0.2,recurrent_dropout=0.2)) #模型第二层为LSTM层,加入dropout减少过拟合\n",
178 | " model.add(layers.LSTM(512,dropout=0.2,recurrent_dropout=0.2)) #模型第三层为LSTM层,加入dropout减少过拟合\n",
179 | " model.add(layers.Dense(len(words),activation='softmax')) #模型第三层为全连接层\n",
180 | "\n",
181 | " optimizer = keras.optimizers.RMSprop(lr=learning_rate) #定义优化器\n",
182 | " model.compile(loss='categorical_crossentropy',optimizer=optimizer) #模型编译\n",
183 | " \n",
184 | " return model"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 9,
190 | "metadata": {
191 | "scrolled": true
192 | },
193 | "outputs": [
194 | {
195 | "name": "stdout",
196 | "output_type": "stream",
197 | "text": [
198 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
199 | "Instructions for updating:\n",
200 | "Colocations handled automatically by placer.\n",
201 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
202 | "Instructions for updating:\n",
203 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
204 | "_________________________________________________________________\n",
205 | "Layer (type) Output Shape Param # \n",
206 | "=================================================================\n",
207 | "embedding_1 (Embedding) (None, None, 512) 4026880 \n",
208 | "_________________________________________________________________\n",
209 | "lstm_1 (LSTM) (None, None, 512) 2099200 \n",
210 | "_________________________________________________________________\n",
211 | "lstm_2 (LSTM) (None, 512) 2099200 \n",
212 | "_________________________________________________________________\n",
213 | "dense_1 (Dense) (None, 7865) 4034745 \n",
214 | "=================================================================\n",
215 | "Total params: 12,260,025\n",
216 | "Trainable params: 12,260,025\n",
217 | "Non-trainable params: 0\n",
218 | "_________________________________________________________________\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "\"\"\"创建模型实例\"\"\"\n",
224 | "learning_rate = 0.003\n",
225 | "model = create_model(words,learning_rate) #创建模型\n",
226 | "model.summary() #打印模型结构"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 10,
232 | "metadata": {},
233 | "outputs": [],
234 | "source": [
235 | "def clean_and_split(line):\n",
236 | " #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n",
237 | " line = ' '.join(re.split(' |\\t|\\v',line))\n",
238 | " #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n",
239 | " line = re.split('([: ,.*\\n(){}\\[\\]=])',line) \n",
240 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
241 | " return line"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 11,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "\"\"\"打印生成的结果\"\"\"\n",
251 | "def print_code_text(code_list):\n",
252 | " mark = \".,*()[]:{}\\n\"\n",
253 | " \n",
254 | " result = \"\"\n",
255 | " last_word = \"\"\n",
256 | " \n",
257 | " for word in code_list:\n",
258 | " if last_word not in mark and word not in mark:\n",
259 | " result += \" \" + word\n",
260 | " else:\n",
261 | " result += word\n",
262 | " \n",
263 | " last_word = word\n",
264 | " \n",
265 | " print(result)\n",
266 | " \n",
267 | " return result"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 12,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "def text_generate(model,input_file,maxlen,\n",
277 | " temperatures,save_path,epoch,gen_lines=30):\n",
278 | " import random\n",
279 | " import codecs\n",
280 | " \n",
281 | " #从原始数据中随机找一行作为文本生成的起点\n",
282 | " with codecs.open(input_file,\"r\",\"utf-8\") as fin:\n",
283 | " lines = fin.readlines()\n",
284 | " random_line = random.randint(0,len(lines))\n",
285 | " start_line = clean_and_split(lines[random_line])\n",
286 | " \n",
287 | " one_line_max_words = 30\n",
288 | " \n",
289 | " print(\"===========epoch:%d===========\" % epoch)\n",
290 | " \n",
291 | " for temperature in temperatures:\n",
292 | " generated_text = start_line[:]\n",
293 | " print_string = start_line[:]\n",
294 | "\n",
295 | " for i in range(gen_lines):\n",
296 | " for j in range(one_line_max_words):\n",
297 | " sampled = np.zeros((1,len(generated_text))) \n",
298 | " #向量化\n",
299 | " for t,word in enumerate(generated_text): \n",
300 | " sampled[0,t] = word_indices.get(word,word_indices[''])\n",
301 | " #预测下一个词\n",
302 | " preds = model.predict(sampled,verbose=0)[0]\n",
303 | " next_index = sample(preds,temperature)\n",
304 | " next_word = words[next_index]\n",
305 | "\n",
306 | " if len(generated_text) == maxlen:\n",
307 | " generated_text = generated_text[1:]\n",
308 | " generated_text.append(next_word)\n",
309 | " print_string.append(next_word)\n",
310 | " if next_word == '\\n':\n",
311 | " break\n",
312 | "\n",
313 | " print(\"-----temperature: {}-----\".format(temperature))\n",
314 | " result = print_code_text(print_string)\n",
315 | "\n",
316 | " save_file = save_path + \"/{}_epoch_{}_temperature\".format(epoch,temperature)\n",
317 | " with codecs.open(save_file,\"w\",\"utf-8\") as fout:\n",
318 | " fout.write(result)"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 13,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "\"\"\"模型保存\"\"\"\n",
328 | "from keras.callbacks import ModelCheckpoint\n",
329 | "filepath = \"../02-checkpoints/\"\n",
330 | "checkpoint = ModelCheckpoint(filepath, save_weights_only=False,verbose=1,save_best_only=False) #回调函数,实现断点续训功能"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 14,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "\"\"\"学习率随模型效果变小\"\"\"\n",
340 | "import keras.backend as K\n",
341 | "from keras.callbacks import LearningRateScheduler\n",
342 | " \n",
343 | "def scheduler(epoch):\n",
344 | " # 每隔10个epoch,学习率减小为原来的5/10\n",
345 | " if epoch % 10 == 0 and epoch != 0:\n",
346 | " lr = K.get_value(model.optimizer.lr)\n",
347 | " K.set_value(model.optimizer.lr, lr * 0.5)\n",
348 | " print(\"lr changed to {}\".format(lr * 0.5))\n",
349 | " return K.get_value(model.optimizer.lr)\n",
350 | " \n",
351 | "reduce_lr = LearningRateScheduler(scheduler)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": 15,
357 | "metadata": {},
358 | "outputs": [
359 | {
360 | "name": "stdout",
361 | "output_type": "stream",
362 | "text": [
363 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
364 | "Instructions for updating:\n",
365 | "Use tf.cast instead.\n",
366 | "Epoch 1/5\n",
367 | " 4096/79181 [>.............................] - ETA: 2:37 - loss: 0.6361"
368 | ]
369 | },
370 | {
371 | "ename": "KeyboardInterrupt",
372 | "evalue": "",
373 | "output_type": "error",
374 | "traceback": [
375 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
376 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
377 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minit_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_epoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreduce_lr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#开始训练模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mtext_generate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxlen\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mprint_save_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_save_path\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"epoch_{}.hdf5\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
378 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m validation_steps=validation_steps)\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m def evaluate(self, x=None, y=None,\n",
379 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
380 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
381 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
382 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1437\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1439\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1440\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
383 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
384 | ]
385 | }
386 | ],
387 | "source": [
388 | "\"\"\"训练模型\"\"\"\n",
389 | "import os\n",
390 | "print_save_path = \"../00-data/\"\n",
391 | "model_save_path = \"../02-checkpoints/\"\n",
392 | "total_epochs = 50 \n",
393 | "for epoch in range(1,total_epochs):\n",
394 | " if os.path.exists(filepath): #如果模型存在,则从现有模型开始训练\n",
395 | " model.load_weights(filepath)\n",
396 | " init_epoch = (epoch - 1) * 5\n",
397 | " model.fit(x,y,batch_size=1024,epochs=(init_epoch + 5),initial_epoch=init_epoch,callbacks=[reduce_lr,checkpoint]) #开始训练模型\n",
398 | " text_generate(model,data_file,maxlen,[0.1,0.4,0.8],print_save_path,epoch * 5)\n",
399 | " model.save(model_save_path + \"epoch_{}.hdf5\".format(epoch * 5))"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "### 模型预测"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 32,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "def generate_text_sentence(seed_text,model_filename): #测试代码和上面训练模型的代码基本一样,就不再介绍\n",
416 | " model.load_weights(model_filename)\n",
417 | " \n",
418 | " strings=''\n",
419 | " last_word=''\n",
420 | " seed_text = re.split('([: ,.\\n(){}\\[\\]=])',seed_text)\n",
421 | " seed_text = list(filter(lambda x: x!=' 'and x!='',seed_text))\n",
422 | " \n",
423 | " generated_text = seed_text[:]\n",
424 | " \n",
425 | " for temperature in [0.1,0.4,0.8]:\n",
426 | " strings += '\\n' + '-------------temperature:' + str(temperature) +'-------------\\n' +'\\n'\n",
427 | " \n",
428 | " for i in range(50):\n",
429 | " if i == 0:\n",
430 | " for k in range(len(generated_text)):\n",
431 | " if generated_text[k] not in mark and last_word not in mark:\n",
432 | " strings += ' ' + generated_text[k]\n",
433 | " else:\n",
434 | " strings += generated_text[k]\n",
435 | " last_word = generated_text[k]\n",
436 | "\n",
437 | " sampled = np.zeros((1,len(generated_text)))\n",
438 | " for t,word in enumerate(generated_text):\n",
439 | " sampled[0,t] = word_indices[word]\n",
440 | "\n",
441 | " preds = model.predict(sampled,verbose=0)[0]\n",
442 | " next_index = sample(preds,temperature = 0.3)\n",
443 | " next_word = words[next_index]\n",
444 | "\n",
445 | "\n",
446 | " generated_text.append(next_word)\n",
447 | "\n",
448 | " #if len(generated_text) == maxlen:\n",
449 | " # generated_text = generated_text[1:]\n",
450 | "\n",
451 | " if next_word not in mark and last_word not in mark:\n",
452 | " strings += ' ' + next_word\n",
453 | " else:\n",
454 | " strings += next_word\n",
455 | "\n",
456 | " last_word = next_word\n",
457 | "\n",
458 | " if next_word == '\\n':\n",
459 | " break\n",
460 | " \n",
461 | " generated_text = seed_text[:]\n",
462 | " \n",
463 | " return strings\n"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": null,
469 | "metadata": {},
470 | "outputs": [],
471 | "source": []
472 | }
473 | ],
474 | "metadata": {
475 | "kernelspec": {
476 | "display_name": "myconda",
477 | "language": "python",
478 | "name": "myconda"
479 | },
480 | "language_info": {
481 | "codemirror_mode": {
482 | "name": "ipython",
483 | "version": 3
484 | },
485 | "file_extension": ".py",
486 | "mimetype": "text/x-python",
487 | "name": "python",
488 | "nbconvert_exporter": "python",
489 | "pygments_lexer": "ipython3",
490 | "version": "3.5.6"
491 | }
492 | },
493 | "nbformat": 4,
494 | "nbformat_minor": 4
495 | }
496 |
--------------------------------------------------------------------------------
/model/01-Text_Gen/.ipynb_checkpoints/Text_Gen_tf10-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "\"\"\"数据导入\"\"\"\n",
10 | "\n",
11 | "import re\n",
12 | "data_file = \"../00-data/tf_data.txt\"\n",
13 | "filename =open(data_file,'r',encoding='utf-8') #打开数据文件\n",
14 | "\n",
15 | "text = filename.read() #将数据读取到字符串text中\n",
16 | "text = ' '.join(re.split(' |\\t|\\v',text)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n",
17 | "text = re.split('([: ,.*\\n(){}\\[\\]=])',text) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n",
18 | "\n",
19 | "text = list(filter(lambda x: x!=' 'and x!='',text)) #将列表中的空格和非空格筛选掉\n",
20 | "list_text = text #保留一份列表格式的数据\n",
21 | "text = ' '.join(text) #将列表转换成字符串"
22 | ]
23 | },
24 | {
25 | "cell_type": "code",
26 | "execution_count": 2,
27 | "metadata": {},
28 | "outputs": [],
29 | "source": [
30 | "\"\"\"文本词频统计\"\"\"\n",
31 | "\n",
32 | "def word_count(list_text): #定义计算文本词频的函数,传入list_text列表\n",
33 | " import collections\n",
34 | " word_freq = collections.defaultdict(int) #定义一个int型的词频词典,并提供默认值\n",
35 | " for w in list_text: #遍历列表中的元素,元素出现一次,频次加一\n",
36 | " word_freq[w] += 1\n",
37 | " return word_freq #返回词频词典\n",
38 | " \n",
39 | " #return word_freq.items() 该语句返回值的类型为list(这句话有语法问题,不必考虑)"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": 3,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "\"\"\"根据text文本创建代码词词典\"\"\"\n",
49 | "\n",
50 | "def build_dict(text, min_word_freq=50):\n",
51 | " word_freq = word_count(text) #文本词频统计,返回一个词频词典\n",
52 | " word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) # filter将词频数量低于指定值的代码词删除。\n",
53 | " word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) # key用于指定排序的元素,因为sorted默认使用list中每个item的第一个元素从小到大排列,所以这里通过lambda进行前后元素调序,并对词频去相反数,从而将词频最大的排列在最前面\n",
54 | " words, _ = list(zip(*word_freq_sorted)) #获取每一个代码词\n",
55 | " words = list(words)\n",
56 | " words.append('')\n",
57 | " word_idx = dict(zip(words, range(len(words)))) #构建词典(不包含词频)\n",
58 | " return words,word_idx #这里只返回了words,倒数两行代码还用不上。返回的是一个不含重复的代码词词典,不包含词频。"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 4,
64 | "metadata": {},
65 | "outputs": [
66 | {
67 | "name": "stderr",
68 | "output_type": "stream",
69 | "text": [
70 | "Using TensorFlow backend.\n"
71 | ]
72 | },
73 | {
74 | "name": "stdout",
75 | "output_type": "stream",
76 | "text": [
77 | "Number of sequences: 79181\n",
78 | "Unique words: 7865\n",
79 | "Vectorization...\n"
80 | ]
81 | }
82 | ],
83 | "source": [
84 | "\"\"\"数据预处理-字符串序列向量化\"\"\"\n",
85 | "\n",
86 | "import numpy as np\n",
87 | "import keras\n",
88 | "import json\n",
89 | "\n",
90 | "maxlen = 50 #提取50个代码词组成的序列\n",
91 | "step = 5 #每5个代码词采样一个新序列\n",
92 | "sentences = [] #保存所提取的序列\n",
93 | "next_words = [] #保存目标代码词\n",
94 | "vocab_file = \"../00-data/vocab\"\n",
95 | "\n",
96 | "cut_words = list_text #将列表形式的元数据保存在cut_words中\n",
97 | "for i in range(0,len(cut_words) - maxlen,step):\n",
98 | " sentences.append(cut_words[i:i + maxlen]) #将元数据按照步长来存储在每个序列中 \n",
99 | " next_words.append(cut_words[i + maxlen]) #将目标代码词存储在next_words中\n",
100 | " \n",
101 | " \n",
102 | "print('Number of sequences:', len(sentences))\n",
103 | "\n",
104 | "\n",
105 | "words,word_idx = list(build_dict(list_text,1)) #创建代码词词典,返回的是一个不含重复的代码词词典,不包含词频。\n",
106 | "print('Unique words:',len(words))\n",
107 | "json_dict = json.dumps(word_idx)\n",
108 | "with open(vocab_file,\"w\") as f:\n",
109 | " f.write(json_dict)\n",
110 | "\n",
111 | "word_indices = dict((word,words.index(word)) for word in words) #创建一个包含代码词唯一索引的代码词词典,返回的是一个字典\n",
112 | "#print(word_indices)\n",
113 | "\n",
114 | "print('Vectorization...')\n",
115 | "x = np.zeros((len(sentences),maxlen)) #初始化x\n",
116 | "y = np.zeros((len(sentences))) #初始化y\n",
117 | "for i,sentence in enumerate(sentences):\n",
118 | " for t,word in enumerate(sentence):\n",
119 | " x[i,t] = word_indices.get(word,word_indices['']) #将代码词转换成向量形式的编码\n",
120 | " #y[i] = word_indices[next_words[i]]\n",
121 | " y[i] = word_indices.get(next_words[i],word_indices[''])\n",
122 | "\n",
123 | "y = keras.utils.to_categorical(y, len(words)) #将int型数组y转换成one-hot编码"
124 | ]
125 | },
126 | {
127 | "cell_type": "code",
128 | "execution_count": 6,
129 | "metadata": {},
130 | "outputs": [],
131 | "source": [
132 | "\"\"\"定义下一个代码词的采样函数---temperature越大,代码生成的随机性越强---\"\"\"\n",
133 | "\n",
134 | "def sample(preds,temperature=0.1):\n",
135 | " preds = np.asarray(preds).astype('float')\n",
136 | " preds = np.log(preds) /temperature\n",
137 | " exp_preds = np.exp(preds)\n",
138 | " preds = exp_preds / np.sum(exp_preds)\n",
139 | " probas = np.random.multinomial(1,preds,1)\n",
140 | " return np.argmax(probas)"
141 | ]
142 | },
143 | {
144 | "cell_type": "code",
145 | "execution_count": 7,
146 | "metadata": {},
147 | "outputs": [],
148 | "source": [
149 | "\"\"\"将字符串写到指定文件中\"\"\"\n",
150 | "\n",
151 | "def save(filename, contents): \n",
152 | " file = open(filename, 'a', encoding='utf-8')\n",
153 | " file.write(contents)\n",
154 | " file.close()"
155 | ]
156 | },
157 | {
158 | "cell_type": "markdown",
159 | "metadata": {},
160 | "source": [
161 | "### 模型训练"
162 | ]
163 | },
164 | {
165 | "cell_type": "code",
166 | "execution_count": 8,
167 | "metadata": {},
168 | "outputs": [],
169 | "source": [
170 | "import keras\n",
171 | "from keras import layers\n",
172 | "from keras.layers import LSTM, Dense, Dropout\n",
173 | "\n",
174 | "def create_model(words,learning_rate): #定义创建模型的函数\n",
175 | " model = keras.models.Sequential() #模型初始化\n",
176 | " model.add(layers.Embedding(len(words),512)) #模型第一层为embedding层\n",
177 | " model.add(layers.LSTM(512,return_sequences=True,dropout=0.2,recurrent_dropout=0.2)) #模型第二层为LSTM层,加入dropout减少过拟合\n",
178 | " model.add(layers.LSTM(512,dropout=0.2,recurrent_dropout=0.2)) #模型第三层为LSTM层,加入dropout减少过拟合\n",
179 | " model.add(layers.Dense(len(words),activation='softmax')) #模型第三层为全连接层\n",
180 | "\n",
181 | " optimizer = keras.optimizers.RMSprop(lr=learning_rate) #定义优化器\n",
182 | " model.compile(loss='categorical_crossentropy',optimizer=optimizer) #模型编译\n",
183 | " \n",
184 | " return model"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": 9,
190 | "metadata": {
191 | "scrolled": true
192 | },
193 | "outputs": [
194 | {
195 | "name": "stdout",
196 | "output_type": "stream",
197 | "text": [
198 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
199 | "Instructions for updating:\n",
200 | "Colocations handled automatically by placer.\n",
201 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
202 | "Instructions for updating:\n",
203 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
204 | "_________________________________________________________________\n",
205 | "Layer (type) Output Shape Param # \n",
206 | "=================================================================\n",
207 | "embedding_1 (Embedding) (None, None, 512) 4026880 \n",
208 | "_________________________________________________________________\n",
209 | "lstm_1 (LSTM) (None, None, 512) 2099200 \n",
210 | "_________________________________________________________________\n",
211 | "lstm_2 (LSTM) (None, 512) 2099200 \n",
212 | "_________________________________________________________________\n",
213 | "dense_1 (Dense) (None, 7865) 4034745 \n",
214 | "=================================================================\n",
215 | "Total params: 12,260,025\n",
216 | "Trainable params: 12,260,025\n",
217 | "Non-trainable params: 0\n",
218 | "_________________________________________________________________\n"
219 | ]
220 | }
221 | ],
222 | "source": [
223 | "\"\"\"创建模型实例\"\"\"\n",
224 | "learning_rate = 0.003\n",
225 | "model = create_model(words,learning_rate) #创建模型\n",
226 | "model.summary() #打印模型结构"
227 | ]
228 | },
229 | {
230 | "cell_type": "code",
231 | "execution_count": 10,
232 | "metadata": {},
233 | "outputs": [],
234 | "source": [
235 | "def clean_and_split(line):\n",
236 | " #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n",
237 | " line = ' '.join(re.split(' |\\t|\\v',line))\n",
238 | " #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n",
239 | " line = re.split('([: ,.*\\n(){}\\[\\]=])',line) \n",
240 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
241 | " return line"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 11,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "\"\"\"打印生成的结果\"\"\"\n",
251 | "def print_code_text(code_list):\n",
252 | " mark = \".,*()[]:{}\\n\"\n",
253 | " \n",
254 | " result = \"\"\n",
255 | " last_word = \"\"\n",
256 | " \n",
257 | " for word in code_list:\n",
258 | " if last_word not in mark and word not in mark:\n",
259 | " result += \" \" + word\n",
260 | " else:\n",
261 | " result += word\n",
262 | " \n",
263 | " last_word = word\n",
264 | " \n",
265 | " print(result)\n",
266 | " \n",
267 | " return result"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 12,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "def text_generate(model,input_file,maxlen,\n",
277 | " temperatures,save_path,epoch,gen_lines=30):\n",
278 | " import random\n",
279 | " import codecs\n",
280 | " \n",
281 | " #从原始数据中随机找一行作为文本生成的起点\n",
282 | " with codecs.open(input_file,\"r\",\"utf-8\") as fin:\n",
283 | " lines = fin.readlines()\n",
284 | " random_line = random.randint(0,len(lines))\n",
285 | " start_line = clean_and_split(lines[random_line])\n",
286 | " \n",
287 | " one_line_max_words = 30\n",
288 | " \n",
289 | " print(\"===========epoch:%d===========\" % epoch)\n",
290 | " \n",
291 | " for temperature in temperatures:\n",
292 | " generated_text = start_line[:]\n",
293 | " print_string = start_line[:]\n",
294 | "\n",
295 | " for i in range(gen_lines):\n",
296 | " for j in range(one_line_max_words):\n",
297 | " sampled = np.zeros((1,len(generated_text))) \n",
298 | " #向量化\n",
299 | " for t,word in enumerate(generated_text): \n",
300 | " sampled[0,t] = word_indices.get(word,word_indices[''])\n",
301 | " #预测下一个词\n",
302 | " preds = model.predict(sampled,verbose=0)[0]\n",
303 | " next_index = sample(preds,temperature)\n",
304 | " next_word = words[next_index]\n",
305 | "\n",
306 | " if len(generated_text) == maxlen:\n",
307 | " generated_text = generated_text[1:]\n",
308 | " generated_text.append(next_word)\n",
309 | " print_string.append(next_word)\n",
310 | " if next_word == '\\n':\n",
311 | " break\n",
312 | "\n",
313 | " print(\"-----temperature: {}-----\".format(temperature))\n",
314 | " result = print_code_text(print_string)\n",
315 | "\n",
316 | " save_file = save_path + \"/{}_epoch_{}_temperature\".format(epoch,temperature)\n",
317 | " with codecs.open(save_file,\"w\",\"utf-8\") as fout:\n",
318 | " fout.write(result)"
319 | ]
320 | },
321 | {
322 | "cell_type": "code",
323 | "execution_count": 13,
324 | "metadata": {},
325 | "outputs": [],
326 | "source": [
327 | "\"\"\"模型保存\"\"\"\n",
328 | "from keras.callbacks import ModelCheckpoint\n",
329 | "filepath = \"../02-checkpoints/\"\n",
330 | "checkpoint = ModelCheckpoint(filepath, save_weights_only=False,verbose=1,save_best_only=False) #回调函数,实现断点续训功能"
331 | ]
332 | },
333 | {
334 | "cell_type": "code",
335 | "execution_count": 14,
336 | "metadata": {},
337 | "outputs": [],
338 | "source": [
339 | "\"\"\"学习率随模型效果变小\"\"\"\n",
340 | "import keras.backend as K\n",
341 | "from keras.callbacks import LearningRateScheduler\n",
342 | " \n",
343 | "def scheduler(epoch):\n",
344 | " # 每隔10个epoch,学习率减小为原来的5/10\n",
345 | " if epoch % 10 == 0 and epoch != 0:\n",
346 | " lr = K.get_value(model.optimizer.lr)\n",
347 | " K.set_value(model.optimizer.lr, lr * 0.5)\n",
348 | " print(\"lr changed to {}\".format(lr * 0.5))\n",
349 | " return K.get_value(model.optimizer.lr)\n",
350 | " \n",
351 | "reduce_lr = LearningRateScheduler(scheduler)"
352 | ]
353 | },
354 | {
355 | "cell_type": "code",
356 | "execution_count": 15,
357 | "metadata": {},
358 | "outputs": [
359 | {
360 | "name": "stdout",
361 | "output_type": "stream",
362 | "text": [
363 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
364 | "Instructions for updating:\n",
365 | "Use tf.cast instead.\n",
366 | "Epoch 1/5\n",
367 | " 4096/79181 [>.............................] - ETA: 2:37 - loss: 0.6361"
368 | ]
369 | },
370 | {
371 | "ename": "KeyboardInterrupt",
372 | "evalue": "",
373 | "output_type": "error",
374 | "traceback": [
375 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
376 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
377 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minit_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_epoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreduce_lr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#开始训练模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mtext_generate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxlen\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mprint_save_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_save_path\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"epoch_{}.hdf5\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
378 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m validation_steps=validation_steps)\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m def evaluate(self, x=None, y=None,\n",
379 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
380 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
381 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
382 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1437\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1439\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1440\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
383 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
384 | ]
385 | }
386 | ],
387 | "source": [
388 | "\"\"\"训练模型\"\"\"\n",
389 | "import os\n",
390 | "print_save_path = \"../00-data/\"\n",
391 | "model_save_path = \"../02-checkpoints/\"\n",
392 | "total_epochs = 50 \n",
393 | "for epoch in range(1,total_epochs):\n",
394 | " if os.path.exists(filepath): #如果模型存在,则从现有模型开始训练\n",
395 | " model.load_weights(filepath)\n",
396 | " init_epoch = (epoch - 1) * 5\n",
397 | " model.fit(x,y,batch_size=1024,epochs=(init_epoch + 5),initial_epoch=init_epoch,callbacks=[reduce_lr,checkpoint]) #开始训练模型\n",
398 | " text_generate(model,data_file,maxlen,[0.1,0.4,0.8],print_save_path,epoch * 5)\n",
399 | " model.save(model_save_path + \"epoch_{}.hdf5\".format(epoch * 5))"
400 | ]
401 | },
402 | {
403 | "cell_type": "markdown",
404 | "metadata": {},
405 | "source": [
406 | "### 模型预测"
407 | ]
408 | },
409 | {
410 | "cell_type": "code",
411 | "execution_count": 32,
412 | "metadata": {},
413 | "outputs": [],
414 | "source": [
415 | "def generate_text_sentence(seed_text,model_filename): #测试代码和上面训练模型的代码基本一样,就不再介绍\n",
416 | " model.load_weights(model_filename)\n",
417 | " \n",
418 | " strings=''\n",
419 | " last_word=''\n",
420 | " seed_text = re.split('([: ,.\\n(){}\\[\\]=])',seed_text)\n",
421 | " seed_text = list(filter(lambda x: x!=' 'and x!='',seed_text))\n",
422 | " \n",
423 | " generated_text = seed_text[:]\n",
424 | " \n",
425 | " for temperature in [0.1,0.4,0.8]:\n",
426 | " strings += '\\n' + '-------------temperature:' + str(temperature) +'-------------\\n' +'\\n'\n",
427 | " \n",
428 | " for i in range(50):\n",
429 | " if i == 0:\n",
430 | " for k in range(len(generated_text)):\n",
431 | " if generated_text[k] not in mark and last_word not in mark:\n",
432 | " strings += ' ' + generated_text[k]\n",
433 | " else:\n",
434 | " strings += generated_text[k]\n",
435 | " last_word = generated_text[k]\n",
436 | "\n",
437 | " sampled = np.zeros((1,len(generated_text)))\n",
438 | " for t,word in enumerate(generated_text):\n",
439 | " sampled[0,t] = word_indices[word]\n",
440 | "\n",
441 | " preds = model.predict(sampled,verbose=0)[0]\n",
442 | " next_index = sample(preds,temperature = 0.3)\n",
443 | " next_word = words[next_index]\n",
444 | "\n",
445 | "\n",
446 | " generated_text.append(next_word)\n",
447 | "\n",
448 | " #if len(generated_text) == maxlen:\n",
449 | " # generated_text = generated_text[1:]\n",
450 | "\n",
451 | " if next_word not in mark and last_word not in mark:\n",
452 | " strings += ' ' + next_word\n",
453 | " else:\n",
454 | " strings += next_word\n",
455 | "\n",
456 | " last_word = next_word\n",
457 | "\n",
458 | " if next_word == '\\n':\n",
459 | " break\n",
460 | " \n",
461 | " generated_text = seed_text[:]\n",
462 | " \n",
463 | " return strings\n"
464 | ]
465 | },
466 | {
467 | "cell_type": "code",
468 | "execution_count": null,
469 | "metadata": {},
470 | "outputs": [],
471 | "source": []
472 | }
473 | ],
474 | "metadata": {
475 | "kernelspec": {
476 | "display_name": "myconda",
477 | "language": "python",
478 | "name": "myconda"
479 | },
480 | "language_info": {
481 | "codemirror_mode": {
482 | "name": "ipython",
483 | "version": 3
484 | },
485 | "file_extension": ".py",
486 | "mimetype": "text/x-python",
487 | "name": "python",
488 | "nbconvert_exporter": "python",
489 | "pygments_lexer": "ipython3",
490 | "version": "3.5.6"
491 | }
492 | },
493 | "nbformat": 4,
494 | "nbformat_minor": 4
495 | }
496 |
--------------------------------------------------------------------------------
/model/02-Seq2Seq/Seq2Seq_tf10.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import tensorflow as tf\n",
10 | "import re\n",
11 | "import os\n",
12 | "import collections\n",
13 | "import numpy as np\n",
14 | "import codecs\n",
15 | "import random\n",
16 | "from operator import itemgetter"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": 2,
22 | "metadata": {},
23 | "outputs": [],
24 | "source": [
25 | "#GPU设置\n",
26 | "config = tf.ConfigProto()\n",
27 | "config.gpu_options.per_process_gpu_memory_fraction = 0.95 #占用95%显存\n",
28 | "session = tf.Session(config=config)\n",
29 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\""
30 | ]
31 | },
32 | {
33 | "cell_type": "markdown",
34 | "metadata": {},
35 | "source": [
36 | "## 构建词表"
37 | ]
38 | },
39 | {
40 | "cell_type": "code",
41 | "execution_count": 14,
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "def get_vocab(input_data,min_word_freq):\n",
46 | " counter = collections.Counter()\n",
47 | " with codecs.open(input_data,\"r\",\"utf-8\") as f:\n",
48 | " for line in f:\n",
49 | " line = ' '.join(re.split(' |\\t|\\v|\\n',line)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n",
50 | " line = re.split('([: ,.(){}\\[\\]=])',line) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n",
51 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
52 | " for word in line:\n",
53 | " counter[word] += 1\n",
54 | " \n",
55 | " counter = filter(lambda x: x[1] > min_word_freq, counter.items())\n",
56 | " sorted_word_to_cnt = sorted(counter,key=itemgetter(1),reverse=True)\n",
57 | " sorted_words = [x[0] for x in sorted_word_to_cnt]\n",
58 | "\n",
59 | " sorted_words = [\"\",\"\",\"\"] + sorted_words\n",
60 | "\n",
61 | " print(\"vocab_len: \" + str(len(sorted_words)))\n",
62 | "\n",
63 | " return sorted_words"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 15,
69 | "metadata": {},
70 | "outputs": [
71 | {
72 | "name": "stdout",
73 | "output_type": "stream",
74 | "text": [
75 | "vocab_len: 11928\n"
76 | ]
77 | },
78 | {
79 | "data": {
80 | "text/plain": [
81 | "['in_channels',\n",
82 | " 'zeros',\n",
83 | " 'class',\n",
84 | " 'plt',\n",
85 | " 'session',\n",
86 | " 'variable_scope',\n",
87 | " 'join',\n",
88 | " 'size',\n",
89 | " 'Session',\n",
90 | " 'zip']"
91 | ]
92 | },
93 | "execution_count": 15,
94 | "metadata": {},
95 | "output_type": "execute_result"
96 | }
97 | ],
98 | "source": [
99 | "input_data = \"../00-data/tf_data.txt\"\n",
100 | "min_word_freq = 0\n",
101 | "vocab = get_vocab(input_data,min_word_freq)\n",
102 | "vocab[70:80]"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 8,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "word_dict = {word:index for index,word in enumerate(vocab)}"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 10,
117 | "metadata": {},
118 | "outputs": [
119 | {
120 | "data": {
121 | "text/plain": [
122 | "78"
123 | ]
124 | },
125 | "execution_count": 10,
126 | "metadata": {},
127 | "output_type": "execute_result"
128 | }
129 | ],
130 | "source": [
131 | "word_dict[\"Session\"]"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "## 【随机】前k句预测下一句"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 5,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "def clean_and_split(line):\n",
148 | " line = ' '.join(re.split(' |\\t|\\v|\\n',line)) \n",
149 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n",
150 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n",
151 | " return line"
152 | ]
153 | },
154 | {
155 | "cell_type": "code",
156 | "execution_count": 6,
157 | "metadata": {},
158 | "outputs": [],
159 | "source": [
160 | "def get_newdata_by_random(raw_data,vocab,enc_vec_data,enc_text_data,dec_vec_data,dec_text_data,rand_max=10,duplicate=2):\n",
161 | " import codecs\n",
162 | " import sys\n",
163 | " import re\n",
164 | "\n",
165 | " word_to_id = {k:v for(k,v) in zip(vocab,range(len(vocab)))}\n",
166 | "\n",
167 | " def get_id(word):\n",
168 | " return word_to_id[word] if word in word_to_id else word_to_id[\"\"]\n",
169 | "\n",
170 | " fout_vec_enc = codecs.open(enc_vec_data,\"w\",\"utf-8\")\n",
171 | " fout_vec_dec = codecs.open(dec_vec_data,\"w\",\"utf-8\")\n",
172 | " fout_text_enc = codecs.open(enc_text_data,\"w\",\"utf-8\")\n",
173 | " fout_text_dec = codecs.open(dec_text_data,\"w\",\"utf-8\")\n",
174 | " \n",
175 | " data_length = 0\n",
176 | " with open(raw_data,\"r\") as fin:\n",
177 | " lines = fin.readlines()\n",
178 | " for i in range(rand_max,len(lines)):\n",
179 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n",
180 | " for rand_num in rand_nums:\n",
181 | " #构造enc_data\n",
182 | " words = []\n",
183 | " for j in range(i - rand_num,i):\n",
184 | " line = clean_and_split(lines[j])\n",
185 | " words += [\"\"] + line + [\"\"]\n",
186 | " out_line = ' '.join([str(get_id(w)) for w in words]) + '\\n'\n",
187 | " fout_text_enc.write(' '.join(words) + '\\n')\n",
188 | " fout_vec_enc.write(out_line)\n",
189 | "\n",
190 | " #构造dec_data\n",
191 | " words = []\n",
192 | " line = clean_and_split(lines[i])\n",
193 | " words = line + [\"\"]\n",
194 | " out_line = ' '.join([str(get_id(w)) for w in words]) + '\\n'\n",
195 | " fout_text_dec.write(' '.join(words) + '\\n')\n",
196 | " fout_vec_dec.write(out_line)\n",
197 | " \n",
198 | " data_length += 1\n",
199 | " fout_vec_enc.close()\n",
200 | " fout_vec_dec.close()\n",
201 | " fout_text_enc.close()\n",
202 | " fout_text_dec.close()\n",
203 | " \n",
204 | " return data_length"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": 7,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "enc_vec_data = \"../00-data/train.enc\"\n",
214 | "dec_vec_data = \"../00-data/train.dec\"\n",
215 | "enc_text_data = \"../00-data/train_text.enc\"\n",
216 | "dec_text_data = \"../00-data/train_text.dec\""
217 | ]
218 | },
219 | {
220 | "cell_type": "code",
221 | "execution_count": 8,
222 | "metadata": {},
223 | "outputs": [
224 | {
225 | "name": "stdout",
226 | "output_type": "stream",
227 | "text": [
228 | "data_len: 74052\n"
229 | ]
230 | }
231 | ],
232 | "source": [
233 | "data_length = get_newdata_by_random(input_data,vocab,enc_vec_data,enc_text_data,dec_vec_data,dec_text_data)\n",
234 | "print(\"data_len: \" + str(data_length))"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "metadata": {},
240 | "source": [
241 | "## 原始数据向量化"
242 | ]
243 | },
244 | {
245 | "cell_type": "code",
246 | "execution_count": 9,
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "def MakeDataset(file_path):\n",
251 | " dataset = tf.data.TextLineDataset(file_path)\n",
252 | " dataset = dataset.map(lambda string:tf.string_split([string]).values)\n",
253 | " dataset = dataset.map(\n",
254 | " lambda string: tf.string_to_number(string,tf.int32))\n",
255 | " dataset = dataset.map(lambda x:(x,tf.size(x)))\n",
256 | " return dataset"
257 | ]
258 | },
259 | {
260 | "cell_type": "code",
261 | "execution_count": 10,
262 | "metadata": {},
263 | "outputs": [],
264 | "source": [
265 | "def MakeSrcTrgDataset(src_path,trg_path,batch_size,shuffle_size=3000,start_id=1):\n",
266 | " src_data = MakeDataset(src_path)\n",
267 | " trg_data = MakeDataset(trg_path)\n",
268 | " \n",
269 | " dataset = tf.data.Dataset.zip((src_data,trg_data))\n",
270 | " \n",
271 | " #删除内容为空or长度过长的句子\n",
272 | " #不需要执行\n",
273 | " def FilterLength(src_tuple,trg_tuple):\n",
274 | " ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)\n",
275 | " src_len_ok = tf.logical_and(\n",
276 | " tf.greater(src_len,1),tf.less_equal(src_len,MAX_LEN))\n",
277 | " trg_len_ok = tf.logical_and(\n",
278 | " tf.greater(trg_len,1),tf.less_equal(trg_len,MAX_LEN))\n",
279 | " return tf.logical_and(src_len_ok,trg_len_ok)\n",
280 | " #dataset = dataset.filter(FilterLength)\n",
281 | " \n",
282 | " #生成 X Y Z 作为解码器的输入\n",
283 | " def MakeTrgInput(src_tuple,trg_tuple):\n",
284 | " ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)\n",
285 | " trg_input = tf.concat([[start_id],trg_label[:-1]],axis=0)\n",
286 | " return ((src_input,src_len),(trg_input,trg_label,trg_len))\n",
287 | " dataset = dataset.map(MakeTrgInput)\n",
288 | " \n",
289 | " #随机打乱训练数据\n",
290 | " dataset = dataset.shuffle(shuffle_size)\n",
291 | " \n",
292 | " #规定填充后输出的数据维度\n",
293 | " padded_shapes = (\n",
294 | " (tf.TensorShape([None]),\n",
295 | " tf.TensorShape([])),\n",
296 | " (tf.TensorShape([None]),\n",
297 | " tf.TensorShape([None]),\n",
298 | " tf.TensorShape([])))\n",
299 | " #调用padded_batch方法进行batching操作\n",
300 | " batched_dataset = dataset.padded_batch(batch_size,padded_shapes)\n",
301 | " return batched_dataset"
302 | ]
303 | },
304 | {
305 | "cell_type": "markdown",
306 | "metadata": {},
307 | "source": [
308 | "## seq2seq + attention模型\n",
309 | "encoder:2层单向lstm\n",
310 | "decoder:1层单向lstm"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": 11,
316 | "metadata": {},
317 | "outputs": [],
318 | "source": [
319 | "CHECKPOINT_PATH = \"../02-checkpoints/\"\n",
320 | "HIDDEN_SIZE = 256\n",
321 | "ENCODER_LAYERS = 2\n",
322 | "DECODER_LAYERS = 1\n",
323 | "SRC_VOCAB_SIZE = len(vocab)\n",
324 | "TRG_VOCAB_SIZE = len(vocab)\n",
325 | "BATCH_SIZE = 64\n",
326 | "NUM_EPOCH = 10\n",
327 | "KEEP_PROB = 0.9\n",
328 | "MAX_GRAD_NORM = 5\n",
329 | "LEARNING_RATE_BASE = 1.0\n",
330 | "LEARNING_RATE_DECAY = 0.7\n",
331 | "SHARE_EMB_AND_SOFTMAX = True"
332 | ]
333 | },
334 | {
335 | "cell_type": "code",
336 | "execution_count": 12,
337 | "metadata": {},
338 | "outputs": [],
339 | "source": [
340 | "# 定义NMTModel类来描述模型。\n",
341 | "class NMTModel(object):\n",
342 | " # 在模型的初始化函数中定义模型要用到的变量。\n",
343 | " def __init__(self):\n",
344 | " # 定义编码器和解码器所使用的LSTM结构。\n",
345 | " self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\n",
346 | " self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\n",
347 | " \n",
348 | " self.enc_cell = tf.nn.rnn_cell.MultiRNNCell(\n",
349 | " [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) \n",
350 | " for _ in range(ENCODER_LAYERS)])\n",
351 | " \n",
352 | " self.dec_cell = tf.nn.rnn_cell.MultiRNNCell(\n",
353 | " [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) \n",
354 | " for _ in range(DECODER_LAYERS)])\n",
355 | "\n",
356 | " # 为源语言和目标语言分别定义词向量。 \n",
357 | " self.src_embedding = tf.get_variable(\n",
358 | " \"src_emb\", [SRC_VOCAB_SIZE, HIDDEN_SIZE])\n",
359 | " self.trg_embedding = tf.get_variable(\n",
360 | " \"trg_emb\", [TRG_VOCAB_SIZE, HIDDEN_SIZE])\n",
361 | "\n",
362 | " # 定义softmax层的变量\n",
363 | " if SHARE_EMB_AND_SOFTMAX:\n",
364 | " self.softmax_weight = tf.transpose(self.trg_embedding)\n",
365 | " else:\n",
366 | " self.softmax_weight = tf.get_variable(\n",
367 | " \"weight\", [HIDDEN_SIZE, TRG_VOCAB_SIZE])\n",
368 | " self.softmax_bias = tf.get_variable(\n",
369 | " \"softmax_bias\", [TRG_VOCAB_SIZE])\n",
370 | "\n",
371 | " # 在forward函数中定义模型的前向计算图。\n",
372 | " # src_input, src_size, trg_input, trg_label, trg_size分别是上面\n",
373 | " # MakeSrcTrgDataset函数产生的五种张量。\n",
374 | " def forward(self, src_input, src_size, trg_input, trg_label, trg_size,data_length):\n",
375 | " global_step = tf.Variable(0, trainable=False)\n",
376 | " batch_size = tf.shape(src_input)[0]\n",
377 | " \n",
378 | " # 将输入和输出单词编号转为词向量。\n",
379 | " src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)\n",
380 | " trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)\n",
381 | " \n",
382 | " # 在词向量上进行dropout。\n",
383 | " src_emb = tf.nn.dropout(src_emb, KEEP_PROB)\n",
384 | " trg_emb = tf.nn.dropout(trg_emb, KEEP_PROB)\n",
385 | "\n",
386 | " # 使用dynamic_rnn构造编码器。\n",
387 | " # 编码器读取源句子每个位置的词向量,输出最后一步的隐藏状态enc_state。\n",
388 | " # 因为编码器是一个双层LSTM,因此enc_state是一个包含两个LSTMStateTuple类\n",
389 | " # 张量的tuple,每个LSTMStateTuple对应编码器中的一层。\n",
390 | " # 张量的维度是 [batch_size, HIDDEN_SIZE]。\n",
391 | " # enc_outputs是顶层LSTM在每一步的输出,它的维度是[batch_size, \n",
392 | " # max_time, HIDDEN_SIZE]。Seq2Seq模型中不需要用到enc_outputs,而\n",
393 | " # 后面介绍的attention模型会用到它。\n",
394 | " # 下面的代码取代了Seq2Seq样例代码中forward函数里的相应部分。\n",
395 | " with tf.variable_scope(\"encoder\"):\n",
396 | " # 构造编码器时,使用dynamic_rnn构造单向循环网络。\n",
397 | " # 单向循环网络的顶层输出enc_outputs是一个包含两个张量的tuple,每个张量的\n",
398 | " # 维度都是[batch_size, max_time, HIDDEN_SIZE],代表两个LSTM在每一步的输出。\n",
399 | " enc_outputs,enc_state = tf.nn.dynamic_rnn(\n",
400 | " self.enc_cell,src_emb,src_size,dtype=tf.float32) \n",
401 | "\n",
402 | " with tf.variable_scope(\"decoder\"):\n",
403 | " # 选择注意力权重的计算模型。BahdanauAttention是使用一个隐藏层的前馈神经网络。\n",
404 | " # memory_sequence_length是一个维度为[batch_size]的张量,代表batch\n",
405 | " # 中每个句子的长度,Attention需要根据这个信息把填充位置的注意力权重设置为0。\n",
406 | " attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(\n",
407 | " HIDDEN_SIZE, enc_outputs,\n",
408 | " memory_sequence_length=src_size)\n",
409 | "\n",
410 | " # 将解码器的循环神经网络self.dec_cell和注意力一起封装成更高层的循环神经网络。\n",
411 | " attention_cell = tf.contrib.seq2seq.AttentionWrapper(\n",
412 | " self.dec_cell, attention_mechanism,\n",
413 | " attention_layer_size=HIDDEN_SIZE)\n",
414 | "\n",
415 | " # 使用attention_cell和dynamic_rnn构造编码器。\n",
416 | " # 这里没有指定init_state,也就是没有使用编码器的输出来初始化输入,而完全依赖\n",
417 | " # 注意力作为信息来源。\n",
418 | " dec_outputs, _ = tf.nn.dynamic_rnn(\n",
419 | " attention_cell, trg_emb, trg_size, dtype=tf.float32)\n",
420 | "\n",
421 | " # 计算解码器每一步的log perplexity。这一步与语言模型代码相同。\n",
422 | " output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])\n",
423 | " logits = tf.matmul(output, self.softmax_weight) + self.softmax_bias\n",
424 | " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(\n",
425 | " labels=tf.reshape(trg_label, [-1]), logits=logits)\n",
426 | "\n",
427 | " # 在计算平均损失时,需要将填充位置的权重设置为0,以避免无效位置的预测干扰\n",
428 | " # 模型的训练。\n",
429 | " label_weights = tf.sequence_mask(\n",
430 | " trg_size, maxlen=tf.shape(trg_label)[1], dtype=tf.float32)\n",
431 | " label_weights = tf.reshape(label_weights, [-1])\n",
432 | " cost = tf.reduce_sum(loss * label_weights)\n",
433 | " cost_op = cost / tf.reduce_sum(label_weights)\n",
434 | " \n",
435 | " # 定义反向传播操作。反向操作的实现与语言模型代码相同。\n",
436 | " trainable_variables = tf.trainable_variables()\n",
437 | "\n",
438 | " # 控制梯度大小,定义优化方法和训练步骤。\n",
439 | " grads = tf.gradients(cost / tf.to_float(batch_size),\n",
440 | " trainable_variables)\n",
441 | " grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)\n",
442 | " \n",
443 | " learning_rate = tf.train.exponential_decay(\n",
444 | " LEARNING_RATE_BASE,\n",
445 | " global_step,\n",
446 | " data_length / batch_size, \n",
447 | " LEARNING_RATE_DECAY,\n",
448 | " staircase=True)\n",
449 | " optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n",
450 | " train_op = optimizer.apply_gradients(\n",
451 | " zip(grads, trainable_variables))\n",
452 | " return cost_op, train_op,learning_rate,global_step"
453 | ]
454 | },
455 | {
456 | "cell_type": "code",
457 | "execution_count": 13,
458 | "metadata": {},
459 | "outputs": [],
460 | "source": [
461 | "def run_epoch(session,cost_op,train_op,learning_rate_op,global_step_op,step,saver,epoch):\n",
462 | " while True:\n",
463 | " try:\n",
464 | " cost,_,learning_rate,global_step = session.run([cost_op,train_op,learning_rate_op,global_step_op])\n",
465 | " if global_step % 50 == 0:\n",
466 | " print(\"After %d global_steps,per token cost is %.3f,learning_rate is %.5f\" %(global_step,cost,learning_rate))\n",
467 | " session.run(tf.assign(global_step_op,step))\n",
468 | " step += 1\n",
469 | " except tf.errors.OutOfRangeError:\n",
470 | " if epoch % 2 == 0:\n",
471 | " saver.save(session,CHECKPOINT_PATH,global_step=global_step)\n",
472 | " break\n",
473 | " return step"
474 | ]
475 | },
476 | {
477 | "cell_type": "code",
478 | "execution_count": null,
479 | "metadata": {},
480 | "outputs": [
481 | {
482 | "name": "stdout",
483 | "output_type": "stream",
484 | "text": [
485 | "WARNING:tensorflow:From :6: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
486 | "Instructions for updating:\n",
487 | "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n",
488 | "WARNING:tensorflow:From :11: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n",
489 | "Instructions for updating:\n",
490 | "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n",
491 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
492 | "Instructions for updating:\n",
493 | "Colocations handled automatically by placer.\n",
494 | "WARNING:tensorflow:From :44: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
495 | "Instructions for updating:\n",
496 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
497 | "WARNING:tensorflow:From :61: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n",
498 | "Instructions for updating:\n",
499 | "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n",
500 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py:626: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
501 | "Instructions for updating:\n",
502 | "Use tf.cast instead.\n",
503 | "\n",
504 | "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
505 | "For more information, please see:\n",
506 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
507 | " * https://github.com/tensorflow/addons\n",
508 | "If you depend on functionality not listed there, please file an issue.\n",
509 | "\n",
510 | "WARNING:tensorflow:From :100: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n",
511 | "Instructions for updating:\n",
512 | "Use tf.cast instead.\n",
513 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n",
514 | "Instructions for updating:\n",
515 | "Use standard file APIs to check for files with this prefix.\n",
516 | "INFO:tensorflow:Restoring parameters from ../02-checkpoints/02-0318/-4635\n",
517 | "In EPOCH: 1\n",
518 | "After 50 global_steps,per token cost is 2.616,learning_rate is 1.00000\n",
519 | "After 100 global_steps,per token cost is 2.445,learning_rate is 1.00000\n",
520 | "After 150 global_steps,per token cost is 1.941,learning_rate is 1.00000\n",
521 | "After 200 global_steps,per token cost is 1.367,learning_rate is 1.00000\n",
522 | "After 250 global_steps,per token cost is 1.796,learning_rate is 1.00000\n",
523 | "After 300 global_steps,per token cost is 2.015,learning_rate is 1.00000\n",
524 | "After 350 global_steps,per token cost is 2.002,learning_rate is 1.00000\n",
525 | "After 400 global_steps,per token cost is 1.729,learning_rate is 1.00000\n",
526 | "After 450 global_steps,per token cost is 1.414,learning_rate is 1.00000\n",
527 | "After 500 global_steps,per token cost is 1.769,learning_rate is 1.00000\n",
528 | "After 550 global_steps,per token cost is 1.915,learning_rate is 1.00000\n",
529 | "After 600 global_steps,per token cost is 2.256,learning_rate is 1.00000\n",
530 | "After 650 global_steps,per token cost is 2.122,learning_rate is 1.00000\n",
531 | "After 700 global_steps,per token cost is 2.302,learning_rate is 1.00000\n",
532 | "After 750 global_steps,per token cost is 2.152,learning_rate is 1.00000\n",
533 | "After 800 global_steps,per token cost is 1.761,learning_rate is 1.00000\n",
534 | "After 850 global_steps,per token cost is 1.629,learning_rate is 1.00000\n",
535 | "After 900 global_steps,per token cost is 1.359,learning_rate is 1.00000\n",
536 | "After 950 global_steps,per token cost is 1.678,learning_rate is 1.00000\n",
537 | "After 1000 global_steps,per token cost is 2.382,learning_rate is 1.00000\n",
538 | "After 1050 global_steps,per token cost is 1.874,learning_rate is 1.00000\n",
539 | "After 1100 global_steps,per token cost is 1.742,learning_rate is 1.00000\n",
540 | "After 1150 global_steps,per token cost is 1.647,learning_rate is 1.00000\n",
541 | "In EPOCH: 2\n",
542 | "After 1200 global_steps,per token cost is 1.876,learning_rate is 0.70000\n",
543 | "After 1250 global_steps,per token cost is 1.677,learning_rate is 0.70000\n",
544 | "After 1300 global_steps,per token cost is 1.445,learning_rate is 0.70000\n",
545 | "After 1350 global_steps,per token cost is 0.987,learning_rate is 0.70000\n",
546 | "After 1400 global_steps,per token cost is 1.265,learning_rate is 0.70000\n",
547 | "After 1450 global_steps,per token cost is 1.468,learning_rate is 0.70000\n",
548 | "After 1500 global_steps,per token cost is 1.557,learning_rate is 0.70000\n",
549 | "After 1550 global_steps,per token cost is 1.438,learning_rate is 0.70000\n",
550 | "After 1600 global_steps,per token cost is 1.204,learning_rate is 0.70000\n",
551 | "After 1650 global_steps,per token cost is 1.480,learning_rate is 0.70000\n",
552 | "After 1700 global_steps,per token cost is 1.611,learning_rate is 0.70000\n",
553 | "After 1750 global_steps,per token cost is 1.511,learning_rate is 0.70000\n",
554 | "After 1800 global_steps,per token cost is 1.696,learning_rate is 0.70000\n",
555 | "After 1850 global_steps,per token cost is 1.664,learning_rate is 0.70000\n",
556 | "After 1900 global_steps,per token cost is 1.624,learning_rate is 0.70000\n",
557 | "After 1950 global_steps,per token cost is 1.385,learning_rate is 0.70000\n",
558 | "After 2000 global_steps,per token cost is 1.141,learning_rate is 0.70000\n",
559 | "After 2050 global_steps,per token cost is 0.869,learning_rate is 0.70000\n",
560 | "After 2100 global_steps,per token cost is 1.091,learning_rate is 0.70000\n",
561 | "After 2150 global_steps,per token cost is 1.086,learning_rate is 0.70000\n",
562 | "After 2200 global_steps,per token cost is 1.226,learning_rate is 0.70000\n",
563 | "After 2250 global_steps,per token cost is 1.461,learning_rate is 0.70000\n",
564 | "After 2300 global_steps,per token cost is 1.344,learning_rate is 0.70000\n",
565 | "In EPOCH: 3\n",
566 | "After 2350 global_steps,per token cost is 1.346,learning_rate is 0.49000\n",
567 | "After 2400 global_steps,per token cost is 1.525,learning_rate is 0.49000\n",
568 | "After 2450 global_steps,per token cost is 1.384,learning_rate is 0.49000\n",
569 | "After 2500 global_steps,per token cost is 0.892,learning_rate is 0.49000\n",
570 | "After 2550 global_steps,per token cost is 0.815,learning_rate is 0.49000\n",
571 | "After 2600 global_steps,per token cost is 1.259,learning_rate is 0.49000\n",
572 | "After 2650 global_steps,per token cost is 1.180,learning_rate is 0.49000\n",
573 | "After 2700 global_steps,per token cost is 1.178,learning_rate is 0.49000\n",
574 | "After 2750 global_steps,per token cost is 1.122,learning_rate is 0.49000\n",
575 | "After 2800 global_steps,per token cost is 1.074,learning_rate is 0.49000\n",
576 | "After 2850 global_steps,per token cost is 1.415,learning_rate is 0.49000\n",
577 | "After 2900 global_steps,per token cost is 1.409,learning_rate is 0.49000\n",
578 | "After 2950 global_steps,per token cost is 1.326,learning_rate is 0.49000\n",
579 | "After 3000 global_steps,per token cost is 1.502,learning_rate is 0.49000\n",
580 | "After 3050 global_steps,per token cost is 1.608,learning_rate is 0.49000\n",
581 | "After 3100 global_steps,per token cost is 1.347,learning_rate is 0.49000\n",
582 | "After 3150 global_steps,per token cost is 1.076,learning_rate is 0.49000\n",
583 | "After 3200 global_steps,per token cost is 0.790,learning_rate is 0.49000\n",
584 | "After 3250 global_steps,per token cost is 0.996,learning_rate is 0.49000\n",
585 | "After 3300 global_steps,per token cost is 1.269,learning_rate is 0.49000\n",
586 | "After 3350 global_steps,per token cost is 1.299,learning_rate is 0.49000\n",
587 | "After 3400 global_steps,per token cost is 1.397,learning_rate is 0.49000\n",
588 | "After 3450 global_steps,per token cost is 1.084,learning_rate is 0.49000\n",
589 | "In EPOCH: 4\n",
590 | "After 3500 global_steps,per token cost is 1.322,learning_rate is 0.34300\n",
591 | "After 3550 global_steps,per token cost is 1.174,learning_rate is 0.34300\n",
592 | "After 3600 global_steps,per token cost is 1.244,learning_rate is 0.34300\n",
593 | "After 3650 global_steps,per token cost is 0.905,learning_rate is 0.34300\n",
594 | "After 3700 global_steps,per token cost is 0.956,learning_rate is 0.34300\n",
595 | "After 3750 global_steps,per token cost is 1.027,learning_rate is 0.34300\n",
596 | "After 3800 global_steps,per token cost is 1.224,learning_rate is 0.34300\n",
597 | "After 3850 global_steps,per token cost is 1.405,learning_rate is 0.34300\n",
598 | "After 3900 global_steps,per token cost is 0.747,learning_rate is 0.34300\n",
599 | "After 3950 global_steps,per token cost is 0.739,learning_rate is 0.34300\n",
600 | "After 4000 global_steps,per token cost is 1.245,learning_rate is 0.34300\n",
601 | "After 4050 global_steps,per token cost is 1.008,learning_rate is 0.34300\n",
602 | "After 4100 global_steps,per token cost is 1.243,learning_rate is 0.34300\n",
603 | "After 4150 global_steps,per token cost is 1.303,learning_rate is 0.34300\n",
604 | "After 4200 global_steps,per token cost is 1.521,learning_rate is 0.34300\n",
605 | "After 4250 global_steps,per token cost is 1.167,learning_rate is 0.34300\n",
606 | "After 4300 global_steps,per token cost is 1.118,learning_rate is 0.34300\n",
607 | "After 4350 global_steps,per token cost is 0.657,learning_rate is 0.34300\n",
608 | "After 4400 global_steps,per token cost is 0.544,learning_rate is 0.34300\n",
609 | "After 4450 global_steps,per token cost is 0.950,learning_rate is 0.34300\n",
610 | "After 4500 global_steps,per token cost is 1.181,learning_rate is 0.34300\n",
611 | "After 4550 global_steps,per token cost is 1.271,learning_rate is 0.34300\n",
612 | "After 4600 global_steps,per token cost is 0.746,learning_rate is 0.34300\n",
613 | "In EPOCH: 5\n",
614 | "After 4650 global_steps,per token cost is 1.303,learning_rate is 0.24010\n",
615 | "After 4700 global_steps,per token cost is 1.117,learning_rate is 0.24010\n",
616 | "After 4750 global_steps,per token cost is 1.203,learning_rate is 0.24010\n",
617 | "After 4800 global_steps,per token cost is 0.825,learning_rate is 0.24010\n",
618 | "After 4850 global_steps,per token cost is 0.638,learning_rate is 0.24010\n",
619 | "After 4900 global_steps,per token cost is 0.877,learning_rate is 0.24010\n",
620 | "After 4950 global_steps,per token cost is 1.058,learning_rate is 0.24010\n",
621 | "After 5000 global_steps,per token cost is 0.943,learning_rate is 0.24010\n",
622 | "After 5050 global_steps,per token cost is 0.828,learning_rate is 0.24010\n",
623 | "After 5100 global_steps,per token cost is 0.860,learning_rate is 0.24010\n",
624 | "After 5150 global_steps,per token cost is 0.915,learning_rate is 0.24010\n",
625 | "After 5200 global_steps,per token cost is 1.021,learning_rate is 0.24010\n",
626 | "After 5250 global_steps,per token cost is 1.344,learning_rate is 0.24010\n",
627 | "After 5300 global_steps,per token cost is 1.096,learning_rate is 0.24010\n",
628 | "After 5350 global_steps,per token cost is 1.425,learning_rate is 0.24010\n",
629 | "After 5400 global_steps,per token cost is 0.972,learning_rate is 0.24010\n",
630 | "After 5450 global_steps,per token cost is 1.101,learning_rate is 0.24010\n",
631 | "After 5500 global_steps,per token cost is 0.593,learning_rate is 0.24010\n",
632 | "After 5550 global_steps,per token cost is 0.569,learning_rate is 0.24010\n",
633 | "After 5600 global_steps,per token cost is 0.745,learning_rate is 0.24010\n",
634 | "After 5650 global_steps,per token cost is 0.944,learning_rate is 0.24010\n",
635 | "After 5700 global_steps,per token cost is 1.116,learning_rate is 0.24010\n",
636 | "After 5750 global_steps,per token cost is 0.812,learning_rate is 0.24010\n",
637 | "In EPOCH: 6\n",
638 | "After 5800 global_steps,per token cost is 1.252,learning_rate is 0.16807\n",
639 | "After 5850 global_steps,per token cost is 0.931,learning_rate is 0.16807\n",
640 | "After 5900 global_steps,per token cost is 1.007,learning_rate is 0.16807\n",
641 | "After 5950 global_steps,per token cost is 1.001,learning_rate is 0.16807\n",
642 | "After 6000 global_steps,per token cost is 0.770,learning_rate is 0.16807\n",
643 | "After 6050 global_steps,per token cost is 0.859,learning_rate is 0.16807\n",
644 | "After 6100 global_steps,per token cost is 1.054,learning_rate is 0.16807\n",
645 | "After 6150 global_steps,per token cost is 1.127,learning_rate is 0.16807\n",
646 | "After 6200 global_steps,per token cost is 0.683,learning_rate is 0.16807\n",
647 | "After 6250 global_steps,per token cost is 0.933,learning_rate is 0.16807\n",
648 | "After 6300 global_steps,per token cost is 0.823,learning_rate is 0.16807\n",
649 | "After 6350 global_steps,per token cost is 0.927,learning_rate is 0.16807\n",
650 | "After 6400 global_steps,per token cost is 1.110,learning_rate is 0.16807\n",
651 | "After 6450 global_steps,per token cost is 0.911,learning_rate is 0.16807\n",
652 | "After 6500 global_steps,per token cost is 1.243,learning_rate is 0.16807\n",
653 | "After 6550 global_steps,per token cost is 0.856,learning_rate is 0.16807\n",
654 | "After 6600 global_steps,per token cost is 0.755,learning_rate is 0.16807\n",
655 | "After 6650 global_steps,per token cost is 0.660,learning_rate is 0.16807\n",
656 | "After 6700 global_steps,per token cost is 0.566,learning_rate is 0.16807\n",
657 | "After 6750 global_steps,per token cost is 0.749,learning_rate is 0.16807\n",
658 | "After 6800 global_steps,per token cost is 0.912,learning_rate is 0.16807\n",
659 | "After 6850 global_steps,per token cost is 0.848,learning_rate is 0.16807\n",
660 | "After 6900 global_steps,per token cost is 0.909,learning_rate is 0.16807\n",
661 | "In EPOCH: 7\n",
662 | "After 6950 global_steps,per token cost is 1.221,learning_rate is 0.11765\n",
663 | "After 7000 global_steps,per token cost is 0.989,learning_rate is 0.11765\n",
664 | "After 7050 global_steps,per token cost is 1.083,learning_rate is 0.11765\n",
665 | "After 7100 global_steps,per token cost is 0.834,learning_rate is 0.11765\n",
666 | "After 7150 global_steps,per token cost is 0.461,learning_rate is 0.11765\n",
667 | "After 7200 global_steps,per token cost is 0.692,learning_rate is 0.11765\n",
668 | "After 7250 global_steps,per token cost is 0.884,learning_rate is 0.11765\n",
669 | "After 7300 global_steps,per token cost is 0.838,learning_rate is 0.11765\n",
670 | "After 7350 global_steps,per token cost is 0.584,learning_rate is 0.11765\n",
671 | "After 7400 global_steps,per token cost is 0.714,learning_rate is 0.11765\n",
672 | "After 7450 global_steps,per token cost is 0.888,learning_rate is 0.11765\n",
673 | "After 7500 global_steps,per token cost is 0.859,learning_rate is 0.11765\n",
674 | "After 7550 global_steps,per token cost is 1.101,learning_rate is 0.11765\n",
675 | "After 7600 global_steps,per token cost is 0.969,learning_rate is 0.11765\n",
676 | "After 7650 global_steps,per token cost is 1.108,learning_rate is 0.11765\n",
677 | "After 7700 global_steps,per token cost is 0.907,learning_rate is 0.11765\n",
678 | "After 7750 global_steps,per token cost is 0.851,learning_rate is 0.11765\n",
679 | "After 7800 global_steps,per token cost is 0.567,learning_rate is 0.11765\n",
680 | "After 7850 global_steps,per token cost is 0.395,learning_rate is 0.11765\n",
681 | "After 7900 global_steps,per token cost is 0.666,learning_rate is 0.11765\n",
682 | "After 7950 global_steps,per token cost is 0.874,learning_rate is 0.11765\n",
683 | "After 8000 global_steps,per token cost is 0.945,learning_rate is 0.11765\n",
684 | "After 8050 global_steps,per token cost is 0.636,learning_rate is 0.11765\n",
685 | "After 8100 global_steps,per token cost is 0.598,learning_rate is 0.08235\n",
686 | "In EPOCH: 8\n",
687 | "After 8150 global_steps,per token cost is 1.055,learning_rate is 0.08235\n",
688 | "After 8200 global_steps,per token cost is 0.970,learning_rate is 0.08235\n",
689 | "After 8250 global_steps,per token cost is 0.936,learning_rate is 0.08235\n",
690 | "After 8300 global_steps,per token cost is 0.468,learning_rate is 0.08235\n",
691 | "After 8350 global_steps,per token cost is 0.610,learning_rate is 0.08235\n",
692 | "After 8400 global_steps,per token cost is 0.903,learning_rate is 0.08235\n",
693 | "After 8450 global_steps,per token cost is 0.718,learning_rate is 0.08235\n",
694 | "After 8500 global_steps,per token cost is 0.577,learning_rate is 0.08235\n",
695 | "After 8550 global_steps,per token cost is 0.633,learning_rate is 0.08235\n",
696 | "After 8600 global_steps,per token cost is 0.596,learning_rate is 0.08235\n",
697 | "After 8650 global_steps,per token cost is 0.829,learning_rate is 0.08235\n",
698 | "After 8700 global_steps,per token cost is 0.812,learning_rate is 0.08235\n",
699 | "After 8750 global_steps,per token cost is 0.826,learning_rate is 0.08235\n",
700 | "After 8800 global_steps,per token cost is 1.072,learning_rate is 0.08235\n",
701 | "After 8850 global_steps,per token cost is 0.958,learning_rate is 0.08235\n",
702 | "After 8900 global_steps,per token cost is 0.766,learning_rate is 0.08235\n",
703 | "After 8950 global_steps,per token cost is 0.510,learning_rate is 0.08235\n",
704 | "After 9000 global_steps,per token cost is 0.433,learning_rate is 0.08235\n",
705 | "After 9050 global_steps,per token cost is 0.442,learning_rate is 0.08235\n",
706 | "After 9100 global_steps,per token cost is 0.723,learning_rate is 0.08235\n",
707 | "After 9150 global_steps,per token cost is 0.803,learning_rate is 0.08235\n",
708 | "After 9200 global_steps,per token cost is 0.767,learning_rate is 0.08235\n",
709 | "After 9250 global_steps,per token cost is 0.646,learning_rate is 0.08235\n",
710 | "In EPOCH: 9\n",
711 | "After 9300 global_steps,per token cost is 0.939,learning_rate is 0.05765\n",
712 | "After 9350 global_steps,per token cost is 0.802,learning_rate is 0.05765\n",
713 | "After 9400 global_steps,per token cost is 0.766,learning_rate is 0.05765\n",
714 | "After 9450 global_steps,per token cost is 0.540,learning_rate is 0.05765\n",
715 | "After 9500 global_steps,per token cost is 0.510,learning_rate is 0.05765\n",
716 | "After 9550 global_steps,per token cost is 0.829,learning_rate is 0.05765\n",
717 | "After 9600 global_steps,per token cost is 0.697,learning_rate is 0.05765\n",
718 | "After 9650 global_steps,per token cost is 0.651,learning_rate is 0.05765\n",
719 | "After 9700 global_steps,per token cost is 0.564,learning_rate is 0.05765\n",
720 | "After 9750 global_steps,per token cost is 0.703,learning_rate is 0.05765\n",
721 | "After 9800 global_steps,per token cost is 0.867,learning_rate is 0.05765\n",
722 | "After 9850 global_steps,per token cost is 0.830,learning_rate is 0.05765\n",
723 | "After 9900 global_steps,per token cost is 0.817,learning_rate is 0.05765\n",
724 | "After 9950 global_steps,per token cost is 0.729,learning_rate is 0.05765\n",
725 | "After 10000 global_steps,per token cost is 0.932,learning_rate is 0.05765\n",
726 | "After 10050 global_steps,per token cost is 0.769,learning_rate is 0.05765\n",
727 | "After 10100 global_steps,per token cost is 0.541,learning_rate is 0.05765\n"
728 | ]
729 | }
730 | ],
731 | "source": [
732 | "def main():\n",
733 | " tf.reset_default_graph()\n",
734 | " initializer = tf.random_uniform_initializer(-0.05,0.05)\n",
735 | " with tf.variable_scope(\"nmt_model\",reuse=None,initializer=initializer):\n",
736 | " train_model = NMTModel()\n",
737 | " \n",
738 | " data = MakeSrcTrgDataset(enc_vec_data,dec_vec_data,BATCH_SIZE)\n",
739 | " iterator = data.make_initializable_iterator()\n",
740 | " (src,src_size),(trg_input,trg_label,trg_size) = iterator.get_next()\n",
741 | " \n",
742 | " cost_op,train_op,learning_rate_op,global_step_op = train_model.forward(src,src_size,trg_input,trg_label,trg_size,data_length)\n",
743 | " \n",
744 | " saver = tf.train.Saver()\n",
745 | " step = 1\n",
746 | " \n",
747 | " with tf.Session() as sess:\n",
748 | " tf.global_variables_initializer().run()\n",
749 | " ckpt = tf.train.get_checkpoint_state(CHECKPOINT_PATH) #获取checkpoints对象 \n",
750 | " if ckpt and ckpt.model_checkpoint_path:##判断ckpt是否为空,若不为空,才进行模型的加载,否则从头开始训练 \n",
751 | " saver.restore(sess,ckpt.model_checkpoint_path)#恢复保存的神经网络结构,实现断点续训 \n",
752 | " for i in range(NUM_EPOCH):\n",
753 | " print(\"In EPOCH: %d\" %(i + 1))\n",
754 | " sess.run(iterator.initializer)\n",
755 | " step = run_epoch(sess,cost_op,train_op,learning_rate_op,global_step_op,step,saver,i + 1)\n",
756 | "if __name__ == \"__main__\":\n",
757 | " main()"
758 | ]
759 | },
760 | {
761 | "cell_type": "code",
762 | "execution_count": null,
763 | "metadata": {},
764 | "outputs": [],
765 | "source": []
766 | }
767 | ],
768 | "metadata": {
769 | "kernelspec": {
770 | "display_name": "myconda",
771 | "language": "python",
772 | "name": "myconda"
773 | },
774 | "language_info": {
775 | "codemirror_mode": {
776 | "name": "ipython",
777 | "version": 3
778 | },
779 | "file_extension": ".py",
780 | "mimetype": "text/x-python",
781 | "name": "python",
782 | "nbconvert_exporter": "python",
783 | "pygments_lexer": "ipython3",
784 | "version": "3.7.7"
785 | }
786 | },
787 | "nbformat": 4,
788 | "nbformat_minor": 4
789 | }
790 |
--------------------------------------------------------------------------------