├── .gitignore ├── README.md ├── __init__.py ├── data ├── dev.iob └── train.iob ├── data_utils.py ├── images └── model.png ├── main.py ├── model.py ├── test.ipynb └── weight └── model.pkl /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SDEN-Pytorch 2 | 3 | Pytorch implementation of [Sequential Dialogue Context Modeling for Spoken Language 4 | Understanding](https://arxiv.org/pdf/1705.03455.pdf) 5 | 6 | ![Model](images/model.png "SDEN") 7 | 8 | 9 | ## Requirements 10 | 11 | ``` 12 | pytorch==0.4 13 | nltk==3.5.1 14 | sklearn_crfsuite 15 | ``` 16 | 17 | ## Run 18 | 19 | ``` 20 | python3 main.py 21 | ``` 22 | 23 | ## Data 24 | 25 | I have modified [Stanford Multi-turn dataset](https://nlp.stanford.edu/blog/a-new-multi-turn-multi-domain-task-oriented-dialogue-dataset/) to fit this model. *So it has some noise especially slot tags.* 26 | It consists of three domain, `Weather`, `Schedule`, `Navigate`. I did dialogue recombination for multi-domain dialogue and modified its format to BIO. 27 | 28 | ### sample 29 | 30 | #### Single domain dialogue 31 | 32 | ``` 33 | User : Will it be hot in Inglewood over the next few days? 34 | BOT : It will be warm both Monday and Tuesday in Inglewood. 35 | User : Thank you very much. 36 | BOT : You're welcome. Hope you have a great day. 37 | ``` 38 | 39 | #### Multi domain dialogue 40 | 41 | ``` 42 | User : is it going to be raining this weekend 43 | BOT : What city are you inquiring about? 44 | User : Alhambra please. 45 | BOT : It will be raining on Saturday and hailing on Sunday in Alhambra. 46 | User : Thanks. 47 | BOT : happy to help 48 | User : I need a gas station 49 | BOT : I have one gas station listed. Want more info? 50 | User : What is the address? 51 | BOT : 76 is at 91 El Camino Real. 52 | User : Thank you! 53 | BOT : You're welcome, stay safe. 54 | ``` 55 | 56 | 57 | ## Devset Result 58 | 59 | `Intent Detection : 0.9503091367071216 (Accuracy)` 60 | 61 | 62 | `Slot Extraction` 63 | 64 | | | precision| recall | f1-score | support | 65 | |---------------------|----------|--------|----------|---------| 66 | | B-agenda |0.256 |0.278 |0.267 |36 | 67 | | I-agenda |0.733 |0.407 |0.524 |54 | 68 | | B-date |0.826 |0.836 |0.831 |911 | 69 | | I-date |0.533 |0.885 |0.665 |549 | 70 | | B-distance |0.624 |0.674 |0.648 |487 | 71 | | I-distance |0.424 |0.353 |0.386 |167 | 72 | | B-event |0.813 |0.793 |0.803 |517 | 73 | | I-event |0.637 |0.847 |0.727 |367 | 74 | | B-location |0.718 |0.928 |0.809 |572 | 75 | | I-location |0.384 |0.950 |0.547 |280 | 76 | | B-party |0.298 |0.807 |0.435 |187 | 77 | | I-party |0.471 |0.471 |0.471 |17 | 78 | | B-poi_type |0.790 |0.738 |0.763 |534 | 79 | | I-poi_type |0.528 |0.718 |0.608 |301 | 80 | | B-room |1.000 |0.400 |0.571 |35 | 81 | | I-room |0.683 |0.848 |0.757 |33 | 82 | | B-time |0.496 |0.595 |0.541 |220 | 83 | | I-time |0.129 |0.286 |0.178 |14 | 84 | | B-traffic_info |0.661 |0.527 |0.587 |237 | 85 | | I-traffic_info |0.749 |0.636 |0.688 |272 | 86 | | B-weather_attribute |0.904 |0.877 |0.890 |546 | 87 | | I-weather_attribute |0.954 |0.775 |0.855 |80 | 88 | | avg / total |0.683 |0.775 |0.712 |6416 | 89 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from konlpy.tag import Mecab 10 | from copy import deepcopy 11 | tagger = Mecab() 12 | flatten = lambda l: [item for sublist in l for item in sublist] 13 | from data_utils import * 14 | from model import SDEN 15 | import pickle 16 | 17 | THIS_PATH = os.path.dirname(os.path.abspath(__file__)) 18 | 19 | class ContextNLU: 20 | def __init__(self): 21 | self.word2index = pickle.load(open(THIS_PATH+'/vocab.pkl','rb')) 22 | slot2index = pickle.load(open(THIS_PATH+'/slot.pkl','rb')) 23 | intent2index = pickle.load(open(THIS_PATH+'/intent.pkl','rb')) 24 | self.index2intent = {v:k for k,v in intent2index.items()} 25 | self.index2slot = {v:k for k,v in slot2index.items()} 26 | self.model = SDEN(len(self.word2index),100,64,len(slot2index),len(intent2index)) 27 | self.model.load_state_dict(torch.load(THIS_PATH+'/sden.pkl')) 28 | self.model.eval() 29 | self.history=[Variable(torch.LongTensor([2])).view(1,-1)] 30 | def reset(self): 31 | self.history=[Variable(torch.LongTensor([2])).view(1,-1)] 32 | 33 | def predict(self,current): 34 | current = tagger.morphs(current) 35 | current = prepare_sequence(current,self.word2index).view(1,-1) 36 | history = pad_to_history(self.history,self.word2index) 37 | s,i = self.model(history,current) 38 | slot_p = s.max(1)[1] 39 | intent_p = i.max(1)[1] 40 | slot = [self.index2slot[s] for s in slot_p.data.tolist()] 41 | intent = self.index2intent[intent_p.data[0]] 42 | 43 | if len(self.history)==[Variable(torch.LongTensor([2])).view(1,-1)]: 44 | self.history.pop() 45 | self.history.append(current) 46 | 47 | return slot, intent 48 | 49 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from copy import deepcopy 6 | import random 7 | from tqdm import tqdm 8 | flatten = lambda l: [item for sublist in l for item in sublist] 9 | 10 | 11 | def prepare_dataset(path,built_vocab=None,user_only=False): 12 | data = open(path,"r",encoding="utf-8").readlines() 13 | p_data=[] 14 | history=[[""]] 15 | for d in data: 16 | if d=="\n": 17 | history=[[""]] 18 | continue 19 | dd = d.replace("\n","").split("|||") 20 | if len(dd)==1: 21 | if user_only: 22 | pass 23 | else: 24 | bot = dd[0].split() 25 | history.append(bot) 26 | else: 27 | user = dd[0].split() 28 | tag = dd[1].split() 29 | intent = dd[2] 30 | temp = deepcopy(history) 31 | p_data.append([temp,user,tag,intent]) 32 | history.append(user) 33 | 34 | if built_vocab is None: 35 | historys, currents, slots, intents = list(zip(*p_data)) 36 | vocab = list(set(flatten(currents))) 37 | slot_vocab = list(set(flatten(slots))) 38 | intent_vocab = list(set(intents)) 39 | 40 | word2index={"" : 0, "" : 1, "" : 2, "" : 3, "" : 4} 41 | for vo in vocab: 42 | if word2index.get(vo)==None: 43 | word2index[vo] = len(word2index) 44 | 45 | slot2index={"" : 0} 46 | for vo in slot_vocab: 47 | if slot2index.get(vo)==None: 48 | slot2index[vo] = len(slot2index) 49 | 50 | intent2index={} 51 | for vo in intent_vocab: 52 | if intent2index.get(vo)==None: 53 | intent2index[vo] = len(intent2index) 54 | else: 55 | word2index, slot2index, intent2index = built_vocab 56 | 57 | for t in tqdm(p_data): 58 | for i,history in enumerate(t[0]): 59 | t[0][i] = prepare_sequence(history, word2index).view(1, -1) 60 | 61 | t[1] = prepare_sequence(t[1], word2index).view(1, -1) 62 | t[2] = prepare_sequence(t[2], slot2index).view(1, -1) 63 | t[3] = torch.LongTensor([intent2index[t[3]]]).view(1,-1) 64 | 65 | if built_vocab is None: 66 | return p_data, word2index, slot2index, intent2index 67 | else: 68 | return p_data 69 | 70 | def prepare_sequence(seq, to_index): 71 | idxs = list(map(lambda w: to_index[w] if to_index.get(w) is not None else to_index[""], seq)) 72 | return torch.LongTensor(idxs) 73 | 74 | def data_loader(train_data,batch_size,shuffle=False): 75 | if shuffle: random.shuffle(train_data) 76 | sindex = 0 77 | eindex = batch_size 78 | while eindex < len(train_data): 79 | batch = train_data[sindex: eindex] 80 | temp = eindex 81 | eindex = eindex + batch_size 82 | sindex = temp 83 | yield batch 84 | 85 | if eindex >= len(train_data): 86 | batch = train_data[sindex:] 87 | yield batch 88 | 89 | def pad_to_batch(batch, w_to_ix,s_to_ix): # for bAbI dataset 90 | history,current,slot,intent = list(zip(*batch)) 91 | max_history = max([len(h) for h in history]) 92 | max_len = max([h.size(1) for h in flatten(history)]) 93 | max_current = max([c.size(1) for c in current]) 94 | max_slot = max([s.size(1) for s in slot]) 95 | 96 | historys, currents, slots = [], [], [] 97 | for i in range(len(batch)): 98 | history_p_t = [] 99 | for j in range(len(history[i])): 100 | if history[i][j].size(1) < max_len: 101 | history_p_t.append(torch.cat([history[i][j], torch.LongTensor([w_to_ix['']] * (max_len - history[i][j].size(1))).view(1, -1)], 1)) 102 | else: 103 | history_p_t.append(history[i][j]) 104 | 105 | while len(history_p_t) < max_history: 106 | history_p_t.append(torch.LongTensor([w_to_ix['']] * max_len).view(1, -1)) 107 | 108 | history_p_t = torch.cat(history_p_t) 109 | historys.append(history_p_t) 110 | 111 | if current[i].size(1) < max_current: 112 | currents.append(torch.cat([current[i], torch.LongTensor([w_to_ix['']] * (max_current - current[i].size(1))).view(1, -1)], 1)) 113 | else: 114 | currents.append(current[i]) 115 | 116 | if slot[i].size(1) < max_slot: 117 | slots.append(torch.cat([slot[i], torch.LongTensor([s_to_ix['']] * (max_slot - slot[i].size(1))).view(1, -1)], 1)) 118 | else: 119 | slots.append(slot[i]) 120 | 121 | currents = torch.cat(currents) 122 | slots = torch.cat(slots) 123 | intents = torch.cat(intent) 124 | 125 | return historys, currents, slots, intents 126 | 127 | def pad_to_history(history, x_to_ix): # this is for inference 128 | 129 | max_x = max([len(s) for s in history]) 130 | x_p = [] 131 | for i in range(len(history)): 132 | h = prepare_sequence(history[i],x_to_ix).unsqueeze(0) 133 | if len(history[i]) < max_x: 134 | x_p.append(torch.cat([h,torch.LongTensor([x_to_ix['']] * (max_x - h.size(1))).view(1, -1)], 1)) 135 | else: 136 | x_p.append(h) 137 | 138 | history = torch.cat(x_p) 139 | return [history] -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/SDEN-Pytorch/6562ba5223e84d5156b824d01fbb92da1e3a4d93/images/model.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | import numpy as np 6 | from data_utils import * 7 | from model import SDEN 8 | from sklearn_crfsuite import metrics 9 | import argparse 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | 13 | 14 | def evaluation(model,dev_data): 15 | model.eval() 16 | index2slot = {v:k for k,v in model.slot_vocab.items()} 17 | preds=[] 18 | labels=[] 19 | hits=0 20 | with torch.no_grad(): 21 | for i,batch in enumerate(data_loader(dev_data,32,True)): 22 | h,c,slot,intent = pad_to_batch(batch,model.vocab,model.slot_vocab) 23 | h = [hh.to(device) for hh in h] 24 | c = c.to(device) 25 | slot = slot.to(device) 26 | intent = intent.to(device) 27 | slot_p, intent_p = model(h,c) 28 | 29 | preds.extend([index2slot[i] for i in slot_p.max(1)[1].tolist()]) 30 | labels.extend([index2slot[i] for i in slot.view(-1).tolist()]) 31 | hits+=torch.eq(intent_p.max(1)[1],intent.view(-1)).sum().item() 32 | 33 | 34 | print(hits/len(dev_data)) 35 | 36 | sorted_labels = sorted( 37 | list(set(labels) - {'O',''}), 38 | key=lambda name: (name[1:], name[0]) 39 | ) 40 | 41 | # this is because sklearn_crfsuite.metrics function flatten inputs 42 | preds = [[y] for y in preds] 43 | labels = [[y] for y in labels] 44 | 45 | print(metrics.flat_classification_report( 46 | labels, preds, labels = sorted_labels, digits=3 47 | )) 48 | 49 | def save(model,config): 50 | checkpoint = { 51 | 'model': model.state_dict(), 52 | 'vocab': model.vocab, 53 | 'slot_vocab' : model.slot_vocab, 54 | 'intent_vocab' : model.intent_vocab, 55 | 'config' : config, 56 | } 57 | torch.save(checkpoint,config.save_path) 58 | print("Model saved!") 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser() 62 | # DONOTCHANGE: They are reserved for nsml 63 | parser.add_argument('--mode', type=str, default='train') 64 | parser.add_argument('--pause', type=int, default=0) 65 | parser.add_argument('--iteration', type=str, default='0') 66 | parser.add_argument('--epochs', type=int, default=5, 67 | help='num_epochs') 68 | parser.add_argument('--batch_size', type=int, default=64, 69 | help='batch size') 70 | parser.add_argument('--lr', type=float, default=0.001, 71 | help='learning_rate') 72 | parser.add_argument('--dropout', type=float, default=0.3, 73 | help='dropout') 74 | parser.add_argument('--embed_size', type=int, default=100, 75 | help='embed_size') 76 | parser.add_argument('--hidden_size', type=int, default=64, 77 | help='hidden_size') 78 | parser.add_argument('--save_path', type=str, default='weight/model.pkl', 79 | help='save_path') 80 | 81 | config = parser.parse_args() 82 | 83 | train_data, word2index, slot2index, intent2index = prepare_dataset('data/train.iob') 84 | dev_data = prepare_dataset('data/dev.iob',(word2index,slot2index,intent2index)) 85 | model = SDEN(len(word2index),config.embed_size,config.hidden_size,\ 86 | len(slot2index),len(intent2index),word2index['']) 87 | model.to(device) 88 | model.vocab = word2index 89 | model.slot_vocab = slot2index 90 | model.intent_vocab = intent2index 91 | 92 | slot_loss_function = nn.CrossEntropyLoss(ignore_index=0) 93 | intent_loss_function = nn.CrossEntropyLoss() 94 | optimizer = optim.Adam(model.parameters(),lr=config.lr) 95 | scheduler = optim.lr_scheduler.MultiStepLR(gamma=0.1,milestones=[config.epochs//4,config.epochs//2],optimizer=optimizer) 96 | 97 | model.train() 98 | for epoch in range(config.epochs): 99 | losses=[] 100 | scheduler.step() 101 | for i,batch in enumerate(data_loader(train_data,config.batch_size,True)): 102 | h,c,slot,intent = pad_to_batch(batch,model.vocab,model.slot_vocab) 103 | h = [hh.to(device) for hh in h] 104 | c = c.to(device) 105 | slot = slot.to(device) 106 | intent = intent.to(device) 107 | model.zero_grad() 108 | slot_p, intent_p = model(h,c) 109 | 110 | loss_s = slot_loss_function(slot_p,slot.view(-1)) 111 | loss_i = intent_loss_function(intent_p,intent.view(-1)) 112 | loss = loss_s + loss_i 113 | losses.append(loss.item()) 114 | loss.backward() 115 | optimizer.step() 116 | 117 | if i % 100 == 0: 118 | print("[%d/%d] [%d/%d] mean_loss : %.3f" % \ 119 | (epoch,config.epochs,i,len(train_data)//config.batch_size,np.mean(losses))) 120 | losses=[] 121 | 122 | evaluation(model,dev_data) 123 | save(model,config) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils.rnn import pad_packed_sequence as unpack 5 | from torch.nn.utils.rnn import pack_padded_sequence as pack 6 | 7 | class SDEN(nn.Module): 8 | def __init__(self,vocab_size,embed_size,hidden_size,slot_size,intent_size,dropout=0.3,pad_idx=0): 9 | super(SDEN,self).__init__() 10 | 11 | self.pad_idx = 0 12 | self.embed = nn.Embedding(vocab_size,embed_size,padding_idx=self.pad_idx) 13 | self.bigru_m = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True) 14 | self.bigru_c = nn.GRU(embed_size,hidden_size,batch_first=True,bidirectional=True) 15 | self.context_encoder = nn.Sequential(nn.Linear(hidden_size*4,hidden_size*2), 16 | nn.Sigmoid()) 17 | self.session_encoder = nn.GRU(hidden_size*2,hidden_size*2,batch_first=True,bidirectional=True) 18 | 19 | self.decoder_1 = nn.GRU(embed_size,hidden_size*2,batch_first=True,bidirectional=True) 20 | self.decoder_2 = nn.LSTM(hidden_size*4,hidden_size*2,batch_first=True,bidirectional=True) 21 | 22 | self.intent_linear = nn.Linear(hidden_size*4,intent_size) 23 | self.slot_linear = nn.Linear(hidden_size*4,slot_size) 24 | self.dropout = nn.Dropout(dropout) 25 | 26 | for param in self.parameters(): 27 | if len(param.size())>1: 28 | nn.init.xavier_uniform_(param) 29 | else: 30 | param.data.zero_() 31 | 32 | def forward(self,history,current): 33 | batch_size = len(history) 34 | H= [] # encoded history 35 | for h in history: 36 | mask = (h!=self.pad_idx) 37 | length = mask.sum(1).long() 38 | embeds = self.embed(h) 39 | embeds = self.dropout(embeds) 40 | lens, indices = torch.sort(length, 0, True) 41 | lens = [l if l>0 else 1 for l in lens.tolist()] # all zero-input 42 | packed_h = pack(embeds[indices], lens, batch_first=True) 43 | outputs, hidden = self.bigru_m(packed_h) 44 | _, _indices = torch.sort(indices, 0) 45 | hidden = torch.cat([hh for hh in hidden],-1) 46 | hidden = hidden[_indices].unsqueeze(0) 47 | H.append(hidden) 48 | 49 | M = torch.cat(H) # B,T_C,2H 50 | M = self.dropout(M) 51 | 52 | embeds = self.embed(current) 53 | embeds = self.dropout(embeds) 54 | mask = (current!=self.pad_idx) 55 | length = mask.sum(1).long() 56 | lens, indices = torch.sort(length, 0, True) 57 | packed_h = pack(embeds[indices], lens.tolist(), batch_first=True) 58 | outputs, hidden = self.bigru_c(packed_h) 59 | _, _indices = torch.sort(indices, 0) 60 | hidden = torch.cat([hh for hh in hidden],-1) 61 | C = hidden[_indices].unsqueeze(1) # B,1,2H 62 | C = self.dropout(C) 63 | 64 | C = C.repeat(1,M.size(1),1) 65 | CONCAT = torch.cat([M,C],-1) # B,T_c,4H 66 | 67 | G = self.context_encoder(CONCAT) 68 | 69 | _,H = self.session_encoder(G) # 2,B,2H 70 | weight = next(self.parameters()) 71 | cell_state = weight.new_zeros(H.size()) 72 | O_1,_ = self.decoder_1(embeds) 73 | O_1 = self.dropout(O_1) 74 | 75 | O_2,(S_2,_) = self.decoder_2(O_1,(H,cell_state)) 76 | O_2 = self.dropout(O_2) 77 | S = torch.cat([s for s in S_2],1) 78 | 79 | intent_prob = self.intent_linear(S) 80 | slot_prob = self.slot_linear(O_2.contiguous().view(O_2.size(0)*O_2.size(1),-1)) 81 | 82 | return slot_prob, intent_prob -------------------------------------------------------------------------------- /test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 8, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "import torch.nn.functional as F\n", 12 | "from data_utils import *\n", 13 | "from model import SDEN\n", 14 | "import pickle\n", 15 | "import json\n", 16 | "import random\n", 17 | "import nltk" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "checkpoint = torch.load('weight/model.pkl',map_location=lambda storage, loc: storage)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "config = checkpoint['config']" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 4, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "model = SDEN(len(checkpoint['vocab']),config.embed_size,config.hidden_size,\n", 45 | " len(checkpoint['slot_vocab']),len(checkpoint['intent_vocab']))\n", 46 | "model.load_state_dict(checkpoint['model'])" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 5, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "SDEN(\n", 58 | " (embed): Embedding(1179, 100, padding_idx=0)\n", 59 | " (bigru_m): GRU(100, 64, batch_first=True, bidirectional=True)\n", 60 | " (bigru_c): GRU(100, 64, batch_first=True, bidirectional=True)\n", 61 | " (context_encoder): Sequential(\n", 62 | " (0): Linear(in_features=256, out_features=128, bias=True)\n", 63 | " (1): Sigmoid()\n", 64 | " )\n", 65 | " (session_encoder): GRU(128, 128, batch_first=True, bidirectional=True)\n", 66 | " (decoder_1): GRU(100, 128, batch_first=True, bidirectional=True)\n", 67 | " (decoder_2): LSTM(256, 128, batch_first=True, bidirectional=True)\n", 68 | " (intent_linear): Linear(in_features=256, out_features=4, bias=True)\n", 69 | " (slot_linear): Linear(in_features=256, out_features=24, bias=True)\n", 70 | " (dropout): Dropout(p=0.3)\n", 71 | ")" 72 | ] 73 | }, 74 | "execution_count": 5, 75 | "metadata": {}, 76 | "output_type": "execute_result" 77 | } 78 | ], 79 | "source": [ 80 | "model.eval()" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": 9, 86 | "metadata": {}, 87 | "outputs": [], 88 | "source": [ 89 | "data = json.load(open('../dataset/kvret/kvret_test_public.json','r'))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 10, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "test = random.choice(data)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 11, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "index2intent = {v:k for k,v in checkpoint['intent_vocab'].items()}\n", 108 | "index2slot = {v:k for k,v in checkpoint['slot_vocab'].items()}" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 17, 114 | "metadata": {}, 115 | "outputs": [], 116 | "source": [ 117 | "test = random.sample(data,2)\n", 118 | "index = random.choice([i for i in range(len(test[0]['dialogue'])) if i%2==0])\n", 119 | "test = test[0]['dialogue'][:index] + test[1]['dialogue']" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 18, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "['What', 'is', 'the', 'date', 'and', 'time', 'of', 'my', 'next', 'meeting', 'and', 'who', 'will', 'be', 'attending', 'it', '?']\n", 132 | "intent : schedule\n", 133 | "slot : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-event', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n", 134 | "\n", 135 | "['Please', 'give', 'me', 'the', 'address', 'and', 'directions', 'via', 'a', 'route', 'with', 'no', 'traffic', 'to', 'the', 'nearest', 'pizza', 'restaurant', '.']\n", 136 | "intent : navigate\n", 137 | "slot : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-traffic_info', 'O', 'O', 'B-distance', 'B-poi_type', 'I-poi_type', 'O']\n", 138 | "\n", 139 | "['Yes', ',', 'let', \"'s\", 'go', ',', 'thank', 'you', '!']\n", 140 | "intent : thanks\n", 141 | "slot : ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']\n", 142 | "\n" 143 | ] 144 | } 145 | ], 146 | "source": [ 147 | "history=[[\"\"]]\n", 148 | "for d in test:\n", 149 | " utter = d['data']['utterance']\n", 150 | " token = nltk.word_tokenize(utter)\n", 151 | " c = prepare_sequence(token,checkpoint['vocab']).unsqueeze(0)\n", 152 | " h = pad_to_history(history,checkpoint['vocab'])\n", 153 | " with torch.no_grad():\n", 154 | " s,i = model(h,c)\n", 155 | " slot_p = s.max(1)[1]\n", 156 | " intent_p = i.max(1)[1]\n", 157 | " if d['turn']=='driver':\n", 158 | " print(token)\n", 159 | " print('intent : ',index2intent[intent_p.item()])\n", 160 | " print('slot : ',[index2slot[s] for s in slot_p.data.tolist()])\n", 161 | " print(\"\")\n", 162 | " history.append(token)" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": null, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [] 171 | } 172 | ], 173 | "metadata": { 174 | "kernelspec": { 175 | "display_name": "Python 3", 176 | "language": "python", 177 | "name": "python3" 178 | }, 179 | "language_info": { 180 | "codemirror_mode": { 181 | "name": "ipython", 182 | "version": 3 183 | }, 184 | "file_extension": ".py", 185 | "mimetype": "text/x-python", 186 | "name": "python", 187 | "nbconvert_exporter": "python", 188 | "pygments_lexer": "ipython3", 189 | "version": "3.6.5" 190 | } 191 | }, 192 | "nbformat": 4, 193 | "nbformat_minor": 2 194 | } 195 | -------------------------------------------------------------------------------- /weight/model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSKSD/SDEN-Pytorch/6562ba5223e84d5156b824d01fbb92da1e3a4d93/weight/model.pkl --------------------------------------------------------------------------------