├── EUR_Cap.py
├── EUR_Cap_grad.py
├── EUR_eval.py
├── README.md
├── data_helpers.py
├── layer.py
├── network.py
├── utils.py
└── w2v.py
/EUR_Cap.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import os
7 | import json
8 | import random
9 | import time
10 | from torch.autograd import Variable
11 | from torch.optim import Adam
12 | from network import CapsNet_Text,BCE_loss
13 | from w2v import load_word2vec
14 | import data_helpers
15 |
16 |
17 | torch.manual_seed(0)
18 | torch.cuda.manual_seed(0)
19 | np.random.seed(0)
20 | random.seed(0)
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p',
25 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p')
26 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size')
27 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size')
28 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents')
29 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled')
30 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs')
31 | parser.add_argument('--tr_batch_size', type=int, default=256, help='Batch size for training')
32 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training')
33 | parser.add_argument('--start_from', type=str, default='', help='')
34 |
35 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules')
36 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules')
37 |
38 | parser.add_argument('--learning_rate_decay_start', type=int, default=0,
39 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
40 | parser.add_argument('--learning_rate_decay_every', type=int, default=20,
41 | help='how many iterations thereafter to drop LR?(in epoch)')
42 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.95,
43 | help='how many iterations thereafter to drop LR?(in epoch)')
44 |
45 |
46 |
47 | args = parser.parse_args()
48 | params = vars(args)
49 | print(json.dumps(params, indent = 2))
50 |
51 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset,
52 | max_length=args.sequence_length,
53 | vocab_size=args.vocab_size)
54 | Y_trn = Y_trn.toarray()
55 | Y_tst = Y_tst.toarray()
56 |
57 | X_trn = X_trn.astype(np.int32)
58 | X_tst = X_tst.astype(np.int32)
59 | Y_trn = Y_trn.astype(np.int32)
60 | Y_tst = Y_tst.astype(np.int32)
61 |
62 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size)
63 |
64 | args.num_classes = Y_trn.shape[1]
65 |
66 | capsule_net = CapsNet_Text(args, embedding_weights)
67 | capsule_net = nn.DataParallel(capsule_net).cuda()
68 |
69 |
70 | def transformLabels(labels):
71 | label_index = list(set([l for _ in labels for l in _]))
72 | label_index.sort()
73 |
74 | variable_num_classes = len(label_index)
75 | target = []
76 | for _ in labels:
77 | tmp = np.zeros([variable_num_classes], dtype=np.float32)
78 | tmp[[label_index.index(l) for l in _]] = 1
79 | target.append(tmp)
80 | target = np.array(target)
81 | return label_index, target
82 |
83 | current_lr = args.learning_rate
84 |
85 | optimizer = Adam(capsule_net.parameters(), lr=current_lr)
86 |
87 | def set_lr(optimizer, lr):
88 | for group in optimizer.param_groups:
89 | group['lr'] = lr
90 |
91 | for epoch in range(args.num_epochs):
92 | torch.cuda.empty_cache()
93 |
94 | nr_trn_num = X_trn.shape[0]
95 | nr_batches = int(np.ceil(nr_trn_num / float(args.tr_batch_size)))
96 |
97 | if epoch > args.learning_rate_decay_start and args.learning_rate_decay_start >= 0:
98 | frac = (epoch - args.learning_rate_decay_start) // args.learning_rate_decay_every
99 | decay_factor = args.learning_rate_decay_rate ** frac
100 | current_lr = current_lr * decay_factor
101 | print(current_lr)
102 | set_lr(optimizer, current_lr)
103 |
104 | capsule_net.train()
105 | for iteration, batch_idx in enumerate(np.random.permutation(range(nr_batches))):
106 | start = time.time()
107 | start_idx = batch_idx * args.tr_batch_size
108 | end_idx = min((batch_idx + 1) * args.tr_batch_size, nr_trn_num)
109 |
110 | X = X_trn[start_idx:end_idx]
111 | Y = Y_trn_o[start_idx:end_idx]
112 | data = Variable(torch.from_numpy(X).long()).cuda()
113 |
114 | batch_labels, batch_target = transformLabels(Y)
115 | batch_target = Variable(torch.from_numpy(batch_target).float()).cuda()
116 | optimizer.zero_grad()
117 | poses, activations = capsule_net(data, batch_labels)
118 | loss = BCE_loss(activations, batch_target)
119 | loss.backward()
120 | optimizer.step()
121 | torch.cuda.empty_cache()
122 | done = time.time()
123 | elapsed = done - start
124 |
125 | print("\rIteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format(
126 | iteration, nr_batches,
127 | iteration * 100 / nr_batches,
128 | loss.item(), elapsed),
129 | end="")
130 |
131 | torch.cuda.empty_cache()
132 |
133 | if (epoch + 1) > 20:
134 | checkpoint_path = os.path.join('save', 'model-eur-akde-' + str(epoch + 1) + '.pth')
135 | torch.save(capsule_net.state_dict(), checkpoint_path)
136 | print("model saved to {}".format(checkpoint_path))
137 |
138 |
--------------------------------------------------------------------------------
/EUR_Cap_grad.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import os
7 | import json
8 | import random
9 | import time
10 | from torch.autograd import Variable
11 | from torch.optim import Adam
12 | from network import CapsNet_Text,BCE_loss, CNN_KIM
13 | from w2v import load_word2vec
14 | import data_helpers
15 |
16 |
17 | torch.manual_seed(0)
18 | torch.cuda.manual_seed(0)
19 | np.random.seed(0)
20 | random.seed(0)
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p',
25 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p')
26 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size')
27 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size')
28 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents')
29 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled')
30 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs')
31 | parser.add_argument('--tr_batch_size', type=int, default=256, help='Batch size for training')
32 | parser.add_argument('--ts_batch_size', type=int, default=16, help='Batch size for training')
33 |
34 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training')
35 | parser.add_argument('--start_from', type=str, default='', help='')
36 |
37 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules')
38 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules')
39 |
40 | parser.add_argument('--learning_rate_decay_start', type=int, default=0,
41 | help='at what iteration to start decaying learning rate? (-1 = dont) (in epoch)')
42 | parser.add_argument('--learning_rate_decay_every', type=int, default=20,
43 | help='how many iterations thereafter to drop LR?(in epoch)')
44 | parser.add_argument('--learning_rate_decay_rate', type=float, default=0.95,
45 | help='how many iterations thereafter to drop LR?(in epoch)')
46 |
47 | parser.add_argument('--gradient_accumulation_steps', type=int, default=8)
48 |
49 | parser.add_argument('--re_ranking', type=int, default=200, help='The number of re-ranking size')
50 |
51 |
52 | args = parser.parse_args()
53 | params = vars(args)
54 | print(json.dumps(params, indent = 2))
55 |
56 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset,
57 | max_length=args.sequence_length,
58 | vocab_size=args.vocab_size)
59 | Y_trn = Y_trn.toarray()
60 | Y_tst = Y_tst.toarray()
61 |
62 | X_trn = X_trn.astype(np.int32)
63 | X_tst = X_tst.astype(np.int32)
64 | Y_trn = Y_trn.astype(np.int32)
65 | Y_tst = Y_tst.astype(np.int32)
66 |
67 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size)
68 |
69 | args.num_classes = Y_trn.shape[1]
70 |
71 | capsule_net = CapsNet_Text(args, embedding_weights)
72 | capsule_net = nn.DataParallel(capsule_net).cuda()
73 |
74 | model_name = 'model-EUR-CNN-40.pth'
75 | baseline = CNN_KIM(args, embedding_weights)
76 | baseline.load_state_dict(torch.load(os.path.join('save_new', model_name)))
77 | baseline = nn.DataParallel(baseline).cuda()
78 | print(model_name + ' loaded')
79 |
80 | def transformLabels(labels, total_labels):
81 | label_index = list(set([l for _ in total_labels for l in _]))
82 | label_index.sort()
83 |
84 | variable_num_classes = len(label_index)
85 | target = []
86 | for _ in labels:
87 | tmp = np.zeros([variable_num_classes], dtype=np.float32)
88 | tmp[[label_index.index(l) for l in _]] = 1
89 | target.append(tmp)
90 | target = np.array(target)
91 | return label_index, target
92 |
93 | current_lr = args.learning_rate
94 |
95 | optimizer = Adam(capsule_net.parameters(), lr=current_lr)
96 |
97 | def set_lr(optimizer, lr):
98 | for group in optimizer.param_groups:
99 | group['lr'] = lr
100 |
101 | from network import CNN_KIM,CapsNet_Text
102 | import random
103 | from utils import evaluate
104 | import data_helpers
105 | import scipy.sparse as sp
106 | from w2v import load_word2vec
107 | import os
108 |
109 | for epoch in range(args.num_epochs):
110 |
111 | nr_trn_num = X_trn.shape[0]
112 | nr_batches = int(np.ceil(nr_trn_num / float(args.tr_batch_size)))
113 |
114 | if epoch > args.learning_rate_decay_start and args.learning_rate_decay_start >= 0:
115 | frac = (epoch - args.learning_rate_decay_start) // args.learning_rate_decay_every
116 | decay_factor = args.learning_rate_decay_rate ** frac
117 | current_lr = current_lr * decay_factor
118 | print(current_lr)
119 | set_lr(optimizer, current_lr)
120 |
121 | capsule_net.train()
122 | for iteration, batch_idx in enumerate(np.random.permutation(range(nr_batches))):
123 | start = time.time()
124 | start_idx = batch_idx * args.tr_batch_size
125 | end_idx = min((batch_idx + 1) * args.tr_batch_size, nr_trn_num)
126 |
127 | X = X_trn[start_idx:end_idx]
128 | Y = Y_trn_o[start_idx:end_idx]
129 |
130 | batch_steps = int(np.ceil(len(X)) / (float(args.tr_batch_size) / float(args.gradient_accumulation_steps)))
131 | batch_loss = 0
132 | for i in range(batch_steps):
133 | step_size = int(float(args.tr_batch_size) // float(args.gradient_accumulation_steps))
134 | step_X = X[i * step_size: (i+1) * step_size]
135 | step_Y = Y[i * step_size: (i+1) * step_size]
136 |
137 | step_X = Variable(torch.from_numpy(step_X).long()).cuda()
138 | step_labels, step_target = transformLabels(step_Y, Y)
139 | step_target = Variable(torch.from_numpy(step_target).float()).cuda()
140 |
141 | poses, activations = capsule_net(step_X, step_labels)
142 | step_loss = BCE_loss(activations, step_target)
143 | step_loss = step_loss / args.gradient_accumulation_steps
144 | step_loss.backward()
145 | batch_loss += step_loss.item()
146 |
147 | optimizer.step()
148 | optimizer.zero_grad()
149 | done = time.time()
150 | elapsed = done - start
151 |
152 | print("\rIteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format(
153 | iteration, nr_batches,
154 | iteration * 100 / nr_batches,
155 | batch_loss, elapsed),
156 | end="")
157 |
158 | if (epoch + 1) > 20 and (epoch + 1)<30:
159 |
160 | nr_tst_num = X_tst.shape[0]
161 | nr_batches = int(np.ceil(nr_tst_num / float(args.ts_batch_size)))
162 |
163 | n, k_trn = Y_trn.shape
164 | m, k_tst = Y_tst.shape
165 | print ('k_trn:', k_trn)
166 | print ('k_tst:', k_tst)
167 |
168 | capsule_net.eval()
169 | top_k = 50
170 | row_idx_list, col_idx_list, val_idx_list = [], [], []
171 | for batch_idx in range(nr_batches):
172 | start = time.time()
173 | start_idx = batch_idx * args.ts_batch_size
174 | end_idx = min((batch_idx + 1) * args.ts_batch_size, nr_tst_num)
175 | X = X_tst[start_idx:end_idx]
176 | Y = Y_tst_o[start_idx:end_idx]
177 | data = Variable(torch.from_numpy(X).long()).cuda()
178 |
179 | candidates = baseline(data)
180 | candidates = candidates.data.cpu().numpy()
181 |
182 | Y_pred = np.zeros([candidates.shape[0], args.num_classes])
183 | for i in range(candidates.shape[0]):
184 | candidate_labels = candidates[i, :].argsort()[-args.re_ranking:][::-1].tolist()
185 | _, activations_2nd = capsule_net(data[i, :].unsqueeze(0), candidate_labels)
186 | Y_pred[i, candidate_labels] = activations_2nd.squeeze(2).data.cpu().numpy()
187 |
188 | for i in range(Y_pred.shape[0]):
189 | sorted_idx = np.argpartition(-Y_pred[i, :], top_k)[:top_k]
190 | row_idx_list += [i + start_idx] * top_k
191 | col_idx_list += (sorted_idx).tolist()
192 | val_idx_list += Y_pred[i, sorted_idx].tolist()
193 |
194 | done = time.time()
195 | elapsed = done - start
196 |
197 | print("\r Epoch: {} Reranking: {} Iteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format(
198 | (epoch + 1), args.re_ranking, batch_idx, nr_batches,
199 | batch_idx * 100 / nr_batches,
200 | 0, elapsed),
201 | end="")
202 |
203 | m = max(row_idx_list) + 1
204 | n = max(k_trn, k_tst)
205 | print(elapsed)
206 | Y_tst_pred = sp.csr_matrix((val_idx_list, (row_idx_list, col_idx_list)), shape=(m, n))
207 |
208 | if k_trn >= k_tst:
209 | Y_tst_pred = Y_tst_pred[:, :k_tst]
210 |
211 | evaluate(Y_tst_pred.toarray(), Y_tst)
212 |
213 | # checkpoint_path = os.path.join('save_new', 'model-eur-akde-' + str(epoch + 1) + '.pth')
214 | # torch.save(capsule_net.state_dict(), checkpoint_path)
215 | # print("model saved to {}".format(checkpoint_path))
216 |
217 |
--------------------------------------------------------------------------------
/EUR_eval.py:
--------------------------------------------------------------------------------
1 | from __future__ import division, print_function, unicode_literals
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.autograd import Variable
7 | from network import CNN_KIM,CapsNet_Text
8 | import random
9 | import time
10 | from utils import evaluate
11 | import data_helpers
12 | import scipy.sparse as sp
13 | from w2v import load_word2vec
14 | import os
15 |
16 | torch.manual_seed(0)
17 | torch.cuda.manual_seed(0)
18 | np.random.seed(0)
19 | random.seed(0)
20 |
21 | parser = argparse.ArgumentParser()
22 |
23 | parser.add_argument('--dataset', type=str, default='eurlex_raw_text.p',
24 | help='Options: eurlex_raw_text.p, rcv1_raw_text.p, wiki30k_raw_text.p')
25 | parser.add_argument('--vocab_size', type=int, default=30001, help='vocabulary size')
26 | parser.add_argument('--vec_size', type=int, default=300, help='embedding size')
27 | parser.add_argument('--sequence_length', type=int, default=500, help='the length of documents')
28 | parser.add_argument('--is_AKDE', type=bool, default=True, help='if Adaptive KDE routing is enabled')
29 | parser.add_argument('--num_epochs', type=int, default=30, help='Number of training epochs')
30 | parser.add_argument('--ts_batch_size', type=int, default=32, help='Batch size for training')
31 | parser.add_argument('--learning_rate', type=float, default=1e-3, help='Learning rate for training')
32 | parser.add_argument('--start_from', type=str, default='save', help='')
33 |
34 | parser.add_argument('--num_compressed_capsule', type=int, default=128, help='The number of compact capsules')
35 | parser.add_argument('--dim_capsule', type=int, default=16, help='The number of dimensions for capsules')
36 |
37 | parser.add_argument('--re_ranking', type=int, default=200, help='The number of re-ranking size')
38 |
39 | import json
40 | args = parser.parse_args()
41 | params = vars(args)
42 | print(json.dumps(params, indent = 2))
43 |
44 | X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv = data_helpers.load_data(args.dataset,
45 | max_length=args.sequence_length,
46 | vocab_size=args.vocab_size)
47 | Y_trn = Y_trn.toarray()
48 | Y_tst = Y_tst.toarray()
49 |
50 | X_trn = X_trn.astype(np.int32)
51 | X_tst = X_tst.astype(np.int32)
52 | Y_trn = Y_trn.astype(np.int32)
53 | Y_tst = Y_tst.astype(np.int32)
54 |
55 |
56 | embedding_weights = load_word2vec('glove', vocabulary_inv, args.vec_size)
57 | args.num_classes = Y_trn.shape[1]
58 |
59 | capsule_net = CapsNet_Text(args, embedding_weights)
60 | capsule_net = nn.DataParallel(capsule_net).cuda()
61 | model_name = 'model-eur-akde-29.pth'
62 | capsule_net.load_state_dict(torch.load(os.path.join(args.start_from, model_name)))
63 | print(model_name + ' loaded')
64 |
65 |
66 | model_name = 'model-EUR-CNN-40.pth'
67 | baseline = CNN_KIM(args, embedding_weights)
68 | baseline.load_state_dict(torch.load(os.path.join(args.start_from, model_name)))
69 | baseline = nn.DataParallel(baseline).cuda()
70 | print(model_name + ' loaded')
71 |
72 |
73 | nr_tst_num = X_tst.shape[0]
74 | nr_batches = int(np.ceil(nr_tst_num / float(args.ts_batch_size)))
75 |
76 | n, k_trn = Y_trn.shape
77 | m, k_tst = Y_tst.shape
78 | print ('k_trn:', k_trn)
79 | print ('k_tst:', k_tst)
80 |
81 | capsule_net.eval()
82 | top_k = 50
83 | row_idx_list, col_idx_list, val_idx_list = [], [], []
84 | for batch_idx in range(nr_batches):
85 | start = time.time()
86 | start_idx = batch_idx * args.ts_batch_size
87 | end_idx = min((batch_idx + 1) * args.ts_batch_size, nr_tst_num)
88 | X = X_tst[start_idx:end_idx]
89 | Y = Y_tst_o[start_idx:end_idx]
90 | data = Variable(torch.from_numpy(X).long()).cuda()
91 |
92 | candidates = baseline(data)
93 | candidates = candidates.data.cpu().numpy()
94 |
95 | Y_pred = np.zeros([candidates.shape[0], args.num_classes])
96 | for i in range(candidates.shape[0]):
97 | candidate_labels = candidates[i, :].argsort()[-args.re_ranking:][::-1].tolist()
98 | _, activations_2nd = capsule_net(data[i, :].unsqueeze(0), candidate_labels)
99 | Y_pred[i, candidate_labels] = activations_2nd.squeeze(2).data.cpu().numpy()
100 |
101 | for i in range(Y_pred.shape[0]):
102 | sorted_idx = np.argpartition(-Y_pred[i, :], top_k)[:top_k]
103 | row_idx_list += [i + start_idx] * top_k
104 | col_idx_list += (sorted_idx).tolist()
105 | val_idx_list += Y_pred[i, sorted_idx].tolist()
106 |
107 | done = time.time()
108 | elapsed = done - start
109 |
110 | print("\r Reranking: {} Iteration: {}/{} ({:.1f}%) Loss: {:.5f} {:.5f}".format(
111 | args.re_ranking, batch_idx, nr_batches,
112 | batch_idx * 100 / nr_batches,
113 | 0, elapsed),
114 | end="")
115 |
116 | m = max(row_idx_list) + 1
117 | n = max(k_trn, k_tst)
118 | print(elapsed)
119 | Y_tst_pred = sp.csr_matrix((val_idx_list, (row_idx_list, col_idx_list)), shape=(m, n))
120 |
121 | if k_trn >= k_tst:
122 | Y_tst_pred = Y_tst_pred[:, :k_tst]
123 |
124 | evaluate(Y_tst_pred.toarray(), Y_tst)
125 |
126 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Towards Scalable and Reliable Capsule Networks for Challenging NLP Applications
2 |
3 | ACL-19: https://www.aclweb.org/anthology/P19-1150/
4 |
5 | Requirements: Code is written in Python 3 and requires Pytorch.
6 |
7 | # Preparation
8 | For quick start, please download [the dataset and trained model](https://drive.google.com/open?id=1gPYAMyYo4YLrmx_Egc9wjCqzWx15D7U8).
9 |
10 | # Code Explanation
11 | The data_helpers implements the functions for data processing.
12 |
13 | The layers.py implements all the main functions of capsule network, including KDE routing, Adaptive KDE routing, Primary Capsule layer and etc.
14 |
15 | The network.py provides the wrapper of our model as well as baseline models for the comparison.
16 |
17 | The utils.py provides all the evaluation functions such as Precision@1,3,5 and NDCG@1,3,5.
18 |
19 | The EUR_Cap.py and EUR_eval.py are for training and inference, respectively.
20 | # Quick start
21 |
22 | ```bash
23 | CUDA_VISIBLE_DEVICES=0 python EUR_eval.py
24 |
25 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python EUR_Cap.py
26 |
27 | CUDA_VISIBLE_DEVICES=0 python EUR_Cap_grad.py # train CapNet on single GPU with accumulated gradients
28 | ```
29 |
30 | # Performance on EUR-Lex dataset
31 |
32 | ```bash
33 | NLP-Capsule with Adaptive KDE routing:
34 |
35 | Epoch: 20 Iteration: 120/121 (99.2%) Loss: 0.00000 0.33459
36 | Tst Prec@1,3,5: [0.7948253557567917, 0.65605864596808838, 0.53666235446312649]
37 | Tst NDCG@1,3,5: [0.7948253557567917, 0.70826730037244034, 0.6843311797551882]
38 |
39 | Epoch: 21 Iteration: 120/121 (99.2%) Loss: 0.00000 0.24704
40 | Tst Prec@1,3,5: [0.79301423027166884, 0.6552824493316064, 0.53666235446312793]
41 | Tst NDCG@1,3,5: [0.79301423027166884, 0.70672871614554134, 0.68443643153244704]
42 |
43 | Epoch: 22 Iteration: 120/121 (99.2%) Loss: 0.00000 0.24949
44 | Tst Prec@1,3,5: [0.79404915912031049, 0.65554118154376773, 0.53800776196636135]
45 | Tst NDCG@1,3,5: [0.79404915912031049, 0.70816714976829975, 0.68780244631961929]
46 |
47 | Epoch: 23 Iteration: 120/121 (99.2%) Loss: 0.00000 0.25533
48 | Tst Prec@1,3,5: [0.8046571798188874, 0.65890470030185422, 0.53604139715394228]
49 | Tst NDCG@1,3,5: [0.8046571798188874, 0.71380071010660562, 0.69040247647419262]
50 |
51 | Epoch: 24 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26880
52 | Tst Prec@1,3,5: [0.80620957309184993, 0.65614489003880982, 0.53661060802069527]
53 | Tst NDCG@1,3,5: [0.80620957309184993, 0.7133596479633022, 0.69571103238443532]
54 |
55 | Epoch: 25 Iteration: 120/121 (99.2%) Loss: 0.00000 0.25847
56 | Tst Prec@1,3,5: [0.80155239327296246, 0.65329883570504454, 0.53448900388098108]
57 | Tst NDCG@1,3,5: [0.80155239327296246, 0.7096033706441367, 0.69201706652281636]
58 |
59 | Epoch: 26 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26063
60 | Tst Prec@1,3,5: [0.80000000000000004, 0.65381630012936431, 0.53350582147477121]
61 | Tst NDCG@1,3,5: [0.80000000000000004, 0.71043623399753963, 0.69499344732549306]
62 |
63 | Epoch: 27 Iteration: 120/121 (99.2%) Loss: 0.00000 0.26004
64 | Tst Prec@1,3,5: [0.79689521345407499, 0.65398878827080587, 0.53376455368693132]
65 | Tst NDCG@1,3,5: [0.79689521345407499, 0.71269493382033577, 0.69812854866301688]
66 |
67 | Epoch: 28 Iteration: 120/121 (99.2%) Loss: 0.00000 0.27287
68 | Tst Prec@1,3,5: [0.79818887451487708, 0.65588615782664883, 0.53500646830530163]
69 | Tst NDCG@1,3,5: [0.79818887451487708, 0.71429911265714374, 0.70057615675866636]
70 |
71 |
72 | XML-CNN:
73 | Epoch: 31 Iteration: 45/46 (97.8%) Loss: 0.00006 0.15460
74 | Tst Prec@1,3,5: [0.7583441138421734, 0.6164726175075479, 0.5073738680465716]
75 | Tst NDCG@1,3,5: [0.7583441138421734, 0.6661232856458101, 0.644838787586548]
76 |
77 | Epoch: 32 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15354
78 | Tst Prec@1,3,5: [0.759379042690815, 0.6143165157395448, 0.5062871927554978]
79 | Tst NDCG@1,3,5: [0.759379042690815, 0.6648180435110952, 0.6434396675410785]
80 |
81 | Epoch: 33 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15399
82 | Tst Prec@1,3,5: [0.757567917205692, 0.6169038378611481, 0.507373868046571]
83 | Tst NDCG@1,3,5: [0.757567917205692, 0.666160785036582, 0.6440332351720106]
84 |
85 | Epoch: 34 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15153
86 | Tst Prec@1,3,5: [0.7573091849935317, 0.616645105648988, 0.5099094437257432]
87 | Tst NDCG@1,3,5: [0.7573091849935317, 0.6659194956789641, 0.6458294426678642]
88 |
89 | Epoch: 35 Iteration: 45/46 (97.8%) Loss: 0.00005 0.15212
90 | Tst Prec@1,3,5: [0.7552393272962484, 0.6153514445881856, 0.5092367399741262]
91 | Tst NDCG@1,3,5: [0.7552393272962484, 0.6648419426927356, 0.6453632713906606]
92 |
93 | Epoch: 36 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15231
94 | Tst Prec@1,3,5: [0.7596377749029755, 0.6157826649417857, 0.5093402328589907]
95 | Tst NDCG@1,3,5: [0.7596377749029755, 0.6661452963066051, 0.646133349811576]
96 |
97 | Epoch: 37 Iteration: 45/46 (97.8%) Loss: 0.00006 0.15357
98 | Tst Prec@1,3,5: [0.7570504527813713, 0.6175937904269097, 0.5088227684346699]
99 | Tst NDCG@1,3,5: [0.7570504527813713, 0.6670823259018512, 0.6455866525334287]
100 |
101 | Epoch: 38 Iteration: 45/46 (97.8%) Loss: 0.00006 0.16400
102 | Tst Prec@1,3,5: [0.7583441138421734, 0.6162138852953867, 0.5085122897800777]
103 | Tst NDCG@1,3,5: [0.7583441138421734, 0.6658377730303046, 0.6448260229129755]
104 |
105 | Epoch: 39 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15555
106 | Tst Prec@1,3,5: [0.7578266494178525, 0.6173350582147488, 0.509029754204398]
107 | Tst NDCG@1,3,5: [0.7578266494178525, 0.6667396690496684, 0.645590263852396]
108 |
109 | Epoch: 40 Iteration: 45/46 (97.8%) Loss: 0.00004 0.15414
110 | Tst Prec@1,3,5: [0.7565329883570504, 0.61811125485123, 0.5087192755498058]
111 | Tst NDCG@1,3,5: [0.7565329883570504, 0.6674559324640292, 0.6452839523583206]
112 |
113 | ```
114 |
115 | # Reference
116 | If you find our source code useful, please consider citing our work.
117 | ```
118 | @inproceedings{zhao2019capsule,
119 | title = "Towards Scalable and Reliable Capsule Networks for Challenging {NLP} Applications",
120 | author = "Zhao, Wei and Peng, Haiyun and Eger, Steffen and Cambria, Erik and Yang, Min",
121 | booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
122 | month = jul,
123 | year = "2019",
124 | address = "Florence, Italy",
125 | publisher = "Association for Computational Linguistics",
126 | url = "https://www.aclweb.org/anthology/P19-1150",
127 | doi = "10.18653/v1/P19-1150",
128 | pages = "1549--1559"
129 | }
130 | ```
131 |
--------------------------------------------------------------------------------
/data_helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import re
4 | import itertools
5 | import scipy.sparse as sp
6 | import cPickle as pickle
7 | from collections import Counter
8 | from nltk.corpus import stopwords
9 |
10 | cachedStopWords = stopwords.words("english")
11 |
12 |
13 | def clean_str(string):
14 | # remove stopwords
15 | # string = ' '.join([word for word in string.split() if word not in cachedStopWords])
16 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
17 | string = re.sub(r"\'s", " \'s", string)
18 | string = re.sub(r"\'ve", " \'ve", string)
19 | string = re.sub(r"n\'t", " n\'t", string)
20 | string = re.sub(r"\'re", " \'re", string)
21 | string = re.sub(r"\'d", " \'d", string)
22 | string = re.sub(r"\'ll", " \'ll", string)
23 | string = re.sub(r",", " , ", string)
24 | string = re.sub(r"!", " ! ", string)
25 | string = re.sub(r"\(", " \( ", string)
26 | string = re.sub(r"\)", " \) ", string)
27 | string = re.sub(r"\?", " \? ", string)
28 | string = re.sub(r"\s{2,}", " ", string)
29 | return string.strip().lower()
30 |
31 |
32 | def pad_sentences(sentences, padding_word="", max_length=500):
33 | sequence_length = min(max(len(x) for x in sentences), max_length)
34 | padded_sentences = []
35 | for i in range(len(sentences)):
36 | sentence = sentences[i]
37 | if len(sentence) < max_length:
38 | num_padding = sequence_length - len(sentence)
39 | new_sentence = sentence + [padding_word] * num_padding
40 | else:
41 | new_sentence = sentence[:max_length]
42 | padded_sentences.append(new_sentence)
43 | return padded_sentences
44 |
45 |
46 | def load_data_and_labels(data):
47 | x_text = [clean_str(doc['text']) for doc in data]
48 | x_text = [s.split(" ") for s in x_text]
49 | labels = [doc['catgy'] for doc in data]
50 | row_idx, col_idx, val_idx = [], [], []
51 | for i in range(len(labels)):
52 | l_list = list(set(labels[i])) # remove duplicate cateories to avoid double count
53 | for y in l_list:
54 | row_idx.append(i)
55 | col_idx.append(y)
56 | val_idx.append(1)
57 | m = max(row_idx) + 1
58 | n = max(col_idx) + 1
59 | Y = sp.csr_matrix((val_idx, (row_idx, col_idx)), shape=(m, n))
60 | return [x_text, Y, labels]
61 |
62 |
63 | def build_vocab(sentences, vocab_size=50000):
64 | word_counts = Counter(itertools.chain(*sentences))
65 | vocabulary_inv = [x[0] for x in word_counts.most_common(vocab_size)]
66 | vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
67 | # append symbol to the vocabulary
68 | vocabulary[''] = len(vocabulary)
69 | vocabulary_inv.append('')
70 | return [vocabulary, vocabulary_inv]
71 |
72 |
73 | def build_input_data(sentences, vocabulary):
74 | x = np.array([[vocabulary[word] if word in vocabulary else vocabulary[''] for word in sentence] for sentence in sentences])
75 | #x = np.array([[vocabulary[word] if word in vocabulary else len(vocabulary) for word in sentence] for sentence in sentences])
76 | return x
77 |
78 |
79 | def load_data(data_path, max_length=500, vocab_size=50000):
80 | # Load and preprocess data
81 | with open(os.path.join(data_path), 'rb') as fin:
82 | [train, test, vocab, catgy] = pickle.load(fin)
83 |
84 | # dirty trick to prevent errors happen when test is empty
85 | if len(test) == 0:
86 | test[:5] = train[:5]
87 |
88 | trn_sents, Y_trn, Y_trn_o = load_data_and_labels(train)
89 | tst_sents, Y_tst, Y_tst_o = load_data_and_labels(test)
90 | trn_sents_padded = pad_sentences(trn_sents, max_length=max_length)
91 | tst_sents_padded = pad_sentences(tst_sents, max_length=max_length)
92 | vocabulary, vocabulary_inv = build_vocab(trn_sents_padded + tst_sents_padded, vocab_size=vocab_size)
93 | X_trn = build_input_data(trn_sents_padded, vocabulary)
94 | X_tst = build_input_data(tst_sents_padded, vocabulary)
95 | return X_trn, Y_trn, Y_trn_o, X_tst, Y_tst, Y_tst_o, vocabulary, vocabulary_inv
96 | # return X_trn, Y_trn, vocabulary, vocabulary_inv
97 |
98 |
99 | def batch_iter(data, batch_size, num_epochs):
100 | """
101 | Generates a batch iterator for a dataset.
102 | """
103 | data = np.array(data)
104 | data_size = len(data)
105 | num_batches_per_epoch = int(len(data)/batch_size) + 1
106 | for epoch in range(num_epochs):
107 | # Shuffle the data at each epoch
108 | shuffle_indices = np.random.permutation(np.arange(data_size))
109 | shuffled_data = data[shuffle_indices]
110 | for batch_num in range(num_batches_per_epoch):
111 | start_index = batch_num * batch_size
112 | end_index = min((batch_num + 1) * batch_size, data_size)
113 | yield shuffled_data[start_index:end_index]
114 |
--------------------------------------------------------------------------------
/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import numpy as np
6 |
7 | def squash_v1(x, axis):
8 | s_squared_norm = (x ** 2).sum(axis, keepdim=True)
9 | scale = torch.sqrt(s_squared_norm)/ (0.5 + s_squared_norm)
10 | return scale * x
11 |
12 | def dynamic_routing(batch_size, b_ij, u_hat, input_capsule_num):
13 | num_iterations = 3
14 |
15 | for i in range(num_iterations):
16 | if True:
17 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True)
18 | leaky_logits = torch.cat((leak, b_ij),2)
19 | leaky_routing = F.softmax(leaky_logits, dim=2)
20 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4)
21 | else:
22 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4)
23 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3)
24 | if i < num_iterations - 1:
25 | b_ij = b_ij + (torch.cat([v_j] * input_capsule_num, dim=1) * u_hat).sum(3)
26 |
27 | poses = v_j.squeeze(1)
28 | activations = torch.sqrt((poses ** 2).sum(2))
29 | return poses, activations
30 |
31 |
32 | def Adaptive_KDE_routing(batch_size, b_ij, u_hat):
33 | last_loss = 0.0
34 | while True:
35 | if False:
36 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True)
37 | leaky_logits = torch.cat((leak, b_ij),2)
38 | leaky_routing = F.softmax(leaky_logits, dim=2)
39 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4)
40 | else:
41 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4)
42 | c_ij = c_ij/c_ij.sum(dim=1, keepdim=True)
43 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3)
44 | dd = 1 - ((squash_v1(u_hat, axis=3)-v_j)** 2).sum(3)
45 | b_ij = b_ij + dd
46 |
47 | c_ij = c_ij.view(batch_size, c_ij.size(1), c_ij.size(2))
48 | dd = dd.view(batch_size, dd.size(1), dd.size(2))
49 |
50 | kde_loss = torch.mul(c_ij, dd).sum()/batch_size
51 | kde_loss = np.log(kde_loss.item())
52 |
53 | if abs(kde_loss - last_loss) < 0.05:
54 | break
55 | else:
56 | last_loss = kde_loss
57 | poses = v_j.squeeze(1)
58 | activations = torch.sqrt((poses ** 2).sum(2))
59 | return poses, activations
60 |
61 |
62 | def KDE_routing(batch_size, b_ij, u_hat):
63 | num_iterations = 3
64 | for i in range(num_iterations):
65 | if False:
66 | leak = torch.zeros_like(b_ij).sum(dim=2, keepdim=True)
67 | leaky_logits = torch.cat((leak, b_ij),2)
68 | leaky_routing = F.softmax(leaky_logits, dim=2)
69 | c_ij = leaky_routing[:,:,1:,:].unsqueeze(4)
70 | else:
71 | c_ij = F.softmax(b_ij, dim=2).unsqueeze(4)
72 |
73 | c_ij = c_ij/c_ij.sum(dim=1, keepdim=True)
74 | v_j = squash_v1((c_ij * u_hat).sum(dim=1, keepdim=True), axis=3)
75 |
76 | if i < num_iterations - 1:
77 | dd = 1 - ((squash_v1(u_hat, axis=3)-v_j)** 2).sum(3)
78 | b_ij = b_ij + dd
79 | poses = v_j.squeeze(1)
80 | activations = torch.sqrt((poses ** 2).sum(2))
81 | return poses, activations
82 |
83 | class FlattenCaps(nn.Module):
84 | def __init__(self):
85 | super(FlattenCaps, self).__init__()
86 | def forward(self, p, a):
87 | poses = p.view(p.size(0), p.size(2) * p.size(3) * p.size(4), -1)
88 | activations = a.view(a.size(0), a.size(1) * a.size(2) * a.size(3), -1)
89 | return poses, activations
90 |
91 | class PrimaryCaps(nn.Module):
92 | def __init__(self, num_capsules, in_channels, out_channels, kernel_size, stride):
93 | super(PrimaryCaps, self).__init__()
94 |
95 | self.capsules = nn.Conv1d(in_channels, out_channels * num_capsules, kernel_size, stride)
96 |
97 | torch.nn.init.xavier_uniform_(self.capsules.weight)
98 |
99 | self.out_channels = out_channels
100 | self.num_capsules = num_capsules
101 |
102 | def forward(self, x):
103 | batch_size = x.size(0)
104 | u = self.capsules(x).view(batch_size, self.num_capsules, self.out_channels, -1, 1)
105 | poses = squash_v1(u, axis=1)
106 | activations = torch.sqrt((poses ** 2).sum(1))
107 | return poses, activations
108 |
109 | class FCCaps(nn.Module):
110 | def __init__(self, args, output_capsule_num, input_capsule_num, in_channels, out_channels):
111 | super(FCCaps, self).__init__()
112 |
113 | self.in_channels = in_channels
114 | self.out_channels = out_channels
115 | self.input_capsule_num = input_capsule_num
116 | self.output_capsule_num = output_capsule_num
117 |
118 | self.W1 = nn.Parameter(torch.FloatTensor(1, input_capsule_num, output_capsule_num, out_channels, in_channels))
119 | torch.nn.init.xavier_uniform_(self.W1)
120 |
121 | self.is_AKDE = args.is_AKDE
122 | self.sigmoid = nn.Sigmoid()
123 |
124 |
125 | def forward(self, x, y, labels):
126 | batch_size = x.size(0)
127 | variable_output_capsule_num = len(labels)
128 | W1 = self.W1[:,:,labels,:,:]
129 |
130 | x = torch.stack([x] * variable_output_capsule_num, dim=2).unsqueeze(4)
131 |
132 | W1 = W1.repeat(batch_size, 1, 1, 1, 1)
133 | u_hat = torch.matmul(W1, x)
134 |
135 | b_ij = Variable(torch.zeros(batch_size, self.input_capsule_num, variable_output_capsule_num, 1)).cuda()
136 |
137 | if self.is_AKDE == True:
138 | poses, activations = Adaptive_KDE_routing(batch_size, b_ij, u_hat)
139 | else:
140 | #poses, activations = dynamic_routing(batch_size, b_ij, u_hat, self.input_capsule_num)
141 | poses, activations = KDE_routing(batch_size, b_ij, u_hat)
142 | return poses, activations
143 |
144 |
--------------------------------------------------------------------------------
/network.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from layer import PrimaryCaps, FCCaps, FlattenCaps
5 |
6 | def BCE_loss(x, target):
7 | return nn.BCELoss()(x.squeeze(2), target)
8 |
9 | class CapsNet_Text(nn.Module):
10 | def __init__(self, args, w2v):
11 | super(CapsNet_Text, self).__init__()
12 | self.num_classes = args.num_classes
13 | self.embed = nn.Embedding(args.vocab_size, args.vec_size)
14 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v))
15 |
16 | self.ngram_size = [2,4,8]
17 |
18 | self.convs_doc = nn.ModuleList([nn.Conv1d(args.sequence_length, 32, K, stride=2) for K in self.ngram_size])
19 | torch.nn.init.xavier_uniform_(self.convs_doc[0].weight)
20 | torch.nn.init.xavier_uniform_(self.convs_doc[1].weight)
21 | torch.nn.init.xavier_uniform_(self.convs_doc[2].weight)
22 |
23 | self.primary_capsules_doc = PrimaryCaps(num_capsules=args.dim_capsule, in_channels=32, out_channels=32, kernel_size=1, stride=1)
24 |
25 | self.flatten_capsules = FlattenCaps()
26 |
27 | self.W_doc = nn.Parameter(torch.FloatTensor(14272, args.num_compressed_capsule))
28 | torch.nn.init.xavier_uniform_(self.W_doc)
29 |
30 | self.fc_capsules_doc_child = FCCaps(args, output_capsule_num=args.num_classes, input_capsule_num=args.num_compressed_capsule,
31 | in_channels=args.dim_capsule, out_channels=args.dim_capsule)
32 |
33 | def compression(self, poses, W):
34 | poses = torch.matmul(poses.permute(0,2,1), W).permute(0,2,1)
35 | activations = torch.sqrt((poses ** 2).sum(2))
36 | return poses, activations
37 |
38 | def forward(self, data, labels):
39 | data = self.embed(data)
40 | nets_doc_l = []
41 | for i in range(len(self.ngram_size)):
42 | nets = self.convs_doc[i](data)
43 | nets_doc_l.append(nets)
44 | nets_doc = torch.cat((nets_doc_l[0], nets_doc_l[1], nets_doc_l[2]), 2)
45 | poses_doc, activations_doc = self.primary_capsules_doc(nets_doc)
46 | poses, activations = self.flatten_capsules(poses_doc, activations_doc)
47 | poses, activations = self.compression(poses, self.W_doc)
48 | poses, activations = self.fc_capsules_doc_child(poses, activations, labels)
49 | return poses, activations
50 |
51 |
52 | class CNN_KIM(nn.Module):
53 |
54 | def __init__(self, args, w2v):
55 | super(CNN_KIM, self).__init__()
56 | self.embed = nn.Embedding(args.vocab_size, args.vec_size)
57 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v))
58 | self.conv13 = nn.Conv2d(1, 128, (3, args.vec_size))
59 | self.conv14 = nn.Conv2d(1, 128, (4, args.vec_size))
60 | self.conv15 = nn.Conv2d(1, 128, (5, args.vec_size))
61 |
62 | self.fc1 = nn.Linear(3 * 128, args.num_classes)
63 | self.m = nn.Sigmoid()
64 |
65 | def conv_and_pool(self, x, conv):
66 | x = F.relu(conv(x)).squeeze(3)
67 | x = F.max_pool1d(x, x.size(2)).squeeze(2)
68 | return x
69 |
70 | def loss(self, x, target):
71 | return nn.BCELoss()(x, target)
72 |
73 | def forward(self, x):
74 | x = self.embed(x).unsqueeze(1)
75 | x1 = self.conv_and_pool(x,self.conv13)
76 | x2 = self.conv_and_pool(x,self.conv14)
77 | x3 = self.conv_and_pool(x,self.conv15)
78 | x = torch.cat((x1, x2, x3), 1)
79 | activations = self.fc1(x)
80 | return self.m(activations)
81 |
82 | class XML_CNN(nn.Module):
83 |
84 | def __init__(self, args, w2v):
85 | super(XML_CNN, self).__init__()
86 | self.embed = nn.Embedding(args.vocab_size, args.vec_size)
87 | self.embed.weight = nn.Parameter(torch.from_numpy(w2v))
88 | self.conv13 = nn.Conv1d(500, 32, 2, stride=2)
89 | self.conv14 = nn.Conv1d(500, 32, 4, stride=2)
90 | self.conv15 = nn.Conv1d(500, 32, 8, stride=2)
91 |
92 | self.fc1 = nn.Linear(14272, 512)
93 | self.fc2 = nn.Linear(512, args.num_classes)
94 | self.m = nn.Sigmoid()
95 | def conv_and_pool(self, x, conv):
96 | x = F.relu(conv(x)).squeeze(3)
97 | return x
98 |
99 | def loss(self, x, target):
100 | return nn.BCELoss()(x, target)
101 |
102 | def forward(self, x):
103 | x = self.embed(x)
104 | batch_size = x.shape[0]
105 |
106 | x1 = self.conv13(x).reshape(batch_size, -1)
107 | x2 = self.conv14(x).reshape(batch_size, -1)
108 | x3 = self.conv15(x).reshape(batch_size, -1)
109 | x = torch.cat((x1, x2, x3), 1)
110 | hidden = self.fc1(x)
111 | activations = self.fc2(hidden)
112 | return self.m(activations)
113 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from multiprocessing import Pool
3 | def precision_at_k(r, k):
4 | assert k >= 1
5 | r = np.asarray(r)[:k] != 0
6 | if r.size != k:
7 | raise ValueError('Relevance score length < k')
8 | return np.mean(r)
9 |
10 | def dcg_at_k(r, k):
11 | r = np.asfarray(r)[:k]
12 | return np.sum(r / np.log2(np.arange(2, r.size + 2)))
13 |
14 |
15 | def ndcg_at_k(r, k):
16 | dcg_max = dcg_at_k(sorted(r, reverse=True), k)
17 | if not dcg_max:
18 | return 0.
19 | return dcg_at_k(r, k) / dcg_max
20 |
21 |
22 | def get_result(args):
23 |
24 | (y_pred, y_true)=args
25 |
26 | top_k = 50
27 | pred_topk_index = sorted(range(len(y_pred)), key=lambda i: y_pred[i], reverse=True)[:top_k]
28 | pos_index = set([k for k, v in enumerate(y_true) if v == 1])
29 |
30 | r = [1 if k in pos_index else 0 for k in pred_topk_index[:top_k]]
31 |
32 | p_1 = precision_at_k(r, 1)
33 | p_3 = precision_at_k(r, 3)
34 | p_5 = precision_at_k(r, 5)
35 |
36 | ndcg_1 = ndcg_at_k(r, 1)
37 | ndcg_3 = ndcg_at_k(r, 3)
38 | ndcg_5 = ndcg_at_k(r, 5)
39 |
40 | return np.array([p_1, p_3, p_5, ndcg_1, ndcg_3, ndcg_5])
41 |
42 | def evaluate(Y_tst_pred, Y_tst):
43 | pool = Pool(12)
44 | results = pool.map(get_result,zip(list(Y_tst_pred), list(Y_tst)))
45 | pool.terminate()
46 | tst_result = list(np.mean(np.array(results),0))
47 | print ('\rTst Prec@1,3,5: ', tst_result[:3], ' Tst NDCG@1,3,5: ', tst_result[3:])
48 |
--------------------------------------------------------------------------------
/w2v.py:
--------------------------------------------------------------------------------
1 | from os.path import join, exists, split
2 | import os
3 | import numpy as np
4 |
5 | def load_word2vec(model_type, vocabulary_inv, num_features=300):
6 | """
7 | loads Word2Vec model
8 | Returns initial weights for embedding layer.
9 |
10 | inputs:
11 | model_type # GoogleNews / glove
12 | vocabulary_inv # dict {str:int}
13 | num_features # Word vector dimensionality
14 | """
15 |
16 | model_dir = 'word2vec_models'
17 |
18 | if model_type == 'glove':
19 | model_name = join(model_dir, 'glove.6B.%dd.txt' % (num_features))
20 | assert(exists(model_name))
21 | print('Loading existing Word2Vec model (Glove.6B.%dd)' % (num_features))
22 |
23 | # dictionary, where key is word, value is word vectors
24 | embedding_model = {}
25 | for line in open(model_name, 'r', encoding="utf-8"):
26 | tmp = line.strip().split()
27 | word, vec = tmp[0], list(map(float, tmp[1:]))
28 | assert(len(vec) == num_features)
29 | if word not in embedding_model:
30 | embedding_model[word] = vec
31 | assert(len(embedding_model) == 400000)
32 |
33 | else:
34 | raise ValueError('Unknown pretrain model type: %s!' % (model_type))
35 |
36 | embedding_weights = [embedding_model[w] if w in embedding_model
37 | else np.random.uniform(-0.25, 0.25, num_features)
38 | for w in vocabulary_inv]
39 | embedding_weights = np.array(embedding_weights).astype('float32')
40 |
41 | return embedding_weights
42 |
--------------------------------------------------------------------------------