├── LICENSE ├── ReadMe ├── camrest ├── data_handler.py ├── evaluate.py ├── model.py ├── single.json ├── test.py ├── train.py └── vocab.json ├── data ├── CamRest │ ├── test.json │ ├── train.json │ └── val.json ├── InCar │ ├── test.json │ ├── train.json │ └── val.json └── Maluuba │ ├── entities.json │ ├── test.json │ ├── train.json │ └── val.json ├── incar ├── data_handler.py ├── evaluate.py ├── model.py ├── test.py ├── train.py └── vocab.json └── maluuba ├── data_handler.py ├── evaluate.py ├── model.py ├── test.py ├── train.py └── vocab.json /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /ReadMe: -------------------------------------------------------------------------------- 1 | This folder contains the data used in experiments for the paper. 2 | 3 | It contains processed data for the following publicly available datasets: 4 | 5 | 1) Incar : https://nlp.stanford.edu/blog/a-new-multi-turn-multi-domain-task-oriented-dialogue-dataset/ 6 | 7 | 2) CamRest : https://www.repository.cam.ac.uk/handle/1810/260970 8 | 9 | 3) Maluuba Frames : https://datasets.maluuba.com/Frames 10 | 11 | Each sub-folder contains the train, validation and test split of the dataset. 12 | Each data sample contains the following: 13 | 14 | 1) Dialog context 15 | 2) Gold response 16 | 3) KB queries (if any) 17 | 4) KB results corresponding to the queries. 18 | 19 | Before running the code, you will need to download the Glove vectors into the "data" folder. 20 | Download the 200 dimensional vectors and name the file: glove.6B.200d.txt -------------------------------------------------------------------------------- /camrest/data_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import random 4 | import nltk 5 | import os 6 | import sys 7 | import numpy as np 8 | import logging 9 | logging.getLogger().setLevel(logging.INFO) 10 | 11 | single_word_entities = json.load(open("./single.json")) 12 | 13 | class DataHandler(object): 14 | 15 | def __init__(self,emb_dim,batch_size,train_path,val_path,test_path,vocab_path,glove_path): 16 | 17 | self.batch_size = batch_size 18 | self.train_path = train_path 19 | self.vocab_threshold = 3 20 | self.val_path = val_path 21 | self.test_path = test_path 22 | self.vocab_path = vocab_path 23 | self.emb_dim = emb_dim 24 | self.glove_path = glove_path 25 | 26 | self.vocab = self.load_vocab() 27 | self.input_vocab_size = self.vocab['input_vocab_size'] 28 | self.output_vocab_size = self.vocab['output_vocab_size'] 29 | self.generate_vocab_size = self.vocab['generate_vocab_size'] 30 | self.emb_init = self.load_glove_vectors() 31 | 32 | self.train_data = json.load(open(self.train_path)) 33 | self.val_data = json.load(open(self.val_path)) 34 | self.test_data = json.load(open(self.test_path)) 35 | 36 | random.shuffle(self.train_data) 37 | random.shuffle(self.val_data) 38 | random.shuffle(self.test_data) 39 | 40 | self.val_data_full = self.append_dummy_data(self.val_data) 41 | 42 | self.train_index = 0 43 | self.val_index = 0 44 | self.train_num = len(self.train_data) 45 | self.val_num = len(self.val_data_full) 46 | 47 | def append_dummy_data(self,data): 48 | new_data = [] 49 | for i in range(0,len(data)): 50 | data[i]['dummy'] = 0 51 | new_data.append(copy.copy(data[i])) 52 | 53 | last = data[-1] 54 | last['dummy'] = 1 55 | for _ in range(0,self.batch_size - len(data)%self.batch_size): 56 | new_data.append(copy.copy(last)) 57 | 58 | return copy.copy(new_data) 59 | 60 | 61 | def load_glove_vectors(self): 62 | logging.info("Loading pre-trained Word Embeddings") 63 | filename = self.glove_path + "glove.6B.200d.txt" 64 | glove = {} 65 | file = open(filename,'r') 66 | for line in file.readlines(): 67 | row = line.strip().split(' ') 68 | glove[row[0]] = np.asarray(row[1:]) 69 | logging.info('Loaded GloVe!') 70 | file.close() 71 | embeddings_init = np.random.normal(size=(self.vocab['input_vocab_size'],self.emb_dim)).astype('f') 72 | count = 0 73 | for word in self.vocab['vocab_mapping']: 74 | if word in glove: 75 | count = count + 1 76 | embeddings_init[self.vocab['vocab_mapping'][word]] = glove[word] 77 | 78 | del glove 79 | 80 | logging.info("Loaded "+str(count)+" pre-trained Word Embeddings") 81 | return embeddings_init 82 | 83 | 84 | def load_vocab(self): 85 | if os.path.isfile(self.vocab_path): 86 | logging.info("Loading vocab from file") 87 | with open(self.vocab_path) as f: 88 | return json.load(f) 89 | else: 90 | logging.info("Vocab file not found. Computing Vocab") 91 | with open(self.train_path) as f: 92 | train_data = json.load(f) 93 | with open(self.val_path) as f: 94 | val_data = json.load(f) 95 | with open(self.test_path) as f: 96 | test_data = json.load(f) 97 | 98 | full_data = [] 99 | full_data.extend(train_data) 100 | full_data.extend(val_data) 101 | full_data.extend(test_data) 102 | 103 | return self.get_vocab(full_data) 104 | 105 | def get_vocab(self,data): 106 | 107 | vocab = {} 108 | for d in data: 109 | utts = [] 110 | utts.append(d['output']) 111 | utts.extend(d['context']) 112 | for utt in utts: 113 | tokens = utt.split(" ") 114 | for token in tokens: 115 | if token.lower() not in vocab: 116 | vocab[token.lower()] = 1 117 | else: 118 | vocab[token.lower()] = vocab[token.lower()] + 1 119 | 120 | for item in d['kb']: 121 | for key in item: 122 | if key.lower() not in vocab: 123 | vocab[key.lower()] = 1 124 | else: 125 | vocab[key.lower()] = vocab[key.lower()] + 1 126 | token = item[key] 127 | if token.lower() not in vocab: 128 | vocab[token.lower()] = 1 129 | else: 130 | vocab[token.lower()] = vocab[token.lower()] + 1 131 | 132 | words = vocab.keys() 133 | words.append("$STOP$") 134 | words.append("$PAD$") 135 | 136 | for i in range(1,9): 137 | words.append("$u"+str(i)+"$") 138 | words.append("$s"+str(i)+"$") 139 | words.append("$u9$") 140 | 141 | generate_words = [] 142 | copy_words = [] 143 | for word in words: 144 | if word in single_word_entities or '_' in word: 145 | if word != 'api_call': 146 | copy_words.append(word) 147 | else: 148 | generate_words.append(word) 149 | else: 150 | generate_words.append(word) 151 | 152 | output_vocab_size = len(words) + 1 153 | 154 | generate_indices = [i for i in range(1,len(generate_words)+1)] 155 | copy_indices = [i for i in range(len(generate_words)+1,output_vocab_size)] 156 | random.shuffle(generate_indices) 157 | random.shuffle(copy_indices) 158 | 159 | mapping = {} 160 | rev_mapping = {} 161 | 162 | for i in range(0,len(generate_words)): 163 | mapping[generate_words[i]] = generate_indices[i] 164 | rev_mapping[str(generate_indices[i])] = generate_words[i] 165 | 166 | for i in range(0,len(copy_words)): 167 | mapping[copy_words[i]] = copy_indices[i] 168 | rev_mapping[str(copy_indices[i])] = copy_words[i] 169 | 170 | mapping["$GO$"] = 0 171 | rev_mapping[0] = "$GO$" 172 | vocab_dict = {} 173 | vocab_dict['vocab_mapping'] = mapping 174 | vocab_dict['rev_mapping'] = rev_mapping 175 | vocab_dict['input_vocab_size'] = len(words) + 1 176 | vocab_dict['generate_vocab_size'] = len(generate_words) + 1 177 | vocab_dict['output_vocab_size'] = output_vocab_size 178 | 179 | with open(self.vocab_path,'w') as f: 180 | json.dump(vocab_dict,f) 181 | 182 | logging.info("Vocab file created") 183 | 184 | return vocab_dict 185 | 186 | def get_sentinel(self,i,context): 187 | if i%2 == 0: 188 | speaker = "u" 189 | turn = (context - i + 1)/2 190 | else: 191 | speaker = "s" 192 | turn = (context - i)/2 193 | return "$"+speaker+str(turn)+"$" 194 | 195 | def vectorize(self,batch,train): 196 | vectorized = {} 197 | vectorized['inp_utt'] = [] 198 | vectorized['out_utt'] = [] 199 | vectorized['inp_len'] = [] 200 | vectorized['context_len'] = [] 201 | vectorized['out_len'] = [] 202 | vectorized['kb'] = [] 203 | vectorized['kb_mask'] = [] 204 | vectorized['keys'] = [] 205 | vectorized['keys_mask'] = [] 206 | 207 | vectorized['dummy'] = [] 208 | vectorized['empty'] = [] 209 | 210 | vectorized['context'] = [] 211 | vectorized['knowledge'] = [] 212 | max_inp_utt_len = 0 213 | max_out_utt_len = 0 214 | max_context_len = 0 215 | kb_len = 0 216 | keys_len = 7 217 | 218 | for item in batch: 219 | 220 | if len(item['context']) > max_context_len: 221 | max_context_len = len(item['context']) 222 | 223 | for utt in item['context']: 224 | tokens = utt.split(" ") 225 | 226 | if len(tokens) > max_inp_utt_len: 227 | max_inp_utt_len = len(tokens) 228 | 229 | tokens = item['output'].split(" ") 230 | if len(tokens) > max_out_utt_len: 231 | max_out_utt_len = len(tokens) 232 | 233 | if len(item['kb']) > kb_len: 234 | kb_len = len(item['kb']) 235 | 236 | max_inp_utt_len = max_inp_utt_len + 1 237 | 238 | max_out_utt_len = max_out_utt_len + 1 239 | vectorized['max_out_utt_len'] = max_out_utt_len 240 | 241 | 242 | for item in batch: 243 | 244 | vectorized['context'].append(item['context']) 245 | vectorized['knowledge'].append(item['kb']) 246 | 247 | if item['kb'] == []: 248 | vectorized['empty'].append(0) 249 | else: 250 | vectorized['empty'].append(1) 251 | if not train: 252 | vectorized['dummy'].append(item['dummy']) 253 | vector_inp = [] 254 | vector_len = [] 255 | for i in range(0,len(item['context'])): 256 | utt = item['context'][i] 257 | inp = [] 258 | sentinel = self.get_sentinel(i,len(item['context'])) 259 | tokens = utt.split(" ") + [sentinel] 260 | for token in tokens: 261 | inp.append(self.vocab['vocab_mapping'][token]) 262 | 263 | vector_len.append(len(tokens)) 264 | for _ in range(0,max_inp_utt_len - len(tokens)): 265 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 266 | vector_inp.append(copy.copy(inp)) 267 | 268 | vectorized['context_len'].append(len(item['context'])) 269 | 270 | for _ in range(0,max_context_len - len(item['context'])): 271 | vector_len.append(0) 272 | inp = [] 273 | for _ in range(0,max_inp_utt_len): 274 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 275 | vector_inp.append(copy.copy(inp)) 276 | 277 | vectorized['inp_utt'].append(copy.copy(vector_inp)) 278 | vectorized['inp_len'].append(vector_len) 279 | 280 | vector_out = [] 281 | tokens = item['output'].split(" ") 282 | tokens.append('$STOP$') 283 | for token in tokens: 284 | vector_out.append(self.vocab['vocab_mapping'][token]) 285 | 286 | for _ in range(0,max_out_utt_len - len(tokens)): 287 | vector_out.append(self.vocab['vocab_mapping']["$PAD$"]) 288 | vectorized['out_utt'].append(copy.copy(vector_out)) 289 | vectorized['out_len'].append(len(tokens)) 290 | 291 | vector_keys = [] 292 | vector_keys_mask = [] 293 | vector_kb = [] 294 | vector_kb_mask = [] 295 | 296 | for result in item['kb']: 297 | vector_result = [] 298 | vector_result_keys = [] 299 | vector_result_keys_mask = [] 300 | vector_kb_mask.append(1) 301 | for key in result: 302 | vector_result.append(self.vocab['vocab_mapping'][result[key]]) 303 | vector_result_keys.append(self.vocab['vocab_mapping'][key]) 304 | vector_result_keys_mask.append(1) 305 | 306 | for _ in range(0,keys_len-len(result.keys())): 307 | vector_result_keys.append(self.vocab['vocab_mapping']["$PAD$"]) 308 | vector_result_keys_mask.append(0) 309 | vector_result.append(self.vocab['vocab_mapping']["$PAD$"]) 310 | vector_keys.append(copy.copy(vector_result_keys)) 311 | vector_keys_mask.append(copy.copy(vector_result_keys_mask)) 312 | vector_kb.append(copy.copy(vector_result)) 313 | 314 | if item['kb'] == []: 315 | vector_kb_mask.append(1) 316 | vector_kb.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 317 | vector_keys.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 318 | vector_keys_mask.append([1] + [0 for _ in range(0,keys_len-1)]) 319 | 320 | current_kb_len = len(vector_kb_mask) 321 | 322 | for _ in range(0,kb_len - current_kb_len): 323 | vector_kb_mask.append(0) 324 | vector_kb.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 325 | vector_keys.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 326 | vector_keys_mask.append([1] + [0 for _ in range(0,keys_len-1)]) 327 | 328 | vectorized['kb'].append(copy.copy(vector_kb)) 329 | vectorized['kb_mask'].append(copy.copy(vector_kb_mask)) 330 | vectorized['keys'].append(copy.copy(vector_keys)) 331 | vectorized['keys_mask'].append(copy.copy(vector_keys_mask)) 332 | 333 | return vectorized 334 | 335 | def get_batch(self,train): 336 | 337 | epoch_done = False 338 | 339 | if train: 340 | index = self.train_index 341 | batch = self.vectorize(self.train_data[index:index+self.batch_size],train) 342 | self.train_index = self.train_index + self.batch_size 343 | 344 | if self.train_index + self.batch_size > self.train_num: 345 | self.train_index = 0 346 | random.shuffle(self.train_data) 347 | epoch_done = True 348 | 349 | else: 350 | index = self.val_index 351 | batch = self.vectorize(self.val_data_full[index:index+self.batch_size],train) 352 | self.val_index = self.val_index + self.batch_size 353 | 354 | if self.val_index + self.batch_size > self.val_num: 355 | self.val_index = 0 356 | random.shuffle(self.val_data) 357 | self.val_data_full = self.append_dummy_data(self.val_data) 358 | epoch_done = True 359 | 360 | 361 | return batch,epoch_done 362 | -------------------------------------------------------------------------------- /camrest/evaluate.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import copy 3 | import json 4 | import csv 5 | from collections import Counter 6 | from nltk.util import ngrams 7 | from nltk.corpus import stopwords 8 | from nltk.tokenize import word_tokenize 9 | from nltk.stem import WordNetLemmatizer 10 | import math, re, argparse 11 | import functools 12 | 13 | entities = json.load(open('./single.json')) 14 | 15 | def score(parallel_corpus): 16 | 17 | # containers 18 | count = [0, 0, 0, 0] 19 | clip_count = [0, 0, 0, 0] 20 | r = 0 21 | c = 0 22 | weights = [0.25, 0.25, 0.25, 0.25] 23 | 24 | # accumulate ngram statistics 25 | for hyps, refs in parallel_corpus: 26 | hyps = [hyp.split() for hyp in hyps] 27 | refs = [ref.split() for ref in refs] 28 | for hyp in hyps: 29 | 30 | for i in range(4): 31 | # accumulate ngram counts 32 | hypcnts = Counter(ngrams(hyp, i + 1)) 33 | cnt = sum(hypcnts.values()) 34 | count[i] += cnt 35 | 36 | # compute clipped counts 37 | max_counts = {} 38 | for ref in refs: 39 | refcnts = Counter(ngrams(ref, i + 1)) 40 | for ng in hypcnts: 41 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 42 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 43 | for ng, count in hypcnts.items()) 44 | clip_count[i] += sum(clipcnt.values()) 45 | 46 | # accumulate r & c 47 | bestmatch = [1000, 1000] 48 | for ref in refs: 49 | if bestmatch[0] == 0: break 50 | diff = abs(len(ref) - len(hyp)) 51 | if diff < bestmatch[0]: 52 | bestmatch[0] = diff 53 | bestmatch[1] = len(ref) 54 | r += bestmatch[1] 55 | c += len(hyp) 56 | 57 | # computing bleu score 58 | p0 = 1e-7 59 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 60 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 61 | for i in range(4)] 62 | s = math.fsum(w * math.log(p_n) \ 63 | for w, p_n in zip(weights, p_ns) if p_n) 64 | bleu = bp * math.exp(s) 65 | return bleu 66 | 67 | 68 | data = pickle.load(open("needed.p")) 69 | vocab = json.load(open("./vocab.json")) 70 | outs = [] 71 | golds = [] 72 | 73 | tp_prec = 0.0 74 | tp_recall = 0.0 75 | total_prec = 0.0 76 | total_recall = 0.0 77 | 78 | for i in range(0,len(data['sentences'])): 79 | sentence = data['sentences'][i] 80 | sentence = list(sentence) 81 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 82 | index = len(sentence) 83 | else: 84 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 85 | predicted = [str(sentence[j]) for j in range(0,index)] 86 | ground = data['output'][i] 87 | ground = list(ground) 88 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 89 | ground_truth = [str(ground[j]) for j in range(0,index)] 90 | 91 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 92 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 93 | 94 | for word in out_anon: 95 | if word in entities or '_' in word: 96 | if word != 'api_call': 97 | total_prec = total_prec + 1 98 | if word in gold_anon: 99 | tp_prec = tp_prec + 1 100 | 101 | for word in gold_anon: 102 | if word in entities or '_' in word: 103 | if word != 'api_call': 104 | total_recall = total_recall + 1 105 | if word in out_anon: 106 | tp_recall = tp_recall + 1 107 | 108 | gold = gold_anon 109 | out = out_anon 110 | golds.append(" ".join(gold)) 111 | outs.append(" ".join(out)) 112 | 113 | wrap_generated = [[_] for _ in outs] 114 | wrap_truth = [[_] for _ in golds] 115 | prec = tp_prec/total_prec 116 | recall = tp_recall/total_recall 117 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (score(zip(wrap_generated, wrap_truth)),prec,recall,2*prec*recall/(prec+recall)) 118 | -------------------------------------------------------------------------------- /camrest/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops import embedding_ops, array_ops, math_ops, tensor_array_ops, control_flow_ops 4 | 5 | class DialogueModel(object): 6 | 7 | def __init__(self,device,batch_size,inp_vocab_size,out_vocab_size,generate_size,emb_init,emb_dim,enc_hid_dim,dec_hid_dim,attn_size): 8 | 9 | self.device = device 10 | self.batch_size = batch_size 11 | self.emb_dim = emb_dim 12 | self.inp_vocab_size = inp_vocab_size 13 | self.out_vocab_size = out_vocab_size 14 | self.generate_size = generate_size 15 | self.emb_init = emb_init 16 | self.enc_hid_dim = enc_hid_dim 17 | self.dec_hid_dim = dec_hid_dim 18 | self.attn_size = attn_size 19 | self.generate_size = generate_size 20 | 21 | self.inp_utt = tf.placeholder( 22 | name='inp_utt', dtype=tf.int64, 23 | shape=[self.batch_size, None, None], 24 | ) 25 | 26 | self.inp_len = tf.placeholder( 27 | name='inp_len', dtype=tf.int64, 28 | shape=[self.batch_size, None], 29 | ) 30 | 31 | self.context_len = tf.placeholder( 32 | name='context_len', dtype=tf.int64, 33 | shape=[self.batch_size], 34 | ) 35 | 36 | self.out_utt = tf.placeholder( 37 | name='out_utt', dtype=tf.int64, 38 | shape=[self.batch_size, None], 39 | ) 40 | 41 | self.out_len = tf.placeholder( 42 | name='out_len', dtype=tf.float32, 43 | shape=[self.batch_size], 44 | ) 45 | 46 | self.kb = tf.placeholder( 47 | name='kb', dtype=tf.int64, 48 | shape=[self.batch_size,None,7], 49 | ) 50 | 51 | self.kb_mask = tf.placeholder( 52 | name='kb_mask', dtype=tf.float32, 53 | shape=[self.batch_size,None], 54 | ) 55 | 56 | self.keys = tf.placeholder( 57 | name='keys', dtype=tf.int64, 58 | shape=[self.batch_size,None,7], 59 | ) 60 | 61 | self.keys_mask = tf.placeholder( 62 | name='keys_mask', dtype=tf.float32, 63 | shape=[self.batch_size,None,7], 64 | ) 65 | 66 | self.max_out_utt_len = tf.placeholder( 67 | name = 'max_out_utt_len', dtype=tf.int32, 68 | shape = (), 69 | ) 70 | 71 | self.db_empty = tf.placeholder( 72 | name='db_empty', dtype=tf.float32, 73 | shape=[self.batch_size], 74 | ) 75 | 76 | self.buildArch() 77 | 78 | def buildArch(self): 79 | 80 | with tf.device(self.device): 81 | 82 | self.embeddings = tf.get_variable("embeddings", initializer=tf.constant(self.emb_init)) 83 | self.inp_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.inp_utt) 84 | 85 | with tf.variable_scope("encoder"): 86 | self.encoder_cell_1 = tf.contrib.rnn.GRUCell(self.enc_hid_dim) 87 | self.encoder_cell_2 = tf.contrib.rnn.GRUCell(2*self.enc_hid_dim) 88 | self.flat_inp_emb = tf.reshape(self.inp_utt_emb,shape=[-1,tf.shape(self.inp_utt)[2],self.emb_dim]) 89 | self.flat_inp_len = tf.reshape(self.inp_len,shape=[-1]) 90 | 91 | outputs,output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell_1,cell_bw=self.encoder_cell_1,inputs=self.flat_inp_emb,dtype=tf.float32,sequence_length=self.flat_inp_len,time_major=False) 92 | self.flat_encoder_states = tf.concat(outputs,axis=2) 93 | self.utt_reps = tf.concat(output_states,axis=1) 94 | 95 | self.utt_rep_second = tf.reshape(self.utt_reps,shape=[self.batch_size,-1,2*self.enc_hid_dim]) 96 | self.hidden_states, self.inp_utt_rep = tf.nn.dynamic_rnn(self.encoder_cell_2,self.utt_rep_second,dtype=tf.float32,sequence_length=self.context_len,time_major=False) 97 | self.encoder_states = tf.reshape(tf.reshape(self.flat_encoder_states,shape=[self.batch_size,-1,tf.shape(self.inp_utt)[2],2*self.enc_hid_dim]), shape=[self.batch_size,-1,2*self.enc_hid_dim]) 98 | 99 | 100 | self.kb_emb = embedding_ops.embedding_lookup(self.embeddings, self.kb) 101 | self.keys_emb = embedding_ops.embedding_lookup(self.embeddings, self.keys) 102 | self.result_rep = tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(self.keys_mask,2),-1),tf.reduce_sum(tf.einsum('ijk,ijkl->ijkl',self.keys_mask,self.kb_emb),2)) 103 | 104 | self.start_token = tf.constant([0] * self.batch_size, dtype=tf.int32) 105 | self.out_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.out_utt) 106 | self.processed_x = tf.transpose(self.out_utt_emb,perm=[1,0,2]) 107 | 108 | with tf.variable_scope("decoder"): 109 | self.decoder_cell = tf.contrib.rnn.GRUCell(self.dec_hid_dim) 110 | 111 | self.h0 = self.inp_utt_rep 112 | self.g_output_unit = self.create_output_unit() 113 | 114 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.max_out_utt_len, 115 | dynamic_size=False, infer_shape=True) 116 | 117 | def _g_recurrence(i, x_t, h_tm1, gen_x): 118 | _,h_t = self.decoder_cell(x_t,h_tm1) 119 | o_t = self.g_output_unit(h_t) # batch x vocab , prob 120 | next_token = tf.cast(tf.reshape(tf.argmax(o_t, 1), [self.batch_size]), tf.int32) 121 | x_tp1 = embedding_ops.embedding_lookup(self.embeddings,next_token) # batch x emb_dim 122 | gen_x = gen_x.write(i, next_token) # indices, batch_size 123 | return i + 1, x_tp1, h_t, gen_x 124 | 125 | _, _, _, self.gen_x = control_flow_ops.while_loop( 126 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 127 | body=_g_recurrence, 128 | loop_vars=(tf.constant(0, dtype=tf.int32), 129 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 130 | self.h0, gen_x)) 131 | 132 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 133 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 134 | 135 | # gen_x contains the colours sampled as outputs 136 | # Hence, gen_x is used while calculating accuracy 137 | 138 | g_predictions = tensor_array_ops.TensorArray( 139 | dtype=tf.float32, size=self.max_out_utt_len, 140 | dynamic_size=False, infer_shape=True) 141 | 142 | ta_emb_x = tensor_array_ops.TensorArray( 143 | dtype=tf.float32, size=self.max_out_utt_len) 144 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 145 | 146 | def _train_recurrence(i, x_t, h_tm1, g_predictions): 147 | _,h_t = self.decoder_cell(x_t,h_tm1) 148 | o_t = self.g_output_unit(h_t) 149 | g_predictions = g_predictions.write(i, o_t) # batch x vocab_size 150 | x_tp1 = ta_emb_x.read(i) 151 | return i + 1, x_tp1, h_t, g_predictions 152 | 153 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 154 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 155 | body=_train_recurrence, 156 | loop_vars=(tf.constant(0, dtype=tf.int32), 157 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 158 | self.h0, g_predictions)) 159 | 160 | self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 161 | 162 | self.loss_mask = tf.sequence_mask(self.out_len,self.max_out_utt_len,dtype=tf.float32) 163 | self.ground_truth = tf.one_hot(self.out_utt,on_value=tf.constant(1,dtype=tf.float32),off_value=tf.constant(0,dtype=tf.float32),depth=self.out_vocab_size,dtype=tf.float32) 164 | self.log_predictions = tf.log(self.g_predictions + 1e-20) 165 | self.cross_entropy = tf.multiply(self.ground_truth,self.log_predictions) 166 | self.cross_entropy_sum = tf.reduce_sum(self.cross_entropy,2) 167 | self.masked_cross_entropy = tf.multiply(self.loss_mask,self.cross_entropy_sum) 168 | self.sentence_loss = tf.divide(tf.reduce_sum(self.masked_cross_entropy,1),tf.reduce_sum(self.loss_mask,1)) 169 | self.loss = -tf.reduce_mean(self.sentence_loss) 170 | 171 | def create_output_unit(self): 172 | 173 | self.W1 = tf.get_variable("W1",shape=[2*self.enc_hid_dim+self.dec_hid_dim,2*self.enc_hid_dim],dtype=tf.float32) 174 | self.W2 = tf.get_variable("W2",shape=[2*self.enc_hid_dim,self.attn_size],dtype=tf.float32) 175 | self.w = tf.get_variable("w",shape=[self.attn_size,1],dtype=tf.float32) 176 | self.U = tf.get_variable("U",shape=[self.dec_hid_dim+2*self.enc_hid_dim,self.generate_size],dtype=tf.float32) 177 | self.W_1 = tf.get_variable("W_1",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 178 | self.W_2 = tf.get_variable("W_2",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 179 | self.W_12 = tf.get_variable("W_12",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 180 | self.W_22 = tf.get_variable("W_22",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 181 | self.r_1 = tf.get_variable("r_1",shape=[self.attn_size,1],dtype=tf.float32) 182 | self.r_2 = tf.get_variable("r_2",shape=[self.attn_size,1],dtype=tf.float32) 183 | self.b1 = tf.get_variable("b1",shape=[self.generate_size],dtype=tf.float32) 184 | self.b2 = tf.get_variable("b2",shape=[1],dtype=tf.float32) 185 | self.b3 = tf.get_variable("b3",shape=[1],dtype=tf.float32) 186 | self.W3 = tf.get_variable("W3",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 187 | self.W4 = tf.get_variable("W4",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 188 | 189 | def unit(hidden_state): 190 | 191 | hidden_state_expanded_attn = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.encoder_states)[1],1]) 192 | attn_rep = tf.concat([self.encoder_states,hidden_state_expanded_attn],axis=2) 193 | attn_rep = tf.nn.tanh(tf.einsum('ijk,kl->ijl',tf.nn.tanh(tf.einsum("ijk,kl->ijl",attn_rep,self.W1)),self.W2)) 194 | u_i = tf.squeeze(tf.einsum('ijk,kl->ijl',attn_rep,self.w),2) 195 | inp_len_mask = tf.sequence_mask(self.inp_len,tf.shape(self.inp_utt)[2],dtype=tf.float32) 196 | attn_mask = tf.reshape(inp_len_mask,shape=[self.batch_size,-1]) 197 | exp_u_i_masked = tf.multiply(tf.cast(attn_mask,dtype=tf.float64),tf.exp(tf.cast(u_i,dtype=tf.float64))) 198 | a = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(exp_u_i_masked,1),-1),exp_u_i_masked),dtype=tf.float32) 199 | inp_attn = tf.reduce_sum(tf.einsum('ij,ijk->ijk',a,self.encoder_states),1) 200 | 201 | generate_dist = tf.nn.softmax(math_ops.matmul(tf.concat([hidden_state,inp_attn],axis=1),self.U) + self.b1) 202 | extra_zeros = tf.zeros([self.batch_size,self.out_vocab_size - self.generate_size]) 203 | extended_generate_dist = tf.concat([generate_dist,extra_zeros],axis=1) 204 | 205 | hidden_state_expanded_result = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.kb)[1],1]) 206 | inp_attn_expanded_result = tf.tile(array_ops.expand_dims(inp_attn,1),[1,tf.shape(self.kb)[1],1]) 207 | result_attn_rep = tf.concat([self.result_rep,hidden_state_expanded_result,inp_attn_expanded_result],axis=2) 208 | result_attn_rep = tf.nn.tanh(tf.einsum("ijk,kl->ijl",tf.nn.tanh(tf.einsum("ijk,kl->ijl",result_attn_rep,self.W_1)),self.W_12)) 209 | beta_logits = tf.squeeze(tf.einsum('ijk,kl->ijl',result_attn_rep,self.r_1),2) 210 | beta_masked = tf.multiply(tf.cast(self.kb_mask,dtype=tf.float64),tf.exp(tf.cast(beta_logits,dtype=tf.float64))) 211 | beta = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(beta_masked,1),-1),beta_masked),dtype=tf.float32) 212 | 213 | hidden_state_expanded_keys = tf.tile(array_ops.expand_dims(array_ops.expand_dims(hidden_state,1),1),[1,tf.shape(self.kb)[1],tf.shape(self.kb)[2],1]) 214 | inp_attn_expanded_keys = tf.tile(array_ops.expand_dims(array_ops.expand_dims(inp_attn,1),1),[1,tf.shape(self.kb)[1],tf.shape(self.kb)[2],1]) 215 | result_key_rep = tf.concat([self.keys_emb,hidden_state_expanded_keys,inp_attn_expanded_keys],axis=3) 216 | result_key_rep = tf.nn.tanh(tf.einsum('ijkl,lm->ijkm',tf.nn.tanh(tf.einsum('ijkl,lm->ijkm',result_key_rep,self.W_2)),self.W_22)) 217 | gamma_logits = tf.squeeze(tf.einsum('ijkl,lm->ijkm',result_key_rep,self.r_2),3) 218 | gamma_masked = tf.multiply(tf.cast(self.keys_mask,dtype=tf.float64),tf.exp(tf.cast(gamma_logits,dtype=tf.float64))) 219 | gamma = tf.einsum('ij,ijk->ijk',beta,tf.cast(tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(gamma_masked,2),-1),gamma_masked),dtype=tf.float32)) 220 | 221 | batch_nums_context = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 222 | batch_nums_tiled_context = tf.tile(batch_nums_context,[1,tf.shape(self.encoder_states)[1]]) 223 | flat_inp_utt = tf.reshape(self.inp_utt,shape=[self.batch_size,-1]) 224 | indices_context = tf.stack([batch_nums_tiled_context,flat_inp_utt],axis=2) 225 | shape = [self.batch_size,self.out_vocab_size] 226 | context_copy_dist = tf.scatter_nd(indices_context,a,shape) 227 | 228 | db_rep = tf.reduce_sum(tf.einsum('ij,ijk->ijk',beta,self.result_rep),1) 229 | 230 | p_db = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W4)+self.b3) 231 | p_db = tf.tile(p_db,[1,self.out_vocab_size]) 232 | one_minus_fn = lambda x: 1 - x 233 | one_minus_pdb = tf.map_fn(one_minus_fn, p_db) 234 | 235 | p_gens = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W3)+self.b2) 236 | p_gens = tf.tile(p_gens,[1,self.out_vocab_size]) 237 | one_minus_fn = lambda x: 1 - x 238 | one_minus_pgens = tf.map_fn(one_minus_fn, p_gens) 239 | 240 | batch_nums = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 241 | kb_ids = tf.reshape(self.kb,shape=[self.batch_size,-1]) 242 | num_kb_ids = tf.shape(kb_ids)[1] 243 | batch_nums_tiled = tf.tile(batch_nums,[1,num_kb_ids]) 244 | indices = tf.stack([batch_nums_tiled,kb_ids],axis=2) 245 | updates = tf.reshape(gamma,shape=[self.batch_size,-1]) 246 | shape = [self.batch_size,self.out_vocab_size] 247 | kb_dist = tf.scatter_nd(indices,updates,shape) 248 | kb_dist = tf.einsum('i,ij->ij',self.db_empty,kb_dist) 249 | 250 | copy_dist = tf.multiply(p_db,kb_dist) + tf.multiply(one_minus_pdb,context_copy_dist) 251 | final_dist = tf.multiply(p_gens,extended_generate_dist) + tf.multiply(one_minus_pgens,copy_dist) 252 | 253 | return final_dist 254 | 255 | return unit 256 | 257 | def get_feed_dict(self,batch): 258 | 259 | fd = { 260 | self.inp_utt : batch['inp_utt'], 261 | self.inp_len : batch['inp_len'], 262 | self.context_len: batch['context_len'], 263 | self.out_utt : batch['out_utt'], 264 | self.out_len : batch['out_len'], 265 | self.kb : batch['kb'], 266 | self.kb_mask : batch['kb_mask'], 267 | self.keys : batch['keys'], 268 | self.keys_mask : batch['keys_mask'], 269 | self.db_empty : batch['empty'], 270 | self.max_out_utt_len : batch['max_out_utt_len'] 271 | } 272 | 273 | return fd 274 | -------------------------------------------------------------------------------- /camrest/single.json: -------------------------------------------------------------------------------- 1 | ["centre", "italian", "cheap", "cb21ab", "east", "international", "cb259aq", "indian", "expensive", "cb21dp", "south", "chinese", "cb17ag", "moderate", "cb17dy", "cb17aa", "cb28pb", "cb58pa", "cb23nj", "cb58jj", "cb21db", "eraina", "european", "cb23rh", "british", "ask", "cb21uf", "cb23pp", "cb21aw", "gastropub", "cb12qa", "west", "cb43le", "cb21rt", "kohinoor", "cb12as", "mexican", "prezzo", "cb30ad", "cb12bd", "lebanese", "cb21nt", "cb23ar", "cb11ln", "cb12az", "cocum", "cb30ah", "cb21su", "cb39ey", "cb21qa", "cb23dt", "vietnamese", "cb30af", "spanish", "north", "cb41jy", "french", "cb43ax", "cb21tw", "wagamama", "japanese", "cb12lf", "nandos", "portuguese", "cb21eg", "cb21sj", "cb21rg", "korean", "turkish", "cb13nf", "cb21rq", "cb11lh", "cb41uy", "cotto", "cb11bg", "cb58wr", "galleria", "cb21uw", "cb23qf", "bedouin", "african", "mediterranean", "cb23ll", "cb41eh", "hakka", "cb43hl", "meghna", "cb43lf", "cb21la", "cb58aq", "cb11dg", "seafood", "cb21qy", "cb58ba", "thai", "cb41nl", "cb28nx", "graffiti", "cb30lx", "cb30dq", "cb41ep", "anatolia", "cb21uj", "panahar", "cb11hr", "cb23ju", "cb21ug", "cb30df", "cb13nl", "rajmahal", "cb58rg", "kymmoy", "cb19hx", "cb41ha", "cote", "cb21nw", "cb23jx"] -------------------------------------------------------------------------------- /camrest/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def score(self,parallel_corpus): 54 | 55 | # containers 56 | count = [0, 0, 0, 0] 57 | clip_count = [0, 0, 0, 0] 58 | r = 0 59 | c = 0 60 | weights = [0.25, 0.25, 0.25, 0.25] 61 | 62 | # accumulate ngram statistics 63 | for hyps, refs in parallel_corpus: 64 | hyps = [hyp.split() for hyp in hyps] 65 | refs = [ref.split() for ref in refs] 66 | for hyp in hyps: 67 | 68 | for i in range(4): 69 | # accumulate ngram counts 70 | hypcnts = Counter(ngrams(hyp, i + 1)) 71 | cnt = sum(hypcnts.values()) 72 | count[i] += cnt 73 | 74 | # compute clipped counts 75 | max_counts = {} 76 | for ref in refs: 77 | refcnts = Counter(ngrams(ref, i + 1)) 78 | for ng in hypcnts: 79 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 80 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 81 | for ng, count in hypcnts.items()) 82 | clip_count[i] += sum(clipcnt.values()) 83 | 84 | # accumulate r & c 85 | bestmatch = [1000, 1000] 86 | for ref in refs: 87 | if bestmatch[0] == 0: break 88 | diff = abs(len(ref) - len(hyp)) 89 | if diff < bestmatch[0]: 90 | bestmatch[0] = diff 91 | bestmatch[1] = len(ref) 92 | r += bestmatch[1] 93 | c += len(hyp) 94 | 95 | # computing bleu score 96 | p0 = 1e-7 97 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 98 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 99 | for i in range(4)] 100 | s = math.fsum(w * math.log(p_n) \ 101 | for w, p_n in zip(weights, p_ns) if p_n) 102 | bleu = bp * math.exp(s) 103 | return bleu 104 | 105 | def evaluate(self,data,vocab): 106 | entities = json.load(open("./single.json")) 107 | outs = [] 108 | golds = [] 109 | 110 | tp_prec = 0.0 111 | tp_recall = 0.0 112 | total_prec = 0.0 113 | total_recall = 0.0 114 | 115 | for i in range(0,len(data['sentences'])): 116 | sentence = data['sentences'][i] 117 | sentence = list(sentence) 118 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 119 | index = len(sentence) 120 | else: 121 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 122 | predicted = [str(sentence[j]) for j in range(0,index)] 123 | ground = data['output'][i] 124 | ground = list(ground) 125 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 126 | ground_truth = [str(ground[j]) for j in range(0,index)] 127 | 128 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 129 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 130 | 131 | for word in out_anon: 132 | if word in entities or '_' in word: 133 | if word != 'api_call': 134 | total_prec = total_prec + 1 135 | if word in gold_anon: 136 | tp_prec = tp_prec + 1 137 | 138 | for word in gold_anon: 139 | if word in entities or '_' in word: 140 | if word != 'api_call': 141 | total_recall = total_recall + 1 142 | if word in out_anon: 143 | tp_recall = tp_recall + 1 144 | 145 | gold = gold_anon 146 | out = out_anon 147 | golds.append(" ".join(gold)) 148 | outs.append(" ".join(out)) 149 | 150 | wrap_generated = [[_] for _ in outs] 151 | wrap_truth = [[_] for _ in golds] 152 | prec = tp_prec/total_prec 153 | recall = tp_recall/total_recall 154 | if prec == 0 or recall == 0: 155 | f1 = 0.0 156 | else: 157 | f1 = 2*prec*recall/(prec+recall) 158 | overall_f1 = f1 159 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 160 | return overall_f1 161 | 162 | def test(self): 163 | test_epoch_done = False 164 | 165 | teststep = 0 166 | testLoss = 0.0 167 | needed = {} 168 | needed['sentences'] = [] 169 | needed['output'] = [] 170 | needed['context'] = [] 171 | needed['kb'] = [] 172 | 173 | while not test_epoch_done: 174 | teststep = teststep + 1 175 | batch, test_epoch_done = self.handler.get_batch(train=False) 176 | feedDict = self.model.get_feed_dict(batch) 177 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 178 | 179 | if 1 not in batch['dummy']: 180 | needed['sentences'].extend(sentences) 181 | needed['output'].extend(batch['out_utt']) 182 | needed['context'].extend(batch['context']) 183 | needed['kb'].extend(batch['knowledge']) 184 | else: 185 | index = batch['dummy'].index(1) 186 | needed['sentences'].extend(sentences[0:index]) 187 | needed['output'].extend(batch['out_utt'][0:index]) 188 | needed['context'].extend(batch['context'][0:index]) 189 | needed['kb'].extend(batch['knowledge'][0:index]) 190 | pickle.dump(needed,open("needed.p","w")) 191 | self.evaluate(needed,self.handler.vocab) 192 | 193 | def main(): 194 | 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('--batch_size', type=int, default=32) 197 | parser.add_argument('--emb_dim', type=int, default=200) 198 | parser.add_argument('--enc_hid_dim', type=int, default=128) 199 | parser.add_argument('--dec_hid_dim', type=int, default=256) 200 | parser.add_argument('--attn_size', type=int, default=200) 201 | parser.add_argument('--epochs', type=int, default=25) 202 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 203 | parser.add_argument('--dataset_path', type=str, default='../data/CamRest/') 204 | parser.add_argument('--glove_path', type=str, default='../data/') 205 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 206 | config = parser.parse_args() 207 | 208 | DEVICE = "/gpu:0" 209 | 210 | logging.info("Loading Data") 211 | 212 | handler = DataHandler( 213 | emb_dim = config.emb_dim, 214 | batch_size = config.batch_size, 215 | train_path = config.dataset_path + "train.json", 216 | val_path = config.dataset_path + "test.json", 217 | test_path = config.dataset_path + "test.json", 218 | vocab_path = "./vocab.json", 219 | glove_path = config.glove_path) 220 | 221 | logging.info("Loading Architecture") 222 | 223 | model = DialogueModel( 224 | device = DEVICE, 225 | batch_size = config.batch_size, 226 | inp_vocab_size = handler.input_vocab_size, 227 | out_vocab_size = handler.output_vocab_size, 228 | generate_size = handler.generate_vocab_size, 229 | emb_init = handler.emb_init, 230 | emb_dim = config.emb_dim, 231 | enc_hid_dim = config.enc_hid_dim, 232 | dec_hid_dim = config.dec_hid_dim, 233 | attn_size = config.attn_size) 234 | 235 | logging.info("Loading Trainer") 236 | 237 | trainer = Trainer( 238 | model=model, 239 | handler=handler, 240 | ckpt_path=config.checkpoint, 241 | num_epochs=config.epochs, 242 | learning_rate = config.learning_rate) 243 | 244 | trainer.test() 245 | 246 | main() -------------------------------------------------------------------------------- /camrest/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def trainData(self): 54 | curEpoch = 0 55 | step = 0 56 | epochLoss = [] 57 | 58 | logging.info("Training the model") 59 | 60 | best_f1 = 0.0 61 | 62 | while curEpoch <= self.epochs: 63 | step = step + 1 64 | 65 | batch, epoch_done = self.handler.get_batch(train=True) 66 | feedDict = self.model.get_feed_dict(batch) 67 | 68 | fetch = [self.global_step, self.model.loss, self.optimizer] 69 | mod_step,loss,_ = self.sess.run(fetch,feed_dict = feedDict) 70 | epochLoss.append(loss) 71 | 72 | if step % 40 == 0: 73 | outstr = "step: "+str(step)+" Loss: "+str(loss) 74 | logging.info(outstr) 75 | 76 | if epoch_done: 77 | train_loss = np.mean(np.asarray(epochLoss)) 78 | 79 | val_epoch_done = False 80 | valstep = 0 81 | valLoss = 0.0 82 | needed = {} 83 | needed['sentences'] = [] 84 | needed['output'] = [] 85 | 86 | while not val_epoch_done: 87 | valstep = valstep + 1 88 | batch, val_epoch_done = self.handler.get_batch(train=False) 89 | feedDict = self.model.get_feed_dict(batch) 90 | val_loss,sentences = self.sess.run([self.model.loss,self.model.gen_x],feed_dict=feedDict) 91 | if 1 not in batch['dummy']: 92 | needed['sentences'].extend(sentences) 93 | needed['output'].extend(batch['out_utt']) 94 | else: 95 | index = batch['dummy'].index(1) 96 | needed['sentences'].extend(sentences[0:index]) 97 | needed['output'].extend(batch['out_utt'][0:index]) 98 | valLoss = valLoss + val_loss 99 | 100 | valLoss = valLoss / float(valstep) 101 | outstr = "Train-info: "+ "Epoch: ",str(curEpoch)+" Loss: "+str(train_loss) 102 | logging.info(outstr) 103 | outstr = "Val-info: "+"Epoch "+str(curEpoch)+" Loss: "+str(valLoss) 104 | logging.info(outstr) 105 | if curEpoch > 2: 106 | current_f1 = self.evaluate(needed,self.handler.vocab) 107 | if current_f1 >= best_f1: 108 | best_f1 = current_f1 109 | self.saver.save(self.sess, os.path.join(self.ckpt_path, 'model'), global_step=curEpoch) 110 | 111 | epochLoss = [] 112 | curEpoch = curEpoch + 1 113 | 114 | def score(self,parallel_corpus): 115 | 116 | # containers 117 | count = [0, 0, 0, 0] 118 | clip_count = [0, 0, 0, 0] 119 | r = 0 120 | c = 0 121 | weights = [0.25, 0.25, 0.25, 0.25] 122 | 123 | # accumulate ngram statistics 124 | for hyps, refs in parallel_corpus: 125 | hyps = [hyp.split() for hyp in hyps] 126 | refs = [ref.split() for ref in refs] 127 | for hyp in hyps: 128 | 129 | for i in range(4): 130 | # accumulate ngram counts 131 | hypcnts = Counter(ngrams(hyp, i + 1)) 132 | cnt = sum(hypcnts.values()) 133 | count[i] += cnt 134 | 135 | # compute clipped counts 136 | max_counts = {} 137 | for ref in refs: 138 | refcnts = Counter(ngrams(ref, i + 1)) 139 | for ng in hypcnts: 140 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 141 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 142 | for ng, count in hypcnts.items()) 143 | clip_count[i] += sum(clipcnt.values()) 144 | 145 | # accumulate r & c 146 | bestmatch = [1000, 1000] 147 | for ref in refs: 148 | if bestmatch[0] == 0: break 149 | diff = abs(len(ref) - len(hyp)) 150 | if diff < bestmatch[0]: 151 | bestmatch[0] = diff 152 | bestmatch[1] = len(ref) 153 | r += bestmatch[1] 154 | c += len(hyp) 155 | 156 | # computing bleu score 157 | p0 = 1e-7 158 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 159 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 160 | for i in range(4)] 161 | s = math.fsum(w * math.log(p_n) \ 162 | for w, p_n in zip(weights, p_ns) if p_n) 163 | bleu = bp * math.exp(s) 164 | return bleu 165 | 166 | def evaluate(self,data,vocab): 167 | entities = json.load(open("./single.json")) 168 | outs = [] 169 | golds = [] 170 | 171 | tp_prec = 0.0 172 | tp_recall = 0.0 173 | total_prec = 0.0 174 | total_recall = 0.0 175 | 176 | for i in range(0,len(data['sentences'])): 177 | sentence = data['sentences'][i] 178 | sentence = list(sentence) 179 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 180 | index = len(sentence) 181 | else: 182 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 183 | predicted = [str(sentence[j]) for j in range(0,index)] 184 | ground = data['output'][i] 185 | ground = list(ground) 186 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 187 | ground_truth = [str(ground[j]) for j in range(0,index)] 188 | 189 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 190 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 191 | 192 | for word in out_anon: 193 | if word in entities or '_' in word: 194 | if word != 'api_call': 195 | total_prec = total_prec + 1 196 | if word in gold_anon: 197 | tp_prec = tp_prec + 1 198 | 199 | for word in gold_anon: 200 | if word in entities or '_' in word: 201 | if word != 'api_call': 202 | total_recall = total_recall + 1 203 | if word in out_anon: 204 | tp_recall = tp_recall + 1 205 | 206 | gold = gold_anon 207 | out = out_anon 208 | golds.append(" ".join(gold)) 209 | outs.append(" ".join(out)) 210 | 211 | wrap_generated = [[_] for _ in outs] 212 | wrap_truth = [[_] for _ in golds] 213 | prec = tp_prec/total_prec 214 | recall = tp_recall/total_recall 215 | if prec == 0 or recall == 0: 216 | f1 = 0.0 217 | else: 218 | f1 = 2*prec*recall/(prec+recall) 219 | overall_f1 = f1 220 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 221 | return overall_f1 222 | 223 | def test(self): 224 | test_epoch_done = False 225 | 226 | teststep = 0 227 | testLoss = 0.0 228 | needed = {} 229 | needed['sentences'] = [] 230 | needed['output'] = [] 231 | needed['context'] = [] 232 | needed['kb'] = [] 233 | 234 | while not test_epoch_done: 235 | teststep = teststep + 1 236 | batch, test_epoch_done = self.handler.get_batch(train=False) 237 | feedDict = self.model.get_feed_dict(batch) 238 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 239 | 240 | if 1 not in batch['dummy']: 241 | needed['sentences'].extend(sentences) 242 | needed['output'].extend(batch['out_utt']) 243 | needed['context'].extend(batch['context']) 244 | needed['kb'].extend(batch['knowledge']) 245 | else: 246 | index = batch['dummy'].index(1) 247 | needed['sentences'].extend(sentences[0:index]) 248 | needed['output'].extend(batch['out_utt'][0:index]) 249 | needed['context'].extend(batch['context'][0:index]) 250 | needed['kb'].extend(batch['knowledge'][0:index]) 251 | pickle.dump(needed,open("needed.p","w")) 252 | self.evaluate(needed,self.handler.vocab) 253 | 254 | def main(): 255 | 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--batch_size', type=int, default=32) 258 | parser.add_argument('--emb_dim', type=int, default=200) 259 | parser.add_argument('--enc_hid_dim', type=int, default=128) 260 | parser.add_argument('--dec_hid_dim', type=int, default=256) 261 | parser.add_argument('--attn_size', type=int, default=200) 262 | parser.add_argument('--epochs', type=int, default=25) 263 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 264 | parser.add_argument('--dataset_path', type=str, default='../data/CamRest/') 265 | parser.add_argument('--glove_path', type=str, default='../data/') 266 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 267 | config = parser.parse_args() 268 | 269 | DEVICE = "/gpu:0" 270 | 271 | logging.info("Loading Data") 272 | 273 | handler = DataHandler( 274 | emb_dim = config.emb_dim, 275 | batch_size = config.batch_size, 276 | train_path = config.dataset_path + "train.json", 277 | val_path = config.dataset_path + "val.json", 278 | test_path = config.dataset_path + "test.json", 279 | vocab_path = "./vocab.json", 280 | glove_path = config.glove_path) 281 | 282 | logging.info("Loading Architecture") 283 | 284 | model = DialogueModel( 285 | device = DEVICE, 286 | batch_size = config.batch_size, 287 | inp_vocab_size = handler.input_vocab_size, 288 | out_vocab_size = handler.output_vocab_size, 289 | generate_size = handler.generate_vocab_size, 290 | emb_init = handler.emb_init, 291 | emb_dim = config.emb_dim, 292 | enc_hid_dim = config.enc_hid_dim, 293 | dec_hid_dim = config.dec_hid_dim, 294 | attn_size = config.attn_size) 295 | 296 | logging.info("Loading Trainer") 297 | 298 | trainer = Trainer( 299 | model=model, 300 | handler=handler, 301 | ckpt_path=config.checkpoint, 302 | num_epochs=config.epochs, 303 | learning_rate = config.learning_rate) 304 | 305 | trainer.trainData() 306 | 307 | main() -------------------------------------------------------------------------------- /camrest/vocab.json: -------------------------------------------------------------------------------- 1 | {"generate_vocab_size": 826, "output_vocab_size": 1248, "rev_mapping": {"0": "$GO$", "199": "certain", "1200": "hk_fusion", "1175": "01223_323178", "861": "01223_461661", "1144": "cb23dt", "1218": "cb41ep", "1145": "cb41nl", "818": "hungarian", "819": "mind", "346": "tonight", "347": "both", "340": "me", "341": "code", "342": "best", "343": "01223352500.", "810": "should", "811": "appear", "812": "probably", "813": "$s6$", "814": "numbers", "815": "southern", "816": "servies", "817": "462354", "1149": "the_cambridge_chop_house", "719": "other", "718": "works", "717": "11", "716": "because", "715": "im", "714": "had", "713": "indonesian", "712": "woah", "711": "mid", "710": "match", "618": "trouble", "915": "the_golden_curry", "914": "restaurant_one_seven", "880": "little_seoul", "917": "napier_street_city_centre", "594": "change", "1068": "the_copper_kettle", "916": "cb30ad", "1182": "taj_tandoori", "195": "caribbean", "1061": "cambridge_city_football_club_milton_road_chesterton", "911": "12_market_hill_city_centre", "1168": "japanese", "1062": "01223_323361", "1065": "01223_366668", "1064": "the_gandhi", "619": "menu", "910": "jesus_lane_fen_ditton", "913": "chinese", "298": "seve", "299": "moderatley", "296": "thank", "297": "oh", "294": "postcode", "295": "kind", "292": "recommend", "293": "cuisine", "290": "for", "291": "get", "591": "victoria", "590": "further", "593": "binh", "592": "01223308681", "595": "bar", "198": "glad", "597": "system", "596": "malasian", "599": "ca", "598": "canapes", "197": "would", "196": "towards", "191": "16", "190": "oriental", "193": "somewhere", "192": "$u2$", "270": "464630.", "271": "choice", "272": "tandoon", "273": "option", "274": "them", "275": "woluld", "276": "medium", "277": "anything", "278": "end", "279": "modern", "1067": "cb41eh", "524": "01223247877", "525": "whether", "214": "couple", "527": "expand", "520": "whats", "521": "secondary", "522": "01223311053.", "523": "also", "1014": "peking_restaurant", "1015": "01223_302010", "1016": "don_pasquale_pizzeria", "1017": "01223_312598", "528": "restauant", "529": "moderatre", "1012": "stazione_restaurant_and_coffee_bar", "1103": "free_school_lane_city_centre", "1234": "pizza_express_fen_ditton", "1025": "lan_hong_house", "449": "does", "448": "6", "1230": "2_rose_crescent_city_centre", "1231": "cb23jx", "1232": "205_victoria_road_chesterton", "1233": "ask", "443": "any", "442": "place", "441": "you", "440": "high", "447": "ciquito", "446": "fine", "445": "more", "444": "restaraunt", "108": "green", "109": "3qf", "1135": "01223_324033", "102": "see", "103": "01223302010.", "100": "y", "101": "at", "106": "settle", "107": "pricey", "104": "info", "105": "3.h", "902": "thanh_binh", "903": "54_king_street_city_centre", "39": "01223568988", "38": "like", "906": "cambridge_lodge_restaurant", "907": "gourmet_burger_kitchen", "904": "cocum", "905": "pizza_hut_city_centre", "33": "regent", "32": "total", "31": "quite", "30": "priced", "37": "01223362525", "36": "01223315232.", "601": "24", "34": "thinking", "125": "variety", "640": "moderately", "643": "look", "642": "milton", "645": "list", "644": "dinner", "438": "areas", "646": "star", "436": "back", "648": "afraid", "434": "hills", "435": "have", "432": "downtown", "433": "dine", "430": "$s4$", "431": "requested", "1156": "riverside_brasserie", "339": "$s2$", "338": "$u9$", "1236": "moderate", "335": "dish", "334": "regarding", "337": "singapore", "336": "swiss", "331": "type", "330": "pardon", "333": "severing", "332": "are", "1120": "anatolia", "559": "01223", "1000": "cb21rg", "558": "$STOP$", "854": "golden_house", "344": "01223727410.", "856": "pipasha_restaurant", "857": "01223_249955", "850": "cb12as", "851": "40428_king_street_city_centre", "852": "cb17aa", "345": "price", "858": "01223_412299", "859": "01223_309147", "1081": "travellers_rest", "748": "of", "6": "apologize", "1150": "curry_garden", "900": "01223_464630", "848": "01223_727410", "99": "dontcare", "98": "d", "844": "meghna", "1085": "gastropub", "91": "fairly", "90": "fact", "93": "average", "92": "made", "95": "were", "94": "malaysian", "97": "is", "96": "restaurants", "1216": "cb11bg", "348": "on", "349": "01223248882.", "153": "problem", "740": "01223727410", "741": "sound", "742": "lovely", "743": "requirements", "744": "call", "745": "nice", "746": "01223350688.", "747": "while", "555": "send", "554": "services", "557": "so", "556": "!", "551": "hear", "550": "b", "553": "needed", "552": "city", "238": "northampton", "239": "try", "234": "appears", "235": "american", "236": "9", "237": "specify", "230": "suits", "231": "restuarants", "232": "been", "233": "giving", "1050": "fitzbillies_restaurant", "1198": "68_histon_road_chesterton", "1052": "01223_412430", "1053": "01223_359506", "1054": "01223_308871", "1055": "01223_351707", "1056": "yu_garden", "1057": "21_burleigh_street_city_centre", "1058": "sitar_tandoori", "1190": "cambridge_leisure_park_clifton_way", "1193": "cb43le", "1192": "cb13nf", "1195": "st._michael's_church_trinity_street_city_centre", "1194": "meze_bar_restaurant", "1197": "01223_367755", "1196": "34_-_35_green_street", "1": "someone", "614": "desire", "146": "job", "147": "specifications", "144": "mistake", "145": "this", "142": "01223351707.", "143": "occasion", "140": "resteraunt", "141": "locate", "612": "fulfill", "613": "much", "610": "sir", "611": "301030", "616": "else", "617": "closer", "148": "got", "615": "selections", "1007": "mexican", "912": "1_kings_parade", "1139": "01223_356555", "1006": "east", "1091": "hakka", "951": "4_kings_parade_city_centre", "1005": "cb21ab", "194": "preferences", "1004": "italian", "948": "01223_352607", "949": "la_margherita", "946": "cb23rh", "1003": "01223_350106", "944": "15_-_19_trumpington_street", "879": "01223_315232", "942": "grafton_hotel_619_newmarket_road_fen_ditton", "943": "victoria_avenue_chesterton", "940": "cb21uj", "1002": "01223_361763", "876": "north_american", "689": "preferably", "688": "ring", "1059": "eraina", "685": ".", "684": "hungry", "687": "english", "686": "upscale", "681": "search", "680": "addresses", "683": "108", "682": "fusion", "458": "clarify", "622": "polish", "1225": "01223_301761", "133": "served", "132": "there", "131": "crossover", "130": "atmosphere", "137": "an", "136": "help", "135": "hmm", "134": "increase", "494": "suggest", "495": "restuarant", "496": "crescent", "138": "provide", "490": "c", "491": "available", "492": "'ll", "493": "shiraz", "24": "differnt", "25": "halal", "26": "still", "27": "going", "20": "require", "21": "these", "22": "01223566388", "23": "01223352500", "1243": "01223_277977", "927": "finders_corner_newmarket_road", "28": "romanian", "29": "hi", "407": "'ve", "406": "many", "405": "by", "404": "cantonese", "403": "that", "402": "belgian", "401": "from", "400": "bad", "933": "01223_354382", "932": "196_mill_road_city_centre", "931": "01223_244277", "930": "01223_259988", "937": "01223_464550", "629": "possible", "409": "perform", "408": "205", "1069": "01223_364917", "453": "$s1$", "1028": "7_milton_road_chesterton", "1212": "de_vere_university_arms_regent_street_city_centre", "1241": "clowns_cafe", "1229": "regent_street_city_centre", "1228": "01223_323639", "1018": "sala_thong", "379": "1,2", "378": "01223366552.", "228": "eat", "829": "152_-_154_hills_road", "828": "01223_323737", "1051": "thai", "1060": "golden_wok", "825": "01223307581", "824": "difficulty", "373": "wide", "372": "christmas", "821": "a", "374": "vegetarian", "823": "result", "822": "sorry", "1083": "de_luca_cucina_and_bar", "708": "details", "709": "up", "1176": "01223_358399", "704": "appeals", "705": "choosing", "706": "3-34", "707": "...", "700": "alimentum", "701": "shot", "702": "sanit", "703": "categories", "393": "dining", "392": "14", "391": "locations", "89": "my", "397": "phone", "396": "time", "395": "australian", "394": "money", "82": "saint", "83": "too", "399": "inquiries", "81": "fish", "86": "places", "87": "head", "84": "query", "85": "evening", "797": "your", "796": "apologies", "795": "meant", "794": "1.uy", "793": "suggestion", "792": "andrews", "791": "priced..", "790": "only", "1170": "royal_spice", "799": "fits", "798": "excellent", "7": "nearest", "899": "cb58jj", "1246": "cb21aw", "586": ">", "587": "absolutely", "584": "instead", "585": "address", "582": "bbq", "583": "trying", "580": "requirement", "581": "around", "1133": "01223_369299", "1132": "huntingdon_road_city_centre", "1131": "191_histon_road_chesterton", "1130": "charlie_chan", "1137": "efes_restaurant", "1063": "01223_363471", "588": "int", "589": "it", "245": "ya", "244": "area", "247": "each", "246": "four", "241": "nine", "240": "correct", "243": "10", "242": "traditional", "149": "goodbye", "249": "go", "248": "whichever", "1117": "zizzi_cambridge", "1213": "41518_castle_street_city_centre", "924": "01223_337766", "970": "01223_518111", "925": "39_burleigh_street_city_centre", "519": "australasian", "518": "epensive", "926": "rice_boat", "1009": "6_lensfield_road", "1008": "cb30af", "511": "venetian", "510": "where'a", "513": "style", "512": "maybe", "515": "169", "514": "here", "517": "$s5$", "516": "cool", "623": "th", "459": "german", "621": "prefer", "620": "scottish", "627": "tuscan", "626": "then", "625": "care", "624": "first", "450": "offering", "451": "ant", "452": "7", "628": "servings", "454": "01223400170", "455": "pricerange", "456": "second", "457": "another", "1084": "cb21qa", "179": "anywhere", "178": "2", "177": "kosher", "176": "steakhouse", "175": "eastern", "174": "resturant", "173": "21-24", "172": "01223351707", "171": "tnahh", "170": "able", "977": "43_high_street_cherry_hinton_cherry_hinton", "656": "hope", "975": "g4_cambridge_leisure_park_clifton_way_cherry_hinton", "974": "01223_308681", "973": "bedouin", "972": "30_bridge_street_city_centre", "971": "cb17ag", "183": "varisty", "1080": "01223_356666", "1013": "01223_362054", "654": "which", "979": "european", "978": "cb30df", "182": "singaporean", "657": "anythig", "180": "listings", "2": "might", "652": "01223311053", "187": "unfortunately", "184": "needs", "651": "covered", "886": "11_peas_hill_city_centre", "186": "welconme", "188": "$u6$", "189": "seem", "1122": "01223_352500", "658": "100", "653": "l", "1123": "korean", "650": "panasian", "1124": "panahar", "185": "although", "1125": "market_hill_city_centre", "1148": "7_barnwell_road_fen_ditton", "1089": "grafton_hotel_restaurant", "1127": "529_newmarket_road_fen_ditton", "1227": "hotel_du_vin_and_bistro", "11": ":", "10": "how", "13": "'", "12": "good-bye", "15": "listing", "14": "offer", "17": "01223350688", "16": "suggestions", "19": "restaraunts", "18": "most", "863": "47-53_regent_street", "862": "40270_king_street_city_centre", "865": "01223_306306", "864": "cb21dp", "867": "cb11ln", "866": "efes_restaurants", "884": "01223_353110", "947": "shiraz_restaurant", "938": "451_newmarket_road_fen_ditton", "659": "$u4$", "1226": "cb41jy", "883": "maharajah_tandoori_restaurant", "753": "alright", "881": "01223_244149", "945": "01223_568988", "887": "the_oak_bistro", "831": "cb12qa", "885": "32_bridge_street_city_centre", "752": "consider", "928": "the_lucky_star", "62": "away", "888": "cb23ar", "1032": "south", "1115": "01223_448620", "756": "name", "1179": "14_-16_bridge_street", "929": "rajmahal", "63": "where", "809": "but", "322": "196", "323": "features", "320": "next", "354": "creative", "326": "serving", "327": "eatery", "324": "as", "325": "mill", "1224": "cb58aq", "328": "provided", "329": "out", "562": "meets", "775": "looking", "200": "c.b.1", "777": "server", "202": "25", "205": "dont", "204": "same", "773": "multiple", "206": "make", "209": "yo", "208": "moderatly", "779": "something", "778": "either", "889": "the_good_luck_chinese_food_takeaway", "77": "just", "76": "contacting", "75": "15", "74": "okay", "73": "selling", "72": "dbye", "71": "gold", "70": "chips", "655": "exist", "79": "post", "78": "thanks", "1043": "01223_413000", "1042": "cb30lx", "1041": "bangkok_city", "1040": "84_regent_street_city_centre", "1047": "01223_247877", "1046": "the_gardenia", "1045": "west", "1044": "01223_362433", "1049": "j_restaurant", "1048": "cb28nx", "1222": "01223_356354", "359": "very", "358": "selection", "868": "cb43ax", "1142": "the_hotpot", "1143": "prezzo", "1140": "01223_506055", "1141": "0871_942_9180", "1146": "59_hills_road_city_centre", "1147": "35_newnham_road_newnham", "669": "barbeque", "668": "less", "667": "no", "666": "basque", "665": "lucky", "664": "course", "663": "280", "662": "street", "661": "rose", "660": "may", "1221": "seafood", "215": "location", "692": "actually", "693": "today", "690": "interested", "691": "$u1$", "696": "nah", "697": "cherry", "694": "ther", "695": "inquiry", "698": "types", "699": "01223360966.", "526": "better", "542": "cost", "543": "special", "540": "specializes", "541": "pleasure", "546": "d.y", "547": "those", "544": "01223367660", "545": "$s3$", "8": "wecome", "548": "n't", "549": "pipasha", "68": "asked", "1019": "vietnamese", "995": "the_missing_sock", "994": "galleria", "997": "cb21tw", "996": "64_cherry_hinton_road_cherry_hinton", "991": "graffiti", "990": "midsummer_common", "993": "pizza_hut_cherry_hinton", "992": "20_milton_road_chesterton", "999": "bloomsbury_restaurant", "998": "cb12bd", "120": "alternatively", "121": "thier", "122": "017-335-3355", "123": "suppose", "124": "found", "764": "8", "126": "such", "127": "huh", "128": "amount", "129": "resaurants", "765": "nothing", "69": "town.is", "1136": "17_hills_road_city_centre", "1188": "33-34_saint_andrews_street", "259": "ok", "1010": "royal_standard", "563": "long", "1189": "12_lensfield_road_city_centre", "1011": "01223_366552", "414": "restrauant", "415": "sure", "416": "moment", "417": "q.a", "410": "options", "411": "35", "412": "therea", "413": "everything", "920": "8_norfolk_street_city_centre", "498": "$PAD$", "922": "01223_302330", "923": "21_-_24_northampton_street", "418": "persian", "419": "meal", "776": "far", "499": "20", "319": "goo", "318": "corsica", "1235": "83_regent_street", "313": "$s8$", "312": "the", "311": "with", "310": "5", "317": "id", "316": "really", "315": "perfect", "314": "want", "1237": "01462_432565", "1177": "01223_360966", "1066": "01223_355711", "921": "01223_244955", "1126": "saigon_city", "1181": "01223_400170", "139": "bye", "1134": "01223_365068", "832": "kymmoy", "833": "wagamama", "830": "01223_307581", "497": "near", "836": "01223_362372", "837": "hills_road_city_centre", "834": "2g_cambridge_leisure_park_cherry_hinton_road_cherry_hinton", "835": "king_street_city_centre", "838": "hotel_felix_whitehouse_lane_huntingdon_road", "839": "nandos", "3": "$", "1030": "01223_329432", "368": "awesome", "369": "be", "366": "will", "367": "..good", "364": "cambridge", "365": "290", "362": "its", "363": "01223302010", "360": "important", "361": "$u8$", "959": "cb21uf", "1138": "the_nirala", "952": "01223_367660", "1238": "12_st._johns_street_city_centre", "882": "01223_462354", "1239": "01223_351027", "380": "", "381": "01223323178", "382": "located", "383": "that'it", "384": "description", "385": "choise", "386": "six", "387": "welcome", "388": "wait", "389": "hotel", "784": "calling", "785": "shall", "786": "than", "787": "need..", "780": "$u5$", "781": "again", "782": "restarant", "783": "wok", "1079": "yippee_noodle_bar", "788": "parameters", "789": "down", "1174": "01223_566188", "860": "spanish", "1223": "01223_241387", "151": "hey", "579": "popular", "578": "danish", "1088": "la_tasca", "604": "irish", "573": "postode", "572": "third", "571": "under", "570": "hello", "577": "diverse", "576": "p", "575": "32", "574": "believe", "60": "he", "61": "along", "258": "visit", "606": "qualifications", "64": "criteria", "65": "know", "66": "particularly", "67": "afghan", "252": "and", "253": "resturants", "250": "rated", "251": "stay", "256": "25.00", "257": "system..good", "254": "establishment", "154": "names", "603": "unfortuntately", "869": "5_jordans_yard_bridge_street_city_centre", "1172": "01223_368786", "602": "range", "939": "cb21su", "731": "come", "730": "4", "733": "day", "732": "ood", "735": "two", "734": "definitely", "737": "swedish", "736": "telephone", "506": "-", "738": "hinton", "504": "return", "505": "84", "502": "reached", "503": "chiquito", "500": "thank-you", "501": "prices", "630": "wonderful", "631": "i", "632": "22", "633": "none", "469": "anyone", "635": "serves", "636": "within", "637": "great", "465": "service", "464": "food", "467": "database", "466": "ranges", "461": "well", "460": "17", "463": "restauarnt", "462": "or", "901": "01223_577786", "390": "preferred", "169": "52", "164": "ids", "165": "yummy", "166": "ones", "167": "01223353110", "160": "expensively", "161": "01223413000", "162": "scandinavian", "163": "hate", "964": "chiquito_restaurant_bar", "965": "modern_european", "966": "cotto", "967": "british", "960": "saffron_brasserie", "961": "tang_chinese", "962": "100_mill_road_city_centre", "963": "cb21uw", "1129": "cb21eg", "968": "12_norfolk_street_city_centre", "969": "dojo_noodle_bar", "936": "pizza_hut_fen_ditton", "1106": "kohinoor", "1107": "mahal_of_cambridge", "1104": "37_newnham_road_newnham", "1105": "cb41uy", "1102": "tandoori_palace", "935": "cb30ah", "1100": "01223_354755", "1101": "169_high_street_chesterton_chesterton", "934": "newmarket_road_fen_ditton", "908": "cb28pb", "1108": "24_green_street_city_centre", "1109": "cb21qy", "909": "da_vinci_pizzeria", "1128": "01223_360409", "35": "recommended", "1159": "22_chesterton_road_chesterton", "1158": "international", "641": "yes", "878": "01223_350688", "1240": "01223_248882", "1183": "shanghai_family_restaurant", "877": "french", "874": "curry_queen", "875": "cb23qf", "872": "cb39ey", "873": "01223_355012", "870": "01223_355909", "871": "01223_276182", "1155": "ali_baba", "1242": "portuguese", "1154": "darrys_cookhouse_and_wine_shop", "9": "4-6", "1245": "01223_362525", "1071": "52_mill_road_city_centre", "1157": "mediterranean", "1244": "01223_365599", "890": "cambridge_leisure_park_clifton_way_cherry_hinton", "891": "restaurant_two_two", "892": "01223_350420", "893": "indian", "894": "74_mill_road_city_centre", "647": "brazilian", "896": "cb21nw", "897": "the_little_rose_37_trumpington_street", "898": "cb12lf", "1098": "88_mill_road_city_centre", "1087": "01223_355166", "439": "our", "1116": "01799_521260", "1099": "midsummer_house_restaurant", "255": "night", "649": "check", "1199": "01223_363270", "1247": "corn_exchange_street", "1153": "city_stop_restaurant", "437": "using", "1152": "pizza_express", "1036": "72_regent_street_city_centre", "1086": "01223_327908", "357": "section", "356": "assist", "355": "01223327908", "808": "information", "353": "results", "352": "finding", "351": "experience", "350": "now", "803": "venues", "802": "chesterton", "801": "plus", "800": "searching", "807": "welsh", "806": "hit", "805": "part", "804": "choose", "216": "q.", "217": "did", "768": "few", "769": "three", "212": "peking", "213": "01223327908.", "210": "nirala", "211": "classy", "762": "think", "763": ",", "760": "following", "761": "choices", "766": "direct", "767": "abou", "218": "they", "219": "find", "957": "cb58pa", "956": "cheap", "1033": "the_cow_pizza_kitchen_and_bar", "1169": "01223_358899", "1078": "milton_road_chesterton", "1110": "106_regent_street_city_centre", "1076": "01223_566388", "1077": "rice_house", "1074": "nandos_city_centre", "1075": "cb259aq", "1072": "01223_353942", "1073": "cb11hr", "1070": "183_east_road_city_centre", "1178": "cambridge_retail_park_newmarket_road_fen_ditton", "289": "one", "288": "austrian", "321": "servers", "4": "59", "281": "catalan", "280": "interest", "283": "spend", "282": "let", "285": "'s", "284": "all", "287": "whatever", "286": "cuban", "1094": "the_river_bar_steakhouse_and_grill", "1095": "charlie_chan-", "1096": "106_mill_road_city_centre", "1097": "cb11lh", "678": "deciding", "679": "jamaican", "1092": "01733_553355", "1093": "108_regent_street_city_centre", "674": "grafton", "675": "venue", "676": "619", "677": "asian", "670": "delicious", "671": "anymore", "672": "bill", "673": "goobye", "263": "hang", "262": "directions", "261": "line", "260": "takes", "267": "pricing", "266": "town", "265": "restaurnt", "264": "$u7$", "1121": "north", "1031": "quayside_off_bridge_street", "269": "closest", "268": "moroccan", "1082": "01223_307030", "59": "meeting", "58": "expinsive", "1215": "2_sturton_street_city_centre", "55": "entree", "54": "sounds", "57": "full", "56": "thing", "51": "request", "50": "west.would", "53": "mid-range", "52": "was", "537": "sytem", "536": "category", "535": "appreciate", "534": ";", "533": "matching", "532": "reasonably", "531": "could", "530": "least", "539": "c.b", "538": "restraunt", "987": "saint_johns_chop_house", "201": "$s7$", "988": "15_magdalene_street_city_centre", "989": "01223_356060", "1171": "cb21nt", "774": "bridge", "982": "3_-_5_millers_yard_mill_lane", "983": "cb43hl", "980": "cb11dg", "981": "33_bridge_street", "986": "01223_812660", "203": "lot", "984": "expensive", "985": "31_newnham_road_newnham", "115": "business", "114": "entries", "117": "01223307030.", "116": "convenient", "111": "?", "110": "139", "113": "happy", "112": "what", "771": "good", "119": "mathcing", "118": "1", "770": "iscb259aq", "207": "river", "1205": "290_mill_road_city_centre", "772": "called", "953": "turkish", "429": "road", "428": "bistro", "1034": "cb21la", "919": "01223_324351", "918": "loch_fyne", "421": "take", "420": "preference", "423": "recommendations", "422": "greek", "425": "assistance", "424": "u.f", "427": "rest", "426": "certainly", "308": "looks", "309": "to", "1191": "66_chesterton_road_chesterton", "855": "17_magdalene_street_city_centre", "300": "some", "301": "currently", "302": "serve", "303": "tell", "304": "enjoy", "305": "reach", "306": "wanting", "307": "alternate", "895": "36_saint_andrews_street", "371": "sort", "181": "russian", "370": "several", "827": "cb21rq", "847": "cambridge_lodge_hotel_139_huntingdon_road_city_centre", "846": "india_house", "845": "cb23ll", "826": "cb21db", "843": "82_cherry_hinton_road_cherry_hinton", "842": "cb23nj", "841": "asian_oriental", "840": "01223_362456", "1151": "35_saint_andrews_street_city_centre", "375": "postal", "1090": "01223_302800", "853": "40210_millers_yard_city_centre", "849": "curry_king", "820": "wanted", "377": "1223", "1209": "caffe_uno", "376": "we", "1164": "01223_311053", "1210": "bridge_street_city_centre", "954": "cb13nl", "1208": "the_slug_and_lettuce", "950": "cb19hx", "1165": "cb41ha", "941": "01223_500005", "1038": "ugly_duckling", "1039": "restaurant_alimentum", "1220": "lebanese", "1166": "01223_357187", "568": "locdated", "569": "listed", "751": "specification", "750": "toward", "757": "seems", "508": "do", "755": "am", "754": "152-154", "560": "'m", "561": "pretty", "759": "show", "758": "unusual", "564": "bummer", "565": "1.d", "566": "their", "567": "inexpensive", "739": "us", "229": "world", "507": "give", "227": "named", "226": "if", "225": "queries", "224": "eateries", "223": "ten", "222": "sitar", "221": "matter", "220": "01223400170.", "1001": "cb58ba", "1024": "cb17dy", "1027": "cb30dq", "1026": "10_homerton_street_city_centre", "1021": "thompsons_lane_fen_ditton", "1020": "01223_311911", "1023": "doubletree_by_hilton_cambridge_granta_place_mill_lane", "1022": "12_bridge_street_city_centre", "1186": "michaelhouse_cafe", "1187": "mill_road_city_centre", "1184": "21_-_24_northampton_road", "1185": "cb21rt", "1029": "4_-_6_rose_crescent", "88": "tasty", "1180": "backstreet_bistro", "1037": "cb21ug", "726": "nope", "727": "particular", "724": "centr", "725": "talking", "722": "f", "723": "specific", "720": "fitting", "721": "tow", "1160": "sesame_restaurant_and_bar", "1035": "cb12az", "1167": "cote", "728": "api_call", "729": "restaurant", "605": "side", "150": "about", "607": "five", "152": "records", "155": "3", "600": "can", "157": "in", "156": "different", "159": "01223351880", "158": "number", "1207": "la_raza", "1206": "crowne_plaza_hotel_20_downing_street", "609": "however", "608": ".cb21ab", "1203": "cb23pp", "1202": "cb43lf", "1204": "cb21sj", "1211": "51_trumpington_street_city_centre", "976": "la_mimosa", "1161": "01223_354679", "634": "others", "80": "elese", "468": "$u3$", "749": "additional", "958": "jinling_noodle_bar", "398": "eritrean", "1214": "centre", "48": "01223812660", "49": "offers", "46": "questions", "47": "varying", "44": "not", "45": "fit", "42": "loated", "43": "center", "40": "nearby", "41": "c.b2", "1111": "71_castle_street_city_centre", "638": "having", "1113": "86_regent_street_city_centre", "1112": "cb23ju", "5": "already", "1114": "african", "1162": "01223_351880", "639": "'re", "1119": "cb58wr", "1118": "curry_prince", "1173": "cb58rg", "1219": "frankie_and_bennys", "1217": "01223_227330", "1163": "01842_753771", "489": "tanh", "488": "01799521260", "487": "work", "486": "01223462354.", "485": "narrow", "484": "per", "483": "ok.", "482": "use", "481": "please", "480": "broaden", "509": "contact", "955": "01223_505015", "1201": "the_varsity_restaurant", "472": "showing", "473": "meet", "470": "'d", "471": "past", "476": "desired", "477": "has", "474": "matches", "475": "requests", "168": "right", "478": "need", "479": "costs"}, "input_vocab_size": 1248, "vocab_mapping": {"regent_street_city_centre": 1229, "2_rose_crescent_city_centre": 1230, "all": 284, "code": 341, "consider": 752, "pardon": 330, "backstreet_bistro": 1180, "victoria_avenue_chesterton": 943, "broaden": 480, "particular": 727, "1,2": 379, "results": 353, "01223_367660": 952, "four": 246, "cb259aq": 1075, "asian": 677, "152_-_154_hills_road": 829, "go": 249, "query": 84, "mill": 325, "hate": 163, "looking": 775, "certainly": 426, "shiraz": 493, "yummy": 165, "chinese": 913, "everything": 413, "01223_307030": 1082, "21_burleigh_street_city_centre": 1057, "send": 555, "suits": 230, "to": 309, "01223_301761": 1225, "12_norfolk_street_city_centre": 968, "th": 623, "under": 571, "sorry": 822, "47-53_regent_street": 863, "01223351707.": 142, "town": 266, "tanh": 489, "crowne_plaza_hotel_20_downing_street": 1206, "need..": 787, "ciquito": 447, "very": 359, "pizza_hut_fen_ditton": 936, "none": 633, "choice": 271, "dojo_noodle_bar": 969, "cb58pa": 957, "entries": 114, "j_restaurant": 1049, "trouble": 618, "canapes": 598, "cool": 516, "cb21db": 826, "01223323178": 381, "191_histon_road_chesterton": 1131, "did": 217, "restuarants": 231, "venue": 675, "la_margherita": 949, "1_kings_parade": 912, "cb17ag": 971, "try": 239, "p": 576, "pizza_express": 1152, "settle": 106, "steakhouse": 176, "01223_448620": 1115, "01223_365599": 1244, "de_vere_university_arms_regent_street_city_centre": 1212, "290": 365, "enjoy": 304, "specializes": 540, "midsummer_house_restaurant": 1099, "01223307030.": 117, "ten": 223, "196": 322, "mediterranean": 1157, "direct": 766, "$s8$": 313, "past": 471, "dish": 335, "second": 456, "cost": 542, "crescent": 496, "further": 590, "$u5$": 780, "choices": 761, "01223_350106": 1003, "resaurants": 129, "cambridge": 364, "will": 366, "what": 112, "the_oak_bistro": 887, "$s1$": 453, "preferences": 194, "giving": 233, "section": 357, "requirements": 743, "near": 497, "swiss": 336, "01223_312598": 1017, "restauant": 528, "indian": 893, "royal_standard": 1010, "cb30ah": 935, "cb30af": 1008, ";": 534, "cb30ad": 916, "full": 57, "01223_367755": 1197, "01223_518111": 970, "meze_bar_restaurant": 1194, "panasian": 650, "106_regent_street_city_centre": 1110, "here": 514, "ranges": 466, "desired": 476, "108": 683, "let": 282, "andrews": 792, "address": 585, "directions": 262, "loch_fyne": 918, "100": 658, "appears": 234, "change": 594, "wait": 388, "basque": 666, "great": 637, "rajmahal": 929, "cb43le": 1193, "14": 392, "difficulty": 824, "32": 575, "vegetarian": 374, "vietnamese": 1019, "12_market_hill_city_centre": 911, "68_histon_road_chesterton": 1198, "experience": 351, "amount": 128, "01223_577786": 901, "suggestion": 793, "menu": 619, "narrow": 485, "d.y": 546, "options": 410, "fulfill": 612, "01223_311053": 1164, "named": 227, "addresses": 680, "search": 681, "prefer": 621, "maharajah_tandoori_restaurant": 883, "01223_356666": 1080, "names": 154, "15_magdalene_street_city_centre": 988, "total": 32, "cb23dt": 1144, "wecome": 8, "cb58wr": 1119, "use": 482, "15_-_19_trumpington_street": 944, "from": 401, "takes": 260, "would": 197, "cambridge_retail_park_newmarket_road_fen_ditton": 1178, "curry_garden": 1150, "visit": 258, "two": 735, "next": 320, "few": 768, "call": 744, "183_east_road_city_centre": 1070, "criteria": 64, "yippee_noodle_bar": 1079, "hinton": 738, "type": 331, "tell": 303, "today": 693, "belgian": 402, "sort": 371, "ood": 732, "expensively": 160, "taj_tandoori": 1182, "evening": 85, "newmarket_road_fen_ditton": 934, "12_bridge_street_city_centre": 1022, "01223_355166": 1087, "it": 589, "downtown": 432, "phone": 397, "sala_thong": 1018, "mathcing": 119, "01223311053": 652, "far": 776, "excellent": 798, "royal_spice": 1170, "modern_european": 965, "glad": 198, "me": 340, "stazione_restaurant_and_coffee_bar": 1012, "462354": 817, "01223_356060": 989, "6_lensfield_road": 1009, "f": 722, "this": 145, "work": 487, "king_street_city_centre": 835, "..good": 367, "curry_prince": 1118, "cambridge_lodge_restaurant": 906, "anywhere": 179, "nine": 241, "can": 600, "following": 760, "meet": 473, "q.": 216, "my": 89, "cb41eh": 1067, "01223352500.": 343, "nearest": 7, "listings": 180, "$PAD$": 498, "give": 507, "panahar": 1124, "awesome": 368, "suggestions": 16, "ids": 164, "high": 440, "de_luca_cucina_and_bar": 1083, "charlie_chan-": 1095, "84_regent_street_city_centre": 1040, "numbers": 814, "want": 314, "cantonese": 404, "$s4$": 430, "301030": 611, "!": 556, "information": 808, "needs": 184, "end": 278, "82_cherry_hinton_road_cherry_hinton": 843, "01223_500005": 941, "provide": 138, "01223_306306": 865, "cb23rh": 946, "get": 291, "range": 602, "01223_354382": 933, "chiquito_restaurant_bar": 964, "1": 118, "how": 10, "01223_461661": 861, "zizzi_cambridge": 1117, "01223_355909": 870, "rice_boat": 926, "galleria": 994, "instead": 584, "establishment": 254, "3-34": 706, "cb41jy": 1226, "01223_353942": 1072, "swedish": 737, "entree": 55, "okay": 74, "description": 384, "sir": 610, "may": 660, "4_kings_parade_city_centre": 951, "southern": 815, "01223_244955": 921, "cb21ab": 1005, "quayside_off_bridge_street": 1031, "ali_baba": 1155, "fits": 799, "such": 126, "cb21aw": 1246, "types": 698, "a": 821, "malasian": 596, "40210_millers_yard_city_centre": 853, "binh": 593, "third": 572, ".cb21ab": 608, "maybe": 512, "meghna": 844, "appreciate": 535, "greek": 422, "01223_350420": 892, "midsummer_common": 990, "so": 557, "south": 1032, "goodbye": 149, "classy": 211, "pleasure": 541, "finders_corner_newmarket_road": 927, "01223311053.": 522, "39_burleigh_street_city_centre": 925, "frankie_and_bennys": 1219, "serving": 326, "singaporean": 182, "help": 136, "01223_351707": 1055, "01223812660": 48, "cb43lf": 1202, "areas": 438, "northampton": 238, "01223353110": 167, "cb11lh": 1097, "course": 664, "already": 5, "looks": 308, "austrian": 288, "139": 110, "still": 26, "01223_353110": 884, "its": 362, "inquiries": 399, "ok.": 483, "perfect": 315, "25": 202, "style": 513, "51_trumpington_street_city_centre": 1211, "20": 499, "thank": 296, "fit": 45, "located": 382, "somewhere": 193, "21_-_24_northampton_road": 1184, "72_regent_street_city_centre": 1036, "4_-_6_rose_crescent": 1029, ",": 763, "actually": 692, "better": 526, "cb21ug": 1037, "offers": 49, "choose": 804, "bummer": 564, "listing": 15, "cotto": 966, "covered": 651, "bye": 139, "might": 2, "saint_johns_chop_house": 987, "then": 626, "them": 274, "someone": 1, "return": 504, "yo": 209, "01223_309147": 859, "food": 464, "cb12az": 1035, "01223_276182": 871, "eat": 228, "cb21uf": 959, "records": 152, "they": 218, "not": 44, "40428_king_street_city_centre": 851, "now": 350, "day": 733, "20_milton_road_chesterton": 992, "01223_259988": 930, "tow": 721, "100_mill_road_city_centre": 962, "seve": 298, "curry_king": 849, "reasonably": 532, "l": 653, "01223247877": 524, "peking": 212, "gourmet_burger_kitchen": 907, "each": 247, "01223_249955": 857, "found": 124, "european": 979, "out": 329, "side": 605, "recommend": 292, "cb21dp": 864, "cb21nt": 1171, "another": 457, "$STOP$": 558, "meets": 562, "01223_355012": 873, "list": 645, "fish": 81, "peking_restaurant": 1014, "01223351880": 159, "mill_road_city_centre": 1187, "619": 676, "01223413000": 161, "our": 439, "01223_354755": 1100, "ring": 688, "84": 505, "really": 316, "category": 536, "goo": 319, "the_river_bar_steakhouse_and_grill": 1094, "'": 13, "01223_366552": 1011, "saffron_brasserie": 960, "2g_cambridge_leisure_park_cherry_hinton_road_cherry_hinton": 834, "increase": 134, "specifications": 147, "7_barnwell_road_fen_ditton": 1148, "oriental": 190, "asian_oriental": 841, "7": 452, "01223_362456": 840, "01223_247877": 1047, "got": 148, "clarify": 458, "cb17aa": 852, "cb23nj": 842, "correct": 240, "74_mill_road_city_centre": 894, "36_saint_andrews_street": 895, "bistro": 428, "assist": 356, "88_mill_road_city_centre": 1098, "hang": 263, "queries": 225, "cb13nf": 1192, "01223727410.": 344, "the_cambridge_chop_house": 1149, "quite": 31, "bloomsbury_restaurant": 999, "01223350688": 17, "64_cherry_hinton_road_cherry_hinton": 996, "30_bridge_street_city_centre": 972, "wagamama": 833, "01223_315232": 879, "cb21eg": 1129, "ask": 1233, "cambridge_leisure_park_clifton_way": 1190, "wanted": 820, "$u4$": 659, "care": 625, "sytem": 537, "01223308681": 592, "shanghai_family_restaurant": 1183, "selections": 615, "the_varsity_restaurant": 1201, "could": 531, "inexpensive": 567, "kohinoor": 1106, "'s": 285, "la_mimosa": 976, "thing": 56, "american": 235, "place": 442, "01223_727410": 848, "differnt": 24, "01223_464550": 937, "think": 762, "'d": 470, "first": 624, "dine": 433, "appeals": 704, "1223": 377, "dont": 205, "37_newnham_road_newnham": 1104, "$u3$": 468, "01223350688.": 746, "number": 158, "cb12bd": 998, "one": 289, "01223302010.": 103, "jesus_lane_fen_ditton": 910, "cb21nw": 896, "delicious": 670, "spanish": 860, "reached": 502, "sounds": 54, "locdated": 568, "city": 552, "rated": 250, "01223_241387": 1223, "service": 465, "their": 566, "01223_277977": 1243, "moderately": 640, "system": 597, "least": 530, "anyone": 469, "needed": 553, "01223_362372": 836, "wonderful": 630, "too": 83, "saint": 82, "cb28pb": 908, "10_homerton_street_city_centre": 1026, "listed": 569, "danish": 578, "selling": 73, "option": 273, "napier_street_city_centre": 917, "that": 403, "hotel": 389, "serve": 302, "milton_road_chesterton": 1078, "japanese": 1168, "part": 805, "eritrean": 398, "c.b.1": 200, "tuscan": 627, "thai": 1051, "believe": 574, "than": 786, "specify": 237, "milton": 642, "11": 717, "wide": 373, "kind": 295, "b": 550, "15": 75, "unfortunately": 187, "17": 460, "16": 191, "sitar_tandoori": 1058, "require": 20, "01223_362433": 1044, "scandinavian": 162, "matter": 221, "street": 662, "were": 95, "alimentum": 700, "are": 332, "and": 252, "bridge": 774, "01223_324351": 919, "290_mill_road_city_centre": 1205, "modern": 279, "mind": 819, "locations": 391, "crossover": 131, "tonight": 346, "woah": 712, "nirala": 210, "graffiti": 991, "halal": 25, "have": 435, "need": 478, "seem": 189, "01223400170": 454, "3qf": 109, "01223_302800": 1090, "01223_350688": 878, "01223_356354": 1222, "cb23jx": 1231, "afraid": 648, "01223_360966": 1177, "convenient": 116, "-": 506, "mid": 711, "cb23ju": 1112, "also": 523, "fitting": 720, "contact": 509, "take": 421, "which": 654, "finding": 352, "wanting": 306, "korean": 1123, "india_house": 846, "169_high_street_chesterton_chesterton": 1101, "restaurant_two_two": 891, "where'a": 510, "towards": 196, "multiple": 773, "shall": 785, "huh": 127, "price": 345, "african": 1114, "reach": 305, "restaurnt": 265, "victoria": 591, "1.d": 565, "most": 18, "01223352500": 23, "the_gandhi": 1064, "services": 554, "g4_cambridge_leisure_park_clifton_way_cherry_hinton": 975, "charlie_chan": 1130, "$s7$": 201, "21-24": 173, "59_hills_road_city_centre": 1146, "cuban": 286, "pipasha": 549, "01223_368786": 1172, "average": 93, "request": 51, "hungry": 684, "01223566388": 22, "traditional": 242, "definitely": 734, "priced..": 791, "01223_363270": 1199, "01223360966.": 699, "mid-range": 53, "71_castle_street_city_centre": 1111, "specification": 751, "fact": 90, "yu_garden": 1056, "atmosphere": 130, "selection": 358, "shot": 701, "gold": 71, "show": 759, "german": 459, "106_mill_road_city_centre": 1096, "cheap": 956, "12_lensfield_road_city_centre": 1189, "moroccan": 268, "$s5$": 517, "russian": 181, "showing": 472, "$u7$": 264, "postcode": 294, "free_school_lane_city_centre": 1103, "restrauant": 414, "01223_812660": 986, "mexican": 1007, "that'it": 383, "australian": 395, "fine": 446, "find": 219, "01223462354.": 486, "doubletree_by_hilton_cambridge_granta_place_mill_lane": 1023, "32_bridge_street_city_centre": 885, "chiquito": 503, "01223_359506": 1053, "parameters": 788, "romanian": 28, "the_good_luck_chinese_food_takeaway": 889, "woluld": 275, "should": 810, "01223_357187": 1166, "only": 790, "going": 27, "205_victoria_road_chesterton": 1232, "cb21tw": 997, "pretty": 561, "money": 394, "the_missing_sock": 995, "8": 764, "52": 169, "hope": 656, "choise": 385, "meant": 795, "do": 508, "hit": 806, "hungarian": 818, "10": 243, "01223_352607": 948, "33_bridge_street": 981, "lucky": 665, "preferred": 390, "lebanese": 1220, "01223_358899": 1169, "ones": 166, "01223_244149": 881, "international": 1158, "529_newmarket_road_fen_ditton": 1127, "2": 178, "01223_302010": 1015, "chips": 70, "01462_432565": 1237, "secondary": 521, "regarding": 334, "bar": 595, "resteraunt": 140, "golden_house": 854, "22_chesterton_road_chesterton": 1159, "hello": 570, "malaysian": 94, "calling": 784, "jinling_noodle_bar": 958, "bad": 400, "recommendations": 423, "river": 207, "where": 63, "severing": 333, "thompsons_lane_fen_ditton": 1021, "restaurants": 96, "moderatre": 529, "available": 491, "requirement": 580, "eraina": 1059, "the_golden_curry": 915, "riverside_brasserie": 1156, "tandoori_palace": 1102, "$s2$": 339, "restaraunts": 19, "see": 102, "cb11dg": 980, "result": 823, "apologies": 796, "hotel_felix_whitehouse_lane_huntingdon_road": 838, "restaurant_one_seven": 914, "best": 342, "closer": 617, "appear": 811, "0871_942_9180": 1141, "away": 62, "currently": 301, "please": 481, "ugly_duckling": 1038, "3": 155, "'re": 639, "features": 323, "cb58jj": 899, "wok": 783, "14_-16_bridge_street": 1179, "probably": 812, "cb13nl": 954, "nope": 726, "barbeque": 669, "01223327908": 355, "24": 601, "we": 376, "cb43ax": 868, "pricerange": 455, "7_milton_road_chesterton": 1028, "cb23ar": 888, "86_regent_street_city_centre": 1113, "01223_323361": 1062, "however": 609, "da_vinci_pizzeria": 909, "seafood": 1221, "job": 146, "preferably": 689, "french": 877, "cb58rg": 1173, "come": 731, "$": 3, "corn_exchange_street": 1247, "both": 347, "c": 490, "toward": 750, "22": 632, "restaurant": 729, "many": 406, "01223_413000": 1043, "kosher": 177, "gastropub": 1085, "british": 967, "cambridge_lodge_hotel_139_huntingdon_road_city_centre": 847, "cherry": 697, "expand": 527, "requests": 475, "01223_462354": 882, "asked": 68, "nearby": 40, "bridge_street_city_centre": 1210, "regent": 33, "cb21qy": 1109, "deciding": 678, "had": 714, "ca": 599, "others": 634, "pizza_hut_cherry_hinton": 993, "venues": 803, "the_lucky_star": 928, "scottish": 620, "pizza_hut_city_centre": 905, "hotel_du_vin_and_bistro": 1227, "01223302010": 363, "cb21qa": 1084, "along": 61, "expensive": 984, "cambridge_city_football_club_milton_road_chesterton": 1061, "unusual": 758, "api_call": 728, "kymmoy": 832, "loated": 42, "$s3$": 545, "three": 769, "been": 232, ".": 685, "north_american": 876, "int": 588, "35_saint_andrews_street_city_centre": 1151, "much": 613, "31_newnham_road_newnham": 985, "01223": 559, "interest": 280, "nandos": 839, "01223_323639": 1228, "lovely": 742, "meeting": 59, "mahal_of_cambridge": 1107, "bbq": 582, ">": 586, "17_magdalene_street_city_centre": 855, "451_newmarket_road_fen_ditton": 938, "01223_412430": 1052, "eastern": 175, "am": 755, "$u2$": 192, "anythig": 657, "cb12lf": 898, "else": 616, "01223_327908": 1086, "hmm": 135, "moderatley": 299, "2_sturton_street_city_centre": 1215, "169": 515, "01223_337766": 924, "prices": 501, "turkish": 953, "bangkok_city": 1041, "those": 547, "sound": 741, "ant": 451, "cote": 1167, "restarant": 782, "look": 643, "4-6": 9, "talking": 725, "these": 21, "bill": 672, "therea": 412, "la_raza": 1207, "resturant": 174, "n't": 548, "while": 747, "c.b2": 41, "suppose": 123, "abou": 767, "152-154": 754, "rice_house": 1077, "exist": 655, "21_-_24_northampton_street": 923, "golden_wok": 1060, "mistake": 144, "eateries": 224, "01223_358399": 1176, "venetian": 511, "43_high_street_cherry_hinton_cherry_hinton": 977, "centr": 724, "is": 97, "telephone": 736, "cb58ba": 1001, "sitar": 222, "cambridge_leisure_park_clifton_way_cherry_hinton": 890, "good": 771, "im": 715, "q.a": 417, "city_stop_restaurant": 1153, "in": 157, "01223_366668": 1065, "464630.": 270, "id": 317, "if": 226, "cuisine": 293, "different": 156, "anymore": 671, "cb39ey": 872, "perform": 409, "suggest": 494, "make": 206, "the_nirala": 1138, "01842_753771": 1163, "same": 204, "any": 443, "grafton_hotel_restaurant": 1089, "matches": 474, "01223_307581": 830, "inquiry": 695, "$u9$": 338, "9": 236, "cb28nx": 1048, "the_slug_and_lettuce": 1208, "several": 370, "couple": 214, "dontcare": 99, "chesterton": 802, "fairly": 91, "elese": 80, "closest": 269, "singapore": 337, "welsh": 807, "grafton": 674, "resturants": 253, "qualifications": 606, "pipasha_restaurant": 856, "market_hill_city_centre": 1125, "moment": 416, "epensive": 518, "01223_464630": 900, "clowns_cafe": 1241, "cocum": 904, "afghan": 67, "six": 386, "west.would": 50, "01223_566188": 1174, "pizza_express_fen_ditton": 1234, "center": 43, "database": 467, "i": 631, "no": 667, "well": 461, "whichever": 248, "know": 65, "costs": 479, "english": 687, "y": 100, "town.is": 69, "the": 312, "01223_248882": 1240, "spend": 283, "01223_369299": 1133, "restuarant": 495, "goobye": 673, "01223_351027": 1239, "cb12qa": 831, "01223_566388": 1076, "travellers_rest": 1081, "sesame_restaurant_and_bar": 1160, "just": 77, "less": 668, "ya": 245, "able": 170, "rest": 427, "01223_360409": 1128, "restauarnt": 463, "01223351707": 172, "thanks": 78, "questions": 46, "world": 229, "cb41ep": 1218, "polish": 622, "yes": 641, "cb23ll": 845, "01223_302330": 922, "24_green_street_city_centre": 1108, "the_hotpot": 1142, "cb41ha": 1165, "like": 38, "hills": 434, "efes_restaurant": 1137, "dining": 393, "thinking": 34, "rose": 661, "01223248882.": 349, "alternate": 307, "seems": 757, "recommended": 35, "01223_323178": 1175, "ther": 694, "01223362525": 37, "varisty": 183, "interested": 690, "4": 730, "cb23qf": 875, "280": 663, "has": 477, "cb23pp": 1203, "match": 710, "details": 708, "sanit": 702, "...": 707, "indonesian": 713, "around": 581, "34_-_35_green_street": 1196, "01799_521260": 1116, "tnahh": 171, "pricey": 107, "possible": 629, "catalan": 281, "fusion": 682, "five": 607, "preference": 420, "01223_356555": 1139, "$u6$": 188, "using": 437, "cb41nl": 1145, "postal": 375, "cb12as": 850, "name": 756, "desire": 614, "$s6$": 813, "thank-you": 500, "59": 4, "d": 98, "cb19hx": 950, "01223_323737": 828, "alternatively": 120, "serves": 635, "server": 777, "specific": 723, "01223367660": 544, "12_st._johns_street_city_centre": 1238, "either": 778, "night": 255, "popular": 579, "served": 133, "works": 718, "8_norfolk_street_city_centre": 920, "italian": 1004, "54_king_street_city_centre": 903, "right": 168, "hakka": 1091, "st._michael's_church_trinity_street_city_centre": 1195, "absolutely": 587, "welconme": 186, "01223307581": 825, "some": 300, "back": 436, "cb21sj": 1204, "curry_queen": 874, "01223400170.": 220, "sure": 415, "pricing": 267, "choosing": 705, "cb58aq": 1224, "dbye": 72, "205": 408, "'ve": 407, "matching": 533, "provided": 328, "caribbean": 195, "for": 290, "3_-_5_millers_yard_mill_lane": 982, "centre": 1214, "portuguese": 1242, "cb21su": 939, "per": 484, "creative": 354, "35": 411, "does": 449, "restaurant_alimentum": 1039, "the_cow_pizza_kitchen_and_bar": 1033, "moderate": 1236, "the_copper_kettle": 1068, "christmas": 372, "?": 111, "cb11hr": 1073, "locate": 141, "be": 369, "varying": 47, "grafton_hotel_619_newmarket_road_fen_ditton": 942, "business": 115, "unfortuntately": 603, "shiraz_restaurant": 947, "servies": 816, "3.h": 105, "01223_355711": 1066, "01223_311911": 1020, "although": 185, "the_gardenia": 1046, "efes_restaurants": 866, "post": 79, "by": 405, "moderatly": 208, "on": 348, "about": 150, "ok": 259, "hk_fusion": 1200, "anything": 277, "oh": 297, "of": 748, "cb21uj": 940, "cb21rt": 1185, "location": 215, "meal": 419, "dinner": 644, "important": 360, "plus": 801, "cb21uw": 963, "01223_365068": 1134, "01223_412299": 858, "52_mill_road_city_centre": 1071, "or": 462, "road": 429, "cb11bg": 1216, "01223327908.": 213, "whats": 520, "01223_244277": 931, "11_peas_hill_city_centre": 886, "don_pasquale_pizzeria": 1016, "system..good": 257, "01223_351880": 1162, "within": 636, "the_little_rose_37_trumpington_street": 897, "cb43hl": 983, "servers": 321, "down": 789, "nothing": 765, "because": 716, "01223_354679": 1161, "apologize": 6, "your": 797, "particularly": 66, "corsica": 318, "east": 1006, "fitzbillies_restaurant": 1050, "additional": 749, "01223_308871": 1054, "area": 244, "cb21la": 1034, "there": 132, "hey": 151, "long": 563, "caffe_uno": 1209, "hills_road_city_centre": 837, "la_tasca": 1088, "huntingdon_road_city_centre": 1132, "lot": 203, "brazilian": 647, ":": 11, "was": 52, "happy": 113, "01223_308681": 974, "head": 87, "medium": 276, "north": 1121, "michaelhouse_cafe": 1186, "offering": 450, "offer": 14, "something": 779, "40270_king_street_city_centre": 862, "6": 448, "cb30lx": 1042, "but": 809, "searching": 800, "01223_324033": 1135, "hi": 29, "hear": 551, "good-bye": 12, "33-34_saint_andrews_street": 1188, "line": 261, "trying": 583, "with": 311, "expinsive": 58, "he": 60, "info": 104, "01223_227330": 1217, "made": 92, "cb11ln": 867, "places": 86, "whether": 525, "restaraunt": 444, "$GO$": 0, "01223_364917": 1069, "up": 709, "us": 739, "01223315232.": 36, "01223_363471": 1063, "west": 1045, "persian": 418, "jamaican": 679, "01223_505015": 955, "servings": 628, "problem": 153, "more": 445, "called": 772, "assistance": 425, "prezzo": 1143, "irish": 604, "bedouin": 973, "whatever": 287, "certain": 199, "upscale": 686, "17_hills_road_city_centre": 1136, "australasian": 519, "cb17dy": 1024, "$u1$": 691, "an": 137, "saigon_city": 1126, "as": 324, "diverse": 577, "tasty": 88, "at": 101, "u.f": 424, "thier": 121, "01223_506055": 1140, "66_chesterton_road_chesterton": 1191, "01223_362525": 1245, "check": 649, "again": 781, "cb41uy": 1105, "'ll": 492, "variety": 125, "nandos_city_centre": 1074, "nah": 696, "green": 108, "35_newnham_road_newnham": 1147, "thanh_binh": 902, "": 380, "01223_362054": 1013, "iscb259aq": 770, "other": 719, "5": 310, "special": 543, "5_jordans_yard_bridge_street_city_centre": 869, "01223366552.": 378, "you": 441, "1.uy": 794, "01223_329432": 1030, "01733_553355": 1092, "contacting": 76, "nice": 745, "requested": 431, "cb30df": 978, "196_mill_road_city_centre": 932, "star": 646, "darrys_cookhouse_and_wine_shop": 1154, "01223_400170": 1181, "anatolia": 1120, "cb21rq": 827, "welcome": 387, "categories": 703, "41518_castle_street_city_centre": 1213, "stay": 251, "'m": 560, "postode": 573, "cb30dq": 1027, "occasion": 143, "108_regent_street_city_centre": 1093, "25.00": 256, "priced": 30, "restraunt": 538, "01223_352500": 1122, "cb21rg": 1000, "01223_568988": 945, "017-335-3355": 122, "tang_chinese": 961, "eatery": 327, "01223727410": 740, "01799521260": 488, "little_seoul": 880, "lan_hong_house": 1025, "tandoon": 272, "c.b": 539, "83_regent_street": 1235, "time": 396, "alright": 753, "$u8$": 361, "01223568988": 39, "having": 638, "01223_361763": 1002}} -------------------------------------------------------------------------------- /incar/data_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import random 4 | import nltk 5 | import os 6 | import sys 7 | import numpy as np 8 | import logging 9 | logging.getLogger().setLevel(logging.INFO) 10 | 11 | single_word_entities = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday','medicine','conference','dinner','lab','yoga','tennis','doctor','meeting','swimming','optometrist','football','dentist',"overcast","snow", "stormy", "hail","hot", "rain", "cold","cloudy", "warm", "windy","foggy", "humid", "frost", "blizzard", "drizzle", "dry", "dew", "misty","friend", "home", "coffee", "chinese","pizza", "grocery", "rest", "shopping", "parking","gas", "hospital"] 12 | 13 | class DataHandler(object): 14 | 15 | def __init__(self,emb_dim,batch_size,train_path,val_path,test_path,vocab_path,glove_path): 16 | 17 | self.batch_size = batch_size 18 | self.train_path = train_path 19 | self.vocab_threshold = 3 20 | self.val_path = val_path 21 | self.test_path = test_path 22 | self.vocab_path = vocab_path 23 | self.emb_dim = emb_dim 24 | self.glove_path = glove_path 25 | 26 | self.vocab = self.load_vocab() 27 | self.input_vocab_size = self.vocab['input_vocab_size'] 28 | self.output_vocab_size = self.vocab['output_vocab_size'] 29 | self.generate_vocab_size = self.vocab['generate_vocab_size'] 30 | self.emb_init = self.load_glove_vectors() 31 | 32 | self.train_data = json.load(open(self.train_path)) 33 | self.val_data = json.load(open(self.val_path)) 34 | self.test_data = json.load(open(self.test_path)) 35 | 36 | random.shuffle(self.train_data) 37 | random.shuffle(self.val_data) 38 | random.shuffle(self.test_data) 39 | 40 | self.val_data_full = self.append_dummy_data(self.val_data) 41 | 42 | self.train_index = 0 43 | self.val_index = 0 44 | self.train_num = len(self.train_data) 45 | self.val_num = len(self.val_data_full) 46 | 47 | def append_dummy_data(self,data): 48 | new_data = [] 49 | for i in range(0,len(data)): 50 | data[i]['dummy'] = 0 51 | new_data.append(copy.copy(data[i])) 52 | 53 | last = data[-1] 54 | last['dummy'] = 1 55 | for _ in range(0,self.batch_size - len(data)%self.batch_size): 56 | new_data.append(copy.copy(last)) 57 | 58 | return copy.copy(new_data) 59 | 60 | 61 | def load_glove_vectors(self): 62 | logging.info("Loading pre-trained Word Embeddings") 63 | filename = self.glove_path + "glove.6B.200d.txt" 64 | glove = {} 65 | file = open(filename,'r') 66 | for line in file.readlines(): 67 | row = line.strip().split(' ') 68 | glove[row[0]] = np.asarray(row[1:]) 69 | logging.info('Loaded GloVe!') 70 | file.close() 71 | embeddings_init = np.random.normal(size=(self.vocab['input_vocab_size'],self.emb_dim)).astype('f') 72 | count = 0 73 | for word in self.vocab['vocab_mapping']: 74 | if word in glove: 75 | count = count + 1 76 | embeddings_init[self.vocab['vocab_mapping'][word]] = glove[word] 77 | 78 | del glove 79 | 80 | logging.info("Loaded "+str(count)+" pre-trained Word Embeddings") 81 | return embeddings_init 82 | 83 | 84 | def load_vocab(self): 85 | if os.path.isfile(self.vocab_path): 86 | logging.info("Loading vocab from file") 87 | with open(self.vocab_path) as f: 88 | return json.load(f) 89 | else: 90 | logging.info("Vocab file not found. Computing Vocab") 91 | with open(self.train_path) as f: 92 | train_data = json.load(f) 93 | with open(self.val_path) as f: 94 | val_data = json.load(f) 95 | with open(self.test_path) as f: 96 | test_data = json.load(f) 97 | 98 | full_data = [] 99 | full_data.extend(train_data) 100 | full_data.extend(val_data) 101 | full_data.extend(test_data) 102 | 103 | return self.get_vocab(full_data) 104 | 105 | def get_vocab(self,data): 106 | 107 | vocab = {} 108 | for d in data: 109 | utts = [] 110 | utts.append(d['output']) 111 | utts.extend(d['context']) 112 | for utt in utts: 113 | tokens = utt.split(" ") 114 | for token in tokens: 115 | if token.lower() not in vocab: 116 | vocab[token.lower()] = 1 117 | else: 118 | vocab[token.lower()] = vocab[token.lower()] + 1 119 | 120 | for item in d['kb']: 121 | for key in item: 122 | if key.lower() not in vocab: 123 | vocab[key.lower()] = 1 124 | else: 125 | vocab[key.lower()] = vocab[key.lower()] + 1 126 | token = item[key] 127 | if token.lower() not in vocab: 128 | vocab[token.lower()] = 1 129 | else: 130 | vocab[token.lower()] = vocab[token.lower()] + 1 131 | 132 | words = vocab.keys() 133 | words.append("$STOP$") 134 | words.append("$PAD$") 135 | 136 | for i in range(1,6): 137 | words.append("$u"+str(i)+"$") 138 | words.append("$s"+str(i)+"$") 139 | words.append("$u6$") 140 | 141 | generate_words = [] 142 | copy_words = [] 143 | for word in words: 144 | if word in single_word_entities or '_' in word: 145 | copy_words.append(word) 146 | else: 147 | generate_words.append(word) 148 | 149 | output_vocab_size = len(words) + 1 150 | 151 | generate_indices = [i for i in range(1,len(generate_words)+1)] 152 | copy_indices = [i for i in range(len(generate_words)+1,output_vocab_size)] 153 | random.shuffle(generate_indices) 154 | random.shuffle(copy_indices) 155 | 156 | mapping = {} 157 | rev_mapping = {} 158 | 159 | for i in range(0,len(generate_words)): 160 | mapping[generate_words[i]] = generate_indices[i] 161 | rev_mapping[str(generate_indices[i])] = generate_words[i] 162 | 163 | for i in range(0,len(copy_words)): 164 | mapping[copy_words[i]] = copy_indices[i] 165 | rev_mapping[str(copy_indices[i])] = copy_words[i] 166 | 167 | mapping["$GO$"] = 0 168 | rev_mapping[0] = "$GO$" 169 | vocab_dict = {} 170 | vocab_dict['vocab_mapping'] = mapping 171 | vocab_dict['rev_mapping'] = rev_mapping 172 | vocab_dict['input_vocab_size'] = len(words) + 1 173 | vocab_dict['generate_vocab_size'] = len(generate_words) + 1 174 | vocab_dict['output_vocab_size'] = output_vocab_size 175 | 176 | with open(self.vocab_path,'w') as f: 177 | json.dump(vocab_dict,f) 178 | 179 | logging.info("Vocab file created") 180 | 181 | return vocab_dict 182 | 183 | def get_sentinel(self,i,context): 184 | if i%2 == 0: 185 | speaker = "u" 186 | turn = (context - i + 1)/2 187 | else: 188 | speaker = "s" 189 | turn = (context - i)/2 190 | return "$"+speaker+str(turn)+"$" 191 | 192 | def vectorize(self,batch,train): 193 | vectorized = {} 194 | vectorized['inp_utt'] = [] 195 | vectorized['out_utt'] = [] 196 | vectorized['inp_len'] = [] 197 | vectorized['context_len'] = [] 198 | vectorized['out_len'] = [] 199 | vectorized['kb'] = [] 200 | vectorized['kb_mask'] = [] 201 | vectorized['keys'] = [] 202 | vectorized['keys_mask'] = [] 203 | vectorized['mapping'] = [] 204 | vectorized['rev_mapping'] = [] 205 | vectorized['type'] = [] 206 | vectorized['dummy'] = [] 207 | vectorized['empty'] = [] 208 | 209 | vectorized['knowledge'] = [] 210 | vectorized['context'] = [] 211 | max_inp_utt_len = 0 212 | max_out_utt_len = 0 213 | max_context_len = 0 214 | kb_len = 0 215 | keys_len = 6 216 | 217 | for item in batch: 218 | 219 | if len(item['context']) > max_context_len: 220 | max_context_len = len(item['context']) 221 | 222 | for utt in item['context']: 223 | tokens = utt.split(" ") 224 | 225 | if len(tokens) > max_inp_utt_len: 226 | max_inp_utt_len = len(tokens) 227 | 228 | tokens = item['output'].split(" ") 229 | if len(tokens) > max_out_utt_len: 230 | max_out_utt_len = len(tokens) 231 | 232 | if len(item['kb']) > kb_len: 233 | kb_len = len(item['kb']) 234 | 235 | max_inp_utt_len = max_inp_utt_len + 1 236 | 237 | max_out_utt_len = max_out_utt_len + 1 238 | vectorized['max_out_utt_len'] = max_out_utt_len 239 | 240 | 241 | for item in batch: 242 | vectorized['context'].append(item['context']) 243 | vectorized['knowledge'].append(item['kb']) 244 | vectorized['mapping'].append(item['mapping']) 245 | vectorized['rev_mapping'].append(item['rev_mapping']) 246 | vectorized['type'].append(item['type']) 247 | if item['kb'] == []: 248 | vectorized['empty'].append(0) 249 | else: 250 | vectorized['empty'].append(1) 251 | if not train: 252 | vectorized['dummy'].append(item['dummy']) 253 | vector_inp = [] 254 | vector_len = [] 255 | 256 | for i in range(0,len(item['context'])): 257 | utt = item['context'][i] 258 | inp = [] 259 | sentinel = self.get_sentinel(i,len(item['context'])) 260 | tokens = utt.split(" ") + [sentinel] 261 | for token in tokens: 262 | inp.append(self.vocab['vocab_mapping'][token]) 263 | 264 | vector_len.append(len(tokens)) 265 | for _ in range(0,max_inp_utt_len - len(tokens)): 266 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 267 | vector_inp.append(copy.copy(inp)) 268 | 269 | vectorized['context_len'].append(len(item['context'])) 270 | 271 | for _ in range(0,max_context_len - len(item['context'])): 272 | vector_len.append(0) 273 | inp = [] 274 | for _ in range(0,max_inp_utt_len): 275 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 276 | vector_inp.append(copy.copy(inp)) 277 | 278 | vectorized['inp_utt'].append(copy.copy(vector_inp)) 279 | vectorized['inp_len'].append(vector_len) 280 | 281 | vector_out = [] 282 | tokens = item['output'].split(" ") 283 | tokens.append('$STOP$') 284 | for token in tokens: 285 | vector_out.append(self.vocab['vocab_mapping'][token]) 286 | 287 | for _ in range(0,max_out_utt_len - len(tokens)): 288 | vector_out.append(self.vocab['vocab_mapping']["$PAD$"]) 289 | vectorized['out_utt'].append(copy.copy(vector_out)) 290 | vectorized['out_len'].append(len(tokens)) 291 | 292 | vector_keys = [] 293 | vector_keys_mask = [] 294 | vector_kb = [] 295 | vector_kb_mask = [] 296 | 297 | for result in item['kb']: 298 | vector_result = [] 299 | vector_result_keys = [] 300 | vector_result_keys_mask = [] 301 | vector_kb_mask.append(1) 302 | for key in result: 303 | vector_result.append(self.vocab['vocab_mapping'][result[key]]) 304 | vector_result_keys.append(self.vocab['vocab_mapping'][key]) 305 | vector_result_keys_mask.append(1) 306 | 307 | for _ in range(0,keys_len-len(result.keys())): 308 | vector_result_keys.append(self.vocab['vocab_mapping']["$PAD$"]) 309 | vector_result_keys_mask.append(0) 310 | vector_result.append(self.vocab['vocab_mapping']["$PAD$"]) 311 | vector_keys.append(copy.copy(vector_result_keys)) 312 | vector_keys_mask.append(copy.copy(vector_result_keys_mask)) 313 | vector_kb.append(copy.copy(vector_result)) 314 | 315 | if item['kb'] == []: 316 | vector_kb_mask.append(1) 317 | vector_kb.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 318 | vector_keys.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 319 | vector_keys_mask.append([1] + [0 for _ in range(0,keys_len-1)]) 320 | 321 | current_kb_len = len(vector_kb_mask) 322 | 323 | for _ in range(0,kb_len - current_kb_len): 324 | vector_kb_mask.append(0) 325 | vector_kb.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 326 | vector_keys.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,keys_len)]) 327 | vector_keys_mask.append([1] + [0 for _ in range(0,keys_len-1)]) 328 | 329 | vectorized['kb'].append(copy.copy(vector_kb)) 330 | vectorized['kb_mask'].append(copy.copy(vector_kb_mask)) 331 | vectorized['keys'].append(copy.copy(vector_keys)) 332 | vectorized['keys_mask'].append(copy.copy(vector_keys_mask)) 333 | 334 | return vectorized 335 | 336 | def get_batch(self,train): 337 | 338 | epoch_done = False 339 | 340 | if train: 341 | index = self.train_index 342 | batch = self.vectorize(self.train_data[index:index+self.batch_size],train) 343 | self.train_index = self.train_index + self.batch_size 344 | 345 | if self.train_index + self.batch_size > self.train_num: 346 | self.train_index = 0 347 | random.shuffle(self.train_data) 348 | epoch_done = True 349 | 350 | else: 351 | index = self.val_index 352 | batch = self.vectorize(self.val_data_full[index:index+self.batch_size],train) 353 | self.val_index = self.val_index + self.batch_size 354 | 355 | if self.val_index + self.batch_size > self.val_num: 356 | self.val_index = 0 357 | random.shuffle(self.val_data) 358 | self.val_data_full = self.append_dummy_data(self.val_data) 359 | epoch_done = True 360 | 361 | 362 | return batch,epoch_done 363 | -------------------------------------------------------------------------------- /incar/evaluate.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import copy 3 | import json 4 | import csv 5 | from collections import Counter 6 | from nltk.util import ngrams 7 | from nltk.corpus import stopwords 8 | from nltk.tokenize import word_tokenize 9 | from nltk.stem import WordNetLemmatizer 10 | import math, re, argparse 11 | import functools 12 | 13 | entities = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday','medicine','conference','dinner','lab','yoga','tennis','doctor','meeting','swimming','optometrist','football','dentist',"overcast","snow", "stormy", "hail","hot", "rain", "cold","cloudy", "warm", "windy","foggy", "humid", "frost", "blizzard", "drizzle", "dry", "dew", "misty","friend", "home", "coffee", "chinese","pizza", "grocery", "rest", "shopping", "parking","gas", "hospital"] 14 | 15 | def score(parallel_corpus): 16 | 17 | # containers 18 | count = [0, 0, 0, 0] 19 | clip_count = [0, 0, 0, 0] 20 | r = 0 21 | c = 0 22 | weights = [0.25, 0.25, 0.25, 0.25] 23 | 24 | # accumulate ngram statistics 25 | for hyps, refs in parallel_corpus: 26 | hyps = [hyp.split() for hyp in hyps] 27 | refs = [ref.split() for ref in refs] 28 | for hyp in hyps: 29 | 30 | for i in range(4): 31 | # accumulate ngram counts 32 | hypcnts = Counter(ngrams(hyp, i + 1)) 33 | cnt = sum(hypcnts.values()) 34 | count[i] += cnt 35 | 36 | # compute clipped counts 37 | max_counts = {} 38 | for ref in refs: 39 | refcnts = Counter(ngrams(ref, i + 1)) 40 | for ng in hypcnts: 41 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 42 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 43 | for ng, count in hypcnts.items()) 44 | clip_count[i] += sum(clipcnt.values()) 45 | 46 | # accumulate r & c 47 | bestmatch = [1000, 1000] 48 | for ref in refs: 49 | if bestmatch[0] == 0: break 50 | diff = abs(len(ref) - len(hyp)) 51 | if diff < bestmatch[0]: 52 | bestmatch[0] = diff 53 | bestmatch[1] = len(ref) 54 | r += bestmatch[1] 55 | c += len(hyp) 56 | 57 | # computing bleu score 58 | p0 = 1e-7 59 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 60 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 61 | for i in range(4)] 62 | s = math.fsum(w * math.log(p_n) \ 63 | for w, p_n in zip(weights, p_ns) if p_n) 64 | bleu = bp * math.exp(s) 65 | return bleu 66 | 67 | 68 | data = pickle.load(open("needed.p")) 69 | vocab = json.load(open("./vocab.json")) 70 | outs = [] 71 | golds = [] 72 | domain_wise = {} 73 | for domain in ['schedule','navigate','weather']: 74 | domain_wise[domain] = {} 75 | domain_wise[domain]['tp_prec'] = 0.0 76 | domain_wise[domain]['tp_recall'] = 0.0 77 | domain_wise[domain]['total_prec'] = 0.0 78 | domain_wise[domain]['total_recall'] = 0.0 79 | domain_wise[domain]['gold'] = [] 80 | domain_wise[domain]['output'] = [] 81 | 82 | tp_prec = 0.0 83 | tp_recall = 0.0 84 | total_prec = 0.0 85 | total_recall = 0.0 86 | 87 | for i in range(0,len(data['sentences'])): 88 | sentence = data['sentences'][i] 89 | domain = data['type'][i] 90 | sentence = list(sentence) 91 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 92 | index = len(sentence) 93 | else: 94 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 95 | predicted = [str(sentence[j]) for j in range(0,index)] 96 | ground = data['output'][i] 97 | ground = list(ground) 98 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 99 | ground_truth = [str(ground[j]) for j in range(0,index)] 100 | 101 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 102 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 103 | 104 | for word in out_anon: 105 | if word in entities or '_' in word: 106 | total_prec = total_prec + 1 107 | domain_wise[domain]['total_prec'] = domain_wise[domain]['total_prec'] + 1 108 | if word in gold_anon: 109 | tp_prec = tp_prec + 1 110 | domain_wise[domain]['tp_prec'] = domain_wise[domain]['tp_prec'] + 1 111 | 112 | for word in gold_anon: 113 | if word in entities or '_' in word: 114 | total_recall = total_recall + 1 115 | domain_wise[domain]['total_recall'] = domain_wise[domain]['total_recall'] + 1 116 | if word in out_anon: 117 | tp_recall = tp_recall + 1 118 | domain_wise[domain]['tp_recall'] = domain_wise[domain]['tp_recall'] + 1 119 | 120 | gold = gold_anon 121 | out = out_anon 122 | 123 | domain_wise[domain]['gold'].append(" ".join(gold)) 124 | golds.append(" ".join(gold)) 125 | domain_wise[domain]['output'].append(" ".join(out)) 126 | outs.append(" ".join(out)) 127 | 128 | with open('output', 'w') as output_file: 129 | for line in outs: 130 | output_file.write(line+"\n") 131 | 132 | with open('reference', 'w') as output_file: 133 | for line in golds: 134 | output_file.write(line+"\n") 135 | 136 | wrap_generated = [[_] for _ in outs] 137 | wrap_truth = [[_] for _ in golds] 138 | prec = tp_prec/total_prec 139 | recall = tp_recall/total_recall 140 | print 'prec',tp_prec,total_prec 141 | print 'recall',tp_recall,total_recall 142 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (score(zip(wrap_generated, wrap_truth)),prec,recall,2*prec*recall/(prec+recall)) 143 | for domain in ['schedule','navigate','weather']: 144 | prec = domain_wise[domain]['tp_prec']/domain_wise[domain]['total_prec'] 145 | recall = domain_wise[domain]['tp_recall']/domain_wise[domain]['total_recall'] 146 | print "prec",domain_wise[domain]['tp_prec'],domain_wise[domain]['total_prec'] 147 | print "recall",domain_wise[domain]['tp_recall'],domain_wise[domain]['total_recall'] 148 | wrap_generated = [[_] for _ in domain_wise[domain]['output']] 149 | wrap_truth = [[_] for _ in domain_wise[domain]['gold']] 150 | print "Domain: " + str(domain) + ", Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (score(zip(wrap_generated, wrap_truth)),prec,recall,2*prec*recall/(prec+recall)) 151 | -------------------------------------------------------------------------------- /incar/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops import embedding_ops, array_ops, math_ops, tensor_array_ops, control_flow_ops 4 | 5 | class DialogueModel(object): 6 | 7 | def __init__(self,device,batch_size,inp_vocab_size,out_vocab_size,generate_size,emb_init,emb_dim,enc_hid_dim,dec_hid_dim,attn_size): 8 | 9 | self.device = device 10 | self.batch_size = batch_size 11 | self.emb_dim = emb_dim 12 | self.inp_vocab_size = inp_vocab_size 13 | self.out_vocab_size = out_vocab_size 14 | self.generate_size = generate_size 15 | self.emb_init = emb_init 16 | self.enc_hid_dim = enc_hid_dim 17 | self.dec_hid_dim = dec_hid_dim 18 | self.attn_size = attn_size 19 | self.generate_size = generate_size 20 | 21 | self.inp_utt = tf.placeholder( 22 | name='inp_utt', dtype=tf.int64, 23 | shape=[self.batch_size, None, None], 24 | ) 25 | 26 | self.inp_len = tf.placeholder( 27 | name='inp_len', dtype=tf.int64, 28 | shape=[self.batch_size, None], 29 | ) 30 | 31 | self.context_len = tf.placeholder( 32 | name='context_len', dtype=tf.int64, 33 | shape=[self.batch_size], 34 | ) 35 | 36 | self.out_utt = tf.placeholder( 37 | name='out_utt', dtype=tf.int64, 38 | shape=[self.batch_size, None], 39 | ) 40 | 41 | self.out_len = tf.placeholder( 42 | name='out_len', dtype=tf.float32, 43 | shape=[self.batch_size], 44 | ) 45 | 46 | self.kb = tf.placeholder( 47 | name='kb', dtype=tf.int64, 48 | shape=[self.batch_size,None,6], 49 | ) 50 | 51 | self.kb_mask = tf.placeholder( 52 | name='kb_mask', dtype=tf.float32, 53 | shape=[self.batch_size,None], 54 | ) 55 | 56 | self.keys = tf.placeholder( 57 | name='keys', dtype=tf.int64, 58 | shape=[self.batch_size,None,6], 59 | ) 60 | 61 | self.keys_mask = tf.placeholder( 62 | name='keys_mask', dtype=tf.float32, 63 | shape=[self.batch_size,None,6], 64 | ) 65 | 66 | self.max_out_utt_len = tf.placeholder( 67 | name = 'max_out_utt_len', dtype=tf.int32, 68 | shape = (), 69 | ) 70 | 71 | self.db_empty = tf.placeholder( 72 | name='db_empty', dtype=tf.float32, 73 | shape=[self.batch_size], 74 | ) 75 | 76 | self.buildArch() 77 | 78 | def buildArch(self): 79 | 80 | with tf.device(self.device): 81 | 82 | self.embeddings = tf.get_variable("embeddings", initializer=tf.constant(self.emb_init)) 83 | self.inp_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.inp_utt) 84 | 85 | with tf.variable_scope("encoder"): 86 | self.encoder_cell_1 = tf.contrib.rnn.GRUCell(self.enc_hid_dim) 87 | self.encoder_cell_2 = tf.contrib.rnn.GRUCell(2*self.enc_hid_dim) 88 | self.flat_inp_emb = tf.reshape(self.inp_utt_emb,shape=[-1,tf.shape(self.inp_utt)[2],self.emb_dim]) 89 | self.flat_inp_len = tf.reshape(self.inp_len,shape=[-1]) 90 | 91 | outputs,output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell_1,cell_bw=self.encoder_cell_1,inputs=self.flat_inp_emb,dtype=tf.float32,sequence_length=self.flat_inp_len,time_major=False) 92 | self.flat_encoder_states = tf.concat(outputs,axis=2) 93 | self.utt_reps = tf.concat(output_states,axis=1) 94 | 95 | self.utt_rep_second = tf.reshape(self.utt_reps,shape=[self.batch_size,-1,2*self.enc_hid_dim]) 96 | self.hidden_states, self.inp_utt_rep = tf.nn.dynamic_rnn(self.encoder_cell_2,self.utt_rep_second,dtype=tf.float32,sequence_length=self.context_len,time_major=False) 97 | self.encoder_states = tf.reshape(tf.reshape(self.flat_encoder_states,shape=[self.batch_size,-1,tf.shape(self.inp_utt)[2],2*self.enc_hid_dim]), shape=[self.batch_size,-1,2*self.enc_hid_dim]) 98 | 99 | 100 | self.kb_emb = embedding_ops.embedding_lookup(self.embeddings, self.kb) 101 | self.keys_emb = embedding_ops.embedding_lookup(self.embeddings, self.keys) 102 | self.result_rep = tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(self.keys_mask,2),-1),tf.reduce_sum(tf.einsum('ijk,ijkl->ijkl',self.keys_mask,self.kb_emb),2)) 103 | 104 | self.start_token = tf.constant([0] * self.batch_size, dtype=tf.int32) 105 | self.out_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.out_utt) 106 | self.processed_x = tf.transpose(self.out_utt_emb,perm=[1,0,2]) 107 | 108 | with tf.variable_scope("decoder"): 109 | self.decoder_cell = tf.contrib.rnn.GRUCell(self.dec_hid_dim) 110 | 111 | self.h0 = self.inp_utt_rep 112 | self.g_output_unit = self.create_output_unit() 113 | 114 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.max_out_utt_len, 115 | dynamic_size=False, infer_shape=True) 116 | 117 | def _g_recurrence(i, x_t, h_tm1, gen_x): 118 | _,h_t = self.decoder_cell(x_t,h_tm1) 119 | o_t = self.g_output_unit(h_t) # batch x vocab , prob 120 | next_token = tf.cast(tf.reshape(tf.argmax(o_t, 1), [self.batch_size]), tf.int32) 121 | x_tp1 = embedding_ops.embedding_lookup(self.embeddings,next_token) # batch x emb_dim 122 | gen_x = gen_x.write(i, next_token) # indices, batch_size 123 | return i + 1, x_tp1, h_t, gen_x 124 | 125 | _, _, _, self.gen_x = control_flow_ops.while_loop( 126 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 127 | body=_g_recurrence, 128 | loop_vars=(tf.constant(0, dtype=tf.int32), 129 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 130 | self.h0, gen_x)) 131 | 132 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 133 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 134 | 135 | # gen_x contains the colours sampled as outputs 136 | # Hence, gen_x is used while calculating accuracy 137 | 138 | g_predictions = tensor_array_ops.TensorArray( 139 | dtype=tf.float32, size=self.max_out_utt_len, 140 | dynamic_size=False, infer_shape=True) 141 | 142 | ta_emb_x = tensor_array_ops.TensorArray( 143 | dtype=tf.float32, size=self.max_out_utt_len) 144 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 145 | 146 | def _train_recurrence(i, x_t, h_tm1, g_predictions): 147 | _,h_t = self.decoder_cell(x_t,h_tm1) 148 | o_t = self.g_output_unit(h_t) 149 | g_predictions = g_predictions.write(i, o_t) # batch x vocab_size 150 | x_tp1 = ta_emb_x.read(i) 151 | return i + 1, x_tp1, h_t, g_predictions 152 | 153 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 154 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 155 | body=_train_recurrence, 156 | loop_vars=(tf.constant(0, dtype=tf.int32), 157 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 158 | self.h0, g_predictions)) 159 | 160 | self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 161 | 162 | self.loss_mask = tf.sequence_mask(self.out_len,self.max_out_utt_len,dtype=tf.float32) 163 | self.ground_truth = tf.one_hot(self.out_utt,on_value=tf.constant(1,dtype=tf.float32),off_value=tf.constant(0,dtype=tf.float32),depth=self.out_vocab_size,dtype=tf.float32) 164 | self.log_predictions = tf.log(self.g_predictions + 1e-20) 165 | self.cross_entropy = tf.multiply(self.ground_truth,self.log_predictions) 166 | self.cross_entropy_sum = tf.reduce_sum(self.cross_entropy,2) 167 | self.masked_cross_entropy = tf.multiply(self.loss_mask,self.cross_entropy_sum) 168 | self.sentence_loss = tf.divide(tf.reduce_sum(self.masked_cross_entropy,1),tf.reduce_sum(self.loss_mask,1)) 169 | self.loss = -tf.reduce_mean(self.sentence_loss) 170 | 171 | def create_output_unit(self): 172 | 173 | self.W1 = tf.get_variable("W1",shape=[2*self.enc_hid_dim+self.dec_hid_dim,2*self.enc_hid_dim],dtype=tf.float32) 174 | self.W2 = tf.get_variable("W2",shape=[2*self.enc_hid_dim,self.attn_size],dtype=tf.float32) 175 | self.w = tf.get_variable("w",shape=[self.attn_size,1],dtype=tf.float32) 176 | self.U = tf.get_variable("U",shape=[self.dec_hid_dim+2*self.enc_hid_dim,self.generate_size],dtype=tf.float32) 177 | self.W_1 = tf.get_variable("W_1",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 178 | self.W_2 = tf.get_variable("W_2",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 179 | self.W_12 = tf.get_variable("W_12",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 180 | self.W_22 = tf.get_variable("W_22",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 181 | self.r_1 = tf.get_variable("r_1",shape=[self.attn_size,1],dtype=tf.float32) 182 | self.r_2 = tf.get_variable("r_2",shape=[self.attn_size,1],dtype=tf.float32) 183 | self.b1 = tf.get_variable("b1",shape=[self.generate_size],dtype=tf.float32) 184 | self.b2 = tf.get_variable("b2",shape=[1],dtype=tf.float32) 185 | self.b3 = tf.get_variable("b3",shape=[1],dtype=tf.float32) 186 | self.W3 = tf.get_variable("W3",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 187 | self.W4 = tf.get_variable("W4",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 188 | 189 | def unit(hidden_state): 190 | 191 | hidden_state_expanded_attn = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.encoder_states)[1],1]) 192 | attn_rep = tf.concat([self.encoder_states,hidden_state_expanded_attn],axis=2) 193 | attn_rep = tf.nn.tanh(tf.einsum('ijk,kl->ijl',tf.nn.tanh(tf.einsum("ijk,kl->ijl",attn_rep,self.W1)),self.W2)) 194 | u_i = tf.squeeze(tf.einsum('ijk,kl->ijl',attn_rep,self.w),2) 195 | inp_len_mask = tf.sequence_mask(self.inp_len,tf.shape(self.inp_utt)[2],dtype=tf.float32) 196 | attn_mask = tf.reshape(inp_len_mask,shape=[self.batch_size,-1]) 197 | exp_u_i_masked = tf.multiply(tf.cast(attn_mask,dtype=tf.float64),tf.exp(tf.cast(u_i,dtype=tf.float64))) 198 | a = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(exp_u_i_masked,1),-1),exp_u_i_masked),dtype=tf.float32) 199 | inp_attn = tf.reduce_sum(tf.einsum('ij,ijk->ijk',a,self.encoder_states),1) 200 | 201 | generate_dist = tf.nn.softmax(math_ops.matmul(tf.concat([hidden_state,inp_attn],axis=1),self.U) + self.b1) 202 | extra_zeros = tf.zeros([self.batch_size,self.out_vocab_size - self.generate_size]) 203 | extended_generate_dist = tf.concat([generate_dist,extra_zeros],axis=1) 204 | 205 | hidden_state_expanded_result = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.kb)[1],1]) 206 | inp_attn_expanded_result = tf.tile(array_ops.expand_dims(inp_attn,1),[1,tf.shape(self.kb)[1],1]) 207 | result_attn_rep = tf.concat([self.result_rep,hidden_state_expanded_result,inp_attn_expanded_result],axis=2) 208 | result_attn_rep = tf.nn.tanh(tf.einsum("ijk,kl->ijl",tf.nn.tanh(tf.einsum("ijk,kl->ijl",result_attn_rep,self.W_1)),self.W_12)) 209 | beta_logits = tf.squeeze(tf.einsum('ijk,kl->ijl',result_attn_rep,self.r_1),2) 210 | beta_masked = tf.multiply(tf.cast(self.kb_mask,dtype=tf.float64),tf.exp(tf.cast(beta_logits,dtype=tf.float64))) 211 | beta = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(beta_masked,1),-1),beta_masked),dtype=tf.float32) 212 | 213 | hidden_state_expanded_keys = tf.tile(array_ops.expand_dims(array_ops.expand_dims(hidden_state,1),1),[1,tf.shape(self.kb)[1],tf.shape(self.kb)[2],1]) 214 | inp_attn_expanded_keys = tf.tile(array_ops.expand_dims(array_ops.expand_dims(inp_attn,1),1),[1,tf.shape(self.kb)[1],tf.shape(self.kb)[2],1]) 215 | result_key_rep = tf.concat([self.keys_emb,hidden_state_expanded_keys,inp_attn_expanded_keys],axis=3) 216 | result_key_rep = tf.nn.tanh(tf.einsum('ijkl,lm->ijkm',tf.nn.tanh(tf.einsum('ijkl,lm->ijkm',result_key_rep,self.W_2)),self.W_22)) 217 | gamma_logits = tf.squeeze(tf.einsum('ijkl,lm->ijkm',result_key_rep,self.r_2),3) 218 | gamma_masked = tf.multiply(tf.cast(self.keys_mask,dtype=tf.float64),tf.exp(tf.cast(gamma_logits,dtype=tf.float64))) 219 | gamma = tf.einsum('ij,ijk->ijk',beta,tf.cast(tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(gamma_masked,2),-1),gamma_masked),dtype=tf.float32)) 220 | 221 | batch_nums_context = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 222 | batch_nums_tiled_context = tf.tile(batch_nums_context,[1,tf.shape(self.encoder_states)[1]]) 223 | flat_inp_utt = tf.reshape(self.inp_utt,shape=[self.batch_size,-1]) 224 | indices_context = tf.stack([batch_nums_tiled_context,flat_inp_utt],axis=2) 225 | shape = [self.batch_size,self.out_vocab_size] 226 | context_copy_dist = tf.scatter_nd(indices_context,a,shape) 227 | 228 | db_rep = tf.reduce_sum(tf.einsum('ij,ijk->ijk',beta,self.result_rep),1) 229 | 230 | p_db = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W4)+self.b3) 231 | p_db = tf.tile(p_db,[1,self.out_vocab_size]) 232 | one_minus_fn = lambda x: 1 - x 233 | one_minus_pdb = tf.map_fn(one_minus_fn, p_db) 234 | 235 | p_gens = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W3)+self.b2) 236 | p_gens = tf.tile(p_gens,[1,self.out_vocab_size]) 237 | one_minus_fn = lambda x: 1 - x 238 | one_minus_pgens = tf.map_fn(one_minus_fn, p_gens) 239 | 240 | batch_nums = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 241 | kb_ids = tf.reshape(self.kb,shape=[self.batch_size,-1]) 242 | num_kb_ids = tf.shape(kb_ids)[1] 243 | batch_nums_tiled = tf.tile(batch_nums,[1,num_kb_ids]) 244 | indices = tf.stack([batch_nums_tiled,kb_ids],axis=2) 245 | updates = tf.reshape(gamma,shape=[self.batch_size,-1]) 246 | shape = [self.batch_size,self.out_vocab_size] 247 | kb_dist = tf.scatter_nd(indices,updates,shape) 248 | kb_dist = tf.einsum('i,ij->ij',self.db_empty,kb_dist) 249 | 250 | copy_dist = tf.multiply(p_db,kb_dist) + tf.multiply(one_minus_pdb,context_copy_dist) 251 | final_dist = tf.multiply(p_gens,extended_generate_dist) + tf.multiply(one_minus_pgens,copy_dist) 252 | 253 | return final_dist 254 | 255 | return unit 256 | 257 | def get_feed_dict(self,batch): 258 | 259 | fd = { 260 | self.inp_utt : batch['inp_utt'], 261 | self.inp_len : batch['inp_len'], 262 | self.context_len: batch['context_len'], 263 | self.out_utt : batch['out_utt'], 264 | self.out_len : batch['out_len'], 265 | self.kb : batch['kb'], 266 | self.kb_mask : batch['kb_mask'], 267 | self.keys : batch['keys'], 268 | self.keys_mask : batch['keys_mask'], 269 | self.db_empty : batch['empty'], 270 | self.max_out_utt_len : batch['max_out_utt_len'] 271 | } 272 | 273 | return fd 274 | -------------------------------------------------------------------------------- /incar/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def score(self,parallel_corpus): 54 | 55 | # containers 56 | count = [0, 0, 0, 0] 57 | clip_count = [0, 0, 0, 0] 58 | r = 0 59 | c = 0 60 | weights = [0.25, 0.25, 0.25, 0.25] 61 | 62 | # accumulate ngram statistics 63 | for hyps, refs in parallel_corpus: 64 | hyps = [hyp.split() for hyp in hyps] 65 | refs = [ref.split() for ref in refs] 66 | for hyp in hyps: 67 | 68 | for i in range(4): 69 | # accumulate ngram counts 70 | hypcnts = Counter(ngrams(hyp, i + 1)) 71 | cnt = sum(hypcnts.values()) 72 | count[i] += cnt 73 | 74 | # compute clipped counts 75 | max_counts = {} 76 | for ref in refs: 77 | refcnts = Counter(ngrams(ref, i + 1)) 78 | for ng in hypcnts: 79 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 80 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 81 | for ng, count in hypcnts.items()) 82 | clip_count[i] += sum(clipcnt.values()) 83 | 84 | # accumulate r & c 85 | bestmatch = [1000, 1000] 86 | for ref in refs: 87 | if bestmatch[0] == 0: break 88 | diff = abs(len(ref) - len(hyp)) 89 | if diff < bestmatch[0]: 90 | bestmatch[0] = diff 91 | bestmatch[1] = len(ref) 92 | r += bestmatch[1] 93 | c += len(hyp) 94 | 95 | # computing bleu score 96 | p0 = 1e-7 97 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 98 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 99 | for i in range(4)] 100 | s = math.fsum(w * math.log(p_n) \ 101 | for w, p_n in zip(weights, p_ns) if p_n) 102 | bleu = bp * math.exp(s) 103 | return bleu 104 | 105 | def evaluate(self,data,vocab): 106 | entities = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday','medicine','conference','dinner','lab','yoga','tennis','doctor','meeting','swimming','optometrist','football','dentist',"overcast","snow", "stormy", "hail","hot", "rain", "cold","cloudy", "warm", "windy","foggy", "humid", "frost", "blizzard", "drizzle", "dry", "dew", "misty","friend", "home", "coffee", "chinese","pizza", "grocery", "rest", "shopping", "parking","gas", "hospital"] 107 | outs = [] 108 | golds = [] 109 | domain_wise = {} 110 | for domain in ['schedule','navigate','weather']: 111 | domain_wise[domain] = {} 112 | domain_wise[domain]['tp_prec'] = 0.0 113 | domain_wise[domain]['tp_recall'] = 0.0 114 | domain_wise[domain]['total_prec'] = 0.0 115 | domain_wise[domain]['total_recall'] = 0.0 116 | domain_wise[domain]['gold'] = [] 117 | domain_wise[domain]['output'] = [] 118 | 119 | tp_prec = 0.0 120 | tp_recall = 0.0 121 | total_prec = 0.0 122 | total_recall = 0.0 123 | 124 | for i in range(0,len(data['sentences'])): 125 | sentence = data['sentences'][i] 126 | domain = data['type'][i] 127 | sentence = list(sentence) 128 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 129 | index = len(sentence) 130 | else: 131 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 132 | predicted = [str(sentence[j]) for j in range(0,index)] 133 | ground = data['output'][i] 134 | ground = list(ground) 135 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 136 | ground_truth = [str(ground[j]) for j in range(0,index)] 137 | 138 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 139 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 140 | 141 | for word in out_anon: 142 | if word in entities or '_' in word: 143 | total_prec = total_prec + 1 144 | domain_wise[domain]['total_prec'] = domain_wise[domain]['total_prec'] + 1 145 | if word in gold_anon: 146 | tp_prec = tp_prec + 1 147 | domain_wise[domain]['tp_prec'] = domain_wise[domain]['tp_prec'] + 1 148 | 149 | for word in gold_anon: 150 | if word in entities or '_' in word: 151 | total_recall = total_recall + 1 152 | domain_wise[domain]['total_recall'] = domain_wise[domain]['total_recall'] + 1 153 | if word in out_anon: 154 | tp_recall = tp_recall + 1 155 | domain_wise[domain]['tp_recall'] = domain_wise[domain]['tp_recall'] + 1 156 | 157 | gold = gold_anon 158 | out = out_anon 159 | 160 | domain_wise[domain]['gold'].append(" ".join(gold)) 161 | golds.append(" ".join(gold)) 162 | domain_wise[domain]['output'].append(" ".join(out)) 163 | outs.append(" ".join(out)) 164 | 165 | wrap_generated = [[_] for _ in outs] 166 | wrap_truth = [[_] for _ in golds] 167 | prec = tp_prec/total_prec 168 | recall = tp_recall/total_recall 169 | if prec == 0 or recall == 0: 170 | f1 = 0.0 171 | else: 172 | f1 = 2*prec*recall/(prec+recall) 173 | overall_f1 = f1 174 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 175 | for domain in ['schedule','navigate','weather']: 176 | prec = domain_wise[domain]['tp_prec']/domain_wise[domain]['total_prec'] 177 | recall = domain_wise[domain]['tp_recall']/domain_wise[domain]['total_recall'] 178 | if prec == 0 or recall == 0: 179 | f1 = 0.0 180 | else: 181 | f1 = 2*prec*recall/(prec+recall) 182 | wrap_generated = [[_] for _ in domain_wise[domain]['output']] 183 | wrap_truth = [[_] for _ in domain_wise[domain]['gold']] 184 | print "Domain: " + str(domain) + ", Bleu: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),f1) 185 | return overall_f1 186 | 187 | def test(self): 188 | test_epoch_done = False 189 | 190 | teststep = 0 191 | testLoss = 0.0 192 | needed = {} 193 | needed['sentences'] = [] 194 | needed['output'] = [] 195 | needed['rev_mapping'] = [] 196 | needed['mapping'] = [] 197 | needed['type'] = [] 198 | needed['context'] = [] 199 | needed['kb'] = [] 200 | 201 | while not test_epoch_done: 202 | teststep = teststep + 1 203 | batch, test_epoch_done = self.handler.get_batch(train=False) 204 | feedDict = self.model.get_feed_dict(batch) 205 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 206 | 207 | if 1 not in batch['dummy']: 208 | needed['sentences'].extend(sentences) 209 | needed['output'].extend(batch['out_utt']) 210 | needed['mapping'].extend(batch['mapping']) 211 | needed['rev_mapping'].extend(batch['rev_mapping']) 212 | needed['type'].extend(batch['type']) 213 | needed['context'].extend(batch['context']) 214 | needed['kb'].extend(batch['knowledge']) 215 | else: 216 | index = batch['dummy'].index(1) 217 | needed['sentences'].extend(sentences[0:index]) 218 | needed['output'].extend(batch['out_utt'][0:index]) 219 | needed['mapping'].extend(batch['mapping'][0:index]) 220 | needed['rev_mapping'].extend(batch['rev_mapping'][0:index]) 221 | needed['type'].extend(batch['type'][0:index]) 222 | needed['context'].extend(batch['context'][0:index]) 223 | needed['kb'].extend(batch['knowledge'][0:index]) 224 | 225 | pickle.dump(needed,open("needed.p","w")) 226 | self.evaluate(needed,self.handler.vocab) 227 | 228 | def main(): 229 | 230 | parser = argparse.ArgumentParser() 231 | parser.add_argument('--batch_size', type=int, default=16) 232 | parser.add_argument('--emb_dim', type=int, default=200) 233 | parser.add_argument('--enc_hid_dim', type=int, default=128) 234 | parser.add_argument('--dec_hid_dim', type=int, default=256) 235 | parser.add_argument('--attn_size', type=int, default=200) 236 | parser.add_argument('--epochs', type=int, default=25) 237 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 238 | parser.add_argument('--dataset_path', type=str, default='../data/InCar/') 239 | parser.add_argument('--glove_path', type=str, default='../data/') 240 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 241 | config = parser.parse_args() 242 | 243 | DEVICE = "/gpu:0" 244 | 245 | logging.info("Loading Data") 246 | 247 | handler = DataHandler( 248 | emb_dim = config.emb_dim, 249 | batch_size = config.batch_size, 250 | train_path = config.dataset_path + "train.json", 251 | val_path = config.dataset_path + "test.json", 252 | test_path = config.dataset_path + "test.json", 253 | vocab_path = "./vocab.json", 254 | glove_path = config.glove_path) 255 | 256 | logging.info("Loading Architecture") 257 | 258 | model = DialogueModel( 259 | device = DEVICE, 260 | batch_size = config.batch_size, 261 | inp_vocab_size = handler.input_vocab_size, 262 | out_vocab_size = handler.output_vocab_size, 263 | generate_size = handler.generate_vocab_size, 264 | emb_init = handler.emb_init, 265 | emb_dim = config.emb_dim, 266 | enc_hid_dim = config.enc_hid_dim, 267 | dec_hid_dim = config.dec_hid_dim, 268 | attn_size = config.attn_size) 269 | 270 | logging.info("Loading Trainer") 271 | 272 | trainer = Trainer( 273 | model=model, 274 | handler=handler, 275 | ckpt_path=config.checkpoint, 276 | num_epochs=config.epochs, 277 | learning_rate = config.learning_rate) 278 | 279 | trainer.test() 280 | 281 | main() -------------------------------------------------------------------------------- /incar/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def trainData(self): 54 | curEpoch = 0 55 | step = 0 56 | epochLoss = [] 57 | 58 | logging.info("Training the model") 59 | 60 | best_f1 = 0.0 61 | 62 | while curEpoch <= self.epochs: 63 | step = step + 1 64 | 65 | batch, epoch_done = self.handler.get_batch(train=True) 66 | feedDict = self.model.get_feed_dict(batch) 67 | 68 | fetch = [self.global_step, self.model.loss, self.optimizer] 69 | mod_step,loss,_ = self.sess.run(fetch,feed_dict = feedDict) 70 | epochLoss.append(loss) 71 | 72 | if step % 80 == 0: 73 | outstr = "step: "+str(step)+" Loss: "+str(loss) 74 | logging.info(outstr) 75 | 76 | if epoch_done: 77 | train_loss = np.mean(np.asarray(epochLoss)) 78 | 79 | val_epoch_done = False 80 | valstep = 0 81 | valLoss = 0.0 82 | needed = {} 83 | needed['sentences'] = [] 84 | needed['output'] = [] 85 | needed['rev_mapping'] = [] 86 | needed['mapping'] = [] 87 | needed['type'] = [] 88 | 89 | while not val_epoch_done: 90 | valstep = valstep + 1 91 | batch, val_epoch_done = self.handler.get_batch(train=False) 92 | feedDict = self.model.get_feed_dict(batch) 93 | val_loss,sentences = self.sess.run([self.model.loss,self.model.gen_x],feed_dict=feedDict) 94 | if 1 not in batch['dummy']: 95 | needed['sentences'].extend(sentences) 96 | needed['output'].extend(batch['out_utt']) 97 | needed['mapping'].extend(batch['mapping']) 98 | needed['rev_mapping'].extend(batch['rev_mapping']) 99 | needed['type'].extend(batch['type']) 100 | else: 101 | index = batch['dummy'].index(1) 102 | needed['sentences'].extend(sentences[0:index]) 103 | needed['output'].extend(batch['out_utt'][0:index]) 104 | needed['mapping'].extend(batch['mapping'][0:index]) 105 | needed['rev_mapping'].extend(batch['rev_mapping'][0:index]) 106 | needed['type'].extend(batch['type'][0:index]) 107 | valLoss = valLoss + val_loss 108 | 109 | valLoss = valLoss / float(valstep) 110 | outstr = "Train-info: "+ "Epoch: ",str(curEpoch)+" Loss: "+str(train_loss) 111 | logging.info(outstr) 112 | outstr = "Val-info: "+"Epoch "+str(curEpoch)+" Loss: "+str(valLoss) 113 | logging.info(outstr) 114 | if curEpoch > 2: 115 | current_f1 = self.evaluate(needed,self.handler.vocab) 116 | if current_f1 >= best_f1: 117 | best_f1 = current_f1 118 | self.saver.save(self.sess, os.path.join(self.ckpt_path, 'model'), global_step=curEpoch) 119 | 120 | epochLoss = [] 121 | curEpoch = curEpoch + 1 122 | 123 | def score(self,parallel_corpus): 124 | 125 | # containers 126 | count = [0, 0, 0, 0] 127 | clip_count = [0, 0, 0, 0] 128 | r = 0 129 | c = 0 130 | weights = [0.25, 0.25, 0.25, 0.25] 131 | 132 | # accumulate ngram statistics 133 | for hyps, refs in parallel_corpus: 134 | hyps = [hyp.split() for hyp in hyps] 135 | refs = [ref.split() for ref in refs] 136 | for hyp in hyps: 137 | 138 | for i in range(4): 139 | # accumulate ngram counts 140 | hypcnts = Counter(ngrams(hyp, i + 1)) 141 | cnt = sum(hypcnts.values()) 142 | count[i] += cnt 143 | 144 | # compute clipped counts 145 | max_counts = {} 146 | for ref in refs: 147 | refcnts = Counter(ngrams(ref, i + 1)) 148 | for ng in hypcnts: 149 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 150 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 151 | for ng, count in hypcnts.items()) 152 | clip_count[i] += sum(clipcnt.values()) 153 | 154 | # accumulate r & c 155 | bestmatch = [1000, 1000] 156 | for ref in refs: 157 | if bestmatch[0] == 0: break 158 | diff = abs(len(ref) - len(hyp)) 159 | if diff < bestmatch[0]: 160 | bestmatch[0] = diff 161 | bestmatch[1] = len(ref) 162 | r += bestmatch[1] 163 | c += len(hyp) 164 | 165 | # computing bleu score 166 | p0 = 1e-7 167 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 168 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 169 | for i in range(4)] 170 | s = math.fsum(w * math.log(p_n) \ 171 | for w, p_n in zip(weights, p_ns) if p_n) 172 | bleu = bp * math.exp(s) 173 | return bleu 174 | 175 | def evaluate(self,data,vocab): 176 | entities = ['monday','tuesday','wednesday','thursday','friday','saturday','sunday','medicine','conference','dinner','lab','yoga','tennis','doctor','meeting','swimming','optometrist','football','dentist',"overcast","snow", "stormy", "hail","hot", "rain", "cold","cloudy", "warm", "windy","foggy", "humid", "frost", "blizzard", "drizzle", "dry", "dew", "misty","friend", "home", "coffee", "chinese","pizza", "grocery", "rest", "shopping", "parking","gas", "hospital"] 177 | outs = [] 178 | golds = [] 179 | domain_wise = {} 180 | for domain in ['schedule','navigate','weather']: 181 | domain_wise[domain] = {} 182 | domain_wise[domain]['tp_prec'] = 0.0 183 | domain_wise[domain]['tp_recall'] = 0.0 184 | domain_wise[domain]['total_prec'] = 0.0 185 | domain_wise[domain]['total_recall'] = 0.0 186 | domain_wise[domain]['gold'] = [] 187 | domain_wise[domain]['output'] = [] 188 | 189 | tp_prec = 0.0 190 | tp_recall = 0.0 191 | total_prec = 0.0 192 | total_recall = 0.0 193 | 194 | for i in range(0,len(data['sentences'])): 195 | sentence = data['sentences'][i] 196 | domain = data['type'][i] 197 | sentence = list(sentence) 198 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 199 | index = len(sentence) 200 | else: 201 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 202 | predicted = [str(sentence[j]) for j in range(0,index)] 203 | ground = data['output'][i] 204 | ground = list(ground) 205 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 206 | ground_truth = [str(ground[j]) for j in range(0,index)] 207 | 208 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 209 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 210 | 211 | for word in out_anon: 212 | if word in entities or '_' in word: 213 | total_prec = total_prec + 1 214 | domain_wise[domain]['total_prec'] = domain_wise[domain]['total_prec'] + 1 215 | if word in gold_anon: 216 | tp_prec = tp_prec + 1 217 | domain_wise[domain]['tp_prec'] = domain_wise[domain]['tp_prec'] + 1 218 | 219 | for word in gold_anon: 220 | if word in entities or '_' in word: 221 | total_recall = total_recall + 1 222 | domain_wise[domain]['total_recall'] = domain_wise[domain]['total_recall'] + 1 223 | if word in out_anon: 224 | tp_recall = tp_recall + 1 225 | domain_wise[domain]['tp_recall'] = domain_wise[domain]['tp_recall'] + 1 226 | 227 | gold = gold_anon 228 | out = out_anon 229 | 230 | domain_wise[domain]['gold'].append(" ".join(gold)) 231 | golds.append(" ".join(gold)) 232 | domain_wise[domain]['output'].append(" ".join(out)) 233 | outs.append(" ".join(out)) 234 | 235 | wrap_generated = [[_] for _ in outs] 236 | wrap_truth = [[_] for _ in golds] 237 | prec = tp_prec/total_prec 238 | recall = tp_recall/total_recall 239 | if prec == 0 or recall == 0: 240 | f1 = 0.0 241 | else: 242 | f1 = 2*prec*recall/(prec+recall) 243 | overall_f1 = f1 244 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 245 | for domain in ['schedule','navigate','weather']: 246 | prec = domain_wise[domain]['tp_prec']/domain_wise[domain]['total_prec'] 247 | recall = domain_wise[domain]['tp_recall']/domain_wise[domain]['total_recall'] 248 | if prec == 0 or recall == 0: 249 | f1 = 0.0 250 | else: 251 | f1 = 2*prec*recall/(prec+recall) 252 | wrap_generated = [[_] for _ in domain_wise[domain]['output']] 253 | wrap_truth = [[_] for _ in domain_wise[domain]['gold']] 254 | print "Domain: " + str(domain) + ", Bleu: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),f1) 255 | return overall_f1 256 | 257 | def test(self): 258 | test_epoch_done = False 259 | 260 | teststep = 0 261 | testLoss = 0.0 262 | needed = {} 263 | needed['sentences'] = [] 264 | needed['output'] = [] 265 | needed['rev_mapping'] = [] 266 | needed['mapping'] = [] 267 | needed['type'] = [] 268 | needed['context'] = [] 269 | needed['kb'] = [] 270 | 271 | while not test_epoch_done: 272 | teststep = teststep + 1 273 | batch, test_epoch_done = self.handler.get_batch(train=False) 274 | feedDict = self.model.get_feed_dict(batch) 275 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 276 | 277 | if 1 not in batch['dummy']: 278 | needed['sentences'].extend(sentences) 279 | needed['output'].extend(batch['out_utt']) 280 | needed['mapping'].extend(batch['mapping']) 281 | needed['rev_mapping'].extend(batch['rev_mapping']) 282 | needed['type'].extend(batch['type']) 283 | needed['context'].extend(batch['context']) 284 | needed['kb'].extend(batch['knowledge']) 285 | else: 286 | index = batch['dummy'].index(1) 287 | needed['sentences'].extend(sentences[0:index]) 288 | needed['output'].extend(batch['out_utt'][0:index]) 289 | needed['mapping'].extend(batch['mapping'][0:index]) 290 | needed['rev_mapping'].extend(batch['rev_mapping'][0:index]) 291 | needed['type'].extend(batch['type'][0:index]) 292 | needed['context'].extend(batch['context'][0:index]) 293 | needed['kb'].extend(batch['knowledge'][0:index]) 294 | 295 | pickle.dump(needed,open("needed.p","w")) 296 | self.evaluate(needed,self.handler.vocab) 297 | 298 | def main(): 299 | 300 | parser = argparse.ArgumentParser() 301 | parser.add_argument('--batch_size', type=int, default=16) 302 | parser.add_argument('--emb_dim', type=int, default=200) 303 | parser.add_argument('--enc_hid_dim', type=int, default=128) 304 | parser.add_argument('--dec_hid_dim', type=int, default=256) 305 | parser.add_argument('--attn_size', type=int, default=200) 306 | parser.add_argument('--epochs', type=int, default=25) 307 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 308 | parser.add_argument('--dataset_path', type=str, default='../data/InCar/') 309 | parser.add_argument('--glove_path', type=str, default='../data/') 310 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 311 | config = parser.parse_args() 312 | 313 | DEVICE = "/gpu:0" 314 | 315 | logging.info("Loading Data") 316 | 317 | handler = DataHandler( 318 | emb_dim = config.emb_dim, 319 | batch_size = config.batch_size, 320 | train_path = config.dataset_path + "train.json", 321 | val_path = config.dataset_path + "val.json", 322 | test_path = config.dataset_path + "test.json", 323 | vocab_path = "./vocab.json", 324 | glove_path = config.glove_path) 325 | 326 | logging.info("Loading Architecture") 327 | 328 | model = DialogueModel( 329 | device = DEVICE, 330 | batch_size = config.batch_size, 331 | inp_vocab_size = handler.input_vocab_size, 332 | out_vocab_size = handler.output_vocab_size, 333 | generate_size = handler.generate_vocab_size, 334 | emb_init = handler.emb_init, 335 | emb_dim = config.emb_dim, 336 | enc_hid_dim = config.enc_hid_dim, 337 | dec_hid_dim = config.dec_hid_dim, 338 | attn_size = config.attn_size) 339 | 340 | logging.info("Loading Trainer") 341 | 342 | trainer = Trainer( 343 | model=model, 344 | handler=handler, 345 | ckpt_path=config.checkpoint, 346 | num_epochs=config.epochs, 347 | learning_rate = config.learning_rate) 348 | 349 | trainer.trainData() 350 | 351 | main() -------------------------------------------------------------------------------- /maluuba/data_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import random 4 | import nltk 5 | import os 6 | import sys 7 | import numpy as np 8 | import logging 9 | logging.getLogger().setLevel(logging.INFO) 10 | 11 | single_word_entities = ["business","economy","breakfast","wifi","gym","parking","spa","park","museum","beach","shopping","market","airport","university","mall","cathedral","downtown","palace","theatre"] 12 | 13 | class DataHandler(object): 14 | 15 | def __init__(self,batch_size,emb_dim,train_path,val_path,test_path,vocab_path,entities_path,glove_path): 16 | 17 | 18 | self.batch_size = batch_size 19 | self.train_path = train_path 20 | self.emb_dim = emb_dim 21 | self.vocab_threshold = 3 22 | self.val_path = val_path 23 | self.test_path = test_path 24 | self.vocab_path = vocab_path 25 | self.entities_path = entities_path 26 | self.all_entities = json.load(open(self.entities_path)) 27 | self.glove_path = glove_path 28 | 29 | self.result_keys = self.initialise_keys() 30 | self.vocab = self.load_vocab() 31 | self.input_vocab_size = self.vocab['input_vocab_size'] 32 | self.output_vocab_size = self.vocab['output_vocab_size'] 33 | self.generate_vocab_size = self.vocab['generate_vocab_size'] 34 | self.emb_init = self.load_glove_vectors() 35 | self.result_keys_vector = self.keys_vector() 36 | 37 | self.train_data = json.load(open(self.train_path)) 38 | self.val_data = json.load(open(self.val_path)) 39 | self.test_data = json.load(open(self.test_path)) 40 | 41 | random.shuffle(self.train_data) 42 | random.shuffle(self.val_data) 43 | random.shuffle(self.test_data) 44 | 45 | self.val_data_full = self.append_dummy_data(self.val_data) 46 | 47 | self.train_index = 0 48 | self.val_index = 0 49 | self.train_num = len(self.train_data) 50 | self.val_num = len(self.val_data) 51 | 52 | def append_dummy_data(self,data): 53 | new_data = [] 54 | for i in range(0,len(data)): 55 | data[i]['dummy'] = 0 56 | new_data.append(copy.copy(data[i])) 57 | 58 | last = data[-1] 59 | last['dummy'] = 1 60 | for _ in range(0,self.batch_size - len(data)%self.batch_size): 61 | new_data.append(copy.copy(last)) 62 | 63 | return copy.copy(new_data) 64 | 65 | def initialise_keys(self): 66 | return ['airport', 'arr_time_dst', 'arr_time_or', 'beach', 'breakfast', 'cathedral', 'dep_time_dst', 'dep_time_or', 'destination', 'downtown', 'duration', 'end', 'guest', 'gym', 'mall', 'market', 'museum', 'name', 'origin', 'palace', 'park', 'parking', 'price', 'rating', 'seat', 'shopping', 'spa', 'start', 'theatre', 'university', 'wifi'] 67 | 68 | def keys_vector(self): 69 | 70 | result_keys_vector = [] 71 | for key in self.result_keys: 72 | result_keys_vector.append(self.vocab['vocab_mapping'][key]) 73 | result_keys_vector.append(self.vocab['vocab_mapping']["$EMPTY$"]) 74 | return result_keys_vector 75 | 76 | def load_glove_vectors(self): 77 | logging.info("Loading pre-trained Word Embeddings") 78 | filename = self.glove_path + "glove.6B.200d.txt" 79 | glove = {} 80 | file = open(filename,'r') 81 | for line in file.readlines(): 82 | row = line.strip().split(' ') 83 | glove[row[0]] = np.asarray(row[1:]) 84 | logging.info('Loaded GloVe!') 85 | file.close() 86 | embeddings_init = np.random.normal(size=(self.vocab['input_vocab_size'],self.emb_dim)).astype('f') 87 | count = 0 88 | for word in self.vocab['vocab_mapping']: 89 | if word in glove: 90 | count = count + 1 91 | embeddings_init[self.vocab['vocab_mapping'][word]] = glove[word] 92 | 93 | del glove 94 | 95 | logging.info("Loaded "+str(count)+" pre-trained Word Embeddings") 96 | return embeddings_init 97 | 98 | def load_vocab(self): 99 | if os.path.isfile(self.vocab_path): 100 | logging.info("Loading vocab from file") 101 | with open(self.vocab_path) as f: 102 | return json.load(f) 103 | else: 104 | logging.info("Vocab file not found. Computing Vocab") 105 | with open(self.train_path) as f: 106 | train_data = json.load(f) 107 | with open(self.val_path) as f: 108 | val_data = json.load(f) 109 | with open(self.test_path) as f: 110 | test_data = json.load(f) 111 | 112 | full_data = [] 113 | full_data.extend(train_data) 114 | full_data.extend(val_data) 115 | full_data.extend(test_data) 116 | 117 | return self.get_vocab(full_data) 118 | 119 | def get_vocab(self,data): 120 | 121 | vocab = {} 122 | for d in data: 123 | utts = [] 124 | utts.append(d['output']) 125 | for utt in d['context']: 126 | utts.append(utt['utt']) 127 | for utt in utts: 128 | tokens = utt.split(" ") 129 | for token in tokens: 130 | if token.encode('ascii',errors='ignore').lower() not in vocab: 131 | vocab[token.encode('ascii',errors='ignore').lower()] = 1 132 | else: 133 | vocab[token.encode('ascii',errors='ignore').lower()] = vocab[token.encode('ascii',errors='ignore').lower()] + 1 134 | 135 | 136 | for query in d['kb']: 137 | search = query[0] 138 | results = query[1] 139 | for key in search: 140 | if key.encode('ascii',errors='ignore').lower() not in vocab: 141 | vocab[key.encode('ascii',errors='ignore').lower()] = 1 142 | else: 143 | vocab[key.encode('ascii',errors='ignore').lower()] = vocab[key.encode('ascii',errors='ignore').lower()] + 1 144 | 145 | if search[key].encode('ascii',errors='ignore').lower() not in vocab: 146 | vocab[search[key].encode('ascii',errors='ignore').lower()] = 1 147 | else: 148 | vocab[search[key].encode('ascii',errors='ignore').lower()] = vocab[search[key].encode('ascii',errors='ignore').lower()] + 1 149 | 150 | for result in results: 151 | for key in result: 152 | 153 | if result[key].encode('ascii',errors='ignore').lower() not in vocab: 154 | vocab[result[key].encode('ascii',errors='ignore').lower()] = 1 155 | else: 156 | vocab[result[key].encode('ascii',errors='ignore').lower()] = vocab[result[key].encode('ascii',errors='ignore').lower()] + 1 157 | 158 | words = [] 159 | for v in vocab: 160 | if vocab[v] > self.vocab_threshold or v in self.all_entities or v.isdigit(): 161 | words.append(v) 162 | 163 | for key in self.result_keys: 164 | if key not in words: 165 | words.append(key) 166 | 167 | words.append("$UNK$") 168 | words.append("$STOP$") 169 | words.append("$PAD$") 170 | words.append("$EMPTY$") 171 | 172 | for i in range(1,31): 173 | words.append("$u"+str(i)+"$") 174 | words.append("$s"+str(i)+"$") 175 | words.append("$u31$") 176 | 177 | generate_words = [] 178 | copy_words = [] 179 | for word in words: 180 | if word in single_word_entities or word in self.all_entities: 181 | copy_words.append(word) 182 | else: 183 | generate_words.append(word) 184 | 185 | output_vocab_size = len(words) + 1 186 | 187 | generate_indices = [i for i in range(1,len(generate_words)+1)] 188 | copy_indices = [i for i in range(len(generate_words)+1,output_vocab_size)] 189 | random.shuffle(generate_indices) 190 | random.shuffle(copy_indices) 191 | 192 | mapping = {} 193 | rev_mapping = {} 194 | 195 | for i in range(0,len(generate_words)): 196 | mapping[generate_words[i]] = generate_indices[i] 197 | rev_mapping[str(generate_indices[i])] = generate_words[i] 198 | 199 | for i in range(0,len(copy_words)): 200 | mapping[copy_words[i]] = copy_indices[i] 201 | rev_mapping[str(copy_indices[i])] = copy_words[i] 202 | 203 | mapping["$GO$"] = 0 204 | rev_mapping[0] = "$GO$" 205 | vocab_dict = {} 206 | vocab_dict['vocab_mapping'] = mapping 207 | vocab_dict['rev_mapping'] = rev_mapping 208 | vocab_dict['input_vocab_size'] = len(words) + 1 209 | vocab_dict['generate_vocab_size'] = len(generate_words) + 1 210 | vocab_dict['output_vocab_size'] = output_vocab_size 211 | 212 | with open(self.vocab_path,'w') as f: 213 | json.dump(vocab_dict,f) 214 | 215 | logging.info("Vocab file created") 216 | 217 | return vocab_dict 218 | 219 | def get_sentinel(self,i,context): 220 | if i%2 == 0: 221 | speaker = "u" 222 | turn = (context - i + 1)/2 223 | else: 224 | speaker = "s" 225 | turn = (context - i)/2 226 | return "$"+speaker+str(turn)+"$" 227 | 228 | def vectorize(self,batch,train): 229 | vectorized = {} 230 | vectorized['inp_utt'] = [] 231 | vectorized['out_utt'] = [] 232 | vectorized['inp_len'] = [] 233 | vectorized['out_len'] = [] 234 | vectorized['context_len'] = [] 235 | vectorized['query_mask'] = [] 236 | vectorized['query_actual_mask'] = [] 237 | vectorized['search_keys'] = [] 238 | vectorized['search_mask'] = [] 239 | vectorized['search_values'] = [] 240 | vectorized['results_mask'] = [] 241 | vectorized['results_actual_mask'] = [] 242 | vectorized['result_keys_mask'] = [] 243 | vectorized['result_values'] = [] 244 | vectorized['dummy'] = [] 245 | vectorized['empty'] = [] 246 | 247 | vectorized['kb'] = [] 248 | vectorized['context'] = [] 249 | 250 | max_query_num = 0 251 | max_inp_utt_len = 0 252 | max_out_utt_len = 0 253 | max_context_len = 0 254 | 255 | search_keys_len = 8 256 | result_keys_len = 31 257 | 258 | for item in batch: 259 | if len(item['kb']) > max_query_num: 260 | max_query_num = len(item['kb']) 261 | 262 | if len(item['context']) > max_context_len: 263 | max_context_len = len(item['context']) 264 | 265 | 266 | for utt in item['context']: 267 | tokens = utt['utt'].split(" ") 268 | 269 | if len(tokens) > max_inp_utt_len: 270 | max_inp_utt_len = len(tokens) 271 | 272 | tokens = item['output'].split(" ") 273 | if len(tokens) > max_out_utt_len: 274 | max_out_utt_len = len(tokens) 275 | 276 | max_inp_utt_len = max_inp_utt_len + 1 277 | 278 | max_out_utt_len = max_out_utt_len + 1 279 | vectorized['max_out_utt_len'] = max_out_utt_len 280 | 281 | for item in batch: 282 | 283 | vectorized['context'].append(item['context']) 284 | vectorized['kb'].append(item['kb']) 285 | 286 | if item['kb'] == []: 287 | vectorized['empty'].append(0) 288 | else: 289 | vectorized['empty'].append(1) 290 | 291 | if not train: 292 | vectorized['dummy'].append(item['dummy']) 293 | 294 | vector_inp = [] 295 | vector_len = [] 296 | for i in range(0,len(item['context'])): 297 | utt = item['context'][i] 298 | inp = [] 299 | sentinel = self.get_sentinel(i,len(item['context'])) 300 | tokens = utt['utt'].split(" ")+[sentinel] 301 | for token in tokens: 302 | if token in self.vocab['vocab_mapping']: 303 | inp.append(self.vocab['vocab_mapping'][token.encode('ascii',errors='ignore')]) 304 | else: 305 | inp.append(self.vocab['vocab_mapping']["$UNK$"]) 306 | 307 | vector_len.append(len(tokens)) 308 | for _ in range(0,max_inp_utt_len - len(tokens)): 309 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 310 | vector_inp.append(copy.copy(inp)) 311 | 312 | vectorized['context_len'].append(len(item['context'])) 313 | 314 | for _ in range(0,max_context_len - len(item['context'])): 315 | vector_len.append(0) 316 | inp = [] 317 | for _ in range(0,max_inp_utt_len): 318 | inp.append(self.vocab['vocab_mapping']["$PAD$"]) 319 | vector_inp.append(copy.copy(inp)) 320 | 321 | vectorized['inp_utt'].append(copy.copy(vector_inp)) 322 | vectorized['inp_len'].append(copy.copy(vector_len)) 323 | 324 | vector_out = [] 325 | 326 | tokens = item['output'].split(" ") 327 | tokens.append("$STOP$") 328 | for token in tokens: 329 | if token in self.vocab['vocab_mapping']: 330 | vector_out.append(self.vocab['vocab_mapping'][token.encode('ascii',errors='ignore')]) 331 | else: 332 | vector_out.append(self.vocab['vocab_mapping']["$UNK$"]) 333 | 334 | for _ in range(0,max_out_utt_len - len(tokens)): 335 | vector_out.append(self.vocab['vocab_mapping']["$PAD$"]) 336 | vectorized['out_utt'].append(copy.copy(vector_out)) 337 | vectorized['out_len'].append(len(tokens)) 338 | 339 | vector_search_mask = [] 340 | vector_search_keys = [] 341 | vector_search_values = [] 342 | vector_query_mask = [] 343 | vector_query_actual_mask = [] 344 | vector_result_keys_mask = [] 345 | vector_result_values = [] 346 | vector_results_mask = [] 347 | vector_results_actual_mask = [] 348 | 349 | kb = [] 350 | 351 | for query in item['kb']: 352 | 353 | search_mask = [] 354 | search_keys = [] 355 | search_values = [] 356 | 357 | for key in query[0]: 358 | search_mask.append(1) 359 | search_keys.append(self.vocab['vocab_mapping'][key]) 360 | if str(query[0][key]).encode('ascii',errors='ignore').lower() in self.vocab['vocab_mapping']: 361 | search_values.append(self.vocab['vocab_mapping'][str(query[0][key]).encode('ascii',errors='ignore').lower()]) 362 | else: 363 | search_values.append(self.vocab['vocab_mapping']["$UNK$"]) 364 | 365 | for _ in range(0,search_keys_len - len(query[0].keys())): 366 | search_mask.append(0) 367 | search_keys.append(self.vocab['vocab_mapping']["$PAD$"]) 368 | search_values.append(self.vocab['vocab_mapping']["$PAD$"]) 369 | 370 | results_mask = [] 371 | results_actual_mask = [] 372 | result_values = [] 373 | result_keys_mask = [] 374 | result_keys = [] 375 | 376 | for result in query[1]: 377 | results_mask.append(1) 378 | results_actual_mask.append(1) 379 | mask = [] 380 | values = [] 381 | 382 | for key in self.result_keys: 383 | if key in result: 384 | mask.append(1) 385 | values.append(self.vocab['vocab_mapping'][str(result[key].encode('ascii',errors='ignore')).lower()]) 386 | else: 387 | mask.append(0) 388 | values.append(self.vocab['vocab_mapping']["$PAD$"]) 389 | if 1 not in mask: 390 | mask.append(1) 391 | else: 392 | mask.append(0) 393 | values.append(self.vocab['vocab_mapping']["$PAD$"]) 394 | 395 | result_keys_mask.append(copy.copy(mask)) 396 | result_values.append(copy.copy(values)) 397 | 398 | if len(results_mask) == 0: 399 | results_mask.append(1) 400 | results_actual_mask.append(0) 401 | result_keys_mask.append([0 for _ in range(0,len(self.result_keys))]+[1]) 402 | result_values.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,len(self.result_keys))]+[self.vocab['vocab_mapping']["$EMPTY$"]]) 403 | 404 | for _ in range(0,10-len(results_mask)): 405 | results_mask.append(0) 406 | results_actual_mask.append(0) 407 | result_keys_mask.append([0 for _ in range(0,len(self.result_keys))]+[1]) 408 | result_values.append([self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,len(self.result_keys))]+[self.vocab['vocab_mapping']["$EMPTY$"]]) 409 | 410 | 411 | kb.append([1,1,copy.copy(search_mask),copy.copy(search_values),copy.copy(search_keys),copy.copy(results_mask),copy.copy(results_actual_mask),copy.copy(result_keys_mask),copy.copy(result_values)]) 412 | 413 | if len(kb) == 0: 414 | newitem = [] 415 | newitem.extend([1,0,[1]+[0 for _ in range(0,search_keys_len-1)],[self.vocab["vocab_mapping"]["$PAD$"] for _ in range(0,search_keys_len)],[self.vocab["vocab_mapping"]["$PAD$"] for _ in range(0,search_keys_len)]]) 416 | newitem.append([1]+[0 for _ in range(0,9)]) 417 | newitem.append([0 for _ in range(0,10)]) 418 | newitem.append([[0 for _ in range(0,len(self.result_keys))]+[1] for _ in range(0,10)]) 419 | newitem.append([[self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,len(self.result_keys))]+[self.vocab['vocab_mapping']["$EMPTY$"]] for _ in range(0,10)]) 420 | kb.append(copy.copy(newitem)) 421 | 422 | kb_length = len(kb) 423 | 424 | for _ in range(0,max_query_num - kb_length): 425 | newitem = [] 426 | newitem.extend([0,0,[1]+[0 for _ in range(0,search_keys_len-1)],[self.vocab["vocab_mapping"]["$PAD$"] for _ in range(0,search_keys_len)],[self.vocab["vocab_mapping"]["$PAD$"] for _ in range(0,search_keys_len)]]) 427 | newitem.append([1]+[0 for _ in range(0,9)]) 428 | newitem.append([0 for _ in range(0,10)]) 429 | newitem.append([[0 for _ in range(0,len(self.result_keys))]+[1] for _ in range(0,10)]) 430 | newitem.append([[self.vocab['vocab_mapping']["$PAD$"] for _ in range(0,len(self.result_keys)+1)] for _ in range(0,10)]) 431 | kb.append(copy.copy(newitem)) 432 | 433 | random.shuffle(kb) 434 | 435 | for each in kb: 436 | vector_query_mask.append(each[0]) 437 | vector_query_actual_mask.append(each[1]) 438 | vector_search_mask.append(each[2]) 439 | vector_search_values.append(each[3]) 440 | vector_search_keys.append(each[4]) 441 | vector_results_mask.append(each[5]) 442 | vector_results_actual_mask.append(each[6]) 443 | vector_result_keys_mask.append(each[7]) 444 | vector_result_values.append(each[8]) 445 | 446 | vectorized['query_mask'].append(copy.copy(vector_query_mask)) 447 | vectorized['query_actual_mask'].append(copy.copy(vector_query_actual_mask)) 448 | vectorized['search_mask'].append(copy.copy(vector_search_mask)) 449 | vectorized['search_values'].append(copy.copy(vector_search_values)) 450 | vectorized['search_keys'].append(copy.copy(vector_search_keys)) 451 | vectorized['results_mask'].append(copy.copy(vector_results_mask)) 452 | vectorized['results_actual_mask'].append(copy.copy(vector_results_actual_mask)) 453 | vectorized['result_keys_mask'].append(copy.copy(vector_result_keys_mask)) 454 | vectorized['result_values'].append(copy.copy(vector_result_values)) 455 | 456 | return vectorized 457 | 458 | def get_batch(self,train): 459 | 460 | epoch_done = False 461 | 462 | if train: 463 | index = self.train_index 464 | batch = self.vectorize(self.train_data[index:index+self.batch_size],train) 465 | self.train_index = self.train_index + self.batch_size 466 | 467 | if self.train_index + self.batch_size > self.train_num: 468 | self.train_index = 0 469 | random.shuffle(self.train_data) 470 | epoch_done = True 471 | 472 | else: 473 | index = self.val_index 474 | batch = self.vectorize(self.val_data_full[index:index+self.batch_size],train) 475 | self.val_index = self.val_index + self.batch_size 476 | 477 | if self.val_index + self.batch_size > self.val_num: 478 | self.val_index = 0 479 | random.shuffle(self.val_data) 480 | self.val_data_full = self.append_dummy_data(self.val_data) 481 | epoch_done = True 482 | 483 | return batch,epoch_done 484 | -------------------------------------------------------------------------------- /maluuba/evaluate.py: -------------------------------------------------------------------------------- 1 | import cPickle as pickle 2 | import copy 3 | import json 4 | import csv 5 | from collections import Counter 6 | from nltk.util import ngrams 7 | from nltk.corpus import stopwords 8 | from nltk.tokenize import word_tokenize 9 | from nltk.stem import WordNetLemmatizer 10 | import math, re, argparse 11 | import functools 12 | 13 | entities = ["business","economy","breakfast","wifi","gym","parking","spa","park","museum","beach","shopping","market","airport","university","mall","cathedral","downtown","palace","theatre"] 14 | 15 | def score(parallel_corpus): 16 | 17 | # containers 18 | count = [0, 0, 0, 0] 19 | clip_count = [0, 0, 0, 0] 20 | r = 0 21 | c = 0 22 | weights = [0.25, 0.25, 0.25, 0.25] 23 | 24 | # accumulate ngram statistics 25 | for hyps, refs in parallel_corpus: 26 | hyps = [hyp.split() for hyp in hyps] 27 | refs = [ref.split() for ref in refs] 28 | for hyp in hyps: 29 | 30 | for i in range(4): 31 | # accumulate ngram counts 32 | hypcnts = Counter(ngrams(hyp, i + 1)) 33 | cnt = sum(hypcnts.values()) 34 | count[i] += cnt 35 | 36 | # compute clipped counts 37 | max_counts = {} 38 | for ref in refs: 39 | refcnts = Counter(ngrams(ref, i + 1)) 40 | for ng in hypcnts: 41 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 42 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 43 | for ng, count in hypcnts.items()) 44 | clip_count[i] += sum(clipcnt.values()) 45 | 46 | # accumulate r & c 47 | bestmatch = [1000, 1000] 48 | for ref in refs: 49 | if bestmatch[0] == 0: break 50 | diff = abs(len(ref) - len(hyp)) 51 | if diff < bestmatch[0]: 52 | bestmatch[0] = diff 53 | bestmatch[1] = len(ref) 54 | r += bestmatch[1] 55 | c += len(hyp) 56 | 57 | # computing bleu score 58 | p0 = 1e-7 59 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 60 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 61 | for i in range(4)] 62 | s = math.fsum(w * math.log(p_n) \ 63 | for w, p_n in zip(weights, p_ns) if p_n) 64 | bleu = bp * math.exp(s) 65 | return bleu 66 | 67 | 68 | data = pickle.load(open("needed.p")) 69 | vocab = json.load(open("./vocab.json")) 70 | all_entities = json.load(open("../data/Maluuba/entities.json")) 71 | outs = [] 72 | golds = [] 73 | 74 | tp_prec = 0.0 75 | tp_recall = 0.0 76 | total_prec = 0.0 77 | total_recall = 0.0 78 | 79 | for i in range(0,len(data['sentences'])): 80 | sentence = data['sentences'][i] 81 | sentence = list(sentence) 82 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 83 | index = len(sentence) 84 | else: 85 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 86 | predicted = [str(sentence[j]) for j in range(0,index)] 87 | ground = data['output'][i] 88 | ground = list(ground) 89 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 90 | ground_truth = [str(ground[j]) for j in range(0,index)] 91 | 92 | gold_anon = [vocab['rev_mapping'][word].encode('ascii',errors='ignore') for word in ground_truth ] 93 | out_anon = [vocab['rev_mapping'][word].encode('ascii',errors='ignore') for word in predicted ] 94 | 95 | for word in out_anon: 96 | if word in entities or word.isdigit() or word in all_entities: 97 | total_prec = total_prec + 1 98 | if word in gold_anon: 99 | tp_prec = tp_prec + 1 100 | 101 | for word in gold_anon: 102 | if word in entities or word.isdigit() or word in all_entities: 103 | total_recall = total_recall + 1 104 | if word in out_anon: 105 | tp_recall = tp_recall + 1 106 | 107 | gold = gold_anon 108 | out = out_anon 109 | 110 | golds.append(" ".join(gold)) 111 | outs.append(" ".join(out)) 112 | 113 | with open('output', 'w') as output_file: 114 | for line in outs: 115 | output_file.write(line+"\n") 116 | 117 | with open('reference', 'w') as output_file: 118 | for line in golds: 119 | output_file.write(line+"\n") 120 | 121 | wrap_generated = [[_] for _ in outs] 122 | wrap_truth = [[_] for _ in golds] 123 | prec = tp_prec/total_prec 124 | recall = tp_recall/total_recall 125 | print 'prec',tp_prec,total_prec 126 | print 'recall',tp_recall,total_recall 127 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (score(zip(wrap_generated, wrap_truth)),prec,recall,2*prec*recall/(prec+recall)) 128 | -------------------------------------------------------------------------------- /maluuba/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow.python.ops import embedding_ops, array_ops, math_ops, tensor_array_ops, control_flow_ops 4 | 5 | class DialogueModel(object): 6 | 7 | def __init__(self,device,batch_size,inp_vocab_size,out_vocab_size,generate_size,emb_init,result_keys_vector,emb_dim,enc_hid_dim,dec_hid_dim,attn_size): 8 | 9 | self.device = device 10 | self.batch_size = batch_size 11 | self.emb_dim = emb_dim 12 | self.inp_vocab_size = inp_vocab_size 13 | self.out_vocab_size = out_vocab_size 14 | self.generate_size = generate_size 15 | self.emb_init = emb_init 16 | self.result_keys = result_keys_vector 17 | self.enc_hid_dim = enc_hid_dim 18 | self.dec_hid_dim = dec_hid_dim 19 | self.attn_size = attn_size 20 | self.generate_size = generate_size 21 | 22 | self.inp_utt = tf.placeholder( 23 | name='inp_utt', dtype=tf.int64, 24 | shape=[self.batch_size, None, None], 25 | ) 26 | 27 | self.inp_len = tf.placeholder( 28 | name='inp_len', dtype=tf.int64, 29 | shape=[self.batch_size, None], 30 | ) 31 | 32 | self.context_len = tf.placeholder( 33 | name='context_len', dtype=tf.int64, 34 | shape=[self.batch_size], 35 | ) 36 | 37 | self.out_utt = tf.placeholder( 38 | name='out_utt', dtype=tf.int64, 39 | shape=[self.batch_size, None], 40 | ) 41 | 42 | self.out_len = tf.placeholder( 43 | name='out_len', dtype=tf.float32, 44 | shape=[self.batch_size], 45 | ) 46 | 47 | self.query_mask = tf.placeholder( 48 | name='query_mask', dtype=tf.float32, 49 | shape=[self.batch_size,None], 50 | ) 51 | 52 | self.search_mask = tf.placeholder( 53 | name='search_mask', dtype=tf.float32, 54 | shape=[self.batch_size,None,8], 55 | ) 56 | 57 | self.search_values = tf.placeholder( 58 | name='search_values', dtype=tf.int64, 59 | shape=[self.batch_size,None,8], 60 | ) 61 | 62 | self.results_mask = tf.placeholder( 63 | name='results_mask', dtype=tf.float32, 64 | shape=[self.batch_size,None,10], 65 | ) 66 | 67 | self.result_keys_mask = tf.placeholder( 68 | name='result_keys_mask', dtype=tf.float32, 69 | shape=[self.batch_size,None,10,len(self.result_keys)], 70 | ) 71 | 72 | self.result_values = tf.placeholder( 73 | name='result_values', dtype=tf.int64, 74 | shape=[self.batch_size,None,10,len(self.result_keys)], 75 | ) 76 | 77 | self.max_out_utt_len = tf.placeholder( 78 | name = 'max_out_utt_len', dtype=tf.int32, 79 | shape = (), 80 | ) 81 | 82 | self.db_empty = tf.placeholder( 83 | name='db_empty', dtype=tf.float32, 84 | shape=[self.batch_size], 85 | ) 86 | 87 | self.buildArch() 88 | 89 | def buildArch(self): 90 | 91 | with tf.device(self.device): 92 | 93 | self.embeddings = tf.get_variable("embeddings", initializer=tf.constant(self.emb_init)) 94 | self.inp_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.inp_utt) 95 | 96 | with tf.variable_scope("encoder"): 97 | self.encoder_cell_1 = tf.contrib.rnn.GRUCell(self.enc_hid_dim) 98 | self.encoder_cell_2 = tf.contrib.rnn.GRUCell(2*self.enc_hid_dim) 99 | self.flat_inp_emb = tf.reshape(self.inp_utt_emb,shape=[-1,tf.shape(self.inp_utt)[2],self.emb_dim]) 100 | self.flat_inp_len = tf.reshape(self.inp_len,shape=[-1]) 101 | 102 | outputs,output_states = tf.nn.bidirectional_dynamic_rnn(cell_fw=self.encoder_cell_1,cell_bw=self.encoder_cell_1,inputs=self.flat_inp_emb,dtype=tf.float32,sequence_length=self.flat_inp_len,time_major=False) 103 | self.flat_encoder_states = tf.concat(outputs,axis=2) 104 | self.utt_reps = tf.concat(output_states,axis=1) 105 | 106 | self.utt_rep_second = tf.reshape(self.utt_reps,shape=[self.batch_size,-1,2*self.enc_hid_dim]) 107 | self.hidden_states, self.inp_utt_rep = tf.nn.dynamic_rnn(self.encoder_cell_2,self.utt_rep_second,dtype=tf.float32,sequence_length=self.context_len,time_major=False) 108 | self.encoder_states = tf.reshape(tf.reshape(self.flat_encoder_states,shape=[self.batch_size,-1,tf.shape(self.inp_utt)[2],2*self.enc_hid_dim]), shape=[self.batch_size,-1,2*self.enc_hid_dim]) 109 | 110 | 111 | 112 | self.search_values_emb = embedding_ops.embedding_lookup(self.embeddings, self.search_values) 113 | self.search_values_rep = tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(self.search_mask,2),-1),tf.reduce_sum(tf.einsum('ijk,ijkl->ijkl',self.search_mask,self.search_values_emb),2)) 114 | 115 | self.result_values_emb = embedding_ops.embedding_lookup(self.embeddings, self.result_values) 116 | self.result_values_rep = tf.einsum('ijk,ijkl->ijkl',tf.pow(tf.reduce_sum(self.result_keys_mask,3),-1),tf.reduce_sum(tf.einsum('ijkl,ijklm->ijklm',self.result_keys_mask,self.result_values_emb),3)) 117 | 118 | self.result_keys_batch = tf.reshape(tf.tile(self.result_keys,[self.batch_size]),shape=(self.batch_size,len(self.result_keys))) 119 | self.result_keys_emb = embedding_ops.embedding_lookup(self.embeddings, self.result_keys_batch) 120 | 121 | self.start_token = tf.constant([0] * self.batch_size, dtype=tf.int32) 122 | self.out_utt_emb = embedding_ops.embedding_lookup(self.embeddings, self.out_utt) 123 | self.processed_x = tf.transpose(self.out_utt_emb,perm=[1,0,2]) 124 | 125 | with tf.variable_scope("decoder"): 126 | self.decoder_cell = tf.contrib.rnn.GRUCell(self.dec_hid_dim) 127 | 128 | self.h0 = self.inp_utt_rep 129 | self.g_output_unit = self.create_output_unit() 130 | 131 | gen_x = tensor_array_ops.TensorArray(dtype=tf.int32, size=self.max_out_utt_len, 132 | dynamic_size=False, infer_shape=True) 133 | 134 | def _g_recurrence(i, x_t, h_tm1, gen_x): 135 | _,h_t = self.decoder_cell(x_t,h_tm1) 136 | o_t = self.g_output_unit(h_t) # batch x vocab , prob 137 | next_token = tf.cast(tf.reshape(tf.argmax(o_t, 1), [self.batch_size]), tf.int32) 138 | x_tp1 = embedding_ops.embedding_lookup(self.embeddings,next_token) # batch x emb_dim 139 | gen_x = gen_x.write(i, next_token) # indices, batch_size 140 | return i + 1, x_tp1, h_t, gen_x 141 | 142 | _, _, _, self.gen_x = control_flow_ops.while_loop( 143 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 144 | body=_g_recurrence, 145 | loop_vars=(tf.constant(0, dtype=tf.int32), 146 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 147 | self.h0, gen_x)) 148 | 149 | self.gen_x = self.gen_x.stack() # seq_length x batch_size 150 | self.gen_x = tf.transpose(self.gen_x, perm=[1, 0]) # batch_size x seq_length 151 | 152 | # gen_x contains the colours sampled as outputs 153 | # Hence, gen_x is used while calculating accuracy 154 | 155 | g_predictions = tensor_array_ops.TensorArray( 156 | dtype=tf.float32, size=self.max_out_utt_len, 157 | dynamic_size=False, infer_shape=True) 158 | 159 | ta_emb_x = tensor_array_ops.TensorArray( 160 | dtype=tf.float32, size=self.max_out_utt_len) 161 | ta_emb_x = ta_emb_x.unstack(self.processed_x) 162 | 163 | def _train_recurrence(i, x_t, h_tm1, g_predictions): 164 | _,h_t = self.decoder_cell(x_t,h_tm1) 165 | o_t = self.g_output_unit(h_t) 166 | g_predictions = g_predictions.write(i, o_t) # batch x vocab_size 167 | x_tp1 = ta_emb_x.read(i) 168 | return i + 1, x_tp1, h_t, g_predictions 169 | 170 | _, _, _, self.g_predictions = control_flow_ops.while_loop( 171 | cond=lambda i, _1, _2, _3: i < self.max_out_utt_len, 172 | body=_train_recurrence, 173 | loop_vars=(tf.constant(0, dtype=tf.int32), 174 | embedding_ops.embedding_lookup(self.embeddings,self.start_token), 175 | self.h0, g_predictions)) 176 | 177 | self.g_predictions = tf.transpose(self.g_predictions.stack(), perm=[1, 0, 2]) # batch_size x seq_length x vocab_size 178 | 179 | self.loss_mask = tf.sequence_mask(self.out_len,self.max_out_utt_len,dtype=tf.float32) 180 | self.ground_truth = tf.one_hot(self.out_utt,on_value=tf.constant(1,dtype=tf.float32),off_value=tf.constant(0,dtype=tf.float32),depth=self.out_vocab_size,dtype=tf.float32) 181 | self.log_predictions = tf.log(self.g_predictions + 1e-20) 182 | self.cross_entropy = tf.multiply(self.ground_truth,self.log_predictions) 183 | self.cross_entropy_sum = tf.reduce_sum(self.cross_entropy,2) 184 | self.masked_cross_entropy = tf.multiply(self.loss_mask,self.cross_entropy_sum) 185 | self.sentence_loss = tf.divide(tf.reduce_sum(self.masked_cross_entropy,1),tf.reduce_sum(self.loss_mask,1)) 186 | self.loss = -tf.reduce_mean(self.sentence_loss) 187 | 188 | def create_output_unit(self): 189 | 190 | self.W1 = tf.get_variable("W1",shape=[2*self.enc_hid_dim+self.dec_hid_dim,2*self.enc_hid_dim],dtype=tf.float32) 191 | self.W2 = tf.get_variable("W2",shape=[2*self.enc_hid_dim,self.attn_size],dtype=tf.float32) 192 | self.w = tf.get_variable("w",shape=[self.attn_size,1],dtype=tf.float32) 193 | self.U = tf.get_variable("U",shape=[self.dec_hid_dim+2*self.enc_hid_dim,self.generate_size],dtype=tf.float32) 194 | self.W_1 = tf.get_variable("W_1",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 195 | self.W_2 = tf.get_variable("W_2",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 196 | self.W_3 = tf.get_variable("W_3",shape=[self.emb_dim+self.dec_hid_dim+2*self.enc_hid_dim,2*self.dec_hid_dim],dtype=tf.float32) 197 | self.W_12 = tf.get_variable("W_12",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 198 | self.W_22 = tf.get_variable("W_22",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 199 | self.W_32 = tf.get_variable("W_32",shape=[2*self.dec_hid_dim,self.attn_size],dtype=tf.float32) 200 | self.r_1 = tf.get_variable("r_1",shape=[self.attn_size,1],dtype=tf.float32) 201 | self.r_2 = tf.get_variable("r_2",shape=[self.attn_size,1],dtype=tf.float32) 202 | self.r_3 = tf.get_variable("r_3",shape=[self.attn_size,1],dtype=tf.float32) 203 | self.b1 = tf.get_variable("b1",shape=[self.generate_size],dtype=tf.float32) 204 | self.b2 = tf.get_variable("b2",shape=[1],dtype=tf.float32) 205 | self.b3 = tf.get_variable("b3",shape=[1],dtype=tf.float32) 206 | self.W3 = tf.get_variable("W3",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 207 | self.W4 = tf.get_variable("W4",shape=[self.dec_hid_dim+2*self.enc_hid_dim+self.emb_dim,1],dtype=tf.float32) 208 | 209 | def unit(hidden_state): 210 | 211 | hidden_state_expanded_attn = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.encoder_states)[1],1]) 212 | attn_rep = tf.concat([self.encoder_states,hidden_state_expanded_attn],axis=2) 213 | attn_rep = tf.nn.tanh(tf.einsum('ijk,kl->ijl',tf.nn.tanh(tf.einsum("ijk,kl->ijl",attn_rep,self.W1)),self.W2)) 214 | u_i = tf.squeeze(tf.einsum('ijk,kl->ijl',attn_rep,self.w),2) 215 | inp_len_mask = tf.sequence_mask(self.inp_len,tf.shape(self.inp_utt)[2],dtype=tf.float32) 216 | attn_mask = tf.reshape(inp_len_mask,shape=[self.batch_size,-1]) 217 | exp_u_i_masked = tf.multiply(tf.cast(attn_mask,dtype=tf.float64),tf.exp(tf.cast(u_i,dtype=tf.float64))) 218 | a = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(exp_u_i_masked,1),-1),exp_u_i_masked),dtype=tf.float32) 219 | inp_attn = tf.reduce_sum(tf.einsum('ij,ijk->ijk',a,self.encoder_states),1) 220 | 221 | generate_dist = tf.nn.softmax(math_ops.matmul(tf.concat([hidden_state,inp_attn],axis=1),self.U) + self.b1) 222 | extra_zeros = tf.zeros([self.batch_size,self.out_vocab_size - self.generate_size]) 223 | extended_generate_dist = tf.concat([generate_dist,extra_zeros],axis=1) 224 | 225 | hidden_state_expanded_query = tf.tile(array_ops.expand_dims(hidden_state,1),[1,tf.shape(self.query_mask)[1],1]) 226 | inp_attn_expanded_query = tf.tile(array_ops.expand_dims(inp_attn,1),[1,tf.shape(self.query_mask)[1],1]) 227 | query_attn_rep = tf.concat([self.search_values_rep,hidden_state_expanded_query,inp_attn_expanded_query],axis=2) 228 | query_attn_rep = tf.nn.tanh(tf.einsum("ijk,kl->ijl",tf.nn.tanh(tf.einsum("ijk,kl->ijl",query_attn_rep,self.W_3)),self.W_32)) 229 | alpha_logits = tf.squeeze(tf.einsum('ijk,kl->ijl',query_attn_rep,self.r_3),2) 230 | alpha_masked = tf.multiply(tf.cast(self.query_mask,dtype=tf.float64),tf.exp(tf.cast(alpha_logits,dtype=tf.float64))) 231 | alpha = tf.cast(tf.einsum('i,ij->ij',tf.pow(tf.reduce_sum(alpha_masked,1),-1),alpha_masked),dtype=tf.float32) 232 | 233 | hidden_state_expanded_result = tf.tile(array_ops.expand_dims(array_ops.expand_dims(hidden_state,1),1),[1,tf.shape(self.results_mask)[1],tf.shape(self.results_mask)[2],1]) 234 | inp_attn_expanded_result = tf.tile(array_ops.expand_dims(array_ops.expand_dims(inp_attn,1),1),[1,tf.shape(self.results_mask)[1],tf.shape(self.results_mask)[2],1]) 235 | result_attn_rep = tf.concat([self.result_values_rep,hidden_state_expanded_result,inp_attn_expanded_result],axis=3) 236 | result_attn_rep = tf.nn.tanh(tf.einsum("ijkl,lm->ijkm",tf.nn.tanh(tf.einsum("ijkl,lm->ijkm",result_attn_rep,self.W_1)),self.W_12)) 237 | beta_logits = tf.squeeze(tf.einsum('ijkl,lm->ijkm',result_attn_rep,self.r_1),3) 238 | beta_masked = tf.multiply(tf.cast(self.results_mask,dtype=tf.float64),tf.exp(tf.cast(beta_logits,dtype=tf.float64))) 239 | beta = tf.einsum('ij,ijk->ijk',alpha,tf.cast(tf.einsum('ij,ijk->ijk',tf.pow(tf.reduce_sum(beta_masked,2),-1),beta_masked),dtype=tf.float32)) 240 | 241 | hidden_state_expanded_keys = tf.tile(array_ops.expand_dims(hidden_state,1),[1,len(self.result_keys),1]) 242 | inp_attn_expanded_keys = tf.tile(array_ops.expand_dims(inp_attn,1),[1,len(self.result_keys),1]) 243 | result_key_rep = tf.concat([self.result_keys_emb,hidden_state_expanded_keys,inp_attn_expanded_keys],axis=2) 244 | result_key_rep = tf.nn.tanh(tf.einsum("ijk,kl->ijl",tf.nn.tanh(tf.einsum("ijk,kl->ijl",result_key_rep,self.W_2)),self.W_22)) 245 | gamma_logits = tf.squeeze(tf.einsum('ijk,kl->ijl',result_key_rep,self.r_2),2) 246 | gamma_logits_expanded = tf.tile(array_ops.expand_dims(array_ops.expand_dims(gamma_logits,1),1),[1,tf.shape(self.result_keys_mask)[1],tf.shape(self.result_keys_mask)[2],1]) 247 | gamma_masked = tf.multiply(tf.cast(self.result_keys_mask,dtype=tf.float64),tf.exp(tf.cast(gamma_logits_expanded,dtype=tf.float64))) 248 | gamma = tf.einsum('ijk,ijkl->ijkl',beta,tf.cast(tf.einsum('ijk,ijkl->ijkl',tf.pow(tf.reduce_sum(gamma_masked,3),-1),gamma_masked),dtype=tf.float32)) 249 | 250 | batch_nums_context = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 251 | batch_nums_tiled_context = tf.tile(batch_nums_context,[1,tf.shape(self.encoder_states)[1]]) 252 | flat_inp_utt = tf.reshape(self.inp_utt,shape=[self.batch_size,-1]) 253 | indices_context = tf.stack([batch_nums_tiled_context,flat_inp_utt],axis=2) 254 | shape = [self.batch_size,self.out_vocab_size] 255 | context_copy_dist = tf.scatter_nd(indices_context,a,shape) 256 | 257 | all_betas = tf.reshape(beta,[self.batch_size,-1]) 258 | all_results = tf.reshape(self.result_values_rep,[self.batch_size,-1,self.emb_dim]) 259 | db_rep = tf.reduce_sum(tf.einsum('ij,ijk->ijk',all_betas,all_results),1) 260 | 261 | p_db = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W4)+self.b3) 262 | p_db = tf.tile(p_db,[1,self.out_vocab_size]) 263 | one_minus_fn = lambda x: 1 - x 264 | one_minus_pdb = tf.map_fn(one_minus_fn, p_db) 265 | 266 | p_gens = tf.nn.sigmoid(tf.matmul(tf.concat([hidden_state,inp_attn,db_rep],axis=1),self.W3)+self.b2) 267 | p_gens = tf.tile(p_gens,[1,self.out_vocab_size]) 268 | one_minus_fn = lambda x: 1 - x 269 | one_minus_pgens = tf.map_fn(one_minus_fn, p_gens) 270 | 271 | batch_nums = array_ops.expand_dims(tf.range(0, limit=self.batch_size, dtype=tf.int64),1) 272 | kb_ids = tf.reshape(self.result_values,shape=[self.batch_size,-1]) 273 | num_kb_ids = tf.shape(kb_ids)[1] 274 | batch_nums_tiled = tf.tile(batch_nums,[1,num_kb_ids]) 275 | indices = tf.stack([batch_nums_tiled,kb_ids],axis=2) 276 | updates = tf.reshape(gamma,shape=[self.batch_size,-1]) 277 | shape = [self.batch_size,self.out_vocab_size] 278 | kb_dist = tf.scatter_nd(indices,updates,shape) 279 | kb_dist = tf.einsum('i,ij->ij',self.db_empty,kb_dist) 280 | 281 | copy_dist = tf.multiply(p_db,kb_dist) + tf.multiply(one_minus_pdb,context_copy_dist) 282 | final_dist = tf.multiply(p_gens,extended_generate_dist) + tf.multiply(one_minus_pgens,copy_dist) 283 | 284 | return final_dist 285 | 286 | return unit 287 | 288 | def get_feed_dict(self,batch): 289 | 290 | fd = { 291 | self.inp_utt : batch['inp_utt'], 292 | self.inp_len : batch['inp_len'], 293 | self.out_utt : batch['out_utt'], 294 | self.out_len : batch['out_len'], 295 | self.context_len: batch['context_len'], 296 | self.query_mask : batch['query_mask'], 297 | self.search_mask : batch['search_mask'], 298 | self.search_values : batch['search_values'], 299 | self.results_mask : batch['results_mask'], 300 | self.result_keys_mask : batch['result_keys_mask'], 301 | self.result_values : batch['result_values'], 302 | self.db_empty : batch['empty'], 303 | self.max_out_utt_len : batch['max_out_utt_len'] 304 | } 305 | 306 | return fd 307 | -------------------------------------------------------------------------------- /maluuba/test.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def score(self,parallel_corpus): 54 | 55 | # containers 56 | count = [0, 0, 0, 0] 57 | clip_count = [0, 0, 0, 0] 58 | r = 0 59 | c = 0 60 | weights = [0.25, 0.25, 0.25, 0.25] 61 | 62 | # accumulate ngram statistics 63 | for hyps, refs in parallel_corpus: 64 | hyps = [hyp.split() for hyp in hyps] 65 | refs = [ref.split() for ref in refs] 66 | for hyp in hyps: 67 | 68 | for i in range(4): 69 | # accumulate ngram counts 70 | hypcnts = Counter(ngrams(hyp, i + 1)) 71 | cnt = sum(hypcnts.values()) 72 | count[i] += cnt 73 | 74 | # compute clipped counts 75 | max_counts = {} 76 | for ref in refs: 77 | refcnts = Counter(ngrams(ref, i + 1)) 78 | for ng in hypcnts: 79 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 80 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 81 | for ng, count in hypcnts.items()) 82 | clip_count[i] += sum(clipcnt.values()) 83 | 84 | # accumulate r & c 85 | bestmatch = [1000, 1000] 86 | for ref in refs: 87 | if bestmatch[0] == 0: break 88 | diff = abs(len(ref) - len(hyp)) 89 | if diff < bestmatch[0]: 90 | bestmatch[0] = diff 91 | bestmatch[1] = len(ref) 92 | r += bestmatch[1] 93 | c += len(hyp) 94 | 95 | # computing bleu score 96 | p0 = 1e-7 97 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 98 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 99 | for i in range(4)] 100 | s = math.fsum(w * math.log(p_n) \ 101 | for w, p_n in zip(weights, p_ns) if p_n) 102 | bleu = bp * math.exp(s) 103 | return bleu 104 | 105 | def evaluate(self,data,vocab): 106 | entities = ["business","economy","breakfast","wifi","gym","parking","spa","park","museum","beach","shopping","market","airport","university","mall","cathedral","downtown","palace","theatre"] 107 | outs = [] 108 | golds = [] 109 | 110 | tp_prec = 0.0 111 | tp_recall = 0.0 112 | total_prec = 0.0 113 | total_recall = 0.0 114 | 115 | for i in range(0,len(data['sentences'])): 116 | sentence = data['sentences'][i] 117 | sentence = list(sentence) 118 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 119 | index = len(sentence) 120 | else: 121 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 122 | predicted = [str(sentence[j]) for j in range(0,index)] 123 | ground = data['output'][i] 124 | ground = list(ground) 125 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 126 | ground_truth = [str(ground[j]) for j in range(0,index)] 127 | 128 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 129 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 130 | 131 | for word in out_anon: 132 | if word in entities or word.isdigit() or word in self.handler.all_entities: 133 | total_prec = total_prec + 1 134 | if word in gold_anon: 135 | tp_prec = tp_prec + 1 136 | 137 | for word in gold_anon: 138 | if word in entities or word.isdigit() or word in self.handler.all_entities: 139 | total_recall = total_recall + 1 140 | if word in out_anon: 141 | tp_recall = tp_recall + 1 142 | 143 | gold = gold_anon 144 | out = out_anon 145 | 146 | golds.append(" ".join(gold)) 147 | outs.append(" ".join(out)) 148 | 149 | wrap_generated = [[_] for _ in outs] 150 | wrap_truth = [[_] for _ in golds] 151 | prec = tp_prec/total_prec 152 | recall = tp_recall/total_recall 153 | if prec == 0 or recall == 0: 154 | f1 = 0.0 155 | else: 156 | f1 = 2*prec*recall/(prec+recall) 157 | overall_f1 = f1 158 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 159 | return overall_f1 160 | 161 | def test(self): 162 | test_epoch_done = False 163 | 164 | teststep = 0 165 | testLoss = 0.0 166 | needed = {} 167 | needed['sentences'] = [] 168 | needed['output'] = [] 169 | needed['context'] = [] 170 | needed['kb'] = [] 171 | 172 | while not test_epoch_done: 173 | teststep = teststep + 1 174 | batch, test_epoch_done = self.handler.get_batch(train=False) 175 | feedDict = self.model.get_feed_dict(batch) 176 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 177 | 178 | if 1 not in batch['dummy']: 179 | needed['sentences'].extend(sentences) 180 | needed['output'].extend(batch['out_utt']) 181 | needed['context'].extend(batch['context']) 182 | needed['kb'].extend(batch['kb']) 183 | else: 184 | index = batch['dummy'].index(1) 185 | needed['sentences'].extend(sentences[0:index]) 186 | needed['output'].extend(batch['out_utt'][0:index]) 187 | needed['context'].extend(batch['context'][0:index]) 188 | needed['kb'].extend(batch['kb'][0:index]) 189 | 190 | pickle.dump(needed,open("needed.p","w")) 191 | self.evaluate(needed,self.handler.vocab) 192 | 193 | def main(): 194 | 195 | parser = argparse.ArgumentParser() 196 | parser.add_argument('--batch_size', type=int, default=4) 197 | parser.add_argument('--emb_dim', type=int, default=200) 198 | parser.add_argument('--enc_hid_dim', type=int, default=128) 199 | parser.add_argument('--dec_hid_dim', type=int, default=256) 200 | parser.add_argument('--attn_size', type=int, default=200) 201 | parser.add_argument('--epochs', type=int, default=12) 202 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 203 | parser.add_argument('--dataset_path', type=str, default='../data/Maluuba/') 204 | parser.add_argument('--glove_path', type=str, default='../data/') 205 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 206 | config = parser.parse_args() 207 | 208 | DEVICE = "/gpu:0" 209 | 210 | logging.info("Loading Data") 211 | 212 | handler = DataHandler( 213 | emb_dim = config.emb_dim, 214 | batch_size = config.batch_size, 215 | train_path = config.dataset_path + "train.json", 216 | val_path = config.dataset_path + "test.json", 217 | test_path = config.dataset_path + "test.json", 218 | vocab_path = "./vocab.json", 219 | entities_path = config.dataset_path + "entities.json", 220 | glove_path = config.glove_path) 221 | 222 | logging.info("Loading Architecture") 223 | 224 | model = DialogueModel( 225 | device = DEVICE, 226 | batch_size = config.batch_size, 227 | inp_vocab_size = handler.input_vocab_size, 228 | out_vocab_size = handler.output_vocab_size, 229 | generate_size = handler.generate_vocab_size, 230 | emb_init = handler.emb_init, 231 | result_keys_vector = handler.result_keys_vector, 232 | emb_dim = config.emb_dim, 233 | enc_hid_dim = config.enc_hid_dim, 234 | dec_hid_dim = config.dec_hid_dim, 235 | attn_size = config.attn_size) 236 | 237 | logging.info("Loading Trainer") 238 | 239 | trainer = Trainer( 240 | model=model, 241 | handler=handler, 242 | ckpt_path="./trainDir/", 243 | num_epochs=config.epochs, 244 | learning_rate = config.learning_rate) 245 | 246 | trainer.test() 247 | 248 | main() -------------------------------------------------------------------------------- /maluuba/train.py: -------------------------------------------------------------------------------- 1 | import json 2 | import copy 3 | import numpy as np 4 | from data_handler import DataHandler 5 | from model import DialogueModel 6 | import os 7 | import tensorflow as tf 8 | import cPickle as pickle 9 | import nltk 10 | import sys 11 | import csv 12 | from collections import Counter 13 | from nltk.util import ngrams 14 | from nltk.corpus import stopwords 15 | from nltk.tokenize import word_tokenize 16 | from nltk.stem import WordNetLemmatizer 17 | import math, re, argparse 18 | import functools 19 | import logging 20 | logging.getLogger().setLevel(logging.INFO) 21 | 22 | class Trainer(object): 23 | 24 | def __init__(self,model,handler,ckpt_path,num_epochs,learning_rate): 25 | self.handler = handler 26 | self.model = model 27 | self.ckpt_path = ckpt_path 28 | self.epochs = num_epochs 29 | self.learning_rate = learning_rate 30 | 31 | if not os.path.exists(self.ckpt_path): 32 | os.makedirs(self.ckpt_path) 33 | 34 | self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) 35 | self.optimizer = tf.contrib.layers.optimize_loss( 36 | loss=self.model.loss, 37 | global_step=self.global_step, 38 | learning_rate=self.learning_rate, 39 | optimizer=tf.train.AdamOptimizer, 40 | clip_gradients=10.0, 41 | name='optimizer_loss' 42 | ) 43 | self.saver = tf.train.Saver(max_to_keep=5) 44 | self.sess = tf.Session(config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)) 45 | init = tf.global_variables_initializer() 46 | self.sess.run(init) 47 | 48 | checkpoint = tf.train.latest_checkpoint(self.ckpt_path) 49 | if checkpoint: 50 | self.saver.restore(self.sess, checkpoint) 51 | logging.info("Loaded parameters from checkpoint") 52 | 53 | def trainData(self): 54 | curEpoch = 0 55 | step = 0 56 | epochLoss = [] 57 | 58 | logging.info("Training the model") 59 | 60 | best_f1 = 0.0 61 | 62 | while curEpoch <= self.epochs: 63 | step = step + 1 64 | 65 | batch, epoch_done = self.handler.get_batch(train=True) 66 | feedDict = self.model.get_feed_dict(batch) 67 | 68 | fetch = [self.global_step, self.model.loss, self.optimizer] 69 | mod_step,loss,_ = self.sess.run(fetch,feed_dict = feedDict) 70 | epochLoss.append(loss) 71 | 72 | if step % 300 == 0: 73 | outstr = "step: "+str(step)+" Loss: "+str(loss) 74 | logging.info(outstr) 75 | 76 | if epoch_done: 77 | train_loss = np.mean(np.asarray(epochLoss)) 78 | 79 | val_epoch_done = False 80 | valstep = 0 81 | valLoss = 0.0 82 | needed = {} 83 | needed['sentences'] = [] 84 | needed['output'] = [] 85 | 86 | while not val_epoch_done: 87 | valstep = valstep + 1 88 | batch, val_epoch_done = self.handler.get_batch(train=False) 89 | feedDict = self.model.get_feed_dict(batch) 90 | val_loss,sentences = self.sess.run([self.model.loss,self.model.gen_x],feed_dict=feedDict) 91 | if 1 not in batch['dummy']: 92 | needed['sentences'].extend(sentences) 93 | needed['output'].extend(batch['out_utt']) 94 | else: 95 | index = batch['dummy'].index(1) 96 | needed['sentences'].extend(sentences[0:index]) 97 | needed['output'].extend(batch['out_utt'][0:index]) 98 | valLoss = valLoss + val_loss 99 | 100 | valLoss = valLoss / float(valstep) 101 | outstr = "Train-info: "+ "Epoch: ",str(curEpoch)+" Loss: "+str(train_loss) 102 | logging.info(outstr) 103 | outstr = "Val-info: "+"Epoch "+str(curEpoch)+" Loss: "+str(valLoss) 104 | logging.info(outstr) 105 | if curEpoch > 5: 106 | current_f1 = self.evaluate(needed,self.handler.vocab) 107 | if current_f1 >= best_f1: 108 | best_f1 = current_f1 109 | self.saver.save(self.sess, os.path.join(self.ckpt_path, 'model'), global_step=curEpoch) 110 | 111 | epochLoss = [] 112 | curEpoch = curEpoch + 1 113 | 114 | def score(self,parallel_corpus): 115 | 116 | # containers 117 | count = [0, 0, 0, 0] 118 | clip_count = [0, 0, 0, 0] 119 | r = 0 120 | c = 0 121 | weights = [0.25, 0.25, 0.25, 0.25] 122 | 123 | # accumulate ngram statistics 124 | for hyps, refs in parallel_corpus: 125 | hyps = [hyp.split() for hyp in hyps] 126 | refs = [ref.split() for ref in refs] 127 | for hyp in hyps: 128 | 129 | for i in range(4): 130 | # accumulate ngram counts 131 | hypcnts = Counter(ngrams(hyp, i + 1)) 132 | cnt = sum(hypcnts.values()) 133 | count[i] += cnt 134 | 135 | # compute clipped counts 136 | max_counts = {} 137 | for ref in refs: 138 | refcnts = Counter(ngrams(ref, i + 1)) 139 | for ng in hypcnts: 140 | max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng]) 141 | clipcnt = dict((ng, min(count, max_counts[ng])) \ 142 | for ng, count in hypcnts.items()) 143 | clip_count[i] += sum(clipcnt.values()) 144 | 145 | # accumulate r & c 146 | bestmatch = [1000, 1000] 147 | for ref in refs: 148 | if bestmatch[0] == 0: break 149 | diff = abs(len(ref) - len(hyp)) 150 | if diff < bestmatch[0]: 151 | bestmatch[0] = diff 152 | bestmatch[1] = len(ref) 153 | r += bestmatch[1] 154 | c += len(hyp) 155 | 156 | # computing bleu score 157 | p0 = 1e-7 158 | bp = 1 if c > r else math.exp(1 - float(r) / float(c)) 159 | p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 \ 160 | for i in range(4)] 161 | s = math.fsum(w * math.log(p_n) \ 162 | for w, p_n in zip(weights, p_ns) if p_n) 163 | bleu = bp * math.exp(s) 164 | return bleu 165 | 166 | def evaluate(self,data,vocab): 167 | entities = ["business","economy","breakfast","wifi","gym","parking","spa","park","museum","beach","shopping","market","airport","university","mall","cathedral","downtown","palace","theatre"] 168 | outs = [] 169 | golds = [] 170 | 171 | tp_prec = 0.0 172 | tp_recall = 0.0 173 | total_prec = 0.0 174 | total_recall = 0.0 175 | 176 | for i in range(0,len(data['sentences'])): 177 | sentence = data['sentences'][i] 178 | sentence = list(sentence) 179 | if vocab['vocab_mapping']['$STOP$'] not in sentence: 180 | index = len(sentence) 181 | else: 182 | index = sentence.index(vocab['vocab_mapping']['$STOP$']) 183 | predicted = [str(sentence[j]) for j in range(0,index)] 184 | ground = data['output'][i] 185 | ground = list(ground) 186 | index = ground.index(vocab['vocab_mapping']['$STOP$']) 187 | ground_truth = [str(ground[j]) for j in range(0,index)] 188 | 189 | gold_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in ground_truth ] 190 | out_anon = [vocab['rev_mapping'][word].encode('utf-8') for word in predicted ] 191 | 192 | for word in out_anon: 193 | if word in entities or word.isdigit() or word in self.handler.all_entities: 194 | total_prec = total_prec + 1 195 | if word in gold_anon: 196 | tp_prec = tp_prec + 1 197 | 198 | for word in gold_anon: 199 | if word in entities or word.isdigit() or word in self.handler.all_entities: 200 | total_recall = total_recall + 1 201 | if word in out_anon: 202 | tp_recall = tp_recall + 1 203 | 204 | gold = gold_anon 205 | out = out_anon 206 | 207 | golds.append(" ".join(gold)) 208 | outs.append(" ".join(out)) 209 | 210 | wrap_generated = [[_] for _ in outs] 211 | wrap_truth = [[_] for _ in golds] 212 | prec = tp_prec/total_prec 213 | recall = tp_recall/total_recall 214 | if prec == 0 or recall == 0: 215 | f1 = 0.0 216 | else: 217 | f1 = 2*prec*recall/(prec+recall) 218 | overall_f1 = f1 219 | print "Bleu: %.3f, Prec: %.3f, Recall: %.3f, F1: %.3f" % (self.score(zip(wrap_generated, wrap_truth)),prec,recall,f1) 220 | return overall_f1 221 | 222 | def test(self): 223 | test_epoch_done = False 224 | 225 | teststep = 0 226 | testLoss = 0.0 227 | needed = {} 228 | needed['sentences'] = [] 229 | needed['output'] = [] 230 | needed['context'] = [] 231 | needed['kb'] = [] 232 | 233 | while not test_epoch_done: 234 | teststep = teststep + 1 235 | batch, test_epoch_done = self.handler.get_batch(train=False) 236 | feedDict = self.model.get_feed_dict(batch) 237 | sentences = self.sess.run(self.model.gen_x,feed_dict=feedDict) 238 | 239 | if 1 not in batch['dummy']: 240 | needed['sentences'].extend(sentences) 241 | needed['output'].extend(batch['out_utt']) 242 | needed['context'].extend(batch['context']) 243 | needed['kb'].extend(batch['kb']) 244 | else: 245 | index = batch['dummy'].index(1) 246 | needed['sentences'].extend(sentences[0:index]) 247 | needed['output'].extend(batch['out_utt'][0:index]) 248 | needed['context'].extend(batch['context'][0:index]) 249 | needed['kb'].extend(batch['kb'][0:index]) 250 | 251 | pickle.dump(needed,open("needed.p","w")) 252 | self.evaluate(needed,self.handler.vocab) 253 | 254 | def main(): 255 | 256 | parser = argparse.ArgumentParser() 257 | parser.add_argument('--batch_size', type=int, default=4) 258 | parser.add_argument('--emb_dim', type=int, default=200) 259 | parser.add_argument('--enc_hid_dim', type=int, default=128) 260 | parser.add_argument('--dec_hid_dim', type=int, default=256) 261 | parser.add_argument('--attn_size', type=int, default=200) 262 | parser.add_argument('--epochs', type=int, default=12) 263 | parser.add_argument('--learning_rate', type=float, default=2.5e-4) 264 | parser.add_argument('--dataset_path', type=str, default='../data/Maluuba/') 265 | parser.add_argument('--glove_path', type=str, default='../data/') 266 | parser.add_argument('--checkpoint', type=str, default="./trainDir/") 267 | config = parser.parse_args() 268 | 269 | DEVICE = "/gpu:0" 270 | 271 | logging.info("Loading Data") 272 | 273 | handler = DataHandler( 274 | emb_dim = config.emb_dim, 275 | batch_size = config.batch_size, 276 | train_path = config.dataset_path + "train.json", 277 | val_path = config.dataset_path + "val.json", 278 | test_path = config.dataset_path + "test.json", 279 | vocab_path = "./vocab.json", 280 | entities_path = config.dataset_path + "entities.json", 281 | glove_path = config.glove_path) 282 | 283 | logging.info("Loading Architecture") 284 | 285 | model = DialogueModel( 286 | device = DEVICE, 287 | batch_size = config.batch_size, 288 | inp_vocab_size = handler.input_vocab_size, 289 | out_vocab_size = handler.output_vocab_size, 290 | generate_size = handler.generate_vocab_size, 291 | emb_init = handler.emb_init, 292 | result_keys_vector = handler.result_keys_vector, 293 | emb_dim = config.emb_dim, 294 | enc_hid_dim = config.enc_hid_dim, 295 | dec_hid_dim = config.dec_hid_dim, 296 | attn_size = config.attn_size) 297 | 298 | logging.info("Loading Trainer") 299 | 300 | trainer = Trainer( 301 | model=model, 302 | handler=handler, 303 | ckpt_path="./trainDir/", 304 | num_epochs=config.epochs, 305 | learning_rate = config.learning_rate) 306 | 307 | trainer.trainData() 308 | 309 | main() --------------------------------------------------------------------------------