├── .gitignore ├── README.md ├── code ├── dataset.py ├── model.py ├── run.py └── utils.py └── data └── nlu_data ├── test.txt ├── train_shuffle.txt └── wiki.en.vec /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | saved_models/ 4 | *.ipynb 5 | *.pt 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks 2 | 3 | This repository implements a capsule model IntentCapsNet-ZSL on the SNIPS-NLU dataset in Python 3 4 | with PyTorch, first introduced in the paper _Zero-shot User Intent Detection via Capsule Neural Networks_. 5 | 6 | The code aims to follow PyTorch best practices, using `torch` instead of `numpy` where possible, and using 7 | `.cuda()` for GPU computation. Feel free to contribute via pull requests. 8 | 9 | # Requirements 10 | 11 | python 3.6+ 12 | 13 | torch 1.0.1 14 | 15 | numpy 16 | 17 | gensim 18 | 19 | scikit-learn 20 | 21 | # Usage and Modification 22 | 23 | * To run the training-validation loop: `python run.py`. 24 | * The custom `Dataset` class is implemented in `dataset.py`. 25 | 26 | # Acknowledgements 27 | * Original repository (TensorFlow, Python 2): https://github.com/congyingxia/ZeroShotCapsule 28 | * Re-implementation (PyTorch, Python 2): https://github.com/nhhoang96/ZeroShotCapsule-PyTorch- 29 | 30 | Please see the following paper for the details: 31 | 32 | Congying Xia, Chenwei Zhang, Xiaohui Yan, Yi Chang, Philip S. Yu. Zero-shot User 33 | Intent Detection via Capsule Neural Networks. In Proceedings of the 2018 Conference on 34 | Empirical Methods in Natural Language Processing (EMNLP), 2018. 35 | 36 | https://arxiv.org/abs/1809.00385 37 | 38 | 39 | ``` 40 | @article{xia2018zero, 41 | title={Zero-shot User Intent Detection via Capsule Neural Networks}, 42 | author={Xia, Congying and Zhang, Chenwei and Yan, Xiaohui and Chang, Yi and Yu, Philip S}, 43 | journal={arXiv preprint arXiv:1809.00385}, 44 | year={2018} 45 | } 46 | ``` 47 | # References 48 | * https://github.com/soskek/dynamic_routing_between_capsules 49 | * https://github.com/ExplorerFreda/Structured-Self-Attentive-Sentence-Embedding 50 | 51 | -------------------------------------------------------------------------------- /code/dataset.py: -------------------------------------------------------------------------------- 1 | from gensim.models.keyedvectors import KeyedVectors 2 | import utils 3 | 4 | import torch 5 | from torch.utils.data import Dataset 6 | from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence 7 | 8 | class IntentDataset(Dataset): 9 | def __init__(self, class_list, w2v_path, data_path): 10 | self.class_list = class_list 11 | self.w2v = self.load_w2v(w2v_path) 12 | self.class_indices, self.class_w2v = self.process_labels(class_list) 13 | self.embedding = self.load_embedding() 14 | self.corpus = self.load_data(data_path) 15 | 16 | def __getitem__(self, idx): 17 | """ 18 | Each sample is a dictionary with keys 19 | {sentence (list of w2v indices), label_onehot, label_w2v}. 20 | """ 21 | return self.corpus[idx] 22 | 23 | def __len__(self): 24 | return len(self.corpus) 25 | 26 | def load_w2v(self, model_path): 27 | """ 28 | load w2v model 29 | input: model filepath 30 | returns: w2v model 31 | """ 32 | w2v = KeyedVectors.load_word2vec_format( 33 | model_path, binary=False) 34 | return w2v 35 | 36 | def process_labels(self, class_list): 37 | """ 38 | pre-process class labels 39 | input: list of string labels 40 | returns: two dicts, {labels-to-indices} and {labels-to-w2v} 41 | """ 42 | label_indices = {} 43 | label_w2v = {} 44 | for k in range(len(class_list)): 45 | 46 | label = class_list[k] 47 | 48 | # Account for classes 49 | # with multiple words. 50 | label_list = label.split(' ') 51 | 52 | for word in label_list: 53 | if word not in self.w2v.vocab.keys(): 54 | raise Exception("{} not in W2V model".format(word)) 55 | 56 | # Compute the sum of the 57 | # constituent words in the label 58 | vector_sum = 0 59 | for word in label_list: 60 | vector_sum += torch.Tensor(self.w2v[word]) 61 | 62 | label_w2v[label] = vector_sum 63 | label_indices[label] = k 64 | 65 | return label_indices, label_w2v 66 | 67 | def load_data(self, file_path): 68 | """ 69 | Loads samples into a list. Each sample is a dictionary with keys 70 | {sentence (list of w2v indices), label_onehot, label_w2v}. 71 | 72 | input: text file path 73 | returns: dataset, a list of dicts. 74 | """ 75 | 76 | dataset = [] 77 | 78 | for line in open(file_path): 79 | arr = line.strip().split('\t') 80 | label = [w for w in arr[0].split(' ')] 81 | sentence = [w for w in arr[1].split(' ')] 82 | cname = ' '.join(label) 83 | 84 | # The line is useless if the class is 85 | # not in the class dictionary. 86 | if cname not in self.class_list: 87 | raise Exception("{} not in class list.".format(cname)) 88 | 89 | # Build the sample dictionary. 90 | sample = {} 91 | sample['sentence_w2v'] = [] 92 | 93 | for word in sentence: 94 | if word not in self.w2v.vocab.keys(): 95 | continue # ignore sentence 96 | 97 | # In the loading embedding (see self.load_embedding()), we 98 | # stack one additional layer of zeros in front to handle padding. 99 | # Thus here we append the embedding index plus one. 100 | sample['sentence_w2v'].append(torch.Tensor([self.w2v.vocab[word].index + 1])) 101 | 102 | sample['length'] = len(sample['sentence_w2v']) 103 | sample['label_onehot'] = self.onehot(self.class_indices[cname]) 104 | sample['label_w2v'] = self.class_w2v[cname] 105 | dataset.append(sample) 106 | 107 | return dataset 108 | 109 | def categorical(self, onehot): 110 | return torch.argmax(onehot, dim=1) 111 | 112 | def onehot(self, idx): 113 | onehot = torch.zeros(len(self.class_indices)) 114 | onehot[idx] = 1. 115 | return onehot 116 | 117 | def load_embedding(self): 118 | # load normalized word embeddings 119 | embedding = self.w2v.syn0 120 | norm_embedding = utils.norm_matrix(embedding) 121 | 122 | # Stack one layer of zeros on the embedding 123 | # to handle padding. So the total length of 124 | # the embedding increases by one. 125 | # See: https://datascience.stackexchange.com/questions/32345/initial-embeddings-for-unknown-padding 126 | # See: https://discuss.pytorch.org/t/padding-zeros-in-nn-embedding-while-using-pre-train-word-vectors/8443/4 127 | 128 | emb = torch.from_numpy(norm_embedding) 129 | zeros = torch.zeros(1, emb.shape[1]) 130 | pad_enabled_embedding = torch.cat((zeros, emb)) 131 | return pad_enabled_embedding 132 | 133 | class IntentBatch: 134 | def __init__(self, batch): 135 | batch.sort(reverse=True, key=lambda x: x['length']) 136 | batch = list(zip(*map(dict.values, batch))) 137 | 138 | sentences_w2v = [torch.LongTensor(x) for x in batch[0]] 139 | lengths = torch.Tensor(batch[1]).long() # cpu 140 | label_onehot = torch.stack(batch[2]) 141 | label_w2v = batch[3] 142 | 143 | self.sentences_w2v = pad_sequence(sentences_w2v, padding_value=0, batch_first=True) 144 | self.lengths = lengths 145 | self.label_onehot = label_onehot 146 | self.label_w2v = label_w2v 147 | 148 | def batch_function(batch): 149 | return IntentBatch(batch) -------------------------------------------------------------------------------- /code/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 8 | 9 | """ 10 | Parameters: 11 | Vocab size: T 12 | hidden_dim: hidden dimension 13 | input_dim (D_W): pre-trained using skip-gram model (300) 14 | """ 15 | class CapsuleNetwork(nn.Module): 16 | def __init__(self, config, pretrained_embedding = None): 17 | super(CapsuleNetwork, self).__init__() 18 | self.hidden_size = config['hidden_size'] 19 | self.vocab_size = config['vocab_size'] 20 | self.word_emb_size = config['word_emb_size'] 21 | self.learning_rate = config['learning_rate'] 22 | self.batch_size = config['batch_size'] 23 | 24 | self.word_embedding = nn.Embedding(config['vocab_size'], config['word_emb_size']) 25 | self.bilstm = nn.LSTM(config['word_emb_size'], config['hidden_size'], 26 | config['nlayers'], bidirectional=True, batch_first=True) 27 | self.drop = nn.Dropout(config['keep_prob']) 28 | 29 | # parameters for self-attention 30 | self.n = config['max_time'] 31 | self.d = config['word_emb_size'] 32 | self.d_a = config['d_a'] 33 | self.u = config['hidden_size'] 34 | self.r = config['r'] 35 | self.alpha = config['alpha'] 36 | 37 | # attention 38 | self.ws1 = nn.Linear(config['hidden_size'] * 2, config['d_a'], bias=False) 39 | self.ws2 = nn.Linear(config['d_a'], config['r'], bias=False) 40 | self.tanh = nn.Tanh() 41 | self.softmax = nn.Softmax() 42 | 43 | self.s_cnum = config['s_cnum'] 44 | self.margin = config['margin'] 45 | self.keep_prob = config['keep_prob'] 46 | self.num_routing = config['num_routing'] 47 | self.output_atoms = config['output_atoms'] 48 | self.nlayers = 2 49 | 50 | # for capsule 51 | self.input_dim = self.r 52 | self.input_atoms = self.hidden_size * 2 53 | self.output_dim = self.s_cnum 54 | self.capsule_weights = nn.Parameter(torch.zeros((self.r, self.hidden_size * 2, 55 | self.s_cnum * self.output_atoms))) 56 | self.init_weights() 57 | 58 | 59 | def forward(self, input, len, embedding, hc): 60 | self.s_len = len 61 | input = input.transpose(0,1) #(Bach,Length,D) => (L,B,D) 62 | # Attention 63 | if (embedding.nelement() != 0): 64 | self.word_embedding = nn.Embedding.from_pretrained(embedding) 65 | 66 | emb = self.word_embedding(input) 67 | packed_emb = pack_padded_sequence(emb, len) 68 | 69 | outp = self.bilstm(packed_emb, hc)[0] ## [bsz, len, d_h * 2] 70 | outp = pad_packed_sequence(outp)[0].transpose(0,1).contiguous() 71 | size = outp.size() 72 | compressed_embeddings = outp.view(-1, size[2]) # [bsz * len, d_h * 2] 73 | hbar = self.tanh(self.ws1(self.drop(compressed_embeddings))) 74 | alphas = self.ws2(hbar).view(size[0], size[1], -1) # [bsz, len, hop] 75 | 76 | self.attention = torch.transpose(alphas, 1, 2).contiguous() # [bsz, hop, len] 77 | self.sentence_embedding = torch.bmm(self.attention, outp) 78 | 79 | ## capsule 80 | dropout_emb = self.drop(self.sentence_embedding) 81 | 82 | input_tiled = torch.unsqueeze(dropout_emb, -1).repeat(1, 1, 1, self.output_dim * self.output_atoms) 83 | votes = torch.sum(input_tiled * self.capsule_weights, dim=2) 84 | votes_reshaped = torch.reshape(votes, [-1, self.input_dim, self.output_dim, self.output_atoms]) 85 | input_shape = self.sentence_embedding.shape 86 | logit_shape = np.stack([input_shape[0], self.input_dim, self.output_dim]) 87 | 88 | self.activation, self.weights_b, self.weights_c = self.routing(votes = votes_reshaped, 89 | logit_shape=logit_shape, 90 | num_dims=4) 91 | self.logits = self.get_logits() 92 | self.votes = votes_reshaped 93 | 94 | def get_logits(self): 95 | logits = torch.norm(self.activation, dim=-1) 96 | return logits 97 | 98 | def routing(self, votes, logit_shape, num_dims): 99 | votes_t_shape = [3, 0, 1, 2] 100 | for i in range(num_dims - 4): 101 | votes_t_shape += [i + 4] 102 | r_t_shape = [1, 2, 3, 0] 103 | for i in range(num_dims - 4): 104 | r_t_shape += [i + 4] 105 | 106 | votes_trans = votes.permute(votes_t_shape) 107 | logits = nn.Parameter(torch.zeros(logit_shape[0], logit_shape[1], logit_shape[2])).cuda() 108 | activations = [] 109 | 110 | # Iterative routing. 111 | for iteration in range(self.num_routing): 112 | route = F.softmax(logits, dim=2).cuda() 113 | preactivate_unrolled = route * votes_trans 114 | preact_trans = preactivate_unrolled.permute(r_t_shape) 115 | # delete bias to fit for unseen classes 116 | preactivate = torch.sum(preact_trans, dim=1) 117 | activation = self._squash(preactivate) 118 | activations.append(activation) 119 | # distances: [batch, input_dim, output_dim] 120 | act_3d = activation.unsqueeze(1) 121 | tile_shape = np.ones(num_dims, dtype=np.int32).tolist() 122 | tile_shape[1] = self.input_dim 123 | act_replicated = act_3d.repeat(tile_shape) 124 | distances = torch.sum(votes * act_replicated, dim=3) 125 | logits = logits + distances 126 | 127 | return activations[self.num_routing - 1], logits, route 128 | 129 | def _squash(self, input_tensor): 130 | norm = torch.norm(input_tensor, dim=2, keepdim= True) 131 | norm_squared = norm * norm 132 | return (input_tensor / norm) * (norm_squared / (0.5 + norm_squared)) 133 | 134 | 135 | def init_weights(self): 136 | nn.init.xavier_uniform_(self.ws1.weight) 137 | nn.init.xavier_uniform_(self.ws2.weight) 138 | nn.init.xavier_uniform_(self.capsule_weights) 139 | 140 | self.ws1.weight.requires_grad_(True) 141 | self.ws2.weight.requires_grad_(True) 142 | self.capsule_weights.requires_grad_(True) 143 | 144 | def _margin_loss(self, labels, raw_logits, margin=0.4, downweight=0.5): 145 | """Penalizes deviations from margin for each logit. 146 | Each wrong logit costs its distance to margin. For negative logits margin is 147 | 0.1 and for positives it is 0.9. First subtract 0.5 from all logits. Now 148 | margin is 0.4 from each side. 149 | Args: 150 | labels: tensor, one hot encoding of ground truth. 151 | raw_logits: tensor, model predictions in range [0, 1] 152 | margin: scalar, the margin after subtracting 0.5 from raw_logits. 153 | downweight: scalar, the factor for negative cost. 154 | Returns: 155 | A tensor with cost for each data point of shape [batch_size]. 156 | """ 157 | logits = raw_logits - 0.5 158 | positive_cost = labels * (logits < margin).float() * ((logits - margin) ** 2) 159 | negative_cost = (1 - labels) * (logits > -margin).float() * ((logits + margin) ** 2) 160 | return 0.5 * positive_cost + downweight * 0.5 * negative_cost 161 | 162 | def loss(self, label): 163 | loss_val = self._margin_loss(label, self.logits) 164 | loss_val = torch.mean(loss_val) 165 | 166 | self_atten_mul = torch.matmul(self.attention, self.attention.permute([0, 2, 1])).float() 167 | sample_num, att_matrix_size, _ = self_atten_mul.shape 168 | self_atten_loss = (torch.norm(self_atten_mul - torch.from_numpy(np.identity(att_matrix_size)).float().cuda()).float()) ** 2 169 | 170 | return 1000 * loss_val + self.alpha * torch.mean(self_atten_loss) -------------------------------------------------------------------------------- /code/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | from random import * 3 | import time 4 | 5 | from dataset import IntentDataset, batch_function 6 | from model import CapsuleNetwork 7 | 8 | import numpy as np 9 | import torch 10 | import torch.optim as optim 11 | from torch.optim.lr_scheduler import ReduceLROnPlateau 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import utils 16 | import math 17 | 18 | from sklearn.metrics import classification_report 19 | from sklearn.preprocessing import normalize 20 | from sklearn.metrics import accuracy_score 21 | 22 | a = Random() 23 | a.seed(1) 24 | torch.cuda.manual_seed_all(17) 25 | device = torch.device("cuda:0") 26 | 27 | def setting(train_set, test_set, embedding): 28 | vocab_size, word_emb_size = embedding.shape 29 | max_time = sorted(train_set, reverse=True, key=lambda x: x['length'])[0] 30 | train_num = len(train_set) 31 | test_num = len(test_set) 32 | s_cnum = len(train_set.class_list) 33 | u_cnum = len(test_set.class_list) 34 | config = {} 35 | config['keep_prob'] = 0.5 # embedding dropout keep rate 36 | config['hidden_size'] = 64 # embedding vector size 37 | config['batch_size'] = 32 # vocab size of word vectors 38 | config['vocab_size'] = vocab_size # vocab size (10895) after subtracting padding 39 | config['num_epochs'] = 5 # number of epochs 40 | config['max_time'] = max_time 41 | config['sample_num'] = train_num # sample number of training data 42 | config['test_num'] = test_num # number of test data 43 | config['s_cnum'] = s_cnum # seen class num 44 | config['u_cnum'] = u_cnum # unseen class num 45 | config['word_emb_size'] = word_emb_size # embedding size of word vectors (300) 46 | config['d_a'] = 20 # self-attention weight hidden units number 47 | config['output_atoms'] = 10 # capsule output atoms 48 | config['r'] = 3 # self-attention weight hops 49 | config['num_routing'] = 2 # capsule routing num 50 | config['alpha'] = 0.0001 # coefficient of self-attention loss 51 | config['margin'] = 1.0 # ranking loss margin 52 | config['learning_rate'] = 0.00005 53 | config['sim_scale'] = 4 # sim scale 54 | config['nlayers'] = 2 # default for bilstm 55 | config['ckpt_dir'] = './saved_models/' # check point dir 56 | return config 57 | 58 | def get_sim(train_set, test_set): 59 | """ 60 | get unseen and seen categories similarity. 61 | """ 62 | seen = normalize(torch.stack(list(train_set.class_w2v.values()))) 63 | unseen = normalize(torch.stack(list(test_set.class_w2v.values()))) 64 | sim = utils.compute_label_sim(unseen, seen, config['sim_scale']) 65 | return torch.from_numpy(sim) 66 | 67 | def _squash(input_tensor): 68 | norm = torch.norm(input_tensor, dim=2, keepdim=True) 69 | norm_squared = norm * norm 70 | return (input_tensor / norm) * (norm_squared / (0.5 + norm_squared)) 71 | 72 | def update_unseen_routing(votes, config, num_routing=3): 73 | votes_t_shape = [3, 0, 1, 2] 74 | r_t_shape = [1, 2, 3, 0] 75 | votes_trans = votes.permute(votes_t_shape) 76 | num_dims = 4 77 | input_dim = config['r'] 78 | output_dim = config['u_cnum'] 79 | input_shape = votes.shape 80 | logit_shape = np.stack([input_shape[0], input_dim, output_dim]) 81 | logits = torch.zeros(logit_shape[0], logit_shape[1], logit_shape[2]).cuda() 82 | activations = [] 83 | 84 | for iteration in range(num_routing): 85 | route = F.softmax(logits, dim=2).cuda() 86 | preactivate_unrolled = route * votes_trans 87 | preact_trans = preactivate_unrolled.permute(r_t_shape) 88 | 89 | # delete bias to fit for unseen classes 90 | preactivate = torch.sum(preact_trans, dim=1) 91 | activation = _squash(preactivate) 92 | # activations = activations.write(i, activation) 93 | activations.append(activation) 94 | # distances: [batch, input_dim, output_dim] 95 | act_3d = torch.unsqueeze(activation, 1) 96 | tile_shape = np.ones(num_dims, dtype=np.int32).tolist() 97 | tile_shape[1] = input_dim 98 | act_replicated = act_3d.repeat(tile_shape) 99 | distances = torch.sum(votes * act_replicated, dim=3) 100 | logits = logits + distances 101 | 102 | return activations[num_routing-1], route 103 | 104 | data_prefix = '../data/nlu_data/' 105 | w2v_path = data_prefix + 'wiki.en.vec' 106 | training_data_path = data_prefix + 'train_shuffle.txt' 107 | test_data_path = data_prefix + 'test.txt' 108 | 109 | seen_classes = ['music', 'search', 'movie', 'weather', 'restaurant'] 110 | unseen_classes = ['playlist', 'book'] 111 | 112 | train_set = IntentDataset(seen_classes, w2v_path, training_data_path) 113 | test_set = IntentDataset(unseen_classes, w2v_path, test_data_path) 114 | 115 | embedding = train_set.embedding 116 | categorical = train_set.categorical 117 | config = setting(train_set, test_set, embedding) 118 | similarity = get_sim(train_set, test_set).to(device) 119 | 120 | train_loader = DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, 121 | collate_fn=batch_function, num_workers=4) 122 | test_loader = DataLoader(test_set, batch_size=config['batch_size'], shuffle=True, 123 | collate_fn=batch_function, num_workers=4) 124 | 125 | model = CapsuleNetwork(config).to(device) 126 | optimizer = optim.Adam(model.parameters(), lr=config['learning_rate']) 127 | #scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True, threshold=0.001) 128 | 129 | if os.path.exists(config['ckpt_dir'] + 'best_model.pt'): 130 | print("Restoring weights from previously trained rnn model.") 131 | model.load_state_dict(torch.load(config['ckpt_dir'] + 'best_model.pt' )) 132 | else: 133 | print('Initializing Variables') 134 | if not os.path.exists(config['ckpt_dir']): 135 | os.mkdir(config['ckpt_dir']) 136 | 137 | def train(epoch, train_loader, config, model, embedding, train_time): 138 | model.train() 139 | avg_acc = 0 140 | avg_loss = 0 141 | start_time = time.time() 142 | 143 | for idx, batch in enumerate(train_loader): 144 | input = batch.sentences_w2v.cuda() 145 | lengths = batch.lengths 146 | target = batch.label_onehot.cuda() 147 | label_w2v = batch.label_w2v 148 | 149 | batch_size = len(input) 150 | hc = (Variable(torch.zeros(4, input.shape[0], config['hidden_size'])).cuda(), 151 | Variable(torch.zeros(4, input.shape[0], config['hidden_size'])).cuda()) 152 | 153 | output = model(input, lengths, embedding.cuda(), hc) 154 | loss = model.loss(target.float()) 155 | 156 | optimizer.zero_grad() 157 | loss.backward() 158 | optimizer.step() 159 | 160 | clone_logits = model.logits.detach().clone() 161 | pred = torch.argmax(clone_logits, 1).cpu() 162 | acc = accuracy_score(categorical(target.cpu()), pred) 163 | print("Epoch: {}\t| Batch: {:03d}/{}\t| Batch Loss: {}\t| Acc: {}%".format( 164 | epoch, (idx+1), len(train_loader), round(loss.item(), 4), round(acc * 100., 2))) 165 | avg_loss += loss.item() 166 | avg_acc += acc 167 | 168 | epoch_time = time.time() - start_time 169 | train_time += epoch_time 170 | avg_loss /= len(train_loader) 171 | avg_acc /= len(train_loader) 172 | 173 | print("Epoch: {}\t| Average Loss: {}\t| Average Acc: {}%\t| Train Time: {}s".format( 174 | epoch, round(avg_loss, 4), round(avg_acc * 100., 2), round(train_time, 2))) 175 | 176 | return avg_loss, avg_acc, train_time 177 | 178 | def test(epoch, test_loader, config, model, embedding, similarity): 179 | # zero-shot testing state 180 | # seen votes shape (110, 2, 34, 10) 181 | # get unseen and seen categories similarity 182 | # sim shape (8, 34) 183 | model.eval() 184 | start_time = time.time() 185 | 186 | with torch.no_grad(): 187 | for idx, batch in enumerate(test_loader): 188 | input = batch.sentences_w2v.cuda() 189 | lengths = batch.lengths 190 | target = batch.label_onehot.long().cuda() 191 | label_w2v = batch.label_w2v 192 | 193 | batch_size = len(input) 194 | hc = (Variable(torch.zeros(4, input.shape[0], config['hidden_size'])).cuda(), 195 | Variable(torch.zeros(4, input.shape[0], config['hidden_size'])).cuda()) 196 | 197 | output = model.forward(input, lengths, embedding.cuda(), hc) 198 | attentions, seen_logits, seen_votes, seen_weights_c = model.attention, model.logits, \ 199 | model.votes, model.weights_c 200 | sim = similarity.unsqueeze(0) 201 | sim = sim.repeat(seen_votes.shape[1], 1, 1).unsqueeze(0) 202 | sim = sim.repeat(seen_votes.shape[0], 1, 1, 1) 203 | seen_weights_c = seen_weights_c.unsqueeze(-1) 204 | seen_weights_c = seen_weights_c.repeat(1, 1, 1, config['output_atoms']) 205 | mul = seen_votes * seen_weights_c 206 | 207 | # compute unseen features 208 | # unseen votes shape (110, 2, 8, 10) 209 | unseen_votes = torch.matmul(sim, mul) 210 | 211 | # routing unseen classes 212 | u_activations, u_weights_c = update_unseen_routing(unseen_votes, config, 3) 213 | unseen_logits = torch.norm(u_activations, dim=-1) 214 | batch_pred = torch.argmax(unseen_logits, dim=1).unsqueeze(1).cuda() 215 | 216 | if idx == 0: 217 | total_pred = batch_pred 218 | total_target = target 219 | else: 220 | total_pred = torch.cat((total_pred.cuda(), batch_pred)) 221 | total_target = torch.cat((total_target.cuda(), target)) 222 | 223 | print (" zero-shot intent detection test set performance ") 224 | cpu_target = categorical(total_target.cpu()) 225 | cpu_pred = total_pred.flatten().cpu() 226 | acc = accuracy_score(cpu_target, cpu_pred) 227 | print (classification_report(cpu_target, cpu_pred, digits=4)) 228 | 229 | test_time = time.time() - start_time 230 | return acc, test_time 231 | 232 | best_acc = 0 233 | train_time, test_time = 0, 0 234 | 235 | for epoch in range(1, config['num_epochs'] + 1): 236 | train_loss, train_acc, train_time = train(epoch, train_loader, config, model, embedding, train_time) 237 | test_acc, test_time = test(epoch, test_loader, config, model, embedding, similarity) 238 | #scheduler.step(test_acc) 239 | 240 | if test_acc > best_acc: 241 | best_acc = test_acc 242 | torch.save(model.state_dict(), config['ckpt_dir'] + 'best_model.pt') 243 | 244 | print("test_acc", test_acc) 245 | print("best_acc", best_acc) 246 | print("Testing time", round(test_time, 4)) 247 | 248 | print("Overall training time", train_time) 249 | print("Overall testing time", test_time) -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.spatial.distance as ds 3 | 4 | def norm_matrix(matrix): 5 | """Nomralize matrix by column 6 | input: numpy array, dtype = float32 7 | output: normalized numpy array, dtype = float32 8 | """ 9 | 10 | # check dtype of the input matrix 11 | np.testing.assert_equal(type(matrix).__name__, 'ndarray') 12 | np.testing.assert_equal(matrix.dtype, np.float32) 13 | # axis = 0 across rows (return size is column length) 14 | row_sums = matrix.sum(axis = 1) # across columns (return size = row length) 15 | 16 | # Replace zero denominator 17 | row_sums[row_sums == 0] = 1 18 | #start:stop:step (:: === :) 19 | #[:,np.newaxis]: expand dimensions of resulting selection by one unit-length dimension 20 | # Added dimension is position of the newaxis object in the selection tuple 21 | norm_matrix = matrix / row_sums[:, np.newaxis] 22 | 23 | return norm_matrix 24 | 25 | def replace_nan(X): 26 | """ 27 | replace nan and inf o 0 28 | """ 29 | X[np.isnan(X)] = 0 30 | X[np.isnan(X)] = 0 31 | 32 | return X 33 | 34 | def compute_label_sim(sig_y1, sig_y2, sim_scale): 35 | """ 36 | compute class label similarity 37 | """ 38 | dist = ds.cdist(sig_y1, sig_y2, 'euclidean') 39 | dist = dist.astype(np.float32) 40 | similarity = np.exp(-np.square(dist) * sim_scale) 41 | s = np.sum(similarity, axis=1) 42 | similarity = replace_nan(similarity / s[:, None]) 43 | 44 | return similarity 45 | --------------------------------------------------------------------------------