├── doc └── img │ ├── plugin_example.jpg │ ├── plugin_step1.jpg │ ├── plugin_step2.jpg │ ├── plugin_step3.jpg │ ├── backend_helloworld.jpg │ ├── code_completion_result_1.jpg │ └── code_completion_result_2.jpg ├── plugin └── 01-code_completion │ ├── lib │ ├── ezmorph-1.0.6.jar │ ├── commons-beanutils.jar │ ├── commons-lang-2.5.jar │ ├── commons-logging-1.1.jar │ ├── json-lib-2.4-jdk15.jar │ └── commons-collections-3.2.1.jar │ ├── .idea │ ├── compiler.xml │ ├── libraries │ │ ├── ezmorph_1_0_6.xml │ │ ├── commons_lang_2_5.xml │ │ ├── commons_beanutils.xml │ │ ├── json_lib_2_4_jdk15.xml │ │ ├── commons_logging_1_1.xml │ │ └── commons_collections_3_2_1.xml │ ├── .gitignore │ ├── modules.xml │ └── misc.xml │ ├── out │ └── production │ │ └── 01-code_completion │ │ ├── SimpleCompletionContributor.class │ │ └── META-INF │ │ └── plugin.xml │ ├── 01-code_completion.iml │ ├── resources │ └── META-INF │ │ └── plugin.xml │ └── src │ └── SimpleCompletionContributor.java ├── backend ├── 01-flask │ ├── server_demo.py │ ├── cors.py │ └── server.py └── 02-crow │ ├── server.py │ └── server.cpp ├── README.md ├── LICENSE └── model ├── 02-Seq2Seq ├── Seq2Seq_tf20.ipynb ├── .~Seq2Seq.ipynb ├── .ipynb_checkpoints │ └── Seq2Seq-checkpoint.ipynb └── Seq2Seq_tf10.ipynb └── 01-Text_Gen ├── Text_Gen_tf10.ipynb └── .ipynb_checkpoints └── Text_Gen_tf10-checkpoint.ipynb /doc/img/plugin_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_example.jpg -------------------------------------------------------------------------------- /doc/img/plugin_step1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step1.jpg -------------------------------------------------------------------------------- /doc/img/plugin_step2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step2.jpg -------------------------------------------------------------------------------- /doc/img/plugin_step3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/plugin_step3.jpg -------------------------------------------------------------------------------- /doc/img/backend_helloworld.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/backend_helloworld.jpg -------------------------------------------------------------------------------- /doc/img/code_completion_result_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/code_completion_result_1.jpg -------------------------------------------------------------------------------- /doc/img/code_completion_result_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/doc/img/code_completion_result_2.jpg -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/ezmorph-1.0.6.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/ezmorph-1.0.6.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/commons-beanutils.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-beanutils.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/commons-lang-2.5.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-lang-2.5.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/commons-logging-1.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-logging-1.1.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/json-lib-2.4-jdk15.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/json-lib-2.4-jdk15.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/lib/commons-collections-3.2.1.jar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/lib/commons-collections-3.2.1.jar -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/compiler.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /plugin/01-code_completion/out/production/01-code_completion/SimpleCompletionContributor.class: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-ABC/AI-Coder/HEAD/plugin/01-code_completion/out/production/01-code_completion/SimpleCompletionContributor.class -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/ezmorph_1_0_6.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /../../../../../:\02-code\00-bishe\01-code_completion\.idea/dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/commons_lang_2_5.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/commons_beanutils.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/json_lib_2_4_jdk15.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/commons_logging_1_1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/libraries/commons_collections_3_2_1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /plugin/01-code_completion/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /backend/01-flask/server_demo.py: -------------------------------------------------------------------------------- 1 | from cors import crossdomain 2 | from flask import Flask, jsonify, request 3 | 4 | 5 | app = Flask(__name__) 6 | 7 | 8 | def get_args(req): 9 | if request.method == 'POST': 10 | args = request.json 11 | elif request.method == "GET": 12 | args = request.args 13 | return args 14 | 15 | 16 | @app.route("/plugin_test", methods=["GET", "POST", "OPTIONS"]) 17 | @crossdomain(origin='*', headers="Content-Type") 18 | def plugin_test(): 19 | args = get_args(request) 20 | sentence = args.get("keyword", "Error: nothing input") 21 | # 处理输入返回数据 22 | results = [] 23 | for i in range(5): 24 | results.append(sentence + " test: " + str(i)) 25 | return jsonify({"data": results}) 26 | 27 | 28 | def main(host="127.0.0.1", port=9078): 29 | app.run(host=host, port=port, debug=True) 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /plugin/01-code_completion/01-code_completion.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /plugin/01-code_completion/resources/META-INF/plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | com.yang.pku.coder_aider 3 | Code Sentence Completion for PyCharm 4 | 1.0 5 | PKU_yang 6 | 7 | 9 | Able to autocomplete for code sentences 10 | ]]> 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 24 | since-build 25 | com.intellij.modules.platform 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 39 | 40 | -------------------------------------------------------------------------------- /plugin/01-code_completion/out/production/01-code_completion/META-INF/plugin.xml: -------------------------------------------------------------------------------- 1 | 2 | com.yang.pku.coder_aider 3 | Code Sentence Completion for PyCharm 4 | 1.0 5 | PKU_yang 6 | 7 | 9 | Able to autocomplete for code sentences 10 | ]]> 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 23 | 24 | since-build 25 | com.intellij.modules.platform 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 39 | 40 | -------------------------------------------------------------------------------- /backend/01-flask/cors.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from flask import make_response, request, current_app 3 | from functools import update_wrapper 4 | 5 | 6 | def crossdomain(origin=None, methods=None, headers=None, max_age=21600, attach_to_all=True, automatic_options=True): 7 | if methods is not None: 8 | methods = ', '.join(sorted(x.upper() for x in methods)) 9 | if headers is not None and not isinstance(headers, str): 10 | headers = ', '.join(x.upper() for x in headers) 11 | if not isinstance(origin, str): 12 | origin = ', '.join(origin) 13 | if isinstance(max_age, timedelta): 14 | max_age = max_age.total_seconds() 15 | 16 | def get_methods(): 17 | if methods is not None: 18 | return methods 19 | 20 | options_resp = current_app.make_default_options_response() 21 | return options_resp.headers['allow'] 22 | 23 | def decorator(f): 24 | def wrapped_function(*args, **kwargs): 25 | if automatic_options and request.method == 'OPTIONS': 26 | resp = current_app.make_default_options_response() 27 | else: 28 | resp = make_response(f(*args, **kwargs)) 29 | if not attach_to_all and request.method != 'OPTIONS': 30 | return resp 31 | 32 | h = resp.headers 33 | 34 | h['Access-Control-Allow-Origin'] = origin 35 | h['Access-Control-Allow-Methods'] = get_methods() 36 | h['Access-Control-Max-Age'] = str(max_age) 37 | if headers is not None: 38 | h['Access-Control-Allow-Headers'] = headers 39 | return resp 40 | 41 | f.provide_automatic_options = False 42 | return update_wrapper(wrapped_function, f) 43 | return decorator 44 | -------------------------------------------------------------------------------- /backend/02-crow/server.py: -------------------------------------------------------------------------------- 1 | from keras.preprocessing.text import Tokenizer 2 | from keras.models import load_model 3 | from keras.preprocessing.sequence import pad_sequences 4 | import json 5 | 6 | class Server: 7 | def __init__(self): 8 | self.model_path = "model/CheckpointModel_2.h5" 9 | self.word_index = json.load(open('word_dict.json', 'r')) 10 | self.para = json.load(open('para_dict.json', 'r')) 11 | self.total_words = self.para['total_words'] 12 | self.max_sequence_len = self.para['max_sequence_len'] 13 | self.model = load_model(self.model_path) 14 | print("模型初始化成功!") 15 | 16 | def reference(self,keyword): 17 | result = self.rep(self.generate_text(keyword)) 18 | #print("python预测结果:" + result) 19 | return result 20 | 21 | def generate_text(self,keyword): 22 | keyword = keyword.replace("(", " pareleft", ).replace(")", " pareright").replace(".", " dot").replace(",",\ 23 | " comma").replace(" \'\'", " quotMark").replace("=", " equal").replace(":", " colon") 24 | i = 0 25 | while i <= 10: 26 | token_list = self.find_index(keyword) 27 | if token_list == None: 28 | return "无法预测" 29 | # print(token_list) 30 | token_list = pad_sequences([token_list], maxlen=self.max_sequence_len - 1, padding='pre') 31 | predicted = self.model.predict_classes(token_list, verbose=0) 32 | output_word = "" 33 | for word, index in self.word_index.items(): 34 | if index == predicted: 35 | output_word = word 36 | break 37 | if output_word == 'dom': 38 | break 39 | keyword += " " + output_word 40 | i += 1 41 | return keyword 42 | 43 | def find_index(self,keyword): 44 | index_list = [] 45 | keyword = keyword.split() 46 | for word in keyword: 47 | if not word in self.word_index: 48 | return 49 | index_list.append(self.word_index[word]) 50 | return index_list 51 | 52 | def rep(self,str): 53 | return str.replace(" pareleft", "(", ).replace(" pareright", ")").replace(" dot", ".") \ 54 | .replace(" comma", ",").replace(" quotMark ", "\'\'").replace(" equal", "=").replace(" colon", ":") \ 55 | .replace(" quotmark", "\'\'") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AI-Coder 2 | AI-Coder是一款基于PyCharm的代码句补全插件。 3 | 4 | 其补全效果如下: 5 | 6 |
7 | 8 |
句中调用代码句补全
9 | 10 |   11 | 12 | 13 |
14 | 15 |
句间调用代码句补全
16 | 17 |   18 | 19 | ## 目录结构 20 | - backend——代码句补全服务器 21 | - dataset——训练数据集 22 | - model——代码句补全模型 23 | - plugin——插件开发配置 24 | 25 |   26 | 27 | ## 服务器 28 | 代码句补全服务器尝试了两种框架,分别是Flask和Crow。 29 | ### Flask 30 | 31 | #### 1. 准备 32 | 33 | 如果没有安装 flask,先安装 flask。anaconda 自带了 flask。 34 | `pip install flask` 35 | 36 | #### 2. 运行 37 | 38 | 进入 backend 文件夹,运行 server_demo.py。 39 | 40 | 在浏览器中输入 localhost:9078/plugin_test?keyword=helloworld ,浏览器返回内容如下。 41 | 42 | 43 | 44 | 后端获取 keyword 中的数据,处理之后返回。后续我们使用模型处理输入,道理是一样的。 45 | 46 | ### [Crow](https://github.com/ipkn/crow) 47 | Crow是一个轻量级的Web服务器框架,这个框架是受Python下的Flask启发的,其实现的功能和Flask基本一致,核心的区别在于Crow是用C++编写的,性能较Flask有一定的提升。 48 | 49 |   50 | 51 | ## 数据集 52 | 数据集有两个,第一个是Keras领域的代码数据,第二个是TensorFlow领域的代码数据。 53 | 54 | 代码数据中均删除了参数。 55 | 56 |   57 | 58 | 59 | ## 模型 60 | ### 深度学习环境 61 | 62 | 模型训练环境包含两种: 63 | #### tensorflow 1 64 | - python 3.5 65 | - tensorflow 1.13.1 66 | - keras 2.2.4 67 | 68 | #### tensorflow 2 69 | - python 3.7 70 | - tensorflow 2.2 71 | 72 | 从模型代码的文件名中可以得知对应的版本号,例如: 73 | > Text_Gen_tf10.ipynb 74 | 75 | 76 | 表示该代码文件为生成模型,使用的版本为tensorflow 1的版本 77 | 78 | 79 | ### 模型介绍 80 | 模型尝试了三种,分别是: 81 | - 基于长短期记忆的代码句生成模型 82 | - 基于序列到序列的代码句预测模型 83 | - 基于Transformer的代码句预测模型 84 | 85 | 86 | 87 | 88 | 注:模型文件类型为ipynb,需要用jupyter打开。 89 | 90 |   91 | 92 | 93 | ## 插件 94 | 95 | ### 1. 准备 96 | 97 | - 开发插件所用的编辑器——IDEA 测试时使用版本 IntelliJ IDEA 2019.1 x64 98 | 99 | - 插件适用对象——Pycharm 测试时使用版本 JetBrains PyCharm Community Edition 2019.1.1 x64 100 | 101 | 注:Pycharm必须安装社区版!否则不能调试。 102 | 103 | ### 2. 插件项目的导入与运行 104 | 105 | 打开IDEA,File->open->我们项目的根目录 106 | 107 | 然后需要配置: 108 | 109 | 1) 在IDEA中选择项目根目录右键Open module settings 110 | 111 |
112 | 113 | 设置项目的SDK为本机安装的Pycharm社区版,新建一个SDK,路径选择为安装的pycharm社区版根目录 114 | 115 |
116 | 117 |
118 | 119 | 2) 运行项目时会启动一个Pycahrm 窗口,是带有我们这个插件效果的。 120 | 121 |
122 | -------------------------------------------------------------------------------- /backend/01-flask/server.py: -------------------------------------------------------------------------------- 1 | from cors import crossdomain 2 | from flask import Flask, jsonify, request 3 | from keras.preprocessing.text import Tokenizer 4 | from keras.models import load_model 5 | from keras.preprocessing.sequence import pad_sequences 6 | import json 7 | 8 | app = Flask(__name__) 9 | 10 | # total_words = 2058 11 | total_words = 0 12 | # max_sequence_len = 9 13 | max_sequence_len = 0 14 | word_index = dict() 15 | model = None 16 | 17 | 18 | def get_args(req): 19 | if request.method == 'POST': 20 | args = request.json 21 | elif request.method == "GET": 22 | args = request.args 23 | return args 24 | 25 | 26 | @app.route("/plugin_test", methods=["GET", "POST", "OPTIONS"]) 27 | @crossdomain(origin='*', headers="Content-Type") 28 | def plugin_test(): 29 | if model == None: 30 | init() 31 | args = get_args(request) 32 | 33 | for seed in args.keys(): 34 | result = rep(generate_text(seed, model, max_sequence_len)) 35 | return jsonify({"data": result}) 36 | 37 | 38 | def main(host="0.0.0.0", port=9000): 39 | app.run(host=host, port=port, debug=True) 40 | 41 | 42 | def init(): 43 | savePath = "model/CheckpointModel_2.h5" 44 | global word_index 45 | word_index = json.load(open('word_dict.json', 'r')) 46 | para = json.load(open('para_dict.json', 'r')) 47 | global total_words, max_sequence_len 48 | total_words = para['total_words'] 49 | max_sequence_len = para['max_sequence_len'] 50 | global model 51 | model = load_model(savePath) 52 | 53 | 54 | def find_index(seed_text): 55 | index_list = [] 56 | seed_text = seed_text.split() 57 | for word in seed_text: 58 | if not word in word_index: 59 | index = None 60 | return 61 | index_list.append(word_index[word]) 62 | return index_list 63 | 64 | 65 | def generate_text(seed_text, model, max_sequence_len): 66 | # for _ in range(next_words): 67 | seed_text = seed_text.replace("(", " pareleft", ).replace(")", " pareright").replace(".", " dot").replace(",",\ 68 | " comma").replace(" \'\'", " quotMark").replace("=", " equal").replace(":", " colon") 69 | i = 0 70 | while i <= 10: 71 | token_list = find_index(seed_text) 72 | if token_list == None: 73 | return "cant predict!" 74 | # print(token_list) 75 | token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') 76 | predicted = model.predict_classes(token_list, verbose=0) 77 | output_word = "" 78 | for word, index in word_index.items(): 79 | if index == predicted: 80 | output_word = word 81 | break 82 | if output_word == 'dom': 83 | break 84 | seed_text += " " + output_word 85 | i += 1 86 | return seed_text 87 | 88 | 89 | def rep(str): 90 | return str.replace(" pareleft", "(", ).replace(" pareright", ")").replace(" dot", ".") \ 91 | .replace(" comma", ",").replace(" quotMark ", "\'\'").replace(" equal", "=").replace(" colon", ":") \ 92 | .replace(" quotmark", "\'\'") 93 | 94 | 95 | if __name__ == "__main__": 96 | main() 97 | -------------------------------------------------------------------------------- /backend/02-crow/server.cpp: -------------------------------------------------------------------------------- 1 | #include "crow.h" 2 | #include "crow/query_string.h" 3 | #include "Python.h" 4 | #include 5 | #include 6 | #include 7 | 8 | using namespace std; 9 | 10 | //CPython编译:g++ -Wall -std=c++11 test.cpp -o 0test -fsanitize=leak -lpython3.6m 11 | //crow编译:g++ -std=c++11 serve.cpp -o server -lpython3.6m -lboost_system -lboost_filesystem -L../boost/stage/lib -pthread 12 | 13 | int main() 14 | { 15 | //step1: PyC init. 16 | Py_Initialize(); 17 | if(!Py_IsInitialized()){ 18 | cout << "[error]: PyC init error." << endl; 19 | return 1; 20 | } 21 | cout << "[INFO]: PyC init succeed." << endl; 22 | 23 | //step2: Export serve.py path 24 | string work_path = string("sys.path.append(\'") + "\')"; 25 | PyRun_SimpleString("import sys"); 26 | PyRun_SimpleString(work_path.c_str()); 27 | 28 | //step3: Import serve.py 29 | PyObject* pModule = PyImport_ImportModule("serve"); 30 | if(!pModule){ 31 | cout << "[Error]: Import serve.py failed." << endl; 32 | return 1; 33 | } 34 | cout << "[INFO]: Import serve.py succeed." << endl; 35 | 36 | //step4: Get classes and methods from pModule 37 | PyObject *pDict = PyModule_GetDict(pModule); 38 | if(!pDict){ 39 | cout << "[Error]: Get classes and methods from pModule failed." << endl; 40 | return 1; 41 | } 42 | cout << "[INFO]: Get classes and methods from pModule succeed." << endl; 43 | 44 | //step5: Get Server class 45 | PyObject* pClass = PyDict_GetItemString(pDict,"Server"); 46 | if(!pClass){ 47 | cout << "[Error]: Get Server class failed." << endl; 48 | return 1; 49 | } 50 | cout << "[INFO]: Get Server class succeed." << endl; 51 | 52 | //step6: Get server construct function 53 | PyObject *pConstruct = PyInstanceMethod_New(pClass); 54 | if(!pConstruct){ 55 | cout << "[Error]: Get server construct function failed." << endl; 56 | return 1; 57 | } 58 | cout << "[INFO]: Get server construct function succeed." << endl; 59 | 60 | //step7: Create and init Server instance 61 | PyObject* pServer = PyObject_CallObject(pConstruct,nullptr); 62 | if(!pServer){ 63 | cout << "[Error]: Create and init Server instance failed." << endl; 64 | return 1; 65 | } 66 | cout << "[INFO]: Create and init Server instance succeed." << endl; 67 | 68 | //step8: Crow start 69 | crow::SimpleApp app; 70 | 71 | CROW_ROUTE(app, "/")([](){ 72 | return "Welcome to use the plugin!"; 73 | }); 74 | 75 | //step9: Model inference 76 | CROW_ROUTE(app, "/plugin_test") 77 | .methods("GET"_method) 78 | ([pServer](const crow::request& req){ 79 | std::ostringstream os; 80 | crow::query_string keyword = req.url_params; 81 | for(char* str :keyword.get_key_value_pairs_()){ 82 | if(strlen(str) == 0){ 83 | cout << "Got a empty value." < 0;++i){ 52 | String line = strs[strs.length - i].trim(); 53 | if(line.length() > 3){ 54 | key = line + "\n" + key; 55 | sentence_nums--; 56 | } 57 | } 58 | //网页会自动将空格编解码为%20 59 | key = key.replaceAll(" ","%20"); 60 | System.out.println("key: " + key); 61 | 62 | try{ 63 | String urlStr = "http://xx.xx.xx.xx:5000/plugin_test"; 64 | JSONObject jsonObject = null; 65 | URL url = new URL(urlStr); 66 | //创建http链接 67 | HttpURLConnection httpURLConnection = (HttpURLConnection) url.openConnection(); 68 | //设置请求的方法类型 69 | httpURLConnection.setRequestMethod("POST"); 70 | // Post请求不能使用缓存 71 | httpURLConnection.setUseCaches(false); 72 | //设置请求的内容类型 73 | httpURLConnection.setRequestProperty("Content-type", "application/json"); 74 | //设置发送数据 75 | httpURLConnection.setDoOutput(true); 76 | //设置接受数据 77 | httpURLConnection.setDoInput(true); 78 | 79 | //设置body内的参数,put到JSONObject中 80 | JSONObject param = new JSONObject(); 81 | param.put("type","Text_Gen"); 82 | param.put("data",key + "\n" + new_line); 83 | 84 | //param.put("type","Seq2Seq"); 85 | //param.put("data",key); 86 | 87 | // 建立实际的连接 88 | httpURLConnection.connect(); 89 | // 得到请求的输出流对象 90 | OutputStreamWriter writer = new OutputStreamWriter(httpURLConnection.getOutputStream(),"UTF-8"); 91 | writer.write(param.toString()); 92 | writer.flush(); 93 | 94 | //接收数据 95 | InputStream inputStream = httpURLConnection.getInputStream(); 96 | //定义字节数组 97 | byte[] b = new byte[1024]; 98 | //定义一个输出流存储接收到的数据 99 | ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); 100 | //开始接收数据 101 | int len = 0; 102 | while (true) { 103 | len = inputStream.read(b); 104 | if (len == -1) { 105 | //数据读完 106 | break; 107 | } 108 | byteArrayOutputStream.write(b, 0, len); 109 | } 110 | //从输出流中获取读取到数据(服务端返回的) 111 | String response = byteArrayOutputStream.toString(); 112 | //System.out.println("response:"+response); 113 | jsonObject = JSONObject.fromObject(response); 114 | //遍历JSON对象 115 | Iterator iter = jsonObject.entrySet().iterator(); 116 | while (iter.hasNext()) { 117 | Map.Entry entry = (Map.Entry) iter.next(); 118 | //System.out.println(entry.getKey().toString()); 119 | String completion_result = entry.getValue().toString(); 120 | //System.out.println("data:"+completion_result); 121 | if(new_line.length() < completion_result.length()){ 122 | String completion_result_start = completion_result.substring(0,new_line.length()); 123 | //System.out.println("new_line: " + new_line + ",completion_result_start:"+ completion_result_start); 124 | if(new_line.equals(completion_result_start)) 125 | completion_result = completion_result.substring(new_line.length()); 126 | } 127 | result.addElement(LookupElementBuilder.create(completion_result.trim())); 128 | } 129 | 130 | //byteArrayOutputStream.close(); 131 | // 关闭连接 132 | httpURLConnection.disconnect(); 133 | 134 | }catch(IOException e){ 135 | e.printStackTrace(); 136 | } 137 | super.fillCompletionVariants(parameters, result); 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /model/02-Seq2Seq/Seq2Seq_tf20.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "decent-client", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import io\n", 11 | "import re\n", 12 | "import random\n", 13 | "import tensorflow as tf\n", 14 | "\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import matplotlib.ticker as ticker\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "\n", 20 | "import unicodedata\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "import time" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "polar-recognition", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "#GPU设置\n", 34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n", 35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n", 36 | "for gpu in gpus:\n", 37 | " tf.config.experimental.set_memory_growth(gpu, True)\n", 38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 5, 44 | "id": "parallel-adjustment", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "'/device:GPU:0'" 51 | ] 52 | }, 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "tf.test.gpu_device_name()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "scenic-decrease", 65 | "metadata": {}, 66 | "source": [ 67 | "## 构建数据集" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "id": "starting-toilet", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def preprocess_sentence(line): \n", 78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n", 79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 80 | " \n", 81 | " new_line = ' ' + ' '.join(line) + ' '\n", 82 | " return new_line" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 7, 88 | "id": "variable-going", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n", 93 | " input_data = []\n", 94 | " output_data = []\n", 95 | " \n", 96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n", 97 | " if num_examples == -1:\n", 98 | " num_examples = len(lines)\n", 99 | " \n", 100 | " for i in range(1,num_examples):\n", 101 | " \n", 102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n", 103 | " for rand_num in rand_nums:\n", 104 | " data = ''\n", 105 | " for j in range(i - rand_num,i):\n", 106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n", 107 | " data += line\n", 108 | " input_data.append(data.strip())\n", 109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n", 110 | " \n", 111 | " return input_data,output_data" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "id": "polished-order", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "def max_length(tensor):\n", 122 | " return max(len(t) for t in tensor)\n", 123 | "\n", 124 | "def get_num_words(lang,scale):\n", 125 | " #获取词典大小\n", 126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n", 127 | " lang_tokenizer.fit_on_texts(lang)\n", 128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n", 129 | " return num_words\n", 130 | " \n", 131 | "def tokenize(lang,scale):\n", 132 | " num_words = get_num_words(lang,scale)\n", 133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n", 134 | " oov_token='',filters='')\n", 135 | " lang_tokenizer.fit_on_texts(lang)\n", 136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n", 137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n", 138 | " \n", 139 | " return tensor, lang_tokenizer\n", 140 | "\n", 141 | "def load_dataset(path,num_examples=None,scale=0.9):\n", 142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n", 143 | "\n", 144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n", 145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n", 146 | "\n", 147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "higher-punishment", 153 | "metadata": {}, 154 | "source": [ 155 | "### 限制数据集的大小以加快实验速度(可选)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 9, 161 | "id": "narrow-palmer", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# 尝试实验不同大小的数据集\n", 166 | "num_examples = -1\n", 167 | "scale = 0.95\n", 168 | "path = \"../00-data/tf_data.txt\"\n", 169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n", 170 | "\n", 171 | "# 计算目标张量的最大长度 (max_length)\n", 172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 10, 178 | "id": "convenient-string", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "109453" 185 | ] 186 | }, 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "len(input_tensor)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "id": "significant-oxygen", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "109453" 206 | ] 207 | }, 208 | "execution_count": 11, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "len(target_tensor)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 31, 220 | "id": "convinced-grounds", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "2501 2501 278 278\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "\"\"\"目前先不采用\"\"\"\n", 233 | "# 采用 90 - 10 的比例切分训练集和验证集\n", 234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n", 235 | "\n", 236 | "# 显示长度\n", 237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "incorrect-software", 243 | "metadata": {}, 244 | "source": [ 245 | "### 创建一个 tf.data 数据集" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 12, 251 | "id": "hungry-lafayette", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "BUFFER_SIZE = len(input_tensor)\n", 256 | "BATCH_SIZE = 64\n", 257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n", 258 | "embedding_dim = 256\n", 259 | "units = 256\n", 260 | "vocab_inp_size = len(inp_lang.word_index)+1\n", 261 | "vocab_tar_size = len(targ_lang.word_index)+1\n", 262 | "\n", 263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n", 264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "id": "requested-township", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "(TensorShape([64, 597]), TensorShape([64, 245]))" 277 | ] 278 | }, 279 | "execution_count": 13, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "example_input_batch, example_target_batch = next(iter(dataset))\n", 286 | "example_input_batch.shape, example_target_batch.shape" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "trying-channel", 292 | "metadata": {}, 293 | "source": [ 294 | "## 编写编码器encoder和解码器decoder模型" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 14, 300 | "id": "consecutive-literacy", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class Encoder(tf.keras.Model):\n", 305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n", 306 | " super(Encoder, self).__init__()\n", 307 | " self.batch_size = batch_size\n", 308 | " self.enc_units = enc_units\n", 309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n", 310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n", 311 | " return_sequences=True,\n", 312 | " return_state=True,\n", 313 | " dropout=0.1,\n", 314 | " recurrent_dropout=0.1)\n", 315 | "\n", 316 | " def call(self, x):\n", 317 | " x = self.embedding(x)\n", 318 | " output_l1,state_l1,_ = self.lstm(x)\n", 319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n", 320 | " return output_l1,state_l1\n", 321 | "\n", 322 | " def initialize_hidden_state(self):\n", 323 | " return tf.zeros((self.batch_size, self.enc_units))" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "id": "offensive-survivor", 330 | "metadata": { 331 | "scrolled": true 332 | }, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n", 339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", 345 | "\n", 346 | "# 样本输入\n", 347 | "sample_hidden = encoder.initialize_hidden_state()\n", 348 | "sample_output,sample_hidden = encoder(example_input_batch)\n", 349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n", 350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 16, 356 | "id": "therapeutic-favor", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "class BahdanauAttention(tf.keras.layers.Layer):\n", 361 | " def __init__(self, units):\n", 362 | " super(BahdanauAttention, self).__init__()\n", 363 | " self.W1 = tf.keras.layers.Dense(units)\n", 364 | " self.W2 = tf.keras.layers.Dense(units)\n", 365 | " self.V = tf.keras.layers.Dense(1)\n", 366 | "\n", 367 | " def call(self, query, values):\n", 368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n", 369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n", 370 | " # 这样做是为了执行加法以计算分数 \n", 371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n", 372 | "\n", 373 | " # 分数的形状 == (批大小,最大长度,1)\n", 374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n", 375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n", 376 | " score = self.V(tf.nn.tanh(\n", 377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n", 378 | "\n", 379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n", 380 | " attention_weights = tf.nn.softmax(score, axis=1)\n", 381 | "\n", 382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n", 383 | " context_vector = attention_weights * values\n", 384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n", 385 | "\n", 386 | " return context_vector, attention_weights" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 17, 392 | "id": "under-swaziland", 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "Attention result shape: (batch size, units) (64, 256)\n", 400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "attention_layer = BahdanauAttention(10)\n", 406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n", 407 | "\n", 408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n", 409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "id": "smoking-horizon", 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "class Decoder(tf.keras.Model):\n", 420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", 421 | " super(Decoder, self).__init__()\n", 422 | " self.batch_sz = batch_sz\n", 423 | " self.dec_units = dec_units\n", 424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", 425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 426 | " return_sequences=True,\n", 427 | " return_state=True,\n", 428 | " dropout=0.1,\n", 429 | " recurrent_dropout=0.1)\n", 430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 431 | " return_sequences=True,\n", 432 | " return_state=True)\n", 433 | " \n", 434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n", 435 | "\n", 436 | " # 用于注意力\n", 437 | " self.attention = BahdanauAttention(self.dec_units)\n", 438 | "\n", 439 | " def call(self, x, hidden, enc_output,is_train=True):\n", 440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n", 441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n", 442 | "\n", 443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n", 444 | " x = self.embedding(x)\n", 445 | "\n", 446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n", 447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", 448 | "\n", 449 | " # 将合并后的向量传送到 LSTM\n", 450 | " if is_train:\n", 451 | " output, state,_ = self.train_LSTM(x)\n", 452 | " else:\n", 453 | " output, state,_ = self.infer_LSTM(x)\n", 454 | "\n", 455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n", 456 | " output = tf.reshape(output, (-1, output.shape[2]))\n", 457 | "\n", 458 | " # 输出的形状 == (批大小,vocab)\n", 459 | " x = self.fc(output)\n", 460 | "\n", 461 | " return x, state, attention_weights" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 19, 467 | "id": "straight-sending", 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n" 475 | ] 476 | } 477 | ], 478 | "source": [ 479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n", 480 | "\n", 481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n", 482 | " sample_hidden, sample_output,is_train=True)\n", 483 | "\n", 484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 20, 490 | "id": "compliant-directive", 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "optimizer = tf.keras.optimizers.Adam()\n", 495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", 496 | " from_logits=True, reduction='none')\n", 497 | "\n", 498 | "def loss_function(real, pred):\n", 499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", 500 | " loss_ = loss_object(real, pred)\n", 501 | "\n", 502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n", 503 | " loss_ *= mask\n", 504 | "\n", 505 | " return tf.reduce_mean(loss_)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 21, 511 | "id": "adult-grocery", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "checkpoint_dir = '../02-checkpoints/'\n", 516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", 518 | " encoder=encoder,\n", 519 | " decoder=decoder)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 22, 525 | "id": "charming-apparatus", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "@tf.function\n", 530 | "def train_step(inp,targ,enc_hidden):\n", 531 | " loss = 0\n", 532 | " \n", 533 | " with tf.GradientTape() as tape:\n", 534 | " enc_output,enc_hidden = encoder(inp)\n", 535 | " dec_hidden = enc_hidden\n", 536 | " \n", 537 | " dec_hidden = enc_hidden\n", 538 | " \n", 539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n", 540 | " \n", 541 | " for t in range(1,targ.shape[1]):\n", 542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n", 543 | " loss += loss_function(targ[:,t],predictions)\n", 544 | " \n", 545 | " dec_input = tf.expand_dims(targ[:,t],1)\n", 546 | " \n", 547 | " batch_loss = (loss / int(targ.shape[1]))\n", 548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n", 549 | " gradients = tape.gradient(loss,variables)\n", 550 | " optimizer.apply_gradients(zip(gradients,variables))\n", 551 | " \n", 552 | " return batch_loss" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "id": "based-wound", 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "EPOCHS = 30\n", 563 | "\n", 564 | "for epoch in range(EPOCHS):\n", 565 | " start = time.time()\n", 566 | "\n", 567 | " enc_hidden = encoder.initialize_hidden_state()\n", 568 | " total_loss = 0\n", 569 | "\n", 570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n", 571 | " batch_loss = train_step(inp, targ, enc_hidden)\n", 572 | " total_loss += batch_loss\n", 573 | "\n", 574 | " #if batch % 5 == 0:\n", 575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", 576 | " batch,\n", 577 | " batch_loss.numpy()))\n", 578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n", 579 | " if (epoch + 1) % 5 == 0:\n", 580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n", 581 | "\n", 582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", 583 | " total_loss / steps_per_epoch))\n", 584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "id": "terminal-suite", 590 | "metadata": {}, 591 | "source": [ 592 | "## 预测" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 13, 598 | "id": "satellite-seating", 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "def create_inference_data(lines):\n", 603 | " data = \"\"\n", 604 | " for line in lines:\n", 605 | " line = preprocess_sentence(line.strip()) + ' '\n", 606 | " data += line\n", 607 | " return data.strip()" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 14, 613 | "id": "cellular-impossible", 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "def evaluate(sentence):\n", 618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", 619 | "\n", 620 | " sentence = create_inference_data(sentence)\n", 621 | "\n", 622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n", 623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n", 624 | " maxlen=max_length_inp,\n", 625 | " padding='post')\n", 626 | " inputs = tf.convert_to_tensor(inputs)\n", 627 | "\n", 628 | " result = ''\n", 629 | "\n", 630 | " hidden = [tf.zeros((1, units))]\n", 631 | " enc_out, enc_hidden = encoder(inputs)\n", 632 | "\n", 633 | " dec_hidden = enc_hidden\n", 634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n", 635 | "\n", 636 | " for t in range(max_length_targ):\n", 637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n", 638 | " dec_hidden,\n", 639 | " enc_out,\n", 640 | " is_train=False)\n", 641 | "\n", 642 | " # 存储注意力权重以便后面制图\n", 643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n", 644 | " attention_plot[t] = attention_weights.numpy()\n", 645 | "\n", 646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n", 647 | "\n", 648 | " result += targ_lang.index_word[predicted_id] + ' '\n", 649 | "\n", 650 | " if targ_lang.index_word[predicted_id] == '':\n", 651 | " return result, sentence, attention_plot\n", 652 | "\n", 653 | " # 预测的 ID 被输送回模型\n", 654 | " dec_input = tf.expand_dims([predicted_id], 0)\n", 655 | "\n", 656 | " return result, sentence, attention_plot" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 21, 662 | "id": "current-feedback", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "# 注意力权重制图函数\n", 667 | "def plot_attention(attention, sentence, predicted_sentence):\n", 668 | " fig = plt.figure(figsize=(20,20))\n", 669 | " ax = fig.add_subplot(1, 1, 1)\n", 670 | " ax.matshow(attention, cmap='viridis')\n", 671 | "\n", 672 | " fontdict = {'fontsize': 10}\n", 673 | "\n", 674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n", 675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n", 676 | "\n", 677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", 678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", 679 | "\n", 680 | " plt.show()" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 22, 686 | "id": "worse-logging", 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "def translate(sentence):\n", 691 | " result,sentence,attention_plot = evaluate(sentence)\n", 692 | " \n", 693 | " print('Input: %s' % (sentence))\n", 694 | " print('Predicted translation: {}'.format(result))\n", 695 | " \n", 696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n", 697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n", 698 | " print(result.split(' '))" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 29, 704 | "id": "integrated-victoria", 705 | "metadata": {}, 706 | "outputs": [], 707 | "source": [ 708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n", 709 | "lines = create_inference_data(lines)" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 30, 715 | "id": "numerous-advisory", 716 | "metadata": { 717 | "scrolled": false 718 | }, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/plain": [ 723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '" 724 | ] 725 | }, 726 | "execution_count": 30, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "lines" 733 | ] 734 | } 735 | ], 736 | "metadata": { 737 | "kernelspec": { 738 | "display_name": "myconda", 739 | "language": "python", 740 | "name": "myconda" 741 | }, 742 | "language_info": { 743 | "codemirror_mode": { 744 | "name": "ipython", 745 | "version": 3 746 | }, 747 | "file_extension": ".py", 748 | "mimetype": "text/x-python", 749 | "name": "python", 750 | "nbconvert_exporter": "python", 751 | "pygments_lexer": "ipython3", 752 | "version": "3.7.7" 753 | } 754 | }, 755 | "nbformat": 4, 756 | "nbformat_minor": 5 757 | } 758 | -------------------------------------------------------------------------------- /model/02-Seq2Seq/.~Seq2Seq.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "decent-client", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import io\n", 11 | "import re\n", 12 | "import random\n", 13 | "import tensorflow as tf\n", 14 | "\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import matplotlib.ticker as ticker\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "\n", 20 | "import unicodedata\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "import time" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "polar-recognition", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "#GPU设置\n", 34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n", 35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n", 36 | "for gpu in gpus:\n", 37 | " tf.config.experimental.set_memory_growth(gpu, True)\n", 38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 5, 44 | "id": "parallel-adjustment", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "'/device:GPU:0'" 51 | ] 52 | }, 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "tf.test.gpu_device_name()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "scenic-decrease", 65 | "metadata": {}, 66 | "source": [ 67 | "## 构建数据集" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "id": "starting-toilet", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def preprocess_sentence(line): \n", 78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n", 79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 80 | " \n", 81 | " new_line = ' ' + ' '.join(line) + ' '\n", 82 | " return new_line" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 7, 88 | "id": "variable-going", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n", 93 | " input_data = []\n", 94 | " output_data = []\n", 95 | " \n", 96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n", 97 | " if num_examples == -1:\n", 98 | " num_examples = len(lines)\n", 99 | " \n", 100 | " for i in range(1,num_examples):\n", 101 | " \n", 102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n", 103 | " for rand_num in rand_nums:\n", 104 | " data = ''\n", 105 | " for j in range(i - rand_num,i):\n", 106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n", 107 | " data += line\n", 108 | " input_data.append(data.strip())\n", 109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n", 110 | " \n", 111 | " return input_data,output_data" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "id": "polished-order", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "def max_length(tensor):\n", 122 | " return max(len(t) for t in tensor)\n", 123 | "\n", 124 | "def get_num_words(lang,scale):\n", 125 | " #获取词典大小\n", 126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n", 127 | " lang_tokenizer.fit_on_texts(lang)\n", 128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n", 129 | " return num_words\n", 130 | " \n", 131 | "def tokenize(lang,scale):\n", 132 | " num_words = get_num_words(lang,scale)\n", 133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n", 134 | " oov_token='',filters='')\n", 135 | " lang_tokenizer.fit_on_texts(lang)\n", 136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n", 137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n", 138 | " \n", 139 | " return tensor, lang_tokenizer\n", 140 | "\n", 141 | "def load_dataset(path,num_examples=None,scale=0.9):\n", 142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n", 143 | "\n", 144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n", 145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n", 146 | "\n", 147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "higher-punishment", 153 | "metadata": {}, 154 | "source": [ 155 | "### 限制数据集的大小以加快实验速度(可选)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 9, 161 | "id": "narrow-palmer", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# 尝试实验不同大小的数据集\n", 166 | "num_examples = -1\n", 167 | "scale = 0.95\n", 168 | "path = \"../00-data/tf_data.txt\"\n", 169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n", 170 | "\n", 171 | "# 计算目标张量的最大长度 (max_length)\n", 172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 10, 178 | "id": "convenient-string", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "109453" 185 | ] 186 | }, 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "len(input_tensor)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "id": "significant-oxygen", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "109453" 206 | ] 207 | }, 208 | "execution_count": 11, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "len(target_tensor)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 31, 220 | "id": "convinced-grounds", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "2501 2501 278 278\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "\"\"\"目前先不采用\"\"\"\n", 233 | "# 采用 90 - 10 的比例切分训练集和验证集\n", 234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n", 235 | "\n", 236 | "# 显示长度\n", 237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "incorrect-software", 243 | "metadata": {}, 244 | "source": [ 245 | "### 创建一个 tf.data 数据集" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 12, 251 | "id": "hungry-lafayette", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "BUFFER_SIZE = len(input_tensor)\n", 256 | "BATCH_SIZE = 64\n", 257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n", 258 | "embedding_dim = 256\n", 259 | "units = 256\n", 260 | "vocab_inp_size = len(inp_lang.word_index)+1\n", 261 | "vocab_tar_size = len(targ_lang.word_index)+1\n", 262 | "\n", 263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n", 264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "id": "requested-township", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "(TensorShape([64, 597]), TensorShape([64, 245]))" 277 | ] 278 | }, 279 | "execution_count": 13, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "example_input_batch, example_target_batch = next(iter(dataset))\n", 286 | "example_input_batch.shape, example_target_batch.shape" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "trying-channel", 292 | "metadata": {}, 293 | "source": [ 294 | "## 编写编码器encoder和解码器decoder模型" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 14, 300 | "id": "consecutive-literacy", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class Encoder(tf.keras.Model):\n", 305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n", 306 | " super(Encoder, self).__init__()\n", 307 | " self.batch_size = batch_size\n", 308 | " self.enc_units = enc_units\n", 309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n", 310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n", 311 | " return_sequences=True,\n", 312 | " return_state=True,\n", 313 | " dropout=0.1,\n", 314 | " recurrent_dropout=0.1)\n", 315 | "\n", 316 | " def call(self, x):\n", 317 | " x = self.embedding(x)\n", 318 | " output_l1,state_l1,_ = self.lstm(x)\n", 319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n", 320 | " return output_l1,state_l1\n", 321 | "\n", 322 | " def initialize_hidden_state(self):\n", 323 | " return tf.zeros((self.batch_size, self.enc_units))" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "id": "offensive-survivor", 330 | "metadata": { 331 | "scrolled": true 332 | }, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n", 339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", 345 | "\n", 346 | "# 样本输入\n", 347 | "sample_hidden = encoder.initialize_hidden_state()\n", 348 | "sample_output,sample_hidden = encoder(example_input_batch)\n", 349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n", 350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 16, 356 | "id": "therapeutic-favor", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "class BahdanauAttention(tf.keras.layers.Layer):\n", 361 | " def __init__(self, units):\n", 362 | " super(BahdanauAttention, self).__init__()\n", 363 | " self.W1 = tf.keras.layers.Dense(units)\n", 364 | " self.W2 = tf.keras.layers.Dense(units)\n", 365 | " self.V = tf.keras.layers.Dense(1)\n", 366 | "\n", 367 | " def call(self, query, values):\n", 368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n", 369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n", 370 | " # 这样做是为了执行加法以计算分数 \n", 371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n", 372 | "\n", 373 | " # 分数的形状 == (批大小,最大长度,1)\n", 374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n", 375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n", 376 | " score = self.V(tf.nn.tanh(\n", 377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n", 378 | "\n", 379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n", 380 | " attention_weights = tf.nn.softmax(score, axis=1)\n", 381 | "\n", 382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n", 383 | " context_vector = attention_weights * values\n", 384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n", 385 | "\n", 386 | " return context_vector, attention_weights" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 17, 392 | "id": "under-swaziland", 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "Attention result shape: (batch size, units) (64, 256)\n", 400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "attention_layer = BahdanauAttention(10)\n", 406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n", 407 | "\n", 408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n", 409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "id": "smoking-horizon", 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "class Decoder(tf.keras.Model):\n", 420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", 421 | " super(Decoder, self).__init__()\n", 422 | " self.batch_sz = batch_sz\n", 423 | " self.dec_units = dec_units\n", 424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", 425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 426 | " return_sequences=True,\n", 427 | " return_state=True,\n", 428 | " dropout=0.1,\n", 429 | " recurrent_dropout=0.1)\n", 430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 431 | " return_sequences=True,\n", 432 | " return_state=True)\n", 433 | " \n", 434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n", 435 | "\n", 436 | " # 用于注意力\n", 437 | " self.attention = BahdanauAttention(self.dec_units)\n", 438 | "\n", 439 | " def call(self, x, hidden, enc_output,is_train=True):\n", 440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n", 441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n", 442 | "\n", 443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n", 444 | " x = self.embedding(x)\n", 445 | "\n", 446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n", 447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", 448 | "\n", 449 | " # 将合并后的向量传送到 LSTM\n", 450 | " if is_train:\n", 451 | " output, state,_ = self.train_LSTM(x)\n", 452 | " else:\n", 453 | " output, state,_ = self.infer_LSTM(x)\n", 454 | "\n", 455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n", 456 | " output = tf.reshape(output, (-1, output.shape[2]))\n", 457 | "\n", 458 | " # 输出的形状 == (批大小,vocab)\n", 459 | " x = self.fc(output)\n", 460 | "\n", 461 | " return x, state, attention_weights" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 19, 467 | "id": "straight-sending", 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n" 475 | ] 476 | } 477 | ], 478 | "source": [ 479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n", 480 | "\n", 481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n", 482 | " sample_hidden, sample_output,is_train=True)\n", 483 | "\n", 484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 20, 490 | "id": "compliant-directive", 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "optimizer = tf.keras.optimizers.Adam()\n", 495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", 496 | " from_logits=True, reduction='none')\n", 497 | "\n", 498 | "def loss_function(real, pred):\n", 499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", 500 | " loss_ = loss_object(real, pred)\n", 501 | "\n", 502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n", 503 | " loss_ *= mask\n", 504 | "\n", 505 | " return tf.reduce_mean(loss_)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 21, 511 | "id": "adult-grocery", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "checkpoint_dir = '../02-checkpoints/'\n", 516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", 518 | " encoder=encoder,\n", 519 | " decoder=decoder)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 22, 525 | "id": "charming-apparatus", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "@tf.function\n", 530 | "def train_step(inp,targ,enc_hidden):\n", 531 | " loss = 0\n", 532 | " \n", 533 | " with tf.GradientTape() as tape:\n", 534 | " enc_output,enc_hidden = encoder(inp)\n", 535 | " dec_hidden = enc_hidden\n", 536 | " \n", 537 | " dec_hidden = enc_hidden\n", 538 | " \n", 539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n", 540 | " \n", 541 | " for t in range(1,targ.shape[1]):\n", 542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n", 543 | " loss += loss_function(targ[:,t],predictions)\n", 544 | " \n", 545 | " dec_input = tf.expand_dims(targ[:,t],1)\n", 546 | " \n", 547 | " batch_loss = (loss / int(targ.shape[1]))\n", 548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n", 549 | " gradients = tape.gradient(loss,variables)\n", 550 | " optimizer.apply_gradients(zip(gradients,variables))\n", 551 | " \n", 552 | " return batch_loss" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "id": "based-wound", 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "EPOCHS = 30\n", 563 | "\n", 564 | "for epoch in range(EPOCHS):\n", 565 | " start = time.time()\n", 566 | "\n", 567 | " enc_hidden = encoder.initialize_hidden_state()\n", 568 | " total_loss = 0\n", 569 | "\n", 570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n", 571 | " batch_loss = train_step(inp, targ, enc_hidden)\n", 572 | " total_loss += batch_loss\n", 573 | "\n", 574 | " #if batch % 5 == 0:\n", 575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", 576 | " batch,\n", 577 | " batch_loss.numpy()))\n", 578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n", 579 | " if (epoch + 1) % 5 == 0:\n", 580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n", 581 | "\n", 582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", 583 | " total_loss / steps_per_epoch))\n", 584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "id": "terminal-suite", 590 | "metadata": {}, 591 | "source": [ 592 | "## 预测" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 13, 598 | "id": "satellite-seating", 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "def create_inference_data(lines):\n", 603 | " data = \"\"\n", 604 | " for line in lines:\n", 605 | " line = preprocess_sentence(line.strip()) + ' '\n", 606 | " data += line\n", 607 | " return data.strip()" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 14, 613 | "id": "cellular-impossible", 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "def evaluate(sentence):\n", 618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", 619 | "\n", 620 | " sentence = create_inference_data(sentence)\n", 621 | "\n", 622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n", 623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n", 624 | " maxlen=max_length_inp,\n", 625 | " padding='post')\n", 626 | " inputs = tf.convert_to_tensor(inputs)\n", 627 | "\n", 628 | " result = ''\n", 629 | "\n", 630 | " hidden = [tf.zeros((1, units))]\n", 631 | " enc_out, enc_hidden = encoder(inputs)\n", 632 | "\n", 633 | " dec_hidden = enc_hidden\n", 634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n", 635 | "\n", 636 | " for t in range(max_length_targ):\n", 637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n", 638 | " dec_hidden,\n", 639 | " enc_out,\n", 640 | " is_train=False)\n", 641 | "\n", 642 | " # 存储注意力权重以便后面制图\n", 643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n", 644 | " attention_plot[t] = attention_weights.numpy()\n", 645 | "\n", 646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n", 647 | "\n", 648 | " result += targ_lang.index_word[predicted_id] + ' '\n", 649 | "\n", 650 | " if targ_lang.index_word[predicted_id] == '':\n", 651 | " return result, sentence, attention_plot\n", 652 | "\n", 653 | " # 预测的 ID 被输送回模型\n", 654 | " dec_input = tf.expand_dims([predicted_id], 0)\n", 655 | "\n", 656 | " return result, sentence, attention_plot" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 21, 662 | "id": "current-feedback", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "# 注意力权重制图函数\n", 667 | "def plot_attention(attention, sentence, predicted_sentence):\n", 668 | " fig = plt.figure(figsize=(20,20))\n", 669 | " ax = fig.add_subplot(1, 1, 1)\n", 670 | " ax.matshow(attention, cmap='viridis')\n", 671 | "\n", 672 | " fontdict = {'fontsize': 10}\n", 673 | "\n", 674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n", 675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n", 676 | "\n", 677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", 678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", 679 | "\n", 680 | " plt.show()" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 22, 686 | "id": "worse-logging", 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "def translate(sentence):\n", 691 | " result,sentence,attention_plot = evaluate(sentence)\n", 692 | " \n", 693 | " print('Input: %s' % (sentence))\n", 694 | " print('Predicted translation: {}'.format(result))\n", 695 | " \n", 696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n", 697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n", 698 | " print(result.split(' '))" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 29, 704 | "id": "integrated-victoria", 705 | "metadata": {}, 706 | "outputs": [], 707 | "source": [ 708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n", 709 | "lines = create_inference_data(lines)" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 30, 715 | "id": "numerous-advisory", 716 | "metadata": { 717 | "scrolled": false 718 | }, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/plain": [ 723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '" 724 | ] 725 | }, 726 | "execution_count": 30, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "lines" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "id": "structured-planet", 739 | "metadata": {}, 740 | "outputs": [], 741 | "source": [] 742 | } 743 | ], 744 | "metadata": { 745 | "kernelspec": { 746 | "display_name": "myconda", 747 | "language": "python", 748 | "name": "myconda" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.7.7" 761 | } 762 | }, 763 | "nbformat": 4, 764 | "nbformat_minor": 5 765 | } 766 | -------------------------------------------------------------------------------- /model/02-Seq2Seq/.ipynb_checkpoints/Seq2Seq-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "decent-client", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import io\n", 11 | "import re\n", 12 | "import random\n", 13 | "import tensorflow as tf\n", 14 | "\n", 15 | "\n", 16 | "import matplotlib.pyplot as plt\n", 17 | "import matplotlib.ticker as ticker\n", 18 | "from sklearn.model_selection import train_test_split\n", 19 | "\n", 20 | "import unicodedata\n", 21 | "import numpy as np\n", 22 | "import os\n", 23 | "import time" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 2, 29 | "id": "polar-recognition", 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "#GPU设置\n", 34 | "gpus = tf.config.experimental.list_physical_devices(device_type='GPU')\n", 35 | "tf.config.experimental.set_visible_devices(devices=gpus[:], device_type='GPU')\n", 36 | "for gpu in gpus:\n", 37 | " tf.config.experimental.set_memory_growth(gpu, True)\n", 38 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\"" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 5, 44 | "id": "parallel-adjustment", 45 | "metadata": {}, 46 | "outputs": [ 47 | { 48 | "data": { 49 | "text/plain": [ 50 | "'/device:GPU:0'" 51 | ] 52 | }, 53 | "execution_count": 5, 54 | "metadata": {}, 55 | "output_type": "execute_result" 56 | } 57 | ], 58 | "source": [ 59 | "tf.test.gpu_device_name()" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "id": "scenic-decrease", 65 | "metadata": {}, 66 | "source": [ 67 | "## 构建数据集" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 6, 73 | "id": "starting-toilet", 74 | "metadata": {}, 75 | "outputs": [], 76 | "source": [ 77 | "def preprocess_sentence(line): \n", 78 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n", 79 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 80 | " \n", 81 | " new_line = ' ' + ' '.join(line) + ' '\n", 82 | " return new_line" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": 7, 88 | "id": "variable-going", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "def create_dataset(path,num_examples,rand_max=15,duplicate=3):\n", 93 | " input_data = []\n", 94 | " output_data = []\n", 95 | " \n", 96 | " lines = io.open(path,encoding='utf-8').read().strip().split('\\n')\n", 97 | " if num_examples == -1:\n", 98 | " num_examples = len(lines)\n", 99 | " \n", 100 | " for i in range(1,num_examples):\n", 101 | " \n", 102 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n", 103 | " for rand_num in rand_nums:\n", 104 | " data = ''\n", 105 | " for j in range(i - rand_num,i):\n", 106 | " line = preprocess_sentence(lines[j].strip()) + ' '\n", 107 | " data += line\n", 108 | " input_data.append(data.strip())\n", 109 | " output_data.append(preprocess_sentence(lines[i].strip()))\n", 110 | " \n", 111 | " return input_data,output_data" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 8, 117 | "id": "polished-order", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "def max_length(tensor):\n", 122 | " return max(len(t) for t in tensor)\n", 123 | "\n", 124 | "def get_num_words(lang,scale):\n", 125 | " #获取词典大小\n", 126 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')\n", 127 | " lang_tokenizer.fit_on_texts(lang)\n", 128 | " num_words = int(len(lang_tokenizer.word_index) * scale)\n", 129 | " return num_words\n", 130 | " \n", 131 | "def tokenize(lang,scale):\n", 132 | " num_words = get_num_words(lang,scale)\n", 133 | " lang_tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=num_words,\n", 134 | " oov_token='',filters='')\n", 135 | " lang_tokenizer.fit_on_texts(lang)\n", 136 | " tensor = lang_tokenizer.texts_to_sequences(lang)\n", 137 | " tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor,padding='post')\n", 138 | " \n", 139 | " return tensor, lang_tokenizer\n", 140 | "\n", 141 | "def load_dataset(path,num_examples=None,scale=0.9):\n", 142 | " inp_lang,targ_lang = create_dataset(path, num_examples)\n", 143 | "\n", 144 | " input_tensor, inp_lang_tokenizer = tokenize(inp_lang,scale)\n", 145 | " target_tensor, targ_lang_tokenizer = tokenize(targ_lang,scale)\n", 146 | "\n", 147 | " return input_tensor, target_tensor, inp_lang_tokenizer, targ_lang_tokenizer" 148 | ] 149 | }, 150 | { 151 | "cell_type": "markdown", 152 | "id": "higher-punishment", 153 | "metadata": {}, 154 | "source": [ 155 | "### 限制数据集的大小以加快实验速度(可选)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": 9, 161 | "id": "narrow-palmer", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "# 尝试实验不同大小的数据集\n", 166 | "num_examples = -1\n", 167 | "scale = 0.95\n", 168 | "path = \"../00-data/tf_data.txt\"\n", 169 | "input_tensor, target_tensor, inp_lang, targ_lang = load_dataset(path, num_examples,scale)\n", 170 | "\n", 171 | "# 计算目标张量的最大长度 (max_length)\n", 172 | "max_length_targ, max_length_inp = max_length(target_tensor), max_length(input_tensor)" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 10, 178 | "id": "convenient-string", 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "data": { 183 | "text/plain": [ 184 | "109453" 185 | ] 186 | }, 187 | "execution_count": 10, 188 | "metadata": {}, 189 | "output_type": "execute_result" 190 | } 191 | ], 192 | "source": [ 193 | "len(input_tensor)" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 11, 199 | "id": "significant-oxygen", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "data": { 204 | "text/plain": [ 205 | "109453" 206 | ] 207 | }, 208 | "execution_count": 11, 209 | "metadata": {}, 210 | "output_type": "execute_result" 211 | } 212 | ], 213 | "source": [ 214 | "len(target_tensor)" 215 | ] 216 | }, 217 | { 218 | "cell_type": "code", 219 | "execution_count": 31, 220 | "id": "convinced-grounds", 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "2501 2501 278 278\n" 228 | ] 229 | } 230 | ], 231 | "source": [ 232 | "\"\"\"目前先不采用\"\"\"\n", 233 | "# 采用 90 - 10 的比例切分训练集和验证集\n", 234 | "input_tensor_train, input_tensor_val, target_tensor_train, target_tensor_val = train_test_split(input_tensor, target_tensor, test_size=0.1)\n", 235 | "\n", 236 | "# 显示长度\n", 237 | "print(len(input_tensor_train), len(target_tensor_train), len(input_tensor_val), len(target_tensor_val))" 238 | ] 239 | }, 240 | { 241 | "cell_type": "markdown", 242 | "id": "incorrect-software", 243 | "metadata": {}, 244 | "source": [ 245 | "### 创建一个 tf.data 数据集" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 12, 251 | "id": "hungry-lafayette", 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "BUFFER_SIZE = len(input_tensor)\n", 256 | "BATCH_SIZE = 64\n", 257 | "steps_per_epoch = len(input_tensor)//BATCH_SIZE\n", 258 | "embedding_dim = 256\n", 259 | "units = 256\n", 260 | "vocab_inp_size = len(inp_lang.word_index)+1\n", 261 | "vocab_tar_size = len(targ_lang.word_index)+1\n", 262 | "\n", 263 | "dataset = tf.data.Dataset.from_tensor_slices((input_tensor, target_tensor)).shuffle(BUFFER_SIZE)\n", 264 | "dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 13, 270 | "id": "requested-township", 271 | "metadata": {}, 272 | "outputs": [ 273 | { 274 | "data": { 275 | "text/plain": [ 276 | "(TensorShape([64, 597]), TensorShape([64, 245]))" 277 | ] 278 | }, 279 | "execution_count": 13, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "example_input_batch, example_target_batch = next(iter(dataset))\n", 286 | "example_input_batch.shape, example_target_batch.shape" 287 | ] 288 | }, 289 | { 290 | "cell_type": "markdown", 291 | "id": "trying-channel", 292 | "metadata": {}, 293 | "source": [ 294 | "## 编写编码器encoder和解码器decoder模型" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 14, 300 | "id": "consecutive-literacy", 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "class Encoder(tf.keras.Model):\n", 305 | " def __init__(self, vocab_size, embedding_dim, enc_units, batch_size):\n", 306 | " super(Encoder, self).__init__()\n", 307 | " self.batch_size = batch_size\n", 308 | " self.enc_units = enc_units\n", 309 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim,batch_input_shape=[batch_size,None])\n", 310 | " self.lstm = tf.keras.layers.LSTM(self.enc_units,\n", 311 | " return_sequences=True,\n", 312 | " return_state=True,\n", 313 | " dropout=0.1,\n", 314 | " recurrent_dropout=0.1)\n", 315 | "\n", 316 | " def call(self, x):\n", 317 | " x = self.embedding(x)\n", 318 | " output_l1,state_l1,_ = self.lstm(x)\n", 319 | " #output_l2,state_l2,_ = self.lstm(output_l1)\n", 320 | " return output_l1,state_l1\n", 321 | "\n", 322 | " def initialize_hidden_state(self):\n", 323 | " return tf.zeros((self.batch_size, self.enc_units))" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 15, 329 | "id": "offensive-survivor", 330 | "metadata": { 331 | "scrolled": true 332 | }, 333 | "outputs": [ 334 | { 335 | "name": "stdout", 336 | "output_type": "stream", 337 | "text": [ 338 | "Encoder output shape: (batch size, sequence length, units) (64, 597, 256)\n", 339 | "Encoder Hidden state shape: (batch size, units) (64, 256)\n" 340 | ] 341 | } 342 | ], 343 | "source": [ 344 | "encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)\n", 345 | "\n", 346 | "# 样本输入\n", 347 | "sample_hidden = encoder.initialize_hidden_state()\n", 348 | "sample_output,sample_hidden = encoder(example_input_batch)\n", 349 | "print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))\n", 350 | "print ('Encoder Hidden state shape: (batch size, units) {}'.format(sample_hidden.shape))" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": 16, 356 | "id": "therapeutic-favor", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "class BahdanauAttention(tf.keras.layers.Layer):\n", 361 | " def __init__(self, units):\n", 362 | " super(BahdanauAttention, self).__init__()\n", 363 | " self.W1 = tf.keras.layers.Dense(units)\n", 364 | " self.W2 = tf.keras.layers.Dense(units)\n", 365 | " self.V = tf.keras.layers.Dense(1)\n", 366 | "\n", 367 | " def call(self, query, values):\n", 368 | " # 隐藏层的形状 == (批大小,隐藏层大小)\n", 369 | " # hidden_with_time_axis 的形状 == (批大小,1,隐藏层大小)\n", 370 | " # 这样做是为了执行加法以计算分数 \n", 371 | " hidden_with_time_axis = tf.expand_dims(query, 1)\n", 372 | "\n", 373 | " # 分数的形状 == (批大小,最大长度,1)\n", 374 | " # 我们在最后一个轴上得到 1, 因为我们把分数应用于 self.V\n", 375 | " # 在应用 self.V 之前,张量的形状是(批大小,最大长度,单位)\n", 376 | " score = self.V(tf.nn.tanh(\n", 377 | " self.W1(values) + self.W2(hidden_with_time_axis)))\n", 378 | "\n", 379 | " # 注意力权重 (attention_weights) 的形状 == (批大小,最大长度,1)\n", 380 | " attention_weights = tf.nn.softmax(score, axis=1)\n", 381 | "\n", 382 | " # 上下文向量 (context_vector) 求和之后的形状 == (批大小,隐藏层大小)\n", 383 | " context_vector = attention_weights * values\n", 384 | " context_vector = tf.reduce_sum(context_vector, axis=1)\n", 385 | "\n", 386 | " return context_vector, attention_weights" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 17, 392 | "id": "under-swaziland", 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "Attention result shape: (batch size, units) (64, 256)\n", 400 | "Attention weights shape: (batch_size, sequence_length, 1) (64, 597, 1)\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "attention_layer = BahdanauAttention(10)\n", 406 | "attention_result, attention_weights = attention_layer(sample_hidden, sample_output)\n", 407 | "\n", 408 | "print(\"Attention result shape: (batch size, units) {}\".format(attention_result.shape))\n", 409 | "print(\"Attention weights shape: (batch_size, sequence_length, 1) {}\".format(attention_weights.shape))" 410 | ] 411 | }, 412 | { 413 | "cell_type": "code", 414 | "execution_count": 18, 415 | "id": "smoking-horizon", 416 | "metadata": {}, 417 | "outputs": [], 418 | "source": [ 419 | "class Decoder(tf.keras.Model):\n", 420 | " def __init__(self, vocab_size, embedding_dim, dec_units, batch_sz):\n", 421 | " super(Decoder, self).__init__()\n", 422 | " self.batch_sz = batch_sz\n", 423 | " self.dec_units = dec_units\n", 424 | " self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)\n", 425 | " self.train_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 426 | " return_sequences=True,\n", 427 | " return_state=True,\n", 428 | " dropout=0.1,\n", 429 | " recurrent_dropout=0.1)\n", 430 | " self.infer_LSTM = tf.keras.layers.LSTM(self.dec_units,\n", 431 | " return_sequences=True,\n", 432 | " return_state=True)\n", 433 | " \n", 434 | " self.fc = tf.keras.layers.Dense(vocab_size)\n", 435 | "\n", 436 | " # 用于注意力\n", 437 | " self.attention = BahdanauAttention(self.dec_units)\n", 438 | "\n", 439 | " def call(self, x, hidden, enc_output,is_train=True):\n", 440 | " # 编码器输出 (enc_output) 的形状 == (批大小,最大长度,隐藏层大小)\n", 441 | " context_vector, attention_weights = self.attention(hidden, enc_output)\n", 442 | "\n", 443 | " # x 在通过嵌入层后的形状 == (批大小,1,嵌入维度)\n", 444 | " x = self.embedding(x)\n", 445 | "\n", 446 | " # x 在拼接 (concatenation) 后的形状 == (批大小,1,嵌入维度 + 隐藏层大小)\n", 447 | " x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)\n", 448 | "\n", 449 | " # 将合并后的向量传送到 LSTM\n", 450 | " if is_train:\n", 451 | " output, state,_ = self.train_LSTM(x)\n", 452 | " else:\n", 453 | " output, state,_ = self.infer_LSTM(x)\n", 454 | "\n", 455 | " # 输出的形状 == (批大小 * 1,隐藏层大小)\n", 456 | " output = tf.reshape(output, (-1, output.shape[2]))\n", 457 | "\n", 458 | " # 输出的形状 == (批大小,vocab)\n", 459 | " x = self.fc(output)\n", 460 | "\n", 461 | " return x, state, attention_weights" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 19, 467 | "id": "straight-sending", 468 | "metadata": {}, 469 | "outputs": [ 470 | { 471 | "name": "stdout", 472 | "output_type": "stream", 473 | "text": [ 474 | "Decoder output shape: (batch_size, vocab size) (64, 11568)\n" 475 | ] 476 | } 477 | ], 478 | "source": [ 479 | "decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)\n", 480 | "\n", 481 | "sample_decoder_output, _, _ = decoder(tf.random.uniform((BATCH_SIZE, 1)),\n", 482 | " sample_hidden, sample_output,is_train=True)\n", 483 | "\n", 484 | "print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))" 485 | ] 486 | }, 487 | { 488 | "cell_type": "code", 489 | "execution_count": 20, 490 | "id": "compliant-directive", 491 | "metadata": {}, 492 | "outputs": [], 493 | "source": [ 494 | "optimizer = tf.keras.optimizers.Adam()\n", 495 | "loss_object = tf.keras.losses.SparseCategoricalCrossentropy(\n", 496 | " from_logits=True, reduction='none')\n", 497 | "\n", 498 | "def loss_function(real, pred):\n", 499 | " mask = tf.math.logical_not(tf.math.equal(real, 0))\n", 500 | " loss_ = loss_object(real, pred)\n", 501 | "\n", 502 | " mask = tf.cast(mask, dtype=loss_.dtype)\n", 503 | " loss_ *= mask\n", 504 | "\n", 505 | " return tf.reduce_mean(loss_)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": 21, 511 | "id": "adult-grocery", 512 | "metadata": {}, 513 | "outputs": [], 514 | "source": [ 515 | "checkpoint_dir = '../02-checkpoints/'\n", 516 | "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", 517 | "checkpoint = tf.train.Checkpoint(optimizer=optimizer,\n", 518 | " encoder=encoder,\n", 519 | " decoder=decoder)" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 22, 525 | "id": "charming-apparatus", 526 | "metadata": {}, 527 | "outputs": [], 528 | "source": [ 529 | "@tf.function\n", 530 | "def train_step(inp,targ,enc_hidden):\n", 531 | " loss = 0\n", 532 | " \n", 533 | " with tf.GradientTape() as tape:\n", 534 | " enc_output,enc_hidden = encoder(inp)\n", 535 | " dec_hidden = enc_hidden\n", 536 | " \n", 537 | " dec_hidden = enc_hidden\n", 538 | " \n", 539 | " dec_input = tf.expand_dims([targ_lang.word_index['']] * BATCH_SIZE, 1)\n", 540 | " \n", 541 | " for t in range(1,targ.shape[1]):\n", 542 | " predictions,dec_hidden,_ = decoder(dec_input,dec_hidden,enc_output,is_train=True)\n", 543 | " loss += loss_function(targ[:,t],predictions)\n", 544 | " \n", 545 | " dec_input = tf.expand_dims(targ[:,t],1)\n", 546 | " \n", 547 | " batch_loss = (loss / int(targ.shape[1]))\n", 548 | " variables = encoder.trainable_variables + decoder.trainable_variables\n", 549 | " gradients = tape.gradient(loss,variables)\n", 550 | " optimizer.apply_gradients(zip(gradients,variables))\n", 551 | " \n", 552 | " return batch_loss" 553 | ] 554 | }, 555 | { 556 | "cell_type": "code", 557 | "execution_count": null, 558 | "id": "based-wound", 559 | "metadata": {}, 560 | "outputs": [], 561 | "source": [ 562 | "EPOCHS = 30\n", 563 | "\n", 564 | "for epoch in range(EPOCHS):\n", 565 | " start = time.time()\n", 566 | "\n", 567 | " enc_hidden = encoder.initialize_hidden_state()\n", 568 | " total_loss = 0\n", 569 | "\n", 570 | " for (batch, (inp, targ)) in enumerate(dataset.take(steps_per_epoch)):\n", 571 | " batch_loss = train_step(inp, targ, enc_hidden)\n", 572 | " total_loss += batch_loss\n", 573 | "\n", 574 | " #if batch % 5 == 0:\n", 575 | " print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1,\n", 576 | " batch,\n", 577 | " batch_loss.numpy()))\n", 578 | " # 每 5 个周期(epoch),保存(检查点)一次模型\n", 579 | " if (epoch + 1) % 5 == 0:\n", 580 | " checkpoint.save(file_prefix=checkpoint_prefix) \n", 581 | "\n", 582 | " print('Epoch {} Loss {:.4f}'.format(epoch + 1,\n", 583 | " total_loss / steps_per_epoch))\n", 584 | " print('Time taken for 1 epoch {} sec\\n'.format(time.time() - start))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "markdown", 589 | "id": "terminal-suite", 590 | "metadata": {}, 591 | "source": [ 592 | "## 预测" 593 | ] 594 | }, 595 | { 596 | "cell_type": "code", 597 | "execution_count": 13, 598 | "id": "satellite-seating", 599 | "metadata": {}, 600 | "outputs": [], 601 | "source": [ 602 | "def create_inference_data(lines):\n", 603 | " data = \"\"\n", 604 | " for line in lines:\n", 605 | " line = preprocess_sentence(line.strip()) + ' '\n", 606 | " data += line\n", 607 | " return data.strip()" 608 | ] 609 | }, 610 | { 611 | "cell_type": "code", 612 | "execution_count": 14, 613 | "id": "cellular-impossible", 614 | "metadata": {}, 615 | "outputs": [], 616 | "source": [ 617 | "def evaluate(sentence):\n", 618 | " attention_plot = np.zeros((max_length_targ, max_length_inp))\n", 619 | "\n", 620 | " sentence = create_inference_data(sentence)\n", 621 | "\n", 622 | " inputs = [inp_lang.word_index[i] for i in sentence.split(' ')]\n", 623 | " inputs = tf.keras.preprocessing.sequence.pad_sequences([inputs],\n", 624 | " maxlen=max_length_inp,\n", 625 | " padding='post')\n", 626 | " inputs = tf.convert_to_tensor(inputs)\n", 627 | "\n", 628 | " result = ''\n", 629 | "\n", 630 | " hidden = [tf.zeros((1, units))]\n", 631 | " enc_out, enc_hidden = encoder(inputs)\n", 632 | "\n", 633 | " dec_hidden = enc_hidden\n", 634 | " dec_input = tf.expand_dims([targ_lang.word_index['']], 0)\n", 635 | "\n", 636 | " for t in range(max_length_targ):\n", 637 | " predictions, dec_hidden, attention_weights = decoder(dec_input,\n", 638 | " dec_hidden,\n", 639 | " enc_out,\n", 640 | " is_train=False)\n", 641 | "\n", 642 | " # 存储注意力权重以便后面制图\n", 643 | " attention_weights = tf.reshape(attention_weights, (-1, ))\n", 644 | " attention_plot[t] = attention_weights.numpy()\n", 645 | "\n", 646 | " predicted_id = tf.argmax(predictions[0]).numpy()\n", 647 | "\n", 648 | " result += targ_lang.index_word[predicted_id] + ' '\n", 649 | "\n", 650 | " if targ_lang.index_word[predicted_id] == '':\n", 651 | " return result, sentence, attention_plot\n", 652 | "\n", 653 | " # 预测的 ID 被输送回模型\n", 654 | " dec_input = tf.expand_dims([predicted_id], 0)\n", 655 | "\n", 656 | " return result, sentence, attention_plot" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 21, 662 | "id": "current-feedback", 663 | "metadata": {}, 664 | "outputs": [], 665 | "source": [ 666 | "# 注意力权重制图函数\n", 667 | "def plot_attention(attention, sentence, predicted_sentence):\n", 668 | " fig = plt.figure(figsize=(20,20))\n", 669 | " ax = fig.add_subplot(1, 1, 1)\n", 670 | " ax.matshow(attention, cmap='viridis')\n", 671 | "\n", 672 | " fontdict = {'fontsize': 10}\n", 673 | "\n", 674 | " ax.set_xticklabels([''] + sentence, fontdict=fontdict)\n", 675 | " ax.set_yticklabels([''] + predicted_sentence, fontdict=fontdict, rotation=90)\n", 676 | "\n", 677 | " ax.xaxis.set_major_locator(ticker.MultipleLocator(1))\n", 678 | " ax.yaxis.set_major_locator(ticker.MultipleLocator(1))\n", 679 | "\n", 680 | " plt.show()" 681 | ] 682 | }, 683 | { 684 | "cell_type": "code", 685 | "execution_count": 22, 686 | "id": "worse-logging", 687 | "metadata": {}, 688 | "outputs": [], 689 | "source": [ 690 | "def translate(sentence):\n", 691 | " result,sentence,attention_plot = evaluate(sentence)\n", 692 | " \n", 693 | " print('Input: %s' % (sentence))\n", 694 | " print('Predicted translation: {}'.format(result))\n", 695 | " \n", 696 | " attention_plot = attention_plot[:len(result.split(' ')),:len(sentence.split(' '))]\n", 697 | " plot_attention(attention_plot,sentence.split(' '),result.split(' ')) \n", 698 | " print(result.split(' '))" 699 | ] 700 | }, 701 | { 702 | "cell_type": "code", 703 | "execution_count": 29, 704 | "id": "integrated-victoria", 705 | "metadata": {}, 706 | "outputs": [], 707 | "source": [ 708 | "lines = io.open(path,encoding='utf-8').read().strip().split('\\n')[12:13]\n", 709 | "lines = create_inference_data(lines)" 710 | ] 711 | }, 712 | { 713 | "cell_type": "code", 714 | "execution_count": 30, 715 | "id": "numerous-advisory", 716 | "metadata": { 717 | "scrolled": false 718 | }, 719 | "outputs": [ 720 | { 721 | "data": { 722 | "text/plain": [ 723 | "' y = tf . placeholder ( tf . float32 , shape = [ ] , name = ) '" 724 | ] 725 | }, 726 | "execution_count": 30, 727 | "metadata": {}, 728 | "output_type": "execute_result" 729 | } 730 | ], 731 | "source": [ 732 | "lines" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": null, 738 | "id": "structured-planet", 739 | "metadata": {}, 740 | "outputs": [], 741 | "source": [] 742 | } 743 | ], 744 | "metadata": { 745 | "kernelspec": { 746 | "display_name": "myconda", 747 | "language": "python", 748 | "name": "myconda" 749 | }, 750 | "language_info": { 751 | "codemirror_mode": { 752 | "name": "ipython", 753 | "version": 3 754 | }, 755 | "file_extension": ".py", 756 | "mimetype": "text/x-python", 757 | "name": "python", 758 | "nbconvert_exporter": "python", 759 | "pygments_lexer": "ipython3", 760 | "version": "3.7.7" 761 | } 762 | }, 763 | "nbformat": 4, 764 | "nbformat_minor": 5 765 | } 766 | -------------------------------------------------------------------------------- /model/01-Text_Gen/Text_Gen_tf10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\"数据导入\"\"\"\n", 10 | "\n", 11 | "import re\n", 12 | "data_file = \"../00-data/tf_data.txt\"\n", 13 | "filename =open(data_file,'r',encoding='utf-8') #打开数据文件\n", 14 | "\n", 15 | "text = filename.read() #将数据读取到字符串text中\n", 16 | "text = ' '.join(re.split(' |\\t|\\v',text)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n", 17 | "text = re.split('([: ,.*\\n(){}\\[\\]=])',text) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n", 18 | "\n", 19 | "text = list(filter(lambda x: x!=' 'and x!='',text)) #将列表中的空格和非空格筛选掉\n", 20 | "list_text = text #保留一份列表格式的数据\n", 21 | "text = ' '.join(text) #将列表转换成字符串" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "\"\"\"文本词频统计\"\"\"\n", 31 | "\n", 32 | "def word_count(list_text): #定义计算文本词频的函数,传入list_text列表\n", 33 | " import collections\n", 34 | " word_freq = collections.defaultdict(int) #定义一个int型的词频词典,并提供默认值\n", 35 | " for w in list_text: #遍历列表中的元素,元素出现一次,频次加一\n", 36 | " word_freq[w] += 1\n", 37 | " return word_freq #返回词频词典\n", 38 | " \n", 39 | " #return word_freq.items() 该语句返回值的类型为list(这句话有语法问题,不必考虑)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "\"\"\"根据text文本创建代码词词典\"\"\"\n", 49 | "\n", 50 | "def build_dict(text, min_word_freq=50):\n", 51 | " word_freq = word_count(text) #文本词频统计,返回一个词频词典\n", 52 | " word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) # filter将词频数量低于指定值的代码词删除。\n", 53 | " word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) # key用于指定排序的元素,因为sorted默认使用list中每个item的第一个元素从小到大排列,所以这里通过lambda进行前后元素调序,并对词频去相反数,从而将词频最大的排列在最前面\n", 54 | " words, _ = list(zip(*word_freq_sorted)) #获取每一个代码词\n", 55 | " words = list(words)\n", 56 | " words.append('')\n", 57 | " word_idx = dict(zip(words, range(len(words)))) #构建词典(不包含词频)\n", 58 | " return words,word_idx #这里只返回了words,倒数两行代码还用不上。返回的是一个不含重复的代码词词典,不包含词频。" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stderr", 68 | "output_type": "stream", 69 | "text": [ 70 | "Using TensorFlow backend.\n" 71 | ] 72 | }, 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Number of sequences: 79181\n", 78 | "Unique words: 7865\n", 79 | "Vectorization...\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "\"\"\"数据预处理-字符串序列向量化\"\"\"\n", 85 | "\n", 86 | "import numpy as np\n", 87 | "import keras\n", 88 | "import json\n", 89 | "\n", 90 | "maxlen = 50 #提取50个代码词组成的序列\n", 91 | "step = 5 #每5个代码词采样一个新序列\n", 92 | "sentences = [] #保存所提取的序列\n", 93 | "next_words = [] #保存目标代码词\n", 94 | "vocab_file = \"../00-data/vocab\"\n", 95 | "\n", 96 | "cut_words = list_text #将列表形式的元数据保存在cut_words中\n", 97 | "for i in range(0,len(cut_words) - maxlen,step):\n", 98 | " sentences.append(cut_words[i:i + maxlen]) #将元数据按照步长来存储在每个序列中 \n", 99 | " next_words.append(cut_words[i + maxlen]) #将目标代码词存储在next_words中\n", 100 | " \n", 101 | " \n", 102 | "print('Number of sequences:', len(sentences))\n", 103 | "\n", 104 | "\n", 105 | "words,word_idx = list(build_dict(list_text,1)) #创建代码词词典,返回的是一个不含重复的代码词词典,不包含词频。\n", 106 | "print('Unique words:',len(words))\n", 107 | "json_dict = json.dumps(word_idx)\n", 108 | "with open(vocab_file,\"w\") as f:\n", 109 | " f.write(json_dict)\n", 110 | "\n", 111 | "word_indices = dict((word,words.index(word)) for word in words) #创建一个包含代码词唯一索引的代码词词典,返回的是一个字典\n", 112 | "#print(word_indices)\n", 113 | "\n", 114 | "print('Vectorization...')\n", 115 | "x = np.zeros((len(sentences),maxlen)) #初始化x\n", 116 | "y = np.zeros((len(sentences))) #初始化y\n", 117 | "for i,sentence in enumerate(sentences):\n", 118 | " for t,word in enumerate(sentence):\n", 119 | " x[i,t] = word_indices.get(word,word_indices['']) #将代码词转换成向量形式的编码\n", 120 | " #y[i] = word_indices[next_words[i]]\n", 121 | " y[i] = word_indices.get(next_words[i],word_indices[''])\n", 122 | "\n", 123 | "y = keras.utils.to_categorical(y, len(words)) #将int型数组y转换成one-hot编码" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "\"\"\"定义下一个代码词的采样函数---temperature越大,代码生成的随机性越强---\"\"\"\n", 133 | "\n", 134 | "def sample(preds,temperature=0.1):\n", 135 | " preds = np.asarray(preds).astype('float')\n", 136 | " preds = np.log(preds) /temperature\n", 137 | " exp_preds = np.exp(preds)\n", 138 | " preds = exp_preds / np.sum(exp_preds)\n", 139 | " probas = np.random.multinomial(1,preds,1)\n", 140 | " return np.argmax(probas)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "\"\"\"将字符串写到指定文件中\"\"\"\n", 150 | "\n", 151 | "def save(filename, contents): \n", 152 | " file = open(filename, 'a', encoding='utf-8')\n", 153 | " file.write(contents)\n", 154 | " file.close()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "### 模型训练" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 8, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "import keras\n", 171 | "from keras import layers\n", 172 | "from keras.layers import LSTM, Dense, Dropout\n", 173 | "\n", 174 | "def create_model(words,learning_rate): #定义创建模型的函数\n", 175 | " model = keras.models.Sequential() #模型初始化\n", 176 | " model.add(layers.Embedding(len(words),512)) #模型第一层为embedding层\n", 177 | " model.add(layers.LSTM(512,return_sequences=True,dropout=0.2,recurrent_dropout=0.2)) #模型第二层为LSTM层,加入dropout减少过拟合\n", 178 | " model.add(layers.LSTM(512,dropout=0.2,recurrent_dropout=0.2)) #模型第三层为LSTM层,加入dropout减少过拟合\n", 179 | " model.add(layers.Dense(len(words),activation='softmax')) #模型第三层为全连接层\n", 180 | "\n", 181 | " optimizer = keras.optimizers.RMSprop(lr=learning_rate) #定义优化器\n", 182 | " model.compile(loss='categorical_crossentropy',optimizer=optimizer) #模型编译\n", 183 | " \n", 184 | " return model" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": { 191 | "scrolled": true 192 | }, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 199 | "Instructions for updating:\n", 200 | "Colocations handled automatically by placer.\n", 201 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 202 | "Instructions for updating:\n", 203 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 204 | "_________________________________________________________________\n", 205 | "Layer (type) Output Shape Param # \n", 206 | "=================================================================\n", 207 | "embedding_1 (Embedding) (None, None, 512) 4026880 \n", 208 | "_________________________________________________________________\n", 209 | "lstm_1 (LSTM) (None, None, 512) 2099200 \n", 210 | "_________________________________________________________________\n", 211 | "lstm_2 (LSTM) (None, 512) 2099200 \n", 212 | "_________________________________________________________________\n", 213 | "dense_1 (Dense) (None, 7865) 4034745 \n", 214 | "=================================================================\n", 215 | "Total params: 12,260,025\n", 216 | "Trainable params: 12,260,025\n", 217 | "Non-trainable params: 0\n", 218 | "_________________________________________________________________\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "\"\"\"创建模型实例\"\"\"\n", 224 | "learning_rate = 0.003\n", 225 | "model = create_model(words,learning_rate) #创建模型\n", 226 | "model.summary() #打印模型结构" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "def clean_and_split(line):\n", 236 | " #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n", 237 | " line = ' '.join(re.split(' |\\t|\\v',line))\n", 238 | " #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n", 239 | " line = re.split('([: ,.*\\n(){}\\[\\]=])',line) \n", 240 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 241 | " return line" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 11, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "\"\"\"打印生成的结果\"\"\"\n", 251 | "def print_code_text(code_list):\n", 252 | " mark = \".,*()[]:{}\\n\"\n", 253 | " \n", 254 | " result = \"\"\n", 255 | " last_word = \"\"\n", 256 | " \n", 257 | " for word in code_list:\n", 258 | " if last_word not in mark and word not in mark:\n", 259 | " result += \" \" + word\n", 260 | " else:\n", 261 | " result += word\n", 262 | " \n", 263 | " last_word = word\n", 264 | " \n", 265 | " print(result)\n", 266 | " \n", 267 | " return result" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "def text_generate(model,input_file,maxlen,\n", 277 | " temperatures,save_path,epoch,gen_lines=30):\n", 278 | " import random\n", 279 | " import codecs\n", 280 | " \n", 281 | " #从原始数据中随机找一行作为文本生成的起点\n", 282 | " with codecs.open(input_file,\"r\",\"utf-8\") as fin:\n", 283 | " lines = fin.readlines()\n", 284 | " random_line = random.randint(0,len(lines))\n", 285 | " start_line = clean_and_split(lines[random_line])\n", 286 | " \n", 287 | " one_line_max_words = 30\n", 288 | " \n", 289 | " print(\"===========epoch:%d===========\" % epoch)\n", 290 | " \n", 291 | " for temperature in temperatures:\n", 292 | " generated_text = start_line[:]\n", 293 | " print_string = start_line[:]\n", 294 | "\n", 295 | " for i in range(gen_lines):\n", 296 | " for j in range(one_line_max_words):\n", 297 | " sampled = np.zeros((1,len(generated_text))) \n", 298 | " #向量化\n", 299 | " for t,word in enumerate(generated_text): \n", 300 | " sampled[0,t] = word_indices.get(word,word_indices[''])\n", 301 | " #预测下一个词\n", 302 | " preds = model.predict(sampled,verbose=0)[0]\n", 303 | " next_index = sample(preds,temperature)\n", 304 | " next_word = words[next_index]\n", 305 | "\n", 306 | " if len(generated_text) == maxlen:\n", 307 | " generated_text = generated_text[1:]\n", 308 | " generated_text.append(next_word)\n", 309 | " print_string.append(next_word)\n", 310 | " if next_word == '\\n':\n", 311 | " break\n", 312 | "\n", 313 | " print(\"-----temperature: {}-----\".format(temperature))\n", 314 | " result = print_code_text(print_string)\n", 315 | "\n", 316 | " save_file = save_path + \"/{}_epoch_{}_temperature\".format(epoch,temperature)\n", 317 | " with codecs.open(save_file,\"w\",\"utf-8\") as fout:\n", 318 | " fout.write(result)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 13, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "\"\"\"模型保存\"\"\"\n", 328 | "from keras.callbacks import ModelCheckpoint\n", 329 | "filepath = \"../02-checkpoints/\"\n", 330 | "checkpoint = ModelCheckpoint(filepath, save_weights_only=False,verbose=1,save_best_only=False) #回调函数,实现断点续训功能" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 14, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "\"\"\"学习率随模型效果变小\"\"\"\n", 340 | "import keras.backend as K\n", 341 | "from keras.callbacks import LearningRateScheduler\n", 342 | " \n", 343 | "def scheduler(epoch):\n", 344 | " # 每隔10个epoch,学习率减小为原来的5/10\n", 345 | " if epoch % 10 == 0 and epoch != 0:\n", 346 | " lr = K.get_value(model.optimizer.lr)\n", 347 | " K.set_value(model.optimizer.lr, lr * 0.5)\n", 348 | " print(\"lr changed to {}\".format(lr * 0.5))\n", 349 | " return K.get_value(model.optimizer.lr)\n", 350 | " \n", 351 | "reduce_lr = LearningRateScheduler(scheduler)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 15, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 364 | "Instructions for updating:\n", 365 | "Use tf.cast instead.\n", 366 | "Epoch 1/5\n", 367 | " 4096/79181 [>.............................] - ETA: 2:37 - loss: 0.6361" 368 | ] 369 | }, 370 | { 371 | "ename": "KeyboardInterrupt", 372 | "evalue": "", 373 | "output_type": "error", 374 | "traceback": [ 375 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 376 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 377 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minit_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_epoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreduce_lr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#开始训练模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mtext_generate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxlen\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mprint_save_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_save_path\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"epoch_{}.hdf5\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 378 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m validation_steps=validation_steps)\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m def evaluate(self, x=None, y=None,\n", 379 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 380 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 381 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 382 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1437\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1439\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1440\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 383 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 384 | ] 385 | } 386 | ], 387 | "source": [ 388 | "\"\"\"训练模型\"\"\"\n", 389 | "import os\n", 390 | "print_save_path = \"../00-data/\"\n", 391 | "model_save_path = \"../02-checkpoints/\"\n", 392 | "total_epochs = 50 \n", 393 | "for epoch in range(1,total_epochs):\n", 394 | " if os.path.exists(filepath): #如果模型存在,则从现有模型开始训练\n", 395 | " model.load_weights(filepath)\n", 396 | " init_epoch = (epoch - 1) * 5\n", 397 | " model.fit(x,y,batch_size=1024,epochs=(init_epoch + 5),initial_epoch=init_epoch,callbacks=[reduce_lr,checkpoint]) #开始训练模型\n", 398 | " text_generate(model,data_file,maxlen,[0.1,0.4,0.8],print_save_path,epoch * 5)\n", 399 | " model.save(model_save_path + \"epoch_{}.hdf5\".format(epoch * 5))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "### 模型预测" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 32, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "def generate_text_sentence(seed_text,model_filename): #测试代码和上面训练模型的代码基本一样,就不再介绍\n", 416 | " model.load_weights(model_filename)\n", 417 | " \n", 418 | " strings=''\n", 419 | " last_word=''\n", 420 | " seed_text = re.split('([: ,.\\n(){}\\[\\]=])',seed_text)\n", 421 | " seed_text = list(filter(lambda x: x!=' 'and x!='',seed_text))\n", 422 | " \n", 423 | " generated_text = seed_text[:]\n", 424 | " \n", 425 | " for temperature in [0.1,0.4,0.8]:\n", 426 | " strings += '\\n' + '-------------temperature:' + str(temperature) +'-------------\\n' +'\\n'\n", 427 | " \n", 428 | " for i in range(50):\n", 429 | " if i == 0:\n", 430 | " for k in range(len(generated_text)):\n", 431 | " if generated_text[k] not in mark and last_word not in mark:\n", 432 | " strings += ' ' + generated_text[k]\n", 433 | " else:\n", 434 | " strings += generated_text[k]\n", 435 | " last_word = generated_text[k]\n", 436 | "\n", 437 | " sampled = np.zeros((1,len(generated_text)))\n", 438 | " for t,word in enumerate(generated_text):\n", 439 | " sampled[0,t] = word_indices[word]\n", 440 | "\n", 441 | " preds = model.predict(sampled,verbose=0)[0]\n", 442 | " next_index = sample(preds,temperature = 0.3)\n", 443 | " next_word = words[next_index]\n", 444 | "\n", 445 | "\n", 446 | " generated_text.append(next_word)\n", 447 | "\n", 448 | " #if len(generated_text) == maxlen:\n", 449 | " # generated_text = generated_text[1:]\n", 450 | "\n", 451 | " if next_word not in mark and last_word not in mark:\n", 452 | " strings += ' ' + next_word\n", 453 | " else:\n", 454 | " strings += next_word\n", 455 | "\n", 456 | " last_word = next_word\n", 457 | "\n", 458 | " if next_word == '\\n':\n", 459 | " break\n", 460 | " \n", 461 | " generated_text = seed_text[:]\n", 462 | " \n", 463 | " return strings\n" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [] 472 | } 473 | ], 474 | "metadata": { 475 | "kernelspec": { 476 | "display_name": "myconda", 477 | "language": "python", 478 | "name": "myconda" 479 | }, 480 | "language_info": { 481 | "codemirror_mode": { 482 | "name": "ipython", 483 | "version": 3 484 | }, 485 | "file_extension": ".py", 486 | "mimetype": "text/x-python", 487 | "name": "python", 488 | "nbconvert_exporter": "python", 489 | "pygments_lexer": "ipython3", 490 | "version": "3.5.6" 491 | } 492 | }, 493 | "nbformat": 4, 494 | "nbformat_minor": 4 495 | } 496 | -------------------------------------------------------------------------------- /model/01-Text_Gen/.ipynb_checkpoints/Text_Gen_tf10-checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "\"\"\"数据导入\"\"\"\n", 10 | "\n", 11 | "import re\n", 12 | "data_file = \"../00-data/tf_data.txt\"\n", 13 | "filename =open(data_file,'r',encoding='utf-8') #打开数据文件\n", 14 | "\n", 15 | "text = filename.read() #将数据读取到字符串text中\n", 16 | "text = ' '.join(re.split(' |\\t|\\v',text)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n", 17 | "text = re.split('([: ,.*\\n(){}\\[\\]=])',text) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n", 18 | "\n", 19 | "text = list(filter(lambda x: x!=' 'and x!='',text)) #将列表中的空格和非空格筛选掉\n", 20 | "list_text = text #保留一份列表格式的数据\n", 21 | "text = ' '.join(text) #将列表转换成字符串" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "\"\"\"文本词频统计\"\"\"\n", 31 | "\n", 32 | "def word_count(list_text): #定义计算文本词频的函数,传入list_text列表\n", 33 | " import collections\n", 34 | " word_freq = collections.defaultdict(int) #定义一个int型的词频词典,并提供默认值\n", 35 | " for w in list_text: #遍历列表中的元素,元素出现一次,频次加一\n", 36 | " word_freq[w] += 1\n", 37 | " return word_freq #返回词频词典\n", 38 | " \n", 39 | " #return word_freq.items() 该语句返回值的类型为list(这句话有语法问题,不必考虑)" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": 3, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "\"\"\"根据text文本创建代码词词典\"\"\"\n", 49 | "\n", 50 | "def build_dict(text, min_word_freq=50):\n", 51 | " word_freq = word_count(text) #文本词频统计,返回一个词频词典\n", 52 | " word_freq = filter(lambda x: x[1] > min_word_freq, word_freq.items()) # filter将词频数量低于指定值的代码词删除。\n", 53 | " word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) # key用于指定排序的元素,因为sorted默认使用list中每个item的第一个元素从小到大排列,所以这里通过lambda进行前后元素调序,并对词频去相反数,从而将词频最大的排列在最前面\n", 54 | " words, _ = list(zip(*word_freq_sorted)) #获取每一个代码词\n", 55 | " words = list(words)\n", 56 | " words.append('')\n", 57 | " word_idx = dict(zip(words, range(len(words)))) #构建词典(不包含词频)\n", 58 | " return words,word_idx #这里只返回了words,倒数两行代码还用不上。返回的是一个不含重复的代码词词典,不包含词频。" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 4, 64 | "metadata": {}, 65 | "outputs": [ 66 | { 67 | "name": "stderr", 68 | "output_type": "stream", 69 | "text": [ 70 | "Using TensorFlow backend.\n" 71 | ] 72 | }, 73 | { 74 | "name": "stdout", 75 | "output_type": "stream", 76 | "text": [ 77 | "Number of sequences: 79181\n", 78 | "Unique words: 7865\n", 79 | "Vectorization...\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "\"\"\"数据预处理-字符串序列向量化\"\"\"\n", 85 | "\n", 86 | "import numpy as np\n", 87 | "import keras\n", 88 | "import json\n", 89 | "\n", 90 | "maxlen = 50 #提取50个代码词组成的序列\n", 91 | "step = 5 #每5个代码词采样一个新序列\n", 92 | "sentences = [] #保存所提取的序列\n", 93 | "next_words = [] #保存目标代码词\n", 94 | "vocab_file = \"../00-data/vocab\"\n", 95 | "\n", 96 | "cut_words = list_text #将列表形式的元数据保存在cut_words中\n", 97 | "for i in range(0,len(cut_words) - maxlen,step):\n", 98 | " sentences.append(cut_words[i:i + maxlen]) #将元数据按照步长来存储在每个序列中 \n", 99 | " next_words.append(cut_words[i + maxlen]) #将目标代码词存储在next_words中\n", 100 | " \n", 101 | " \n", 102 | "print('Number of sequences:', len(sentences))\n", 103 | "\n", 104 | "\n", 105 | "words,word_idx = list(build_dict(list_text,1)) #创建代码词词典,返回的是一个不含重复的代码词词典,不包含词频。\n", 106 | "print('Unique words:',len(words))\n", 107 | "json_dict = json.dumps(word_idx)\n", 108 | "with open(vocab_file,\"w\") as f:\n", 109 | " f.write(json_dict)\n", 110 | "\n", 111 | "word_indices = dict((word,words.index(word)) for word in words) #创建一个包含代码词唯一索引的代码词词典,返回的是一个字典\n", 112 | "#print(word_indices)\n", 113 | "\n", 114 | "print('Vectorization...')\n", 115 | "x = np.zeros((len(sentences),maxlen)) #初始化x\n", 116 | "y = np.zeros((len(sentences))) #初始化y\n", 117 | "for i,sentence in enumerate(sentences):\n", 118 | " for t,word in enumerate(sentence):\n", 119 | " x[i,t] = word_indices.get(word,word_indices['']) #将代码词转换成向量形式的编码\n", 120 | " #y[i] = word_indices[next_words[i]]\n", 121 | " y[i] = word_indices.get(next_words[i],word_indices[''])\n", 122 | "\n", 123 | "y = keras.utils.to_categorical(y, len(words)) #将int型数组y转换成one-hot编码" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": 6, 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [ 132 | "\"\"\"定义下一个代码词的采样函数---temperature越大,代码生成的随机性越强---\"\"\"\n", 133 | "\n", 134 | "def sample(preds,temperature=0.1):\n", 135 | " preds = np.asarray(preds).astype('float')\n", 136 | " preds = np.log(preds) /temperature\n", 137 | " exp_preds = np.exp(preds)\n", 138 | " preds = exp_preds / np.sum(exp_preds)\n", 139 | " probas = np.random.multinomial(1,preds,1)\n", 140 | " return np.argmax(probas)" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 7, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "\"\"\"将字符串写到指定文件中\"\"\"\n", 150 | "\n", 151 | "def save(filename, contents): \n", 152 | " file = open(filename, 'a', encoding='utf-8')\n", 153 | " file.write(contents)\n", 154 | " file.close()" 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "### 模型训练" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 8, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [ 170 | "import keras\n", 171 | "from keras import layers\n", 172 | "from keras.layers import LSTM, Dense, Dropout\n", 173 | "\n", 174 | "def create_model(words,learning_rate): #定义创建模型的函数\n", 175 | " model = keras.models.Sequential() #模型初始化\n", 176 | " model.add(layers.Embedding(len(words),512)) #模型第一层为embedding层\n", 177 | " model.add(layers.LSTM(512,return_sequences=True,dropout=0.2,recurrent_dropout=0.2)) #模型第二层为LSTM层,加入dropout减少过拟合\n", 178 | " model.add(layers.LSTM(512,dropout=0.2,recurrent_dropout=0.2)) #模型第三层为LSTM层,加入dropout减少过拟合\n", 179 | " model.add(layers.Dense(len(words),activation='softmax')) #模型第三层为全连接层\n", 180 | "\n", 181 | " optimizer = keras.optimizers.RMSprop(lr=learning_rate) #定义优化器\n", 182 | " model.compile(loss='categorical_crossentropy',optimizer=optimizer) #模型编译\n", 183 | " \n", 184 | " return model" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 9, 190 | "metadata": { 191 | "scrolled": true 192 | }, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 199 | "Instructions for updating:\n", 200 | "Colocations handled automatically by placer.\n", 201 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 202 | "Instructions for updating:\n", 203 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 204 | "_________________________________________________________________\n", 205 | "Layer (type) Output Shape Param # \n", 206 | "=================================================================\n", 207 | "embedding_1 (Embedding) (None, None, 512) 4026880 \n", 208 | "_________________________________________________________________\n", 209 | "lstm_1 (LSTM) (None, None, 512) 2099200 \n", 210 | "_________________________________________________________________\n", 211 | "lstm_2 (LSTM) (None, 512) 2099200 \n", 212 | "_________________________________________________________________\n", 213 | "dense_1 (Dense) (None, 7865) 4034745 \n", 214 | "=================================================================\n", 215 | "Total params: 12,260,025\n", 216 | "Trainable params: 12,260,025\n", 217 | "Non-trainable params: 0\n", 218 | "_________________________________________________________________\n" 219 | ] 220 | } 221 | ], 222 | "source": [ 223 | "\"\"\"创建模型实例\"\"\"\n", 224 | "learning_rate = 0.003\n", 225 | "model = create_model(words,learning_rate) #创建模型\n", 226 | "model.summary() #打印模型结构" 227 | ] 228 | }, 229 | { 230 | "cell_type": "code", 231 | "execution_count": 10, 232 | "metadata": {}, 233 | "outputs": [], 234 | "source": [ 235 | "def clean_and_split(line):\n", 236 | " #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n", 237 | " line = ' '.join(re.split(' |\\t|\\v',line))\n", 238 | " #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n", 239 | " line = re.split('([: ,.*\\n(){}\\[\\]=])',line) \n", 240 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 241 | " return line" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 11, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "\"\"\"打印生成的结果\"\"\"\n", 251 | "def print_code_text(code_list):\n", 252 | " mark = \".,*()[]:{}\\n\"\n", 253 | " \n", 254 | " result = \"\"\n", 255 | " last_word = \"\"\n", 256 | " \n", 257 | " for word in code_list:\n", 258 | " if last_word not in mark and word not in mark:\n", 259 | " result += \" \" + word\n", 260 | " else:\n", 261 | " result += word\n", 262 | " \n", 263 | " last_word = word\n", 264 | " \n", 265 | " print(result)\n", 266 | " \n", 267 | " return result" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 12, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "def text_generate(model,input_file,maxlen,\n", 277 | " temperatures,save_path,epoch,gen_lines=30):\n", 278 | " import random\n", 279 | " import codecs\n", 280 | " \n", 281 | " #从原始数据中随机找一行作为文本生成的起点\n", 282 | " with codecs.open(input_file,\"r\",\"utf-8\") as fin:\n", 283 | " lines = fin.readlines()\n", 284 | " random_line = random.randint(0,len(lines))\n", 285 | " start_line = clean_and_split(lines[random_line])\n", 286 | " \n", 287 | " one_line_max_words = 30\n", 288 | " \n", 289 | " print(\"===========epoch:%d===========\" % epoch)\n", 290 | " \n", 291 | " for temperature in temperatures:\n", 292 | " generated_text = start_line[:]\n", 293 | " print_string = start_line[:]\n", 294 | "\n", 295 | " for i in range(gen_lines):\n", 296 | " for j in range(one_line_max_words):\n", 297 | " sampled = np.zeros((1,len(generated_text))) \n", 298 | " #向量化\n", 299 | " for t,word in enumerate(generated_text): \n", 300 | " sampled[0,t] = word_indices.get(word,word_indices[''])\n", 301 | " #预测下一个词\n", 302 | " preds = model.predict(sampled,verbose=0)[0]\n", 303 | " next_index = sample(preds,temperature)\n", 304 | " next_word = words[next_index]\n", 305 | "\n", 306 | " if len(generated_text) == maxlen:\n", 307 | " generated_text = generated_text[1:]\n", 308 | " generated_text.append(next_word)\n", 309 | " print_string.append(next_word)\n", 310 | " if next_word == '\\n':\n", 311 | " break\n", 312 | "\n", 313 | " print(\"-----temperature: {}-----\".format(temperature))\n", 314 | " result = print_code_text(print_string)\n", 315 | "\n", 316 | " save_file = save_path + \"/{}_epoch_{}_temperature\".format(epoch,temperature)\n", 317 | " with codecs.open(save_file,\"w\",\"utf-8\") as fout:\n", 318 | " fout.write(result)" 319 | ] 320 | }, 321 | { 322 | "cell_type": "code", 323 | "execution_count": 13, 324 | "metadata": {}, 325 | "outputs": [], 326 | "source": [ 327 | "\"\"\"模型保存\"\"\"\n", 328 | "from keras.callbacks import ModelCheckpoint\n", 329 | "filepath = \"../02-checkpoints/\"\n", 330 | "checkpoint = ModelCheckpoint(filepath, save_weights_only=False,verbose=1,save_best_only=False) #回调函数,实现断点续训功能" 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 14, 336 | "metadata": {}, 337 | "outputs": [], 338 | "source": [ 339 | "\"\"\"学习率随模型效果变小\"\"\"\n", 340 | "import keras.backend as K\n", 341 | "from keras.callbacks import LearningRateScheduler\n", 342 | " \n", 343 | "def scheduler(epoch):\n", 344 | " # 每隔10个epoch,学习率减小为原来的5/10\n", 345 | " if epoch % 10 == 0 and epoch != 0:\n", 346 | " lr = K.get_value(model.optimizer.lr)\n", 347 | " K.set_value(model.optimizer.lr, lr * 0.5)\n", 348 | " print(\"lr changed to {}\".format(lr * 0.5))\n", 349 | " return K.get_value(model.optimizer.lr)\n", 350 | " \n", 351 | "reduce_lr = LearningRateScheduler(scheduler)" 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": 15, 357 | "metadata": {}, 358 | "outputs": [ 359 | { 360 | "name": "stdout", 361 | "output_type": "stream", 362 | "text": [ 363 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 364 | "Instructions for updating:\n", 365 | "Use tf.cast instead.\n", 366 | "Epoch 1/5\n", 367 | " 4096/79181 [>.............................] - ETA: 2:37 - loss: 0.6361" 368 | ] 369 | }, 370 | { 371 | "ename": "KeyboardInterrupt", 372 | "evalue": "", 373 | "output_type": "error", 374 | "traceback": [ 375 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 376 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 377 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0minit_epoch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minit_epoch\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mreduce_lr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcheckpoint\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#开始训练模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0mtext_generate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdata_file\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmaxlen\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0.1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m0.8\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mprint_save_path\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel_save_path\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"epoch_{}.hdf5\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 378 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)\u001b[0m\n\u001b[1;32m 1037\u001b[0m \u001b[0minitial_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minitial_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1038\u001b[0m \u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msteps_per_epoch\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1039\u001b[0;31m validation_steps=validation_steps)\n\u001b[0m\u001b[1;32m 1040\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m def evaluate(self, x=None, y=None,\n", 379 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mfit_loop\u001b[0;34m(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mins_batch\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtoarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 199\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 200\u001b[0m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mto_list\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_labels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 380 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2713\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_legacy_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2714\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2715\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2716\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2717\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mpy_any\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_tensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 381 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/keras/backend/tensorflow_backend.py\u001b[0m in \u001b[0;36m_call\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 2673\u001b[0m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_metadata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2675\u001b[0;31m \u001b[0mfetched\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callable_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0marray_vals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2676\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2677\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 382 | "\u001b[0;32m~/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1437\u001b[0m ret = tf_session.TF_SessionRunCallable(\n\u001b[1;32m 1438\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstatus\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1439\u001b[0;31m run_metadata_ptr)\n\u001b[0m\u001b[1;32m 1440\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1441\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 383 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 384 | ] 385 | } 386 | ], 387 | "source": [ 388 | "\"\"\"训练模型\"\"\"\n", 389 | "import os\n", 390 | "print_save_path = \"../00-data/\"\n", 391 | "model_save_path = \"../02-checkpoints/\"\n", 392 | "total_epochs = 50 \n", 393 | "for epoch in range(1,total_epochs):\n", 394 | " if os.path.exists(filepath): #如果模型存在,则从现有模型开始训练\n", 395 | " model.load_weights(filepath)\n", 396 | " init_epoch = (epoch - 1) * 5\n", 397 | " model.fit(x,y,batch_size=1024,epochs=(init_epoch + 5),initial_epoch=init_epoch,callbacks=[reduce_lr,checkpoint]) #开始训练模型\n", 398 | " text_generate(model,data_file,maxlen,[0.1,0.4,0.8],print_save_path,epoch * 5)\n", 399 | " model.save(model_save_path + \"epoch_{}.hdf5\".format(epoch * 5))" 400 | ] 401 | }, 402 | { 403 | "cell_type": "markdown", 404 | "metadata": {}, 405 | "source": [ 406 | "### 模型预测" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": 32, 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "def generate_text_sentence(seed_text,model_filename): #测试代码和上面训练模型的代码基本一样,就不再介绍\n", 416 | " model.load_weights(model_filename)\n", 417 | " \n", 418 | " strings=''\n", 419 | " last_word=''\n", 420 | " seed_text = re.split('([: ,.\\n(){}\\[\\]=])',seed_text)\n", 421 | " seed_text = list(filter(lambda x: x!=' 'and x!='',seed_text))\n", 422 | " \n", 423 | " generated_text = seed_text[:]\n", 424 | " \n", 425 | " for temperature in [0.1,0.4,0.8]:\n", 426 | " strings += '\\n' + '-------------temperature:' + str(temperature) +'-------------\\n' +'\\n'\n", 427 | " \n", 428 | " for i in range(50):\n", 429 | " if i == 0:\n", 430 | " for k in range(len(generated_text)):\n", 431 | " if generated_text[k] not in mark and last_word not in mark:\n", 432 | " strings += ' ' + generated_text[k]\n", 433 | " else:\n", 434 | " strings += generated_text[k]\n", 435 | " last_word = generated_text[k]\n", 436 | "\n", 437 | " sampled = np.zeros((1,len(generated_text)))\n", 438 | " for t,word in enumerate(generated_text):\n", 439 | " sampled[0,t] = word_indices[word]\n", 440 | "\n", 441 | " preds = model.predict(sampled,verbose=0)[0]\n", 442 | " next_index = sample(preds,temperature = 0.3)\n", 443 | " next_word = words[next_index]\n", 444 | "\n", 445 | "\n", 446 | " generated_text.append(next_word)\n", 447 | "\n", 448 | " #if len(generated_text) == maxlen:\n", 449 | " # generated_text = generated_text[1:]\n", 450 | "\n", 451 | " if next_word not in mark and last_word not in mark:\n", 452 | " strings += ' ' + next_word\n", 453 | " else:\n", 454 | " strings += next_word\n", 455 | "\n", 456 | " last_word = next_word\n", 457 | "\n", 458 | " if next_word == '\\n':\n", 459 | " break\n", 460 | " \n", 461 | " generated_text = seed_text[:]\n", 462 | " \n", 463 | " return strings\n" 464 | ] 465 | }, 466 | { 467 | "cell_type": "code", 468 | "execution_count": null, 469 | "metadata": {}, 470 | "outputs": [], 471 | "source": [] 472 | } 473 | ], 474 | "metadata": { 475 | "kernelspec": { 476 | "display_name": "myconda", 477 | "language": "python", 478 | "name": "myconda" 479 | }, 480 | "language_info": { 481 | "codemirror_mode": { 482 | "name": "ipython", 483 | "version": 3 484 | }, 485 | "file_extension": ".py", 486 | "mimetype": "text/x-python", 487 | "name": "python", 488 | "nbconvert_exporter": "python", 489 | "pygments_lexer": "ipython3", 490 | "version": "3.5.6" 491 | } 492 | }, 493 | "nbformat": 4, 494 | "nbformat_minor": 4 495 | } 496 | -------------------------------------------------------------------------------- /model/02-Seq2Seq/Seq2Seq_tf10.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import tensorflow as tf\n", 10 | "import re\n", 11 | "import os\n", 12 | "import collections\n", 13 | "import numpy as np\n", 14 | "import codecs\n", 15 | "import random\n", 16 | "from operator import itemgetter" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "#GPU设置\n", 26 | "config = tf.ConfigProto()\n", 27 | "config.gpu_options.per_process_gpu_memory_fraction = 0.95 #占用95%显存\n", 28 | "session = tf.Session(config=config)\n", 29 | "os.environ['CUDA_VISIBLE_DEVICES']=\"0\"" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "metadata": {}, 35 | "source": [ 36 | "## 构建词表" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 14, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "def get_vocab(input_data,min_word_freq):\n", 46 | " counter = collections.Counter()\n", 47 | " with codecs.open(input_data,\"r\",\"utf-8\") as f:\n", 48 | " for line in f:\n", 49 | " line = ' '.join(re.split(' |\\t|\\v|\\n',line)) #将数据中的空格符统一,便于后期处理(原始数据中空格符包含\\t、\\v等) \n", 50 | " line = re.split('([: ,.(){}\\[\\]=])',line) #将字符串数据按照括号中的符号进行分割,分割成列表格式,并且在列表中保留分隔符\n", 51 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 52 | " for word in line:\n", 53 | " counter[word] += 1\n", 54 | " \n", 55 | " counter = filter(lambda x: x[1] > min_word_freq, counter.items())\n", 56 | " sorted_word_to_cnt = sorted(counter,key=itemgetter(1),reverse=True)\n", 57 | " sorted_words = [x[0] for x in sorted_word_to_cnt]\n", 58 | "\n", 59 | " sorted_words = [\"\",\"\",\"\"] + sorted_words\n", 60 | "\n", 61 | " print(\"vocab_len: \" + str(len(sorted_words)))\n", 62 | "\n", 63 | " return sorted_words" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 15, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "name": "stdout", 73 | "output_type": "stream", 74 | "text": [ 75 | "vocab_len: 11928\n" 76 | ] 77 | }, 78 | { 79 | "data": { 80 | "text/plain": [ 81 | "['in_channels',\n", 82 | " 'zeros',\n", 83 | " 'class',\n", 84 | " 'plt',\n", 85 | " 'session',\n", 86 | " 'variable_scope',\n", 87 | " 'join',\n", 88 | " 'size',\n", 89 | " 'Session',\n", 90 | " 'zip']" 91 | ] 92 | }, 93 | "execution_count": 15, 94 | "metadata": {}, 95 | "output_type": "execute_result" 96 | } 97 | ], 98 | "source": [ 99 | "input_data = \"../00-data/tf_data.txt\"\n", 100 | "min_word_freq = 0\n", 101 | "vocab = get_vocab(input_data,min_word_freq)\n", 102 | "vocab[70:80]" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 8, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "word_dict = {word:index for index,word in enumerate(vocab)}" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": 10, 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "data": { 121 | "text/plain": [ 122 | "78" 123 | ] 124 | }, 125 | "execution_count": 10, 126 | "metadata": {}, 127 | "output_type": "execute_result" 128 | } 129 | ], 130 | "source": [ 131 | "word_dict[\"Session\"]" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## 【随机】前k句预测下一句" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 5, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "def clean_and_split(line):\n", 148 | " line = ' '.join(re.split(' |\\t|\\v|\\n',line)) \n", 149 | " line = re.split('([: ,.(){}\\[\\]=])',line) \n", 150 | " line = list(filter(lambda x: x!=' 'and x!='',line))\n", 151 | " return line" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 6, 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "def get_newdata_by_random(raw_data,vocab,enc_vec_data,enc_text_data,dec_vec_data,dec_text_data,rand_max=10,duplicate=2):\n", 161 | " import codecs\n", 162 | " import sys\n", 163 | " import re\n", 164 | "\n", 165 | " word_to_id = {k:v for(k,v) in zip(vocab,range(len(vocab)))}\n", 166 | "\n", 167 | " def get_id(word):\n", 168 | " return word_to_id[word] if word in word_to_id else word_to_id[\"\"]\n", 169 | "\n", 170 | " fout_vec_enc = codecs.open(enc_vec_data,\"w\",\"utf-8\")\n", 171 | " fout_vec_dec = codecs.open(dec_vec_data,\"w\",\"utf-8\")\n", 172 | " fout_text_enc = codecs.open(enc_text_data,\"w\",\"utf-8\")\n", 173 | " fout_text_dec = codecs.open(dec_text_data,\"w\",\"utf-8\")\n", 174 | " \n", 175 | " data_length = 0\n", 176 | " with open(raw_data,\"r\") as fin:\n", 177 | " lines = fin.readlines()\n", 178 | " for i in range(rand_max,len(lines)):\n", 179 | " rand_nums = set(random.randint(1,rand_max) for _ in range(duplicate))\n", 180 | " for rand_num in rand_nums:\n", 181 | " #构造enc_data\n", 182 | " words = []\n", 183 | " for j in range(i - rand_num,i):\n", 184 | " line = clean_and_split(lines[j])\n", 185 | " words += [\"\"] + line + [\"\"]\n", 186 | " out_line = ' '.join([str(get_id(w)) for w in words]) + '\\n'\n", 187 | " fout_text_enc.write(' '.join(words) + '\\n')\n", 188 | " fout_vec_enc.write(out_line)\n", 189 | "\n", 190 | " #构造dec_data\n", 191 | " words = []\n", 192 | " line = clean_and_split(lines[i])\n", 193 | " words = line + [\"\"]\n", 194 | " out_line = ' '.join([str(get_id(w)) for w in words]) + '\\n'\n", 195 | " fout_text_dec.write(' '.join(words) + '\\n')\n", 196 | " fout_vec_dec.write(out_line)\n", 197 | " \n", 198 | " data_length += 1\n", 199 | " fout_vec_enc.close()\n", 200 | " fout_vec_dec.close()\n", 201 | " fout_text_enc.close()\n", 202 | " fout_text_dec.close()\n", 203 | " \n", 204 | " return data_length" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": 7, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "enc_vec_data = \"../00-data/train.enc\"\n", 214 | "dec_vec_data = \"../00-data/train.dec\"\n", 215 | "enc_text_data = \"../00-data/train_text.enc\"\n", 216 | "dec_text_data = \"../00-data/train_text.dec\"" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": 8, 222 | "metadata": {}, 223 | "outputs": [ 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "data_len: 74052\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "data_length = get_newdata_by_random(input_data,vocab,enc_vec_data,enc_text_data,dec_vec_data,dec_text_data)\n", 234 | "print(\"data_len: \" + str(data_length))" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "metadata": {}, 240 | "source": [ 241 | "## 原始数据向量化" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 9, 247 | "metadata": {}, 248 | "outputs": [], 249 | "source": [ 250 | "def MakeDataset(file_path):\n", 251 | " dataset = tf.data.TextLineDataset(file_path)\n", 252 | " dataset = dataset.map(lambda string:tf.string_split([string]).values)\n", 253 | " dataset = dataset.map(\n", 254 | " lambda string: tf.string_to_number(string,tf.int32))\n", 255 | " dataset = dataset.map(lambda x:(x,tf.size(x)))\n", 256 | " return dataset" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 10, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "def MakeSrcTrgDataset(src_path,trg_path,batch_size,shuffle_size=3000,start_id=1):\n", 266 | " src_data = MakeDataset(src_path)\n", 267 | " trg_data = MakeDataset(trg_path)\n", 268 | " \n", 269 | " dataset = tf.data.Dataset.zip((src_data,trg_data))\n", 270 | " \n", 271 | " #删除内容为空or长度过长的句子\n", 272 | " #不需要执行\n", 273 | " def FilterLength(src_tuple,trg_tuple):\n", 274 | " ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)\n", 275 | " src_len_ok = tf.logical_and(\n", 276 | " tf.greater(src_len,1),tf.less_equal(src_len,MAX_LEN))\n", 277 | " trg_len_ok = tf.logical_and(\n", 278 | " tf.greater(trg_len,1),tf.less_equal(trg_len,MAX_LEN))\n", 279 | " return tf.logical_and(src_len_ok,trg_len_ok)\n", 280 | " #dataset = dataset.filter(FilterLength)\n", 281 | " \n", 282 | " #生成 X Y Z 作为解码器的输入\n", 283 | " def MakeTrgInput(src_tuple,trg_tuple):\n", 284 | " ((src_input,src_len),(trg_label,trg_len)) = (src_tuple,trg_tuple)\n", 285 | " trg_input = tf.concat([[start_id],trg_label[:-1]],axis=0)\n", 286 | " return ((src_input,src_len),(trg_input,trg_label,trg_len))\n", 287 | " dataset = dataset.map(MakeTrgInput)\n", 288 | " \n", 289 | " #随机打乱训练数据\n", 290 | " dataset = dataset.shuffle(shuffle_size)\n", 291 | " \n", 292 | " #规定填充后输出的数据维度\n", 293 | " padded_shapes = (\n", 294 | " (tf.TensorShape([None]),\n", 295 | " tf.TensorShape([])),\n", 296 | " (tf.TensorShape([None]),\n", 297 | " tf.TensorShape([None]),\n", 298 | " tf.TensorShape([])))\n", 299 | " #调用padded_batch方法进行batching操作\n", 300 | " batched_dataset = dataset.padded_batch(batch_size,padded_shapes)\n", 301 | " return batched_dataset" 302 | ] 303 | }, 304 | { 305 | "cell_type": "markdown", 306 | "metadata": {}, 307 | "source": [ 308 | "## seq2seq + attention模型\n", 309 | "encoder:2层单向lstm\n", 310 | "decoder:1层单向lstm" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 11, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "CHECKPOINT_PATH = \"../02-checkpoints/\"\n", 320 | "HIDDEN_SIZE = 256\n", 321 | "ENCODER_LAYERS = 2\n", 322 | "DECODER_LAYERS = 1\n", 323 | "SRC_VOCAB_SIZE = len(vocab)\n", 324 | "TRG_VOCAB_SIZE = len(vocab)\n", 325 | "BATCH_SIZE = 64\n", 326 | "NUM_EPOCH = 10\n", 327 | "KEEP_PROB = 0.9\n", 328 | "MAX_GRAD_NORM = 5\n", 329 | "LEARNING_RATE_BASE = 1.0\n", 330 | "LEARNING_RATE_DECAY = 0.7\n", 331 | "SHARE_EMB_AND_SOFTMAX = True" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": 12, 337 | "metadata": {}, 338 | "outputs": [], 339 | "source": [ 340 | "# 定义NMTModel类来描述模型。\n", 341 | "class NMTModel(object):\n", 342 | " # 在模型的初始化函数中定义模型要用到的变量。\n", 343 | " def __init__(self):\n", 344 | " # 定义编码器和解码器所使用的LSTM结构。\n", 345 | " self.enc_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\n", 346 | " self.enc_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE)\n", 347 | " \n", 348 | " self.enc_cell = tf.nn.rnn_cell.MultiRNNCell(\n", 349 | " [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) \n", 350 | " for _ in range(ENCODER_LAYERS)])\n", 351 | " \n", 352 | " self.dec_cell = tf.nn.rnn_cell.MultiRNNCell(\n", 353 | " [tf.nn.rnn_cell.BasicLSTMCell(HIDDEN_SIZE) \n", 354 | " for _ in range(DECODER_LAYERS)])\n", 355 | "\n", 356 | " # 为源语言和目标语言分别定义词向量。 \n", 357 | " self.src_embedding = tf.get_variable(\n", 358 | " \"src_emb\", [SRC_VOCAB_SIZE, HIDDEN_SIZE])\n", 359 | " self.trg_embedding = tf.get_variable(\n", 360 | " \"trg_emb\", [TRG_VOCAB_SIZE, HIDDEN_SIZE])\n", 361 | "\n", 362 | " # 定义softmax层的变量\n", 363 | " if SHARE_EMB_AND_SOFTMAX:\n", 364 | " self.softmax_weight = tf.transpose(self.trg_embedding)\n", 365 | " else:\n", 366 | " self.softmax_weight = tf.get_variable(\n", 367 | " \"weight\", [HIDDEN_SIZE, TRG_VOCAB_SIZE])\n", 368 | " self.softmax_bias = tf.get_variable(\n", 369 | " \"softmax_bias\", [TRG_VOCAB_SIZE])\n", 370 | "\n", 371 | " # 在forward函数中定义模型的前向计算图。\n", 372 | " # src_input, src_size, trg_input, trg_label, trg_size分别是上面\n", 373 | " # MakeSrcTrgDataset函数产生的五种张量。\n", 374 | " def forward(self, src_input, src_size, trg_input, trg_label, trg_size,data_length):\n", 375 | " global_step = tf.Variable(0, trainable=False)\n", 376 | " batch_size = tf.shape(src_input)[0]\n", 377 | " \n", 378 | " # 将输入和输出单词编号转为词向量。\n", 379 | " src_emb = tf.nn.embedding_lookup(self.src_embedding, src_input)\n", 380 | " trg_emb = tf.nn.embedding_lookup(self.trg_embedding, trg_input)\n", 381 | " \n", 382 | " # 在词向量上进行dropout。\n", 383 | " src_emb = tf.nn.dropout(src_emb, KEEP_PROB)\n", 384 | " trg_emb = tf.nn.dropout(trg_emb, KEEP_PROB)\n", 385 | "\n", 386 | " # 使用dynamic_rnn构造编码器。\n", 387 | " # 编码器读取源句子每个位置的词向量,输出最后一步的隐藏状态enc_state。\n", 388 | " # 因为编码器是一个双层LSTM,因此enc_state是一个包含两个LSTMStateTuple类\n", 389 | " # 张量的tuple,每个LSTMStateTuple对应编码器中的一层。\n", 390 | " # 张量的维度是 [batch_size, HIDDEN_SIZE]。\n", 391 | " # enc_outputs是顶层LSTM在每一步的输出,它的维度是[batch_size, \n", 392 | " # max_time, HIDDEN_SIZE]。Seq2Seq模型中不需要用到enc_outputs,而\n", 393 | " # 后面介绍的attention模型会用到它。\n", 394 | " # 下面的代码取代了Seq2Seq样例代码中forward函数里的相应部分。\n", 395 | " with tf.variable_scope(\"encoder\"):\n", 396 | " # 构造编码器时,使用dynamic_rnn构造单向循环网络。\n", 397 | " # 单向循环网络的顶层输出enc_outputs是一个包含两个张量的tuple,每个张量的\n", 398 | " # 维度都是[batch_size, max_time, HIDDEN_SIZE],代表两个LSTM在每一步的输出。\n", 399 | " enc_outputs,enc_state = tf.nn.dynamic_rnn(\n", 400 | " self.enc_cell,src_emb,src_size,dtype=tf.float32) \n", 401 | "\n", 402 | " with tf.variable_scope(\"decoder\"):\n", 403 | " # 选择注意力权重的计算模型。BahdanauAttention是使用一个隐藏层的前馈神经网络。\n", 404 | " # memory_sequence_length是一个维度为[batch_size]的张量,代表batch\n", 405 | " # 中每个句子的长度,Attention需要根据这个信息把填充位置的注意力权重设置为0。\n", 406 | " attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(\n", 407 | " HIDDEN_SIZE, enc_outputs,\n", 408 | " memory_sequence_length=src_size)\n", 409 | "\n", 410 | " # 将解码器的循环神经网络self.dec_cell和注意力一起封装成更高层的循环神经网络。\n", 411 | " attention_cell = tf.contrib.seq2seq.AttentionWrapper(\n", 412 | " self.dec_cell, attention_mechanism,\n", 413 | " attention_layer_size=HIDDEN_SIZE)\n", 414 | "\n", 415 | " # 使用attention_cell和dynamic_rnn构造编码器。\n", 416 | " # 这里没有指定init_state,也就是没有使用编码器的输出来初始化输入,而完全依赖\n", 417 | " # 注意力作为信息来源。\n", 418 | " dec_outputs, _ = tf.nn.dynamic_rnn(\n", 419 | " attention_cell, trg_emb, trg_size, dtype=tf.float32)\n", 420 | "\n", 421 | " # 计算解码器每一步的log perplexity。这一步与语言模型代码相同。\n", 422 | " output = tf.reshape(dec_outputs, [-1, HIDDEN_SIZE])\n", 423 | " logits = tf.matmul(output, self.softmax_weight) + self.softmax_bias\n", 424 | " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(\n", 425 | " labels=tf.reshape(trg_label, [-1]), logits=logits)\n", 426 | "\n", 427 | " # 在计算平均损失时,需要将填充位置的权重设置为0,以避免无效位置的预测干扰\n", 428 | " # 模型的训练。\n", 429 | " label_weights = tf.sequence_mask(\n", 430 | " trg_size, maxlen=tf.shape(trg_label)[1], dtype=tf.float32)\n", 431 | " label_weights = tf.reshape(label_weights, [-1])\n", 432 | " cost = tf.reduce_sum(loss * label_weights)\n", 433 | " cost_op = cost / tf.reduce_sum(label_weights)\n", 434 | " \n", 435 | " # 定义反向传播操作。反向操作的实现与语言模型代码相同。\n", 436 | " trainable_variables = tf.trainable_variables()\n", 437 | "\n", 438 | " # 控制梯度大小,定义优化方法和训练步骤。\n", 439 | " grads = tf.gradients(cost / tf.to_float(batch_size),\n", 440 | " trainable_variables)\n", 441 | " grads, _ = tf.clip_by_global_norm(grads, MAX_GRAD_NORM)\n", 442 | " \n", 443 | " learning_rate = tf.train.exponential_decay(\n", 444 | " LEARNING_RATE_BASE,\n", 445 | " global_step,\n", 446 | " data_length / batch_size, \n", 447 | " LEARNING_RATE_DECAY,\n", 448 | " staircase=True)\n", 449 | " optimizer = tf.train.GradientDescentOptimizer(learning_rate)\n", 450 | " train_op = optimizer.apply_gradients(\n", 451 | " zip(grads, trainable_variables))\n", 452 | " return cost_op, train_op,learning_rate,global_step" 453 | ] 454 | }, 455 | { 456 | "cell_type": "code", 457 | "execution_count": 13, 458 | "metadata": {}, 459 | "outputs": [], 460 | "source": [ 461 | "def run_epoch(session,cost_op,train_op,learning_rate_op,global_step_op,step,saver,epoch):\n", 462 | " while True:\n", 463 | " try:\n", 464 | " cost,_,learning_rate,global_step = session.run([cost_op,train_op,learning_rate_op,global_step_op])\n", 465 | " if global_step % 50 == 0:\n", 466 | " print(\"After %d global_steps,per token cost is %.3f,learning_rate is %.5f\" %(global_step,cost,learning_rate))\n", 467 | " session.run(tf.assign(global_step_op,step))\n", 468 | " step += 1\n", 469 | " except tf.errors.OutOfRangeError:\n", 470 | " if epoch % 2 == 0:\n", 471 | " saver.save(session,CHECKPOINT_PATH,global_step=global_step)\n", 472 | " break\n", 473 | " return step" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": null, 479 | "metadata": {}, 480 | "outputs": [ 481 | { 482 | "name": "stdout", 483 | "output_type": "stream", 484 | "text": [ 485 | "WARNING:tensorflow:From :6: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n", 486 | "Instructions for updating:\n", 487 | "This class is equivalent as tf.keras.layers.LSTMCell, and will be replaced by that in Tensorflow 2.0.\n", 488 | "WARNING:tensorflow:From :11: MultiRNNCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.\n", 489 | "Instructions for updating:\n", 490 | "This class is equivalent as tf.keras.layers.StackedRNNCells, and will be replaced by that in Tensorflow 2.0.\n", 491 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 492 | "Instructions for updating:\n", 493 | "Colocations handled automatically by placer.\n", 494 | "WARNING:tensorflow:From :44: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n", 495 | "Instructions for updating:\n", 496 | "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n", 497 | "WARNING:tensorflow:From :61: dynamic_rnn (from tensorflow.python.ops.rnn) is deprecated and will be removed in a future version.\n", 498 | "Instructions for updating:\n", 499 | "Please use `keras.layers.RNN(cell)`, which is equivalent to this API\n", 500 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py:626: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 501 | "Instructions for updating:\n", 502 | "Use tf.cast instead.\n", 503 | "\n", 504 | "WARNING: The TensorFlow contrib module will not be included in TensorFlow 2.0.\n", 505 | "For more information, please see:\n", 506 | " * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n", 507 | " * https://github.com/tensorflow/addons\n", 508 | "If you depend on functionality not listed there, please file an issue.\n", 509 | "\n", 510 | "WARNING:tensorflow:From :100: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", 511 | "Instructions for updating:\n", 512 | "Use tf.cast instead.\n", 513 | "WARNING:tensorflow:From /root/miniconda3/envs/myconda/lib/python3.5/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.\n", 514 | "Instructions for updating:\n", 515 | "Use standard file APIs to check for files with this prefix.\n", 516 | "INFO:tensorflow:Restoring parameters from ../02-checkpoints/02-0318/-4635\n", 517 | "In EPOCH: 1\n", 518 | "After 50 global_steps,per token cost is 2.616,learning_rate is 1.00000\n", 519 | "After 100 global_steps,per token cost is 2.445,learning_rate is 1.00000\n", 520 | "After 150 global_steps,per token cost is 1.941,learning_rate is 1.00000\n", 521 | "After 200 global_steps,per token cost is 1.367,learning_rate is 1.00000\n", 522 | "After 250 global_steps,per token cost is 1.796,learning_rate is 1.00000\n", 523 | "After 300 global_steps,per token cost is 2.015,learning_rate is 1.00000\n", 524 | "After 350 global_steps,per token cost is 2.002,learning_rate is 1.00000\n", 525 | "After 400 global_steps,per token cost is 1.729,learning_rate is 1.00000\n", 526 | "After 450 global_steps,per token cost is 1.414,learning_rate is 1.00000\n", 527 | "After 500 global_steps,per token cost is 1.769,learning_rate is 1.00000\n", 528 | "After 550 global_steps,per token cost is 1.915,learning_rate is 1.00000\n", 529 | "After 600 global_steps,per token cost is 2.256,learning_rate is 1.00000\n", 530 | "After 650 global_steps,per token cost is 2.122,learning_rate is 1.00000\n", 531 | "After 700 global_steps,per token cost is 2.302,learning_rate is 1.00000\n", 532 | "After 750 global_steps,per token cost is 2.152,learning_rate is 1.00000\n", 533 | "After 800 global_steps,per token cost is 1.761,learning_rate is 1.00000\n", 534 | "After 850 global_steps,per token cost is 1.629,learning_rate is 1.00000\n", 535 | "After 900 global_steps,per token cost is 1.359,learning_rate is 1.00000\n", 536 | "After 950 global_steps,per token cost is 1.678,learning_rate is 1.00000\n", 537 | "After 1000 global_steps,per token cost is 2.382,learning_rate is 1.00000\n", 538 | "After 1050 global_steps,per token cost is 1.874,learning_rate is 1.00000\n", 539 | "After 1100 global_steps,per token cost is 1.742,learning_rate is 1.00000\n", 540 | "After 1150 global_steps,per token cost is 1.647,learning_rate is 1.00000\n", 541 | "In EPOCH: 2\n", 542 | "After 1200 global_steps,per token cost is 1.876,learning_rate is 0.70000\n", 543 | "After 1250 global_steps,per token cost is 1.677,learning_rate is 0.70000\n", 544 | "After 1300 global_steps,per token cost is 1.445,learning_rate is 0.70000\n", 545 | "After 1350 global_steps,per token cost is 0.987,learning_rate is 0.70000\n", 546 | "After 1400 global_steps,per token cost is 1.265,learning_rate is 0.70000\n", 547 | "After 1450 global_steps,per token cost is 1.468,learning_rate is 0.70000\n", 548 | "After 1500 global_steps,per token cost is 1.557,learning_rate is 0.70000\n", 549 | "After 1550 global_steps,per token cost is 1.438,learning_rate is 0.70000\n", 550 | "After 1600 global_steps,per token cost is 1.204,learning_rate is 0.70000\n", 551 | "After 1650 global_steps,per token cost is 1.480,learning_rate is 0.70000\n", 552 | "After 1700 global_steps,per token cost is 1.611,learning_rate is 0.70000\n", 553 | "After 1750 global_steps,per token cost is 1.511,learning_rate is 0.70000\n", 554 | "After 1800 global_steps,per token cost is 1.696,learning_rate is 0.70000\n", 555 | "After 1850 global_steps,per token cost is 1.664,learning_rate is 0.70000\n", 556 | "After 1900 global_steps,per token cost is 1.624,learning_rate is 0.70000\n", 557 | "After 1950 global_steps,per token cost is 1.385,learning_rate is 0.70000\n", 558 | "After 2000 global_steps,per token cost is 1.141,learning_rate is 0.70000\n", 559 | "After 2050 global_steps,per token cost is 0.869,learning_rate is 0.70000\n", 560 | "After 2100 global_steps,per token cost is 1.091,learning_rate is 0.70000\n", 561 | "After 2150 global_steps,per token cost is 1.086,learning_rate is 0.70000\n", 562 | "After 2200 global_steps,per token cost is 1.226,learning_rate is 0.70000\n", 563 | "After 2250 global_steps,per token cost is 1.461,learning_rate is 0.70000\n", 564 | "After 2300 global_steps,per token cost is 1.344,learning_rate is 0.70000\n", 565 | "In EPOCH: 3\n", 566 | "After 2350 global_steps,per token cost is 1.346,learning_rate is 0.49000\n", 567 | "After 2400 global_steps,per token cost is 1.525,learning_rate is 0.49000\n", 568 | "After 2450 global_steps,per token cost is 1.384,learning_rate is 0.49000\n", 569 | "After 2500 global_steps,per token cost is 0.892,learning_rate is 0.49000\n", 570 | "After 2550 global_steps,per token cost is 0.815,learning_rate is 0.49000\n", 571 | "After 2600 global_steps,per token cost is 1.259,learning_rate is 0.49000\n", 572 | "After 2650 global_steps,per token cost is 1.180,learning_rate is 0.49000\n", 573 | "After 2700 global_steps,per token cost is 1.178,learning_rate is 0.49000\n", 574 | "After 2750 global_steps,per token cost is 1.122,learning_rate is 0.49000\n", 575 | "After 2800 global_steps,per token cost is 1.074,learning_rate is 0.49000\n", 576 | "After 2850 global_steps,per token cost is 1.415,learning_rate is 0.49000\n", 577 | "After 2900 global_steps,per token cost is 1.409,learning_rate is 0.49000\n", 578 | "After 2950 global_steps,per token cost is 1.326,learning_rate is 0.49000\n", 579 | "After 3000 global_steps,per token cost is 1.502,learning_rate is 0.49000\n", 580 | "After 3050 global_steps,per token cost is 1.608,learning_rate is 0.49000\n", 581 | "After 3100 global_steps,per token cost is 1.347,learning_rate is 0.49000\n", 582 | "After 3150 global_steps,per token cost is 1.076,learning_rate is 0.49000\n", 583 | "After 3200 global_steps,per token cost is 0.790,learning_rate is 0.49000\n", 584 | "After 3250 global_steps,per token cost is 0.996,learning_rate is 0.49000\n", 585 | "After 3300 global_steps,per token cost is 1.269,learning_rate is 0.49000\n", 586 | "After 3350 global_steps,per token cost is 1.299,learning_rate is 0.49000\n", 587 | "After 3400 global_steps,per token cost is 1.397,learning_rate is 0.49000\n", 588 | "After 3450 global_steps,per token cost is 1.084,learning_rate is 0.49000\n", 589 | "In EPOCH: 4\n", 590 | "After 3500 global_steps,per token cost is 1.322,learning_rate is 0.34300\n", 591 | "After 3550 global_steps,per token cost is 1.174,learning_rate is 0.34300\n", 592 | "After 3600 global_steps,per token cost is 1.244,learning_rate is 0.34300\n", 593 | "After 3650 global_steps,per token cost is 0.905,learning_rate is 0.34300\n", 594 | "After 3700 global_steps,per token cost is 0.956,learning_rate is 0.34300\n", 595 | "After 3750 global_steps,per token cost is 1.027,learning_rate is 0.34300\n", 596 | "After 3800 global_steps,per token cost is 1.224,learning_rate is 0.34300\n", 597 | "After 3850 global_steps,per token cost is 1.405,learning_rate is 0.34300\n", 598 | "After 3900 global_steps,per token cost is 0.747,learning_rate is 0.34300\n", 599 | "After 3950 global_steps,per token cost is 0.739,learning_rate is 0.34300\n", 600 | "After 4000 global_steps,per token cost is 1.245,learning_rate is 0.34300\n", 601 | "After 4050 global_steps,per token cost is 1.008,learning_rate is 0.34300\n", 602 | "After 4100 global_steps,per token cost is 1.243,learning_rate is 0.34300\n", 603 | "After 4150 global_steps,per token cost is 1.303,learning_rate is 0.34300\n", 604 | "After 4200 global_steps,per token cost is 1.521,learning_rate is 0.34300\n", 605 | "After 4250 global_steps,per token cost is 1.167,learning_rate is 0.34300\n", 606 | "After 4300 global_steps,per token cost is 1.118,learning_rate is 0.34300\n", 607 | "After 4350 global_steps,per token cost is 0.657,learning_rate is 0.34300\n", 608 | "After 4400 global_steps,per token cost is 0.544,learning_rate is 0.34300\n", 609 | "After 4450 global_steps,per token cost is 0.950,learning_rate is 0.34300\n", 610 | "After 4500 global_steps,per token cost is 1.181,learning_rate is 0.34300\n", 611 | "After 4550 global_steps,per token cost is 1.271,learning_rate is 0.34300\n", 612 | "After 4600 global_steps,per token cost is 0.746,learning_rate is 0.34300\n", 613 | "In EPOCH: 5\n", 614 | "After 4650 global_steps,per token cost is 1.303,learning_rate is 0.24010\n", 615 | "After 4700 global_steps,per token cost is 1.117,learning_rate is 0.24010\n", 616 | "After 4750 global_steps,per token cost is 1.203,learning_rate is 0.24010\n", 617 | "After 4800 global_steps,per token cost is 0.825,learning_rate is 0.24010\n", 618 | "After 4850 global_steps,per token cost is 0.638,learning_rate is 0.24010\n", 619 | "After 4900 global_steps,per token cost is 0.877,learning_rate is 0.24010\n", 620 | "After 4950 global_steps,per token cost is 1.058,learning_rate is 0.24010\n", 621 | "After 5000 global_steps,per token cost is 0.943,learning_rate is 0.24010\n", 622 | "After 5050 global_steps,per token cost is 0.828,learning_rate is 0.24010\n", 623 | "After 5100 global_steps,per token cost is 0.860,learning_rate is 0.24010\n", 624 | "After 5150 global_steps,per token cost is 0.915,learning_rate is 0.24010\n", 625 | "After 5200 global_steps,per token cost is 1.021,learning_rate is 0.24010\n", 626 | "After 5250 global_steps,per token cost is 1.344,learning_rate is 0.24010\n", 627 | "After 5300 global_steps,per token cost is 1.096,learning_rate is 0.24010\n", 628 | "After 5350 global_steps,per token cost is 1.425,learning_rate is 0.24010\n", 629 | "After 5400 global_steps,per token cost is 0.972,learning_rate is 0.24010\n", 630 | "After 5450 global_steps,per token cost is 1.101,learning_rate is 0.24010\n", 631 | "After 5500 global_steps,per token cost is 0.593,learning_rate is 0.24010\n", 632 | "After 5550 global_steps,per token cost is 0.569,learning_rate is 0.24010\n", 633 | "After 5600 global_steps,per token cost is 0.745,learning_rate is 0.24010\n", 634 | "After 5650 global_steps,per token cost is 0.944,learning_rate is 0.24010\n", 635 | "After 5700 global_steps,per token cost is 1.116,learning_rate is 0.24010\n", 636 | "After 5750 global_steps,per token cost is 0.812,learning_rate is 0.24010\n", 637 | "In EPOCH: 6\n", 638 | "After 5800 global_steps,per token cost is 1.252,learning_rate is 0.16807\n", 639 | "After 5850 global_steps,per token cost is 0.931,learning_rate is 0.16807\n", 640 | "After 5900 global_steps,per token cost is 1.007,learning_rate is 0.16807\n", 641 | "After 5950 global_steps,per token cost is 1.001,learning_rate is 0.16807\n", 642 | "After 6000 global_steps,per token cost is 0.770,learning_rate is 0.16807\n", 643 | "After 6050 global_steps,per token cost is 0.859,learning_rate is 0.16807\n", 644 | "After 6100 global_steps,per token cost is 1.054,learning_rate is 0.16807\n", 645 | "After 6150 global_steps,per token cost is 1.127,learning_rate is 0.16807\n", 646 | "After 6200 global_steps,per token cost is 0.683,learning_rate is 0.16807\n", 647 | "After 6250 global_steps,per token cost is 0.933,learning_rate is 0.16807\n", 648 | "After 6300 global_steps,per token cost is 0.823,learning_rate is 0.16807\n", 649 | "After 6350 global_steps,per token cost is 0.927,learning_rate is 0.16807\n", 650 | "After 6400 global_steps,per token cost is 1.110,learning_rate is 0.16807\n", 651 | "After 6450 global_steps,per token cost is 0.911,learning_rate is 0.16807\n", 652 | "After 6500 global_steps,per token cost is 1.243,learning_rate is 0.16807\n", 653 | "After 6550 global_steps,per token cost is 0.856,learning_rate is 0.16807\n", 654 | "After 6600 global_steps,per token cost is 0.755,learning_rate is 0.16807\n", 655 | "After 6650 global_steps,per token cost is 0.660,learning_rate is 0.16807\n", 656 | "After 6700 global_steps,per token cost is 0.566,learning_rate is 0.16807\n", 657 | "After 6750 global_steps,per token cost is 0.749,learning_rate is 0.16807\n", 658 | "After 6800 global_steps,per token cost is 0.912,learning_rate is 0.16807\n", 659 | "After 6850 global_steps,per token cost is 0.848,learning_rate is 0.16807\n", 660 | "After 6900 global_steps,per token cost is 0.909,learning_rate is 0.16807\n", 661 | "In EPOCH: 7\n", 662 | "After 6950 global_steps,per token cost is 1.221,learning_rate is 0.11765\n", 663 | "After 7000 global_steps,per token cost is 0.989,learning_rate is 0.11765\n", 664 | "After 7050 global_steps,per token cost is 1.083,learning_rate is 0.11765\n", 665 | "After 7100 global_steps,per token cost is 0.834,learning_rate is 0.11765\n", 666 | "After 7150 global_steps,per token cost is 0.461,learning_rate is 0.11765\n", 667 | "After 7200 global_steps,per token cost is 0.692,learning_rate is 0.11765\n", 668 | "After 7250 global_steps,per token cost is 0.884,learning_rate is 0.11765\n", 669 | "After 7300 global_steps,per token cost is 0.838,learning_rate is 0.11765\n", 670 | "After 7350 global_steps,per token cost is 0.584,learning_rate is 0.11765\n", 671 | "After 7400 global_steps,per token cost is 0.714,learning_rate is 0.11765\n", 672 | "After 7450 global_steps,per token cost is 0.888,learning_rate is 0.11765\n", 673 | "After 7500 global_steps,per token cost is 0.859,learning_rate is 0.11765\n", 674 | "After 7550 global_steps,per token cost is 1.101,learning_rate is 0.11765\n", 675 | "After 7600 global_steps,per token cost is 0.969,learning_rate is 0.11765\n", 676 | "After 7650 global_steps,per token cost is 1.108,learning_rate is 0.11765\n", 677 | "After 7700 global_steps,per token cost is 0.907,learning_rate is 0.11765\n", 678 | "After 7750 global_steps,per token cost is 0.851,learning_rate is 0.11765\n", 679 | "After 7800 global_steps,per token cost is 0.567,learning_rate is 0.11765\n", 680 | "After 7850 global_steps,per token cost is 0.395,learning_rate is 0.11765\n", 681 | "After 7900 global_steps,per token cost is 0.666,learning_rate is 0.11765\n", 682 | "After 7950 global_steps,per token cost is 0.874,learning_rate is 0.11765\n", 683 | "After 8000 global_steps,per token cost is 0.945,learning_rate is 0.11765\n", 684 | "After 8050 global_steps,per token cost is 0.636,learning_rate is 0.11765\n", 685 | "After 8100 global_steps,per token cost is 0.598,learning_rate is 0.08235\n", 686 | "In EPOCH: 8\n", 687 | "After 8150 global_steps,per token cost is 1.055,learning_rate is 0.08235\n", 688 | "After 8200 global_steps,per token cost is 0.970,learning_rate is 0.08235\n", 689 | "After 8250 global_steps,per token cost is 0.936,learning_rate is 0.08235\n", 690 | "After 8300 global_steps,per token cost is 0.468,learning_rate is 0.08235\n", 691 | "After 8350 global_steps,per token cost is 0.610,learning_rate is 0.08235\n", 692 | "After 8400 global_steps,per token cost is 0.903,learning_rate is 0.08235\n", 693 | "After 8450 global_steps,per token cost is 0.718,learning_rate is 0.08235\n", 694 | "After 8500 global_steps,per token cost is 0.577,learning_rate is 0.08235\n", 695 | "After 8550 global_steps,per token cost is 0.633,learning_rate is 0.08235\n", 696 | "After 8600 global_steps,per token cost is 0.596,learning_rate is 0.08235\n", 697 | "After 8650 global_steps,per token cost is 0.829,learning_rate is 0.08235\n", 698 | "After 8700 global_steps,per token cost is 0.812,learning_rate is 0.08235\n", 699 | "After 8750 global_steps,per token cost is 0.826,learning_rate is 0.08235\n", 700 | "After 8800 global_steps,per token cost is 1.072,learning_rate is 0.08235\n", 701 | "After 8850 global_steps,per token cost is 0.958,learning_rate is 0.08235\n", 702 | "After 8900 global_steps,per token cost is 0.766,learning_rate is 0.08235\n", 703 | "After 8950 global_steps,per token cost is 0.510,learning_rate is 0.08235\n", 704 | "After 9000 global_steps,per token cost is 0.433,learning_rate is 0.08235\n", 705 | "After 9050 global_steps,per token cost is 0.442,learning_rate is 0.08235\n", 706 | "After 9100 global_steps,per token cost is 0.723,learning_rate is 0.08235\n", 707 | "After 9150 global_steps,per token cost is 0.803,learning_rate is 0.08235\n", 708 | "After 9200 global_steps,per token cost is 0.767,learning_rate is 0.08235\n", 709 | "After 9250 global_steps,per token cost is 0.646,learning_rate is 0.08235\n", 710 | "In EPOCH: 9\n", 711 | "After 9300 global_steps,per token cost is 0.939,learning_rate is 0.05765\n", 712 | "After 9350 global_steps,per token cost is 0.802,learning_rate is 0.05765\n", 713 | "After 9400 global_steps,per token cost is 0.766,learning_rate is 0.05765\n", 714 | "After 9450 global_steps,per token cost is 0.540,learning_rate is 0.05765\n", 715 | "After 9500 global_steps,per token cost is 0.510,learning_rate is 0.05765\n", 716 | "After 9550 global_steps,per token cost is 0.829,learning_rate is 0.05765\n", 717 | "After 9600 global_steps,per token cost is 0.697,learning_rate is 0.05765\n", 718 | "After 9650 global_steps,per token cost is 0.651,learning_rate is 0.05765\n", 719 | "After 9700 global_steps,per token cost is 0.564,learning_rate is 0.05765\n", 720 | "After 9750 global_steps,per token cost is 0.703,learning_rate is 0.05765\n", 721 | "After 9800 global_steps,per token cost is 0.867,learning_rate is 0.05765\n", 722 | "After 9850 global_steps,per token cost is 0.830,learning_rate is 0.05765\n", 723 | "After 9900 global_steps,per token cost is 0.817,learning_rate is 0.05765\n", 724 | "After 9950 global_steps,per token cost is 0.729,learning_rate is 0.05765\n", 725 | "After 10000 global_steps,per token cost is 0.932,learning_rate is 0.05765\n", 726 | "After 10050 global_steps,per token cost is 0.769,learning_rate is 0.05765\n", 727 | "After 10100 global_steps,per token cost is 0.541,learning_rate is 0.05765\n" 728 | ] 729 | } 730 | ], 731 | "source": [ 732 | "def main():\n", 733 | " tf.reset_default_graph()\n", 734 | " initializer = tf.random_uniform_initializer(-0.05,0.05)\n", 735 | " with tf.variable_scope(\"nmt_model\",reuse=None,initializer=initializer):\n", 736 | " train_model = NMTModel()\n", 737 | " \n", 738 | " data = MakeSrcTrgDataset(enc_vec_data,dec_vec_data,BATCH_SIZE)\n", 739 | " iterator = data.make_initializable_iterator()\n", 740 | " (src,src_size),(trg_input,trg_label,trg_size) = iterator.get_next()\n", 741 | " \n", 742 | " cost_op,train_op,learning_rate_op,global_step_op = train_model.forward(src,src_size,trg_input,trg_label,trg_size,data_length)\n", 743 | " \n", 744 | " saver = tf.train.Saver()\n", 745 | " step = 1\n", 746 | " \n", 747 | " with tf.Session() as sess:\n", 748 | " tf.global_variables_initializer().run()\n", 749 | " ckpt = tf.train.get_checkpoint_state(CHECKPOINT_PATH) #获取checkpoints对象 \n", 750 | " if ckpt and ckpt.model_checkpoint_path:##判断ckpt是否为空,若不为空,才进行模型的加载,否则从头开始训练 \n", 751 | " saver.restore(sess,ckpt.model_checkpoint_path)#恢复保存的神经网络结构,实现断点续训 \n", 752 | " for i in range(NUM_EPOCH):\n", 753 | " print(\"In EPOCH: %d\" %(i + 1))\n", 754 | " sess.run(iterator.initializer)\n", 755 | " step = run_epoch(sess,cost_op,train_op,learning_rate_op,global_step_op,step,saver,i + 1)\n", 756 | "if __name__ == \"__main__\":\n", 757 | " main()" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": null, 763 | "metadata": {}, 764 | "outputs": [], 765 | "source": [] 766 | } 767 | ], 768 | "metadata": { 769 | "kernelspec": { 770 | "display_name": "myconda", 771 | "language": "python", 772 | "name": "myconda" 773 | }, 774 | "language_info": { 775 | "codemirror_mode": { 776 | "name": "ipython", 777 | "version": 3 778 | }, 779 | "file_extension": ".py", 780 | "mimetype": "text/x-python", 781 | "name": "python", 782 | "nbconvert_exporter": "python", 783 | "pygments_lexer": "ipython3", 784 | "version": "3.7.7" 785 | } 786 | }, 787 | "nbformat": 4, 788 | "nbformat_minor": 4 789 | } 790 | --------------------------------------------------------------------------------