├── LICENSE ├── README.md ├── corpus ├── __init__.py ├── cornelldata.py ├── data │ ├── cornell │ │ ├── README.txt │ │ ├── movie_conversations.txt │ │ └── movie_lines.txt │ ├── lightweight │ │ └── README.md │ ├── opensubs │ │ └── README.md │ ├── samples │ │ ├── dataset-cornell-length40-filter1-vocabSize40000.pkl │ │ └── dataset-cornell.pkl │ ├── scotus │ │ └── README.md │ └── ubuntu │ │ └── README.md ├── lightweightdata.py ├── opensubsdata.py ├── scotusdata.py ├── textdata.py └── ubuntudata.py ├── disc ├── __init__.py ├── hier_disc.py └── hier_rnn_model.py ├── gen ├── __init__.py ├── gen_model.py ├── generator.py └── seq2seq.py ├── main.py └── utils ├── __init__.py └── conf.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AL/RL based Chatbot 2 | An implementation of Chatbot using Adversarial Learning and Reinforcement Learning 3 | ## References 4 | - **Paper**: 5 | [Adversarial Learning for Neural Dialogue Generation](https://arxiv.org/abs/1701.06547 "Adversarial Learning for Neural Dialogue Generation") 6 | - **Initial implementation**: 7 | Many thanks to [@liuyuemaicha](https://github.com/liuyuemaicha/Adversarial-Learning-for-Neural-Dialogue-Generation-in-Tensorflow "@liuyuemaicha") 8 | 9 | ## Dependencies 10 | - Python 3.6 11 | - TensorFlow 1.12.0 12 | - nltk 13 | - tqdm 14 | - numpy 15 | 16 | Note: You might also need to download additional data to make nltk work 17 | ``` 18 | python -m nltk.downloader punkt 19 | ``` 20 | 21 | ## Corpus 22 | - [Cornell Movie Dialogs Corpus](http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html "Cornell Movie Dialogs Corpus") (default). Already included when cloning the repository. 23 | - [OpenSubtitles](http://opus.nlpl.eu/OpenSubtitles.php "OpenSubtitles"). Follow the [instructions](corpus/data/opensubs "instructions") and use the argument `--corpus opensubs`. 24 | - Supreme Court Conversation Data. Follow the [instructions](corpus/data/scotus "instructions") and use the argument `--corpus scotus`. 25 | - [Ubuntu Dialogue Corpus](https://arxiv.org/abs/1506.08909 "Ubuntu Dialogue Corpus"). Follow the [instructions](corpus/data/ubuntu "instructions") and use the argument `--corpus ubuntu`. 26 | 27 | Many thanks to @Conchylicultor for providing above corpus information. 28 | 29 | ## Training 30 | `python main.py` 31 | 32 | **Follow below steps:** 33 | 1. Pre train the Generator and get the GEN_0 model. 34 | 2. GEN model test 35 | 3. Pre train the Discriminator and get the DISC_0 model 36 | 4. Train the GEN model and DISC model using AL/RL 37 | 5. GEN model test 38 | 39 | ## Test 40 | `python main.py --test 1` 41 | 42 | Results 43 | 44 | Q: Hi. 45 | A: Hey. 46 | Q: How are you.. 47 | A: Fine. 48 | Q: What is your name 49 | A: Harold ... 50 | Q: You are so nice 51 | A: I am? 52 | Q: Wow 53 | A: What do you mean? 54 | Q: That's great 55 | A: That's right. 56 | Q: Cool 57 | A: Yeah? 58 | Q: My name is Amigo 59 | A: I know. 60 | Q: What is the first letter of the alphabet ? 61 | A: What? 62 | Q: See you later 63 | A: Yeah. 64 | Q: Tell me something ... 65 | A: I'm not saying ... 66 | Q: You are so cute !!!!!!!!!! 67 | A: Wonderful! 68 | Q: Can we make this quick? 69 | A: Sure. 70 | Q: Where are you from 71 | A: Here. 72 | Q: Merry christmas ! 73 | A: Merry christmas. 74 | Q: How old are you 75 | A: Twenty-eight. -------------------------------------------------------------------------------- /corpus/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /corpus/cornelldata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | import ast 18 | 19 | """ 20 | Load the cornell movie dialog corpus. 21 | 22 | Available from here: 23 | http://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html 24 | 25 | """ 26 | 27 | class CornellData: 28 | """ 29 | 30 | """ 31 | 32 | def __init__(self, dirName): 33 | """ 34 | Args: 35 | dirName (string): directory where to load the corpus 36 | """ 37 | self.lines = {} 38 | self.conversations = [] 39 | 40 | MOVIE_LINES_FIELDS = ["lineID","characterID","movieID","character","text"] 41 | MOVIE_CONVERSATIONS_FIELDS = ["character1ID","character2ID","movieID","utteranceIDs"] 42 | 43 | self.lines = self.loadLines(os.path.join(dirName, "movie_lines.txt"), MOVIE_LINES_FIELDS) 44 | self.conversations = self.loadConversations(os.path.join(dirName, "movie_conversations.txt"), MOVIE_CONVERSATIONS_FIELDS) 45 | 46 | # TODO: Cleaner program (merge copy-paste) !! 47 | 48 | def loadLines(self, fileName, fields): 49 | """ 50 | Args: 51 | fileName (str): file to load 52 | field (set): fields to extract 53 | Return: 54 | dict>: the extracted fields for each line 55 | """ 56 | lines = {} 57 | 58 | with open(fileName, 'r', encoding='iso-8859-1') as f: # TODO: Solve Iso encoding pb ! 59 | for line in f: 60 | values = line.split(" +++$+++ ") 61 | 62 | # Extract fields 63 | lineObj = {} 64 | for i, field in enumerate(fields): 65 | lineObj[field] = values[i] 66 | 67 | lines[lineObj['lineID']] = lineObj 68 | 69 | return lines 70 | 71 | def loadConversations(self, fileName, fields): 72 | """ 73 | Args: 74 | fileName (str): file to load 75 | field (set): fields to extract 76 | Return: 77 | dict>: the extracted fields for each line 78 | """ 79 | conversations = [] 80 | 81 | with open(fileName, 'r', encoding='iso-8859-1') as f: # TODO: Solve Iso encoding pb ! 82 | for line in f: 83 | values = line.split(" +++$+++ ") 84 | 85 | # Extract fields 86 | convObj = {} 87 | for i, field in enumerate(fields): 88 | convObj[field] = values[i] 89 | 90 | # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]") 91 | lineIds = ast.literal_eval(convObj["utteranceIDs"]) 92 | 93 | # Reassemble lines 94 | convObj["lines"] = [] 95 | for lineId in lineIds: 96 | convObj["lines"].append(self.lines[lineId]) 97 | 98 | conversations.append(convObj) 99 | 100 | return conversations 101 | 102 | def getConversations(self): 103 | return self.conversations 104 | -------------------------------------------------------------------------------- /corpus/data/cornell/README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/corpus/data/cornell/README.txt -------------------------------------------------------------------------------- /corpus/data/cornell/movie_lines.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/corpus/data/cornell/movie_lines.txt -------------------------------------------------------------------------------- /corpus/data/lightweight/README.md: -------------------------------------------------------------------------------- 1 | You can create your own dataset using a simple custom format where one line correspond to one line of dialogue. Use `===` to separate conversations between 2 people. Example of conversation file: 2 | 3 | 4 | ``` 5 | from A to B 6 | from B to A 7 | from A to B 8 | from B to A 9 | from A to B 10 | === 11 | from C to D 12 | from D to C 13 | from C to D 14 | === 15 | from E to F 16 | from F to E 17 | from E to F 18 | from F to E 19 | ``` 20 | 21 | To use your conversation file `.txt`, copy it in this repository and launch the program with the option `--corpus lightweight --datasetTag `. 22 | -------------------------------------------------------------------------------- /corpus/data/opensubs/README.md: -------------------------------------------------------------------------------- 1 | In order to use the OpenSubtitles dataset, you must first download and unpack the archive in this folder. The program will automatically look at every subfolders here. Train with this dataset using `./main.py --corpus opensubs`. 2 | 3 | Download english corpus directly here: 4 | http://opus.lingfil.uu.se/download.php?f=OpenSubtitles/en.tar.gz 5 | 6 | All details on the corpus here: 7 | http://opus.lingfil.uu.se/OpenSubtitles.php 8 | 9 | Note that even if that has not been tested, the program should be compatible with other languages as well. Just download the subtitles from the language you want from the OpenSubtitles database website. 10 | -------------------------------------------------------------------------------- /corpus/data/samples/dataset-cornell-length40-filter1-vocabSize40000.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/corpus/data/samples/dataset-cornell-length40-filter1-vocabSize40000.pkl -------------------------------------------------------------------------------- /corpus/data/samples/dataset-cornell.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/corpus/data/samples/dataset-cornell.pkl -------------------------------------------------------------------------------- /corpus/data/scotus/README.md: -------------------------------------------------------------------------------- 1 | Download and extract the Scotus dataset here: 2 | 3 | ```bash 4 | # From this directory: 5 | wget https://github.com/pender/chatbot-rnn/raw/master/data/scotus/scotus.bz2 && bzip2 -dk scotus.bz2 && rm scotus.bz2 6 | ``` 7 | -------------------------------------------------------------------------------- /corpus/data/ubuntu/README.md: -------------------------------------------------------------------------------- 1 | Download and extract the Ubuntu Dialogue Corpus dataset here: 2 | 3 | Source: http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ 4 | 5 | ```bash 6 | # From this directory: 7 | wget http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz && tar -xvzf ubuntu_dialogs.tgz && rm ubuntu_dialogs.tgz 8 | ``` 9 | 10 | Individual conversation files will be located in a `dialogs/` subdirectory. 11 | -------------------------------------------------------------------------------- /corpus/lightweightdata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | 18 | """ 19 | Load data from a dataset of simply-formatted data 20 | 21 | from A to B 22 | from B to A 23 | from A to B 24 | from B to A 25 | from A to B 26 | === 27 | from C to D 28 | from D to C 29 | from C to D 30 | from D to C 31 | from C to D 32 | from D to C 33 | ... 34 | 35 | `===` lines just separate linear conversations between 2 people. 36 | 37 | """ 38 | 39 | class LightweightData: 40 | """ 41 | """ 42 | 43 | def __init__(self, lightweightFile): 44 | """ 45 | Args: 46 | lightweightFile (string): file containing our lightweight-formatted corpus 47 | """ 48 | self.CONVERSATION_SEP = "===" 49 | self.conversations = [] 50 | self.loadLines(lightweightFile + '.txt') 51 | 52 | def loadLines(self, fileName): 53 | """ 54 | Args: 55 | fileName (str): file to load 56 | """ 57 | 58 | linesBuffer = [] 59 | with open(fileName, 'r') as f: 60 | for line in f: 61 | l = line.strip() 62 | if l == self.CONVERSATION_SEP: 63 | self.conversations.append({"lines": linesBuffer}) 64 | linesBuffer = [] 65 | else: 66 | linesBuffer.append({"text": l}) 67 | if len(linesBuffer): # Eventually flush the last conversation 68 | self.conversations.append({"lines": linesBuffer}) 69 | 70 | def getConversations(self): 71 | return self.conversations 72 | -------------------------------------------------------------------------------- /corpus/opensubsdata.py: -------------------------------------------------------------------------------- 1 | # Based on code from https://github.com/AlJohri/OpenSubtitles 2 | # by Al Johri 3 | 4 | import xml.etree.ElementTree as ET 5 | import datetime 6 | import os 7 | import sys 8 | import json 9 | import re 10 | import pprint 11 | 12 | from gzip import GzipFile 13 | from tqdm import tqdm 14 | 15 | """ 16 | Load the opensubtitles dialog corpus. 17 | """ 18 | 19 | class OpensubsData: 20 | """ 21 | 22 | """ 23 | 24 | def __init__(self, dirName): 25 | """ 26 | Args: 27 | dirName (string): directory where to load the corpus 28 | """ 29 | 30 | # Hack this to filter on subset of Opensubtitles 31 | # dirName = "%s/en/Action" % dirName 32 | 33 | print("Loading OpenSubtitles conversations in %s." % dirName) 34 | self.conversations = [] 35 | self.tag_re = re.compile(r'(|<[^>]*>)') 36 | self.conversations = self.loadConversations(dirName) 37 | 38 | def loadConversations(self, dirName): 39 | """ 40 | Args: 41 | dirName (str): folder to load 42 | Return: 43 | array(question, answer): the extracted QA pairs 44 | """ 45 | conversations = [] 46 | dirList = self.filesInDir(dirName) 47 | for filepath in tqdm(dirList, "OpenSubtitles data files"): 48 | if filepath.endswith('gz'): 49 | try: 50 | doc = self.getXML(filepath) 51 | conversations.extend(self.genList(doc)) 52 | except ValueError: 53 | tqdm.write("Skipping file %s with errors." % filepath) 54 | except: 55 | print("Unexpected error:", sys.exc_info()[0]) 56 | raise 57 | return conversations 58 | 59 | def getConversations(self): 60 | return self.conversations 61 | 62 | def genList(self, tree): 63 | root = tree.getroot() 64 | 65 | timeFormat = '%H:%M:%S' 66 | maxDelta = datetime.timedelta(seconds=1) 67 | 68 | startTime = datetime.datetime.min 69 | strbuf = '' 70 | sentList = [] 71 | 72 | for child in root: 73 | for elem in child: 74 | if elem.tag == 'time': 75 | elemID = elem.attrib['id'] 76 | elemVal = elem.attrib['value'][:-4] 77 | if elemID[-1] == 'S': 78 | startTime = datetime.datetime.strptime(elemVal, timeFormat) 79 | else: 80 | sentList.append((strbuf.strip(), startTime, datetime.datetime.strptime(elemVal, timeFormat))) 81 | strbuf = '' 82 | else: 83 | try: 84 | strbuf = strbuf + " " + elem.text 85 | except: 86 | pass 87 | 88 | conversations = [] 89 | for idx in range(0, len(sentList) - 1): 90 | cur = sentList[idx] 91 | nxt = sentList[idx + 1] 92 | if nxt[1] - cur[2] <= maxDelta and cur and nxt: 93 | tmp = {} 94 | tmp["lines"] = [] 95 | tmp["lines"].append(self.getLine(cur[0])) 96 | tmp["lines"].append(self.getLine(nxt[0])) 97 | if self.filter(tmp): 98 | conversations.append(tmp) 99 | 100 | return conversations 101 | 102 | def getLine(self, sentence): 103 | line = {} 104 | line["text"] = self.tag_re.sub('', sentence).replace('\\\'','\'').strip().lower() 105 | return line 106 | 107 | def filter(self, lines): 108 | # Use the followint to customize filtering of QA pairs 109 | # 110 | # startwords = ("what", "how", "when", "why", "where", "do", "did", "is", "are", "can", "could", "would", "will") 111 | # question = lines["lines"][0]["text"] 112 | # if not question.endswith('?'): 113 | # return False 114 | # if not question.split(' ')[0] in startwords: 115 | # return False 116 | # 117 | return True 118 | 119 | def getXML(self, filepath): 120 | fext = os.path.splitext(filepath)[1] 121 | if fext == '.gz': 122 | tmp = GzipFile(filename=filepath) 123 | return ET.parse(tmp) 124 | else: 125 | return ET.parse(filepath) 126 | 127 | def filesInDir(self, dirname): 128 | result = [] 129 | for dirpath, dirs, files in os.walk(dirname): 130 | for filename in files: 131 | fname = os.path.join(dirpath, filename) 132 | result.append(fname) 133 | return result 134 | -------------------------------------------------------------------------------- /corpus/scotusdata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | 18 | """ 19 | Load transcripts from the Supreme Court of the USA. 20 | 21 | Available from here: 22 | https://github.com/pender/chatbot-rnn 23 | 24 | """ 25 | 26 | class ScotusData: 27 | """ 28 | """ 29 | 30 | def __init__(self, dirName): 31 | """ 32 | Args: 33 | dirName (string): directory where to load the corpus 34 | """ 35 | self.lines = self.loadLines(os.path.join(dirName, "scotus")) 36 | self.conversations = [{"lines": self.lines}] 37 | 38 | 39 | def loadLines(self, fileName): 40 | """ 41 | Args: 42 | fileName (str): file to load 43 | Return: 44 | list>: the extracted fields for each line 45 | """ 46 | lines = [] 47 | 48 | with open(fileName, 'r') as f: 49 | for line in f: 50 | l = line[line.index(":")+1:].strip() # Strip name of speaker. 51 | 52 | lines.append({"text": l}) 53 | 54 | return lines 55 | 56 | 57 | def getConversations(self): 58 | return self.conversations 59 | -------------------------------------------------------------------------------- /corpus/textdata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """ 17 | Loads the dialogue corpus, builds the vocabulary 18 | """ 19 | 20 | import numpy as np 21 | import nltk # For tokenize 22 | from tqdm import tqdm # Progress bar 23 | import pickle # Saving the data 24 | import math # For float comparison 25 | import os # Checking file existance 26 | import random 27 | import string 28 | import collections 29 | 30 | from corpus.cornelldata import CornellData 31 | from corpus.opensubsdata import OpensubsData 32 | from corpus.scotusdata import ScotusData 33 | from corpus.ubuntudata import UbuntuData 34 | from corpus.lightweightdata import LightweightData 35 | 36 | 37 | class Batch: 38 | """Struct containing batches info 39 | """ 40 | def __init__(self): 41 | self.encoderSeqs = [] 42 | self.decoderSeqs = [] 43 | self.targetSeqs = [] 44 | self.weights = [] 45 | 46 | 47 | class TextData: 48 | """Dataset class 49 | Warning: No vocabulary limit 50 | """ 51 | 52 | availableCorpus = collections.OrderedDict([ # OrderedDict because the first element is the default choice 53 | ('cornell', CornellData), 54 | ('opensubs', OpensubsData), 55 | ('scotus', ScotusData), 56 | ('ubuntu', UbuntuData), 57 | ('lightweight', LightweightData), 58 | ]) 59 | 60 | @staticmethod 61 | def corpusChoices(): 62 | """Return the dataset availables 63 | Return: 64 | list: the supported corpus 65 | """ 66 | return list(TextData.availableCorpus.keys()) 67 | 68 | def __init__(self, args): 69 | """Load all conversations 70 | Args: 71 | args: parameters of the model 72 | """ 73 | # Model parameters 74 | self.args = args 75 | 76 | # Path variables 77 | self.corpusDir = os.path.join(self.args.rootDir, 'data', self.args.corpus) 78 | basePath = self._constructBasePath() 79 | self.fullSamplesPath = basePath + '.pkl' # Full sentences length/vocab 80 | self.filteredSamplesPath = basePath + '-length{}-filter{}-vocabSize{}.pkl'.format( 81 | self.args.maxLength, 82 | self.args.filterVocab, 83 | self.args.vocabularySize, 84 | ) # Sentences/vocab filtered for this model 85 | 86 | self.padToken = -1 # Padding 87 | self.goToken = -1 # Start of sequence 88 | self.eosToken = -1 # End of sequence 89 | self.unknownToken = -1 # Word dropped from vocabulary 90 | 91 | self.trainingSamples = [] # 2d array containing each question and his answer [[input,target]] 92 | 93 | self.word2id = {} 94 | self.id2word = {} # For a rapid conversion (Warning: If replace dict by list, modify the filtering to avoid linear complexity with del) 95 | self.idCount = {} # Useful to filters the words (TODO: Could replace dict by list or use collections.Counter) 96 | 97 | self.loadCorpus() 98 | 99 | # Plot some stats: 100 | self._printStats() 101 | 102 | # if self.args.playDataset: 103 | # self.playDataset() 104 | 105 | def _printStats(self): 106 | print('Loaded {}: {} words, {} QA'.format(self.args.corpus, len(self.word2id), len(self.trainingSamples))) 107 | 108 | def _constructBasePath(self): 109 | """Return the name of the base prefix of the current dataset 110 | """ 111 | path = os.path.join(self.args.rootDir, 'data' + os.sep + 'samples' + os.sep) 112 | path += 'dataset-{}'.format(self.args.corpus) 113 | if self.args.datasetTag: 114 | path += '-' + self.args.datasetTag 115 | return path 116 | 117 | def makeLighter(self, ratioDataset): 118 | """Only keep a small fraction of the dataset, given by the ratio 119 | """ 120 | #if not math.isclose(ratioDataset, 1.0): 121 | # self.shuffle() # Really ? 122 | # print('WARNING: Ratio feature not implemented !!!') 123 | pass 124 | 125 | def shuffle(self): 126 | """Shuffle the training samples 127 | """ 128 | print('Shuffling the dataset...') 129 | random.shuffle(self.trainingSamples) 130 | 131 | def add_pad(self, source_seq, target_len): 132 | target_seq = source_seq + [self.padToken] * (target_len - len(source_seq)) 133 | return target_seq 134 | 135 | # def get_batch(self, samples, Q_size, A_size): 136 | def get_batch(self, samples, encoder_size, decoder_size): 137 | batch = Batch() 138 | batchSize = len(samples) 139 | # encoder_size = Q_size 140 | # decoder_size = A_size + 2 141 | # Create the batch tensor 142 | for i in range(batchSize): 143 | # Unpack the sample 144 | sample = samples[i] 145 | # if not self.args.test and self.args.watsonMode: # Watson mode: invert question and answer 146 | # sample = list(reversed(sample)) 147 | # if not self.args.test and self.args.autoEncode: # Autoencode: use either the question or answer for both input and output 148 | # k = random.randint(0, 1) 149 | # sample = (sample[k], sample[k]) 150 | # TODO: Why re-processed that at each epoch ? Could precompute that 151 | # once and reuse those every time. Is not the bottleneck so won't change 152 | # much ? and if preprocessing, should be compatible with autoEncode & cie. 153 | batch.encoderSeqs.append(list(reversed(sample[0]))) # Reverse inputs (and not outputs), little trick as defined on the original seq2seq paper 154 | batch.decoderSeqs.append([self.goToken] + sample[1] + [self.eosToken]) # Add the and tokens 155 | batch.targetSeqs.append(batch.decoderSeqs[-1][1:]) # Same as decoder, but shifted to the left (ignore the ) 156 | 157 | # Long sentences should have been filtered during the dataset creation 158 | assert len(batch.encoderSeqs[i]) <= encoder_size 159 | assert len(batch.decoderSeqs[i]) <= decoder_size 160 | 161 | # TODO: Should use tf batch function to automatically add padding and batch samples 162 | # Add padding & define weight 163 | batch.encoderSeqs[i] = [self.padToken] * (encoder_size - len(batch.encoderSeqs[i])) + batch.encoderSeqs[i] # Left padding for the input 164 | batch.weights.append([1.0] * len(batch.targetSeqs[i]) + [0.0] * (decoder_size - len(batch.targetSeqs[i]))) 165 | batch.decoderSeqs[i] = batch.decoderSeqs[i] + [self.padToken] * (decoder_size - len(batch.decoderSeqs[i])) 166 | batch.targetSeqs[i] = batch.targetSeqs[i] + [self.padToken] * (decoder_size - len(batch.targetSeqs[i])) 167 | 168 | # Simple hack to reshape the batch 169 | encoderSeqsT = [] # Corrected orientation 170 | for i in range(encoder_size): 171 | encoderSeqT = [] 172 | for j in range(batchSize): 173 | encoderSeqT.append(batch.encoderSeqs[j][i]) 174 | encoderSeqsT.append(encoderSeqT) 175 | batch.encoderSeqs = encoderSeqsT 176 | 177 | decoderSeqsT = [] 178 | targetSeqsT = [] 179 | weightsT = [] 180 | for i in range(decoder_size): 181 | decoderSeqT = [] 182 | targetSeqT = [] 183 | weightT = [] 184 | for j in range(batchSize): 185 | decoderSeqT.append(batch.decoderSeqs[j][i]) 186 | targetSeqT.append(batch.targetSeqs[j][i]) 187 | weightT.append(batch.weights[j][i]) 188 | decoderSeqsT.append(decoderSeqT) 189 | targetSeqsT.append(targetSeqT) 190 | weightsT.append(weightT) 191 | batch.decoderSeqs = decoderSeqsT 192 | batch.targetSeqs = targetSeqsT 193 | batch.weights = weightsT 194 | 195 | # # Debug 196 | # self.printBatch(batch) # Input inverted, padding should be correct 197 | # print(self.sequence2str(samples[0][0])) 198 | # print(self.sequence2str(samples[0][1])) # Check we did not modified the original sample 199 | 200 | return batch 201 | 202 | def _createBatch(self, samples): 203 | """Create a single batch from the list of sample. The batch size is automatically defined by the number of 204 | samples given. 205 | The inputs should already be inverted. The target should already have and 206 | Warning: This function should not make direct calls to args.batchSize !!! 207 | Args: 208 | samples (list): a list of samples, each sample being on the form [input, target] 209 | Return: 210 | Batch: a batch object en 211 | """ 212 | 213 | batch = Batch() 214 | batchSize = len(samples) 215 | 216 | # Create the batch tensor 217 | for i in range(batchSize): 218 | # Unpack the sample 219 | sample = samples[i] 220 | if not self.args.test and self.args.watsonMode: # Watson mode: invert question and answer 221 | sample = list(reversed(sample)) 222 | if not self.args.test and self.args.autoEncode: # Autoencode: use either the question or answer for both input and output 223 | k = random.randint(0, 1) 224 | sample = (sample[k], sample[k]) 225 | # TODO: Why re-processed that at each epoch ? Could precompute that 226 | # once and reuse those every time. Is not the bottleneck so won't change 227 | # much ? and if preprocessing, should be compatible with autoEncode & cie. 228 | batch.encoderSeqs.append(list(reversed(sample[0]))) # Reverse inputs (and not outputs), little trick as defined on the original seq2seq paper 229 | batch.decoderSeqs.append([self.goToken] + sample[1] + [self.eosToken]) # Add the and tokens 230 | batch.targetSeqs.append(batch.decoderSeqs[-1][1:]) # Same as decoder, but shifted to the left (ignore the ) 231 | 232 | # Long sentences should have been filtered during the dataset creation 233 | assert len(batch.encoderSeqs[i]) <= self.args.maxLengthEnco 234 | assert len(batch.decoderSeqs[i]) <= self.args.maxLengthDeco 235 | 236 | # TODO: Should use tf batch function to automatically add padding and batch samples 237 | # Add padding & define weight 238 | batch.encoderSeqs[i] = [self.padToken] * (self.args.maxLengthEnco - len(batch.encoderSeqs[i])) + batch.encoderSeqs[i] # Left padding for the input 239 | batch.weights.append([1.0] * len(batch.targetSeqs[i]) + [0.0] * (self.args.maxLengthDeco - len(batch.targetSeqs[i]))) 240 | batch.decoderSeqs[i] = batch.decoderSeqs[i] + [self.padToken] * (self.args.maxLengthDeco - len(batch.decoderSeqs[i])) 241 | batch.targetSeqs[i] = batch.targetSeqs[i] + [self.padToken] * (self.args.maxLengthDeco - len(batch.targetSeqs[i])) 242 | 243 | # Simple hack to reshape the batch 244 | encoderSeqsT = [] # Corrected orientation 245 | for i in range(self.args.maxLengthEnco): 246 | encoderSeqT = [] 247 | for j in range(batchSize): 248 | encoderSeqT.append(batch.encoderSeqs[j][i]) 249 | encoderSeqsT.append(encoderSeqT) 250 | batch.encoderSeqs = encoderSeqsT 251 | 252 | decoderSeqsT = [] 253 | targetSeqsT = [] 254 | weightsT = [] 255 | for i in range(self.args.maxLengthDeco): 256 | decoderSeqT = [] 257 | targetSeqT = [] 258 | weightT = [] 259 | for j in range(batchSize): 260 | decoderSeqT.append(batch.decoderSeqs[j][i]) 261 | targetSeqT.append(batch.targetSeqs[j][i]) 262 | weightT.append(batch.weights[j][i]) 263 | decoderSeqsT.append(decoderSeqT) 264 | targetSeqsT.append(targetSeqT) 265 | weightsT.append(weightT) 266 | batch.decoderSeqs = decoderSeqsT 267 | batch.targetSeqs = targetSeqsT 268 | batch.weights = weightsT 269 | 270 | # # Debug 271 | # self.printBatch(batch) # Input inverted, padding should be correct 272 | # print(self.sequence2str(samples[0][0])) 273 | # print(self.sequence2str(samples[0][1])) # Check we did not modified the original sample 274 | 275 | return batch 276 | 277 | def getBatches(self): 278 | """Prepare the batches for the current epoch 279 | Return: 280 | list: Get a list of the batches for the next epoch 281 | """ 282 | self.shuffle() 283 | 284 | batches = [] 285 | 286 | def genNextSamples(): 287 | """ Generator over the mini-batch training samples 288 | """ 289 | for i in range(0, self.getSampleSize(), self.args.batchSize): 290 | yield self.trainingSamples[i:min(i + self.args.batchSize, self.getSampleSize())] 291 | 292 | # TODO: Should replace that by generator (better: by tf.queue) 293 | 294 | for samples in genNextSamples(): 295 | batch = self._createBatch(samples) 296 | batches.append(batch) 297 | return batches 298 | 299 | def getSampleSize(self): 300 | """Return the size of the dataset 301 | Return: 302 | int: Number of training samples 303 | """ 304 | return len(self.trainingSamples) 305 | 306 | def getVocabularySize(self): 307 | """Return the number of words present in the dataset 308 | Return: 309 | int: Number of word on the loader corpus 310 | """ 311 | return len(self.word2id) 312 | 313 | def loadCorpus(self): 314 | """Load/create the conversations data 315 | """ 316 | datasetExist = os.path.isfile(self.filteredSamplesPath) 317 | if not datasetExist: # First time we load the database: creating all files 318 | print('Training samples not found. Creating dataset...') 319 | 320 | datasetExist = os.path.isfile(self.fullSamplesPath) # Try to construct the dataset from the preprocessed entry 321 | if not datasetExist: 322 | print('Constructing full dataset...') 323 | 324 | optional = '' 325 | if self.args.corpus == 'lightweight': 326 | if not self.args.datasetTag: 327 | raise ValueError('Use the --datasetTag to define the lightweight file to use.') 328 | optional = os.sep + self.args.datasetTag # HACK: Forward the filename 329 | 330 | # Corpus creation 331 | corpusData = TextData.availableCorpus[self.args.corpus](self.corpusDir + optional) 332 | self.createFullCorpus(corpusData.getConversations()) 333 | self.saveDataset(self.fullSamplesPath) 334 | else: 335 | self.loadDataset(self.fullSamplesPath) 336 | self._printStats() 337 | 338 | print('Filtering words (vocabSize = {} and wordCount > {})...'.format( 339 | self.args.vocabularySize, 340 | self.args.filterVocab 341 | )) 342 | self.filterFromFull() # Extract the sub vocabulary for the given maxLength and filterVocab 343 | 344 | # Saving 345 | print('Saving dataset...') 346 | self.saveDataset(self.filteredSamplesPath) # Saving tf samples 347 | else: 348 | self.loadDataset(self.filteredSamplesPath) 349 | 350 | assert self.padToken == 0 351 | 352 | def saveDataset(self, filename): 353 | """Save samples to file 354 | Args: 355 | filename (str): pickle filename 356 | """ 357 | 358 | with open(os.path.join(filename), 'wb') as handle: 359 | data = { # Warning: If adding something here, also modifying loadDataset 360 | 'word2id': self.word2id, 361 | 'id2word': self.id2word, 362 | 'idCount': self.idCount, 363 | 'trainingSamples': self.trainingSamples 364 | } 365 | pickle.dump(data, handle, -1) # Using the highest protocol available 366 | 367 | def loadDataset(self, filename): 368 | """Load samples from file 369 | Args: 370 | filename (str): pickle filename 371 | """ 372 | dataset_path = os.path.join(filename) 373 | print('Loading dataset from {}'.format(dataset_path)) 374 | with open(dataset_path, 'rb') as handle: 375 | data = pickle.load(handle) # Warning: If adding something here, also modifying saveDataset 376 | self.word2id = data['word2id'] 377 | self.id2word = data['id2word'] 378 | self.idCount = data.get('idCount', None) 379 | self.trainingSamples = data['trainingSamples'] 380 | 381 | self.padToken = self.word2id[''] 382 | self.goToken = self.word2id[''] 383 | self.eosToken = self.word2id[''] 384 | self.unknownToken = self.word2id[''] # Restore special words 385 | 386 | def filterFromFull(self): 387 | """ Load the pre-processed full corpus and filter the vocabulary / sentences 388 | to match the given model options 389 | """ 390 | 391 | def mergeSentences(sentences, fromEnd=False): 392 | """Merge the sentences until the max sentence length is reached 393 | Also decrement id count for unused sentences. 394 | Args: 395 | sentences (list>): the list of sentences for the current line 396 | fromEnd (bool): Define the question on the answer 397 | Return: 398 | list: the list of the word ids of the sentence 399 | """ 400 | # We add sentence by sentence until we reach the maximum length 401 | merged = [] 402 | 403 | # If question: we only keep the last sentences 404 | # If answer: we only keep the first sentences 405 | if fromEnd: 406 | sentences = reversed(sentences) 407 | 408 | for sentence in sentences: 409 | 410 | # If the total length is not too big, we still can add one more sentence 411 | if len(merged) + len(sentence) <= self.args.maxLength: 412 | if fromEnd: # Append the sentence 413 | merged = sentence + merged 414 | else: 415 | merged = merged + sentence 416 | else: # If the sentence is not used, neither are the words 417 | for w in sentence: 418 | self.idCount[w] -= 1 419 | return merged 420 | 421 | newSamples = [] 422 | 423 | # 1st step: Iterate over all words and add filters the sentences 424 | # according to the sentence lengths 425 | for inputWords, targetWords in tqdm(self.trainingSamples, desc='Filter sentences:', leave=False): 426 | inputWords = mergeSentences(inputWords, fromEnd=True) 427 | targetWords = mergeSentences(targetWords, fromEnd=False) 428 | 429 | newSamples.append([inputWords, targetWords]) 430 | words = [] 431 | 432 | # WARNING: DO NOT FILTER THE UNKNOWN TOKEN !!! Only word which has count==0 ? 433 | 434 | # 2nd step: filter the unused words and replace them by the unknown token 435 | # This is also where we update the correnspondance dictionaries 436 | specialTokens = { # TODO: bad HACK to filter the special tokens. Error prone if one day add new special tokens 437 | self.padToken, 438 | self.goToken, 439 | self.eosToken, 440 | self.unknownToken 441 | } 442 | newMapping = {} # Map the full words ids to the new one (TODO: Should be a list) 443 | newId = 0 444 | 445 | selectedWordIds = collections \ 446 | .Counter(self.idCount) \ 447 | .most_common(self.args.vocabularySize or None) # Keep all if vocabularySize == 0 448 | selectedWordIds = {k for k, v in selectedWordIds if v > self.args.filterVocab} 449 | selectedWordIds |= specialTokens 450 | 451 | for wordId, count in [(i, self.idCount[i]) for i in range(len(self.idCount))]: # Iterate in order 452 | if wordId in selectedWordIds: # Update the word id 453 | newMapping[wordId] = newId 454 | word = self.id2word[wordId] # The new id has changed, update the dictionaries 455 | del self.id2word[wordId] # Will be recreated if newId == wordId 456 | self.word2id[word] = newId 457 | self.id2word[newId] = word 458 | newId += 1 459 | else: # Cadidate to filtering, map it to unknownToken (Warning: don't filter special token) 460 | newMapping[wordId] = self.unknownToken 461 | del self.word2id[self.id2word[wordId]] # The word isn't used anymore 462 | del self.id2word[wordId] 463 | 464 | # Last step: replace old ids by new ones and filters empty sentences 465 | def replace_words(words): 466 | valid = False # Filter empty sequences 467 | for i, w in enumerate(words): 468 | words[i] = newMapping[w] 469 | if words[i] != self.unknownToken: # Also filter if only contains unknown tokens 470 | valid = True 471 | return valid 472 | 473 | self.trainingSamples.clear() 474 | 475 | for inputWords, targetWords in tqdm(newSamples, desc='Replace ids:', leave=False): 476 | valid = True 477 | valid &= replace_words(inputWords) 478 | valid &= replace_words(targetWords) 479 | valid &= targetWords.count(self.unknownToken) == 0 # Filter target with out-of-vocabulary target words ? 480 | 481 | if valid: 482 | self.trainingSamples.append([inputWords, targetWords]) # TODO: Could replace list by tuple 483 | 484 | self.idCount.clear() # Not usefull anymore. Free data 485 | 486 | def createFullCorpus(self, conversations): 487 | """Extract all data from the given vocabulary. 488 | Save the data on disk. Note that the entire corpus is pre-processed 489 | without restriction on the sentence length or vocab size. 490 | """ 491 | # Add standard tokens 492 | self.padToken = self.getWordId('') # Padding (Warning: first things to add > id=0 !!) 493 | self.goToken = self.getWordId('') # Start of sequence 494 | self.eosToken = self.getWordId('') # End of sequence 495 | self.unknownToken = self.getWordId('') # Word dropped from vocabulary 496 | 497 | # Preprocessing data 498 | 499 | for conversation in tqdm(conversations, desc='Extract conversations'): 500 | self.extractConversation(conversation) 501 | 502 | # The dataset will be saved in the same order it has been extracted 503 | 504 | def extractConversation(self, conversation): 505 | """Extract the sample lines from the conversations 506 | Args: 507 | conversation (Obj): a conversation object containing the lines to extract 508 | """ 509 | 510 | if self.args.skipLines: # WARNING: The dataset won't be regenerated if the choice evolve (have to use the datasetTag) 511 | step = 2 512 | else: 513 | step = 1 514 | 515 | # Iterate over all the lines of the conversation 516 | for i in tqdm_wrap( 517 | range(0, len(conversation['lines']) - 1, step), # We ignore the last line (no answer for it) 518 | desc='Conversation', 519 | leave=False 520 | ): 521 | inputLine = conversation['lines'][i] 522 | targetLine = conversation['lines'][i+1] 523 | 524 | inputWords = self.extractText(inputLine['text']) 525 | targetWords = self.extractText(targetLine['text']) 526 | 527 | if inputWords and targetWords: # Filter wrong samples (if one of the list is empty) 528 | self.trainingSamples.append([inputWords, targetWords]) 529 | 530 | def extractText(self, line): 531 | """Extract the words from a sample lines 532 | Args: 533 | line (str): a line containing the text to extract 534 | Return: 535 | list>: the list of sentences of word ids of the sentence 536 | """ 537 | sentences = [] # List[List[str]] 538 | 539 | # Extract sentences 540 | sentencesToken = nltk.sent_tokenize(line) 541 | 542 | # We add sentence by sentence until we reach the maximum length 543 | for i in range(len(sentencesToken)): 544 | tokens = nltk.word_tokenize(sentencesToken[i]) 545 | 546 | tempWords = [] 547 | for token in tokens: 548 | tempWords.append(self.getWordId(token)) # Create the vocabulary and the training sentences 549 | 550 | sentences.append(tempWords) 551 | 552 | return sentences 553 | 554 | def getWordId(self, word, create=True): 555 | """Get the id of the word (and add it to the dictionary if not existing). If the word does not exist and 556 | create is set to False, the function will return the unknownToken value 557 | Args: 558 | word (str): word to add 559 | create (Bool): if True and the word does not exist already, the world will be added 560 | Return: 561 | int: the id of the word created 562 | """ 563 | # Should we Keep only words with more than one occurrence ? 564 | 565 | word = word.lower() # Ignore case 566 | 567 | # At inference, we simply look up for the word 568 | if not create: 569 | wordId = self.word2id.get(word, self.unknownToken) 570 | # Get the id if the word already exist 571 | elif word in self.word2id: 572 | wordId = self.word2id[word] 573 | self.idCount[wordId] += 1 574 | # If not, we create a new entry 575 | else: 576 | wordId = len(self.word2id) 577 | self.word2id[word] = wordId 578 | self.id2word[wordId] = word 579 | self.idCount[wordId] = 1 580 | 581 | return wordId 582 | 583 | def printBatch(self, batch): 584 | """Print a complete batch, useful for debugging 585 | Args: 586 | batch (Batch): a batch object 587 | """ 588 | print('----- Print batch -----') 589 | for i in range(len(batch.encoderSeqs[0])): # Batch size 590 | print('Encoder: {}'.format(self.batchSeq2str(batch.encoderSeqs, seqId=i))) 591 | print('Decoder: {}'.format(self.batchSeq2str(batch.decoderSeqs, seqId=i))) 592 | print('Targets: {}'.format(self.batchSeq2str(batch.targetSeqs, seqId=i))) 593 | print('Weights: {}'.format(' '.join([str(weight) for weight in [batchWeight[i] for batchWeight in batch.weights]]))) 594 | 595 | def sequence2str(self, sequence, clean=False, reverse=False): 596 | """Convert a list of integer into a human readable string 597 | Args: 598 | sequence (list): the sentence to print 599 | clean (Bool): if set, remove the , and tokens 600 | reverse (Bool): for the input, option to restore the standard order 601 | Return: 602 | str: the sentence 603 | """ 604 | 605 | if not sequence: 606 | return '' 607 | 608 | if not clean: 609 | return ' '.join([self.id2word[idx] for idx in sequence]) 610 | 611 | sentence = [] 612 | for wordId in sequence: 613 | if wordId == self.eosToken: # End of generated sentence 614 | break 615 | elif wordId != self.padToken and wordId != self.goToken: 616 | sentence.append(self.id2word[wordId]) 617 | 618 | if reverse: # Reverse means input so no (otherwise pb with previous early stop) 619 | sentence.reverse() 620 | 621 | return self.detokenize(sentence) 622 | 623 | def detokenize(self, tokens): 624 | """Slightly cleaner version of joining with spaces. 625 | Args: 626 | tokens (list): the sentence to print 627 | Return: 628 | str: the sentence 629 | """ 630 | return ''.join([ 631 | ' ' + t if not t.startswith('\'') and 632 | t not in string.punctuation 633 | else t 634 | for t in tokens]).strip().capitalize() 635 | 636 | def batchSeq2str(self, batchSeq, seqId=0, **kwargs): 637 | """Convert a list of integer into a human readable string. 638 | The difference between the previous function is that on a batch object, the values have been reorganized as 639 | batch instead of sentence. 640 | Args: 641 | batchSeq (list>): the sentence(s) to print 642 | seqId (int): the position of the sequence inside the batch 643 | kwargs: the formatting options( See sequence2str() ) 644 | Return: 645 | str: the sentence 646 | """ 647 | sequence = [] 648 | for i in range(len(batchSeq)): # Sequence length 649 | sequence.append(batchSeq[i][seqId]) 650 | return self.sequence2str(sequence, **kwargs) 651 | 652 | def sentence2enco(self, sentence): 653 | """Encode a sequence and return a batch as an input for the model 654 | Return: 655 | Batch: a batch object containing the sentence, or none if something went wrong 656 | """ 657 | 658 | if sentence == '': 659 | return None 660 | 661 | # First step: Divide the sentence in token 662 | tokens = nltk.word_tokenize(sentence) 663 | if len(tokens) > self.args.maxLength: 664 | return None 665 | 666 | # Second step: Convert the token in word ids 667 | wordIds = [] 668 | for token in tokens: 669 | wordIds.append(self.getWordId(token, create=False)) # Create the vocabulary and the training sentences 670 | 671 | # Third step: creating the batch (add padding, reverse) 672 | batch = self._createBatch([[wordIds, []]]) # Mono batch, no target output 673 | 674 | return batch 675 | 676 | def deco2sentence(self, decoderOutputs): 677 | """Decode the output of the decoder and return a human friendly sentence 678 | decoderOutputs (list): 679 | """ 680 | sequence = [] 681 | 682 | # Choose the words with the highest prediction score 683 | for out in decoderOutputs: 684 | sequence.append(np.argmax(out)) # Adding each predicted word ids 685 | 686 | return sequence # We return the raw sentence. Let the caller do some cleaning eventually 687 | 688 | def playDataset(self): 689 | """Print a random dialogue from the dataset 690 | """ 691 | print('Randomly play samples:') 692 | for i in range(self.args.playDataset): 693 | idSample = random.randint(0, len(self.trainingSamples) - 1) 694 | print('Q: {}'.format(self.sequence2str(self.trainingSamples[idSample][0], clean=True))) 695 | print('A: {}'.format(self.sequence2str(self.trainingSamples[idSample][1], clean=True))) 696 | print() 697 | pass 698 | 699 | 700 | def tqdm_wrap(iterable, *args, **kwargs): 701 | """Forward an iterable eventually wrapped around a tqdm decorator 702 | The iterable is only wrapped if the iterable contains enough elements 703 | Args: 704 | iterable (list): An iterable object which define the __len__ method 705 | *args, **kwargs: the tqdm parameters 706 | Return: 707 | iter: The iterable eventually decorated 708 | """ 709 | if len(iterable) > 100: 710 | return tqdm(iterable, *args, **kwargs) 711 | return iterable 712 | -------------------------------------------------------------------------------- /corpus/ubuntudata.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 Conchylicultor. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | import os 17 | 18 | from tqdm import tqdm 19 | 20 | """ 21 | Ubuntu Dialogue Corpus 22 | 23 | http://arxiv.org/abs/1506.08909 24 | 25 | """ 26 | 27 | class UbuntuData: 28 | """ 29 | """ 30 | 31 | def __init__(self, dirName): 32 | """ 33 | Args: 34 | dirName (string): directory where to load the corpus 35 | """ 36 | self.MAX_NUMBER_SUBDIR = 10 37 | self.conversations = [] 38 | __dir = os.path.join(dirName, "dialogs") 39 | number_subdir = 0 40 | for sub in tqdm(os.scandir(__dir), desc="Ubuntu dialogs subfolders", total=len(os.listdir(__dir))): 41 | if number_subdir == self.MAX_NUMBER_SUBDIR: 42 | print("WARNING: Early stoping, only extracting {} directories".format(self.MAX_NUMBER_SUBDIR)) 43 | return 44 | 45 | if sub.is_dir(): 46 | number_subdir += 1 47 | for f in os.scandir(sub.path): 48 | if f.name.endswith(".tsv"): 49 | self.conversations.append({"lines": self.loadLines(f.path)}) 50 | 51 | 52 | def loadLines(self, fileName): 53 | """ 54 | Args: 55 | fileName (str): file to load 56 | Return: 57 | list>: the extracted fields for each line 58 | """ 59 | lines = [] 60 | with open(fileName, 'r') as f: 61 | for line in f: 62 | l = line[line.rindex("\t")+1:].strip() # Strip metadata (timestamps, speaker names) 63 | 64 | lines.append({"text": l}) 65 | 66 | return lines 67 | 68 | 69 | def getConversations(self): 70 | return self.conversations 71 | -------------------------------------------------------------------------------- /disc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/disc/__init__.py -------------------------------------------------------------------------------- /disc/hier_disc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import os 4 | import time 5 | import random 6 | from disc.hier_rnn_model import Hier_rnn_model 7 | import sys 8 | 9 | 10 | def evaluate(session, model, config, evl_inputs, evl_labels, evl_masks): 11 | total_num = len(evl_inputs[0]) 12 | 13 | fetches = [model.correct_num, model.prediction, model.logits, model.target] 14 | feed_dict = {} 15 | for i in range(config.max_len): 16 | feed_dict[model.input_data[i].name] = evl_inputs[i] 17 | feed_dict[model.target.name] = evl_labels 18 | feed_dict[model.mask_x.name] = evl_masks 19 | correct_num, prediction, logits, target = session.run(fetches, feed_dict) 20 | 21 | print("total_num: ", total_num) 22 | print("correct_num: ", correct_num) 23 | print("prediction: ", prediction) 24 | print("target: ", target) 25 | 26 | accuracy = float(correct_num) / total_num 27 | return accuracy 28 | 29 | 30 | def hier_get_batch(config, max_set_len, query_set, answer_set, gen_set): 31 | batch_size = config.batch_size 32 | if batch_size % 2 == 1: 33 | return IOError("Error") 34 | train_query = [] 35 | train_answer = [] 36 | train_labels = [] 37 | half_size = batch_size / 2 38 | is_random_choose = False 39 | if max_set_len > half_size: 40 | is_random_choose = True 41 | for i in range(int(half_size)): 42 | if is_random_choose: 43 | index = random.randint(0, max_set_len) 44 | else: 45 | index = i 46 | train_query.append(query_set[index]) 47 | train_answer.append(answer_set[index]) 48 | train_labels.append(1) 49 | train_query.append(query_set[index]) 50 | train_answer.append(gen_set[index]) 51 | train_labels.append(0) 52 | 53 | return train_query, train_answer, train_labels 54 | 55 | 56 | def create_model(sess, config, vocab_size, name_scope, initializer=None): 57 | with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): 58 | model = Hier_rnn_model(config=config, vocab_size=vocab_size, name_scope=name_scope) 59 | disc_ckpt_dir = os.path.abspath(os.path.join(config.train_dir, "checkpoints")) 60 | ckpt = tf.train.get_checkpoint_state(disc_ckpt_dir) 61 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 62 | print("Reading Hier Disc model parameters from %s" % ckpt.model_checkpoint_path) 63 | model.saver.restore(sess, ckpt.model_checkpoint_path) 64 | else: 65 | print("Created Hier Disc model with fresh parameters.") 66 | disc_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] 67 | sess.run(tf.variables_initializer(disc_global_variables)) 68 | return model 69 | 70 | 71 | def softmax(x): 72 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 73 | return prob 74 | 75 | 76 | def hier_train(config_disc, config_evl, vocab_size, train_set): 77 | config_evl.keep_prob = 1.0 78 | 79 | print("Disc begin training...") 80 | 81 | with tf.Session() as session: 82 | 83 | query_set = train_set[0] 84 | answer_set = train_set[1] 85 | gen_set = train_set[2] 86 | 87 | train_bucket_sizes = [len(query_set[b]) for b in range(len(config_disc.buckets))] 88 | train_total_size = float(sum(train_bucket_sizes)) 89 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 90 | for i in range(len(train_bucket_sizes))] 91 | 92 | total_qa_size = 0 93 | for i, set in enumerate(query_set): 94 | length = len(set) 95 | print("Discriminator train_set_{} len: {}".format(i, length)) 96 | total_qa_size += length 97 | print("Discriminator train_set total size is {} QA".format(total_qa_size)) 98 | 99 | model = create_model(session, config_disc, vocab_size, name_scope=config_disc.name_model) 100 | 101 | step_time, loss = 0.0, 0.0 102 | current_step = 0 103 | # previous_losses = [] 104 | step_loss_summary = tf.Summary() 105 | disc_writer = tf.summary.FileWriter(config_disc.tensorboard_dir, session.graph) 106 | 107 | while True: 108 | random_number_01 = np.random.random_sample() 109 | bucket_id = min([i for i in range(len(train_buckets_scale)) 110 | if train_buckets_scale[i] > random_number_01]) 111 | 112 | start_time = time.time() 113 | 114 | b_query, b_answer, b_gen = query_set[bucket_id], answer_set[bucket_id], gen_set[bucket_id] 115 | 116 | train_query, train_answer, train_labels = hier_get_batch(config_disc, len(b_query)-1, 117 | b_query, b_answer, b_gen) 118 | 119 | train_query = np.transpose(train_query) 120 | train_answer = np.transpose(train_answer) 121 | 122 | feed_dict = {} 123 | for i in range(config_disc.buckets[bucket_id][0]): 124 | feed_dict[model.query[i].name] = train_query[i] 125 | for i in range(config_disc.buckets[bucket_id][1]): 126 | feed_dict[model.answer[i].name] = train_answer[i] 127 | feed_dict[model.target.name] = train_labels 128 | 129 | fetches = [model.b_train_op[bucket_id], model.b_logits[bucket_id], model.b_loss[bucket_id], model.target] 130 | train_op, logits, step_loss, target = session.run(fetches, feed_dict) 131 | 132 | step_time += (time.time() - start_time) / config_disc.steps_per_checkpoint 133 | loss += step_loss /config_disc.steps_per_checkpoint 134 | current_step += 1 135 | 136 | if current_step % config_disc.steps_per_checkpoint == 0: 137 | 138 | disc_loss_value = step_loss_summary.value.add() 139 | disc_loss_value.tag = config_disc.name_loss 140 | disc_loss_value.simple_value = float(loss) 141 | 142 | disc_writer.add_summary(step_loss_summary, int(session.run(model.global_step))) 143 | 144 | print("logits shape: ", np.shape(logits)) 145 | 146 | # softmax operation 147 | logits = np.transpose(softmax(np.transpose(logits))) 148 | 149 | reward, gen_num = 0.0, 0 150 | for logit, label in zip(logits, train_labels): 151 | if label == 0: 152 | reward += logit[1] # only for true probability 153 | gen_num += 1 154 | # reward = reward / len(train_labels) 155 | reward = reward / gen_num 156 | print("reward: ", reward) 157 | 158 | print("current_step: %d, step_loss: %.4f" %(current_step, step_loss)) 159 | 160 | if current_step % (config_disc.steps_per_checkpoint * 6) == 0: 161 | print("current_step: %d, save_model" % (current_step)) 162 | disc_ckpt_dir = os.path.abspath(os.path.join(config_disc.train_dir, "checkpoints")) 163 | if not os.path.exists(disc_ckpt_dir): 164 | os.makedirs(disc_ckpt_dir) 165 | disc_model_path = os.path.join(disc_ckpt_dir, "disc_pretrain.model") 166 | model.saver.save(session, disc_model_path, global_step=model.global_step) 167 | 168 | step_time, loss = 0.0, 0.0 169 | sys.stdout.flush() 170 | 171 | -------------------------------------------------------------------------------- /disc/hier_rnn_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.contrib import rnn 4 | 5 | 6 | class Hier_rnn_model(object): 7 | def __init__(self, config, vocab_size, name_scope, dtype=tf.float32): 8 | # with tf.variable_scope(name_or_scope=scope_name): 9 | emb_dim = config.embed_dim 10 | num_layers = config.num_layers 11 | # vocab_size = config.vocab_size 12 | # max_len = config.max_len 13 | num_class = config.num_class 14 | buckets = config.buckets 15 | self.lr = config.lr 16 | self.global_step = tf.Variable(initial_value=0, trainable=False) 17 | 18 | self.query = [] 19 | self.answer = [] 20 | for i in range(buckets[-1][0]): 21 | self.query.append(tf.placeholder(dtype=tf.int32, shape=[None], name="query{0}".format(i))) 22 | for i in range(buckets[-1][1]): 23 | self.answer.append(tf.placeholder(dtype=tf.int32, shape=[None], name="answer{0}".format(i))) 24 | 25 | self.target = tf.placeholder(dtype=tf.int64, shape=[None], name="target") 26 | 27 | # encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(emb_dim) 28 | encoder_cell = tf.nn.rnn_cell.LSTMCell(emb_dim) 29 | encoder_mutil = tf.nn.rnn_cell.MultiRNNCell([encoder_cell] * num_layers) 30 | encoder_emb = rnn.EmbeddingWrapper(encoder_mutil, embedding_classes=vocab_size, 31 | embedding_size=emb_dim) 32 | 33 | # context_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=emb_dim) 34 | context_cell = tf.nn.rnn_cell.LSTMCell(num_units=emb_dim) 35 | context_multi = tf.nn.rnn_cell.MultiRNNCell([context_cell] * num_layers) 36 | 37 | self.b_query_state = [] 38 | self.b_answer_state = [] 39 | self.b_state = [] 40 | self.b_logits = [] 41 | self.b_loss = [] 42 | # self.b_cost = [] 43 | self.b_train_op = [] 44 | for i, bucket in enumerate(buckets): 45 | with tf.variable_scope(name_or_scope="Hier_RNN_encoder", reuse=True if i > 0 else None) as var_scope: 46 | query_output, query_state = rnn.static_rnn(encoder_emb, inputs=self.query[:bucket[0]], dtype=tf.float32) 47 | # output [max_len, batch_size, emb_dim] state [num_layer, 2, batch_size, emb_dim] 48 | var_scope.reuse_variables() 49 | answer_output, answer_state = rnn.static_rnn(encoder_emb, inputs=self.answer[:bucket[1]], dtype=tf.float32) 50 | self.b_query_state.append(query_state) 51 | self.b_answer_state.append(answer_state) 52 | context_input = [query_state[-1][1], answer_state[-1][1]] 53 | 54 | with tf.variable_scope(name_or_scope="Hier_RNN_context", reuse=True if i > 0 else None): 55 | output, state = rnn.static_rnn(context_multi, context_input, dtype=tf.float32) 56 | self.b_state.append(state) 57 | top_state = state[-1][1] # [batch_size, emb_dim] 58 | 59 | with tf.variable_scope("Softmax_layer_and_output", reuse=True if i > 0 else None): 60 | softmax_w = tf.get_variable("softmax_w", [emb_dim, num_class], dtype=tf.float32) 61 | softmax_b = tf.get_variable("softmax_b", [num_class], dtype=tf.float32) 62 | logits = tf.matmul(top_state, softmax_w) + softmax_b 63 | self.b_logits.append(logits) 64 | 65 | with tf.name_scope("loss"): 66 | loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.target, logits=logits) 67 | mean_loss = tf.reduce_mean(loss) 68 | self.b_loss.append(mean_loss) 69 | 70 | with tf.name_scope("gradient_descent"): 71 | disc_params = [var for var in tf.trainable_variables() if name_scope in var.name] 72 | grads, norm = tf.clip_by_global_norm(tf.gradients(mean_loss, disc_params), config.max_grad_norm) 73 | # optimizer = tf.train.GradientDescentOptimizer(self.lr) 74 | optimizer = tf.train.AdamOptimizer(self.lr) 75 | train_op = optimizer.apply_gradients(zip(grads, disc_params), global_step=self.global_step) 76 | self.b_train_op.append(train_op) 77 | 78 | all_variables = [v for v in tf.global_variables() if name_scope in v.name] 79 | self.saver = tf.train.Saver(all_variables) 80 | 81 | 82 | class Config(object): 83 | embed_dim = 12 84 | lr = 0.1 85 | num_class = 2 86 | train_dir = './disc_data/' 87 | name_model = "disc_model" 88 | tensorboard_dir = "./tensorboard/disc_log/" 89 | name_loss = "disc_loss" 90 | num_layers = 3 91 | vocab_size = 10 92 | max_len = 50 93 | batch_size = 1 94 | init_scale = 0.1 95 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50), (50, 50)] 96 | max_grad_norm = 5 97 | 98 | 99 | def main(_): 100 | with tf.Session() as sess: 101 | query = [[1],[2],[3],[4],[5]] 102 | answer = [[6],[7],[8],[9],[0],[0],[0],[0],[0],[0]] 103 | target = [1] 104 | config = Config 105 | initializer = tf.random_uniform_initializer(-1 * config.init_scale, 1 * config.init_scale) 106 | with tf.variable_scope(name_or_scope="rnn_model", initializer=initializer): 107 | model = Hier_rnn_model(config, name_scope=config.name_model) 108 | sess.run(tf.global_variables_initializer()) 109 | input_feed = {} 110 | for i in range(config.buckets[0][0]): 111 | input_feed[model.query[i].name] = query[i] 112 | for i in range(config.buckets[0][1]): 113 | input_feed[model.answer[i].name] = answer[i] 114 | input_feed[model.target.name] = target 115 | 116 | fetches = [model.b_train_op[0], model.b_query_state[0], model.b_state[0], model.b_logits[0]] 117 | 118 | train_op, query, state, logits = sess.run(fetches=fetches, feed_dict=input_feed) 119 | 120 | print("query: ", np.shape(query)) 121 | 122 | pass 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /gen/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/gen/__init__.py -------------------------------------------------------------------------------- /gen/gen_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import sys 5 | import numpy as np 6 | import tensorflow as tf 7 | import gen.seq2seq as rl_seq2seq 8 | 9 | 10 | class Seq2SeqModel(object): 11 | 12 | def __init__(self, config, vocab_size, name_scope, forward_only=False, num_samples=512, dtype=tf.float32): 13 | 14 | # self.scope_name = scope_name 15 | # with tf.variable_scope(self.scope_name): 16 | source_vocab_size = vocab_size 17 | target_vocab_size = vocab_size 18 | emb_dim = config.emb_dim 19 | 20 | self.buckets = config.buckets 21 | # self.learning_rate = tf.Variable(float(config.learning_rate), trainable=False, dtype=dtype) 22 | self.learning_rate = config.learning_rate 23 | # self.learning_rate_decay_op = self.learning_rate.assign(self.learning_rate * config.learning_rate_decay_factor) 24 | self.global_step = tf.Variable(0, trainable=False) 25 | self.batch_size = config.batch_size 26 | self.num_layers = config.num_layers 27 | self.max_gradient_norm = config.max_gradient_norm 28 | self.mc_search = tf.placeholder(tf.bool, name="mc_search") 29 | self.forward_only = tf.placeholder(tf.bool, name="forward_only") 30 | self.up_reward = tf.placeholder(tf.bool, name="up_reward") 31 | self.reward_bias = tf.get_variable("reward_bias", [1], dtype=tf.float32) 32 | # If we use sampled softmax, we need an output projection. 33 | output_projection = None 34 | softmax_loss_function = None 35 | # Sampled softmax only makes sense if we sample less than vocabulary size. 36 | if num_samples > 0 and num_samples < target_vocab_size: 37 | w_t = tf.get_variable("proj_w", [target_vocab_size, emb_dim], dtype=dtype) 38 | w = tf.transpose(w_t) 39 | b = tf.get_variable("proj_b", [target_vocab_size], dtype=dtype) 40 | output_projection = (w, b) 41 | 42 | def sampled_loss(inputs, labels): 43 | labels = tf.reshape(labels, [-1, 1]) 44 | # We need to compute the sampled_softmax_loss using 32bit floats to 45 | # avoid numerical instabilities. 46 | local_w_t = tf.cast(w_t, tf.float32) 47 | local_b = tf.cast(b, tf.float32) 48 | local_inputs = tf.cast(inputs, tf.float32) 49 | return tf.cast( 50 | # tf.nn.sampled_softmax_loss(local_w_t, local_b, local_inputs, labels, 51 | # num_samples, target_vocab_size), dtype) 52 | tf.nn.sampled_softmax_loss(local_w_t, local_b, labels, local_inputs, 53 | num_samples, target_vocab_size), dtype) 54 | 55 | softmax_loss_function = sampled_loss 56 | 57 | # Create the internal multi-layer cell for our RNN. 58 | single_cell = tf.nn.rnn_cell.GRUCell(emb_dim) 59 | cell = single_cell 60 | if self.num_layers > 1: 61 | cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * self.num_layers) 62 | 63 | # The seq2seq function: we use embedding for the input and attention. 64 | def seq2seq_f(encoder_inputs, decoder_inputs, do_decode): 65 | return rl_seq2seq.embedding_attention_seq2seq( 66 | encoder_inputs, 67 | decoder_inputs, 68 | cell, 69 | num_encoder_symbols=source_vocab_size, 70 | num_decoder_symbols=target_vocab_size, 71 | embedding_size=emb_dim, 72 | output_projection=output_projection, 73 | feed_previous=do_decode, 74 | mc_search=self.mc_search, 75 | dtype=dtype) 76 | 77 | # Feeds for inputs. 78 | self.encoder_inputs = [] 79 | self.decoder_inputs = [] 80 | self.target_weights = [] 81 | for i in range(self.buckets[-1][0]): # Last bucket is the biggest one. 82 | self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="encoder{0}".format(i))) 83 | # for i in range(self.buckets[-1][1] + 2 + 1): 84 | for i in range(self.buckets[-1][1] + 1): 85 | self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None], name="decoder{0}".format(i))) 86 | self.target_weights.append(tf.placeholder(dtype, shape=[None], name="weight{0}".format(i))) 87 | self.reward = [tf.placeholder(tf.float32, name="reward_%i" % i) for i in range(len(self.buckets))] 88 | 89 | # Our targets are decoder inputs shifted by one. 90 | targets = [self.decoder_inputs[i + 1] for i in range(len(self.decoder_inputs) - 1)] 91 | 92 | self.outputs, self.losses, self.encoder_state = rl_seq2seq.model_with_buckets( 93 | self.encoder_inputs, self.decoder_inputs, targets, self.target_weights, 94 | self.buckets, source_vocab_size, self.batch_size, 95 | lambda x, y: seq2seq_f(x, y, tf.where(self.forward_only, True, False)), 96 | output_projection=output_projection, softmax_loss_function=softmax_loss_function) 97 | 98 | for b in range(len(self.buckets)): 99 | self.outputs[b] = [ 100 | tf.cond( 101 | self.forward_only, 102 | lambda: tf.matmul(output, output_projection[0]) + output_projection[1], 103 | lambda: output 104 | ) 105 | for output in self.outputs[b] 106 | ] 107 | 108 | if not forward_only: 109 | with tf.name_scope("gradient_descent"): 110 | self.gradient_norms = [] 111 | self.updates = [] 112 | self.aj_losses = [] 113 | self.gen_params = [p for p in tf.trainable_variables() if name_scope in p.name] 114 | # opt = tf.train.GradientDescentOptimizer(self.learning_rate) 115 | opt = tf.train.AdamOptimizer(self.learning_rate) 116 | for b in range(len(self.buckets)): 117 | R = tf.subtract(self.reward[b], self.reward_bias) 118 | # self.reward[b] = self.reward[b] - reward_bias 119 | adjusted_loss = tf.cond(self.up_reward, 120 | lambda:tf.multiply(self.losses[b], self.reward[b]), 121 | lambda: self.losses[b]) 122 | 123 | # adjusted_loss = tf.cond(self.up_reward, 124 | # lambda: tf.mul(self.losses[b], R), 125 | # lambda: self.losses[b]) 126 | self.aj_losses.append(adjusted_loss) 127 | gradients = tf.gradients(adjusted_loss, self.gen_params) 128 | clipped_gradients, norm = tf.clip_by_global_norm(gradients, self.max_gradient_norm) 129 | self.gradient_norms.append(norm) 130 | self.updates.append(opt.apply_gradients( 131 | zip(clipped_gradients, self.gen_params), global_step=self.global_step)) 132 | 133 | self.gen_variables = [k for k in tf.global_variables() if name_scope in k.name] 134 | self.saver = tf.train.Saver(self.gen_variables) 135 | 136 | def step(self, session, encoder_inputs, decoder_inputs, target_weights, 137 | bucket_id, forward_only=True, reward=1, mc_search=False, up_reward=False, debug=True): 138 | # Check if the sizes match. 139 | # Q_size, A_size = self.buckets[bucket_id] 140 | # encoder_size = Q_size 141 | # decoder_size = A_size + 2 142 | encoder_size, decoder_size = self.buckets[bucket_id] 143 | if len(encoder_inputs) != encoder_size: 144 | raise ValueError("Encoder length must be equal to the one in bucket," 145 | " %d != %d." % (len(encoder_inputs), encoder_size)) 146 | if len(decoder_inputs) != decoder_size: 147 | raise ValueError("Decoder length must be equal to the one in bucket," 148 | " %d != %d." % (len(decoder_inputs), decoder_size)) 149 | if len(target_weights) != decoder_size: 150 | raise ValueError("Weights length must be equal to the one in bucket," 151 | " %d != %d." % (len(target_weights), decoder_size)) 152 | 153 | # Input feed: encoder inputs, decoder inputs, target_weights, as provided. 154 | 155 | input_feed = { 156 | self.forward_only.name: forward_only, 157 | self.up_reward.name: up_reward, 158 | self.mc_search.name: mc_search 159 | } 160 | for l in range(len(self.buckets)): 161 | input_feed[self.reward[l].name] = reward 162 | for l in range(encoder_size): 163 | input_feed[self.encoder_inputs[l].name] = encoder_inputs[l] 164 | for l in range(decoder_size): 165 | input_feed[self.decoder_inputs[l].name] = decoder_inputs[l] 166 | input_feed[self.target_weights[l].name] = target_weights[l] 167 | 168 | # Since our targets are decoder inputs shifted by one, we need one more. 169 | last_target = self.decoder_inputs[decoder_size].name 170 | input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32) 171 | 172 | # Output feed: depends on whether we do a backward step or not. 173 | if not forward_only: # normal training 174 | output_feed = [self.updates[bucket_id], # Update Op that does SGD. 175 | self.aj_losses[bucket_id], # Gradient norm. 176 | self.losses[bucket_id]] # Loss for this batch. 177 | else: # testing or reinforcement learning 178 | output_feed = [self.encoder_state[bucket_id], self.losses[bucket_id]] # Loss for this batch. 179 | for l in range(decoder_size): # Output logits. 180 | output_feed.append(self.outputs[bucket_id][l]) 181 | 182 | outputs = session.run(output_feed, input_feed) 183 | if not forward_only: 184 | return outputs[1], outputs[2], outputs[0] # Gradient norm, loss, no outputs. 185 | else: 186 | return outputs[0], outputs[1], outputs[2:] # encoder_state, loss, outputs. 187 | -------------------------------------------------------------------------------- /gen/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import random 6 | import sys 7 | import time 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | 12 | import utils.conf as conf 13 | import gen.gen_model as seq2seq_model 14 | import nltk # For tokenize 15 | 16 | # We use a number of buckets and pad to the closest one for efficiency. 17 | # See seq2seq_model.Seq2SeqModel for details of how they work. 18 | 19 | SENTENCES_PREFIX = ['Q: ', 'A: '] 20 | 21 | 22 | # create the train_set, format: [bucket_id][source_ids, target_ids] 23 | def create_train_set(config, text_data): 24 | data_set = [[] for _ in config.buckets] 25 | samples = text_data.trainingSamples 26 | for sample in samples: 27 | source = sample[0] 28 | target = sample[1] 29 | # source_ids = [int(x) for x in source.split()] 30 | # target_ids = [int(x) for x in target.split()] 31 | for bucket_id, (source_size, target_size) in enumerate( 32 | config.buckets): # [bucket_id, (source_size, target_size)] 33 | if len(source) < source_size and len(target) < (target_size - 2): 34 | data_set[bucket_id].append([source, target]) 35 | break 36 | return data_set 37 | 38 | 39 | def create_disc_train_set(config, text_data, bucket_id=-1, train_set=None, batch_num=1, sess=None, gen_model=None): 40 | if train_set is None: 41 | train_set = create_train_set(config, text_data) 42 | random_bucket_id = False 43 | if bucket_id is -1: 44 | train_bucket_sizes = [len(train_set[b]) for b in range(len(config.buckets))] 45 | train_total_size = float(sum(train_bucket_sizes)) 46 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 47 | for i in range(len(train_bucket_sizes))] 48 | random_bucket_id = True 49 | 50 | query_set = [[] for _ in config.buckets] 51 | answer_set = [[] for _ in config.buckets] 52 | gen_set = [[] for _ in config.buckets] 53 | 54 | model = gen_model 55 | is_close_sess = False 56 | if sess is None: 57 | sess = tf.Session() 58 | model = create_model(sess, config, text_data.getVocabularySize(), forward_only=True, 59 | name_scope=config.name_model) 60 | is_close_sess = True 61 | 62 | num_step = 0 63 | print("total generating steps: ", batch_num) 64 | while num_step < batch_num: 65 | print("generating num_step: ", num_step) 66 | if random_bucket_id: 67 | random_number_01 = np.random.random_sample() 68 | bucket_id = min([i for i in range(len(train_buckets_scale)) 69 | if train_buckets_scale[i] > random_number_01]) 70 | 71 | encoder_inputs, decoder_inputs, target_weights, batch_source_encoder, batch_source_decoder = \ 72 | get_batch(config, train_set, bucket_id, config.batch_size, text_data) 73 | 74 | _, _, out_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, 75 | forward_only=True) 76 | seq_tokens = [] 77 | for seq in out_logits: 78 | row_token = [] 79 | for t in seq: 80 | row_token.append(int(np.argmax(t, axis=0))) 81 | seq_tokens.append(row_token) 82 | 83 | seq_tokens_t = [] 84 | for col in range(len(seq_tokens[0])): 85 | seq_tokens_t.append([seq_tokens[row][col] for row in range(len(seq_tokens))]) 86 | 87 | for i in range(config.batch_size): 88 | query_set[bucket_id].append(batch_source_encoder[i]) 89 | answer_set[bucket_id].append(batch_source_decoder[i]) 90 | gen_set[bucket_id].append(seq_tokens_t[i]) 91 | 92 | num_step += 1 93 | 94 | train_set = [query_set, answer_set, gen_set] 95 | if is_close_sess: 96 | sess.close() 97 | return train_set 98 | 99 | 100 | def create_model(session, gen_config, vocab_size, forward_only, name_scope, initializer=None): 101 | """Create translation model and initialize or load parameters in session.""" 102 | with tf.variable_scope(name_or_scope=name_scope, initializer=initializer): 103 | model = seq2seq_model.Seq2SeqModel(gen_config, vocab_size=vocab_size, name_scope=name_scope, 104 | forward_only=forward_only) 105 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 106 | ckpt = tf.train.get_checkpoint_state(gen_ckpt_dir) 107 | if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path): 108 | print("Reading Gen model parameters from %s" % ckpt.model_checkpoint_path) 109 | model.saver.restore(session, ckpt.model_checkpoint_path) 110 | else: 111 | print("Created Gen model with fresh parameters.") 112 | gen_global_variables = [gv for gv in tf.global_variables() if name_scope in gv.name] 113 | session.run(tf.variables_initializer(gen_global_variables)) 114 | return model 115 | 116 | 117 | def softmax(x): 118 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 119 | return prob 120 | 121 | 122 | def get_batch(config, train_set, bucket_id, batch_size, text_data): 123 | # Q_size, A_size = config.buckets[bucket_id] 124 | encoder_size, decoder_size = config.buckets[bucket_id] 125 | # pad them if needed, reverse encoder inputs and add GO to decoder. 126 | batch_source_encoder, batch_source_decoder, samples = [], [], [] 127 | # print("bucket_id: %s" %bucket_id) 128 | for batch_i in range(batch_size): 129 | encoder_input, decoder_input = random.choice(train_set[bucket_id]) 130 | sample = [encoder_input, decoder_input] 131 | samples.append(sample) 132 | encoder_input = text_data.add_pad(encoder_input, config.buckets[bucket_id][0]) 133 | batch_source_encoder.append(encoder_input) 134 | decoder_input = text_data.add_pad(decoder_input, config.buckets[bucket_id][1]) 135 | batch_source_decoder.append(decoder_input) 136 | 137 | # Now we create batch-major vectors from the disc_data selected above. 138 | batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], [] 139 | 140 | # batch = text_data.get_batch(samples, Q_size, A_size) 141 | batch = text_data.get_batch(samples, encoder_size, decoder_size) 142 | batch_encoder_inputs = batch.encoderSeqs 143 | batch_decoder_inputs = batch.decoderSeqs 144 | batch_weights = batch.weights 145 | 146 | return batch_encoder_inputs, batch_decoder_inputs, batch_weights, batch_source_encoder, batch_source_decoder 147 | 148 | 149 | def train(gen_config, text_data): 150 | # vocab, rev_vocab, train_set = prepare_data(gen_config) 151 | train_set = create_train_set(gen_config, text_data) 152 | 153 | total_qa_size = 0 154 | for i, set in enumerate(train_set): 155 | length = len(set) 156 | print("Generator train_set_{} len: {}".format(i, length)) 157 | total_qa_size += length 158 | print("Generator train_set total size is {} QA".format(total_qa_size)) 159 | 160 | with tf.Session() as sess: 161 | #with tf.device("/gpu:1"): 162 | # Create model. 163 | print("Creating %d layers of %d units." % (gen_config.num_layers, gen_config.emb_dim)) 164 | vocab_size = text_data.getVocabularySize() 165 | model = create_model(sess, gen_config, vocab_size, forward_only=False, 166 | name_scope=gen_config.name_model) 167 | 168 | train_bucket_sizes = [len(train_set[b]) for b in range(len(gen_config.buckets))] 169 | train_total_size = float(sum(train_bucket_sizes)) 170 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 171 | for i in range(len(train_bucket_sizes))] 172 | 173 | # This is the training loop. 174 | step_time, loss = 0.0, 0.0 175 | current_step = 0 176 | # previous_losses = [] 177 | 178 | gen_loss_summary = tf.Summary() 179 | gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir, sess.graph) 180 | 181 | while True: 182 | # Choose a bucket according to disc_data distribution. We pick a random number 183 | # in [0, 1] and use the corresponding interval in train_buckets_scale. 184 | random_number_01 = np.random.random_sample() 185 | bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01]) 186 | 187 | # Get a batch and make a step. 188 | start_time = time.time() 189 | encoder_inputs, decoder_inputs, target_weights, _, _ = get_batch(gen_config, train_set, bucket_id, 190 | gen_config.batch_size, text_data) 191 | 192 | _, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 193 | target_weights, bucket_id, forward_only=False) 194 | 195 | step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint 196 | loss += step_loss / gen_config.steps_per_checkpoint 197 | current_step += 1 198 | 199 | # Once in a while, we save checkpoint, print statistics, and run evals. 200 | if current_step % gen_config.steps_per_checkpoint == 0: 201 | 202 | bucket_value = gen_loss_summary.value.add() 203 | bucket_value.tag = gen_config.name_loss 204 | bucket_value.simple_value = float(loss) 205 | gen_writer.add_summary(gen_loss_summary, int(model.global_step.eval())) 206 | 207 | # Print statistics for the previous epoch. 208 | # perplexity = math.exp(loss) if loss < 300 else float('inf') 209 | # print("global step %d learning rate %.4f step-time %.2f perplexity " 210 | # "%.2f" % (model.global_step.eval(), gen_config.learning_rate, 211 | # step_time, perplexity)) 212 | print("global step %d learning rate %.4f step-time %.2f loss " 213 | "%.2f" % (model.global_step.eval(), gen_config.learning_rate, 214 | step_time, loss)) 215 | # Decrease learning rate if no improvement was seen over last 3 times. 216 | # if len(previous_losses) > 2 and loss > max(previous_losses[-3:]): 217 | # sess.run(model.learning_rate_decay_op) 218 | # previous_losses.append(loss) 219 | 220 | # Save checkpoint and zero timer and loss. 221 | if current_step % (gen_config.steps_per_checkpoint * 6) == 0: 222 | print("current_step: %d, save model" %(current_step)) 223 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 224 | if not os.path.exists(gen_ckpt_dir): 225 | os.makedirs(gen_ckpt_dir) 226 | checkpoint_path = os.path.join(gen_ckpt_dir, "gen_pretrain.model") 227 | model.saver.save(sess, checkpoint_path, global_step=model.global_step) 228 | 229 | step_time, loss = 0.0, 0.0 230 | # Run evals on development set and print their perplexity. 231 | # for bucket_id in range(len(gen_config.buckets)): 232 | # encoder_inputs, decoder_inputs, target_weights = model.get_batch( 233 | # dev_set, bucket_id) 234 | # _, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, 235 | # target_weights, bucket_id, True) 236 | # eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf') 237 | # print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx)) 238 | sys.stdout.flush() 239 | 240 | 241 | def inference_interactive(text_data): 242 | config = conf.gen_config 243 | with tf.Session() as sess: 244 | model = create_model(sess, config, text_data.getVocabularySize(), forward_only=True, name_scope=config.name_model) 245 | model.batch_size = 1 246 | # print('Testing: Launch interactive mode:') 247 | print('**************************************************************************************') 248 | print('* Welcome to the interactive mode, here you can ask Chatbot the sentence you want. *\n' 249 | '* Don\'t have high expectation. *\n' 250 | '* Type \'exit\' or just press ENTER to quit the program. Have fun. *') 251 | print('**************************************************************************************') 252 | while True: 253 | question = input(SENTENCES_PREFIX[0]) 254 | if question == '' or question == 'exit': 255 | break 256 | # First step: Divide the sentence in token 257 | tokens = nltk.word_tokenize(question) 258 | # Second step: Convert the token in word ids 259 | word_ids = [] 260 | bucket_id = len(config.buckets) - 1 261 | for token in tokens: 262 | word_ids.append(text_data.getWordId(token, create=False)) # Create the vocabulary and the training sentences 263 | for i, bucket in enumerate(config.buckets): 264 | if bucket[0] >= len(word_ids): 265 | bucket_id = i 266 | break 267 | else: 268 | print('Warning: sentence too long, sorry. Maybe try a shorter sentence.') 269 | 270 | samples = [] 271 | sample = [word_ids, []] 272 | samples.append(sample) 273 | # Q_size, A_size = config.buckets[bucket_id] 274 | # batch = text_data.get_batch(samples, Q_size, A_size) 275 | encoder_size, decoder_size = config.buckets[bucket_id] 276 | batch = text_data.get_batch(samples, encoder_size, decoder_size) 277 | encoder_inputs = batch.encoderSeqs 278 | decoder_inputs = batch.decoderSeqs 279 | weights = batch.weights 280 | start = time.time() 281 | _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, weights, bucket_id, True) 282 | end = time.time() 283 | process_time = end - start 284 | 285 | print("output_logits shape: ", np.shape(output_logits)) 286 | print("inference time: ", process_time) 287 | 288 | outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] 289 | outputs = text_data.sequence2str(outputs, clean=True) 290 | print('{}{}'.format(SENTENCES_PREFIX[1], outputs)) 291 | 292 | -------------------------------------------------------------------------------- /gen/seq2seq.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | # We disable pylint because we need python3 compatibility. 6 | from six.moves import xrange # pylint: disable=redefined-builtin 7 | from six.moves import zip # pylint: disable=redefined-builtin 8 | import tensorflow as tf 9 | from tensorflow.python import shape 10 | from tensorflow.python.framework import dtypes 11 | from tensorflow.python.framework import ops 12 | from tensorflow.python.ops import array_ops 13 | from tensorflow.python.ops import control_flow_ops 14 | from tensorflow.python.ops import embedding_ops 15 | from tensorflow.python.ops import math_ops 16 | from tensorflow.python.ops import nn_ops 17 | # from tensorflow.python.ops import rnn 18 | from tensorflow.contrib import rnn 19 | # from tensorflow.python.ops import rnn_cell 20 | # from tensorflow.contrib.rnn.python.ops import rnn_cell 21 | from tensorflow.contrib.rnn.python.ops import core_rnn_cell 22 | from tensorflow.python.ops import variable_scope 23 | from tensorflow.python.util import nest 24 | 25 | # TODO(ebrevdo): Remove once _linear is fully deprecated. 26 | linear = core_rnn_cell._linear # pylint: disable=protected-access 27 | 28 | 29 | def _argmax_or_mcsearch(embedding, output_projection=None, update_embedding=True, mc_search=False): 30 | def loop_function(prev, _): 31 | if output_projection is not None: 32 | prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1]) 33 | 34 | 35 | if isinstance(mc_search, bool): 36 | prev_symbol = tf.reshape(tf.multinomial(prev, 1), [-1]) if mc_search else math_ops.argmax(prev, 1) 37 | else: 38 | prev_symbol = tf.cond(mc_search, lambda: tf.reshape(tf.multinomial(prev, 1), [-1]), lambda: tf.argmax(prev, 1)) 39 | 40 | 41 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 42 | if not update_embedding: 43 | emb_prev = array_ops.stop_gradient(emb_prev) 44 | return emb_prev 45 | return loop_function 46 | 47 | def _extract_argmax_and_embed(embedding, output_projection=None, update_embedding=True): 48 | """Get a loop_function that extracts the previous symbol and embeds it. 49 | 50 | Args: 51 | embedding: embedding tensor for symbols. 52 | output_projection: None or a pair (W, B). If provided, each fed previous 53 | output will first be multiplied by W and added B. 54 | update_embedding: Boolean; if False, the gradients will not propagate 55 | through the embeddings. 56 | 57 | Returns: 58 | A loop function. 59 | """ 60 | def loop_function(prev, _): 61 | if output_projection is not None: 62 | prev = nn_ops.xw_plus_b( 63 | prev, output_projection[0], output_projection[1]) 64 | prev_symbol = math_ops.argmax(prev, 1) 65 | # Note that gradients will not propagate through the second parameter of 66 | # embedding_lookup. 67 | emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol) 68 | if not update_embedding: 69 | emb_prev = array_ops.stop_gradient(emb_prev) 70 | return emb_prev 71 | return loop_function 72 | 73 | 74 | def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, 75 | scope=None): 76 | """RNN decoder for the sequence-to-sequence model. 77 | 78 | Args: 79 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 80 | initial_state: 2D Tensor with shape [batch_size x cell.state_size]. 81 | cell: rnn_cell.RNNCell defining the cell function and size. 82 | loop_function: If not None, this function will be applied to the i-th output 83 | in order to generate the i+1-st input, and decoder_inputs will be ignored, 84 | except for the first element ("GO" symbol). This can be used for decoding, 85 | but also for training to emulate http://arxiv.org/abs/1506.03099. 86 | Signature -- loop_function(prev, i) = next 87 | * prev is a 2D Tensor of shape [batch_size x output_size], 88 | * i is an integer, the step number (when advanced control is needed), 89 | * next is a 2D Tensor of shape [batch_size x input_size]. 90 | scope: VariableScope for the created subgraph; defaults to "rnn_decoder". 91 | 92 | Returns: 93 | A tuple of the form (outputs, state), where: 94 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 95 | shape [batch_size x output_size] containing generated outputs. 96 | state: The state of each cell at the final time-step. 97 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 98 | (Note that in some cases, like basic RNN cell or GRU cell, outputs and 99 | states can be the same. They are different for LSTM cells though.) 100 | """ 101 | with variable_scope.variable_scope(scope or "rnn_decoder"): 102 | state = initial_state 103 | outputs = [] 104 | prev = None 105 | for i, inp in enumerate(decoder_inputs): 106 | if loop_function is not None and prev is not None: 107 | with variable_scope.variable_scope("loop_function", reuse=True): 108 | inp = loop_function(prev, i) 109 | if i > 0: 110 | variable_scope.get_variable_scope().reuse_variables() 111 | output, state = cell(inp, state) 112 | outputs.append(output) 113 | if loop_function is not None: 114 | prev = output 115 | return outputs, state 116 | 117 | 118 | def basic_rnn_seq2seq( 119 | encoder_inputs, decoder_inputs, cell, dtype=dtypes.float32, scope=None): 120 | """Basic RNN sequence-to-sequence model. 121 | 122 | This model first runs an RNN to encode encoder_inputs into a state vector, 123 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 124 | Encoder and decoder use the same RNN cell type, but don't share parameters. 125 | 126 | Args: 127 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 128 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 129 | cell: rnn_cell.RNNCell defining the cell function and size. 130 | dtype: The dtype of the initial state of the RNN cell (default: tf.float32). 131 | scope: VariableScope for the created subgraph; default: "basic_rnn_seq2seq". 132 | 133 | Returns: 134 | A tuple of the form (outputs, state), where: 135 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 136 | shape [batch_size x output_size] containing the generated outputs. 137 | state: The state of each decoder cell in the final time-step. 138 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 139 | """ 140 | with variable_scope.variable_scope(scope or "basic_rnn_seq2seq"): 141 | _, enc_state = rnn.static_rnn(cell, encoder_inputs, dtype=dtype) 142 | return rnn_decoder(decoder_inputs, enc_state, cell) 143 | 144 | 145 | def tied_rnn_seq2seq(encoder_inputs, decoder_inputs, cell, 146 | loop_function=None, dtype=dtypes.float32, scope=None): 147 | """RNN sequence-to-sequence model with tied encoder and decoder parameters. 148 | 149 | This model first runs an RNN to encode encoder_inputs into a state vector, and 150 | then runs decoder, initialized with the last encoder state, on decoder_inputs. 151 | Encoder and decoder use the same RNN cell and share parameters. 152 | 153 | Args: 154 | encoder_inputs: A list of 2D Tensors [batch_size x input_size]. 155 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 156 | cell: rnn_cell.RNNCell defining the cell function and size. 157 | loop_function: If not None, this function will be applied to i-th output 158 | in order to generate i+1-th input, and decoder_inputs will be ignored, 159 | except for the first element ("GO" symbol), see rnn_decoder for details. 160 | dtype: The dtype of the initial state of the rnn cell (default: tf.float32). 161 | scope: VariableScope for the created subgraph; default: "tied_rnn_seq2seq". 162 | 163 | Returns: 164 | A tuple of the form (outputs, state), where: 165 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 166 | shape [batch_size x output_size] containing the generated outputs. 167 | state: The state of each decoder cell in each time-step. This is a list 168 | with length len(decoder_inputs) -- one item for each time-step. 169 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 170 | """ 171 | with variable_scope.variable_scope("combined_tied_rnn_seq2seq"): 172 | scope = scope or "tied_rnn_seq2seq" 173 | _, enc_state = rnn.static_rnn( 174 | cell, encoder_inputs, dtype=dtype, scope=scope) 175 | variable_scope.get_variable_scope().reuse_variables() 176 | return rnn_decoder(decoder_inputs, enc_state, cell, 177 | loop_function=loop_function, scope=scope) 178 | 179 | 180 | def embedding_rnn_decoder(decoder_inputs, 181 | initial_state, 182 | cell, 183 | num_symbols, 184 | embedding_size, 185 | output_projection=None, 186 | feed_previous=False, 187 | update_embedding_for_previous=True, 188 | scope=None): 189 | 190 | with variable_scope.variable_scope(scope or "embedding_rnn_decoder") as scope: 191 | if output_projection is not None: 192 | dtype = scope.dtype 193 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 194 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 195 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 196 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 197 | 198 | embedding = variable_scope.get_variable("embedding", 199 | [num_symbols, embedding_size]) 200 | loop_function = _extract_argmax_and_embed( 201 | embedding, output_projection, 202 | update_embedding_for_previous) if feed_previous else None 203 | emb_inp = ( 204 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs) 205 | return rnn_decoder(emb_inp, initial_state, cell, 206 | loop_function=loop_function) 207 | 208 | 209 | def embedding_rnn_seq2seq(encoder_inputs, 210 | decoder_inputs, 211 | cell, 212 | num_encoder_symbols, 213 | num_decoder_symbols, 214 | embedding_size, 215 | output_projection=None, 216 | feed_previous=False, 217 | dtype=None, 218 | scope=None): 219 | 220 | with variable_scope.variable_scope(scope or "embedding_rnn_seq2seq") as scope: 221 | if dtype is not None: 222 | scope.set_dtype(dtype) 223 | else: 224 | dtype = scope.dtype 225 | 226 | # Encoder. 227 | encoder_cell = core_rnn_cell.EmbeddingWrapper( 228 | cell, embedding_classes=num_encoder_symbols, 229 | embedding_size=embedding_size) 230 | _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 231 | 232 | # Decoder. 233 | if output_projection is None: 234 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 235 | 236 | if isinstance(feed_previous, bool): 237 | return embedding_rnn_decoder( 238 | decoder_inputs, 239 | encoder_state, 240 | cell, 241 | num_decoder_symbols, 242 | embedding_size, 243 | output_projection=output_projection, 244 | feed_previous=feed_previous, 245 | scope=scope) 246 | 247 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 248 | def decoder(feed_previous_bool): 249 | reuse = None if feed_previous_bool else True 250 | with variable_scope.variable_scope( 251 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 252 | outputs, state = embedding_rnn_decoder( 253 | decoder_inputs, encoder_state, cell, num_decoder_symbols, 254 | embedding_size, output_projection=output_projection, 255 | feed_previous=feed_previous_bool, 256 | update_embedding_for_previous=False) 257 | state_list = [state] 258 | if nest.is_sequence(state): 259 | state_list = nest.flatten(state) 260 | return outputs + state_list 261 | 262 | outputs_and_state = control_flow_ops.cond(feed_previous, 263 | lambda: decoder(True), 264 | lambda: decoder(False)) 265 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 266 | state_list = outputs_and_state[outputs_len:] 267 | state = state_list[0] 268 | if nest.is_sequence(encoder_state): 269 | state = nest.pack_sequence_as(structure=encoder_state, 270 | flat_sequence=state_list) 271 | return outputs_and_state[:outputs_len], state 272 | 273 | 274 | def embedding_tied_rnn_seq2seq(encoder_inputs, 275 | decoder_inputs, 276 | cell, 277 | num_symbols, 278 | embedding_size, 279 | num_decoder_symbols=None, 280 | output_projection=None, 281 | feed_previous=False, 282 | dtype=None, 283 | scope=None): 284 | """Embedding RNN sequence-to-sequence model with tied (shared) parameters. 285 | 286 | This model first embeds encoder_inputs by a newly created embedding (of shape 287 | [num_symbols x input_size]). Then it runs an RNN to encode embedded 288 | encoder_inputs into a state vector. Next, it embeds decoder_inputs using 289 | the same embedding. Then it runs RNN decoder, initialized with the last 290 | encoder state, on embedded decoder_inputs. The decoder output is over symbols 291 | from 0 to num_decoder_symbols - 1 if num_decoder_symbols is none; otherwise it 292 | is over 0 to num_symbols - 1. 293 | 294 | Args: 295 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 296 | decoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 297 | cell: rnn_cell.RNNCell defining the cell function and size. 298 | num_symbols: Integer; number of symbols for both encoder and decoder. 299 | embedding_size: Integer, the length of the embedding vector for each symbol. 300 | num_decoder_symbols: Integer; number of output symbols for decoder. If 301 | provided, the decoder output is over symbols 0 to num_decoder_symbols - 1. 302 | Otherwise, decoder output is over symbols 0 to num_symbols - 1. Note that 303 | this assumes that the vocabulary is set up such that the first 304 | num_decoder_symbols of num_symbols are part of decoding. 305 | output_projection: None or a pair (W, B) of output projection weights and 306 | biases; W has shape [output_size x num_symbols] and B has 307 | shape [num_symbols]; if provided and feed_previous=True, each 308 | fed previous output will first be multiplied by W and added B. 309 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first 310 | of decoder_inputs will be used (the "GO" symbol), and all other decoder 311 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 312 | If False, decoder_inputs are used as given (the standard decoder case). 313 | dtype: The dtype to use for the initial RNN states (default: tf.float32). 314 | scope: VariableScope for the created subgraph; defaults to 315 | "embedding_tied_rnn_seq2seq". 316 | 317 | Returns: 318 | A tuple of the form (outputs, state), where: 319 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 320 | shape [batch_size x output_symbols] containing the generated 321 | outputs where output_symbols = num_decoder_symbols if 322 | num_decoder_symbols is not None otherwise output_symbols = num_symbols. 323 | state: The state of each decoder cell at the final time-step. 324 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 325 | 326 | Raises: 327 | ValueError: When output_projection has the wrong shape. 328 | """ 329 | with variable_scope.variable_scope( 330 | scope or "embedding_tied_rnn_seq2seq", dtype=dtype) as scope: 331 | dtype = scope.dtype 332 | 333 | if output_projection is not None: 334 | proj_weights = ops.convert_to_tensor(output_projection[0], dtype=dtype) 335 | proj_weights.get_shape().assert_is_compatible_with([None, num_symbols]) 336 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 337 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 338 | 339 | embedding = variable_scope.get_variable( 340 | "embedding", [num_symbols, embedding_size], dtype=dtype) 341 | 342 | emb_encoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 343 | for x in encoder_inputs] 344 | emb_decoder_inputs = [embedding_ops.embedding_lookup(embedding, x) 345 | for x in decoder_inputs] 346 | 347 | output_symbols = num_symbols 348 | if num_decoder_symbols is not None: 349 | output_symbols = num_decoder_symbols 350 | if output_projection is None: 351 | cell = core_rnn_cell.OutputProjectionWrapper(cell, output_symbols) 352 | 353 | if isinstance(feed_previous, bool): 354 | loop_function = _extract_argmax_and_embed( 355 | embedding, output_projection, True) if feed_previous else None 356 | return tied_rnn_seq2seq(emb_encoder_inputs, emb_decoder_inputs, cell, 357 | loop_function=loop_function, dtype=dtype) 358 | 359 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 360 | def decoder(feed_previous_bool): 361 | loop_function = _extract_argmax_and_embed( 362 | embedding, output_projection, False) if feed_previous_bool else None 363 | reuse = None if feed_previous_bool else True 364 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 365 | reuse=reuse): 366 | outputs, state = tied_rnn_seq2seq( 367 | emb_encoder_inputs, emb_decoder_inputs, cell, 368 | loop_function=loop_function, dtype=dtype) 369 | state_list = [state] 370 | if nest.is_sequence(state): 371 | state_list = nest.flatten(state) 372 | return outputs + state_list 373 | 374 | outputs_and_state = control_flow_ops.cond(feed_previous, 375 | lambda: decoder(True), 376 | lambda: decoder(False)) 377 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 378 | state_list = outputs_and_state[outputs_len:] 379 | state = state_list[0] 380 | # Calculate zero-state to know it's structure. 381 | static_batch_size = encoder_inputs[0].get_shape()[0] 382 | for inp in encoder_inputs[1:]: 383 | static_batch_size.merge_with(inp.get_shape()[0]) 384 | batch_size = static_batch_size.value 385 | if batch_size is None: 386 | batch_size = array_ops.shape(encoder_inputs[0])[0] 387 | zero_state = cell.zero_state(batch_size, dtype) 388 | if nest.is_sequence(zero_state): 389 | state = nest.pack_sequence_as(structure=zero_state, 390 | flat_sequence=state_list) 391 | return outputs_and_state[:outputs_len], state 392 | 393 | 394 | def attention_decoder(decoder_inputs, 395 | initial_state, 396 | attention_states, 397 | cell, 398 | output_size=None, 399 | num_heads=1, 400 | loop_function=None, 401 | dtype=None, 402 | scope=None, 403 | initial_state_attention=False): 404 | """RNN decoder with attention for the sequence-to-sequence model. 405 | 406 | In this context "attention" means that, during decoding, the RNN can look up 407 | information in the additional tensor attention_states, and it does this by 408 | focusing on a few entries from the tensor. This model has proven to yield 409 | especially good results in a number of sequence-to-sequence tasks. This 410 | implementation is based on http://arxiv.org/abs/1412.7449 (see below for 411 | details). It is recommended for complex sequence-to-sequence tasks. 412 | 413 | Args: 414 | decoder_inputs: A list of 2D Tensors [batch_size x input_size]. 415 | initial_state: 2D Tensor [batch_size x cell.state_size]. 416 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 417 | cell: rnn_cell.RNNCell defining the cell function and size. 418 | output_size: Size of the output vectors; if None, we use cell.output_size. 419 | num_heads: Number of attention heads that read from attention_states. 420 | loop_function: If not None, this function will be applied to i-th output 421 | in order to generate i+1-th input, and decoder_inputs will be ignored, 422 | except for the first element ("GO" symbol). This can be used for decoding, 423 | but also for training to emulate http://arxiv.org/abs/1506.03099. 424 | Signature -- loop_function(prev, i) = next 425 | * prev is a 2D Tensor of shape [batch_size x output_size], 426 | * i is an integer, the step number (when advanced control is needed), 427 | * next is a 2D Tensor of shape [batch_size x input_size]. 428 | dtype: The dtype to use for the RNN initial state (default: tf.float32). 429 | scope: VariableScope for the created subgraph; default: "attention_decoder". 430 | initial_state_attention: If False (default), initial attentions are zero. 431 | If True, initialize the attentions from the initial state and attention 432 | states -- useful when we wish to resume decoding from a previously 433 | stored decoder state and attention states. 434 | 435 | Returns: 436 | A tuple of the form (outputs, state), where: 437 | outputs: A list of the same length as decoder_inputs of 2D Tensors of 438 | shape [batch_size x output_size]. These represent the generated outputs. 439 | Output i is computed from input i (which is either the i-th element 440 | of decoder_inputs or loop_function(output {i-1}, i)) as follows. 441 | First, we run the cell on a combination of the input and previous 442 | attention masks: 443 | cell_output, new_state = cell(linear(input, prev_attn), prev_state). 444 | Then, we calculate new attention masks: 445 | new_attn = softmax(V^T * tanh(W * attention_states + U * new_state)) 446 | and then we calculate the output: 447 | output = linear(cell_output, new_attn). 448 | state: The state of each decoder cell the final time-step. 449 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 450 | 451 | Raises: 452 | ValueError: when num_heads is not positive, there are no inputs, shapes 453 | of attention_states are not set, or input size cannot be inferred 454 | from the input. 455 | """ 456 | if not decoder_inputs: 457 | raise ValueError("Must provide at least 1 input to attention decoder.") 458 | if num_heads < 1: 459 | raise ValueError("With less than 1 heads, use a non-attention decoder.") 460 | if attention_states.get_shape()[2].value is None: 461 | raise ValueError("Shape[2] of attention_states must be known: %s" 462 | % attention_states.get_shape()) 463 | if output_size is None: 464 | output_size = cell.output_size 465 | 466 | with variable_scope.variable_scope( 467 | scope or "attention_decoder", dtype=dtype) as scope: 468 | dtype = scope.dtype 469 | 470 | batch_size = array_ops.shape(decoder_inputs[0])[0] # Needed for reshaping. 471 | attn_length = attention_states.get_shape()[1].value 472 | if attn_length is None: 473 | attn_length = shape(attention_states)[1] 474 | attn_size = attention_states.get_shape()[2].value 475 | 476 | # To calculate W1 * h_t we use a 1-by-1 convolution, need to reshape before. 477 | hidden = array_ops.reshape( 478 | attention_states, [-1, attn_length, 1, attn_size]) 479 | hidden_features = [] 480 | v = [] 481 | attention_vec_size = attn_size # Size of query vectors for attention. 482 | for a in xrange(num_heads): 483 | k = variable_scope.get_variable("AttnW_%d" % a, 484 | [1, 1, attn_size, attention_vec_size]) 485 | hidden_features.append(nn_ops.conv2d(hidden, k, [1, 1, 1, 1], "SAME")) 486 | v.append( 487 | variable_scope.get_variable("AttnV_%d" % a, [attention_vec_size])) 488 | 489 | state = initial_state 490 | 491 | def attention(query): 492 | """Put attention masks on hidden using hidden_features and query.""" 493 | ds = [] # Results of attention reads will be stored here. 494 | if nest.is_sequence(query): # If the query is a tuple, flatten it. 495 | query_list = nest.flatten(query) 496 | for q in query_list: # Check that ndims == 2 if specified. 497 | ndims = q.get_shape().ndims 498 | if ndims: 499 | assert ndims == 2 500 | query = array_ops.concat(query_list, 1) 501 | for a in xrange(num_heads): 502 | with variable_scope.variable_scope("Attention_%d" % a): 503 | y = linear(query, attention_vec_size, True) 504 | y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) 505 | # Attention mask is a softmax of v^T * tanh(...). 506 | s = math_ops.reduce_sum( 507 | v[a] * math_ops.tanh(hidden_features[a] + y), [2, 3]) 508 | a = nn_ops.softmax(s) 509 | # Now calculate the attention-weighted vector d. 510 | d = math_ops.reduce_sum( 511 | array_ops.reshape(a, [-1, attn_length, 1, 1]) * hidden, 512 | [1, 2]) 513 | ds.append(array_ops.reshape(d, [-1, attn_size])) 514 | return ds 515 | 516 | outputs = [] 517 | prev = None 518 | batch_attn_size = array_ops.pack([batch_size, attn_size]) 519 | attns = [array_ops.zeros(batch_attn_size, dtype=dtype) 520 | for _ in xrange(num_heads)] 521 | for a in attns: # Ensure the second shape of attention vectors is set. 522 | a.set_shape([None, attn_size]) 523 | if initial_state_attention: 524 | attns = attention(initial_state) 525 | for i, inp in enumerate(decoder_inputs): 526 | if i > 0: 527 | variable_scope.get_variable_scope().reuse_variables() 528 | # If loop_function is set, we use it instead of decoder_inputs. 529 | if loop_function is not None and prev is not None: 530 | with variable_scope.variable_scope("loop_function", reuse=True): 531 | inp = loop_function(prev, i) 532 | # Merge input and previous attentions into one vector of the right size. 533 | input_size = inp.get_shape().with_rank(2)[1] 534 | if input_size.value is None: 535 | raise ValueError("Could not infer input size from input: %s" % inp.name) 536 | x = linear([inp] + attns, input_size, True) 537 | # Run the RNN. 538 | cell_output, state = cell(x, state) 539 | # Run the attention mechanism. 540 | if i == 0 and initial_state_attention: 541 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 542 | reuse=True): 543 | attns = attention(state) 544 | else: 545 | attns = attention(state) 546 | 547 | with variable_scope.variable_scope("AttnOutputProjection"): 548 | output = linear([cell_output] + attns, output_size, True) 549 | if loop_function is not None: 550 | prev = output 551 | outputs.append(output) 552 | 553 | return outputs, state 554 | 555 | 556 | def embedding_attention_decoder(decoder_inputs, 557 | initial_state, 558 | attention_states, 559 | cell, 560 | num_symbols, 561 | embedding_size, 562 | num_heads=1, 563 | output_size=None, 564 | output_projection=None, 565 | feed_previous=False, 566 | update_embedding_for_previous=True, 567 | dtype=None, 568 | scope=None, 569 | initial_state_attention=False, 570 | mc_search = False): 571 | """RNN decoder with embedding and attention and a pure-decoding option. 572 | 573 | Args: 574 | decoder_inputs: A list of 1D batch-sized int32 Tensors (decoder inputs). 575 | initial_state: 2D Tensor [batch_size x cell.state_size]. 576 | attention_states: 3D Tensor [batch_size x attn_length x attn_size]. 577 | cell: rnn_cell.RNNCell defining the cell function. 578 | num_symbols: Integer, how many symbols come into the embedding. 579 | embedding_size: Integer, the length of the embedding vector for each symbol. 580 | num_heads: Number of attention heads that read from attention_states. 581 | output_size: Size of the output vectors; if None, use output_size. 582 | output_projection: None or a pair (W, B) of output projection weights and 583 | biases; W has shape [output_size x num_symbols] and B has shape 584 | [num_symbols]; if provided and feed_previous=True, each fed previous 585 | output will first be multiplied by W and added B. 586 | feed_previous: Boolean; if True, only the first of decoder_inputs will be 587 | used (the "GO" symbol), and all other decoder inputs will be generated by: 588 | next = embedding_lookup(embedding, argmax(previous_output)), 589 | In effect, this implements a greedy decoder. It can also be used 590 | during training to emulate http://arxiv.org/abs/1506.03099. 591 | If False, decoder_inputs are used as given (the standard decoder case). 592 | update_embedding_for_previous: Boolean; if False and feed_previous=True, 593 | only the embedding for the first symbol of decoder_inputs (the "GO" 594 | symbol) will be updated by back propagation. Embeddings for the symbols 595 | generated from the decoder itself remain unchanged. This parameter has 596 | no effect if feed_previous=False. 597 | dtype: The dtype to use for the RNN initial states (default: tf.float32). 598 | scope: VariableScope for the created subgraph; defaults to 599 | "embedding_attention_decoder". 600 | initial_state_attention: If False (default), initial attentions are zero. 601 | If True, initialize the attentions from the initial state and attention 602 | states -- useful when we wish to resume decoding from a previously 603 | stored decoder state and attention states. 604 | 605 | Returns: 606 | A tuple of the form (outputs, state), where: 607 | outputs: A list of the same length as decoder_inputs of 2D Tensors with 608 | shape [batch_size x output_size] containing the generated outputs. 609 | state: The state of each decoder cell at the final time-step. 610 | It is a 2D Tensor of shape [batch_size x cell.state_size]. 611 | 612 | Raises: 613 | ValueError: When output_projection has the wrong shape. 614 | """ 615 | if output_size is None: 616 | output_size = cell.output_size 617 | if output_projection is not None: 618 | proj_biases = ops.convert_to_tensor(output_projection[1], dtype=dtype) 619 | proj_biases.get_shape().assert_is_compatible_with([num_symbols]) 620 | 621 | with variable_scope.variable_scope( 622 | scope or "embedding_attention_decoder", dtype=dtype) as scope: 623 | 624 | embedding = variable_scope.get_variable("embedding", 625 | [num_symbols, embedding_size]) 626 | 627 | loop_function = None 628 | if feed_previous == True: 629 | loop_function = _argmax_or_mcsearch(embedding, output_projection, update_embedding_for_previous, mc_search) 630 | # if isinstance(mc_search, bool): 631 | # if feed_previous == True and mc_search == True: 632 | # loop_function = _mc_argmax_and_embed(embedding, output_projection, update_embedding_for_previous) 633 | # elif feed_previous == True and mc_search == False: 634 | # loop_function = _extract_argmax_and_embed(embedding, output_projection, update_embedding_for_previous) 635 | # elif (feed_previous == True): 636 | # loop_function = control_flow_ops.cond(mc_search, 637 | # _mc_argmax_and_embed(embedding, output_projection, update_embedding_for_previous), 638 | # _extract_argmax_and_embed(embedding, output_projection, update_embedding_for_previous)) 639 | 640 | emb_inp = [ 641 | embedding_ops.embedding_lookup(embedding, i) for i in decoder_inputs] 642 | return attention_decoder( 643 | emb_inp, 644 | initial_state, 645 | attention_states, 646 | cell, 647 | output_size=output_size, 648 | num_heads=num_heads, 649 | loop_function=loop_function, 650 | initial_state_attention=initial_state_attention, 651 | scope=scope) 652 | 653 | 654 | def embedding_attention_seq2seq(encoder_inputs, 655 | decoder_inputs, 656 | cell, 657 | num_encoder_symbols, 658 | num_decoder_symbols, 659 | embedding_size, 660 | num_heads=1, 661 | output_projection=None, 662 | feed_previous=False, 663 | dtype=None, 664 | scope=None, 665 | initial_state_attention=False, 666 | mc_search=False): 667 | 668 | with variable_scope.variable_scope( 669 | scope or "embedding_attention_seq2seq", dtype=dtype) as scope: 670 | dtype = scope.dtype 671 | # Encoder. 672 | encoder_cell = core_rnn_cell.EmbeddingWrapper( 673 | cell, embedding_classes=num_encoder_symbols, 674 | embedding_size=embedding_size) 675 | encoder_outputs, encoder_state = rnn.static_rnn( 676 | encoder_cell, encoder_inputs, dtype=dtype) 677 | 678 | # First calculate a concatenation of encoder outputs to put attention on. 679 | top_states = [array_ops.reshape(e, [-1, 1, cell.output_size]) 680 | for e in encoder_outputs] 681 | # attention_states = array_ops.concat(1, top_states) 682 | attention_states = array_ops.concat(top_states, 1) 683 | 684 | # Decoder. 685 | output_size = None 686 | if output_projection is None: 687 | cell = core_rnn_cell.OutputProjectionWrapper(cell, num_decoder_symbols) 688 | output_size = num_decoder_symbols 689 | 690 | if isinstance(feed_previous, bool): 691 | outputs, state = embedding_attention_decoder( 692 | decoder_inputs, 693 | encoder_state, 694 | attention_states, 695 | cell, 696 | num_decoder_symbols, 697 | embedding_size, 698 | num_heads=num_heads, 699 | output_size=output_size, 700 | output_projection=output_projection, 701 | feed_previous=feed_previous, 702 | initial_state_attention=initial_state_attention, 703 | mc_search=mc_search, 704 | scope=scope) 705 | return outputs, state, encoder_state 706 | 707 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 708 | def decoder(feed_previous_bool): 709 | reuse = None if feed_previous_bool else True 710 | with variable_scope.variable_scope( 711 | variable_scope.get_variable_scope(), reuse=reuse) as scope: 712 | outputs, state = embedding_attention_decoder( 713 | decoder_inputs, 714 | encoder_state, 715 | attention_states, 716 | cell, 717 | num_decoder_symbols, 718 | embedding_size, 719 | num_heads=num_heads, 720 | output_size=output_size, 721 | output_projection=output_projection, 722 | feed_previous=feed_previous_bool, 723 | update_embedding_for_previous=False, 724 | initial_state_attention=initial_state_attention, 725 | mc_search=mc_search, 726 | scope=scope) 727 | state_list = [state] 728 | if nest.is_sequence(state): 729 | state_list = nest.flatten(state) 730 | return outputs + state_list 731 | 732 | outputs_and_state = control_flow_ops.cond(feed_previous, 733 | lambda: decoder(True), 734 | lambda: decoder(False)) 735 | outputs_len = len(decoder_inputs) # Outputs length same as decoder inputs. 736 | state_list = outputs_and_state[outputs_len:] 737 | state = state_list[0] 738 | if nest.is_sequence(encoder_state): 739 | state = nest.pack_sequence_as(structure=encoder_state, 740 | flat_sequence=state_list) 741 | return outputs_and_state[:outputs_len], state, encoder_state 742 | 743 | 744 | def one2many_rnn_seq2seq(encoder_inputs, 745 | decoder_inputs_dict, 746 | cell, 747 | num_encoder_symbols, 748 | num_decoder_symbols_dict, 749 | embedding_size, 750 | feed_previous=False, 751 | dtype=None, 752 | scope=None): 753 | """One-to-many RNN sequence-to-sequence model (multi-task). 754 | 755 | This is a multi-task sequence-to-sequence model with one encoder and multiple 756 | decoders. Reference to multi-task sequence-to-sequence learning can be found 757 | here: http://arxiv.org/abs/1511.06114 758 | 759 | Args: 760 | encoder_inputs: A list of 1D int32 Tensors of shape [batch_size]. 761 | decoder_inputs_dict: A dictionany mapping decoder name (string) to 762 | the corresponding decoder_inputs; each decoder_inputs is a list of 1D 763 | Tensors of shape [batch_size]; num_decoders is defined as 764 | len(decoder_inputs_dict). 765 | cell: rnn_cell.RNNCell defining the cell function and size. 766 | num_encoder_symbols: Integer; number of symbols on the encoder side. 767 | num_decoder_symbols_dict: A dictionary mapping decoder name (string) to an 768 | integer specifying number of symbols for the corresponding decoder; 769 | len(num_decoder_symbols_dict) must be equal to num_decoders. 770 | embedding_size: Integer, the length of the embedding vector for each symbol. 771 | feed_previous: Boolean or scalar Boolean Tensor; if True, only the first of 772 | decoder_inputs will be used (the "GO" symbol), and all other decoder 773 | inputs will be taken from previous outputs (as in embedding_rnn_decoder). 774 | If False, decoder_inputs are used as given (the standard decoder case). 775 | dtype: The dtype of the initial state for both the encoder and encoder 776 | rnn cells (default: tf.float32). 777 | scope: VariableScope for the created subgraph; defaults to 778 | "one2many_rnn_seq2seq" 779 | 780 | Returns: 781 | A tuple of the form (outputs_dict, state_dict), where: 782 | outputs_dict: A mapping from decoder name (string) to a list of the same 783 | length as decoder_inputs_dict[name]; each element in the list is a 2D 784 | Tensors with shape [batch_size x num_decoder_symbol_list[name]] 785 | containing the generated outputs. 786 | state_dict: A mapping from decoder name (string) to the final state of the 787 | corresponding decoder RNN; it is a 2D Tensor of shape 788 | [batch_size x cell.state_size]. 789 | """ 790 | outputs_dict = {} 791 | state_dict = {} 792 | 793 | with variable_scope.variable_scope( 794 | scope or "one2many_rnn_seq2seq", dtype=dtype) as scope: 795 | dtype = scope.dtype 796 | 797 | # Encoder. 798 | encoder_cell = core_rnn_cell.EmbeddingWrapper( 799 | cell, embedding_classes=num_encoder_symbols, 800 | embedding_size=embedding_size) 801 | _, encoder_state = rnn.static_rnn(encoder_cell, encoder_inputs, dtype=dtype) 802 | 803 | # Decoder. 804 | for name, decoder_inputs in decoder_inputs_dict.items(): 805 | num_decoder_symbols = num_decoder_symbols_dict[name] 806 | 807 | with variable_scope.variable_scope("one2many_decoder_" + str( 808 | name)) as scope: 809 | decoder_cell = core_rnn_cell.OutputProjectionWrapper(cell, 810 | num_decoder_symbols) 811 | if isinstance(feed_previous, bool): 812 | outputs, state = embedding_rnn_decoder( 813 | decoder_inputs, encoder_state, decoder_cell, num_decoder_symbols, 814 | embedding_size, feed_previous=feed_previous) 815 | else: 816 | # If feed_previous is a Tensor, we construct 2 graphs and use cond. 817 | def filled_embedding_rnn_decoder(feed_previous): 818 | """The current decoder with a fixed feed_previous parameter.""" 819 | # pylint: disable=cell-var-from-loop 820 | reuse = None if feed_previous else True 821 | vs = variable_scope.get_variable_scope() 822 | with variable_scope.variable_scope(vs, reuse=reuse): 823 | outputs, state = embedding_rnn_decoder( 824 | decoder_inputs, encoder_state, decoder_cell, 825 | num_decoder_symbols, embedding_size, 826 | feed_previous=feed_previous) 827 | # pylint: enable=cell-var-from-loop 828 | state_list = [state] 829 | if nest.is_sequence(state): 830 | state_list = nest.flatten(state) 831 | return outputs + state_list 832 | 833 | outputs_and_state = control_flow_ops.cond( 834 | feed_previous, 835 | lambda: filled_embedding_rnn_decoder(True), 836 | lambda: filled_embedding_rnn_decoder(False)) 837 | # Outputs length is the same as for decoder inputs. 838 | outputs_len = len(decoder_inputs) 839 | outputs = outputs_and_state[:outputs_len] 840 | state_list = outputs_and_state[outputs_len:] 841 | state = state_list[0] 842 | if nest.is_sequence(encoder_state): 843 | state = nest.pack_sequence_as(structure=encoder_state, 844 | flat_sequence=state_list) 845 | outputs_dict[name] = outputs 846 | state_dict[name] = state 847 | 848 | return outputs_dict, state_dict 849 | 850 | 851 | # def sequence_loss_by_mle(logits, targets, emb_dim, sequence_length, batch_size, name=None): 852 | # pass 853 | # if len(targets) != len(logits) or len(weights) != len(logits): 854 | # raise ValueError("Lengths of logits, weights, and targets must be the same " 855 | # "%d, %d, %d." % (len(logits), len(weights), len(targets))) 856 | # with ops.name_scope(name, "sequence_loss_by_mle", 857 | # logits + targets + weights): 858 | # 859 | # pretrain_loss = -tf.reduce_sum( 860 | # tf.one_hot(tf.to_int32(tf.reshape(targets, [-1])), emb_dim, 1.0, 0.0) * tf.log( 861 | # tf.clip_by_value(tf.reshape(logits, [-1, emb_dim]), 1e-20, 1.0) 862 | # ) 863 | # ) / (sequence_length * batch_size) 864 | # 865 | # 866 | # 867 | # log_perp_list = [] 868 | # for logit, target, weight in zip(logits, targets, weights): 869 | # pass 870 | 871 | # def sequence_loss_by_example(logits, targets, weights, 872 | # average_across_timesteps=True, 873 | # softmax_loss_function=None,up_reward=None, policy_gradient=None, name=None): 874 | # if len(targets) != len(logits) or len(weights) != len(logits): 875 | # raise ValueError("Lengths of logits, weights, and targets must be the same " 876 | # "%d, %d, %d." % (len(logits), len(weights), len(targets))) 877 | # with ops.name_scope(name, "sequence_loss_by_example", 878 | # logits + targets + weights): 879 | # log_perp_list = [] 880 | # for logit, target, weight in zip(logits, targets, weights): 881 | # if softmax_loss_function is None: 882 | # # TODO(irving,ebrevdo): This reshape is needed because 883 | # # sequence_loss_by_example is called with scalars sometimes, which 884 | # # violates our general scalar strictness policy. 885 | # target = array_ops.reshape(target, [-1]) 886 | # crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 887 | # logit, target) 888 | # else: 889 | # #crossent = softmax_loss_function(logit, target) 890 | # crossent = tf.cond(up_reward, 891 | # lambda :policy_gradient(logit, target), 892 | # lambda :softmax_loss_function(logit,target)) 893 | # log_perp_list.append(crossent * weight) 894 | # log_perps = math_ops.add_n(log_perp_list) 895 | # if average_across_timesteps: 896 | # total_size = math_ops.add_n(weights) 897 | # total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 898 | # log_perps /= total_size 899 | # return log_perps 900 | 901 | def sequence_loss_by_example(logits, targets, weights, 902 | average_across_timesteps=True, 903 | softmax_loss_function=None, name=None): 904 | if len(targets) != len(logits) or len(weights) != len(logits): 905 | raise ValueError("Lengths of logits, weights, and targets must be the same " 906 | "%d, %d, %d." % (len(logits), len(weights), len(targets))) 907 | with ops.name_scope(name, "sequence_loss_by_example", 908 | logits + targets + weights): 909 | log_perp_list = [] 910 | for logit, target, weight in zip(logits, targets, weights): 911 | if softmax_loss_function is None: 912 | # TODO(irving,ebrevdo): This reshape is needed because 913 | # sequence_loss_by_example is called with scalars sometimes, which 914 | # violates our general scalar strictness policy. 915 | target = array_ops.reshape(target, [-1]) 916 | crossent = nn_ops.sparse_softmax_cross_entropy_with_logits( 917 | logit, target) 918 | else: 919 | crossent = softmax_loss_function(logit, target) 920 | log_perp_list.append(crossent * weight) 921 | log_perps = math_ops.add_n(log_perp_list) 922 | if average_across_timesteps: 923 | total_size = math_ops.add_n(weights) 924 | total_size += 1e-12 # Just to avoid division by 0 for all-0 weights. 925 | log_perps /= total_size 926 | return log_perps 927 | 928 | 929 | def sequence_loss(logits, targets, weights, 930 | average_across_timesteps=True, average_across_batch=True, 931 | softmax_loss_function=None, name=None): 932 | """Weighted cross-entropy loss for a sequence of logits, batch-collapsed. 933 | 934 | Args: 935 | logits: List of 2D Tensors of shape [batch_size x num_decoder_symbols]. 936 | targets: List of 1D batch-sized int32 Tensors of the same length as logits. 937 | weights: List of 1D batch-sized float-Tensors of the same length as logits. 938 | average_across_timesteps: If set, divide the returned cost by the total 939 | label weight. 940 | average_across_batch: If set, divide the returned cost by the batch size. 941 | softmax_loss_function: Function (inputs-batch, labels-batch) -> loss-batch 942 | to be used instead of the standard softmax (the default if this is None). 943 | name: Optional name for this operation, defaults to "sequence_loss". 944 | 945 | Returns: 946 | A scalar float Tensor: The average log-perplexity per symbol (weighted). 947 | 948 | Raises: 949 | ValueError: If len(logits) is different from len(targets) or len(weights). 950 | """ 951 | with ops.name_scope(name, "sequence_loss", logits + targets + weights): 952 | cost = math_ops.reduce_sum(sequence_loss_by_example( 953 | logits, targets, weights, 954 | average_across_timesteps=average_across_timesteps, 955 | softmax_loss_function=softmax_loss_function)) 956 | if average_across_batch: 957 | batch_size = array_ops.shape(targets[0])[0] 958 | return cost / math_ops.cast(batch_size, cost.dtype) 959 | else: 960 | return cost 961 | 962 | def sequence_loss_by_mle(logits, targets, vocab_size, sequence_length, batch_size, output_projection=None): 963 | #print("logits: ", np.shape(logits[0])) 964 | #logits: [seq_len, batch_size, emb_dim] 965 | #targets: [seq_len, batch_size] =====transpose====> [batch_size, seq_len] 966 | # labels = tf.to_int32(tf.transpose(targets)) 967 | #targets: [seq_len, batch_size] ====reshape[-1]====> [seq_len * batch_size] 968 | labels = tf.to_int32(tf.reshape(targets, [-1])) 969 | 970 | if output_projection is not None: 971 | #logits = nn_ops.xw_plus_b(logits, output_projection[0], output_projection[1]) 972 | logits = [tf.matmul(logit, output_projection[0]) + output_projection[1] for logit in logits] 973 | 974 | reshape_logits = tf.reshape(logits, [-1, vocab_size]) #[seq_len * batch_size, vocab_size] 975 | 976 | prediction = tf.clip_by_value(reshape_logits, 1e-20, 1.0) 977 | 978 | pretrain_loss = -tf.reduce_sum( 979 | # [seq_len * batch_size , vocab_size] 980 | tf.one_hot(labels, vocab_size, 1.0, 0.0) * tf.log(prediction) 981 | ) / (sequence_length * batch_size) 982 | return pretrain_loss 983 | 984 | 985 | def model_with_buckets(encoder_inputs, decoder_inputs, targets, weights, buckets, vocab_size, batch_size, seq2seq, 986 | output_projection=None, softmax_loss_function=None, per_example_loss=False, name=None): 987 | if len(encoder_inputs) < buckets[-1][0]: 988 | raise ValueError("Length of encoder_inputs (%d) must be at least that of la" 989 | "st bucket (%d)." % (len(encoder_inputs), buckets[-1][0])) 990 | if len(targets) < buckets[-1][1]: 991 | raise ValueError("Length of targets (%d) must be at least that of last" 992 | "bucket (%d)." % (len(targets), buckets[-1][1])) 993 | if len(weights) < buckets[-1][1]: 994 | raise ValueError("Length of weights (%d) must be at least that of last" 995 | "bucket (%d)." % (len(weights), buckets[-1][1])) 996 | 997 | all_inputs = encoder_inputs + decoder_inputs + targets + weights 998 | losses = [] 999 | outputs = [] 1000 | encoder_states = [] 1001 | with ops.name_scope(name, "model_with_buckets", all_inputs): 1002 | for j, bucket in enumerate(buckets): 1003 | with variable_scope.variable_scope(variable_scope.get_variable_scope(), 1004 | reuse=True if j > 0 else None): 1005 | bucket_outputs, decoder_states, encoder_state = seq2seq(encoder_inputs[:bucket[0]], 1006 | decoder_inputs[:bucket[1]]) 1007 | outputs.append(bucket_outputs) 1008 | #print("bucket outputs: %s" %bucket_outputs) 1009 | encoder_states.append(encoder_state) 1010 | if per_example_loss: 1011 | losses.append(sequence_loss_by_example( 1012 | outputs[-1], targets[:bucket[1]], weights[:bucket[1]], 1013 | softmax_loss_function=softmax_loss_function)) 1014 | else: 1015 | # losses.append(sequence_loss_by_mle(outputs[-1], targets[:bucket[1]], vocab_size, bucket[1], batch_size, output_projection)) 1016 | losses.append(sequence_loss(outputs[-1], targets[:bucket[1]], weights[:bucket[1]], softmax_loss_function=softmax_loss_function)) 1017 | 1018 | return outputs, losses, encoder_states -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | import sys 6 | import time 7 | import gen.generator as gens 8 | import disc.hier_disc as h_disc 9 | import random 10 | import utils.conf as conf 11 | import argparse # Command line parsing 12 | from corpus.textdata import TextData 13 | 14 | import connection 15 | 16 | gen_config = conf.gen_config 17 | disc_config = conf.disc_config 18 | evl_config = conf.disc_config 19 | 20 | G_STEPS = 1 21 | D_STEPS = 5 22 | 23 | # text_data = None 24 | 25 | 26 | # pre train discriminator 27 | def disc_pre_train(text_data): 28 | train_set = gens.create_disc_train_set(gen_config, text_data, -1, None, gen_config.disc_data_batch_num) 29 | h_disc.hier_train(disc_config, evl_config, text_data.getVocabularySize(), train_set) 30 | 31 | 32 | # pre train generator 33 | def gen_pre_train(text_data): 34 | gens.train(gen_config, text_data) 35 | 36 | 37 | # test gen model 38 | def gen_test_interactive(text_data): 39 | gens.inference_interactive(text_data) 40 | 41 | 42 | def get_negative_decoder_inputs(sess, gen_model, encoder_inputs, decoder_inputs, 43 | target_weights, bucket_id, mc_search=False): 44 | _, _, out_logits = gen_model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, 45 | forward_only=True, mc_search=mc_search) 46 | seq_tokens = [] 47 | for seq in out_logits: 48 | row_token = [] 49 | for t in seq: 50 | row_token.append(int(np.argmax(t, axis=0))) 51 | seq_tokens.append(row_token) 52 | 53 | seq_tokens_t = [] 54 | for col in range(len(seq_tokens[0])): 55 | seq_tokens_t.append([seq_tokens[row][col] for row in range(len(seq_tokens))]) 56 | 57 | return seq_tokens_t 58 | 59 | 60 | def softmax(x): 61 | prob = np.exp(x) / np.sum(np.exp(x), axis=0) 62 | return prob 63 | 64 | 65 | # discriminator api 66 | def disc_step(sess, bucket_id, disc_model, train_query, train_answer, train_labels, forward_only=False): 67 | feed_dict = {} 68 | 69 | for i in range(len(train_query)): 70 | 71 | feed_dict[disc_model.query[i].name] = train_query[i] 72 | 73 | for i in range(len(train_answer)): 74 | feed_dict[disc_model.answer[i].name] = train_answer[i] 75 | 76 | feed_dict[disc_model.target.name] = train_labels 77 | 78 | loss = 0.0 79 | if forward_only: 80 | fetches = [disc_model.b_logits[bucket_id]] 81 | logits = sess.run(fetches, feed_dict) 82 | logits = logits[0] 83 | else: 84 | fetches = [disc_model.b_train_op[bucket_id], disc_model.b_loss[bucket_id], disc_model.b_logits[bucket_id]] 85 | train_op, loss, logits = sess.run(fetches, feed_dict) 86 | 87 | # softmax operation 88 | logits = np.transpose(softmax(np.transpose(logits))) 89 | 90 | reward, gen_num = 0.0, 0 91 | for logit, label in zip(logits, train_labels): 92 | # if label == 0 that means the answer is machine generated, 93 | # logit[0] means the 0's probability, logit[1] means the 1's probability 94 | # so when label == 0, we get logit[1] means the disc discriminate the machine generated answer is human's probability 95 | # and use it as reward for generator training 96 | if int(label) == 0: 97 | reward += logit[1] 98 | gen_num += 1 99 | reward = reward / gen_num 100 | 101 | return reward, loss 102 | 103 | 104 | # Adversarial Learning for Neural Dialogue Generation 105 | def al_train(text_data): 106 | with tf.Session() as sess: 107 | train_set = gens.create_train_set(gen_config, text_data) 108 | 109 | total_qa_size = 0 110 | for i, set in enumerate(train_set): 111 | length = len(set) 112 | print("Generator train_set_{} len: {}".format(i, length)) 113 | total_qa_size += length 114 | print("Generator train_set total size is {} QA".format(total_qa_size)) 115 | 116 | train_bucket_sizes = [len(train_set[b]) for b in range(len(gen_config.buckets))] 117 | train_total_size = float(sum(train_bucket_sizes)) 118 | train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size 119 | for i in range(len(train_bucket_sizes))] 120 | vocab_size = text_data.getVocabularySize() 121 | disc_model = h_disc.create_model(sess, disc_config, vocab_size, disc_config.name_model) 122 | gen_model = gens.create_model(sess, gen_config, vocab_size, forward_only=False, 123 | name_scope=gen_config.name_model) 124 | 125 | current_step = 0 126 | step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 127 | gen_loss_summary = tf.Summary() 128 | disc_loss_summary = tf.Summary() 129 | 130 | gen_writer = tf.summary.FileWriter(gen_config.tensorboard_dir, sess.graph) 131 | disc_writer = tf.summary.FileWriter(disc_config.tensorboard_dir, sess.graph) 132 | 133 | while True: 134 | current_step += 1 135 | random_number_01 = np.random.random_sample() 136 | bucket_id = min([i for i in range(len(train_buckets_scale)) 137 | if train_buckets_scale[i] > random_number_01]) 138 | start_time = time.time() 139 | print("==================Update Discriminator: %d==================" % current_step) 140 | for i in range(D_STEPS): 141 | print("=============It's the %d time update Discriminator in current step=============" % (i+1)) 142 | 143 | # 1. Sample (X,Y) from real data and sample ^Y from G(*|X) 144 | query_set, answer_set, gen_set = gens.create_disc_train_set(gen_config, text_data, bucket_id, 145 | train_set, 1, sess, gen_model) 146 | 147 | b_query, b_answer, b_gen = query_set[bucket_id], answer_set[bucket_id], gen_set[bucket_id] 148 | 149 | train_query, train_answer, train_labels = h_disc.hier_get_batch(disc_config, len(b_query) - 1, 150 | b_query, b_answer, b_gen) 151 | train_query = np.transpose(train_query) 152 | train_answer = np.transpose(train_answer) 153 | 154 | # 2. Update D using (X,Y) as positive examples and(X,^Y) as negative examples 155 | _, disc_step_loss = disc_step(sess, bucket_id, disc_model, train_query, train_answer, 156 | train_labels, forward_only=False) 157 | disc_loss += disc_step_loss / (D_STEPS * disc_config.steps_per_checkpoint) 158 | if i == D_STEPS - 1: 159 | print("disc_step_loss: ", disc_step_loss) 160 | 161 | print("==================Update Generator: %d==================" % current_step) 162 | for j in range(G_STEPS): 163 | print("=============It's the %d time update Generator in current step=============" % (j+1)) 164 | # 1. Sample (X,Y) from real data 165 | encoder_inputs, decoder_inputs, target_weights,\ 166 | source_inputs, source_outputs = gens.get_batch(gen_config, train_set, bucket_id, 167 | gen_config.batch_size, text_data) 168 | 169 | # 2. Sample ^Y from G(*|X) for generator update 170 | decoder_inputs_negative = get_negative_decoder_inputs(sess, gen_model, encoder_inputs, 171 | decoder_inputs, target_weights, bucket_id) 172 | decoder_inputs_negative = np.transpose(decoder_inputs_negative) 173 | 174 | # 3. Sample ^Y from G(*|X) with Monte Carlo search 175 | train_query, train_answer, train_labels = [], [], [] 176 | for query, answer in zip(source_inputs, source_outputs): 177 | train_query.append(query) 178 | train_answer.append(answer) 179 | train_labels.append(1) 180 | for _ in range(gen_config.beam_size): 181 | gen_set = get_negative_decoder_inputs(sess, gen_model, encoder_inputs, decoder_inputs, 182 | target_weights, bucket_id, mc_search=True) 183 | for i, output in enumerate(gen_set): 184 | train_query.append(train_query[i]) 185 | train_answer.append(output) 186 | train_labels.append(0) 187 | 188 | train_query = np.transpose(train_query) 189 | train_answer = np.transpose(train_answer) 190 | 191 | # 4. Compute Reward r for (X,^Y) using D.---based on Monte Carlo search 192 | reward, _ = disc_step(sess, bucket_id, disc_model, train_query, train_answer, 193 | train_labels, forward_only=True) 194 | batch_reward += reward / gen_config.steps_per_checkpoint 195 | print("step_reward: ", reward) 196 | 197 | # 5. Update G on (X,^Y) using reward r 198 | gan_adjusted_loss, gen_step_loss, _ = gen_model.step(sess, encoder_inputs, decoder_inputs_negative, 199 | target_weights, bucket_id, forward_only=False, 200 | reward=reward, up_reward=True, debug=True) 201 | gen_loss += gen_step_loss / gen_config.steps_per_checkpoint 202 | 203 | print("gen_step_loss: ", gen_step_loss) 204 | print("gen_step_adjusted_loss: ", gan_adjusted_loss) 205 | 206 | # 6. Teacher-Forcing: Update G on (X,Y) 207 | t_adjusted_loss, t_step_loss, a = gen_model.step(sess, encoder_inputs, decoder_inputs, 208 | target_weights, bucket_id, forward_only=False) 209 | t_loss += t_step_loss / (G_STEPS * gen_config.steps_per_checkpoint) 210 | 211 | print("t_step_loss: ", t_step_loss) 212 | print("t_adjusted_loss", t_adjusted_loss) 213 | 214 | if current_step % gen_config.steps_per_checkpoint == 0: 215 | 216 | step_time += (time.time() - start_time) / gen_config.steps_per_checkpoint 217 | 218 | print("current_steps: %d, step time: %.4f, disc_loss: %.3f, gen_loss: %.3f, t_loss: %.3f, reward: %.3f " 219 | % (current_step, step_time, disc_loss, gen_loss, t_loss, batch_reward)) 220 | 221 | disc_loss_value = disc_loss_summary.value.add() 222 | disc_loss_value.tag = disc_config.name_loss 223 | disc_loss_value.simple_value = float(disc_loss) 224 | disc_writer.add_summary(disc_loss_summary, int(sess.run(disc_model.global_step))) 225 | 226 | gen_global_steps = sess.run(gen_model.global_step) 227 | gen_loss_value = gen_loss_summary.value.add() 228 | gen_loss_value.tag = gen_config.name_loss 229 | gen_loss_value.simple_value = float(gen_loss) 230 | t_loss_value = gen_loss_summary.value.add() 231 | t_loss_value.tag = gen_config.teacher_loss 232 | t_loss_value.simple_value = float(t_loss) 233 | batch_reward_value = gen_loss_summary.value.add() 234 | batch_reward_value.tag = gen_config.reward_name 235 | batch_reward_value.simple_value = float(batch_reward) 236 | gen_writer.add_summary(gen_loss_summary, int(gen_global_steps)) 237 | 238 | if current_step % (gen_config.steps_per_checkpoint * 4) == 0: 239 | print("current_steps: %d, save disc model" % current_step) 240 | disc_ckpt_dir = os.path.abspath(os.path.join(disc_config.train_dir, "checkpoints")) 241 | if not os.path.exists(disc_ckpt_dir): 242 | os.makedirs(disc_ckpt_dir) 243 | disc_model_path = os.path.join(disc_ckpt_dir, "disc.model") 244 | disc_model.saver.save(sess, disc_model_path, global_step=disc_model.global_step) 245 | 246 | print("current_steps: %d, save gen model" % current_step) 247 | gen_ckpt_dir = os.path.abspath(os.path.join(gen_config.train_dir, "checkpoints")) 248 | if not os.path.exists(gen_ckpt_dir): 249 | os.makedirs(gen_ckpt_dir) 250 | gen_model_path = os.path.join(gen_ckpt_dir, "gen.model") 251 | gen_model.saver.save(sess, gen_model_path, global_step=gen_model.global_step) 252 | 253 | step_time, disc_loss, gen_loss, t_loss, batch_reward = 0.0, 0.0, 0.0, 0.0, 0.0 254 | sys.stdout.flush() 255 | 256 | 257 | def parse_args(): 258 | """ 259 | Parse the arguments from the given command line 260 | Args: 261 | args (list): List of arguments to parse. If None, the default sys.argv will be parsed 262 | """ 263 | 264 | parser = argparse.ArgumentParser() 265 | parser.add_argument('--test', type=int, default=0, help='Test mode') 266 | parser.add_argument('--maxLength', type=int, default=40, 267 | help='Maximum length of the sentence (for input and output), define number of maximum step of the RNN') 268 | parser.add_argument('--filterVocab', type=int, default=1, 269 | help='Remove rarely used words (by default words used only once). 0 to keep all words.') 270 | parser.add_argument('--vocabularySize', type=int, default=40000, 271 | help='Limit the number of words in the vocabulary (0 for unlimited)') 272 | parser.add_argument('--corpus', choices=TextData.corpusChoices(), default=TextData.corpusChoices()[0], 273 | help='Corpus on which extract the dataset.') 274 | parser.add_argument('--rootDir', type=str, default='corpus', help='Folder where to look for the models and data') 275 | parser.add_argument('--datasetTag', type=str, default='', 276 | help='Add a tag to the dataset (file where to load the vocabulary and the precomputed samples, not the original corpus). Useful to manage multiple versions. Also used to define the file used for the lightweight format.') # The samples are computed from the corpus if it does not exist already. There are saved in \'data/samples/\' 277 | parser.add_argument('--skipLines', action='store_true', default=True, 278 | help='Generate training samples by only using even conversation lines as questions (and odd lines as answer). Useful to train the network on a particular person.') 279 | args = parser.parse_args() 280 | return args 281 | 282 | 283 | def main(): 284 | # global text_data 285 | args = parse_args() 286 | text_data = TextData(args) 287 | try: 288 | if args.test: 289 | gen_test_interactive(text_data) 290 | else: 291 | # Step 1: Pre train the Generator and get the GEN_0 model 292 | gen_pre_train(text_data) 293 | 294 | # Step 2: GEN model test 295 | # gen_test_interactive(text_data) 296 | 297 | # Step 3: Pre train the Discriminator and get the DISC_0 model 298 | # disc_pre_train(text_data) 299 | 300 | # Step 4: Train the GEN model and DISC model using AL/RL 301 | # al_train(text_data) 302 | 303 | # Step 5: GEN model test 304 | # gen_test_interactive(text_data) 305 | 306 | # integration test 307 | # connection.start_server(text_data, True) 308 | except KeyboardInterrupt: 309 | pass 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekberu/chatbot_al/b732a587fca68867d63c020f10997d4f1d7087d8/utils/__init__.py -------------------------------------------------------------------------------- /utils/conf.py: -------------------------------------------------------------------------------- 1 | __author__ = 'liuyuemaicha' 2 | import os 3 | 4 | 5 | class disc_config(object): 6 | # batch_size = 256 7 | batch_size = 16 8 | lr = 0.001 9 | lr_decay = 0.9 10 | embed_dim = 512 11 | steps_per_checkpoint = 100 12 | #hidden_neural_size = 128 13 | num_layers = 2 14 | train_dir = './disc_data/' 15 | name_model = "disc_model" 16 | tensorboard_dir = "./tensorboard/disc_log/" 17 | name_loss = "disc_loss" 18 | max_len = 50 19 | piece_size = batch_size * steps_per_checkpoint 20 | piece_dir = "./disc_data/batch_piece/" 21 | #query_len = 0 22 | valid_num = 100 23 | init_scale = 0.1 24 | num_class = 2 25 | keep_prob = 0.5 26 | #num_epoch = 60 27 | #max_decay_epoch = 30 28 | max_grad_norm = 5 29 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 30 | epoch_num = 100 31 | 32 | 33 | class gen_config(object): 34 | # batch_size = 128 35 | batch_size = 8 36 | beam_size = 7 37 | learning_rate = 0.001 38 | learning_rate_decay_factor = 0.99 39 | max_gradient_norm = 5.0 40 | disc_data_batch_num = 100 41 | emb_dim = 512 42 | num_layers = 2 43 | train_dir = "./gen_data/" 44 | name_model = "gen_model" 45 | tensorboard_dir = "./tensorboard/gen_log/" 46 | name_loss = "gen_loss" 47 | teacher_loss = "teacher_loss" 48 | reward_name = "reward" 49 | max_train_data_size = 0 50 | steps_per_checkpoint = 100 51 | # bucket->(source_size, target_size), source is the query, target is the answer 52 | buckets = [(5, 10), (10, 15), (20, 25), (40, 50)] 53 | buckets_concat = [(5, 10), (10, 15), (20, 25), (40, 50), (100, 50)] 54 | 55 | 56 | 57 | --------------------------------------------------------------------------------