├── DSTmodels.py ├── LICENSE ├── README.md ├── bucket_io.py ├── cnntrack.py ├── data.py ├── dataset_walker.py ├── extract_vocab.py ├── gen_custom_data.py ├── get_embbeding.py ├── lectrack.py ├── mat_data.py ├── mat_io.py ├── mod_lectrack.py ├── offline_model.py ├── offline_model_dstc.py ├── turnbow_io.py ├── turnsent_io.py ├── vocab_actN.dict └── vocab_matNN.dict /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 Liliang Ren 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Universal Dialogue State Tracking 2 | The code for the EMNLP 2018 paper: "Towards Universal Dialogue State Tracking" 3 | [Paper Link](https://arxiv.org/abs/1810.09587) 4 | [Oral Slides](https://drive.google.com/open?id=1aUTcBzDA44fOgU40vPspyNuWu2aR5cgV) 5 | 6 | We applied turn-length bucketing for better data parallelism. The StateNet model is defined with the class name "doublelstm" in the file "DSTmodels.py". 7 | -------------------------------------------------------------------------------- /bucket_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import json 5 | 6 | import numpy as np 7 | import mxnet as mx 8 | 9 | # The interface of a data iter that works for bucketing 10 | # 11 | # DataIter 12 | # - default_bucket_key: the bucket key for the default symbol. 13 | # 14 | # DataBatch 15 | # - provide_data: same as DataIter, but specific to this batch 16 | # - provide_label: same as DataIter, but specific to this batch 17 | # - bucket_key: the key for the bucket that should be used for this batch 18 | 19 | def read_1best_dialog_content(dialog, labelIdx): 20 | dialog_sentences, dialog_scores, dialog_labels = [], [], [] 21 | sentence = "" 22 | score = [] 23 | for turn in dialog["turns"]: 24 | dialog_labels.append(turn["labelIdx"][labelIdx]) 25 | sentence +=" #turn# " 26 | score.append(1) 27 | 28 | for saPair in turn["machine_output"]: 29 | act = saPair["act"] 30 | slots = " " 31 | for slot in saPair["slots"]: 32 | #count never appears in train/dev set# 33 | if "count" in slot: 34 | #slot[1] = str(slot[1]) 35 | continue 36 | slots += " ".join(slot) 37 | slots += " " 38 | machine_act=(act+slots) 39 | for _ in range(len(machine_act.split())): 40 | score.append(1) 41 | sentence += machine_act 42 | 43 | asr = turn["user_input"][0]["asr-hyp"] 44 | if len(asr.split()) > 0: 45 | sentence += turn["user_input"][0]["asr-hyp"] + " " 46 | score.extend([turn["user_input"][0]["score"]] * len(asr.split())) 47 | 48 | #sentence += " " 49 | #score.append(1) 50 | assert(len(sentence.split())==len(score)) 51 | dialog_sentences.append(sentence) 52 | dialog_scores.append(score[:]) 53 | return dialog_sentences, dialog_scores, dialog_labels 54 | 55 | def default_read_content(path, labelIdx): 56 | sentences, scores, labels = [], [], [] 57 | with open(path) as json_file: 58 | data = json.load(json_file) 59 | for dialog in data: 60 | dialog_sentences, dialog_scores, dialog_labels = read_1best_dialog_content(dialog, labelIdx) 61 | sentences.extend(dialog_sentences) 62 | scores.extend(dialog_scores) 63 | labels.extend(dialog_labels) 64 | return sentences, scores, labels 65 | 66 | def default_text2id(sentence, the_vocab): 67 | words = sentence.split() 68 | words = [(the_vocab[w] if w in the_vocab else 0) for w in words if len(w) > 0] 69 | return words 70 | 71 | def default_gen_buckets(sentences, batch_size, the_vocab): 72 | len_dict = {} 73 | max_len = -1 74 | for sentence in sentences: 75 | words = default_text2id(sentence, the_vocab) 76 | if len(words) == 0: 77 | continue 78 | if len(words) > max_len: 79 | max_len = len(words) 80 | if len(words) in len_dict: 81 | len_dict[len(words)] += 1 82 | else: 83 | len_dict[len(words)] = 1 84 | #print(len_dict) 85 | 86 | tl = 0 87 | buckets = [] 88 | for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this 89 | if n + tl >= batch_size*6: 90 | buckets.append(l) 91 | tl = 0 92 | else: 93 | tl += n 94 | if tl > 0 and len(buckets) > 0: 95 | buckets[-1] = max_len 96 | return buckets 97 | 98 | class SimpleBatch(object): 99 | def __init__(self, data_names, data, label_names, label, bucket_key, pad=0): 100 | self.data = data 101 | self.label = label 102 | self.data_names = data_names 103 | self.label_names = label_names 104 | self.bucket_key = bucket_key 105 | 106 | self.pad = pad 107 | self.index = None # TODO: what is index? 108 | 109 | @property 110 | def provide_data(self): 111 | return [(n, x.shape) for n, x in zip(self.data_names, self.data)] 112 | 113 | @property 114 | def provide_label(self): 115 | return [(n, x.shape) for n, x in zip(self.label_names, self.label)] 116 | 117 | class DSTSentenceIter(mx.io.DataIter): 118 | def __init__(self, path, labelIdx, vocab, buckets, batch_size, 119 | init_states, data_components, 120 | seperate_char=' ', text2id=None, read_content=None, label_out=1): 121 | super(DSTSentenceIter, self).__init__() 122 | self.padding_id = vocab[''] 123 | 124 | self.label_out = label_out 125 | if text2id == None: 126 | self.text2id = default_text2id 127 | else: 128 | self.text2id = text2id 129 | if read_content == None: 130 | self.read_content = default_read_content 131 | else: 132 | self.read_content = read_content 133 | #content = self.read_content(path) 134 | sentences,scores,labels = self.read_content(path, labelIdx) 135 | 136 | if len(buckets) == 0: 137 | buckets = default_gen_buckets(sentences, batch_size, vocab) 138 | 139 | self.vocab_size = len(vocab) 140 | 141 | buckets.sort() 142 | self.buckets = buckets 143 | self.data = [[] for _ in buckets] 144 | self.data_score = [[] for _ in buckets] 145 | self.label = [[] for _ in buckets] 146 | 147 | # pre-allocate with the largest bucket for better memory sharing 148 | self.default_bucket_key = max(buckets) 149 | 150 | for i in range(len(sentences)): 151 | sentence = sentences[i] 152 | score = scores[i] 153 | label = labels[i] 154 | sentence = self.text2id(sentence, vocab) 155 | if len(sentence) == 0: 156 | continue 157 | for i, bkt in enumerate(buckets): 158 | if bkt >= len(sentence): 159 | assert(len(sentence)==len(score)) 160 | self.data[i].append(sentence) 161 | self.data_score[i].append(score) 162 | self.label[i].append(label) 163 | break 164 | # we just ignore the sentence it is longer than the maximum 165 | # bucket size here 166 | 167 | # re-arrange buckets to include as much as possible corpus 168 | for i in xrange(len(self.data)-1): 169 | tmp_num = len(self.data[i]) / batch_size 170 | self.data[i+1].extend(self.data[i][tmp_num*batch_size:]) 171 | self.data[i] = self.data[i][:tmp_num*batch_size] 172 | self.data_score[i+1].extend(self.data_score[i][tmp_num*batch_size:]) 173 | self.data_score[i] = self.data_score[i][:tmp_num*batch_size] 174 | self.label[i+1].extend(self.label[i][tmp_num*batch_size:]) 175 | self.label[i] = self.label[i][:tmp_num*batch_size] 176 | 177 | # convert data into ndarrays for better speed during training 178 | #data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)] 179 | data = [np.full((len(x), buckets[i]), self.padding_id) for i, x in enumerate(self.data)] 180 | data_mask_len = [np.zeros((len(x), )) for i, x in enumerate(self.data)] 181 | data_score = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data_score)] 182 | label = [np.zeros((len(x), self.label_out)) for i, x in enumerate(self.label)] 183 | for i_bucket in range(len(self.buckets)): 184 | for j in range(len(self.data[i_bucket])): 185 | sentence = self.data[i_bucket][j] 186 | data[i_bucket][j, :len(sentence)] = sentence 187 | data_mask_len[i_bucket][j] = len(sentence) 188 | score = self.data_score[i_bucket][j] 189 | #print(sentence) 190 | #print(score) 191 | data_score[i_bucket][j, :len(score)] = score 192 | label[i_bucket][j] = self.label[i_bucket][j] 193 | 194 | self.data = data 195 | self.data_mask_len = data_mask_len 196 | self.data_score = data_score 197 | self.label = label 198 | 199 | # backup corpus 200 | self.all_data = copy.deepcopy(self.data) 201 | self.all_data_mask_len = copy.deepcopy(self.data_mask_len) 202 | self.all_data_score = copy.deepcopy(self.data_score) 203 | self.all_label = copy.deepcopy(self.label) 204 | 205 | # Get the size of each bucket, so that we could sample 206 | # uniformly from the bucket 207 | sizeS=0 208 | bucket_sizes = [len(x) for x in self.data] 209 | print("Summary of dataset ==================") 210 | for bkt, size in zip(buckets, bucket_sizes): 211 | sizeS+=size 212 | print("bucket of len %3d : %d samples" % (bkt, size)) 213 | 214 | self.batch_size = batch_size 215 | #self.make_data_iter_plan() 216 | 217 | self.init_states = init_states 218 | self.data_components = data_components 219 | self.size=int(sizeS/batch_size) 220 | self.provide_data = self.data_components + self.init_states 221 | self.provide_label = [('softmax_label', (self.batch_size, self.label_out))] 222 | 223 | 224 | def make_data_iter_plan(self): 225 | "make a random data iteration plan" 226 | # truncate each bucket into multiple of batch-size 227 | bucket_n_batches = [] 228 | for i in range(len(self.data)): 229 | # shuffle data before truncate 230 | index_shuffle = range(len(self.data[i])) 231 | np.random.shuffle(index_shuffle) 232 | self.data[i] = self.all_data[i][index_shuffle] 233 | self.data_mask_len[i] = self.all_data_mask_len[i][index_shuffle] 234 | self.data_score[i] = self.all_data_score[i][index_shuffle] 235 | self.label[i] = self.all_label[i][index_shuffle] 236 | 237 | bucket_n_batches.append(len(self.data[i]) / self.batch_size) 238 | self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] 239 | self.data_mask_len[i] = self.data_mask_len[i][:int(bucket_n_batches[i]*self.batch_size)] 240 | self.data_score[i] = self.data_score[i][:int(bucket_n_batches[i]*self.batch_size)] 241 | 242 | bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) 243 | np.random.shuffle(bucket_plan) 244 | 245 | bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] 246 | 247 | self.bucket_plan = bucket_plan 248 | self.bucket_idx_all = bucket_idx_all 249 | self.bucket_curr_idx = [0 for x in self.data] 250 | 251 | self.data_buffer = [] 252 | self.data_mask_len_buffer = [] 253 | self.data_score_buffer = [] 254 | self.label_buffer = [] 255 | for i_bucket in range(len(self.data)): 256 | data = np.zeros((self.batch_size, self.buckets[i_bucket])) 257 | data_mask_len = np.zeros((self.batch_size,)) 258 | data_score = np.zeros((self.batch_size, self.buckets[i_bucket])) 259 | label = np.zeros((self.batch_size, self.label_out)) 260 | self.data_buffer.append(data) 261 | self.data_mask_len_buffer.append(data_mask_len) 262 | self.data_score_buffer.append(data_score) 263 | self.label_buffer.append(label) 264 | 265 | def __iter__(self): 266 | self.make_data_iter_plan() 267 | for i_bucket in self.bucket_plan: 268 | data = self.data_buffer[i_bucket] 269 | data_mask_len = self.data_mask_len_buffer[i_bucket] 270 | data_score = self.data_score_buffer[i_bucket] 271 | i_idx = self.bucket_curr_idx[i_bucket] 272 | idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] 273 | self.bucket_curr_idx[i_bucket] += self.batch_size 274 | 275 | # Data parallelism 276 | data[:] = self.data[i_bucket][idx] 277 | data_mask_len[:] = self.data_mask_len[i_bucket][idx] 278 | data_score[:] = self.data_score[i_bucket][idx] 279 | 280 | for sentence in data: 281 | assert len(sentence) == self.buckets[i_bucket] 282 | 283 | label = self.label_buffer[i_bucket] 284 | label[:] = self.label[i_bucket][idx] 285 | 286 | data_names = [x[0] for x in self.provide_data] 287 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 288 | data_all = [mx.nd.array(data)] 289 | if 'score' in data_names: 290 | data_all += [mx.nd.array(data_score)] 291 | if 'data_mask_len' in data_names: 292 | data_all += [mx.nd.array(data_mask_len)] 293 | data_all += init_state_arrays 294 | 295 | label_names = ['softmax_label'] 296 | label_all = [mx.nd.array(label)] 297 | 298 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, self.buckets[i_bucket]) 299 | yield data_batch 300 | 301 | 302 | def reset(self): 303 | self.bucket_curr_idx = [0 for x in self.data] 304 | -------------------------------------------------------------------------------- /cnntrack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys,os 4 | import mxnet as mx 5 | import numpy as np 6 | import time 7 | import math 8 | from collections import namedtuple 9 | 10 | import logging 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger(__name__) # get a logger to accuracies are printed 13 | 14 | 15 | from data import vocab, ontologyDict as ontology 16 | from bucket_io import default_text2id, default_read_content 17 | 18 | # 脚本所在位置 19 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 20 | 21 | 22 | CNNModel = namedtuple("CNNModel", ['cnn_exec', 'symbol', 'data', 'label', 'param_blocks']) 23 | 24 | def text_cnn(sentence_size, num_embed, batch_size, vocab_size, 25 | num_label, filter_list=[3, 4, 5], num_filter=200, 26 | dropout=0., with_embedding=True): 27 | 28 | input_x = mx.sym.Variable('data') # placeholder for input 29 | input_y = mx.sym.Variable('softmax_label') # placeholder for output 30 | 31 | # embedding layer 32 | if not with_embedding: 33 | embed_layer = mx.sym.Embedding(data=input_x, input_dim=vocab_size, output_dim=num_embed, name='vocab_embed') 34 | conv_input = mx.sym.Reshape(data=embed_layer, target_shape=(batch_size, 1, sentence_size, num_embed)) 35 | else: 36 | conv_input = input_x 37 | 38 | # create convolution + (max) pooling layer for each filter operation 39 | pooled_outputs = [] 40 | for i, filter_size in enumerate(filter_list): 41 | convi = mx.sym.Convolution(data=conv_input, kernel=(filter_size, num_embed), num_filter=num_filter) 42 | relui = mx.sym.Activation(data=convi, act_type='relu') 43 | pooli = mx.sym.Pooling(data=relui, pool_type='max', kernel=(sentence_size - filter_size + 1, 1), stride=(1,1)) 44 | pooled_outputs.append(pooli) 45 | 46 | # combine all pooled outputs 47 | total_filters = num_filter * len(filter_list) 48 | concat = mx.sym.Concat(*pooled_outputs, dim=1) 49 | h_pool = mx.sym.Reshape(data=concat, target_shape=(batch_size, total_filters)) 50 | 51 | # dropout layer 52 | if dropout > 0.0: 53 | h_drop = mx.sym.Dropout(data=h_pool, p=dropout) 54 | else: 55 | h_drop = h_pool 56 | 57 | # fully connected 58 | cls_weight = mx.sym.Variable('cls_weight') 59 | cls_bias = mx.sym.Variable('cls_bias') 60 | 61 | fc = mx.sym.FullyConnected(data=h_drop, weight=cls_weight, bias=cls_bias, num_hidden=num_label) 62 | 63 | # softmax output 64 | #sm = mx.sym.SoftmaxOutput(data=fc, label=input_y, name='softmax', normalization='batch') 65 | sm = mx.sym.SoftmaxOutput(data=fc, label=input_y, name='softmax') 66 | 67 | return sm 68 | 69 | 70 | def setup_cnn_model(ctx, batch_size, sentence_size, num_embed, vocab_size, num_label, 71 | dropout=0.5, initializer=mx.initializer.Uniform(0.1), with_embedding=True): 72 | #dropout=0.5, initializer=mx.init.Xavier(magnitude=2.34), with_embedding=True): 73 | 74 | cnn = text_cnn(sentence_size, num_embed, batch_size=batch_size, 75 | vocab_size=vocab_size, num_label=num_label, dropout=dropout, with_embedding=with_embedding) 76 | arg_names = cnn.list_arguments() 77 | 78 | input_shapes = {} 79 | if with_embedding: 80 | input_shapes['data'] = (batch_size, 1, sentence_size, num_embed) 81 | else: 82 | input_shapes['data'] = (batch_size, sentence_size) 83 | 84 | arg_shape, out_shape, aux_shape = cnn.infer_shape(**input_shapes) 85 | arg_arrays = [mx.nd.zeros(s, ctx) for s in arg_shape] 86 | args_grad = {} 87 | for shape, name in zip(arg_shape, arg_names): 88 | if name in ['softmax_label', 'data']: # input, output 89 | continue 90 | args_grad[name] = mx.nd.zeros(shape, ctx) 91 | 92 | cnn_exec = cnn.bind(ctx=ctx, args=arg_arrays, args_grad=args_grad, grad_req='add') 93 | 94 | param_blocks = [] 95 | arg_dict = dict(zip(arg_names, cnn_exec.arg_arrays)) 96 | for i, name in enumerate(arg_names): 97 | if name in ['softmax_label', 'data']: # input, output 98 | continue 99 | initializer(name, arg_dict[name]) 100 | 101 | param_blocks.append( (i, arg_dict[name], args_grad[name], name) ) 102 | 103 | out_dict = dict(zip(cnn.list_outputs(), cnn_exec.outputs)) 104 | 105 | data = cnn_exec.arg_dict['data'] 106 | label = cnn_exec.arg_dict['softmax_label'] 107 | 108 | return CNNModel(cnn_exec=cnn_exec, symbol=cnn, data=data, label=label, param_blocks=param_blocks) 109 | 110 | 111 | def train_cnn(model, X_train_batch, y_train_batch, X_dev_batch, y_dev_batch, X_test_batch, y_test_batch, batch_size, 112 | #optimizer='rmsprop', max_grad_norm=5.0, learning_rate=0.0005, epoch=200): 113 | #optimizer='adadelta', max_grad_norm=5.0, learning_rate=0.0005, epoch=200): 114 | optimizer='adam', max_grad_norm=5.0, learning_rate=0.0005, epoch=100): 115 | m = model 116 | # create optimizer 117 | opt = mx.optimizer.create(optimizer) 118 | opt.lr = learning_rate 119 | 120 | updater = mx.optimizer.get_updater(opt) 121 | 122 | dev_acc_list = [0.0] 123 | for iteration in range(epoch): 124 | tic = time.time() 125 | num_correct = 0 126 | num_total = 0 127 | for begin in range(0, X_train_batch.shape[0], batch_size): 128 | batchX = X_train_batch[begin:begin+batch_size] 129 | batchY = y_train_batch[begin:begin+batch_size] 130 | if batchX.shape[0] != batch_size: 131 | continue 132 | 133 | m.data[:] = batchX 134 | m.label[:] = batchY 135 | 136 | # forward 137 | m.cnn_exec.forward(is_train=True) 138 | 139 | # backward 140 | m.cnn_exec.backward() 141 | 142 | # eval on training data 143 | num_correct += sum(batchY == np.argmax(m.cnn_exec.outputs[0].asnumpy(), axis=1)) 144 | num_total += len(batchY) 145 | 146 | # update weights 147 | norm = 0 148 | for idx, weight, grad, name in m.param_blocks: 149 | grad /= batch_size 150 | l2_norm = mx.nd.norm(grad).asscalar() 151 | norm += l2_norm * l2_norm 152 | 153 | norm = math.sqrt(norm) 154 | for idx, weight, grad, name in m.param_blocks: 155 | if norm > max_grad_norm: 156 | grad *= (max_grad_norm / norm) 157 | 158 | updater(idx, grad, weight) 159 | 160 | # reset gradient to zero 161 | grad[:] = 0.0 162 | 163 | # decay learning rate 164 | #if iteration % 50 == 0 and iteration > 0: 165 | # opt.lr *= 0.5 166 | # print('reset learning rate to %g' % opt.lr) 167 | 168 | # end of training loop 169 | toc = time.time() 170 | train_time = toc - tic 171 | train_acc = num_correct * 100 / float(num_total) 172 | 173 | # saving checkpoint 174 | if (iteration + 1) % 10 == 0: 175 | prefix = 'cnn' 176 | m.symbol.save('checkpoint/%s-symbol.json' % prefix) 177 | save_dict = {('arg:%s' % k) :v for k, v in m.cnn_exec.arg_dict.items()} 178 | save_dict.update({('aux:%s' % k) : v for k, v in m.cnn_exec.aux_dict.items()}) 179 | param_name = 'checkpoint/%s-%04d.params' % (prefix, iteration) 180 | mx.nd.save(param_name, save_dict) 181 | print('Saved checkpoint to %s' % param_name) 182 | 183 | def evaluate_dataset(X_batch, y_batch): 184 | # evaluate on some data set 185 | num_correct = 0 186 | num_total = 0 187 | for begin in range(0, X_batch.shape[0], batch_size): 188 | batchX = X_batch[begin:begin+batch_size] 189 | batchY = y_batch[begin:begin+batch_size] 190 | if batchX.shape[0] != batch_size: 191 | continue 192 | 193 | m.data[:] = batchX 194 | m.cnn_exec.forward(is_train=False) 195 | num_correct += sum(batchY == np.argmax(m.cnn_exec.outputs[0].asnumpy(), axis=1)) 196 | num_total += len(batchY) 197 | acc = num_correct * 100 / float(num_total) 198 | return acc 199 | 200 | dev_acc = evaluate_dataset(X_dev_batch, y_dev_batch) 201 | test_acc = evaluate_dataset(X_test_batch, y_test_batch) 202 | print('Iter [%d] Train: Time: %.3fs, Training Accuracy: %.3f \ 203 | --- Dev Accuracy thus far: %.3f \ 204 | --- Test Accuracy thus far: %.3f' % (iteration, train_time, train_acc, dev_acc, test_acc)) 205 | sys.stdout.flush() 206 | sys.stderr.flush() 207 | 208 | # decay learning rate 209 | #if dev_acc < dev_acc_list[-1]: 210 | # opt.lr *= 0.5 211 | # print('reset learning rate to %g' % opt.lr) 212 | #dev_acc_list.append(dev_acc) 213 | 214 | 215 | 216 | def get_x_y_from_data(data_json_file, labelIdx): 217 | raw_sentences, scores, labels = default_read_content(data_json_file, labelIdx) 218 | sentences = [] 219 | for i in xrange(len(raw_sentences)): 220 | raw_sentence = raw_sentences[i] 221 | sentences.append(default_text2id(raw_sentence, vocab)) 222 | 223 | # padding to max sentence length with '' 224 | padding_word = '' 225 | sequence_length = 360 226 | padded_sentences = [] 227 | for i in xrange(len(sentences)): 228 | sentence = sentences[i] 229 | num_padding = sequence_length - len(sentence) 230 | new_sentence = sentence + [vocab[padding_word]] * num_padding 231 | padded_sentences.append(new_sentence) 232 | 233 | # convert to np array 234 | x = np.array(padded_sentences) 235 | y = np.array(labels) 236 | return x, y 237 | 238 | 239 | def train_without_pretrained_embedding(labelIdx, config_dict={}): 240 | # ################################## 241 | label_index_list = ['goal_food', 'goal_pricerange', 'goal_name', 'goal_area', 'method', 'requested'] 242 | num_label_dict = { 243 | 'goal_food': len(ontology["informable"]["food"])+1, 244 | 'goal_pricerange': len(ontology["informable"]["pricerange"])+1, 245 | 'goal_name': len(ontology["informable"]["name"])+1, 246 | 'goal_area': len(ontology["informable"]["area"])+1, 247 | 'method': len(ontology["method"]), 248 | 'requested': len(ontology["requestable"]) 249 | } 250 | vocab_size = len(vocab) 251 | np.random.seed(10) 252 | 253 | # ################################## 254 | train_json_file = config_dict.get('train_json', os.path.join(cur_dir, 'train_nbest.json')) 255 | dev_json_file = config_dict.get('dev_json', os.path.join(cur_dir, 'dev_nbest.json')) 256 | test_json_file = config_dict.get('test_json', os.path.join(cur_dir, 'test_nbest.json')) 257 | 258 | label_name = label_index_list[labelIdx] 259 | num_label = num_label_dict[label_name] 260 | x_train, y_train = get_x_y_from_data(train_json_file, labelIdx) 261 | x_dev, y_dev = get_x_y_from_data(dev_json_file, labelIdx) 262 | x_test, y_test = get_x_y_from_data(test_json_file, labelIdx) 263 | 264 | # ################################## 265 | # randomly shuffle data 266 | shuffle_indices = np.random.permutation(np.arange(len(y_train))) 267 | x_train = x_train[shuffle_indices] 268 | y_train = y_train[shuffle_indices] 269 | shuffle_indices = np.random.permutation(np.arange(len(y_dev))) 270 | x_dev = x_dev[shuffle_indices] 271 | y_dev = y_dev[shuffle_indices] 272 | shuffle_indices = np.random.permutation(np.arange(len(y_test))) 273 | x_test = x_test[shuffle_indices] 274 | y_test = y_test[shuffle_indices] 275 | print('Train/Dev split: %d/%d' % (len(y_train), len(y_dev))) 276 | print('train shape:', x_train.shape) 277 | print('dev shape:', x_dev.shape) 278 | print('test shape:', x_test.shape) 279 | print('vocab_size', vocab_size) 280 | 281 | batch_size = 32 282 | num_embed = 100 283 | sentence_size = x_train.shape[1] 284 | 285 | print('batch size', batch_size) 286 | print('sentence max words', sentence_size) 287 | print('embedding size', num_embed) 288 | 289 | cnn_model = setup_cnn_model(mx.gpu(0), batch_size, sentence_size, num_embed, vocab_size, num_label, dropout=0.5, with_embedding=False) 290 | train_cnn(cnn_model, x_train, y_train, x_dev, y_dev, x_test, y_test, batch_size) 291 | 292 | 293 | #class CnnTrack(object): 294 | # """CnnTrack implementation, config_dict: 295 | # output_type: can be one of ['softmax', 'sigmoid'] 296 | # N: number of cpu/gpu cores 297 | # pretrain_embed: whether to use pre-trained word embedding 298 | # embed_matrix: pre-trained embed_matrix file 299 | # """ 300 | # def __init__(self, config_dict): 301 | # pass 302 | 303 | 304 | if __name__ == '__main__': 305 | if not os.path.exists("checkpoint"): 306 | os.mkdir("checkpoint") 307 | train_without_pretrained_embedding(0) 308 | #for i in xrange(5): 309 | # print('labelIdx: ', i) 310 | # train_without_pretrained_embedding(i) 311 | 312 | 313 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import copy 5 | import time 6 | import pickle 7 | import argparse 8 | import json 9 | import math 10 | import numpy as np 11 | 12 | import dataset_walker 13 | 14 | 15 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 16 | ontology_path = os.path.join(cur_dir, "config/ontology_dstc2.json") 17 | vocab_path = os.path.join(cur_dir, 'vocab.dict') 18 | 19 | # TODO !!! For consistency, this should be the ONLY place for loading ontology and vocab, and modifying them. 20 | # !!! Other files should import these data from here. 21 | vocab = pickle.load(open(vocab_path,'rb')) 22 | ontologyDict = json.load(open(ontology_path, 'r')) 23 | for key in ontologyDict[u'informable']: 24 | ontologyDict[u'informable'][key].append('dontcare') 25 | 26 | # TODO 27 | #max_tag_index = 5 28 | #for i in xrange(1, max_tag_index+1): 29 | # ontologyDict['informable']['food'].append('#food%d#'%i) 30 | # ontologyDict['informable']['name'].append('#name%d#'%i) 31 | 32 | 33 | # ################################## 34 | 35 | # TODO replace un-informable values with tags, for example: 36 | # a machine act is "inform(addr='alibaba qijian dian')", then it is replaced as "inform(addr=)" 37 | # default to [] to disable any such replacement 38 | #replace_un_informable_slots = ['phone', 'postcode', 'addr'] 39 | replace_un_informable_slots = [] 40 | 41 | label_slot_order = ['food', 'pricerange', 'name', 'area'] 42 | 43 | def label2vec(labelDict, method, reqList): 44 | ''' 45 | Parameters: 46 | 1. goal 47 | 2. method 48 | 3. requests 49 | Return Value: 50 | 1. resIdx 51 | ''' 52 | resIdx = list() 53 | 54 | for slot in label_slot_order: 55 | if slot in labelDict and labelDict[slot] in ontologyDict['informable'][slot]: 56 | resIdx.append(ontologyDict['informable'][slot].index(labelDict[slot])) 57 | else: 58 | # the max index is for the special value: "none" 59 | resIdx.append(len(ontologyDict['informable'][slot])) 60 | 61 | resIdx.append(ontologyDict['method'].index(method)) 62 | 63 | reqVec = [0.0] * len(ontologyDict['requestable']) 64 | for req in reqList: 65 | reqVec[ontologyDict['requestable'].index(req)] = 1 66 | resIdx.append(reqVec) 67 | 68 | return resIdx 69 | 70 | def genTurnData_nbest(turn, labelJson): 71 | turnData = dict() 72 | 73 | # process user_input : exp scores 74 | user_input = turn["input"]["live"]["asr-hyps"] 75 | for asr_pair in user_input: 76 | asr_pair['score'] = math.exp(float(asr_pair['score'])) 77 | 78 | # process machine_output : replace un-informable value with tags 79 | machine_output = turn["output"]["dialog-acts"] 80 | for slot in replace_un_informable_slots : 81 | for act in machine_output: 82 | for pair in act["slots"]: 83 | if len(pair) >= 2 and pair[0] == slot: 84 | pair[1] = '<%s>' % slot 85 | 86 | # generate labelIdx 87 | labelIdx = label2vec(labelJson['goal-labels'], labelJson['method-label'], labelJson['requested-slots']) 88 | 89 | turnData["user_input"] = user_input 90 | turnData["machine_output"] = machine_output 91 | turnData["labelIdx"] = labelIdx 92 | return turnData 93 | 94 | # ################################## 95 | def tagTurnData(turnData, ontology): 96 | """将一个turn的数据进行tag替换""" 97 | tagged_turnData = copy.deepcopy(turnData) 98 | tag_dict = {} 99 | for slot in ["food", "name"]: 100 | val_ind = 1 101 | for slot_val in ontology["informable"][slot]: 102 | if slot_val.startswith("#%s"%slot): 103 | continue 104 | cur_tag = "#%s%d#" % (slot, val_ind) 105 | replace_flag = False 106 | 107 | # process user_input 108 | for i in xrange(len(tagged_turnData["user_input"])): 109 | sentence = tagged_turnData["user_input"][i]['asr-hyp'] 110 | tag_sentence = sentence.replace(slot_val, cur_tag) 111 | if tag_sentence != sentence: 112 | tagged_turnData["user_input"][i]['asr-hyp'] = tag_sentence 113 | tag_dict[cur_tag] = slot_val 114 | replace_flag = True 115 | 116 | # process machine_output 117 | for act in tagged_turnData["machine_output"]: 118 | for pair in act["slots"]: 119 | if len(pair) >= 2 and pair[0] == slot and pair[1] == slot_val: 120 | pair[1] = cur_tag 121 | tag_dict[cur_tag] = slot_val 122 | replace_flag = True 123 | 124 | if replace_flag: 125 | val_ind += 1 126 | if val_ind > max_tag_index: 127 | break 128 | 129 | # process labelIdx 130 | val_ind_dict = {ontology["informable"][slot].index(v):ontology["informable"][slot].index(k) 131 | for k, v in tag_dict.items() if k.startswith("#%s"%slot)} 132 | labelIdx_ind = label_slot_order.index(slot) 133 | labelIdx = tagged_turnData["labelIdx"][labelIdx_ind] 134 | if labelIdx in val_ind_dict: 135 | tagged_turnData["labelIdx"][labelIdx_ind] = val_ind_dict[labelIdx] 136 | 137 | # add tag_dict to tagged_turnData 138 | tagged_turnData["tag_dict"] = tag_dict 139 | 140 | 141 | return tagged_turnData 142 | 143 | def genTurnData_nbest_tagged(turn, labelJson): 144 | turnData = genTurnData_nbest(turn, labelJson) 145 | turnData = tagTurnData(turnData, ontologyDict) 146 | return turnData 147 | 148 | # ################################## 149 | def main(): 150 | parser = argparse.ArgumentParser(description='Simple hand-crafted dialog state tracker baseline.') 151 | parser.add_argument('--dataset', dest='dataset', action='store', metavar='DATASET', required=True, 152 | help='The dataset to analyze') 153 | parser.add_argument('--dataroot',dest='dataroot',action='store',required=True,metavar='PATH', 154 | help='Will look for corpus in //...') 155 | parser.add_argument('--output_type',dest='output_type',action='store',default='nbest', 156 | help='the type of output json') 157 | args = parser.parse_args() 158 | dataset = dataset_walker.dataset_walker(args.dataset, dataroot=args.dataroot, labels=True) 159 | 160 | def gen_data(func_genTurnData): 161 | data = [] 162 | for call in dataset: 163 | fileData = dict() 164 | fileData["session-id"] = call.log["session-id"] 165 | fileData["turns"] = list() 166 | #print {"session-id":call.log["session-id"]} 167 | for turn, labelJson in call: 168 | turnData = func_genTurnData(turn, labelJson) 169 | fileData["turns"].append(turnData) 170 | data.append(fileData) 171 | return data 172 | 173 | # different output type 174 | if args.output_type == 'nbest': 175 | res_data = gen_data(genTurnData_nbest) 176 | elif args.output_type == 'nbest_tagged': 177 | res_data1 = gen_data(genTurnData_nbest) 178 | res_data2 = gen_data(genTurnData_nbest_tagged) 179 | res_data = res_data1 + res_data2 180 | 181 | # write to json file 182 | file_prefix = args.dataset.split('_')[-1] 183 | res_file = "%s_%s.json" % (file_prefix, args.output_type) 184 | with open(res_file, "w") as fw: 185 | fw.write(json.dumps(res_data, indent=2)) 186 | 187 | 188 | if __name__ == '__main__': 189 | start_time = time.time() 190 | main() 191 | end_time = time.time() 192 | print 'time: ', end_time - start_time, 's' 193 | -------------------------------------------------------------------------------- /dataset_walker.py: -------------------------------------------------------------------------------- 1 | import os, json, re 2 | class dataset_walker(object): 3 | def __init__(self,dataset,labels=False,dataroot=None): 4 | if "[" in dataset : 5 | self.datasets = json.loads(dataset) 6 | elif type(dataset) == type([]) : 7 | self.datasets= dataset 8 | else: 9 | self.datasets = [dataset] 10 | self.dataset = dataset 11 | self.install_root = os.path.abspath(os.path.dirname(os.path.abspath(__file__))) 12 | self.dataset_session_lists = [os.path.join(self.install_root,'config',dataset + '.flist') for dataset in self.datasets] 13 | 14 | self.labels = labels 15 | if (dataroot == None): 16 | install_parent = os.path.dirname(self.install_root) 17 | self.dataroot = os.path.join(install_parent,'data') 18 | else: 19 | self.dataroot = os.path.join(os.path.abspath(dataroot)) 20 | 21 | # load dataset (list of calls) 22 | self.session_list = [] 23 | for dataset_session_list in self.dataset_session_lists : 24 | f = open(dataset_session_list) 25 | for line in f: 26 | line = line.strip() 27 | #line = re.sub('/',r'\\',line) 28 | #line = re.sub(r'\\+$','',line) 29 | if (line in self.session_list): 30 | raise RuntimeError,'Call appears twice: %s' % (line) 31 | self.session_list.append(line) 32 | f.close() 33 | 34 | def __iter__(self): 35 | for session_id in self.session_list: 36 | session_id_list = session_id.split('/') 37 | session_dirname = os.path.join(self.dataroot,*session_id_list) 38 | applog_filename = os.path.join(session_dirname,'log.json') 39 | if (self.labels): 40 | labels_filename = os.path.join(session_dirname,'label.json') 41 | if (not os.path.exists(labels_filename)): 42 | raise RuntimeError,'Cant score : cant open labels file %s' % (labels_filename) 43 | else: 44 | labels_filename = None 45 | call = Call(applog_filename,labels_filename) 46 | call.dirname = session_dirname 47 | yield call 48 | def __len__(self, ): 49 | return len(self.session_list) 50 | 51 | 52 | class Call(object): 53 | def __init__(self,applog_filename,labels_filename): 54 | self.applog_filename = applog_filename 55 | self.labels_filename = labels_filename 56 | f = open(applog_filename) 57 | self.log = json.load(f) 58 | f.close() 59 | if (labels_filename != None): 60 | f = open(labels_filename) 61 | self.labels = json.load(f) 62 | f.close() 63 | else: 64 | self.labels = None 65 | 66 | def __iter__(self): 67 | if (self.labels_filename != None): 68 | for (log,labels) in zip(self.log['turns'],self.labels['turns']): 69 | yield (log,labels) 70 | else: 71 | for log in self.log['turns']: 72 | yield (log,None) 73 | 74 | def __len__(self, ): 75 | return len(self.log['turns']) 76 | 77 | 78 | if __name__ == '__main__': 79 | import misc 80 | dataset = dataset_walker("HDCCN", dataroot="data", labels=True) 81 | for call in dataset : 82 | if call.log["session-id"]=="voip-f32f2cfdae-130328_192703" : 83 | for turn, label in call : 84 | print misc.S(turn) 85 | -------------------------------------------------------------------------------- /extract_vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pickle 4 | import time 5 | import os 6 | import json 7 | #from collections import OrderedDict 8 | 9 | import dataset_walker 10 | 11 | # ontology所在位置 12 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 13 | ontology_path = os.path.join(cur_dir, 'config/ontology_dstc2.json') 14 | ontology = json.load(open(ontology_path, 'r')) 15 | 16 | dataset_name = 'dstc2_train' 17 | dataroot = 'dstc2_traindev/data' 18 | 19 | start_time = time.time() 20 | 21 | 22 | # vocab to be generated 23 | vocab_dict = {} 24 | # 词表中每个词出现的最小词频 25 | oov_threshold = 20 26 | # 出现过的所有词 27 | word_dict = {} 28 | 29 | # pre-set some words 30 | vocab_dict[""] = len(vocab_dict) 31 | vocab_dict[""] = len(vocab_dict) 32 | vocab_dict["#turn#"] = len(vocab_dict) 33 | 34 | # TODO 35 | vocab_dict["#food#"] = len(vocab_dict) 36 | # vocab_dict["#food2#"] = len(vocab_dict) 37 | # vocab_dict["#food3#"] = len(vocab_dict) 38 | # vocab_dict["#food4#"] = len(vocab_dict) 39 | # vocab_dict["#food5#"] = len(vocab_dict) 40 | vocab_dict["#name#"] = len(vocab_dict) 41 | vocab_dict["#slot#"] = len(vocab_dict) 42 | # vocab_dict["#name2#"] = len(vocab_dict) 43 | # vocab_dict["#name3#"] = len(vocab_dict) 44 | # vocab_dict["#name4#"] = len(vocab_dict) 45 | # vocab_dict["#name5#"] = len(vocab_dict) 46 | 47 | # TODO 48 | # vocab_dict[""] = len(vocab_dict) 49 | # vocab_dict[""] = len(vocab_dict) 50 | # vocab_dict[""] = len(vocab_dict) 51 | 52 | def add_word(word): 53 | word=word.encode('utf-8') 54 | """向word_dict中添加一个word""" 55 | word_dict[word] = word_dict.get(word, 0) + 1 56 | 57 | def add_words(words): 58 | """向word_dict中添加若干个words""" 59 | word_list = words.split() 60 | # add 1-gram word 61 | for word in word_list: 62 | add_word(word) 63 | # TODO add 2-gram word 64 | #for word in [' '.join(word_list[i:i+2]) for i in xrange(len(word_list)-1)]: 65 | # add_word(word) 66 | # TODO add 3-gram word 67 | #for word in [' '.join(word_list[i:i+3]) for i in xrange(len(word_list)-2)]: 68 | # add_word(word) 69 | 70 | 71 | # include all ontology values into vocab 72 | # TODO 73 | add_words("none") 74 | add_words("dontcare") 75 | 76 | for i in range(59): 77 | for key in ontology: 78 | add_words(key) 79 | if key in ["requestable", "method"]: 80 | for val in ontology[key]: 81 | add_words(val) 82 | elif key == "informable": 83 | for slot in ["area", "pricerange"]: 84 | add_words(slot) 85 | for val in ontology[key][slot]: 86 | add_words(val) 87 | # TODO 88 | for slot in ["food"]: 89 | add_words(slot) 90 | for val in ontology[key][slot]: 91 | add_words(val) 92 | # TODO 93 | for slot in ["name"]: 94 | add_words(slot) 95 | for val in ontology[key][slot]: 96 | add_words(val) 97 | 98 | 99 | # include asr words and slu words appeared in data set 100 | dataset = dataset_walker.dataset_walker(dataset_name, dataroot=dataroot, labels=True) 101 | add_words("asr") 102 | add_words("slots") 103 | add_words("act") 104 | for call in dataset: 105 | for turn, labelJson in call: 106 | asrs = turn["input"]["live"]["asr-hyps"] 107 | 108 | # 1best 109 | add_words(asrs[0]["asr-hyp"]) 110 | 111 | # 2best - nbest 112 | # TODO 113 | for asr in asrs[1:]: 114 | add_words(asr["asr-hyp"]) 115 | 116 | # dialog acts 117 | machine_act_words = [] 118 | for act_item in turn["output"]["dialog-acts"]: 119 | if "act" in act_item: 120 | machine_act_words.append(act_item["act"]) 121 | if "slots" in act_item: 122 | for item in act_item["slots"]: 123 | for item_val in item: 124 | machine_act_words.append(item_val) 125 | machine_act = ' '.join(machine_act_words) 126 | add_words(machine_act) 127 | 128 | 129 | # save vocab to file 130 | # TODO modify file name if needed 131 | with open('vocab_matNN.dict', 'wb') as f: 132 | for word, freq in word_dict.items(): 133 | if freq >= oov_threshold: 134 | vocab_dict[word] = len(vocab_dict) 135 | pickle.dump(vocab_dict, f) 136 | 137 | 138 | end_time = time.time() 139 | print "vocab size:", len(vocab_dict) 140 | #print vocab_dict 141 | print "cost time: ", end_time-start_time, 's' 142 | -------------------------------------------------------------------------------- /gen_custom_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from data import genTurnData_nbest 6 | 7 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser(description='Simple hand-crafted dialog state tracker baseline.') 12 | parser.add_argument('--dst_history_dir', dest='dst_history_dir', action='store', default="dst_history", 13 | help='The dir of custom dst_history.') 14 | parser.add_argument('--start_ind', dest='start_ind', action='store', type=int, default=0, 15 | help='start index.') 16 | parser.add_argument('--end_ind', dest='end_ind', action='store', type=int, default=2999, 17 | help='end index.') 18 | parser.add_argument('--output_name', dest='output_name', action='store', default="train_custom.json", 19 | help='The output filename.') 20 | args = parser.parse_args() 21 | 22 | log_list = [os.path.join(cur_dir, args.dst_history_dir, 'log-%d.json'%i) 23 | for i in xrange(args.start_ind, args.end_ind+1)] 24 | label_list = [os.path.join(cur_dir, args.dst_history_dir, 'label-%d.json'%i) 25 | for i in xrange(args.start_ind, args.end_ind+1)] 26 | 27 | data = [] 28 | for i in xrange(len(log_list)): 29 | fileData = dict() 30 | fileData["turns"] = list() 31 | with open(log_list[i], 'r') as log_file: 32 | with open(label_list[i], 'r') as label_file: 33 | log_json = json.load(log_file) 34 | label_json = json.load(label_file) 35 | for i in xrange(len(log_json["turns"])): 36 | turnData = genTurnData_nbest(log_json["turns"][i], label_json["turns"][i]) 37 | fileData["turns"].append(turnData) 38 | data.append(fileData) 39 | with open(args.output_name, "w") as fw: 40 | fw.write(json.dumps(data, indent=1)) 41 | 42 | if __name__ == '__main__': 43 | main() 44 | -------------------------------------------------------------------------------- /get_embbeding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import pickle 5 | import numpy as np 6 | 7 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | EMBEDDING_DIM =300 10 | vocab_version = 'vN3' 11 | 12 | pretrained_file ='paragram_300_sl999.txt'# 13 | vocab_dict_file = 'vocab_matNN.dict' 14 | embedding_matrix_file = 'embed_%s.npy' % vocab_version 15 | vocab_embed_txt = 'vocab2embed_%s.txt' % vocab_version 16 | 17 | 18 | # load vocab_dict 19 | vocab_dict = pickle.load(open(os.path.join(cur_dir, vocab_dict_file))) 20 | #print vocab_dict 21 | 22 | 23 | 24 | # load pretrained embeddings 25 | embeddings_index = {} 26 | 27 | embedding_matrix = np.zeros((len(vocab_dict), EMBEDDING_DIM)) 28 | #print vocab_dict.keys() 29 | 30 | m=0 31 | with open(os.path.join(cur_dir, pretrained_file)) as f: 32 | for line in f: 33 | values = line.split() 34 | word = values[0] 35 | if word in vocab_dict.keys(): 36 | coefs = np.asarray(values[1:]) 37 | embedding_matrix[vocab_dict.get(word)]=coefs 38 | m+=1 39 | print 'Found %s word vectors.' % m 40 | print len(coefs) 41 | embedding_matrix[1] = np.asarray([0.0]*EMBEDDING_DIM) 42 | embedding_matrix[0] = np.asarray([0.0]*EMBEDDING_DIM) 43 | 44 | save_path = os.path.join(cur_dir, 'vocab_set', embedding_matrix_file) 45 | np.save(save_path, embedding_matrix) 46 | print 'Save embedding_matrix to:', save_path 47 | 48 | 49 | 50 | # generate vocab2embed_list 51 | vocab2embed_list = [None] * len(vocab_dict) 52 | for word, i in vocab_dict.items(): 53 | vocab2embed_list[i] = [word] + [str(num) for num in embedding_matrix[i].tolist()] 54 | file_string = '\n'.join([' '.join(item) for item in vocab2embed_list]) 55 | 56 | save_path = os.path.join(cur_dir, 'vocab_set', vocab_embed_txt) 57 | with open(save_path, 'w') as f: 58 | f.write(file_string) 59 | print 'Save vocab_embed_txt to:', save_path 60 | 61 | -------------------------------------------------------------------------------- /lectrack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | #import sys 4 | import numpy as np 5 | import mxnet as mx 6 | 7 | from DSTmodels import lstm_unroll 8 | from DSTmodels import lstmcnn_unroll 9 | from DSTmodels import blstm_unroll 10 | from DSTmodels import lstmn_unroll 11 | from DSTmodels import text_cnn 12 | from DSTmodels import bowlstm_unroll 13 | from DSTmodels import cnnlstm_unroll 14 | from DSTmodels import cnncnnlstm_unroll 15 | from DSTmodels import doublelstm_unroll 16 | from DSTmodels import reslstm_unroll 17 | from bucket_io import SimpleBatch 18 | from bucket_io import default_text2id 19 | from mat_data import ontologyDict 20 | 21 | 22 | class LecTrack(object): 23 | """LecTrack implementation, config_dict: 24 | output_type: can be one of ['softmax', 'sigmoid'] 25 | N: number of cpu/gpu cores 26 | pretrain_embed: whether to use pre-trained word embedding 27 | embed_matrix: pre-trained embed_matrix file 28 | fix_embed: whether to keep embed_matrix unchanged while training 29 | """ 30 | def __init__(self, config_dict): 31 | # Configuration 32 | self.input_size = config_dict.get('input_size') 33 | self.num_label = config_dict.get('num_label') 34 | self.nn_type = config_dict.get('nn_type', 'lstm') 35 | self.output_type = config_dict.get('output_type', 'softmax') 36 | self.context_type = config_dict.get('context_type', 'cpu') 37 | self.dropout = config_dict.get('dropout', 0.) 38 | self.batch_size = config_dict.get('batch_size', 32) 39 | self.optimizer = config_dict.get('optimizer', 'adam') 40 | self.initializer = config_dict.get('initializer', 'xavier') 41 | 42 | # Configurations that usually does not need to tune 43 | self.N = config_dict.get('N', 1) 44 | self.enable_mask = config_dict.get('enable_mask', False) 45 | self.num_embed = config_dict.get('num_embed', 300) 46 | self.num_lstm_layer = config_dict.get('num_lstm_layer', 1) 47 | self.num_lstm_o = config_dict.get('num_lstm_o', 128) 48 | self.num_hidden = config_dict.get('num_hidden', 128) 49 | self.learning_rate = config_dict.get('learning_rate', 0.01) 50 | self.pretrain_embed = config_dict.get('pretrain_embed', False) 51 | self.embed_matrix = config_dict.get('embed_matrix', 'embed_mat.npy') 52 | self.fix_embed = config_dict.get('fix_embed', True) 53 | self.buckets = config_dict.get('buckets', []) 54 | 55 | # ################################## 56 | if self.nn_type in ['lstmn']: 57 | self.buckets = [3, 5, 8, 10, 13, 15, 18, 22, 26, 30, 34, 38, 42, 46, 50, 55, 60, 65, 70, 80, 90, 110, 150, 350] 58 | elif self.nn_type in ['cnn']: 59 | self.buckets = [10, 13, 15, 18, 22, 26, 30, 34, 38, 42, 46, 50, 55, 60, 65, 70, 80, 90, 110, 150, 350, 500] 60 | elif self.nn_type in ['reslstm','matlstm','bowlstm', 'cnnlstm', 'cnncnnlstm','doublelstm']: 61 | self.buckets = range(1, 31) 62 | else: 63 | self.buckets = [3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 500] 64 | self.default_bucket_key = 500 if len(self.buckets) == 0 else max(self.buckets) 65 | print '[INFO]: nn_type is %s, buckets are %s' % (self.nn_type, str(self.buckets)) 66 | 67 | self.batch_size *= self.N 68 | if self.context_type == 'gpu': 69 | self.contexts = [mx.context.gpu(i) for i in range(self.N)] 70 | else: 71 | self.contexts = [mx.context.cpu(i) for i in range(self.N)] 72 | 73 | # ################################## 74 | if self.nn_type == "lstm": 75 | self.nn_unroll = lstm_unroll 76 | self.data_components = [('data', (self.batch_size, self.default_bucket_key)), \ 77 | ('score', (self.batch_size, self.default_bucket_key))] 78 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 79 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 80 | self.init_states = self.init_c + self.init_h 81 | elif self.nn_type == "lstmcnn": 82 | self.nn_unroll = lstmcnn_unroll 83 | self.data_components = [('data', (self.batch_size, self.default_bucket_key)),]# \ 84 | #('score', (self.batch_size, self.default_bucket_key))] 85 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_embed)) for l in range(self.num_lstm_layer)] 86 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_embed)) for l in range(self.num_lstm_layer)] 87 | self.init_states = self.init_c + self.init_h 88 | elif self.nn_type == "blstm": 89 | self.nn_unroll = blstm_unroll 90 | self.data_components = [('data', (self.batch_size, self.default_bucket_key)), \ 91 | ('score', (self.batch_size, self.default_bucket_key))] 92 | self.forward_init_c = [('forward_l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 93 | self.forward_init_h = [('forward_l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 94 | self.backward_init_c = [('backward_l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 95 | self.backward_init_h = [('backward_l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 96 | self.init_states = self.forward_init_c + self.forward_init_h + self.backward_init_c + self.backward_init_h 97 | elif self.nn_type == "lstmn": 98 | self.nn_unroll = lstmn_unroll 99 | self.data_components = [('data', (self.batch_size, self.default_bucket_key)), \ 100 | ('score', (self.batch_size, self.default_bucket_key))] 101 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 102 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 103 | self.init_states = self.init_c + self.init_h 104 | elif self.nn_type == "cnn": 105 | self.nn_unroll = text_cnn 106 | self.data_components = [('data', (self.batch_size, self.default_bucket_key))] 107 | self.init_states = [] 108 | elif self.nn_type == "bowlstm": 109 | self.nn_unroll = bowlstm_unroll 110 | self.data_components = [('data', (self.batch_size, self.default_bucket_key, self.input_size)), \ 111 | ('data_act', (self.batch_size, self.default_bucket_key, self.input_size))] 112 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 113 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 114 | self.init_states = self.init_c + self.init_h 115 | elif self.nn_type == "cnnlstm": 116 | self.nn_unroll = cnnlstm_unroll 117 | self.max_nbest = 10 # the maximum number of nbest asr in train/dev/test is 10 118 | self.max_sentlen = 30 # the maximum length of user utterance in train/dev/test is 23 119 | self.data_components = [('data', (self.batch_size, self.default_bucket_key,self.max_nbest,self.max_sentlen)), \ 120 | ('data_act', (self.batch_size, self.default_bucket_key, self.max_sentlen)), \ 121 | ('score', (self.batch_size, self.default_bucket_key, self.max_nbest))] 122 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 123 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_label)) for l in range(self.num_lstm_layer)] 124 | self.init_states = self.init_c + self.init_h 125 | elif self.nn_type == "reslstm": 126 | self.numM=2 127 | self.nn_unroll = reslstm_unroll 128 | self.max_nbest = 10 # the maximum number of nbest asr in train/dev/test is 10 129 | self.max_sentlen = 30 # the maximum length of user utterance in train/dev/test is 23 130 | self.data_components = [('data', (self.batch_size, self.default_bucket_key, 2,self.input_size)), \ 131 | ('data_act', (self.batch_size, self.default_bucket_key,2, self.input_size))] 132 | # ('score', (self.batch_size, self.default_bucket_key, self.max_nbest))] 133 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 134 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 135 | self.init_m=[('m_init_m%d'%i,(self.batch_size,2,self.input_size))for i in range(self.numM)] 136 | #self.init_con=[('one',(self.batch_size,1,self.num_label))] 137 | self.init_states = self.init_m +self.init_c + self.init_h#+self.init_con 138 | 139 | elif self.nn_type == "doublelstm": 140 | self.nn_unroll = doublelstm_unroll 141 | self.max_nbest = 10 # the maximum number of nbest asr in train/dev/test is 10 142 | self.max_sentlen = 30 # the maximum length of user utterance in train/dev/test is 23 143 | 144 | self.numM=2 145 | 146 | val_comp=[] 147 | for i,nv in enumerate(self.num_label): 148 | val_comp.append(('value_%d'%i,(self.batch_size,nv,300))) 149 | self.data_components = [('data', (self.batch_size, self.default_bucket_key, self.max_sentlen,300)), \ 150 | ('data_act', (self.batch_size, self.default_bucket_key, self.input_size)),\ 151 | ('slot',(self.batch_size,len(self.num_label),300)),\ 152 | ]+val_comp 153 | 154 | #('value',(self.batch_size,self.num_label,300)) ] 155 | # ('score', (self.batch_size, self.default_bucket_key, self.max_nbest))] 156 | #self.forward_init_c = [('forward_l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 157 | #self.forward_init_h = [('forward_l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 158 | self.backward_init_c = [('backward_l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 159 | self.backward_init_h = [('backward_l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 160 | #self.init_m=[('m_init_m%d'%i,(self.batch_size,2,128))for i in range(self.numM)] 161 | self.init_states = self.backward_init_c + self.backward_init_h#+self.init_m 162 | elif self.nn_type == "cnncnnlstm": 163 | self.nn_unroll = cnncnnlstm_unroll 164 | self.max_nbest = 10 # the maximum number of nbest asr in train/dev/test is 10 165 | self.max_sentlen = 35 # the maximum length of user utterance in train/dev/test is 23 166 | self.data_components = [('data', (self.batch_size, self.default_bucket_key, self.max_nbest, self.max_sentlen)), \ 167 | ('data_act', (self.batch_size, self.default_bucket_key, self.max_sentlen)), \ 168 | ('score', (self.batch_size, self.default_bucket_key, self.max_nbest))] 169 | self.init_c = [('l%d_init_c'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 170 | self.init_h = [('l%d_init_h'%l, (self.batch_size, self.num_lstm_o)) for l in range(self.num_lstm_layer)] 171 | self.init_states = self.init_c + self.init_h 172 | 173 | if self.enable_mask and self.nn_type in ['reslstm','matlstm','lstm', 'blstm', 'lstmn']: 174 | self.data_components += [('data_mask_len', (self.batch_size,))] 175 | print '[INFO]: Enable data mask' 176 | self.default_provide_data = self.data_components + self.init_states 177 | 178 | tmp_label_out = self.num_label if self.output_type == 'sigmoid' else len(self.num_label) 179 | if self.nn_type in ['bowlstm','reslstm','matlstm', 'cnnlstm', 'cnncnnlstm','doublelstm']: 180 | self.default_provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key, tmp_label_out))] 181 | else: 182 | self.default_provide_label = [('softmax_label', (self.batch_size, tmp_label_out))] 183 | 184 | # ################################## 185 | self.model = mx.mod.BucketingModule( 186 | sym_gen = self.sym_gen, 187 | default_bucket_key = self.default_bucket_key, 188 | context = self.contexts) 189 | self.model.bind(data_shapes = self.default_provide_data, label_shapes = self.default_provide_label) 190 | 191 | # ################################## 192 | if self.initializer == 'xavier': 193 | default_init = mx.init.Xavier(magnitude=2.) 194 | elif self.initializer == 'uniform': 195 | default_init = mx.init.Uniform(0.1) 196 | print '[INFO]: Using initializer - %s' % self.initializer 197 | 198 | if self.pretrain_embed and self.embed_matrix: 199 | print '[INFO]: Using pre-trained word embedding.' 200 | embed_weight = np.load(self.embed_matrix) 201 | #print len(embed_weight) 202 | init = mx.initializer.Load(param={"embed_weight": embed_weight}, default_init=default_init, verbose=True) 203 | 204 | self.model.init_params(initializer=init) #aux_params=auxDic) 205 | else: 206 | if self.pretrain_embed and not self.embed_matrix: 207 | print '[WARNNING]: Pre-trained word embedding is not used.' 208 | print '[WARNNING]: Because: pretrain_embed is True while embed_matrix is not given.' 209 | #auxDic={} 210 | #auxDic['multn']=2.0 211 | #auxDic['one']=mx.nd.full((32,1,self.num_label),0.25) 212 | #cons=mx.init.Load(auxDic) 213 | #init=mx.init.Mixed(['one','.*'],[cons,default_init]) 214 | gamma=mx.nd.full((512,),0.1) 215 | gamma1=mx.nd.full((128,),0.1) 216 | init = mx.initializer.Load(param={"g50_gamma": gamma,"g60_gamma":gamma,"g70_gamma":gamma1}, default_init=default_init, verbose=True) 217 | self.model.init_params(initializer=default_init)#mx.init.MSRAPrelu()) 218 | 219 | 220 | def sym_gen(self, seq_len): 221 | if self.nn_type in ['lstm','lstmcnn', 'blstm', 'lstmn']: 222 | return self.nn_unroll( 223 | num_lstm_layer = self.num_lstm_layer, 224 | seq_len = seq_len, 225 | input_size = self.input_size, 226 | num_hidden = self.num_hidden, 227 | num_embed = self.num_embed, 228 | num_lstm_o = self.num_lstm_o, 229 | num_label = self.num_label, 230 | output_type = self.output_type, 231 | dropout = self.dropout, 232 | fix_embed = self.fix_embed, 233 | enable_mask = self.enable_mask 234 | ) 235 | elif self.nn_type in ['cnn']: 236 | return self.nn_unroll( 237 | seq_len = seq_len, 238 | num_embed = self.num_embed, 239 | input_size = self.input_size, 240 | num_label = self.num_label, 241 | filter_list = [3, 4, 5], 242 | num_filter = 200, 243 | output_type = self.output_type, 244 | dropout = self.dropout, 245 | fix_embed = self.fix_embed 246 | ) 247 | elif self.nn_type in ['bowlstm']: 248 | return self.nn_unroll( 249 | num_lstm_layer = self.num_lstm_layer, 250 | seq_len = seq_len, 251 | input_size = self.input_size, 252 | num_hidden = self.num_hidden, 253 | num_embed = self.num_embed, 254 | num_lstm_o = self.num_lstm_o, 255 | num_label = self.num_label, 256 | output_type = self.output_type, 257 | dropout = self.dropout 258 | ) 259 | elif self.nn_type in ['reslstm','matlstm','cnnlstm', 'cnncnnlstm','doublelstm']: 260 | return self.nn_unroll( 261 | num_lstm_layer = self.num_lstm_layer, 262 | seq_len = seq_len, 263 | input_size = self.input_size, 264 | num_hidden = self.num_hidden, 265 | num_embed = self.num_embed, 266 | num_lstm_o = self.num_lstm_o, 267 | num_label = self.num_label, 268 | filter_list = [3, 4, 5], 269 | num_filter = 100, 270 | max_nbest = self.max_nbest, 271 | max_sentlen = self.max_sentlen, 272 | output_type = self.output_type, 273 | dropout = self.dropout 274 | ) 275 | 276 | def load_params(self, params_path): 277 | self.model.load_params(params_path) 278 | 279 | def save_params(self, params_path): 280 | self.model.save_params(params_path) 281 | 282 | def train(self, data_batch): 283 | self.model.forward(data_batch) 284 | self.model.backward() 285 | self.model.update() 286 | 287 | def predict(self, data_batch): 288 | self.model.forward(data_batch, is_train=False) 289 | return self.model.get_outputs() 290 | 291 | def getMatchKey(self, sentence_len): 292 | if len(self.buckets) > 0: 293 | for key in self.buckets: 294 | if key >= sentence_len: 295 | return key 296 | return sentence_len 297 | 298 | def oneSentenceBatch(self, cur_sentence, cur_score, cur_label, label_out): 299 | cur_bucket_key = self.getMatchKey(len(cur_sentence)) 300 | 301 | data = np.zeros((self.batch_size, cur_bucket_key)) 302 | data_mask_len = np.zeros((self.batch_size, )) 303 | data_score = np.zeros((self.batch_size, cur_bucket_key)) 304 | label = np.zeros((self.batch_size, label_out)) 305 | data[:, :len(cur_sentence)] = cur_sentence 306 | data_mask_len[:] = len(cur_sentence) 307 | data_score[:, :len(cur_score)] = cur_score 308 | label[:, :label_out] = cur_label 309 | 310 | data_names = [x[0] for x in self.default_provide_data] 311 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 312 | data_all = [mx.nd.array(data)] + init_state_arrays 313 | if 'score' in data_names: 314 | data_all += [mx.nd.array(data_score)] 315 | if 'data_mask_len' in data_names: 316 | data_all += [mx.nd.array(data_mask_len)] 317 | 318 | label_names = ['softmax_label'] 319 | label_all = [mx.nd.array(label)] 320 | 321 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, cur_bucket_key) 322 | return data_batch 323 | 324 | def multiWordSentBatch(self, sentences, scores, labels, label_out): 325 | assert(len(sentences) <= self.batch_size) 326 | cur_bucket_key = self.getMatchKey(max([len(s) for s in sentences])) 327 | 328 | data = np.zeros((self.batch_size, cur_bucket_key)) 329 | data_mask_len = np.zeros((self.batch_size, )) 330 | data_score = np.zeros((self.batch_size, cur_bucket_key)) 331 | label = np.zeros((self.batch_size, label_out)) 332 | for i in xrange(len(sentences)): 333 | data[i, :len(sentences[i])] = sentences[i] 334 | data_mask_len[i] = len(sentences[i]) 335 | data_score[i, :len(scores[i])] = scores[i] 336 | label[i, :label_out] = labels[i] 337 | 338 | data_names = [x[0] for x in self.default_provide_data] 339 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 340 | data_all = [mx.nd.array(data)] + init_state_arrays 341 | if 'score' in data_names: 342 | data_all += [mx.nd.array(data_score)] 343 | if 'data_mask_len' in data_names: 344 | data_all += [mx.nd.array(data_mask_len)] 345 | 346 | label_names = ['softmax_label'] 347 | label_all = [mx.nd.array(label)] 348 | 349 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, cur_bucket_key) 350 | return data_batch 351 | 352 | def multiTurnBowBatch(self, sentences, acts, labels, label_out): 353 | assert(len(sentences) <= self.batch_size) 354 | cur_bucket_key = self.getMatchKey(max([len(s) for s in sentences])) 355 | 356 | data = np.zeros((self.batch_size, cur_bucket_key, self.input_size)) 357 | data_act = np.zeros((self.batch_size, cur_bucket_key, self.input_size)) 358 | label = np.zeros((self.batch_size, cur_bucket_key, label_out)) 359 | for i in xrange(len(sentences)): 360 | for j in xrange(len(sentences[i])): 361 | data[i, j, :len(sentences[i][j])] = sentences[i][j] 362 | data_act[i, j, :len(acts[i][j])] = acts[i][j] 363 | label[i, j, :label_out] = labels[i][j] 364 | 365 | data_names = [x[0] for x in self.default_provide_data] 366 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 367 | data_all = [mx.nd.array(data), mx.nd.array(data_act)] 368 | data_all += init_state_arrays 369 | 370 | label_names = ['softmax_label'] 371 | label_all = [mx.nd.array(label)] 372 | 373 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, cur_bucket_key) 374 | return data_batch 375 | 376 | 377 | def multiTurnBatch(self, labelIdx,sentences, acts, scores, labels, label_out, vocab, vocab1,feature_type='bowbow'): 378 | assert(len(sentences) <= self.batch_size) 379 | print len(vocab) 380 | 381 | print len(vocab1) 382 | cur_bucket_key = self.getMatchKey(max([len(s) for s in sentences])) 383 | 384 | padding_id = vocab[''] 385 | len_sent = self.max_sentlen if feature_type in ['sentsent', 'sentbow'] else len(vocab) 386 | len_act_sent = self.max_sentlen if feature_type in ['sentsent', 'bowsent'] else len(vocab1) 387 | 388 | embed_weight = mx.nd.array(np.load('embed_vN3.npy')) 389 | # convert data into ndarrays for better speed during training 390 | 391 | slotsent="food pricerange name area" 392 | slota=default_text2id(slotsent, vocab) 393 | slotarr=slotsent.split() 394 | #print slota 395 | 396 | val_len=len(ontologyDict[u'informable'][slotarr[labelIdx]]) 397 | 398 | vl=[] 399 | for key in ontologyDict[u'informable'][slotarr[labelIdx]]: 400 | #print key 401 | v=default_text2id(key,vocab) 402 | tmp=mx.nd.array(v) 403 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(vocab), weight=embed_weight, output_dim=300, name='embed') 404 | tmp=mx.nd.sum(tmp,axis=0) 405 | v=tmp.asnumpy() 406 | vl.append(v) 407 | vl=np.asarray(vl) 408 | #print vl 409 | #print len(vl) 410 | 411 | tmp=mx.nd.array([slota[labelIdx]]) 412 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(vocab), weight=embed_weight, output_dim=300, name='embed') 413 | slota=tmp.asnumpy() 414 | 415 | 416 | 417 | value=np.zeros((self.batch_size,val_len,300)) 418 | slot=np.zeros((self.batch_size,300)) 419 | for i in range(self.batch_size): 420 | slot[i]=slota 421 | value[i]=vl 422 | 423 | 424 | 425 | 426 | 427 | datatmp = np.full((self.batch_size, cur_bucket_key,2, self.max_nbest, len_sent), padding_id,dtype=np.double) 428 | data_act = np.full((self.batch_size, cur_bucket_key,2, len_act_sent), padding_id,dtype=np.double) 429 | data_score = np.zeros((self.batch_size, cur_bucket_key, self.max_nbest)) 430 | label = np.zeros((self.batch_size, cur_bucket_key, label_out)) 431 | 432 | data = np.full((self.batch_size, cur_bucket_key,2, len_sent, 300), padding_id,dtype=np.double) 433 | 434 | for i_diag in range(len(sentences)): 435 | for i_turn in range(len(sentences[i_diag])): 436 | act = acts[i_diag][i_turn] 437 | for i in range(2): 438 | data_act[i_diag, i_turn,i, :len(act[i])] = act[i] 439 | label[i_diag, i_turn, :] = labels[i_diag][i_turn] 440 | # be careful that, here, max_nbest can be smaller than current turn nbest number. extra-best will be truncated. 441 | for i_data in range(2): 442 | tempsent=[] 443 | for i_nbest in range(min(len(sentences[i_diag][i_turn][i_data]), self.max_nbest)): 444 | sentence = sentences[i_diag][i_turn][i_data][i_nbest] 445 | datatmp[i_diag, i_turn, i_data,i_nbest, :len(sentence)] = sentence 446 | tmp=mx.nd.array(datatmp[i_diag, i_turn, i_data,i_nbest]) 447 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(vocab), weight=embed_weight, output_dim=300, name='embed') 448 | sentence=tmp.asnumpy() 449 | score = scores[i_diag][i_turn][i_nbest] 450 | #preprocess 451 | sent =sentence*score 452 | tempsent.append(sent) 453 | data_score[i_diag, i_turn, i_nbest] = score 454 | tempsent=np.asarray(tempsent) 455 | scoredsent=np.sum(tempsent,axis=0) 456 | #scoredsent=scoredsent*2-1 457 | data[i_diag, i_turn, i_data] = scoredsent 458 | 459 | 460 | 461 | 462 | data_names = [x[0] for x in self.default_provide_data] 463 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 464 | data_all = [mx.nd.array(data), mx.nd.array(data_act)] 465 | if 'score' in data_names: 466 | data_all += [mx.nd.array(data_score)] 467 | if 'slot' in data_names: 468 | data_all += [mx.nd.array(slot)] 469 | if 'value' in data_names: 470 | data_all += [mx.nd.array(value)] 471 | 472 | 473 | data_all += init_state_arrays 474 | 475 | label_names = ['softmax_label'] 476 | label_all = [mx.nd.array(label)] 477 | 478 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, cur_bucket_key) 479 | return data_batch 480 | -------------------------------------------------------------------------------- /mat_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import copy 5 | import time 6 | import pickle 7 | import argparse 8 | import json 9 | import math 10 | import numpy as np 11 | 12 | import dataset_walker 13 | 14 | 15 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 16 | ontology_path = os.path.join(cur_dir, "config/ontology_dstc2.json") 17 | vocab_path = os.path.join(cur_dir, 'vocab_matNN.dict') 18 | 19 | # TODO !!! For consistency, this should be the ONLY place for loading ontology and vocab, and modifying them. 20 | # !!! Other files should import these data from here. 21 | vocab = pickle.load(open(vocab_path,'rb')) 22 | 23 | vocab_path2 = os.path.join(cur_dir, 'vocab_actN.dict') 24 | 25 | # TODO !!! For consistency, this should be the ONLY place for loading ontology and vocab, and modifying them. 26 | # !!! Other files should import these data from here. 27 | vocab1 = pickle.load(open(vocab_path2,'rb')) 28 | 29 | ontologyDict = json.load(open(ontology_path, 'r')) 30 | for key in ontologyDict[u'informable']: 31 | ontologyDict[u'informable'][key].append('dontcare') 32 | ontologyDict[u'informable'][key].append('none') 33 | 34 | # TODO 35 | #max_tag_index = 5 36 | #for i in xrange(1, max_tag_index+1): 37 | # ontologyDict['informable']['food'].append('#food%d#'%i) 38 | # ontologyDict['informable']['name'].append('#name%d#'%i) 39 | 40 | 41 | # ################################## 42 | 43 | # TODO replace un-informable values with tags, for example: 44 | # a machine act is "inform(addr='alibaba qijian dian')", then it is replaced as "inform(addr=)" 45 | # default to [] to disable any such replacement 46 | #replace_un_informable_slots = ['phone', 'postcode', 'addr'] 47 | replace_un_informable_slots = [] 48 | 49 | # 在对句子中的slotval值进行tagging处理时,不同slot的处理顺序 50 | label_slot_order = ['food', 'pricerange', 'name', 'area'] 51 | 52 | def label2vec(labelDict, method, reqList): 53 | ''' 54 | Parameters: 55 | 1. goal 56 | 2. method 57 | 3. requests 58 | Return Value: 59 | 1. resIdx 60 | ''' 61 | resIdx = list() 62 | 63 | for slot in label_slot_order: 64 | if slot in labelDict and labelDict[slot] in ontologyDict['informable'][slot]: 65 | resIdx.append(ontologyDict['informable'][slot].index(labelDict[slot])) 66 | else: 67 | # the max index is for the special value: "none" 68 | resIdx.append(len(ontologyDict['informable'][slot])) 69 | 70 | resIdx.append(ontologyDict['method'].index(method)) 71 | 72 | reqVec = [0.0] * len(ontologyDict['requestable']) 73 | for req in reqList: 74 | reqVec[ontologyDict['requestable'].index(req)] = 1 75 | resIdx.append(reqVec) 76 | 77 | return resIdx 78 | 79 | def genTurnData_nbest(turn, labelJson): 80 | turnData = dict() 81 | 82 | # process user_input : exp scores 83 | user_input = turn["input"]["live"]["asr-hyps"] 84 | ssum=0. 85 | for asr_pair in user_input: 86 | ssum+= math.exp(float(asr_pair['score'])) 87 | 88 | for asr_pair in user_input: 89 | asr_pair['score'] = math.exp(float(asr_pair['score']))/ssum 90 | 91 | # process machine_output : replace un-informable value with tags 92 | machine_output = turn["output"]["dialog-acts"] 93 | for slot in replace_un_informable_slots : 94 | for act in machine_output: 95 | for pair in act["slots"]: 96 | if len(pair) >= 2 and pair[0] == slot: 97 | pair[1] = '<%s>' % slot 98 | 99 | # generate labelIdx 100 | labelIdx = label2vec(labelJson['goal-labels'], labelJson['method-label'], labelJson['requested-slots']) 101 | 102 | turnData["user_input"] = user_input 103 | turnData["machine_output"] = machine_output 104 | turnData["labelIdx"] = labelIdx 105 | return turnData 106 | 107 | ############################### 108 | def tagTurnData(turnData, ontology): 109 | """将一个turn的数据进行tag替换""" 110 | tagged_turnData = copy.deepcopy(turnData) 111 | tag_dict = {} 112 | for slot in ["food", "name"]: 113 | val_ind = 1 114 | for slot_val in ontology["informable"][slot]: 115 | if slot_val.startswith("#%s"%slot): 116 | continue 117 | cur_tag = "#%s%d#" % (slot, val_ind) 118 | replace_flag = False 119 | 120 | # process user_input 121 | for i in xrange(len(tagged_turnData["user_input"])): 122 | sentence = tagged_turnData["user_input"][i]['asr-hyp'] 123 | tag_sentence = sentence.replace(slot_val, cur_tag) 124 | if tag_sentence != sentence: 125 | tagged_turnData["user_input"][i]['asr-hyp'] = tag_sentence 126 | tag_dict[cur_tag] = slot_val 127 | replace_flag = True 128 | 129 | # process machine_output 130 | for act in tagged_turnData["machine_output"]: 131 | for pair in act["slots"]: 132 | if len(pair) >= 2 and pair[0] == slot and pair[1] == slot_val: 133 | pair[1] = cur_tag 134 | tag_dict[cur_tag] = slot_val 135 | replace_flag = True 136 | 137 | if replace_flag: 138 | val_ind += 1 139 | if val_ind > max_tag_index: 140 | break 141 | 142 | # process labelIdx 143 | val_ind_dict = {ontology["informable"][slot].index(v):ontology["informable"][slot].index(k) 144 | for k, v in tag_dict.items() if k.startswith("#%s"%slot)} 145 | labelIdx_ind = label_slot_order.index(slot) 146 | labelIdx = tagged_turnData["labelIdx"][labelIdx_ind] 147 | if labelIdx in val_ind_dict: 148 | tagged_turnData["labelIdx"][labelIdx_ind] = val_ind_dict[labelIdx] 149 | 150 | # add tag_dict to tagged_turnData 151 | tagged_turnData["tag_dict"] = tag_dict 152 | 153 | 154 | return tagged_turnData 155 | 156 | 157 | 158 | ################################### 159 | def tagTurnData_matFS(turnData, ontology): 160 | """将一个turn的数据进行tag替换""" 161 | tagged_turnData = copy.deepcopy(turnData) 162 | tag_dict = {} 163 | for slot in ["food", "name"]: 164 | # val_ind = 1 165 | for slot_val in ontology["informable"][slot]: 166 | if slot_val.startswith("#%s"%slot): 167 | continue 168 | cur_tag = "#%s#" % (slot,) 169 | replace_flag = False 170 | 171 | # process user_input 172 | for i in xrange(len(tagged_turnData["user_input"])): 173 | sentence = tagged_turnData["user_input"][i]['asr-hyp'] 174 | tag_sentence = sentence.replace(slot_val, cur_tag) 175 | if tag_sentence != sentence: 176 | tagged_turnData["user_input"][i]['asr-hyp'] = tag_sentence 177 | tag_dict[cur_tag] = slot_val 178 | replace_flag = True 179 | 180 | # process machine_output 181 | for act in tagged_turnData["machine_output"]: 182 | for pair in act["slots"]: 183 | if len(pair) >= 2 and pair[0] == slot and pair[1] == slot_val: 184 | pair[0] = "#slot#" 185 | pair[1] = cur_tag 186 | tag_dict[cur_tag] = slot_val 187 | replace_flag = True 188 | 189 | #if replace_flag: 190 | # val_ind += 1 191 | #if val_ind > max_tag_index: 192 | # break 193 | 194 | for i in xrange(len(tagged_turnData["user_input"])): 195 | sentence = tagged_turnData["user_input"][i]['asr-hyp'] 196 | tag_sentence = sentence.replace(slot, "#slot#") 197 | 198 | # # process labelIdx 199 | # val_ind_dict = {ontology["informable"][slot].index(v):ontology["informable"][slot].index(k) 200 | # for k, v in tag_dict.items() if k.startswith("#%s"%slot)} 201 | # labelIdx_ind = label_slot_order.index(slot) 202 | # labelIdx = tagged_turnData["labelIdx"][labelIdx_ind] 203 | # if labelIdx in val_ind_dict: 204 | # tagged_turnData["labelIdx"][labelIdx_ind] = val_ind_dict[labelIdx] 205 | 206 | # # add tag_dict to tagged_turnData 207 | # tagged_turnData["tag_dict"] = tag_dict 208 | 209 | return tagged_turnData 210 | 211 | def genTurnData_nbest_tagged(turn, labelJson): 212 | turnData = genTurnData_nbest(turn, labelJson) 213 | turnData = tagTurnData_matFS(turnData, ontologyDict) 214 | return turnData 215 | 216 | # def genTurnData_matFV(turn,labelJson): 217 | # tagged_turnData =genTurnData_nbest(turn, labelJson) 218 | # tag_dict = {} 219 | # for slot in ["food", "name"]: 220 | # val_ind = 1 221 | # for slot_val in ontology["informable"][slot]: 222 | # if slot_val.startswith("#%s"%slot): 223 | # continue 224 | # cur_tag = "#value%d#" % (val_ind) 225 | # replace_flag = False 226 | 227 | # # process user_input 228 | # for i in xrange(len(tagged_turnData["user_input"])): 229 | # sentence = tagged_turnData["user_input"][i]['asr-hyp'] 230 | # tag_sentence = sentence.replace(slot_val, cur_tag) 231 | # if tag_sentence != sentence: 232 | # tagged_turnData["user_input"][i]['asr-hyp'] = tag_sentence 233 | # tag_dict[cur_tag] = slot_val 234 | # replace_flag = True 235 | 236 | # # process machine_output 237 | # for act in tagged_turnData["machine_output"]: 238 | # for pair in act["slots"]: 239 | # if len(pair) >= 2 and pair[0] == slot and pair[1] == slot_val: 240 | # pair[1] = cur_tag 241 | # tag_dict[cur_tag] = slot_val 242 | # replace_flag = True 243 | 244 | # if replace_flag: 245 | # val_ind += 1 246 | # if val_ind > max_tag_index: 247 | # break 248 | 249 | # # process labelIdx 250 | # val_ind_dict = {ontology["informable"][slot].index(v):ontology["informable"][slot].index(k) 251 | # for k, v in tag_dict.items() if k.startswith("#%s"%slot)} 252 | # labelIdx_ind = label_slot_order.index(slot) 253 | # labelIdx = tagged_turnData["labelIdx"][labelIdx_ind] 254 | # if labelIdx in val_ind_dict: 255 | # tagged_turnData["labelIdx"][labelIdx_ind] = val_ind_dict[labelIdx] 256 | 257 | # # add tag_dict to tagged_turnData 258 | # tagged_turnData["tag_dict"] = tag_dict 259 | 260 | def gen_resdata(dataset,output_type): 261 | 262 | def gen_data(func_genTurnData): 263 | data = [] 264 | for call in dataset: 265 | fileData = dict() 266 | fileData["session-id"] = call.log["session-id"] 267 | fileData["turns"] = list() 268 | #print {"session-id":call.log["session-id"]} 269 | for turn, labelJson in call: 270 | turnData = func_genTurnData(turn, labelJson) 271 | fileData["turns"].append(turnData) 272 | data.append(fileData) 273 | return data 274 | # different output type 275 | if output_type == 'nbest': 276 | res_data = gen_data(genTurnData_nbest) 277 | elif output_type == 'nbest_tagged': 278 | data = [] 279 | res_data1 = gen_data(genTurnData_nbest) 280 | data.append(res_data1) 281 | res_data2 = gen_data(genTurnData_nbest_tagged) 282 | data.append(res_data2) 283 | res_data = data 284 | return res_data 285 | 286 | 287 | 288 | 289 | 290 | # return tagged_turnData 291 | 292 | # ################################## 293 | def main(): 294 | parser = argparse.ArgumentParser(description='Simple hand-crafted dialog state tracker baseline.') 295 | parser.add_argument('--dataset', dest='dataset', action='store', metavar='DATASET', required=True, 296 | help='The dataset to analyze') 297 | parser.add_argument('--dataroot',dest='dataroot',action='store',required=True,metavar='PATH', 298 | help='Will look for corpus in //...') 299 | parser.add_argument('--output_type',dest='output_type',action='store',default='nbest', 300 | help='the type of output json') 301 | args = parser.parse_args() 302 | dataset = dataset_walker.dataset_walker(args.dataset, dataroot=args.dataroot, labels=True) 303 | 304 | 305 | 306 | res_data=gen_resdata(dataset,args.output_type) 307 | # write to json file 308 | file_prefix = args.dataset.split('_')[-1] 309 | res_file = "%s_%s.json" % (file_prefix, args.output_type) 310 | with open(res_file, "w") as fw: 311 | fw.write(json.dumps(res_data, indent=2)) 312 | 313 | 314 | if __name__ == '__main__': 315 | start_time = time.time() 316 | main() 317 | end_time = time.time() 318 | print 'time: ', end_time - start_time, 's' 319 | -------------------------------------------------------------------------------- /mat_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import json 5 | import math 6 | import os 7 | import numpy as np 8 | import mxnet as mx 9 | 10 | from bucket_io import SimpleBatch 11 | from bucket_io import default_text2id 12 | 13 | # The interface of a data iter that works for bucketing 14 | # 15 | # DataIter 16 | # - default_bucket_key: the bucket key for the default symbol. 17 | # 18 | # DataBatch 19 | # - provide_data: same as DataIter, but specific to this batch 20 | # - provide_label: same as DataIter, but specific to this batch 21 | # - bucket_key: the key for the bucket that should be used for this batch 22 | from mat_data import ontologyDict 23 | 24 | 25 | def read_nbest_dialog_content(dialog, labelIdx): 26 | """ 27 | dialog_sentences: (turn_num, nbest_num, sentence_len) 28 | dialog_scores: (turn_num, nbest_num) 29 | machine_acts: (turn_num, machine_act_len) 30 | dialog_labels: (turn_num, ) 31 | """ 32 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = [], [], [], [] 33 | for turn in dialog["turns"]: 34 | dialog_labels.append(turn["labelIdx"][labelIdx]) 35 | 36 | machine_act = "" 37 | for saPair in turn["machine_output"]: 38 | act = saPair["act"] 39 | slots = " " 40 | for slot in saPair["slots"]: 41 | #count never appears in train/dev set# 42 | if "count" in slot: 43 | #slot[1] = str(slot[1]) 44 | continue 45 | slots += " ".join(slot) 46 | slots += " " 47 | machine_act_item=(act+slots) 48 | machine_act += machine_act_item 49 | machine_act = machine_act.strip() 50 | machine_acts.append(machine_act) 51 | 52 | nbest_sentences = [] 53 | nbest_scores = [] 54 | for asr_hyp in turn["user_input"]: 55 | if len(asr_hyp["asr-hyp"].split()) == 0: 56 | continue 57 | nbest_scores.append(asr_hyp["score"]) 58 | sentence = "" 59 | #sentence +=" #turn# " 60 | sentence += asr_hyp["asr-hyp"] 61 | #sentence += " " 62 | nbest_sentences.append(sentence) 63 | dialog_sentences.append(nbest_sentences) 64 | dialog_scores.append(nbest_scores) 65 | 66 | return dialog_sentences, dialog_scores, machine_acts, dialog_labels 67 | 68 | def turn_read_content(path, labelIdx, dataIdx): 69 | """ 70 | sentences: (dialog_num, turn_num, nbest_num, sentence_len) 71 | scores: (dialog_num, turn_num, nbest_num) 72 | acts: (dialog_num, turn_num, machine_act_len) 73 | labels: (dialog_num, turn_num, [label_dim]) 74 | """ 75 | sentences, scores, acts, labels = [], [], [], [] 76 | with open(path) as json_file: 77 | data = json.load(json_file) 78 | #print data["data"][dataIdx] 79 | for dialog in data[dataIdx]: 80 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = read_nbest_dialog_content(dialog, labelIdx) 81 | sentences.append(dialog_sentences) 82 | scores.append(dialog_scores) 83 | acts.append(machine_acts) 84 | labels.append(dialog_labels) 85 | return sentences, scores, acts, labels 86 | 87 | def text2bow(sentence, the_vocab): 88 | res = np.zeros(len(the_vocab)) 89 | for word in the_vocab: 90 | if word in sentence: 91 | res[the_vocab[word]] = 1 92 | # words = sentence.split() 93 | # for word in words: 94 | # if word in the_vocab: 95 | # res[the_vocab[word]] = 1 96 | return res 97 | 98 | class MATTurnSentIter(mx.io.DataIter): 99 | """ 100 | feature_type,['bowbow', 'sentsent', 'bowsent', 'sentbow'] 101 | """ 102 | def __init__(self, path, labelIdx,vocab, vocab1, buckets, batch_size, max_nbest, max_sentlen, 103 | init_states, data_components, label_out=1, feature_type='bowbow'): 104 | super(MATTurnSentIter, self).__init__() 105 | self.vocab = vocab 106 | self.vocab1=vocab1 107 | self.padding_id = self.vocab[''] 108 | 109 | self.label_out = label_out 110 | 111 | self.max_nbest = max_nbest 112 | self.max_sentlen = max_sentlen 113 | self.feature_type = feature_type 114 | self.len_sent = self.max_sentlen if self.feature_type in ['sentsent', 'sentbow'] else len(self.vocab) 115 | self.len_act_sent = self.max_sentlen if self.feature_type in ['sentsent', 'bowsent'] else len(self.vocab1) 116 | 117 | sentences, scores, acts, labels = turn_read_content(path, labelIdx[0],0) 118 | #sentences1, scores1, acts1, labels1 = turn_read_content(path, labelIdx[0],1) 119 | 120 | 121 | lab=[] 122 | for i in labelIdx: 123 | se,sc,ac,l=turn_read_content(path, i,0) 124 | lab.append(l) 125 | 126 | labels0=[] 127 | for i in range(len(labels)): 128 | d0=[] 129 | for j in range(len(labels[i])): 130 | ll=[] 131 | for lb in lab: 132 | ll.append(lb[i][j]) 133 | d0.append(ll) 134 | labels0.append(d0) 135 | labels=labels0 136 | 137 | 138 | """ 139 | sentences: (dialog_num, turn_num, nbest_num, sentence_len) 140 | scores: (dialog_num, turn_num, nbest_num) 141 | acts: (dialog_num, turn_num, machine_act_len) 142 | labels: (dialog_num, turn_num, ) 143 | """ 144 | 145 | """ 146 | new 147 | sentences: (dialog_num, turn_num, 2, nbest_num, sentence_len) 148 | scores: (dialog_num, turn_num, nbest_num) 149 | acts: (dialog_num, turn_num, 2 , machine_act_len) 150 | labels: (dialog_num, turn_num,4 ) 151 | """ 152 | 153 | buckets.sort() 154 | self.buckets = buckets 155 | self.data = [[] for _ in buckets] 156 | self.data_act = [[] for _ in buckets] 157 | self.data_score = [[] for _ in buckets] 158 | self.label = [[] for _ in buckets] 159 | 160 | # pre-allocate with the largest bucket for better memory sharing 161 | self.default_bucket_key = max(buckets) 162 | 163 | for i in range(len(sentences)): 164 | sentence = sentences[i] 165 | score = scores[i] 166 | act = acts[i] 167 | label = labels[i] 168 | for turn_id in range(len(sentence)): 169 | # user sentence feature 170 | #for i in range(2): 171 | for nbest_id in range(len(sentence[turn_id])): 172 | if self.feature_type in ['sentsent', 'sentbow']: 173 | sentence[turn_id][nbest_id] = default_text2id(sentence[turn_id][nbest_id], self.vocab) 174 | elif self.feature_type in ['bowsent', 'bowbow']: 175 | sentence[turn_id][nbest_id] = text2bow(sentence[turn_id][nbest_id], self.vocab) 176 | # sys act feature 177 | if self.feature_type in ['sentbow', 'bowbow']: 178 | act[turn_id] = text2bow(act[turn_id], self.vocab1) 179 | elif self.feature_type in ['sentsent', 'bowsent']: 180 | act[turn_id] = default_text2id(act[turn_id], self.vocab1) 181 | for i, bkt in enumerate(buckets): 182 | if bkt == len(sentence): 183 | self.data[i].append(sentence) 184 | self.data_score[i].append(score) 185 | self.data_act[i].append(act) 186 | self.label[i].append(label) 187 | break 188 | """ 189 | sentence: (turn_num, nbest_num, len_sent) 190 | score: (turn_num, nbest_num) 191 | act: (turn_num, len_act_sent) 192 | label: (turn_num, label_out) 193 | """ 194 | # we just ignore the sentence it is longer than the maximum 195 | # bucket size here 196 | 197 | embed_weight = mx.nd.array(np.load('embed_vN3.npy')) 198 | slotsent="food pricerange name area" 199 | slota=default_text2id(slotsent, self.vocab) 200 | slotarr=slotsent.split() 201 | #print slota 202 | label_len=len(labelIdx) 203 | val_len=[] 204 | for i in labelIdx: 205 | val_len.append(len(ontologyDict[u'informable'][slotarr[i]])) 206 | 207 | vl=[] 208 | for i in labelIdx: 209 | vla=[] 210 | for key in ontologyDict[u'informable'][slotarr[i]]: 211 | #print key 212 | v=default_text2id(key,self.vocab) 213 | tmp=mx.nd.array(v) 214 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(self.vocab), weight=embed_weight, output_dim=300, name='embed') 215 | tmp=mx.nd.sum(tmp,axis=0) 216 | v=tmp.asnumpy() 217 | vla.append(v) 218 | vla=np.asarray(vla) 219 | vl.append(vla) 220 | #vl=np.asarray(vl) 221 | #print vl 222 | #print len(vl) 223 | 224 | sa=[] 225 | for i in labelIdx: 226 | tmp=mx.nd.array([slota[i]]) 227 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(self.vocab), weight=embed_weight, output_dim=300, name='embed') 228 | sa.append(tmp.asnumpy()) 229 | slota=np.squeeze(np.asarray(sa)) 230 | 231 | 232 | slot=np.zeros((batch_size,label_len,300)) 233 | for i in range(batch_size): 234 | slot[i]=slota 235 | 236 | value=[] 237 | for j in range(label_len): 238 | tmp=np.zeros((batch_size,val_len[j],300)) 239 | for i in range(batch_size): 240 | tmp[i]=vl[j] 241 | value.append(tmp) 242 | 243 | vl_name=[] 244 | for i in range(label_len): 245 | vl_name.append("value_%d"%i) 246 | self.vl_name=vl_name 247 | 248 | 249 | # convert data into ndarrays for better speed during training 250 | data_mask_len =[np.zeros((len(x), )) for i, x in enumerate(self.data)] 251 | datatmp = [np.full((len(x), buckets[i],self.max_nbest,self.len_sent), self.padding_id) for i, x in enumerate(self.data)] 252 | data_act = [np.full((len(x), buckets[i],self.len_act_sent), 0.0) for i, x in enumerate(self.data_act)] 253 | data_score =[np.zeros((len(x), buckets[i], self.max_nbest)) for i, x in enumerate(self.data_score)] 254 | label = [np.zeros((len(x), buckets[i], self.label_out)) for i, x in enumerate(self.label)] 255 | data =[np.zeros((len(x), buckets[i],self.len_sent,300)) for i, x in enumerate(self.data)] 256 | 257 | #slot =[np.zeros((len(x), buckets[i],300),dtype=np.float32) for i, x in enumerate(self.data)] 258 | for i_bucket in range(len(self.buckets)): 259 | for i_diag in range(len(self.data[i_bucket])): 260 | data_mask_len[i_bucket][i_diag]=len(self.data[i_bucket][i_diag]) 261 | for i_turn in range(len(self.data[i_bucket][i_diag])): 262 | 263 | act = self.data_act[i_bucket][i_diag][i_turn] 264 | #for i in range(2): 265 | data_act[i_bucket][i_diag, i_turn, :len(act)] = act 266 | label[i_bucket][i_diag, i_turn, :] = self.label[i_bucket][i_diag][i_turn] 267 | # be careful that, here, max_nbest can be smaller than current turn nbest number. extra-best will be truncated. 268 | #for i_data in range(2): 269 | tempsent=[] 270 | for i_nbest in range(min(len(self.data[i_bucket][i_diag][i_turn]), self.max_nbest)): 271 | sentence = self.data[i_bucket][i_diag][i_turn][i_nbest] 272 | datatmp[i_bucket][i_diag, i_turn,i_nbest, :len(sentence)] = sentence 273 | tmp=mx.nd.array(datatmp[i_bucket][i_diag, i_turn,i_nbest]) 274 | tmp= mx.nd.Embedding(data=tmp, input_dim=len(self.vocab), weight=embed_weight, output_dim=300, name='embed') 275 | sentence=tmp.asnumpy() 276 | score = self.data_score[i_bucket][i_diag][i_turn][i_nbest] 277 | #preprocess 278 | sent =sentence*score 279 | #if i_nbest ==0: 280 | tempsent.append(sent) 281 | data_score[i_bucket][i_diag, i_turn, i_nbest] = score 282 | tempsent=np.asarray(tempsent) 283 | scoredsent=np.sum(tempsent,axis=0) 284 | #scoredsent=scoredsent*2-1 285 | data[i_bucket][i_diag, i_turn] = scoredsent 286 | 287 | """ 288 | data: (bucket_num, dialog_num, bucket_size/turn_num, max_nbest, len_sent) 289 | score: (bucket_num, dialog_num, bucket_size/turn_num, max_nbest) 290 | data_act: (bucket_num, dialog_num, bucket_size/turn_num, len_act_sent) 291 | label: (bucket_num, dialog_num, bucket_size/turn_num, label_out) 292 | """ 293 | self.data_mask_len=data_mask_len 294 | self.data = data 295 | self.data_act = data_act 296 | self.data_score = data_score 297 | self.label = label 298 | 299 | self.slot=slot 300 | 301 | self.value=value 302 | 303 | # backup corpus 304 | self.all_data_mask_len = copy.deepcopy(self.data_mask_len) 305 | self.all_data = copy.deepcopy(self.data) 306 | self.all_data_act = copy.deepcopy(self.data_act) 307 | self.all_data_score = copy.deepcopy(self.data_score) 308 | self.all_label = copy.deepcopy(self.label) 309 | 310 | # Get the size of each bucket, so that we could sample 311 | # uniformly from the bucket 312 | sizeS=0 313 | bucket_sizes = [len(x) for x in self.data] 314 | print("Summary of dataset ==================") 315 | for bkt, size in zip(buckets, bucket_sizes): 316 | sizeS+=size 317 | print("bucket of len %3d : %d samples" % (bkt, size)) 318 | 319 | self.batch_size = batch_size 320 | #self.make_data_iter_plan() 321 | 322 | self.init_states = init_states 323 | self.data_components = data_components 324 | self.size=int(sizeS/batch_size) 325 | self.provide_data = self.data_components + self.init_states 326 | 327 | 328 | def make_data_iter_plan(self): 329 | "make a random data iteration plan" 330 | # truncate each bucket into multiple of batch-size 331 | bucket_n_batches = [] 332 | for i in range(len(self.data)): 333 | # shuffle data before truncate 334 | index_shuffle = range(len(self.data[i])) 335 | np.random.shuffle(index_shuffle) 336 | self.data[i] = self.all_data[i][index_shuffle] 337 | self.data_mask_len[i] = self.all_data_mask_len[i][index_shuffle] 338 | self.data_act[i] = self.all_data_act[i][index_shuffle] 339 | self.data_score[i] = self.all_data_score[i][index_shuffle] 340 | self.label[i] = self.all_label[i][index_shuffle] 341 | 342 | bucket_n_batches.append(int(math.ceil(1.0*len(self.data[i]) / self.batch_size))) 343 | self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] 344 | self.data_mask_len[i] = self.data_mask_len[i][:int(bucket_n_batches[i]*self.batch_size)] 345 | self.data_act[i] = self.data_act[i][:int(bucket_n_batches[i]*self.batch_size)] 346 | self.data_score[i] = self.data_score[i][:int(bucket_n_batches[i]*self.batch_size)] 347 | self.label[i] = self.label[i][:int(bucket_n_batches[i]*self.batch_size)] 348 | 349 | bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) 350 | np.random.shuffle(bucket_plan) 351 | 352 | bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] 353 | 354 | self.bucket_plan = bucket_plan 355 | self.bucket_idx_all = bucket_idx_all 356 | self.bucket_curr_idx = [0 for x in self.data] 357 | 358 | self.data_buffer = [] 359 | self.data_mask_len_buffer = [] 360 | self.data_act_buffer = [] 361 | self.data_score_buffer = [] 362 | self.label_buffer = [] 363 | for i_bucket in range(len(self.data)): 364 | data = np.zeros((self.batch_size, self.buckets[i_bucket], self.len_sent,300)) 365 | data_mask_len=np.zeros((self.batch_size,)) 366 | data_act = np.zeros((self.batch_size, self.buckets[i_bucket], self.len_act_sent)) 367 | data_score = np.zeros((self.batch_size, self.buckets[i_bucket], self.max_nbest)) 368 | label = np.zeros((self.batch_size, self.buckets[i_bucket], self.label_out)) 369 | self.data_buffer.append(data) 370 | self.data_mask_len_buffer.append(data_mask_len) 371 | self.data_act_buffer.append(data_act) 372 | self.data_score_buffer.append(data_score) 373 | self.label_buffer.append(label) 374 | 375 | def __iter__(self): 376 | self.make_data_iter_plan() 377 | for i_bucket in self.bucket_plan: 378 | data = self.data_buffer[i_bucket] 379 | data_mask_len = self.data_mask_len_buffer[i_bucket] 380 | data_act = self.data_act_buffer[i_bucket] 381 | data_score = self.data_score_buffer[i_bucket] 382 | label = self.label_buffer[i_bucket] 383 | data.fill(self.padding_id) 384 | data_act.fill(self.padding_id) 385 | data_score.fill(0) 386 | label.fill(0) 387 | 388 | i_idx = self.bucket_curr_idx[i_bucket] 389 | idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] 390 | self.bucket_curr_idx[i_bucket] += self.batch_size 391 | 392 | # Data parallelism 393 | data[:len(idx)] = self.data[i_bucket][idx] 394 | data_mask_len[:len(idx)] = self.data_mask_len[i_bucket][idx] 395 | data_act[:len(idx)] = self.data_act[i_bucket][idx] 396 | data_score[:len(idx)] = self.data_score[i_bucket][idx] 397 | label[:len(idx)] = self.label[i_bucket][idx] 398 | 399 | data_names = [x[0] for x in self.provide_data] 400 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 401 | data_all = [mx.nd.array(data), mx.nd.array(data_act)] 402 | if 'score' in data_names: 403 | data_all += [mx.nd.array(data_score)] 404 | if 'data_mask_len' in data_names: 405 | data_all += [mx.nd.array(data_mask_len)] 406 | 407 | if 'slot' in data_names: 408 | data_all += [mx.nd.array(self.slot)] 409 | 410 | label_names = ['softmax_label'] 411 | 412 | label_all = [mx.nd.array(label)] 413 | #labels=mx.nd.split(mx.nd.array(label),axis=-1,num_outputs=self.label_out,squeeze_axis=1) 414 | 415 | #data_batch=[] 416 | for i in range(self.label_out): 417 | if self.vl_name[i] in data_names: 418 | data_all += [mx.nd.array(self.value[i])] 419 | 420 | data_all += init_state_arrays 421 | 422 | 423 | pad = self.batch_size - len(idx) 424 | 425 | data_batch=SimpleBatch(data_names, data_all, label_names, label_all, self.buckets[i_bucket], pad) 426 | yield data_batch 427 | 428 | def reset(self): 429 | self.bucket_curr_idx = [0 for x in self.data] 430 | -------------------------------------------------------------------------------- /mod_lectrack.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import sys 4 | import copy 5 | import math 6 | import os 7 | import json 8 | 9 | from lectrack import LecTrack 10 | from bucket_io import read_1best_dialog_content, default_text2id 11 | from mat_io import read_nbest_dialog_content, text2bow 12 | from turnbow_io import nbest_text2bow 13 | from mat_data import genTurnData_nbest 14 | 15 | from mat_data import vocab, vocab1, ontologyDict 16 | 17 | # 脚本所在位置 18 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 19 | 20 | 21 | class ModTracker(object): 22 | """ModTracker implementation with LecTrack 23 | ------ 24 | config_dict { 25 | nn_type: string, "lstm" or "blstm" 26 | for_train: boolean, whether to train the models during running 27 | pre_train: boolean, whether to load pre-trained models 28 | model_dir: string, dir of pre_train models 29 | } 30 | """ 31 | def __init__(self, config_dict=None): 32 | self.ontology = ontologyDict 33 | 34 | 35 | # configuration 36 | config_dict = config_dict or {} 37 | self.context_type = config_dict.get('context_type', 'cpu') 38 | self.nn_type = config_dict.get('nn_type', "lstm") 39 | self.batch_size = config_dict.get('batch_size', 32) 40 | self.for_train = config_dict.get('for_train', False) 41 | self.pre_train = config_dict.get('pre_train', True) 42 | self.model_dir = config_dict.get('model_dir', os.path.join(cur_dir, 'models')) 43 | 44 | # ################################## 45 | self.label_index_list = ['goal_food', 'goal_pricerange', 'goal_name', 'goal_area'] 46 | self.num_label_dict = { 47 | 'goal_food': len(self.ontology["informable"]["food"]), 48 | 'goal_pricerange': len(self.ontology["informable"]["pricerange"]), 49 | 'goal_name': len(self.ontology["informable"]["name"]), 50 | 'goal_area': len(self.ontology["informable"]["area"]), 51 | #'method': len(self.ontology["method"]), 52 | #'requested': len(self.ontology["requestable"]) 53 | } 54 | 55 | # ################################## 56 | self.track_dict = {} 57 | self.vocab = {} 58 | for label_name, num_label in self.num_label_dict.items(): 59 | #if label_name=='goal_food' or label_name=='goal_name': 60 | self.vocab[label_name]= vocab1 61 | #else: 62 | # self.vocab[label_name]= vocab2 63 | lectrack_config_dict = { 64 | "input_size": len(self.vocab[label_name]), 65 | "num_label": num_label, 66 | "output_type": 'sigmoid' if label_name == 'requested' else 'softmax', 67 | "nn_type": self.nn_type, 68 | "batch_size": self.batch_size, 69 | "context_type": self.context_type, 70 | } 71 | print label_name 72 | print lectrack_config_dict["input_size"] 73 | self.track_dict[label_name] = LecTrack(lectrack_config_dict) 74 | 75 | # load pre-trained params if needed 76 | if self.pre_train == True: 77 | for label_index, label_name in enumerate(self.label_index_list): 78 | # TODO 79 | 80 | #params_path = os.path.join(self.model_dir, 'labelIdx-%d-final.params' % label_index) 81 | params_path = os.path.join(self.model_dir, 'labelIdx-%d-testbest.params' % label_index) 82 | 83 | print label_index,label_name 84 | self.track_dict[label_name].load_params(params_path) 85 | 86 | # about dialogue history 87 | self.history_index = 0 88 | self.history_label = { 89 | "turns": [] 90 | } 91 | self.history_log = { 92 | "turns": [] 93 | } 94 | # reset 95 | self.reset() 96 | 97 | def reset(self): 98 | """重置/初始化 各个成员变量""" 99 | self.turn = 0 100 | 101 | # TODO CALL ONLY WHEN NEEDED!!! save current dialogue history to file 102 | #self._saveHistory() 103 | 104 | if len(self.history_log["turns"]) > 0: 105 | self.history_index += 1 106 | self.history_label = {"turns": []} 107 | self.history_log = {"turns": []} 108 | 109 | def _saveHistory(self): 110 | """(需要的话再调用该函数)保存当前对话的历史到文件中""" 111 | save_dialog_dir = os.path.join(cur_dir, 'dst_history') 112 | if not os.path.exists(save_dialog_dir): 113 | os.makedirs(save_dialog_dir) 114 | if len(self.history_log["turns"]) > 0: 115 | label_path = os.path.join(save_dialog_dir, 'label-%d.json' % self.history_index) 116 | json.dump(self.history_label, open(label_path, 'w'), indent=4) 117 | log_path = os.path.join(save_dialog_dir, 'log-%d.json' % self.history_index) 118 | json.dump(self.history_log, open(log_path, 'w'), indent=4) 119 | 120 | def _updateHistory(self, dm_output, asr_output, us_live_goal): 121 | """更新当前对话(同一个session)的历史信息""" 122 | # transfer us_live_goal to specific format if needed 123 | tmp_us_live_goal = copy.deepcopy(us_live_goal) 124 | if isinstance(tmp_us_live_goal["requested-slots"], dict): 125 | tmp_us_live_goal["requested-slots"] = [] 126 | # update history_label 127 | self.history_label["turns"].append(tmp_us_live_goal) 128 | 129 | # transfer dm_output format to specific format if needed 130 | log_output = {} 131 | if "dialog-acts" in dm_output: 132 | log_output = dm_output 133 | else: 134 | tmp_dm_output = [] 135 | for act in dm_output: 136 | new_act = { 137 | "act": act["act_type"], 138 | "slots": [] 139 | } 140 | if "slot_name" in act and "slot_val" in act: 141 | new_act["slots"].append([act["slot_name"], act["slot_val"]]) 142 | elif "slot_name" in act: 143 | new_act["slots"].append(["slot", act["slot_name"]]) 144 | tmp_dm_output.append(new_act) 145 | log_output["dialog-acts"] = tmp_dm_output 146 | 147 | # transfer asr_output score to log-score if needed 148 | log_input = {} 149 | if "live" in asr_output: 150 | log_input = asr_output 151 | else: 152 | tmp_asr_output = [] 153 | for hyp in asr_output["gen+asr"]: 154 | tmp_asr_output.append({ 155 | "asr-hyp": hyp["asr-hyp"], 156 | "score": math.log(hyp['score']) 157 | }) 158 | log_input = { 159 | "live": { 160 | "asr-hyps": tmp_asr_output 161 | } 162 | } 163 | 164 | # update history_log 165 | self.history_log["turns"].append({ 166 | "input": log_input, 167 | "output": log_output 168 | }) 169 | 170 | def _updateState(self, cur_state, cur_outputs, label_name, top_n=sys.maxint): 171 | """根据lstm的输出和当前label_name更新cur_state(原地修改), 该函数的调用不能覆盖其他不相关label的值""" 172 | def float_floor(num): 173 | return float('%0.4f'%(num-0.00005 if num-0.00005>0 else 0.0)) 174 | 175 | key_list = label_name.split('_') 176 | if key_list[0] == 'requested': 177 | cur_state["requested-slots"] = {} 178 | for i in xrange(len(cur_outputs)-1): 179 | slot_name = self.ontology["requestable"][i] 180 | cur_state["requested-slots"][slot_name] = float_floor(cur_outputs[i]) 181 | #cur_state["requested-slots"][slot_name] = cur_outputs[i] 182 | elif key_list[0] == 'method': 183 | cur_state["method-label"] = {} 184 | for i in xrange(len(cur_outputs)): 185 | tmp_key = self.ontology["method"][i] 186 | cur_state["method-label"][tmp_key] = float_floor(cur_outputs[i]) 187 | elif key_list[0] == 'goal': 188 | # 只取概率最高的若干个和 "none" 作为可能的取值 189 | cur_pairs = [(cur_outputs[i], i) for i in xrange(len(cur_outputs))] 190 | max_part = sorted(cur_pairs[:-1], key=lambda x:x[0], reverse=True)[:top_n] 191 | max_part += [cur_pairs[-1]] 192 | tmp_sum = sum([p[0] for p in max_part]) 193 | max_part = [(float_floor(p[0]/tmp_sum), p[1]) for p in max_part] 194 | slot_name = key_list[1] 195 | cur_state["goal-labels"][slot_name] = {} 196 | for p in max_part: 197 | tmp_key = self.ontology["informable"][slot_name][p[1]] if p[1] != len(cur_outputs)-1 else 'none' 198 | if p[0] > 0.0: 199 | cur_state["goal-labels"][slot_name][tmp_key] = p[0] 200 | 201 | def get_new_state(self, dm_output, asr_output, pre_state=None, us_live_goal=None): 202 | """[用于线上]生成新的state,调用该方法会影响self的状态,即当前输入的turn会被认为与之前的turn相关""" 203 | self.turn += 1 204 | cur_state = { 205 | "goal-labels": {}, 206 | "method-label": { 207 | "none": 1.0 208 | }, 209 | "requested-slots": {} 210 | } 211 | self._updateHistory(dm_output, asr_output, us_live_goal) 212 | 213 | # construct data format for generating DataBatch 214 | fileData = {} 215 | fileData["turns"] = [] 216 | for i in xrange(len(self.history_label["turns"])): 217 | turnData = genTurnData_nbest(self.history_log["turns"][i], self.history_label["turns"][i]) 218 | fileData["turns"].append(turnData) 219 | 220 | # update state 221 | for label_index, label_name in enumerate(self.label_index_list): 222 | dialog_sentences, dialog_scores, dialog_labels = read_1best_dialog_content(fileData, label_index) 223 | cur_sentence, cur_score, cur_label = dialog_sentences[-1], dialog_scores[-1], dialog_labels[-1] 224 | 225 | cur_sentence = default_text2id(cur_sentence, self.vocab[label_name]) 226 | assert len(cur_sentence) > 0 and len(cur_sentence) == len(cur_score) 227 | 228 | tmp_label_out = len(cur_label) if self.track_dict[label_name].output_type == 'sigmoid' else 1 229 | data_batch = self.track_dict[label_name].oneSentenceBatch(cur_sentence, cur_score, cur_label, tmp_label_out) 230 | cur_outputs = self.track_dict[label_name].predict(data_batch)[0] 231 | cur_outputs = cur_outputs[0].asnumpy() 232 | self._updateState(cur_state, cur_outputs, label_name, top_n=5) 233 | 234 | # remove "signature" from requested_slots 235 | if "signature" in cur_state["requested-slots"]: 236 | del cur_state["requested-slots"]["signature"] 237 | 238 | return cur_state 239 | 240 | def get_batch_new_state(self, fileDatas): 241 | """[用于线下]同时生成多个新state,调用该方法不会影响self的状态(注意样本的个数不能超过LecTrack的batch_size) 242 | batch的每个example是:包含若干个turn的一个对话拼接成一个长的句子,输出是:对于每个example有一个预测值""" 243 | assert(len(fileDatas) <= self.batch_size) 244 | tracker_outputs = [] 245 | for i in xrange(len(fileDatas)): 246 | tracker_outputs.append({ 247 | "goal-labels": {}, 248 | "method-label": { 249 | "none": 1.0 250 | }, 251 | "requested-slots": {} 252 | }) 253 | 254 | for label_index, label_name in enumerate(self.label_index_list): 255 | sentences, scores, labels = [], [], [] 256 | for fileData in fileDatas: 257 | dialog_sentences, dialog_scores, dialog_labels = read_1best_dialog_content(fileData, label_index) 258 | cur_sentence, cur_score, cur_label = dialog_sentences[-1], dialog_scores[-1], dialog_labels[-1] 259 | 260 | cur_sentence = default_text2id(cur_sentence, self.vocab[label_name]) 261 | assert len(cur_sentence) > 0 and len(cur_sentence) == len(cur_score) 262 | 263 | sentences.append(cur_sentence) 264 | scores.append(cur_score) 265 | labels.append(cur_label) 266 | tmp_label_out = len(cur_label) if self.track_dict[label_name].output_type == 'sigmoid' else 1 267 | 268 | data_batch = self.track_dict[label_name].multiWordSentBatch(sentences, scores, labels, tmp_label_out) 269 | outputs = self.track_dict[label_name].predict(data_batch)[0] 270 | 271 | for i in xrange(len(tracker_outputs)): 272 | cur_outputs = outputs[i].asnumpy() 273 | self._updateState(tracker_outputs[i], cur_outputs, label_name, top_n=10) 274 | 275 | # remove "signature" from requested_slots 276 | for i in xrange(len(tracker_outputs)): 277 | if "signature" in tracker_outputs[i]["requested-slots"]: 278 | del tracker_outputs[i]["requested-slots"]["signature"] 279 | 280 | return tracker_outputs 281 | 282 | def get_turn_batch_state(self, fileDatas, feature_type='bow'): 283 | """[用于线下]同时生成多个新state,调用该方法不会影响self的状态 284 | batch的每个example是:包含若干个单独turn的一个对话每个turn是一个sentence,输出是:对于每个example的每个turn都有一个预测值 285 | Parameters: 286 | feature_type, 仅当level=turn时有效,取值['bow', 'sentsent', 'sentbow'], 表明turn级别的特征是如何生成的""" 287 | tracker_outputs = [] 288 | for i in xrange(len(fileDatas[0])): 289 | tracker_outputs.append([]) 290 | for j in xrange(len(fileDatas[0][i]["turns"])): 291 | tracker_outputs[i].append({ 292 | "goal-labels": {}, 293 | "method-label": { 294 | "none": 1.0 295 | }, 296 | "requested-slots": {} 297 | }) 298 | 299 | 300 | for label_index, label_name in enumerate(self.label_index_list): 301 | # generate data for batch use accroding to current feature_type 302 | if feature_type == 'bow': 303 | sentences, acts, labels = [], [], [] 304 | for fileData in fileDatas: 305 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = read_nbest_dialog_content(fileData, label_index) 306 | sentence_turn, act_turn, label_turn = [], [], [] 307 | for turn_id in xrange(len(dialog_sentences)): 308 | cur_sentence = nbest_text2bow(dialog_sentences[turn_id], dialog_scores[turn_id], self.vocab[label_name]) 309 | cur_act = text2bow(machine_acts[turn_id], self.vocab[label_name]) 310 | cur_label = dialog_labels[turn_id] 311 | tmp_label_out = len(cur_label) if self.track_dict[label_name].output_type == 'sigmoid' else 1 312 | sentence_turn.append(cur_sentence) 313 | act_turn.append(cur_act) 314 | label_turn.append(cur_label) 315 | sentences.append(sentence_turn) 316 | acts.append(act_turn) 317 | labels.append(label_turn) 318 | data_batch = self.track_dict[label_name].multiTurnBowBatch(sentences, acts, labels, tmp_label_out) 319 | 320 | elif feature_type in ['sentsent', 'sentbow', 'bowsent', 'bowbow']: 321 | def turn_read_content(fileDatas,dataIdx,feature_type): 322 | sentences, acts, scores, labels = [], [], [], [] 323 | 324 | for fileData in fileDatas[dataIdx]: 325 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = read_nbest_dialog_content(fileData, label_index) 326 | 327 | sentence_turn, act_turn, score_turn, label_turn = [], [], [], [] 328 | 329 | for turn_id in xrange(len(dialog_sentences)): 330 | cur_sentence = [] 331 | # user sentence feature 332 | for nbest_id in range(len(dialog_sentences[turn_id])): 333 | if feature_type in ['sentsent', 'sentbow']: 334 | cur_sentbest = default_text2id(dialog_sentences[turn_id][nbest_id], vocab) 335 | elif feature_type in ['bowsent', 'bowbow']: 336 | cur_sentbest = text2bow(dialog_sentences[turn_id][nbest_id], self.vocab[label_name]) 337 | cur_sentence.append(cur_sentbest) 338 | # sys act feature 339 | if feature_type in ['sentbow', 'bowbow']: 340 | cur_act = text2bow(machine_acts[turn_id], self.vocab[label_name]) 341 | elif feature_type in ['sentsent', 'bowsent']: 342 | cur_act = default_text2id(machine_acts[turn_id], self.vocab[label_name]) 343 | cur_score = dialog_scores[turn_id] 344 | cur_label = dialog_labels[turn_id] 345 | tmp_label_out = len(cur_label) if self.track_dict[label_name].output_type == 'sigmoid' else 1 346 | sentence_turn.append(cur_sentence) 347 | act_turn.append(cur_act) 348 | score_turn.append(cur_score) 349 | label_turn.append(cur_label) 350 | 351 | sentences.append(sentence_turn) 352 | scores.append(score_turn) 353 | acts.append(act_turn) 354 | labels.append(label_turn) 355 | return sentences, scores, acts, labels,tmp_label_out 356 | 357 | sentences, scores, acts, labels,tmp_label_out = turn_read_content(fileDatas,0,feature_type) 358 | sentences1, scores1, acts1, labels1,tmp_label_out= turn_read_content(fileDatas,1,feature_type) 359 | sentences0=[] 360 | for i in range(len(sentences)): 361 | dialog0=[] 362 | for j in range(len(sentences[i])): 363 | sent=[] 364 | sent.append(sentences[i][j]) 365 | sent.append(sentences1[i][j]) 366 | dialog0.append(sent) 367 | sentences0.append(dialog0) 368 | sentences=sentences0 369 | act0=[] 370 | for i in range(len(acts)): 371 | dialog0=[] 372 | for j in range(len(acts[i])): 373 | act=[] 374 | act.append(acts[i][j]) 375 | act.append(acts1[i][j]) 376 | dialog0.append(act) 377 | act0.append(dialog0) 378 | acts=act0 379 | 380 | data_batch = self.track_dict[label_name].multiTurnBatch(label_index,sentences, acts, scores, labels, tmp_label_out, vocab, self.vocab[label_name],feature_type) 381 | 382 | outputs = self.track_dict[label_name].predict(data_batch)[0] 383 | if outputs.shape[0] != self.batch_size: 384 | outputs = outputs.reshape((self.batch_size, -1,) + outputs.shape[1:]) 385 | 386 | for i in xrange(len(tracker_outputs)): 387 | for j in xrange(len(tracker_outputs[i])): 388 | cur_outputs = outputs[i][j].asnumpy() 389 | self._updateState(tracker_outputs[i][j], cur_outputs, label_name, top_n=10) 390 | 391 | # remove "signature" from requested_slots 392 | for i in xrange(len(tracker_outputs)): 393 | for j in xrange(len(tracker_outputs[i])): 394 | if "signature" in tracker_outputs[i][j]["requested-slots"]: 395 | del tracker_outputs[i][j]["requested-slots"]["signature"] 396 | 397 | return tracker_outputs 398 | 399 | 400 | if __name__ == '__main__': 401 | import logging 402 | head = '%(asctime)-15s %(message)s' 403 | logging.basicConfig(level=logging.DEBUG, format=head) 404 | 405 | mod_tracker = ModTracker() 406 | 407 | tmp_us_live_goal = { 408 | "goal-labels": { 409 | "food": "french", 410 | "name": "hotel du vin and bistro" 411 | }, 412 | "method-label": "byname", 413 | "requested-slots": [] 414 | } 415 | tmp_dm_output = [ 416 | #{ 417 | # "slot_name": "chinese", 418 | # "act_type": "inform", 419 | # "slot_val": "food", 420 | #} 421 | #{ 422 | # "slot_name": "pricerange", 423 | # "act_type": "request", 424 | #} 425 | { 426 | "act_type": "welcomemsg", 427 | } 428 | ] 429 | tmp_asr_output = { 430 | "gen+asr": [ 431 | { 432 | #"asr-hyp": "i dont care", 433 | "asr-hyp": "expensive restaurant in west", 434 | #"asr-hyp": "what about the price", 435 | "score": 0.9 436 | }, 437 | { 438 | #"asr-hyp": "any will do", 439 | "asr-hyp": "i am", 440 | "score": 0.1 441 | } 442 | ], 443 | } 444 | 445 | tmp_output = mod_tracker.get_new_state(tmp_dm_output, tmp_asr_output, pre_state=None, us_live_goal=tmp_us_live_goal) 446 | print '------mod_tracker--------' 447 | print tmp_output 448 | print json.dumps(tmp_output, indent=4) 449 | print '------------------------------------------------' 450 | -------------------------------------------------------------------------------- /offline_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import os 5 | import sys 6 | import pickle 7 | import numpy as np 8 | 9 | import mxnet as mx 10 | 11 | from lectrack import LecTrack 12 | from bucket_io import DSTSentenceIter 13 | from turnbow_io import DSTTurnIter 14 | from mat_io import MATTurnSentIter 15 | from turnsent_io import DSTTurnSentIter 16 | 17 | from mat_data import vocab,vocab1, ontologyDict 18 | 19 | # 脚本所在位置 20 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 21 | 22 | 23 | class OfflineModel(object): 24 | def __init__(self, labelIdx,ini, context_type='cpu', config_dict=None): 25 | self.labelIdx = labelIdx 26 | self.ontology = ontologyDict 27 | self.vocab = vocab 28 | self.ini =ini 29 | self.vocab1 = vocab1 30 | # configuration 31 | config_dict = config_dict or {} 32 | self.nn_type = config_dict.get('nn_type', "lstm") 33 | self.model_dir = config_dict.get('model_dir', os.path.join(cur_dir, 'models')) 34 | self.train_json_file = config_dict.get('train_json', os.path.join(cur_dir, 'train_custom.json')) 35 | self.dev_json_file = config_dict.get('dev_json', os.path.join(cur_dir, 'dev_custom.json')) 36 | self.test_json_file = config_dict.get('test_json', os.path.join(cur_dir, 'test_custom.json')) 37 | 38 | if not os.path.exists(self.model_dir): 39 | os.makedirs(self.model_dir) 40 | print(mx.__version__) 41 | # ################################## 42 | self.label_index_list = ['goal_food', 'goal_pricerange', 'goal_name', 'goal_area', 'method', 'requested'] 43 | self.num_label_dict = { 44 | 'goal_food': len(self.ontology["informable"]["food"]), 45 | 'goal_pricerange': len(self.ontology["informable"]["pricerange"]), 46 | 'goal_name': len(self.ontology["informable"]["name"]), 47 | 'goal_area': len(self.ontology["informable"]["area"]), 48 | 'method': len(self.ontology["method"]), 49 | 'requested': len(self.ontology["requestable"]) 50 | } 51 | print self.num_label_dict 52 | # ################################## 53 | self.label_name=[] 54 | for i in labelIdx: 55 | self.label_name.append(self.label_index_list[i]) 56 | self.num_label=[] 57 | for k in self.label_name: 58 | self.num_label.append(self.num_label_dict[k]) 59 | 60 | self.batch_size = 32 61 | if self.nn_type in ['lstmn']: 62 | self.buckets = [3, 5, 8, 10, 13, 15, 18, 22, 26, 30, 34, 38, 42, 46, 50, 55, 60, 65, 70, 80, 90, 110, 150, 350] 63 | elif self.nn_type in ['cnn']: 64 | self.buckets = [10, 13, 15, 18, 22, 26, 30, 34, 38, 42, 46, 50, 55, 60, 65, 70, 80, 90, 110, 150, 350, 500] 65 | elif self.nn_type in ['reslstm','matlstm','bowlstm', 'cnnlstm', 'cnncnnlstm','doublelstm']: 66 | self.buckets = range(1, 31) 67 | # FIXME TODO 68 | self.batch_size = 32 69 | else: 70 | self.buckets = [] 71 | 72 | if self.nn_type in ['reslstm','matlstm']: 73 | self.batch_size=32 74 | self.opt='rmsprop' 75 | self.drop=0. 76 | self.lr=0.0005 77 | elif self.nn_type in ['lstmcnn']: 78 | self.opt='sgd' 79 | self.drop=0.#0.2 80 | self.lr=0.001 81 | else: 82 | if self.ini==1: 83 | self.opt='rmsprop'#rmsprop 84 | self.drop=0.#0. 85 | self.lr=0.0005#0.0005 86 | else: 87 | self.opt='adam'#rmsprop 88 | self.drop=0.#0. 89 | self.lr=0.001#0.0005 90 | # TODO be careful about "pretrain_embed" 91 | lectrack_config_dict = { 92 | "batch_size": self.batch_size, 93 | "input_size": len(self.vocab1), 94 | "num_label": self.num_label, 95 | "output_type": 'sigmoid' if 'requested' in self.label_name else 'softmax', 96 | "nn_type": self.nn_type, 97 | "context_type": context_type, 98 | "dropout":self.drop,#0.2 99 | "optimizer": self.opt,#adam 100 | "buckets": self.buckets, 101 | "pretrain_embed": False, 102 | "fix_embed": True, 103 | "num_lstm_o": 128,#256 104 | "num_hidden": 128, 105 | "learning_rate": self.lr,#0.01mat,0.001dlstm 106 | "enable_mask": False, 107 | "N": 1, 108 | "num_embed": 300 109 | } 110 | self.lectrack = LecTrack(lectrack_config_dict) 111 | self.init_states = self.lectrack.init_states 112 | self.data_components = self.lectrack.data_components 113 | if 'requested' in self.label_name: 114 | self.label_out = self.num_label_dict['requested'] 115 | else: 116 | self.label_out = len(self.labelIdx) 117 | 118 | # ################################## 119 | if self.nn_type in ['bowlstm']: 120 | self.data_train = DSTTurnIter(self.train_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 121 | self.init_states, self.data_components, label_out=self.label_out) 122 | self.data_val = DSTTurnIter(self.dev_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 123 | self.init_states, self.data_components, label_out=self.label_out) 124 | self.data_test = DSTTurnIter(self.test_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 125 | self.init_states, self.data_components, label_out=self.label_out) 126 | elif self.nn_type in ['reslstm','matlstm','doublelstm']: 127 | self.max_nbest = self.lectrack.max_nbest 128 | self.max_sentlen = self.lectrack.max_sentlen 129 | if self.nn_type in ['doublelstm']: 130 | feature_type = 'sentbow' 131 | elif self.nn_type in ['reslstm','matlstm']: 132 | feature_type= 'bowbow' 133 | else: 134 | feature_type = 'sentsent' 135 | self.data_train = MATTurnSentIter(self.train_json_file, self.labelIdx, self.vocab, self.vocab1,self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 136 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 137 | self.data_val = MATTurnSentIter(self.dev_json_file, self.labelIdx, self.vocab, self.vocab1,self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 138 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 139 | self.data_test = MATTurnSentIter(self.test_json_file, self.labelIdx, self.vocab, self.vocab1,self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 140 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 141 | elif self.nn_type in ['cnncnnlstm','cnnlstm']: 142 | self.max_nbest = self.lectrack.max_nbest 143 | self.max_sentlen = self.lectrack.max_sentlen 144 | feature_type = 'sentsent' 145 | self.data_train = DSTTurnSentIter(self.train_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 146 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 147 | self.data_val = DSTTurnSentIter(self.dev_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 148 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 149 | self.data_test = DSTTurnSentIter(self.test_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, self.max_nbest, self.max_sentlen, 150 | self.init_states, self.data_components, label_out=self.label_out, feature_type=feature_type) 151 | else: 152 | self.data_train = DSTSentenceIter(self.train_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 153 | self.init_states, self.data_components, label_out=self.label_out) 154 | self.data_val = DSTSentenceIter(self.dev_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 155 | self.init_states, self.data_components, label_out=self.label_out) 156 | self.data_test = DSTSentenceIter(self.test_json_file, self.labelIdx, self.vocab, self.buckets, self.batch_size, 157 | self.init_states, self.data_components, label_out=self.label_out) 158 | # ################################## 159 | if self.lectrack.optimizer == 'adam': 160 | lrFactor=0.5 161 | epochSize=self.data_train.size 162 | step_epochs=[150] 163 | steps=[epochSize*x for x in step_epochs] 164 | scheduler= mx.lr_scheduler.MultiFactorScheduler(step=steps,factor=lrFactor) 165 | self.lectrack.model.init_optimizer(optimizer='adam',optimizer_params=(('learning_rate',self.lectrack.learning_rate),('clip_gradient',2.0),))#, ('lr_scheduler',scheduler),)) 166 | #self.model.init_optimizer(optimizer='adam',optimizer_params=(('learning_rate',self.learning_rate), ('clip_gradient',5.0))) 167 | elif self.lectrack.optimizer == 'adagrad': 168 | self.lectrack.model.init_optimizer(optimizer='adagrad') 169 | elif self.lectrack.optimizer == 'rmsprop': 170 | self.lectrack.model.init_optimizer(optimizer='rmsprop',optimizer_params=(('learning_rate',self.lectrack.learning_rate),('centered',False), ))#('gamma1',0.95),('clip_gradient',10), )) 171 | elif self.lectrack.optimizer == 'sgd': 172 | lrFactor=0.5 173 | epochSize=self.data_train.size 174 | step_epochs=[30,80,130] 175 | steps=[epochSize*x for x in step_epochs] 176 | scheduler= mx.lr_scheduler.MultiFactorScheduler(step=steps,factor=lrFactor) 177 | self.lectrack.model.init_optimizer(optimizer='sgd',optimizer_params=(('learning_rate',self.lectrack.learning_rate),('lr_scheduler',scheduler),('momentum',0.9), ('clip_gradient',5.0), ('wd',0.0001), )) 178 | print '[INFO]: Using optimizer - %s' % self.lectrack.optimizer 179 | 180 | if self.ini==0: 181 | params_path = os.path.join(self.model_dir, 'labelIdx-99-devbest.params' ) 182 | if os.path.exists(params_path): 183 | self.lectrack.load_params(params_path) 184 | 185 | 186 | 187 | # ################################## 188 | if 'requested' in self.label_name: 189 | def customAcc(labels, preds): 190 | ret = 0 191 | if len(labels.shape) > len(preds.shape): 192 | preds_true_shape = (self.batch_size, preds.shape[0]/self.batch_size) + preds.shape[1:] 193 | preds = preds.reshape(preds_true_shape) 194 | for label, pred_label in zip(labels, preds): 195 | pred_label = (pred_label + 0.5).astype('int32') 196 | label = label.astype('int32') 197 | ret += np.sum(np.all(np.equal(pred_label, label), axis=-1)) 198 | return (ret, len(labels.reshape((-1, labels.shape[-1])))) 199 | self.metric = mx.metric.CustomMetric(customAcc) 200 | else: 201 | def customAcc(labels, preds): 202 | ret = 0 203 | l=0 204 | for label, pred_label in zip(labels, preds): 205 | # if pred_label.shape != label.shape: 206 | # pred_label = np.argmax(pred_label, axis=-1) 207 | pred_label = pred_label.astype('int32') 208 | label = label.astype('int32') 209 | ret += (pred_label == label).all(-1).sum() 210 | l+= len((pred_label == label).all(-1)) 211 | return (ret, l) 212 | self.metric = mx.metric.CustomMetric(customAcc,output_names=['softmax_output'],label_names=['softmax_label']) 213 | 214 | def custom_score(self, eval_data, eval_metric): 215 | eval_data.reset() 216 | eval_metric.reset() 217 | pad_count = 0 218 | 219 | for nbatch, eval_batch in enumerate(eval_data): 220 | label_shape = eval_batch.provide_label[0][1] 221 | pad_count += eval_batch.pad * label_shape[1] 222 | self.lectrack.model.forward(eval_batch, is_train=False) 223 | self.lectrack.model.update_metric(eval_metric, eval_batch.label) 224 | eval_metric.sum_metric -= pad_count 225 | eval_metric.num_inst -= pad_count 226 | return eval_metric.get_name_value() 227 | 228 | 229 | def offline_train(self, num_epoch=20): 230 | print '====== Train labelIdx-%d:' % 99# self.labelIdx 231 | print '[Start] ', time.strftime('%Y-%m-%d %H-%M', time.localtime(time.time())) 232 | 233 | if self.ini==1: 234 | model_name='labelIdx-%d-devbest.params' % 99 235 | else: 236 | model_name='labelIdx-%d-devbest.params' % 999 237 | 238 | best_dev_info = { 239 | "params_path": os.path.join(self.model_dir, model_name),#self.labelIdx), 240 | "epoch_num": -1, 241 | "acc": 0.0 242 | } 243 | best_test_info = { 244 | "params_path": os.path.join(self.model_dir, 'labelIdx-joint-testbest.params' ), 245 | "epoch_num": -1, 246 | "acc": 0.0 247 | } 248 | 249 | 250 | 251 | for i in range(num_epoch): 252 | print("=============BEGIN EPOCH %d==================="%i) 253 | for batch in self.data_train: 254 | self.lectrack.model.forward(batch, is_train=True) 255 | self.lectrack.model.update_metric(self.metric, batch.label) 256 | self.lectrack.model.backward() 257 | self.lectrack.model.update() 258 | print(self.metric.get()) 259 | print '[Training over] ', time.strftime('%Y-%m-%d %H-%M', time.localtime(time.time())) 260 | 261 | #params_path = os.path.join(self.model_dir, 'labelIdx-%d-epoch-%d.params' % (self.labelIdx, i)) 262 | #self.lectrack.model.save_params(params_path) 263 | self.data_train.reset() 264 | 265 | # evaluate on different data set and save best model automatically 266 | # FIXME using custom score function 267 | print 'train:', self.custom_score(self.data_train, self.metric) 268 | dev_score = self.custom_score(self.data_val, self.metric) 269 | print 'dev :', dev_score 270 | test_score = self.custom_score(self.data_test, self.metric) 271 | print 'test :', test_score 272 | print '[Testing over] ', time.strftime('%Y-%m-%d %H-%M', time.localtime(time.time())) 273 | sys.stdout.flush() 274 | if dev_score[0][1] > best_dev_info["acc"]: 275 | best_dev_info["acc"] = dev_score[0][1] 276 | best_dev_info["epoch_num"] = i 277 | self.lectrack.model.save_params(best_dev_info["params_path"]) 278 | if test_score[0][1] > best_test_info["acc"]: 279 | best_test_info["acc"] = test_score[0][1] 280 | best_test_info["epoch_num"] = i 281 | self.lectrack.model.save_params(best_test_info["params_path"]) 282 | 283 | print 'devbest epoch: %s, acc: %s' % (best_dev_info['epoch_num'], best_dev_info['acc']) 284 | print 'testbest epoch: %s, acc: %s' % (best_test_info['epoch_num'], best_test_info['acc']) 285 | 286 | print '[End] ', time.strftime('%Y-%m-%d %H-%M',time.localtime(time.time())) 287 | print "="*50 288 | sys.stdout.flush() 289 | 290 | 291 | def offline_eval(self): 292 | print '====== Eval labelIdx-%d:' % 99 293 | print '[Start] ', time.strftime('%Y-%m-%d %H-%M', time.localtime(time.time())) 294 | 295 | # load pre-trained params 296 | params_path = os.path.join(self.model_dir, 'labelIdx-%d-devbest.params' % 999) 297 | self.lectrack.load_params(params_path) 298 | 299 | print 'train:', self.custom_score(self.data_train, self.metric) 300 | dev_score = self.custom_score(self.data_val, self.metric) 301 | print 'dev :', dev_score 302 | test_score = self.custom_score(self.data_test, self.metric) 303 | print 'test :', test_score 304 | print '[Testing over] ', time.strftime('%Y-%m-%d %H-%M', time.localtime(time.time())) 305 | sys.stdout.flush() 306 | 307 | #print 'train:', self.lectrack.model.score(self.data_train, self.metric) 308 | #print 'dev :', self.lectrack.model.score(self.data_val, self.metric) 309 | #print 'test :', self.lectrack.model.score(self.data_test, self.metric) 310 | 311 | #print '[End] ', time.strftime('%Y-%m-%d %H-%M',time.localtime(time.time())) 312 | 313 | 314 | def evalAll(): 315 | ctx = 'cpu' 316 | 317 | #tmp_model = OfflineModel(5, ctx) 318 | #tmp_model.offline_eval() 319 | for i in xrange(6): 320 | tmp_model = OfflineModel(i, ctx) 321 | tmp_model.offline_eval() 322 | 323 | def trainAll(): 324 | ctx = 'cpu' 325 | 326 | #tmp_model = OfflineModel(5, ctx) 327 | #tmp_model.offline_train() 328 | for i in xrange(6): 329 | tmp_model = OfflineModel(i, ctx) 330 | tmp_model.offline_train() 331 | 332 | 333 | if __name__ == '__main__': 334 | pass 335 | #evalAll() 336 | #trainAll() 337 | -------------------------------------------------------------------------------- /offline_model_dstc.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | import time 5 | import json 6 | import copy 7 | import math 8 | import numpy as np 9 | import mxnet as mx 10 | from offline_model import OfflineModel 11 | from mod_lectrack import ModTracker 12 | from mat_data import genTurnData_nbest, genTurnData_nbest_tagged,gen_resdata 13 | 14 | import dataset_walker 15 | 16 | # 脚本所在位置 17 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | # TODO offline config 20 | # nn_type can be ['lstm', 'blstm', 'lstmn', 'cnn', 'bowlstm'] 21 | modSel=1 22 | if modSel==0: 23 | # blstm config 24 | offline_config_dict = { 25 | 'nn_type': 'lstmcnn', 26 | 'model_dir': os.path.join(cur_dir, 'models_dstc_lstmcnnRes'), 27 | 'train_json': os.path.join(cur_dir, 'train_nbest.json'), 28 | 'dev_json': os.path.join(cur_dir, 'dev_nbest.json'), 29 | 'test_json': os.path.join(cur_dir, 'test_nbest.json') 30 | } 31 | elif modSel==1: 32 | #lstm config 33 | offline_config_dict = { 34 | 'nn_type': 'doublelstm', 35 | 'model_dir': os.path.join(cur_dir, 'models_dstc_caplstmSCLT'), 36 | 'train_json': os.path.join(cur_dir, 'train_nbest_tagged.json'), 37 | 'dev_json': os.path.join(cur_dir, 'dev_nbest_tagged.json'), 38 | 'test_json': os.path.join(cur_dir, 'test_nbest_tagged.json') 39 | } 40 | else: 41 | # mat config 42 | offline_config_dict = { 43 | 'nn_type': 'reslstm', 44 | 'model_dir': os.path.join(cur_dir, 'models_dstc_reslstmTest1'), 45 | 'train_json': os.path.join(cur_dir, 'train_nbest_tagged.json'), 46 | 'dev_json': os.path.join(cur_dir, 'dev_nbest_tagged.json'), 47 | 'test_json': os.path.join(cur_dir, 'test_nbest_tagged.json') 48 | } 49 | 50 | 51 | # 1best tagged lstm config 52 | #offline_config_dict = { 53 | # 'nn_type': 'lstm', 54 | # 'model_dir': os.path.join(cur_dir, 'models_dstc2_tagged'), 55 | # 'train_json': os.path.join(cur_dir, 'train_nbest_tagged.json'), 56 | # 'dev_json': os.path.join(cur_dir, 'dev_nbest_tagged.json'), 57 | # 'test_json': os.path.join(cur_dir, 'test_nbest_tagged.json') 58 | #} 59 | 60 | # cnn config 61 | #offline_config_dict = { 62 | # 'nn_type': 'cnn', 63 | # 'model_dir': os.path.join(cur_dir, 'models_cnn'), 64 | # 'train_json': os.path.join(cur_dir, 'train_nbest.json'), 65 | # 'dev_json': os.path.join(cur_dir, 'dev_nbest.json'), 66 | # 'test_json': os.path.join(cur_dir, 'test_nbest.json') 67 | #} 68 | 69 | # bowlstm config 70 | #offline_config_dict = { 71 | # 'nn_type': 'bowlstm', 72 | # 'model_dir': os.path.join(cur_dir, 'models_bowlstm'), 73 | # 'train_json': os.path.join(cur_dir, 'train_nbest.json'), 74 | # 'dev_json': os.path.join(cur_dir, 'dev_nbest.json'), 75 | # 'test_json': os.path.join(cur_dir, 'test_nbest.json') 76 | #} 77 | 78 | # cnnlstm config 79 | #offline_config_dict = { 80 | # 'nn_type': 'cnnlstm', 81 | # 'model_dir': os.path.join(cur_dir, 'models_cnnlstm'), 82 | # 'train_json': os.path.join(cur_dir, 'train_nbest.json'), 83 | # 'dev_json': os.path.join(cur_dir, 'dev_nbest.json'), 84 | # 'test_json': os.path.join(cur_dir, 'test_nbest.json') 85 | #} 86 | 87 | # cnncnnlstm config 88 | # offline_config_dict = { 89 | # 'nn_type': 'cnncnnlstm', 90 | # 'model_dir': os.path.join(cur_dir, 'models_cnncnnlstm'), 91 | # 'train_json': os.path.join(cur_dir, 'train_nbest.json'), 92 | # 'dev_json': os.path.join(cur_dir, 'dev_nbest.json'), 93 | # 'test_json': os.path.join(cur_dir, 'test_nbest.json') 94 | # } 95 | 96 | def train_dstc2(ini): 97 | #params_path = os.path.join(offline_config_dict['model_dir'], 'labelIdx-joint-testbest.params' ) 98 | #if os.path.exists(params_path): 99 | # os.remove(params_path) 100 | for t in range(1,2): 101 | np.random.seed(t) 102 | ctx = 'gpu' 103 | if ini==1: 104 | index=[0] 105 | else: 106 | index=[0,1,3] 107 | for i in [index]: 108 | tmp_model = OfflineModel(i, ini, ctx, offline_config_dict) 109 | tmp_model.offline_train(150) 110 | #tmp_model.offline_eval() 111 | 112 | def del_none_val(turn_output): 113 | """delete all "none" values in turn_output. In-place operation""" 114 | if "none" in turn_output["requested-slots"]: 115 | del turn_output["requested-slots"]["none"] 116 | for _, vals in turn_output["goal-labels"].items(): 117 | if "none" in vals: 118 | del vals["none"] 119 | 120 | def tag_to_val(turn_output, tag_dict): 121 | goal_output = turn_output["goal-labels"] 122 | for slot in goal_output: 123 | for slotval in copy.deepcopy(goal_output[slot]): 124 | if slotval.startswith('#'): 125 | if slotval in tag_dict: 126 | goal_output[slot][tag_dict[slotval]] = goal_output[slot].get(tag_dict[slotval], 0.0) + goal_output[slot][slotval] 127 | del goal_output[slot][slotval] 128 | 129 | 130 | def gen_baseline_ground(dataset_name, dataroot): 131 | res_ground = { 132 | 'dataset': dataset_name, 133 | 'sessions': [] 134 | } 135 | dataset = dataset_walker.dataset_walker(dataset_name, dataroot=dataroot, labels=True) 136 | for call in dataset: 137 | res_dialogue = dict() 138 | res_dialogue["session-id"] = call.log["session-id"] 139 | res_dialogue["turns"] = list() 140 | for turn, labelJson in call: 141 | turn_label = { 142 | "goal-labels": labelJson["goal-labels"], 143 | "method-label": labelJson["method-label"], 144 | "requested-slots": labelJson["requested-slots"] 145 | } 146 | res_dialogue["turns"].append(turn_label) 147 | res_ground["sessions"].append(res_dialogue) 148 | json.dump(res_ground, open('baseline_ground_%s.json'%dataset_name, 'wb'), indent=4) 149 | 150 | 151 | def gen_baseline(dataset_name, dataroot, tagged=False): 152 | res = { 153 | 'dataset': dataset_name, 154 | 'sessions': [] 155 | } 156 | dataset = dataset_walker.dataset_walker(dataset_name, dataroot=dataroot, labels=True) 157 | mod_config_dict = { 158 | 'context_type': 'cpu', 159 | 'nn_type': offline_config_dict["nn_type"], 160 | 'model_dir':offline_config_dict["model_dir"] 161 | } 162 | if mod_config_dict['nn_type'] in ['doublelstm','reslstm','matlstm','cnnlstm', 'cnncnnlstm']: 163 | mod_config_dict['batch_size'] = 32 164 | 165 | mod_tracker = ModTracker(config_dict=mod_config_dict) 166 | start_time = time.time() 167 | 168 | # decide how to process data 169 | if mod_config_dict['nn_type'] in ['bowlstm']: 170 | level = 'turn' 171 | feature_type = 'bow' 172 | elif mod_config_dict['nn_type'] in ['reslstm','matlstm','cnnlstm']: 173 | level = 'turn' 174 | feature_type = 'bowbow' 175 | elif mod_config_dict['nn_type'] in ['doublelstm','cnncnnlstm']: 176 | level = 'turn' 177 | feature_type = 'sentbow' 178 | else: 179 | level = 'word' 180 | 181 | # process by word-level dialogue 182 | if level == 'word': 183 | for call in dataset: 184 | res_dialogue = dict() 185 | res_dialogue["session-id"] = call.log["session-id"] 186 | res_dialogue["turns"] = list() 187 | 188 | fileDatas = [] 189 | tag_dicts = [] 190 | 191 | fileData = {} 192 | fileData["turns"] = [] 193 | for turn, labelJson in call: 194 | if tagged: 195 | turnData = genTurnData_nbest_tagged(turn, labelJson) 196 | tag_dicts.append(turnData["tag_dict"]) 197 | else: 198 | turnData = genTurnData_nbest(turn, labelJson) 199 | fileData["turns"].append(turnData) 200 | fileDatas.append(copy.deepcopy(fileData)) 201 | 202 | tracker_outputs = mod_tracker.get_batch_new_state(fileDatas) 203 | for i in xrange(len(tracker_outputs)): 204 | del_none_val(tracker_outputs[i]) 205 | if tagged: 206 | tag_to_val(tracker_outputs[i], tag_dicts[i]) 207 | res_dialogue["turns"].append(tracker_outputs[i]) 208 | res["sessions"].append(res_dialogue) 209 | print "processed dialogue no.:", len(res["sessions"]) 210 | 211 | # process by turn-level dialogue 212 | elif level == 'turn': 213 | 214 | batch_size = mod_tracker.batch_size 215 | 216 | fileDatas_all=gen_resdata(dataset,'nbest_tagged') 217 | # fileDatas_all = [] 218 | # for call in dataset: 219 | # fileData = {} 220 | # fileData["turns"] = [] 221 | # fileData["session-id"] = call.log["session-id"] 222 | # for turn, labelJson in call: 223 | # turnData = genTurnData_nbest(turn, labelJson) 224 | # fileData["turns"].append(turnData) 225 | # fileDatas_all.append(fileData) 226 | 227 | batch_num = int(math.ceil(len(fileDatas_all[0]) / float(batch_size))) 228 | for j in xrange(batch_num): 229 | fileDatas0 = fileDatas_all[0][batch_size*j: batch_size*(j+1)] 230 | fileDatas1 = fileDatas_all[1][batch_size*j: batch_size*(j+1)] 231 | fileDatas=[] 232 | fileDatas.append(fileDatas0) 233 | fileDatas.append(fileDatas1) 234 | tracker_outputs = mod_tracker.get_turn_batch_state(fileDatas, feature_type) 235 | 236 | for i in xrange(len(fileDatas[0])): 237 | res_dialogue = dict() 238 | res_dialogue["session-id"] = fileDatas[0][i]["session-id"] 239 | res_dialogue["turns"] = tracker_outputs[i] 240 | for turn_output in res_dialogue["turns"]: 241 | del_none_val(turn_output) 242 | res["sessions"].append(res_dialogue) 243 | print "processed dialogue no.:", len(res["sessions"]) 244 | 245 | end_time = time.time() 246 | res['wall-time'] = end_time - start_time 247 | if tagged: 248 | baseline_json_file = 'baseline_%s_tagged.json'%dataset_name 249 | else: 250 | baseline_json_file = 'baseline_%s_dlstm.json'%dataset_name 251 | json.dump(res, open(baseline_json_file, 'wb'), indent=4) 252 | 253 | 254 | if __name__ == '__main__': 255 | # ######################### 256 | # Decide data set 257 | # ######################### 258 | #dataset_name = 'dstc2_train' 259 | #dataroot = 'dstc2_traindev/data' 260 | 261 | #dataset_name = 'dstc2_test' 262 | #dataroot = 'dstc2_traindev/data' 263 | 264 | 265 | # ######################### 266 | # Choose operation 267 | # ######################### 268 | train_dstc2(0) 269 | ##gen_baseline_ground(dataset_name, dataroot) 270 | ##gen_baseline(dataset_name, dataroot, tagged=True) 271 | #gen_baseline(dataset_name, dataroot) 272 | -------------------------------------------------------------------------------- /turnbow_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import json 5 | import math 6 | 7 | import numpy as np 8 | import mxnet as mx 9 | 10 | from bucket_io import SimpleBatch 11 | from turnsent_io import turn_read_content, text2bow 12 | 13 | # The interface of a data iter that works for bucketing 14 | # 15 | # DataIter 16 | # - default_bucket_key: the bucket key for the default symbol. 17 | # 18 | # DataBatch 19 | # - provide_data: same as DataIter, but specific to this batch 20 | # - provide_label: same as DataIter, but specific to this batch 21 | # - bucket_key: the key for the bucket that should be used for this batch 22 | 23 | def nbest_text2bow(nbest_sentence, nbest_score, the_vocab, ngram=1): 24 | res = np.zeros(len(the_vocab)) 25 | for i in range(len(nbest_sentence)): 26 | words = list(set(nbest_sentence[i].split())) 27 | for word in words: 28 | if word in the_vocab: 29 | res[the_vocab[word]] += nbest_score[i] 30 | if ngram >= 2: 31 | for word in [' '.join(words[j:j+2]) for j in xrange(len(words)-1)]: 32 | if word in the_vocab: 33 | res[the_vocab[word]] += nbest_score[i] 34 | if ngram >= 3: 35 | for word in [' '.join(words[j:j+3]) for j in xrange(len(words)-2)]: 36 | if word in the_vocab: 37 | res[the_vocab[word]] += nbest_score[i] 38 | return res 39 | 40 | class DSTTurnIter(mx.io.DataIter): 41 | def __init__(self, path, labelIdx, vocab, buckets, batch_size, 42 | init_states, data_components, label_out=1): 43 | super(DSTTurnIter, self).__init__() 44 | self.vocab = vocab 45 | self.padding_id = self.vocab[''] 46 | 47 | self.label_out = label_out 48 | 49 | sentences, scores, acts, labels = turn_read_content(path, labelIdx) 50 | """ 51 | sentences: (dialog_num, turn_num, nbest_num, sentence_len) 52 | scores: (dialog_num, turn_num, nbest_num) 53 | acts: (dialog_num, turn_num, machine_act_len) 54 | labels: (dialog_num, turn_num, ) 55 | """ 56 | 57 | buckets.sort() 58 | self.buckets = buckets 59 | self.data = [[] for _ in buckets] 60 | self.data_act = [[] for _ in buckets] 61 | self.label = [[] for _ in buckets] 62 | 63 | # pre-allocate with the largest bucket for better memory sharing 64 | self.default_bucket_key = max(buckets) 65 | 66 | for i in range(len(sentences)): 67 | sentence = sentences[i] 68 | score = scores[i] 69 | act = acts[i] 70 | label = labels[i] 71 | for turn_id in range(len(sentence)): 72 | sentence[turn_id] = nbest_text2bow(sentence[turn_id], score[turn_id], self.vocab, ngram=1) 73 | act[turn_id] = text2bow(act[turn_id], self.vocab) 74 | for i, bkt in enumerate(buckets): 75 | if bkt == len(sentence): 76 | self.data[i].append(sentence) 77 | self.data_act[i].append(act) 78 | self.label[i].append(label) 79 | break 80 | """ 81 | sentence: (turn_num, vocab_size) 82 | act: (turn_num, vocab_size) 83 | label: (turn_num, label_out) 84 | """ 85 | # we just ignore the sentence it is longer than the maximum 86 | # bucket size here 87 | 88 | # convert data into ndarrays for better speed during training 89 | data = [np.array(x) for i, x in enumerate(self.data)] 90 | data_act = [np.array(x) for i, x in enumerate(self.data_act)] 91 | label = [np.array(x).reshape((len(x), buckets[i], self.label_out)) for i, x in enumerate(self.label)] 92 | 93 | self.data = data 94 | self.data_act = data_act 95 | self.label = label 96 | 97 | # backup corpus 98 | self.all_data = copy.deepcopy(self.data) 99 | self.all_data_act = copy.deepcopy(self.data_act) 100 | self.all_label = copy.deepcopy(self.label) 101 | 102 | # Get the size of each bucket, so that we could sample 103 | # uniformly from the bucket 104 | bucket_sizes = [len(x) for x in self.data] 105 | print("Summary of dataset ==================") 106 | for bkt, size in zip(buckets, bucket_sizes): 107 | print("bucket of len %3d : %d samples" % (bkt, size)) 108 | 109 | self.batch_size = batch_size 110 | #self.make_data_iter_plan() 111 | 112 | self.init_states = init_states 113 | self.data_components = data_components 114 | 115 | self.provide_data = self.data_components + self.init_states 116 | 117 | 118 | def make_data_iter_plan(self): 119 | "make a random data iteration plan" 120 | # truncate each bucket into multiple of batch-size 121 | bucket_n_batches = [] 122 | for i in range(len(self.data)): 123 | # shuffle data before truncate 124 | index_shuffle = range(len(self.data[i])) 125 | np.random.shuffle(index_shuffle) 126 | self.data[i] = self.all_data[i][index_shuffle] 127 | self.data_act[i] = self.all_data_act[i][index_shuffle] 128 | self.label[i] = self.all_label[i][index_shuffle] 129 | 130 | bucket_n_batches.append(int(math.ceil(1.0*len(self.data[i]) / self.batch_size))) 131 | self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] 132 | self.data_act[i] = self.data_act[i][:int(bucket_n_batches[i]*self.batch_size)] 133 | self.label[i] = self.label[i][:int(bucket_n_batches[i]*self.batch_size)] 134 | 135 | bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) 136 | np.random.shuffle(bucket_plan) 137 | 138 | bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] 139 | 140 | self.bucket_plan = bucket_plan 141 | self.bucket_idx_all = bucket_idx_all 142 | self.bucket_curr_idx = [0 for x in self.data] 143 | 144 | self.data_buffer = [] 145 | self.data_act_buffer = [] 146 | self.label_buffer = [] 147 | for i_bucket in range(len(self.data)): 148 | data = np.zeros((self.batch_size, self.buckets[i_bucket], len(self.vocab))) 149 | data_act = np.zeros((self.batch_size, self.buckets[i_bucket], len(self.vocab))) 150 | label = np.zeros((self.batch_size, self.buckets[i_bucket], self.label_out)) 151 | self.data_buffer.append(data) 152 | self.data_act_buffer.append(data_act) 153 | self.label_buffer.append(label) 154 | 155 | def __iter__(self): 156 | self.make_data_iter_plan() 157 | for i_bucket in self.bucket_plan: 158 | data = self.data_buffer[i_bucket] 159 | data_act = self.data_act_buffer[i_bucket] 160 | label = self.label_buffer[i_bucket] 161 | data.fill(0) 162 | data_act.fill(0) 163 | label.fill(0) 164 | 165 | i_idx = self.bucket_curr_idx[i_bucket] 166 | idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] 167 | self.bucket_curr_idx[i_bucket] += self.batch_size 168 | 169 | # Data parallelism 170 | data[:len(idx)] = self.data[i_bucket][idx] 171 | data_act[:len(idx)] = self.data_act[i_bucket][idx] 172 | label[:len(idx)] = self.label[i_bucket][idx] 173 | 174 | data_names = [x[0] for x in self.provide_data] 175 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 176 | data_all = [mx.nd.array(data), mx.nd.array(data_act)] 177 | data_all += init_state_arrays 178 | 179 | label_names = ['softmax_label'] 180 | label_all = [mx.nd.array(label)] 181 | 182 | pad = self.batch_size - len(idx) 183 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, self.buckets[i_bucket], pad) 184 | yield data_batch 185 | 186 | def reset(self): 187 | self.bucket_curr_idx = [0 for x in self.data] 188 | -------------------------------------------------------------------------------- /turnsent_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import copy 4 | import json 5 | import math 6 | 7 | import numpy as np 8 | import mxnet as mx 9 | 10 | from bucket_io import SimpleBatch 11 | from bucket_io import default_text2id 12 | 13 | # The interface of a data iter that works for bucketing 14 | # 15 | # DataIter 16 | # - default_bucket_key: the bucket key for the default symbol. 17 | # 18 | # DataBatch 19 | # - provide_data: same as DataIter, but specific to this batch 20 | # - provide_label: same as DataIter, but specific to this batch 21 | # - bucket_key: the key for the bucket that should be used for this batch 22 | 23 | def read_nbest_dialog_content(dialog, labelIdx): 24 | """生成一个对话的sample。注意该函数的输出的shape,下面的示意只是模糊值,实际的输出是list嵌套list没有固定shape: 25 | dialog_sentences: (turn_num, nbest_num, sentence_len) 26 | dialog_scores: (turn_num, nbest_num) 27 | machine_acts: (turn_num, machine_act_len) 28 | dialog_labels: (turn_num, )""" 29 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = [], [], [], [] 30 | for turn in dialog["turns"]: 31 | dialog_labels.append(turn["labelIdx"][labelIdx]) 32 | 33 | machine_act = "" 34 | for saPair in turn["machine_output"]: 35 | act = saPair["act"] 36 | slots = " " 37 | for slot in saPair["slots"]: 38 | #count never appears in train/dev set# 39 | if "count" in slot: 40 | #slot[1] = str(slot[1]) 41 | continue 42 | slots += " ".join(slot) 43 | slots += " " 44 | machine_act_item=(act+slots) 45 | machine_act += machine_act_item 46 | machine_act = machine_act.strip() 47 | machine_acts.append(machine_act) 48 | 49 | nbest_sentences = [] 50 | nbest_scores = [] 51 | for asr_hyp in turn["user_input"]: 52 | if len(asr_hyp["asr-hyp"].split()) == 0: 53 | continue 54 | nbest_scores.append(asr_hyp["score"]) 55 | sentence = "" 56 | #sentence +=" #turn# " 57 | sentence += asr_hyp["asr-hyp"] 58 | #sentence += " " 59 | nbest_sentences.append(sentence) 60 | dialog_sentences.append(nbest_sentences) 61 | dialog_scores.append(nbest_scores) 62 | 63 | return dialog_sentences, dialog_scores, machine_acts, dialog_labels 64 | 65 | def turn_read_content(path, labelIdx): 66 | """注意该函数的输出的shape,下面的示意只是模糊值,实际的输出是list嵌套list没有固定shape: 67 | sentences: (dialog_num, turn_num, nbest_num, sentence_len) 68 | scores: (dialog_num, turn_num, nbest_num) 69 | acts: (dialog_num, turn_num, machine_act_len) 70 | labels: (dialog_num, turn_num, [label_dim]) 71 | """ 72 | sentences, scores, acts, labels = [], [], [], [] 73 | with open(path) as json_file: 74 | data = json.load(json_file) 75 | for dialog in data: 76 | dialog_sentences, dialog_scores, machine_acts, dialog_labels = read_nbest_dialog_content(dialog, labelIdx) 77 | sentences.append(dialog_sentences) 78 | scores.append(dialog_scores) 79 | acts.append(machine_acts) 80 | labels.append(dialog_labels) 81 | 82 | 83 | return sentences, scores, acts, labels 84 | 85 | def text2bow(sentence, the_vocab): 86 | res = np.zeros(len(the_vocab)) 87 | words = sentence.split() 88 | for word in words: 89 | if word in the_vocab: 90 | res[the_vocab[word]] = 1 91 | return res 92 | 93 | class DSTTurnSentIter(mx.io.DataIter): 94 | """ 95 | feature_type, 取值['bowbow', 'sentsent', 'bowsent', 'sentbow'] 96 | """ 97 | def __init__(self, path, labelIdx, vocab, buckets, batch_size, max_nbest, max_sentlen, 98 | init_states, data_components, label_out=1, feature_type='bowbow'): 99 | super(DSTTurnSentIter, self).__init__() 100 | self.vocab = vocab 101 | self.padding_id = self.vocab[''] 102 | 103 | self.label_out = label_out 104 | 105 | self.max_nbest = max_nbest 106 | self.max_sentlen = max_sentlen 107 | self.feature_type = feature_type 108 | self.len_sent = self.max_sentlen if self.feature_type in ['sentsent', 'sentbow'] else len(self.vocab) 109 | self.len_act_sent = self.max_sentlen if self.feature_type in ['sentsent', 'bowsent'] else len(self.vocab) 110 | 111 | sentences, scores, acts, labels = turn_read_content(path, labelIdx) 112 | """ 113 | sentences: (dialog_num, turn_num, nbest_num, sentence_len) 114 | scores: (dialog_num, turn_num, nbest_num) 115 | acts: (dialog_num, turn_num, machine_act_len) 116 | labels: (dialog_num, turn_num, ) 117 | """ 118 | 119 | buckets.sort() 120 | self.buckets = buckets 121 | self.data = [[] for _ in buckets] 122 | self.data_act = [[] for _ in buckets] 123 | self.data_score = [[] for _ in buckets] 124 | self.label = [[] for _ in buckets] 125 | 126 | # pre-allocate with the largest bucket for better memory sharing 127 | self.default_bucket_key = max(buckets) 128 | 129 | for i in range(len(sentences)): 130 | sentence = sentences[i] 131 | score = scores[i] 132 | act = acts[i] 133 | label = labels[i] 134 | for turn_id in range(len(sentence)): 135 | # user sentence feature 136 | for nbest_id in range(len(sentence[turn_id])): 137 | if self.feature_type in ['sentsent', 'sentbow']: 138 | sentence[turn_id][nbest_id] = default_text2id(sentence[turn_id][nbest_id], self.vocab) 139 | elif self.feature_type in ['bowsent', 'bowbow']: 140 | sentence[turn_id][nbest_id] = text2bow(sentence[turn_id][nbest_id], self.vocab) 141 | # sys act feature 142 | if self.feature_type in ['sentbow', 'bowbow']: 143 | act[turn_id] = text2bow(act[turn_id], self.vocab) 144 | elif self.feature_type in ['sentsent', 'bowsent']: 145 | act[turn_id] = default_text2id(act[turn_id], self.vocab) 146 | for i, bkt in enumerate(buckets): 147 | if bkt == len(sentence): 148 | self.data[i].append(sentence) 149 | self.data_score[i].append(score) 150 | self.data_act[i].append(act) 151 | self.label[i].append(label) 152 | break 153 | """ 154 | sentence: (turn_num, nbest_num, len_sent) 155 | score: (turn_num, nbest_num) 156 | act: (turn_num, len_act_sent) 157 | label: (turn_num, label_out) 158 | """ 159 | # we just ignore the sentence it is longer than the maximum 160 | # bucket size here 161 | 162 | # convert data into ndarrays for better speed during training 163 | data = [np.full((len(x), buckets[i], self.max_nbest, self.len_sent), self.padding_id) for i, x in enumerate(self.data)] 164 | data_act = [np.full((len(x), buckets[i], self.len_act_sent), self.padding_id) for i, x in enumerate(self.data_act)] 165 | data_score =[np.zeros((len(x), buckets[i], self.max_nbest)) for i, x in enumerate(self.data_score)] 166 | label = [np.zeros((len(x), buckets[i], self.label_out)) for i, x in enumerate(self.label)] 167 | for i_bucket in range(len(self.buckets)): 168 | for i_diag in range(len(self.data[i_bucket])): 169 | for i_turn in range(len(self.data[i_bucket][i_diag])): 170 | act = self.data_act[i_bucket][i_diag][i_turn] 171 | data_act[i_bucket][i_diag, i_turn, :len(act)] = act 172 | label[i_bucket][i_diag, i_turn, :] = self.label[i_bucket][i_diag][i_turn] 173 | # be careful that, here, max_nbest can be smaller than current turn nbest number. extra-best will be truncated. 174 | for i_nbest in range(min(len(self.data[i_bucket][i_diag][i_turn]), self.max_nbest)): 175 | sentence = self.data[i_bucket][i_diag][i_turn][i_nbest] 176 | score = self.data_score[i_bucket][i_diag][i_turn][i_nbest] 177 | data[i_bucket][i_diag, i_turn, i_nbest, :len(sentence)] = sentence 178 | data_score[i_bucket][i_diag, i_turn, i_nbest] = score 179 | """ 180 | data: (bucket_num, dialog_num, bucket_size/turn_num, max_nbest, len_sent) 181 | score: (bucket_num, dialog_num, bucket_size/turn_num, max_nbest) 182 | data_act: (bucket_num, dialog_num, bucket_size/turn_num, len_act_sent) 183 | label: (bucket_num, dialog_num, bucket_size/turn_num, label_out) 184 | """ 185 | 186 | self.data = data 187 | self.data_act = data_act 188 | self.data_score = data_score 189 | self.label = label 190 | 191 | # backup corpus 192 | self.all_data = copy.deepcopy(self.data) 193 | self.all_data_act = copy.deepcopy(self.data_act) 194 | self.all_data_score = copy.deepcopy(self.data_score) 195 | self.all_label = copy.deepcopy(self.label) 196 | 197 | # Get the size of each bucket, so that we could sample 198 | # uniformly from the bucket 199 | bucket_sizes = [len(x) for x in self.data] 200 | print("Summary of dataset ==================") 201 | for bkt, size in zip(buckets, bucket_sizes): 202 | print("bucket of len %3d : %d samples" % (bkt, size)) 203 | 204 | self.batch_size = batch_size 205 | #self.make_data_iter_plan() 206 | 207 | self.init_states = init_states 208 | self.data_components = data_components 209 | 210 | self.provide_data = self.data_components + self.init_states 211 | 212 | 213 | def make_data_iter_plan(self): 214 | "make a random data iteration plan" 215 | # truncate each bucket into multiple of batch-size 216 | bucket_n_batches = [] 217 | for i in range(len(self.data)): 218 | # shuffle data before truncate 219 | index_shuffle = range(len(self.data[i])) 220 | np.random.shuffle(index_shuffle) 221 | self.data[i] = self.all_data[i][index_shuffle] 222 | self.data_act[i] = self.all_data_act[i][index_shuffle] 223 | self.data_score[i] = self.all_data_score[i][index_shuffle] 224 | self.label[i] = self.all_label[i][index_shuffle] 225 | 226 | bucket_n_batches.append(int(math.ceil(1.0*len(self.data[i]) / self.batch_size))) 227 | self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)] 228 | self.data_act[i] = self.data_act[i][:int(bucket_n_batches[i]*self.batch_size)] 229 | self.data_score[i] = self.data_score[i][:int(bucket_n_batches[i]*self.batch_size)] 230 | self.label[i] = self.label[i][:int(bucket_n_batches[i]*self.batch_size)] 231 | 232 | bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)]) 233 | np.random.shuffle(bucket_plan) 234 | 235 | bucket_idx_all = [np.random.permutation(len(x)) for x in self.data] 236 | 237 | self.bucket_plan = bucket_plan 238 | self.bucket_idx_all = bucket_idx_all 239 | self.bucket_curr_idx = [0 for x in self.data] 240 | 241 | self.data_buffer = [] 242 | self.data_act_buffer = [] 243 | self.data_score_buffer = [] 244 | self.label_buffer = [] 245 | for i_bucket in range(len(self.data)): 246 | data = np.zeros((self.batch_size, self.buckets[i_bucket], self.max_nbest, self.len_sent)) 247 | data_act = np.zeros((self.batch_size, self.buckets[i_bucket], self.len_act_sent)) 248 | data_score = np.zeros((self.batch_size, self.buckets[i_bucket], self.max_nbest)) 249 | label = np.zeros((self.batch_size, self.buckets[i_bucket], self.label_out)) 250 | self.data_buffer.append(data) 251 | self.data_act_buffer.append(data_act) 252 | self.data_score_buffer.append(data_score) 253 | self.label_buffer.append(label) 254 | 255 | def __iter__(self): 256 | self.make_data_iter_plan() 257 | for i_bucket in self.bucket_plan: 258 | data = self.data_buffer[i_bucket] 259 | data_act = self.data_act_buffer[i_bucket] 260 | data_score = self.data_score_buffer[i_bucket] 261 | label = self.label_buffer[i_bucket] 262 | data.fill(self.padding_id) 263 | data_act.fill(self.padding_id) 264 | data_score.fill(0) 265 | label.fill(0) 266 | 267 | i_idx = self.bucket_curr_idx[i_bucket] 268 | idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size] 269 | self.bucket_curr_idx[i_bucket] += self.batch_size 270 | 271 | # Data parallelism 272 | data[:len(idx)] = self.data[i_bucket][idx] 273 | data_act[:len(idx)] = self.data_act[i_bucket][idx] 274 | data_score[:len(idx)] = self.data_score[i_bucket][idx] 275 | label[:len(idx)] = self.label[i_bucket][idx] 276 | 277 | data_names = [x[0] for x in self.provide_data] 278 | init_state_arrays = [mx.nd.zeros(x[1]) for x in self.init_states] 279 | data_all = [mx.nd.array(data), mx.nd.array(data_act)] 280 | if 'score' in data_names: 281 | data_all += [mx.nd.array(data_score)] 282 | data_all += init_state_arrays 283 | 284 | label_names = ['softmax_label'] 285 | label_all = [mx.nd.array(label)] 286 | 287 | pad = self.batch_size - len(idx) 288 | data_batch = SimpleBatch(data_names, data_all, label_names, label_all, self.buckets[i_bucket], pad) 289 | yield data_batch 290 | 291 | def reset(self): 292 | self.bucket_curr_idx = [0 for x in self.data] 293 | -------------------------------------------------------------------------------- /vocab_matNN.dict: -------------------------------------------------------------------------------- 1 | (dp0 2 | S'all' 3 | p1 4 | I367 5 | sS'code' 6 | p2 7 | I531 8 | sS'chinese' 9 | p3 10 | I368 11 | sS'ali' 12 | p4 13 | I369 14 | sS'global' 15 | p5 16 | I186 17 | sS'four' 18 | p6 19 | I6 20 | sS'mediterranean' 21 | p7 22 | I542 23 | sS'asian' 24 | p8 25 | I7 26 | sS'dish' 27 | p9 28 | I370 29 | sS'mill' 30 | p10 31 | I188 32 | sS'shiraz' 33 | p11 34 | I534 35 | sS'vietnamese' 36 | p12 37 | I388 38 | sS'tv' 39 | p13 40 | I371 41 | sS'shanghai' 42 | p14 43 | I536 44 | sS'to' 45 | p15 46 | I372 47 | sS'charge' 48 | p16 49 | I190 50 | sS'does' 51 | p17 52 | I537 53 | sS'sorry' 54 | p18 55 | I8 56 | sS'town' 57 | p19 58 | I397 59 | sS'g4' 60 | p20 61 | I538 62 | sS'garden' 63 | p21 64 | I539 65 | sS'very' 66 | p22 67 | I373 68 | sS'michaelhouse' 69 | p23 70 | I191 71 | sS'every' 72 | p24 73 | I9 74 | sS'telling' 75 | p25 76 | I540 77 | sS'canapes' 78 | p26 79 | I374 80 | sS'cool' 81 | p27 82 | I192 83 | sS'requestable' 84 | p28 85 | I532 86 | sS'school' 87 | p29 88 | I10 89 | sS'venue' 90 | p30 91 | I11 92 | sS'try' 93 | p31 94 | I541 95 | sS'tang' 96 | p32 97 | I12 98 | sS'steakhouse' 99 | p33 100 | I376 101 | sS'magdalene' 102 | p34 103 | I377 104 | sS'290' 105 | p35 106 | I13 107 | sS'welcomemsg' 108 | p36 109 | I187 110 | sS"i'll" 111 | p37 112 | I378 113 | sS'tea' 114 | p38 115 | I379 116 | sS'midsummer' 117 | p39 118 | I14 119 | sS'japanese' 120 | p40 121 | I15 122 | sS'sign' 123 | p41 124 | I193 125 | sS'burger' 126 | p42 127 | I381 128 | sS'street' 129 | p43 130 | I16 131 | sS'crescent' 132 | p44 133 | I195 134 | sS'go' 135 | p45 136 | I533 137 | sS'yippee' 138 | p46 139 | I196 140 | sS'vince' 141 | p47 142 | I621 143 | sS'drinks' 144 | p48 145 | I197 146 | sS'will' 147 | p49 148 | I626 149 | sS'what' 150 | p50 151 | I382 152 | sS'reqmore' 153 | p51 154 | I198 155 | sS'addr' 156 | p52 157 | I199 158 | sS'histon' 159 | p53 160 | I493 161 | sS'indian' 162 | p54 163 | I543 164 | sS'new' 165 | p55 166 | I17 167 | sS'method' 168 | p56 169 | I383 170 | sS'told' 171 | p57 172 | I18 173 | sS'full' 174 | p58 175 | I384 176 | sS'panasian' 177 | p59 178 | I201 179 | sS'jinling' 180 | p60 181 | I385 182 | sS'french' 183 | p61 184 | I202 185 | sS'bangkok' 186 | p62 187 | I20 188 | sS'108' 189 | p63 190 | I21 191 | sS'andrews' 192 | p64 193 | I386 194 | sS'address' 195 | p65 196 | I203 197 | sS'01733' 198 | p66 199 | I22 200 | sS'finders' 201 | p67 202 | I23 203 | sS'change' 204 | p68 205 | I204 206 | sS'great' 207 | p69 208 | I545 209 | sS'rajmahal' 210 | p70 211 | I205 212 | sS'my' 213 | p71 214 | I38 215 | sS'k' 216 | p72 217 | I24 218 | sS'turkish' 219 | p73 220 | I535 221 | sS'byname' 222 | p74 223 | I25 224 | sS'35' 225 | p75 226 | I546 227 | sS'guide' 228 | p76 229 | I494 230 | sS'pasquale' 231 | p77 232 | I206 233 | sS"i'd" 234 | p78 235 | I26 236 | sS'romanian' 237 | p79 238 | I290 239 | sS'raza' 240 | p80 241 | I389 242 | sS"i'm" 243 | p81 244 | I27 245 | sS'options' 246 | p82 247 | I390 248 | sS'king' 249 | p83 250 | I450 251 | sS'golden' 252 | p84 253 | I28 254 | sS'gastro' 255 | p85 256 | I547 257 | sS'family' 258 | p86 259 | I391 260 | sS'vin' 261 | p87 262 | I392 263 | sS'mexican' 264 | p88 265 | I409 266 | sS'chesterton' 267 | p89 268 | I497 269 | sS'select' 270 | p90 271 | I393 272 | sS'yourself' 273 | p91 274 | I254 275 | sS'use' 276 | p92 277 | I548 278 | sS'521260' 279 | p93 280 | I207 281 | sS'would' 282 | p94 283 | I29 284 | sS'fen' 285 | p95 286 | I550 287 | sS'two' 288 | p96 289 | I394 290 | sS'next' 291 | p97 292 | I551 293 | sS'call' 294 | p98 295 | I30 296 | sS'308681' 297 | p99 298 | I552 299 | sS'australian' 300 | p100 301 | I208 302 | sS'type' 303 | p101 304 | I31 305 | sS'tell' 306 | p102 307 | I32 308 | sS'today' 309 | p103 310 | I209 311 | sS'more' 312 | p104 313 | I395 314 | sS'sort' 315 | p105 316 | I521 317 | sS'entrance' 318 | p106 319 | I210 320 | sS'afford' 321 | p107 322 | I211 323 | sS"didn't" 324 | p108 325 | I216 326 | sS'it' 327 | p109 328 | I135 329 | sS'phone' 330 | p110 331 | I33 332 | sS'varsity' 333 | p111 334 | I34 335 | sS'particular' 336 | p112 337 | I396 338 | sS'ok' 339 | p113 340 | I555 341 | sS'me' 342 | p114 343 | I35 344 | sS'grill' 345 | p115 346 | I398 347 | sS'none' 348 | p116 349 | I399 350 | sS'f' 351 | p117 352 | I556 353 | sS'mm' 354 | p118 355 | I36 356 | sS'car' 357 | p119 358 | I212 359 | sS'anywhere' 360 | p120 361 | I558 362 | sS'can' 363 | p121 364 | I213 365 | sS'#food#' 366 | p122 367 | I3 368 | sS'erm' 369 | p123 370 | I37 371 | sS'bites' 372 | p124 373 | I214 374 | sS'something' 375 | p125 376 | I561 377 | sS'canthelp.exception' 378 | p126 379 | I215 380 | sS'give' 381 | p127 382 | I39 383 | sS'loch' 384 | p128 385 | I559 386 | sS'india' 387 | p129 388 | I40 389 | sS'315232' 390 | p130 391 | I217 392 | sS'foods' 393 | p131 394 | I41 395 | sS'numbers' 396 | p132 397 | I400 398 | sS'want' 399 | p133 400 | I42 401 | sS'cantonese' 402 | p134 403 | I43 404 | sS'taj' 405 | p135 406 | I562 407 | sS'swedish' 408 | p136 409 | I218 410 | sS'sir' 411 | p137 412 | I563 413 | sS'needs' 414 | p138 415 | I401 416 | sS'end' 417 | p139 418 | I44 419 | sS'expl-conf' 420 | p140 421 | I564 422 | sS'byalternatives' 423 | p141 424 | I219 425 | sS'1' 426 | p142 427 | I573 428 | sS'how' 429 | p143 430 | I45 431 | sS'trumpington' 432 | p144 433 | I220 434 | sS'galleria' 435 | p145 436 | I566 437 | sS"doesn't" 438 | p146 439 | I72 440 | sS'pizza' 441 | p147 442 | I46 443 | sS'pub' 444 | p148 445 | I431 446 | sS'lan' 447 | p149 448 | I47 449 | sS'express' 450 | p150 451 | I468 452 | sS'okay' 453 | p151 454 | I403 455 | sS'seoul' 456 | p152 457 | I404 458 | sS'cocktail' 459 | p153 460 | I221 461 | sS'may' 462 | p154 463 | I222 464 | sS'southern' 465 | p155 466 | I586 467 | sS'hong' 468 | p156 469 | I138 470 | sS'wrong' 471 | p157 472 | I48 473 | sS'polynesian' 474 | p158 475 | I405 476 | sS'cheaper' 477 | p159 478 | I469 479 | sS'a' 480 | p160 481 | I406 482 | sS'binh' 483 | p161 484 | I568 485 | sS'light' 486 | p162 487 | I569 488 | sS'meghna' 489 | p163 490 | I223 491 | sS'inform' 492 | p164 493 | I224 494 | sS'cote' 495 | p165 496 | I470 497 | sS'so' 498 | p166 499 | I225 500 | sS'meze' 501 | p167 502 | I407 503 | sS'goodbye' 504 | p168 505 | I52 506 | sS'african' 507 | p169 508 | I226 509 | sS'order' 510 | p170 511 | I53 512 | sS'wine' 513 | p171 514 | I54 515 | sS'dojo' 516 | p172 517 | I227 518 | sS'serving' 519 | p173 520 | I228 521 | sS'singaporean' 522 | p174 523 | I55 524 | sS'help' 525 | p175 526 | I410 527 | sS"don't" 528 | p176 529 | I411 530 | sS'over' 531 | p177 532 | I56 533 | sS'northampton' 534 | p178 535 | I229 536 | sS'through' 537 | p179 538 | I412 539 | sS'austrian' 540 | p180 541 | I571 542 | sS'avenue' 543 | p181 544 | I413 545 | sS'its' 546 | p182 547 | I300 548 | sS'before' 549 | p183 550 | I57 551 | sS'24' 552 | p184 553 | I415 554 | sS'saigon' 555 | p185 556 | I58 557 | sS'20' 558 | p186 559 | I416 560 | sS'thank' 561 | p187 562 | I230 563 | sS'thanh' 564 | p188 565 | I231 566 | sS'cookhouse' 567 | p189 568 | I402 569 | sS'bloomsbury' 570 | p190 571 | I59 572 | sS'actually' 573 | p191 574 | I418 575 | sS'451' 576 | p192 577 | I591 578 | sS'cotto' 579 | p193 580 | I576 581 | sS'barnwell' 582 | p194 583 | I375 584 | sS'then' 585 | p195 586 | I61 587 | sS'them' 588 | p196 589 | I62 590 | sS'good' 591 | p197 592 | I419 593 | sS'ya' 594 | p198 595 | I315 596 | sS'uhm' 597 | p199 598 | I421 599 | sS'eat' 600 | p200 601 | I654 602 | sS'canthelp' 603 | p201 604 | I578 605 | sS'prezzo' 606 | p202 607 | I579 608 | sS'they' 609 | p203 610 | I63 611 | sS'not' 612 | p204 613 | I233 614 | sS'now' 615 | p205 616 | I234 617 | sS'yu' 618 | p206 619 | I423 620 | sS'day' 621 | p207 622 | I580 623 | sS'nor' 624 | p208 625 | I235 626 | sS'309147' 627 | p209 628 | I583 629 | sS'name' 630 | p210 631 | I236 632 | sS'peking' 633 | p211 634 | I582 635 | sS'01799' 636 | p212 637 | I237 638 | sS'luca' 639 | p213 640 | I66 641 | sS'323178' 642 | p214 643 | I293 644 | sS'244955' 645 | p215 646 | I425 647 | sS'out' 648 | p216 649 | I584 650 | sS'side' 651 | p217 652 | I67 653 | sS'luck' 654 | p218 655 | I68 656 | sS'house' 657 | p219 658 | I426 659 | sS'yeah' 660 | p220 661 | I238 662 | sS'persian' 663 | p221 664 | I239 665 | sS'year' 666 | p222 667 | I240 668 | sS'er' 669 | p223 670 | I241 671 | sS'god' 672 | p224 673 | I69 674 | sS'looking' 675 | p225 676 | I242 677 | sS're' 678 | p226 679 | I70 680 | sS'oriental' 681 | p227 682 | I585 683 | sS'hill' 684 | p228 685 | I428 686 | sS'7' 687 | p229 688 | I429 689 | sS'internet' 690 | p230 691 | I243 692 | sS'got' 693 | p231 694 | I71 695 | sS'pizzeria' 696 | p232 697 | I244 698 | sS'correct' 699 | p233 700 | I245 701 | sS'191' 702 | p234 703 | I380 704 | sS'bistro' 705 | p235 706 | I430 707 | sS'thong' 708 | p236 709 | I246 710 | sS'free' 711 | p237 712 | I73 713 | sS'standard' 714 | p238 715 | I74 716 | sS'wagamama' 717 | p239 718 | I587 719 | sS'plea' 720 | p240 721 | I247 722 | sS'ask' 723 | p241 724 | I432 725 | sS'care' 726 | p242 727 | I248 728 | sS'yard' 729 | p243 730 | I588 731 | sS'could' 732 | p244 733 | I589 734 | sS'kohinoor' 735 | p245 736 | I75 737 | sS'british' 738 | p246 739 | I249 740 | sS'americas' 741 | p247 742 | I433 743 | sS'maharajah' 744 | p248 745 | I250 746 | sS'gourmet' 747 | p249 748 | I442 749 | sS'american' 750 | p250 751 | I434 752 | sS'place' 753 | p251 754 | I251 755 | sS'signature' 756 | p252 757 | I76 758 | sS'castle' 759 | p253 760 | I435 761 | sS'retail' 762 | p254 763 | I49 764 | sS'think' 765 | p255 766 | I252 767 | sS'south' 768 | p256 769 | I590 770 | sS'first' 771 | p257 772 | I253 773 | sS'323737' 774 | p258 775 | I77 776 | sS'rang' 777 | p259 778 | I78 779 | sS'zizzi' 780 | p260 781 | I60 782 | sS'number' 783 | p261 784 | I437 785 | sS'one' 786 | p262 787 | I65 788 | sS'lettuce' 789 | p263 790 | I79 791 | sS'done' 792 | p264 793 | I438 794 | sS'long' 795 | p265 796 | I515 797 | sS'another' 798 | p266 799 | I80 800 | sS'spanish' 801 | p267 802 | I256 803 | sS'millers' 804 | p268 805 | I81 806 | sS'sounds' 807 | p269 808 | I257 809 | sS'fitzbillies' 810 | p270 811 | I440 812 | sS'city' 813 | p271 814 | I258 815 | sS'little' 816 | p272 817 | I259 818 | sS'guest' 819 | p273 820 | I441 821 | sS'553355' 822 | p274 823 | I260 824 | sS'st.' 825 | p275 826 | I443 827 | sS'moderately' 828 | p276 829 | I444 830 | sS'system' 831 | p277 832 | I593 833 | sS'their' 834 | p278 835 | I594 836 | sS'566388' 837 | p279 838 | I445 839 | sS'2' 840 | p280 841 | I261 842 | sS'too' 843 | p281 844 | I82 845 | sS'saint' 846 | p282 847 | I446 848 | sS'molecular' 849 | p283 850 | I627 851 | sS'danish' 852 | p284 853 | I83 854 | sS'that' 855 | p285 856 | I262 857 | sS'ditton' 858 | p286 859 | I263 860 | sS'hum' 861 | p287 862 | I264 863 | sS'hotel' 864 | p288 865 | I447 866 | sS'serve' 867 | p289 868 | I84 869 | sS'hut' 870 | p290 871 | I265 872 | sS'#turn#' 873 | p291 874 | I2 875 | sS'eritrean' 876 | p292 877 | I449 878 | sS'western' 879 | p293 880 | I85 881 | sS'tuscan' 882 | p294 883 | I595 884 | sS'thai' 885 | p295 886 | I266 887 | sS'milton' 888 | p296 889 | I267 890 | sS'10' 891 | p297 892 | I268 893 | sS'kind' 894 | p298 895 | I451 896 | sS'b' 897 | p299 898 | I452 899 | sS'15' 900 | p300 901 | I269 902 | sS'17' 903 | p301 904 | I270 905 | sS'scandinavian' 906 | p302 907 | I454 908 | sS'matter' 909 | p303 910 | I86 911 | sS'cost' 912 | p304 913 | I194 914 | sS'budget' 915 | p305 916 | I491 917 | sS'alimentum' 918 | p306 919 | I455 920 | sS'and' 921 | p307 922 | I271 923 | sS'bridge' 924 | p308 925 | I87 926 | sS'palace' 927 | p309 928 | I88 929 | sS'huntingdon' 930 | p310 931 | I89 932 | sS'modern' 933 | p311 934 | I90 935 | sS'mind' 936 | p312 937 | I91 938 | sS'crossover' 939 | p313 940 | I272 941 | sS'nirala' 942 | p314 943 | I273 944 | sS'graffiti' 945 | p315 946 | I456 947 | sS'halal' 948 | p316 949 | I597 950 | sS'have' 951 | p317 952 | I598 953 | sS'close' 954 | p318 955 | I477 956 | sS'need' 957 | p319 958 | I599 959 | sS'sells' 960 | p320 961 | I274 962 | sS'any' 963 | p321 964 | I275 965 | sS'these' 966 | p322 967 | I308 968 | sS'greek' 969 | p323 970 | I50 971 | sS'-' 972 | p324 973 | I92 974 | sS'also' 975 | p325 976 | I458 977 | sS'high' 978 | p326 979 | I560 980 | sS'take' 981 | p327 982 | I276 983 | sS'which' 984 | p328 985 | I600 986 | sS'confirm-domain' 987 | p329 988 | I93 989 | sS'green' 990 | p330 991 | I51 992 | sS'korean' 993 | p331 994 | I601 995 | sS'sure' 996 | p332 997 | I277 998 | sS'205' 999 | p333 1000 | I278 1001 | sS'though' 1002 | p334 1003 | I94 1004 | sS'park' 1005 | p335 1006 | I459 1007 | sS'price' 1008 | p336 1009 | I279 1010 | sS'takeaway' 1011 | p337 1012 | I570 1013 | sS'reach' 1014 | p338 1015 | I460 1016 | sS'victoria' 1017 | p339 1018 | I95 1019 | sS"where's" 1020 | p340 1021 | I171 1022 | sS'zealand' 1023 | p341 1024 | I96 1025 | sS'cuban' 1026 | p342 1027 | I280 1028 | sS'don' 1029 | p343 1030 | I97 1031 | sS'pipasha' 1032 | p344 1033 | I281 1034 | sS'sala' 1035 | p345 1036 | I282 1037 | sS'seafood' 1038 | p346 1039 | I283 1040 | sS'duckling' 1041 | p347 1042 | I98 1043 | sS'm' 1044 | p348 1045 | I99 1046 | sS'traditional' 1047 | p349 1048 | I461 1049 | sS"you'll" 1050 | p350 1051 | I607 1052 | sS'354755' 1053 | p351 1054 | I284 1055 | sS"that's" 1056 | p352 1057 | I408 1058 | sS'shop' 1059 | p353 1060 | I285 1061 | sS'german' 1062 | p354 1063 | I286 1064 | sS'queen' 1065 | p355 1066 | I100 1067 | sS'cheap' 1068 | p356 1069 | I287 1070 | sS'moroccan' 1071 | p357 1072 | I462 1073 | sS'2g' 1074 | p358 1075 | I424 1076 | sS'postcode' 1077 | p359 1078 | I608 1079 | sS'corner' 1080 | p360 1081 | I288 1082 | sS'fine' 1083 | p361 1084 | I463 1085 | sS'find' 1086 | p362 1087 | I464 1088 | sS'slot' 1089 | p363 1090 | I289 1091 | sS'chiquito' 1092 | p364 1093 | I465 1094 | sS'ha' 1095 | p365 1096 | I653 1097 | sS'northern' 1098 | p366 1099 | I466 1100 | sS'menu' 1101 | p367 1102 | I101 1103 | sS'should' 1104 | p368 1105 | I609 1106 | sS'only' 1107 | p369 1108 | I292 1109 | sS'pretty' 1110 | p370 1111 | I467 1112 | sS'lodge' 1113 | p371 1114 | I102 1115 | sS'rice' 1116 | p372 1117 | I103 1118 | sS'do' 1119 | p373 1120 | I104 1121 | sS'mimosa' 1122 | p374 1123 | I604 1124 | sS'hungarian' 1125 | p375 1126 | I294 1127 | sS'get' 1128 | p376 1129 | I295 1130 | sS'de' 1131 | p377 1132 | I105 1133 | sS'stop' 1134 | p378 1135 | I567 1136 | sS'lucky' 1137 | p379 1138 | I610 1139 | sS'leisure' 1140 | p380 1141 | I106 1142 | sS'da' 1143 | p381 1144 | I107 1145 | sS'cannot' 1146 | p382 1147 | I296 1148 | sS'international' 1149 | p383 1150 | I200 1151 | sS'mahal' 1152 | p384 1153 | I453 1154 | sS'du' 1155 | p385 1156 | I574 1157 | sS'byconstraints' 1158 | p386 1159 | I108 1160 | sS'areas' 1161 | p387 1162 | I611 1163 | sS'bar' 1164 | p388 1165 | I109 1166 | sS'413000' 1167 | p389 1168 | I110 1169 | sS'malaysian' 1170 | p390 1171 | I297 1172 | sS'bad' 1173 | p391 1174 | I111 1175 | sS'clowns' 1176 | p392 1177 | I298 1178 | sS'river' 1179 | p393 1180 | I472 1181 | sS'where' 1182 | p394 1183 | I299 1184 | sS'restaurants' 1185 | p395 1186 | I473 1187 | sS'eraina' 1188 | p396 1189 | I112 1190 | sS'see' 1191 | p397 1192 | I474 1193 | sS'are' 1194 | p398 1195 | I475 1196 | sS'sea' 1197 | p399 1198 | I476 1199 | sS'polish' 1200 | p400 1201 | I322 1202 | sS'brasserie' 1203 | p401 1204 | I113 1205 | sS'venetian' 1206 | p402 1207 | I656 1208 | sS'please' 1209 | p403 1210 | I478 1211 | sS'327908' 1212 | p404 1213 | I414 1214 | sS'3' 1215 | p405 1216 | I291 1217 | sS"there's" 1218 | p406 1219 | I114 1220 | sS'wok' 1221 | p407 1222 | I479 1223 | sS'copper' 1224 | p408 1225 | I436 1226 | sS'barbeque' 1227 | p409 1228 | I612 1229 | sS'perfect' 1230 | p410 1231 | I572 1232 | sS'we' 1233 | p411 1234 | I115 1235 | sS'latin' 1236 | p412 1237 | I480 1238 | sS'missing' 1239 | p413 1240 | I481 1241 | sS'fyne' 1242 | p414 1243 | I301 1244 | sS'courtyard' 1245 | p415 1246 | I613 1247 | sS'coffee' 1248 | p416 1249 | I482 1250 | sS'here' 1251 | p417 1252 | I19 1253 | sS'travellers' 1254 | p418 1255 | I116 1256 | sS'informable' 1257 | p419 1258 | I117 1259 | sS'kitchen' 1260 | p420 1261 | I118 1262 | sS'newnham' 1263 | p421 1264 | I483 1265 | sS'toward' 1266 | p422 1267 | I484 1268 | sS'cow' 1269 | p423 1270 | I119 1271 | sS'restaurant' 1272 | p424 1273 | I485 1274 | sS'la' 1275 | p425 1276 | I575 1277 | sS'darrys' 1278 | p426 1279 | I614 1280 | sS'kosher' 1281 | p427 1282 | I615 1283 | sS'gastropub' 1284 | p428 1285 | I635 1286 | sS's' 1287 | p429 1288 | I302 1289 | sS'rose' 1290 | p430 1291 | I324 1292 | sS"can't" 1293 | p431 1294 | I303 1295 | sS'uno' 1296 | p432 1297 | I120 1298 | sS'others' 1299 | p433 1300 | I544 1301 | sS'152' 1302 | p434 1303 | I180 1304 | sS'basque' 1305 | p435 1306 | I122 1307 | sS'boat' 1308 | p436 1309 | I305 1310 | sS'unusual' 1311 | p437 1312 | I335 1313 | sS'kymmoy' 1314 | p438 1315 | I616 1316 | sS'west' 1317 | p439 1318 | I306 1319 | sS'prince' 1320 | p440 1321 | I502 1322 | sS'three' 1323 | p441 1324 | I123 1325 | sS'tiny' 1326 | p442 1327 | I124 1328 | sS'beer' 1329 | p443 1330 | I125 1331 | sS'much' 1332 | p444 1333 | I126 1334 | sS'01223' 1335 | p445 1336 | I618 1337 | sS'ah' 1338 | p446 1339 | I358 1340 | sS'kettle' 1341 | p447 1342 | I643 1343 | sS'vinci' 1344 | p448 1345 | I619 1346 | sS'eastern' 1347 | p449 1348 | I127 1349 | sS'chan' 1350 | p450 1351 | I522 1352 | sS"what's" 1353 | p451 1354 | I487 1355 | sS'else' 1356 | p452 1357 | I488 1358 | sS'hmm' 1359 | p453 1360 | I622 1361 | sS'finished' 1362 | p454 1363 | I623 1364 | sS'prices' 1365 | p455 1366 | I489 1367 | sS'sound' 1368 | p456 1369 | I307 1370 | sS'lebanese' 1371 | p457 1372 | I625 1373 | sS'bye' 1374 | p458 1375 | I577 1376 | sS'look' 1377 | p459 1378 | I490 1379 | sS'400170' 1380 | p460 1381 | I129 1382 | sS'panahar' 1383 | p461 1384 | I130 1385 | sS'cash' 1386 | p462 1387 | I309 1388 | sS'sesame' 1389 | p463 1390 | I492 1391 | sS'ugly' 1392 | p464 1393 | I131 1394 | sS'near' 1395 | p465 1396 | I132 1397 | sS'vegetarian' 1398 | p466 1399 | I387 1400 | sS'seven' 1401 | p467 1402 | I133 1403 | sS'#name#' 1404 | p468 1405 | I4 1406 | sS'is' 1407 | p469 1408 | I134 1409 | sS'telephone' 1410 | p470 1411 | I311 1412 | sS'sitar' 1413 | p471 1414 | I495 1415 | sS'middle' 1416 | p472 1417 | I312 1418 | sS'in' 1419 | p473 1420 | I136 1421 | sS'tandoori' 1422 | p474 1423 | I457 1424 | sS'if' 1425 | p475 1426 | I137 1427 | sS'different' 1428 | p476 1429 | I313 1430 | sS'lankan' 1431 | p477 1432 | I628 1433 | sS'pay' 1434 | p478 1435 | I314 1436 | sS'food' 1437 | p479 1438 | I420 1439 | sS'caffe' 1440 | p480 1441 | I139 1442 | sS'eclectic' 1443 | p481 1444 | I422 1445 | sS'dontcare' 1446 | p482 1447 | I496 1448 | sS'european' 1449 | p483 1450 | I140 1451 | sS'welsh' 1452 | p484 1453 | I141 1454 | sS'grafton' 1455 | p485 1456 | I629 1457 | sS'backstreet' 1458 | p486 1459 | I498 1460 | sS'353110' 1461 | p487 1462 | I499 1463 | sS'812660' 1464 | p488 1465 | I317 1466 | sS'cocum' 1467 | p489 1468 | I318 1469 | sS'afghan' 1470 | p490 1471 | I630 1472 | sS'' 1473 | p491 1474 | I0 1475 | sS'slug' 1476 | p492 1477 | I644 1478 | sS'noodle' 1479 | p493 1480 | I500 1481 | sS'catalan' 1482 | p494 1483 | I505 1484 | sS'i' 1485 | p495 1486 | I631 1487 | sS'charlie' 1488 | p496 1489 | I142 1490 | sS'well' 1491 | p497 1492 | I632 1493 | sS'21' 1494 | p498 1495 | I417 1496 | sS'english' 1497 | p499 1498 | I633 1499 | sS'swiss' 1500 | p500 1501 | I316 1502 | sS'the' 1503 | p501 1504 | I143 1505 | sS'frankie' 1506 | p502 1507 | I320 1508 | sS'just' 1509 | p503 1510 | I144 1511 | sS'from' 1512 | p504 1513 | I549 1514 | sS'being' 1515 | p505 1516 | I321 1517 | sS'oak' 1518 | p506 1519 | I658 1520 | sS'rest' 1521 | p507 1522 | I471 1523 | sS'schools' 1524 | p508 1525 | I64 1526 | sS'thanks' 1527 | p509 1528 | I145 1529 | sS'part' 1530 | p510 1531 | I448 1532 | sS'yep' 1533 | p511 1534 | I146 1535 | sS'yes' 1536 | p512 1537 | I147 1538 | sS'' 1539 | p513 1540 | I1 1541 | sS'yet' 1542 | p514 1543 | I148 1544 | sS'hotpot' 1545 | p515 1546 | I503 1547 | sS'hills' 1548 | p516 1549 | I149 1550 | sS'thinking' 1551 | p517 1552 | I323 1553 | sS'cherry' 1554 | p518 1555 | I150 1556 | sS'regent' 1557 | p519 1558 | I325 1559 | sS'royal' 1560 | p520 1561 | I121 1562 | sS'4' 1563 | p521 1564 | I326 1565 | sS'cucina' 1566 | p522 1567 | I151 1568 | sS'riverside' 1569 | p523 1570 | I636 1571 | sS'east' 1572 | p524 1573 | I128 1574 | sS'indonesian' 1575 | p525 1576 | I504 1577 | sS'351880' 1578 | p526 1579 | I327 1580 | sS'fusion' 1581 | p527 1582 | I152 1583 | sS'five' 1584 | p528 1585 | I637 1586 | sS'know' 1587 | p529 1588 | I638 1589 | sS'312598' 1590 | p530 1591 | I153 1592 | sS'world' 1593 | p531 1594 | I501 1595 | sS'margherita' 1596 | p532 1597 | I154 1598 | sS'like' 1599 | p533 1600 | I506 1601 | sS'serves' 1602 | p534 1603 | I328 1604 | sS'gandhi' 1605 | p535 1606 | I155 1607 | sS'night' 1608 | p536 1609 | I156 1610 | sS'served' 1611 | p537 1612 | I329 1613 | sS'tower' 1614 | p538 1615 | I330 1616 | sS'italian' 1617 | p539 1618 | I639 1619 | sS'portuguese' 1620 | p540 1621 | I157 1622 | sS'right' 1623 | p541 1624 | I158 1625 | sS'old' 1626 | p542 1627 | I159 1628 | sS'hakka' 1629 | p543 1630 | I160 1631 | sS'scottish' 1632 | p544 1633 | I641 1634 | sS'some' 1635 | p545 1636 | I507 1637 | sS'back' 1638 | p546 1639 | I508 1640 | sS'homerton' 1641 | p547 1642 | I161 1643 | sS'anatolia' 1644 | p548 1645 | I162 1646 | sS'caribbean' 1647 | p549 1648 | I331 1649 | sS'for' 1650 | p550 1651 | I163 1652 | sS'#slot#' 1653 | p551 1654 | I5 1655 | sS'centre' 1656 | p552 1657 | I602 1658 | sS'350688' 1659 | p553 1660 | I332 1661 | sS'creative' 1662 | p554 1663 | I189 1664 | sS'everything' 1665 | p555 1666 | I164 1667 | sS'asking' 1668 | p556 1669 | I165 1670 | sS'expensive' 1671 | p557 1672 | I486 1673 | sS'362456' 1674 | p558 1675 | I333 1676 | sS'moderate' 1677 | p559 1678 | I334 1679 | sS'christmas' 1680 | p560 1681 | I166 1682 | sS't' 1683 | p561 1684 | I509 1685 | sS'be' 1686 | p562 1687 | I510 1688 | sS'who' 1689 | p563 1690 | I603 1691 | sS'corn' 1692 | p564 1693 | I167 1694 | sS'efes' 1695 | p565 1696 | I168 1697 | sS'newmarket' 1698 | p566 1699 | I169 1700 | sS'thailand' 1701 | p567 1702 | I336 1703 | sS'post' 1704 | p568 1705 | I170 1706 | sS'by' 1707 | p569 1708 | I511 1709 | sS'on' 1710 | p570 1711 | I337 1712 | sS'about' 1713 | p571 1714 | I645 1715 | sS'central' 1716 | p572 1717 | I338 1718 | sS'anything' 1719 | p573 1720 | I512 1721 | sS'oh' 1722 | p574 1723 | I339 1724 | sS'of' 1725 | p575 1726 | I340 1727 | sS'am' 1728 | p576 1729 | I620 1730 | sS'sri' 1731 | p577 1732 | I341 1733 | sS'chop' 1734 | p578 1735 | I342 1736 | sS'range' 1737 | p579 1738 | I513 1739 | sS'afternoon' 1740 | p580 1741 | I172 1742 | sS"it's" 1743 | p581 1744 | I617 1745 | sS'mean' 1746 | p582 1747 | I343 1748 | sS'curry' 1749 | p583 1750 | I173 1751 | sS'or' 1752 | p584 1753 | I344 1754 | sS'road' 1755 | p585 1756 | I345 1757 | sS'244277' 1758 | p586 1759 | I174 1760 | sS'this' 1761 | p587 1762 | I557 1763 | sS'cambridge' 1764 | p588 1765 | I346 1766 | sS'johns' 1767 | p589 1768 | I232 1769 | sS'clifton' 1770 | p590 1771 | I646 1772 | sS'down' 1773 | p591 1774 | I175 1775 | sS'gastronomy' 1776 | p592 1777 | I581 1778 | sS'because' 1779 | p593 1780 | I640 1781 | sS'been' 1782 | p594 1783 | I647 1784 | sS'your' 1785 | p595 1786 | I347 1787 | sS'corsica' 1788 | p596 1789 | I348 1790 | sS'351707' 1791 | p597 1792 | I596 1793 | sS'360966' 1794 | p598 1795 | I349 1796 | sS'her' 1797 | p599 1798 | I350 1799 | sS'area' 1800 | p600 1801 | I351 1802 | sS'there' 1803 | p601 1804 | I352 1805 | sS'sock' 1806 | p602 1807 | I514 1808 | sS'fast' 1809 | p603 1810 | I255 1811 | sS'start' 1812 | p604 1813 | I353 1814 | sS'727410' 1815 | p605 1816 | I648 1817 | sS'way' 1818 | p606 1819 | I176 1820 | sS'brazilian' 1821 | p607 1822 | I354 1823 | sS'was' 1824 | p608 1825 | I177 1826 | sS'north' 1827 | p609 1828 | I649 1829 | sS'offer' 1830 | p610 1831 | I178 1832 | sS'but' 1833 | p611 1834 | I650 1835 | sS'hk' 1836 | p612 1837 | I651 1838 | sS'hi' 1839 | p613 1840 | I652 1841 | sS'hear' 1842 | p614 1843 | I179 1844 | sS'russian' 1845 | p615 1846 | I516 1847 | sS'trying' 1848 | p616 1849 | I355 1850 | sS'with' 1851 | p617 1852 | I356 1853 | sS'hinton' 1854 | p618 1855 | I517 1856 | sS'spice' 1857 | p619 1858 | I363 1859 | sS'j' 1860 | p620 1861 | I655 1862 | sS'um' 1863 | p621 1864 | I518 1865 | sS'jamaican' 1866 | p622 1867 | I357 1868 | sS'uh' 1869 | p623 1870 | I519 1871 | sS'tasca' 1872 | p624 1873 | I605 1874 | sS'belgian' 1875 | p625 1876 | I553 1877 | sS'called' 1878 | p626 1879 | I520 1880 | sS'baba' 1881 | p627 1882 | I554 1883 | sS'irish' 1884 | p628 1885 | I657 1886 | sS'bedouin' 1887 | p629 1888 | I304 1889 | sS'located' 1890 | p630 1891 | I565 1892 | sS'154' 1893 | p631 1894 | I181 1895 | sS'nandos' 1896 | p632 1897 | I182 1898 | sS'australasian' 1899 | p633 1900 | I359 1901 | sS'an' 1902 | p634 1903 | I624 1904 | sS'as' 1905 | p635 1906 | I360 1907 | sS'pricerange' 1908 | p636 1909 | I310 1910 | sS'at' 1911 | p637 1912 | I361 1913 | sS'request' 1914 | p638 1915 | I606 1916 | sS'cafe' 1917 | p639 1918 | I523 1919 | sS'again' 1920 | p640 1921 | I362 1922 | sS'no' 1923 | p641 1924 | I319 1925 | sS'when' 1926 | p642 1927 | I183 1928 | sS'saffron' 1929 | p643 1930 | I524 1931 | sS'tight' 1932 | p644 1933 | I634 1934 | sS'other' 1935 | p645 1936 | I659 1937 | sS'5' 1938 | p646 1939 | I364 1940 | sS'c.b' 1941 | p647 1942 | I528 1943 | sS'you' 1944 | p648 1945 | I365 1946 | sS'really' 1947 | p649 1948 | I427 1949 | sS"you're" 1950 | p650 1951 | I592 1952 | sS'repeat' 1953 | p651 1954 | I660 1955 | sS'star' 1956 | p652 1957 | I642 1958 | sS'gardenia' 1959 | p653 1960 | I525 1961 | sS'bennys' 1962 | p654 1963 | I184 1964 | sS'priced' 1965 | p655 1966 | I661 1967 | sS'lane' 1968 | p656 1969 | I526 1970 | sS'stazione' 1971 | p657 1972 | I662 1973 | sS'e' 1974 | p658 1975 | I527 1976 | sS'allright' 1977 | p659 1978 | I439 1979 | sS'u' 1980 | p660 1981 | I366 1982 | sS'time' 1983 | p661 1984 | I185 1985 | sS'far' 1986 | p662 1987 | I529 1988 | sS'hello' 1989 | p663 1990 | I530 1991 | s. --------------------------------------------------------------------------------