├── EUR_Cap.py ├── EUR_Cap_grad.py ├── EUR_eval.py ├── README.md ├── data_helpers.py ├── layer.py ├── network.py ├── utils.py └── w2v.py /EUR_Cap.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import json 8 | import random 9 | import time 10 | from torch.autograd import Variable 11 | from torch.optim import Adam 12 | from network import CapsNet_Text,BCE_loss 13 | from w2v import load_word2vec 14 | import data_helpers 15 | 16 | 17 | torch.manual_seed(0) 18 | torch.cuda.manual_seed(0) 19 | np.random.seed(0) 20 | random.seed(0) 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p', 25 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p') 26 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size') 27 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size') 28 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents') 29 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled') 30 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs') 31 | parser.add_argument('--tr_batch_size', type=int, default=256, help='Batch size for training') 32 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training') 33 | parser.add_argument('--start_from', type=str, default='', help='') 34 | 35 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules') 36 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules') 37 | 38 | parser.add_argument('--learning_rate_decay_start', type=int, default=0, 39 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 40 | parser.add_argument('--learning_rate_decay_every', type=int, default=20, 41 | help='how many iterations thereafter to drop LR?(in epoch)') 42 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.95, 43 | help='how many iterations thereafter to drop LR?(in epoch)') 44 | 45 | 46 | 47 | args = parser.parse_args() 48 | params = vars(args) 49 | print(json.dumps(params, indent = 2)) 50 | 51 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset, 52 | max_length=args.sequence_length, 53 | vocab_size=args.vocab_size) 54 | Y_trn = Y_trn.toarray() 55 | Y_tst = Y_tst.toarray() 56 | 57 | X_trn = X_trn.astype(np.int32) 58 | X_tst = X_tst.astype(np.int32) 59 | Y_trn = Y_trn.astype(np.int32) 60 | Y_tst = Y_tst.astype(np.int32) 61 | 62 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size) 63 | 64 | args.num_classes = Y_trn.shape[1] 65 | 66 | capsule_net = CapsNet_Text(args, embedding_weights) 67 | capsule_net = nn.DataParallel(capsule_net).cuda() 68 | 69 | 70 | def transformLabels(labels): 71 | label_index = list(set([l for _ in labels for l in _])) 72 | label_index.sort() 73 | 74 | variable_num_classes = len(label_index) 75 | target = [] 76 | for _ in labels: 77 | tmp = np.zeros([variable_num_classes], dtype=np.float32) 78 | tmp[[label_index.index(l) for l in _]] = 1 79 | target.append(tmp) 80 | target = np.array(target) 81 | return label_index, target 82 | 83 | current_lr = args.learning_rate 84 | 85 | optimizer = Adam(capsule_net.parameters(), lr=current_lr) 86 | 87 | def set_lr(optimizer, lr): 88 | for group in optimizer.param_groups: 89 | group['lr'] = lr 90 | 91 | for epoch in range(args.num_epochs): 92 | torch.cuda.empty_cache() 93 | 94 | nr_trn_num = X_trn.shape[0] 95 | nr_batches = int(np.ceil(nr_trn_num / float(args.tr_batch_size))) 96 | 97 | if epoch > args.learning_rate_decay_start and args.learning_rate_decay_start >= 0: 98 | frac = (epoch - args.learning_rate_decay_start) // args.learning_rate_decay_every 99 | decay_factor = args.learning_rate_decay_rate ** frac 100 | current_lr = current_lr * decay_factor 101 | print(current_lr) 102 | set_lr(optimizer, current_lr) 103 | 104 | capsule_net.train() 105 | for iteration, batch_idx in enumerate(np.random.permutation(range(nr_batches))): 106 | start = time.time() 107 | start_idx = batch_idx * args.tr_batch_size 108 | end_idx = min((batch_idx + 1) * args.tr_batch_size, nr_trn_num) 109 | 110 | X = X_trn[start_idx:end_idx] 111 | Y = Y_trn_o[start_idx:end_idx] 112 | data = Variable(torch.from_numpy(X).long()).cuda() 113 | 114 | batch_labels, batch_target = transformLabels(Y) 115 | batch_target = Variable(torch.from_numpy(batch_target).float()).cuda() 116 | optimizer.zero_grad() 117 | poses, activations = capsule_net(data, batch_labels) 118 | loss = BCE_loss(activations, batch_target) 119 | loss.backward() 120 | optimizer.step() 121 | torch.cuda.empty_cache() 122 | done = time.time() 123 | elapsed = done - start 124 | 125 | print("\rIteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format( 126 | iteration, nr_batches, 127 | iteration * 100 / nr_batches, 128 | loss.item(), elapsed), 129 | end="") 130 | 131 | torch.cuda.empty_cache() 132 | 133 | if (epoch + 1) > 20: 134 | checkpoint_path = os.path.join('save', 'model-eur-akde-' + str(epoch + 1) + '.pth') 135 | torch.save(capsule_net.state_dict(), checkpoint_path) 136 | print("model saved to {}".format(checkpoint_path)) 137 | 138 | -------------------------------------------------------------------------------- /EUR_Cap_grad.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import os 7 | import json 8 | import random 9 | import time 10 | from torch.autograd import Variable 11 | from torch.optim import Adam 12 | from network import CapsNet_Text,BCE_loss, CNN_KIM 13 | from w2v import load_word2vec 14 | import data_helpers 15 | 16 | 17 | torch.manual_seed(0) 18 | torch.cuda.manual_seed(0) 19 | np.random.seed(0) 20 | random.seed(0) 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p', 25 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p') 26 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size') 27 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size') 28 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents') 29 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled') 30 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs') 31 | parser.add_argument('--tr_batch_size', type=int, default=256, help='Batch size for training') 32 | parser.add_argument('--ts_batch_size', type=int, default=16, help='Batch size for training') 33 | 34 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training') 35 | parser.add_argument('--start_from', type=str, default='', help='') 36 | 37 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules') 38 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules') 39 | 40 | parser.add_argument('--learning_rate_decay_start', type=int, default=0, 41 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)') 42 | parser.add_argument('--learning_rate_decay_every', type=int, default=20, 43 | help='how many iterations thereafter to drop LR?(in epoch)') 44 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.95, 45 | help='how many iterations thereafter to drop LR?(in epoch)') 46 | 47 | parser.add_argument('--gradient_accumulation_steps', type=int, default=8) 48 | 49 | parser.add_argument('--re_ranking', type=int, default=200, help='The number of re-ranking size') 50 | 51 | 52 | args = parser.parse_args() 53 | params = vars(args) 54 | print(json.dumps(params, indent = 2)) 55 | 56 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset, 57 | max_length=args.sequence_length, 58 | vocab_size=args.vocab_size) 59 | Y_trn = Y_trn.toarray() 60 | Y_tst = Y_tst.toarray() 61 | 62 | X_trn = X_trn.astype(np.int32) 63 | X_tst = X_tst.astype(np.int32) 64 | Y_trn = Y_trn.astype(np.int32) 65 | Y_tst = Y_tst.astype(np.int32) 66 | 67 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size) 68 | 69 | args.num_classes = Y_trn.shape[1] 70 | 71 | capsule_net = CapsNet_Text(args, embedding_weights) 72 | capsule_net = nn.DataParallel(capsule_net).cuda() 73 | 74 | model_name = 'model-EUR-CNN-40.pth' 75 | baseline = CNN_KIM(args, embedding_weights) 76 | baseline.load_state_dict(torch.load(os.path.join('save_new', model_name))) 77 | baseline = nn.DataParallel(baseline).cuda() 78 | print(model_name + ' loaded') 79 | 80 | def transformLabels(labels, total_labels): 81 | label_index = list(set([l for _ in total_labels for l in _])) 82 | label_index.sort() 83 | 84 | variable_num_classes = len(label_index) 85 | target = [] 86 | for _ in labels: 87 | tmp = np.zeros([variable_num_classes], dtype=np.float32) 88 | tmp[[label_index.index(l) for l in _]] = 1 89 | target.append(tmp) 90 | target = np.array(target) 91 | return label_index, target 92 | 93 | current_lr = args.learning_rate 94 | 95 | optimizer = Adam(capsule_net.parameters(), lr=current_lr) 96 | 97 | def set_lr(optimizer, lr): 98 | for group in optimizer.param_groups: 99 | group['lr'] = lr 100 | 101 | from network import CNN_KIM,CapsNet_Text 102 | import random 103 | from utils import evaluate 104 | import data_helpers 105 | import scipy.sparse as sp 106 | from w2v import load_word2vec 107 | import os 108 | 109 | for epoch in range(args.num_epochs): 110 | 111 | nr_trn_num = X_trn.shape[0] 112 | nr_batches = int(np.ceil(nr_trn_num / float(args.tr_batch_size))) 113 | 114 | if epoch > args.learning_rate_decay_start and args.learning_rate_decay_start >= 0: 115 | frac = (epoch - args.learning_rate_decay_start) // args.learning_rate_decay_every 116 | decay_factor = args.learning_rate_decay_rate ** frac 117 | current_lr = current_lr * decay_factor 118 | print(current_lr) 119 | set_lr(optimizer, current_lr) 120 | 121 | capsule_net.train() 122 | for iteration, batch_idx in enumerate(np.random.permutation(range(nr_batches))): 123 | start = time.time() 124 | start_idx = batch_idx * args.tr_batch_size 125 | end_idx = min((batch_idx + 1) * args.tr_batch_size, nr_trn_num) 126 | 127 | X = X_trn[start_idx:end_idx] 128 | Y = Y_trn_o[start_idx:end_idx] 129 | 130 | batch_steps = int(np.ceil(len(X)) / (float(args.tr_batch_size) / float(args.gradient_accumulation_steps))) 131 | batch_loss = 0 132 | for i in range(batch_steps): 133 | step_size = int(float(args.tr_batch_size) // float(args.gradient_accumulation_steps)) 134 | step_X = X[i * step_size: (i+1) * step_size] 135 | step_Y = Y[i * step_size: (i+1) * step_size] 136 | 137 | step_X = Variable(torch.from_numpy(step_X).long()).cuda() 138 | step_labels, step_target = transformLabels(step_Y, Y) 139 | step_target = Variable(torch.from_numpy(step_target).float()).cuda() 140 | 141 | poses, activations = capsule_net(step_X, step_labels) 142 | step_loss = BCE_loss(activations, step_target) 143 | step_loss = step_loss / args.gradient_accumulation_steps 144 | step_loss.backward() 145 | batch_loss += step_loss.item() 146 | 147 | optimizer.step() 148 | optimizer.zero_grad() 149 | done = time.time() 150 | elapsed = done - start 151 | 152 | print("\rIteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format( 153 | iteration, nr_batches, 154 | iteration * 100 / nr_batches, 155 | batch_loss, elapsed), 156 | end="") 157 | 158 | if (epoch + 1) > 20 and (epoch + 1)<30: 159 | 160 | nr_tst_num = X_tst.shape[0] 161 | nr_batches = int(np.ceil(nr_tst_num / float(args.ts_batch_size))) 162 | 163 | n, k_trn = Y_trn.shape 164 | m, k_tst = Y_tst.shape 165 | print ('k_trn:', k_trn) 166 | print ('k_tst:', k_tst) 167 | 168 | capsule_net.eval() 169 | top_k = 50 170 | row_idx_list, col_idx_list, val_idx_list = [], [], [] 171 | for batch_idx in range(nr_batches): 172 | start = time.time() 173 | start_idx = batch_idx * args.ts_batch_size 174 | end_idx = min((batch_idx + 1) * args.ts_batch_size, nr_tst_num) 175 | X = X_tst[start_idx:end_idx] 176 | Y = Y_tst_o[start_idx:end_idx] 177 | data = Variable(torch.from_numpy(X).long()).cuda() 178 | 179 | candidates = baseline(data) 180 | candidates = candidates.data.cpu().numpy() 181 | 182 | Y_pred = np.zeros([candidates.shape[0], args.num_classes]) 183 | for i in range(candidates.shape[0]): 184 | candidate_labels = candidates[i, :].argsort()[-args.re_ranking:][::-1].tolist() 185 | _, activations_2nd = capsule_net(data[i, :].unsqueeze(0), candidate_labels) 186 | Y_pred[i, candidate_labels] = activations_2nd.squeeze(2).data.cpu().numpy() 187 | 188 | for i in range(Y_pred.shape[0]): 189 | sorted_idx = np.argpartition(-Y_pred[i, :], top_k)[:top_k] 190 | row_idx_list += [i + start_idx] * top_k 191 | col_idx_list += (sorted_idx).tolist() 192 | val_idx_list += Y_pred[i, sorted_idx].tolist() 193 | 194 | done = time.time() 195 | elapsed = done - start 196 | 197 | print("\r Epoch: {} Reranking: {} Iteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format( 198 | (epoch + 1), args.re_ranking, batch_idx, nr_batches, 199 | batch_idx * 100 / nr_batches, 200 | 0, elapsed), 201 | end="") 202 | 203 | m = max(row_idx_list) + 1 204 | n = max(k_trn, k_tst) 205 | print(elapsed) 206 | Y_tst_pred = sp.csr_matrix((val_idx_list, (row_idx_list, col_idx_list)), shape=(m, n)) 207 | 208 | if k_trn >= k_tst: 209 | Y_tst_pred = Y_tst_pred[:, :k_tst] 210 | 211 | evaluate(Y_tst_pred.toarray(), Y_tst) 212 | 213 | # checkpoint_path = os.path.join('save_new', 'model-eur-akde-' + str(epoch + 1) + '.pth') 214 | # torch.save(capsule_net.state_dict(), checkpoint_path) 215 | # print("model saved to {}".format(checkpoint_path)) 216 | 217 | -------------------------------------------------------------------------------- /EUR_eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, unicode_literals 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | from network import CNN_KIM,CapsNet_Text 8 | import random 9 | import time 10 | from utils import evaluate 11 | import data_helpers 12 | import scipy.sparse as sp 13 | from w2v import load_word2vec 14 | import os 15 | 16 | torch.manual_seed(0) 17 | torch.cuda.manual_seed(0) 18 | np.random.seed(0) 19 | random.seed(0) 20 | 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p', 24 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p') 25 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size') 26 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size') 27 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents') 28 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled') 29 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs') 30 | parser.add_argument('--ts_batch_size', type=int, default=32, help='Batch size for training') 31 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training') 32 | parser.add_argument('--start_from', type=str, default='save', help='') 33 | 34 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules') 35 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules') 36 | 37 | parser.add_argument('--re_ranking', type=int, default=200, help='The number of re-ranking size') 38 | 39 | import json 40 | args = parser.parse_args() 41 | params = vars(args) 42 | print(json.dumps(params, indent = 2)) 43 | 44 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset, 45 | max_length=args.sequence_length, 46 | vocab_size=args.vocab_size) 47 | Y_trn = Y_trn.toarray() 48 | Y_tst = Y_tst.toarray() 49 | 50 | X_trn = X_trn.astype(np.int32) 51 | X_tst = X_tst.astype(np.int32) 52 | Y_trn = Y_trn.astype(np.int32) 53 | Y_tst = Y_tst.astype(np.int32) 54 | 55 | 56 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size) 57 | args.num_classes = Y_trn.shape[1] 58 | 59 | capsule_net = CapsNet_Text(args, embedding_weights) 60 | capsule_net = nn.DataParallel(capsule_net).cuda() 61 | model_name = 'model-eur-akde-29.pth' 62 | capsule_net.load_state_dict(torch.load(os.path.join(args.start_from, model_name))) 63 | print(model_name + ' loaded') 64 | 65 | 66 | model_name = 'model-EUR-CNN-40.pth' 67 | baseline = CNN_KIM(args, embedding_weights) 68 | baseline.load_state_dict(torch.load(os.path.join(args.start_from, model_name))) 69 | baseline = nn.DataParallel(baseline).cuda() 70 | print(model_name + ' loaded') 71 | 72 | 73 | nr_tst_num = X_tst.shape[0] 74 | nr_batches = int(np.ceil(nr_tst_num / float(args.ts_batch_size))) 75 | 76 | n, k_trn = Y_trn.shape 77 | m, k_tst = Y_tst.shape 78 | print ('k_trn:', k_trn) 79 | print ('k_tst:', k_tst) 80 | 81 | capsule_net.eval() 82 | top_k = 50 83 | row_idx_list, col_idx_list, val_idx_list = [], [], [] 84 | for batch_idx in range(nr_batches): 85 | start = time.time() 86 | start_idx = batch_idx * args.ts_batch_size 87 | end_idx = min((batch_idx + 1) * args.ts_batch_size, nr_tst_num) 88 | X = X_tst[start_idx:end_idx] 89 | Y = Y_tst_o[start_idx:end_idx] 90 | data = Variable(torch.from_numpy(X).long()).cuda() 91 | 92 | candidates = baseline(data) 93 | candidates = candidates.data.cpu().numpy() 94 | 95 | Y_pred = np.zeros([candidates.shape[0], args.num_classes]) 96 | for i in range(candidates.shape[0]): 97 | candidate_labels = candidates[i, :].argsort()[-args.re_ranking:][::-1].tolist() 98 | _, activations_2nd = capsule_net(data[i, :].unsqueeze(0), candidate_labels) 99 | Y_pred[i, candidate_labels] = activations_2nd.squeeze(2).data.cpu().numpy() 100 | 101 | for i in range(Y_pred.shape[0]): 102 | sorted_idx = np.argpartition(-Y_pred[i, :], top_k)[:top_k] 103 | row_idx_list += [i + start_idx] * top_k 104 | col_idx_list += (sorted_idx).tolist() 105 | val_idx_list += Y_pred[i, sorted_idx].tolist() 106 | 107 | done = time.time() 108 | elapsed = done - start 109 | 110 | print("\r Reranking: {} Iteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format( 111 | args.re_ranking, batch_idx, nr_batches, 112 | batch_idx * 100 / nr_batches, 113 | 0, elapsed), 114 | end="") 115 | 116 | m = max(row_idx_list) + 1 117 | n = max(k_trn, k_tst) 118 | print(elapsed) 119 | Y_tst_pred = sp.csr_matrix((val_idx_list, (row_idx_list, col_idx_list)), shape=(m, n)) 120 | 121 | if k_trn >= k_tst: 122 | Y_tst_pred = Y_tst_pred[:, :k_tst] 123 | 124 | evaluate(Y_tst_pred.toarray(), Y_tst) 125 | 126 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Scalable and Reliable Capsule Networks for Challenging NLP Applications 2 | 3 | ACL-19: https://www.aclweb.org/anthology/P19-1150/ 4 | 5 | Requirements: Code is written in Python 3 and requires Pytorch. 6 | 7 | # Preparation 8 | For quick start, please download [the dataset and trained model](https://drive.google.com/open?id=1gPYAMyYo4YLrmx_Egc9wjCqzWx15D7U8). 9 | 10 | # Code Explanation 11 | The data_helpers implements the functions for data processing. 12 | 13 | The layers.py implements all the main functions of capsule network, including KDE routing, Adaptive KDE routing, Primary Capsule layer and etc. 14 | 15 | The network.py provides the wrapper of our model as well as baseline models for the comparison. 16 | 17 | The utils.py provides all the evaluation functions such as Precision@1,3,5 and NDCG@1,3,5. 18 | 19 | The EUR_Cap.py and EUR_eval.py are for training and inference, respectively. 20 | # Quick start 21 | 22 | ```bash 23 | CUDA_VISIBLE_DEVICES=0 python EUR_eval.py 24 | 25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python EUR_Cap.py 26 | 27 | CUDA_VISIBLE_DEVICES=0 python EUR_Cap_grad.py # train CapNet on single GPU with accumulated gradients 28 | ``` 29 | 30 | # Performance on EUR-Lex dataset 31 | 32 | ```bash 33 | NLP-Capsule with Adaptive KDE routing: 34 | 35 | Epoch: 20 Iteration: 120/121 (99.2%) Loss: 0.00000 0.33459 36 | Tst Prec@1,3,5: [0.7948253557567917, 0.65605864596808838, 0.53666235446312649] 37 | Tst NDCG@1,3,5: [0.7948253557567917, 0.70826730037244034, 0.6843311797551882] 38 | 39 | Epoch: 21 Iteration: 120/121 (99.2%) Loss: 0.00000 0.24704 40 | Tst Prec@1,3,5: [0.79301423027166884, 0.6552824493316064, 0.53666235446312793] 41 | Tst NDCG@1,3,5: [0.79301423027166884, 0.70672871614554134, 0.68443643153244704] 42 | 43 | Epoch: 22 Iteration: 120/121 (99.2%) Loss: 0.00000 0.24949 44 | Tst Prec@1,3,5: [0.79404915912031049, 0.65554118154376773, 0.53800776196636135] 45 | Tst NDCG@1,3,5: [0.79404915912031049, 0.70816714976829975, 0.68780244631961929] 46 | 47 | Epoch: 23 Iteration: 120/121 (99.2%) Loss: 0.00000 0.25533 48 | Tst Prec@1,3,5: [0.8046571798188874, 0.65890470030185422, 0.53604139715394228] 49 | Tst NDCG@1,3,5: [0.8046571798188874, 0.71380071010660562, 0.69040247647419262] 50 | 51 | Epoch: 24 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26880 52 | Tst Prec@1,3,5: [0.80620957309184993, 0.65614489003880982, 0.53661060802069527] 53 | Tst NDCG@1,3,5: [0.80620957309184993, 0.7133596479633022, 0.69571103238443532] 54 | 55 | Epoch: 25 Iteration: 120/121 (99.2%) Loss: 0.00000 0.25847 56 | Tst Prec@1,3,5: [0.80155239327296246, 0.65329883570504454, 0.53448900388098108] 57 | Tst NDCG@1,3,5: [0.80155239327296246, 0.7096033706441367, 0.69201706652281636] 58 | 59 | Epoch: 26 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26063 60 | Tst Prec@1,3,5: [0.80000000000000004, 0.65381630012936431, 0.53350582147477121] 61 | Tst NDCG@1,3,5: [0.80000000000000004, 0.71043623399753963, 0.69499344732549306] 62 | 63 | Epoch: 27 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26004 64 | Tst Prec@1,3,5: [0.79689521345407499, 0.65398878827080587, 0.53376455368693132] 65 | Tst NDCG@1,3,5: [0.79689521345407499, 0.71269493382033577, 0.69812854866301688] 66 | 67 | Epoch: 28 Iteration: 120/121 (99.2%) Loss: 0.00000 0.27287 68 | Tst Prec@1,3,5: [0.79818887451487708, 0.65588615782664883, 0.53500646830530163] 69 | Tst NDCG@1,3,5: [0.79818887451487708, 0.71429911265714374, 0.70057615675866636] 70 | 71 | 72 | XML-CNN: 73 | Epoch: 31 Iteration: 45/46 (97.8%) Loss: 0.00006 0.15460 74 | Tst Prec@1,3,5: [0.7583441138421734, 0.6164726175075479, 0.5073738680465716] 75 | Tst NDCG@1,3,5: [0.7583441138421734, 0.6661232856458101, 0.644838787586548] 76 | 77 | Epoch: 32 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15354 78 | Tst Prec@1,3,5: [0.759379042690815, 0.6143165157395448, 0.5062871927554978] 79 | Tst NDCG@1,3,5: [0.759379042690815, 0.6648180435110952, 0.6434396675410785] 80 | 81 | Epoch: 33 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15399 82 | Tst Prec@1,3,5: [0.757567917205692, 0.6169038378611481, 0.507373868046571] 83 | Tst NDCG@1,3,5: [0.757567917205692, 0.666160785036582, 0.6440332351720106] 84 | 85 | Epoch: 34 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15153 86 | Tst Prec@1,3,5: [0.7573091849935317, 0.616645105648988, 0.5099094437257432] 87 | Tst NDCG@1,3,5: [0.7573091849935317, 0.6659194956789641, 0.6458294426678642] 88 | 89 | Epoch: 35 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15212 90 | Tst Prec@1,3,5: [0.7552393272962484, 0.6153514445881856, 0.5092367399741262] 91 | Tst NDCG@1,3,5: [0.7552393272962484, 0.6648419426927356, 0.6453632713906606] 92 | 93 | Epoch: 36 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15231 94 | Tst Prec@1,3,5: [0.7596377749029755, 0.6157826649417857, 0.5093402328589907] 95 | Tst NDCG@1,3,5: [0.7596377749029755, 0.6661452963066051, 0.646133349811576] 96 | 97 | Epoch: 37 Iteration: 45/46 (97.8%) Loss: 0.00006 0.15357 98 | Tst Prec@1,3,5: [0.7570504527813713, 0.6175937904269097, 0.5088227684346699] 99 | Tst NDCG@1,3,5: [0.7570504527813713, 0.6670823259018512, 0.6455866525334287] 100 | 101 | Epoch: 38 Iteration: 45/46 (97.8%) Loss: 0.00006 0.16400 102 | Tst Prec@1,3,5: [0.7583441138421734, 0.6162138852953867, 0.5085122897800777] 103 | Tst NDCG@1,3,5: [0.7583441138421734, 0.6658377730303046, 0.6448260229129755] 104 | 105 | Epoch: 39 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15555 106 | Tst Prec@1,3,5: [0.7578266494178525, 0.6173350582147488, 0.509029754204398] 107 | Tst NDCG@1,3,5: [0.7578266494178525, 0.6667396690496684, 0.645590263852396] 108 | 109 | Epoch: 40 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15414 110 | Tst Prec@1,3,5: [0.7565329883570504, 0.61811125485123, 0.5087192755498058] 111 | Tst NDCG@1,3,5: [0.7565329883570504, 0.6674559324640292, 0.6452839523583206] 112 | 113 | ``` 114 | 115 | # Reference 116 | If you find our source code useful, please consider citing our work. 117 | ``` 118 | @inproceedings{zhao2019capsule, 119 | title = "Towards Scalable and Reliable Capsule Networks for Challenging {NLP} Applications", 120 | author = "Zhao, Wei and Peng, Haiyun and Eger, Steffen and Cambria, Erik and Yang, Min", 121 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics", 122 | month = jul, 123 | year = "2019", 124 | address = "Florence, Italy", 125 | publisher = "Association for Computational Linguistics", 126 | url = "https://www.aclweb.org/anthology/P19-1150", 127 | doi = "10.18653/v1/P19-1150", 128 | pages = "1549--1559" 129 | } 130 | ``` 131 | -------------------------------------------------------------------------------- /data_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import re 4 | import itertools 5 | import scipy.sparse as sp 6 | import cPickle as pickle 7 | from collections import Counter 8 | from nltk.corpus import stopwords 9 | 10 | cachedStopWords = stopwords.words("english") 11 | 12 | 13 | def clean_str(string): 14 | # remove stopwords 15 | # string = ' '.join([word for word in string.split() if word not in cachedStopWords]) 16 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string) 17 | string = re.sub(r"\'s", " \'s", string) 18 | string = re.sub(r"\'ve", " \'ve", string) 19 | string = re.sub(r"n\'t", " n\'t", string) 20 | string = re.sub(r"\'re", " \'re", string) 21 | string = re.sub(r"\'d", " \'d", string) 22 | string = re.sub(r"\'ll", " \'ll", string) 23 | string = re.sub(r",", " , ", string) 24 | string = re.sub(r"!", " ! ", string) 25 | string = re.sub(r"\(", " \( ", string) 26 | string = re.sub(r"\)", " \) ", string) 27 | string = re.sub(r"\?", " \? ", string) 28 | string = re.sub(r"\s{2,}", " ", string) 29 | return string.strip().lower() 30 | 31 | 32 | def pad_sentences(sentences, padding_word="", max_length=500): 33 | sequence_length = min(max(len(x) for x in sentences), max_length) 34 | padded_sentences = [] 35 | for i in range(len(sentences)): 36 | sentence = sentences[i] 37 | if len(sentence) < max_length: 38 | num_padding = sequence_length - len(sentence) 39 | new_sentence = sentence + [padding_word] * num_padding 40 | else: 41 | new_sentence = sentence[:max_length] 42 | padded_sentences.append(new_sentence) 43 | return padded_sentences 44 | 45 | 46 | def load_data_and_labels(data): 47 | x_text = [clean_str(doc['text']) for doc in data] 48 | x_text = [s.split(" ") for s in x_text] 49 | labels = [doc['catgy'] for doc in data] 50 | row_idx, col_idx, val_idx = [], [], [] 51 | for i in range(len(labels)): 52 | l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count 53 | for y in l_list: 54 | row_idx.append(i) 55 | col_idx.append(y) 56 | val_idx.append(1) 57 | m = max(row_idx) + 1 58 | n = max(col_idx) + 1 59 | Y = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n)) 60 | return [x_text, Y, labels] 61 | 62 | 63 | def build_vocab(sentences, vocab_size=50000): 64 | word_counts = Counter(itertools.chain(*sentences)) 65 | vocabulary_inv = [x[0] for x in word_counts.most_common(vocab_size)] 66 | vocabulary = {x: i for i, x in enumerate(vocabulary_inv)} 67 | # append symbol to the vocabulary 68 | vocabulary[''] = len(vocabulary) 69 | vocabulary_inv.append('') 70 | return [vocabulary, vocabulary_inv] 71 | 72 | 73 | def build_input_data(sentences, vocabulary): 74 | x = np.array([[vocabulary[word] if word in vocabulary else vocabulary[''] for word in sentence] for sentence in sentences]) 75 | #x = np.array([[vocabulary[word] if word in vocabulary else len(vocabulary) for word in sentence] for sentence in sentences]) 76 | return x 77 | 78 | 79 | def load_data(data_path, max_length=500, vocab_size=50000): 80 | # Load and preprocess data 81 | with open(os.path.join(data_path), 'rb') as fin: 82 | [train, test, vocab, catgy] = pickle.load(fin) 83 | 84 | # dirty trick to prevent errors happen when test is empty 85 | if len(test) == 0: 86 | test[:5] = train[:5] 87 | 88 | trn_sents, Y_trn, Y_trn_o = load_data_and_labels(train) 89 | tst_sents, Y_tst, Y_tst_o = load_data_and_labels(test) 90 | trn_sents_padded = pad_sentences(trn_sents, max_length=max_length) 91 | tst_sents_padded = pad_sentences(tst_sents, max_length=max_length) 92 | vocabulary, vocabulary_inv = build_vocab(trn_sents_padded + tst_sents_padded, vocab_size=vocab_size) 93 | X_trn = build_input_data(trn_sents_padded, vocabulary) 94 | X_tst = build_input_data(tst_sents_padded, vocabulary) 95 | return X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv 96 | # return X_trn, Y_trn, vocabulary, vocabulary_inv 97 | 98 | 99 | def batch_iter(data, batch_size, num_epochs): 100 | """ 101 | Generates a batch iterator for a dataset. 102 | """ 103 | data = np.array(data) 104 | data_size = len(data) 105 | num_batches_per_epoch = int(len(data)/batch_size) + 1 106 | for epoch in range(num_epochs): 107 | # Shuffle the data at each epoch 108 | shuffle_indices = np.random.permutation(np.arange(data_size)) 109 | shuffled_data = data[shuffle_indices] 110 | for batch_num in range(num_batches_per_epoch): 111 | start_index = batch_num * batch_size 112 | end_index = min((batch_num + 1) * batch_size, data_size) 113 | yield shuffled_data[start_index:end_index] 114 | -------------------------------------------------------------------------------- /layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | def squash_v1(x, axis): 8 | s_squared_norm = (x ** 2).sum(axis, keepdim=True) 9 | scale = torch.sqrt(s_squared_norm)/ (0.5 + s_squared_norm) 10 | return scale * x 11 | 12 | def dynamic_routing(batch_size, b_ij, u_hat, input_capsule_num): 13 | num_iterations = 3 14 | 15 | for i in range(num_iterations): 16 | if True: 17 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True) 18 | leaky_logits = torch.cat((leak, b_ij),2) 19 | leaky_routing = F.softmax(leaky_logits, dim=2) 20 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4) 21 | else: 22 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4) 23 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3) 24 | if i < num_iterations - 1: 25 | b_ij = b_ij + (torch.cat([v_j] * input_capsule_num, dim=1) * u_hat).sum(3) 26 | 27 | poses = v_j.squeeze(1) 28 | activations = torch.sqrt((poses ** 2).sum(2)) 29 | return poses, activations 30 | 31 | 32 | def Adaptive_KDE_routing(batch_size, b_ij, u_hat): 33 | last_loss = 0.0 34 | while True: 35 | if False: 36 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True) 37 | leaky_logits = torch.cat((leak, b_ij),2) 38 | leaky_routing = F.softmax(leaky_logits, dim=2) 39 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4) 40 | else: 41 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4) 42 | c_ij = c_ij/c_ij.sum(dim=1, keepdim=True) 43 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3) 44 | dd = 1 - ((squash_v1(u_hat, axis=3)-v_j)** 2).sum(3) 45 | b_ij = b_ij + dd 46 | 47 | c_ij = c_ij.view(batch_size, c_ij.size(1), c_ij.size(2)) 48 | dd = dd.view(batch_size, dd.size(1), dd.size(2)) 49 | 50 | kde_loss = torch.mul(c_ij, dd).sum()/batch_size 51 | kde_loss = np.log(kde_loss.item()) 52 | 53 | if abs(kde_loss - last_loss) < 0.05: 54 | break 55 | else: 56 | last_loss = kde_loss 57 | poses = v_j.squeeze(1) 58 | activations = torch.sqrt((poses ** 2).sum(2)) 59 | return poses, activations 60 | 61 | 62 | def KDE_routing(batch_size, b_ij, u_hat): 63 | num_iterations = 3 64 | for i in range(num_iterations): 65 | if False: 66 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True) 67 | leaky_logits = torch.cat((leak, b_ij),2) 68 | leaky_routing = F.softmax(leaky_logits, dim=2) 69 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4) 70 | else: 71 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4) 72 | 73 | c_ij = c_ij/c_ij.sum(dim=1, keepdim=True) 74 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3) 75 | 76 | if i < num_iterations - 1: 77 | dd = 1 - ((squash_v1(u_hat, axis=3)-v_j)** 2).sum(3) 78 | b_ij = b_ij + dd 79 | poses = v_j.squeeze(1) 80 | activations = torch.sqrt((poses ** 2).sum(2)) 81 | return poses, activations 82 | 83 | class FlattenCaps(nn.Module): 84 | def __init__(self): 85 | super(FlattenCaps, self).__init__() 86 | def forward(self, p, a): 87 | poses = p.view(p.size(0), p.size(2) * p.size(3) * p.size(4), -1) 88 | activations = a.view(a.size(0), a.size(1) * a.size(2) * a.size(3), -1) 89 | return poses, activations 90 | 91 | class PrimaryCaps(nn.Module): 92 | def __init__(self, num_capsules, in_channels, out_channels, kernel_size, stride): 93 | super(PrimaryCaps, self).__init__() 94 | 95 | self.capsules = nn.Conv1d(in_channels, out_channels * num_capsules, kernel_size, stride) 96 | 97 | torch.nn.init.xavier_uniform_(self.capsules.weight) 98 | 99 | self.out_channels = out_channels 100 | self.num_capsules = num_capsules 101 | 102 | def forward(self, x): 103 | batch_size = x.size(0) 104 | u = self.capsules(x).view(batch_size, self.num_capsules, self.out_channels, -1, 1) 105 | poses = squash_v1(u, axis=1) 106 | activations = torch.sqrt((poses ** 2).sum(1)) 107 | return poses, activations 108 | 109 | class FCCaps(nn.Module): 110 | def __init__(self, args, output_capsule_num, input_capsule_num, in_channels, out_channels): 111 | super(FCCaps, self).__init__() 112 | 113 | self.in_channels = in_channels 114 | self.out_channels = out_channels 115 | self.input_capsule_num = input_capsule_num 116 | self.output_capsule_num = output_capsule_num 117 | 118 | self.W1 = nn.Parameter(torch.FloatTensor(1, input_capsule_num, output_capsule_num, out_channels, in_channels)) 119 | torch.nn.init.xavier_uniform_(self.W1) 120 | 121 | self.is_AKDE = args.is_AKDE 122 | self.sigmoid = nn.Sigmoid() 123 | 124 | 125 | def forward(self, x, y, labels): 126 | batch_size = x.size(0) 127 | variable_output_capsule_num = len(labels) 128 | W1 = self.W1[:,:,labels,:,:] 129 | 130 | x = torch.stack([x] * variable_output_capsule_num, dim=2).unsqueeze(4) 131 | 132 | W1 = W1.repeat(batch_size, 1, 1, 1, 1) 133 | u_hat = torch.matmul(W1, x) 134 | 135 | b_ij = Variable(torch.zeros(batch_size, self.input_capsule_num, variable_output_capsule_num, 1)).cuda() 136 | 137 | if self.is_AKDE == True: 138 | poses, activations = Adaptive_KDE_routing(batch_size, b_ij, u_hat) 139 | else: 140 | #poses, activations = dynamic_routing(batch_size, b_ij, u_hat, self.input_capsule_num) 141 | poses, activations = KDE_routing(batch_size, b_ij, u_hat) 142 | return poses, activations 143 | 144 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from layer import PrimaryCaps, FCCaps, FlattenCaps 5 | 6 | def BCE_loss(x, target): 7 | return nn.BCELoss()(x.squeeze(2), target) 8 | 9 | class CapsNet_Text(nn.Module): 10 | def __init__(self, args, w2v): 11 | super(CapsNet_Text, self).__init__() 12 | self.num_classes = args.num_classes 13 | self.embed = nn.Embedding(args.vocab_size, args.vec_size) 14 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v)) 15 | 16 | self.ngram_size = [2,4,8] 17 | 18 | self.convs_doc = nn.ModuleList([nn.Conv1d(args.sequence_length, 32, K, stride=2) for K in self.ngram_size]) 19 | torch.nn.init.xavier_uniform_(self.convs_doc[0].weight) 20 | torch.nn.init.xavier_uniform_(self.convs_doc[1].weight) 21 | torch.nn.init.xavier_uniform_(self.convs_doc[2].weight) 22 | 23 | self.primary_capsules_doc = PrimaryCaps(num_capsules=args.dim_capsule, in_channels=32, out_channels=32, kernel_size=1, stride=1) 24 | 25 | self.flatten_capsules = FlattenCaps() 26 | 27 | self.W_doc = nn.Parameter(torch.FloatTensor(14272, args.num_compressed_capsule)) 28 | torch.nn.init.xavier_uniform_(self.W_doc) 29 | 30 | self.fc_capsules_doc_child = FCCaps(args, output_capsule_num=args.num_classes, input_capsule_num=args.num_compressed_capsule, 31 | in_channels=args.dim_capsule, out_channels=args.dim_capsule) 32 | 33 | def compression(self, poses, W): 34 | poses = torch.matmul(poses.permute(0,2,1), W).permute(0,2,1) 35 | activations = torch.sqrt((poses ** 2).sum(2)) 36 | return poses, activations 37 | 38 | def forward(self, data, labels): 39 | data = self.embed(data) 40 | nets_doc_l = [] 41 | for i in range(len(self.ngram_size)): 42 | nets = self.convs_doc[i](data) 43 | nets_doc_l.append(nets) 44 | nets_doc = torch.cat((nets_doc_l[0], nets_doc_l[1], nets_doc_l[2]), 2) 45 | poses_doc, activations_doc = self.primary_capsules_doc(nets_doc) 46 | poses, activations = self.flatten_capsules(poses_doc, activations_doc) 47 | poses, activations = self.compression(poses, self.W_doc) 48 | poses, activations = self.fc_capsules_doc_child(poses, activations, labels) 49 | return poses, activations 50 | 51 | 52 | class CNN_KIM(nn.Module): 53 | 54 | def __init__(self, args, w2v): 55 | super(CNN_KIM, self).__init__() 56 | self.embed = nn.Embedding(args.vocab_size, args.vec_size) 57 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v)) 58 | self.conv13 = nn.Conv2d(1, 128, (3, args.vec_size)) 59 | self.conv14 = nn.Conv2d(1, 128, (4, args.vec_size)) 60 | self.conv15 = nn.Conv2d(1, 128, (5, args.vec_size)) 61 | 62 | self.fc1 = nn.Linear(3 * 128, args.num_classes) 63 | self.m = nn.Sigmoid() 64 | 65 | def conv_and_pool(self, x, conv): 66 | x = F.relu(conv(x)).squeeze(3) 67 | x = F.max_pool1d(x, x.size(2)).squeeze(2) 68 | return x 69 | 70 | def loss(self, x, target): 71 | return nn.BCELoss()(x, target) 72 | 73 | def forward(self, x): 74 | x = self.embed(x).unsqueeze(1) 75 | x1 = self.conv_and_pool(x,self.conv13) 76 | x2 = self.conv_and_pool(x,self.conv14) 77 | x3 = self.conv_and_pool(x,self.conv15) 78 | x = torch.cat((x1, x2, x3), 1) 79 | activations = self.fc1(x) 80 | return self.m(activations) 81 | 82 | class XML_CNN(nn.Module): 83 | 84 | def __init__(self, args, w2v): 85 | super(XML_CNN, self).__init__() 86 | self.embed = nn.Embedding(args.vocab_size, args.vec_size) 87 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v)) 88 | self.conv13 = nn.Conv1d(500, 32, 2, stride=2) 89 | self.conv14 = nn.Conv1d(500, 32, 4, stride=2) 90 | self.conv15 = nn.Conv1d(500, 32, 8, stride=2) 91 | 92 | self.fc1 = nn.Linear(14272, 512) 93 | self.fc2 = nn.Linear(512, args.num_classes) 94 | self.m = nn.Sigmoid() 95 | def conv_and_pool(self, x, conv): 96 | x = F.relu(conv(x)).squeeze(3) 97 | return x 98 | 99 | def loss(self, x, target): 100 | return nn.BCELoss()(x, target) 101 | 102 | def forward(self, x): 103 | x = self.embed(x) 104 | batch_size = x.shape[0] 105 | 106 | x1 = self.conv13(x).reshape(batch_size, -1) 107 | x2 = self.conv14(x).reshape(batch_size, -1) 108 | x3 = self.conv15(x).reshape(batch_size, -1) 109 | x = torch.cat((x1, x2, x3), 1) 110 | hidden = self.fc1(x) 111 | activations = self.fc2(hidden) 112 | return self.m(activations) 113 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import Pool 3 | def precision_at_k(r, k): 4 | assert k >= 1 5 | r = np.asarray(r)[:k] != 0 6 | if r.size != k: 7 | raise ValueError('Relevance score length < k') 8 | return np.mean(r) 9 | 10 | def dcg_at_k(r, k): 11 | r = np.asfarray(r)[:k] 12 | return np.sum(r / np.log2(np.arange(2, r.size + 2))) 13 | 14 | 15 | def ndcg_at_k(r, k): 16 | dcg_max = dcg_at_k(sorted(r, reverse=True), k) 17 | if not dcg_max: 18 | return 0. 19 | return dcg_at_k(r, k) / dcg_max 20 | 21 | 22 | def get_result(args): 23 | 24 | (y_pred, y_true)=args 25 | 26 | top_k = 50 27 | pred_topk_index = sorted(range(len(y_pred)), key=lambda i: y_pred[i], reverse=True)[:top_k] 28 | pos_index = set([k for k, v in enumerate(y_true) if v == 1]) 29 | 30 | r = [1 if k in pos_index else 0 for k in pred_topk_index[:top_k]] 31 | 32 | p_1 = precision_at_k(r, 1) 33 | p_3 = precision_at_k(r, 3) 34 | p_5 = precision_at_k(r, 5) 35 | 36 | ndcg_1 = ndcg_at_k(r, 1) 37 | ndcg_3 = ndcg_at_k(r, 3) 38 | ndcg_5 = ndcg_at_k(r, 5) 39 | 40 | return np.array([p_1, p_3, p_5, ndcg_1, ndcg_3, ndcg_5]) 41 | 42 | def evaluate(Y_tst_pred, Y_tst): 43 | pool = Pool(12) 44 | results = pool.map(get_result,zip(list(Y_tst_pred), list(Y_tst))) 45 | pool.terminate() 46 | tst_result = list(np.mean(np.array(results),0)) 47 | print ('\rTst Prec@1,3,5: ', tst_result[:3], ' Tst NDCG@1,3,5: ', tst_result[3:]) 48 | -------------------------------------------------------------------------------- /w2v.py: -------------------------------------------------------------------------------- 1 | from os.path import join, exists, split 2 | import os 3 | import numpy as np 4 | 5 | def load_word2vec(model_type, vocabulary_inv, num_features=300): 6 | """ 7 | loads Word2Vec model 8 | Returns initial weights for embedding layer. 9 | 10 | inputs: 11 | model_type # GoogleNews / glove 12 | vocabulary_inv # dict {str:int} 13 | num_features # Word vector dimensionality 14 | """ 15 | 16 | model_dir = 'word2vec_models' 17 | 18 | if model_type == 'glove': 19 | model_name = join(model_dir, 'glove.6B.%dd.txt' % (num_features)) 20 | assert(exists(model_name)) 21 | print('Loading existing Word2Vec model (Glove.6B.%dd)' % (num_features)) 22 | 23 | # dictionary, where key is word, value is word vectors 24 | embedding_model = {} 25 | for line in open(model_name, 'r', encoding="utf-8"): 26 | tmp = line.strip().split() 27 | word, vec = tmp[0], list(map(float, tmp[1:])) 28 | assert(len(vec) == num_features) 29 | if word not in embedding_model: 30 | embedding_model[word] = vec 31 | assert(len(embedding_model) == 400000) 32 | 33 | else: 34 | raise ValueError('Unknown pretrain model type: %s!' % (model_type)) 35 | 36 | embedding_weights = [embedding_model[w] if w in embedding_model 37 | else np.random.uniform(-0.25, 0.25, num_features) 38 | for w in vocabulary_inv] 39 | embedding_weights = np.array(embedding_weights).astype('float32') 40 | 41 | return embedding_weights 42 | --------------------------------------------------------------------------------