├── .gitignore ├── LICENSE ├── data ├── train.json └── val.json ├── fewshot_re_kit ├── __init__.py ├── data_loader.py ├── framework.py ├── network │ ├── __init__.py │ ├── embedding.py │ └── encoder.py └── sentence_encoder.py ├── models ├── __init__.py ├── gnn.py ├── gnn_iclr.py ├── metanet.py ├── proto.py └── snail.py ├── paper └── fewrel.pdf ├── readme.md ├── test_demo.py └── train_demo.py /.gitignore: -------------------------------------------------------------------------------- 1 | # folders 2 | checkpoint 3 | test_result 4 | 5 | # files 6 | *.pyc 7 | *.swp 8 | *.tar 9 | *.sh 10 | sbatch* 11 | *.ipynb 12 | 13 | # data 14 | _processed_data 15 | data/test.json 16 | data/glove.6B.50d.json 17 | 18 | # virtualenv 19 | .virtual 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 THUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /fewshot_re_kit/__init__.py: -------------------------------------------------------------------------------- 1 | from fewshot_re_kit import data_loader 2 | from fewshot_re_kit import framework 3 | from fewshot_re_kit import sentence_encoder 4 | 5 | -------------------------------------------------------------------------------- /fewshot_re_kit/data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import multiprocessing 4 | import numpy as np 5 | import random 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | class FileDataLoader: 10 | def next_batch(self, B, N, K, Q): 11 | ''' 12 | B: batch size. 13 | N: the number of relations for each batch 14 | K: the number of support instances for each relation 15 | Q: the number of query instances for each relation 16 | return: support_set, query_set, query_label 17 | ''' 18 | raise NotImplementedError 19 | 20 | class JSONFileDataLoader(FileDataLoader): 21 | def _load_preprocessed_file(self): 22 | name_prefix = '.'.join(self.file_name.split('/')[-1].split('.')[:-1]) 23 | word_vec_name_prefix = '.'.join(self.word_vec_file_name.split('/')[-1].split('.')[:-1]) 24 | processed_data_dir = '_processed_data' 25 | if not os.path.isdir(processed_data_dir): 26 | return False 27 | word_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_word.npy') 28 | pos1_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos1.npy') 29 | pos2_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_pos2.npy') 30 | mask_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_mask.npy') 31 | length_npy_file_name = os.path.join(processed_data_dir, name_prefix + '_length.npy') 32 | rel2scope_file_name = os.path.join(processed_data_dir, name_prefix + '_rel2scope.json') 33 | word_vec_mat_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy') 34 | word2id_file_name = os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json') 35 | if not os.path.exists(word_npy_file_name) or \ 36 | not os.path.exists(pos1_npy_file_name) or \ 37 | not os.path.exists(pos2_npy_file_name) or \ 38 | not os.path.exists(mask_npy_file_name) or \ 39 | not os.path.exists(length_npy_file_name) or \ 40 | not os.path.exists(rel2scope_file_name) or \ 41 | not os.path.exists(word_vec_mat_file_name) or \ 42 | not os.path.exists(word2id_file_name): 43 | return False 44 | print("Pre-processed files exist. Loading them...") 45 | self.data_word = np.load(word_npy_file_name) 46 | self.data_pos1 = np.load(pos1_npy_file_name) 47 | self.data_pos2 = np.load(pos2_npy_file_name) 48 | self.data_mask = np.load(mask_npy_file_name) 49 | self.data_length = np.load(length_npy_file_name) 50 | self.rel2scope = json.load(open(rel2scope_file_name)) 51 | self.word_vec_mat = np.load(word_vec_mat_file_name) 52 | self.word2id = json.load(open(word2id_file_name)) 53 | if self.data_word.shape[1] != self.max_length: 54 | print("Pre-processed files don't match current settings. Reprocessing...") 55 | return False 56 | print("Finish loading") 57 | return True 58 | 59 | def __init__(self, file_name, word_vec_file_name, max_length=40, case_sensitive=False, reprocess=False, cuda=True): 60 | ''' 61 | file_name: Json file storing the data in the following format 62 | { 63 | "P155": # relation id 64 | [ 65 | { 66 | "h": ["song for a future generation", "Q7561099", [[16, 17, ...]]], # head entity [word, id, location] 67 | "t": ["whammy kiss", "Q7990594", [[11, 12]]], # tail entity [word, id, location] 68 | "token": ["Hot", "Dance", "Club", ...], # sentence 69 | }, 70 | ... 71 | ], 72 | "P177": 73 | [ 74 | ... 75 | ] 76 | ... 77 | } 78 | word_vec_file_name: Json file storing word vectors in the following format 79 | [ 80 | {'word': 'the', 'vec': [0.418, 0.24968, ...]}, 81 | {'word': ',', 'vec': [0.013441, 0.23682, ...]}, 82 | ... 83 | ] 84 | max_length: The length that all the sentences need to be extend to. 85 | case_sensitive: Whether the data processing is case-sensitive, default as False. 86 | reprocess: Do the pre-processing whether there exist pre-processed files, default as False. 87 | cuda: Use cuda or not, default as True. 88 | ''' 89 | self.file_name = file_name 90 | self.word_vec_file_name = word_vec_file_name 91 | self.case_sensitive = case_sensitive 92 | self.max_length = max_length 93 | self.cuda = cuda 94 | 95 | if reprocess or not self._load_preprocessed_file(): # Try to load pre-processed files: 96 | # Check files 97 | if file_name is None or not os.path.isfile(file_name): 98 | raise Exception("[ERROR] Data file doesn't exist") 99 | if word_vec_file_name is None or not os.path.isfile(word_vec_file_name): 100 | raise Exception("[ERROR] Word vector file doesn't exist") 101 | 102 | # Load files 103 | print("Loading data file...") 104 | self.ori_data = json.load(open(self.file_name, "r")) 105 | print("Finish loading") 106 | print("Loading word vector file...") 107 | self.ori_word_vec = json.load(open(self.word_vec_file_name, "r")) 108 | print("Finish loading") 109 | 110 | # Eliminate case sensitive 111 | if not case_sensitive: 112 | print("Elimiating case sensitive problem...") 113 | for relation in self.ori_data: 114 | for ins in self.ori_data[relation]: 115 | for i in range(len(ins['tokens'])): 116 | ins['tokens'][i] = ins['tokens'][i].lower() 117 | print("Finish eliminating") 118 | 119 | 120 | # Pre-process word vec 121 | self.word2id = {} 122 | self.word_vec_tot = len(self.ori_word_vec) 123 | UNK = self.word_vec_tot 124 | BLANK = self.word_vec_tot + 1 125 | self.word_vec_dim = len(self.ori_word_vec[0]['vec']) 126 | print("Got {} words of {} dims".format(self.word_vec_tot, self.word_vec_dim)) 127 | print("Building word vector matrix and mapping...") 128 | self.word_vec_mat = np.zeros((self.word_vec_tot, self.word_vec_dim), dtype=np.float32) 129 | for cur_id, word in enumerate(self.ori_word_vec): 130 | w = word['word'] 131 | if not case_sensitive: 132 | w = w.lower() 133 | self.word2id[w] = cur_id 134 | self.word_vec_mat[cur_id, :] = word['vec'] 135 | self.word_vec_mat[cur_id] = self.word_vec_mat[cur_id] / np.sqrt(np.sum(self.word_vec_mat[cur_id] ** 2)) 136 | self.word2id['UNK'] = UNK 137 | self.word2id['BLANK'] = BLANK 138 | print("Finish building") 139 | 140 | # Pre-process data 141 | print("Pre-processing data...") 142 | self.instance_tot = 0 143 | for relation in self.ori_data: 144 | self.instance_tot += len(self.ori_data[relation]) 145 | self.data_word = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 146 | self.data_pos1 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 147 | self.data_pos2 = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 148 | self.data_mask = np.zeros((self.instance_tot, self.max_length), dtype=np.int32) 149 | self.data_length = np.zeros((self.instance_tot), dtype=np.int32) 150 | self.rel2scope = {} # left close right open 151 | i = 0 152 | for relation in self.ori_data: 153 | self.rel2scope[relation] = [i, i] 154 | for ins in self.ori_data[relation]: 155 | head = ins['h'][0] 156 | tail = ins['t'][0] 157 | pos1 = ins['h'][2][0][0] 158 | pos2 = ins['t'][2][0][0] 159 | words = ins['tokens'] 160 | cur_ref_data_word = self.data_word[i] 161 | for j, word in enumerate(words): 162 | if j < max_length: 163 | if word in self.word2id: 164 | cur_ref_data_word[j] = self.word2id[word] 165 | else: 166 | cur_ref_data_word[j] = UNK 167 | for j in range(j + 1, max_length): 168 | cur_ref_data_word[j] = BLANK 169 | self.data_length[i] = len(words) 170 | if len(words) > max_length: 171 | self.data_length[i] = max_length 172 | if pos1 >= max_length: 173 | pos1 = max_length - 1 174 | if pos2 >= max_length: 175 | pos2 = max_length - 1 176 | pos_min = min(pos1, pos2) 177 | pos_max = max(pos1, pos2) 178 | for j in range(max_length): 179 | self.data_pos1[i][j] = j - pos1 + max_length 180 | self.data_pos2[i][j] = j - pos2 + max_length 181 | if j >= self.data_length[i]: 182 | self.data_mask[i][j] = 0 183 | elif j <= pos_min: 184 | self.data_mask[i][j] = 1 185 | elif j <= pos_max: 186 | self.data_mask[i][j] = 2 187 | else: 188 | self.data_mask[i][j] = 3 189 | i += 1 190 | self.rel2scope[relation][1] = i 191 | 192 | print("Finish pre-processing") 193 | 194 | print("Storing processed files...") 195 | name_prefix = '.'.join(file_name.split('/')[-1].split('.')[:-1]) 196 | word_vec_name_prefix = '.'.join(word_vec_file_name.split('/')[-1].split('.')[:-1]) 197 | processed_data_dir = '_processed_data' 198 | if not os.path.isdir(processed_data_dir): 199 | os.mkdir(processed_data_dir) 200 | np.save(os.path.join(processed_data_dir, name_prefix + '_word.npy'), self.data_word) 201 | np.save(os.path.join(processed_data_dir, name_prefix + '_pos1.npy'), self.data_pos1) 202 | np.save(os.path.join(processed_data_dir, name_prefix + '_pos2.npy'), self.data_pos2) 203 | np.save(os.path.join(processed_data_dir, name_prefix + '_mask.npy'), self.data_mask) 204 | np.save(os.path.join(processed_data_dir, name_prefix + '_length.npy'), self.data_length) 205 | json.dump(self.rel2scope, open(os.path.join(processed_data_dir, name_prefix + '_rel2scope.json'), 'w')) 206 | np.save(os.path.join(processed_data_dir, word_vec_name_prefix + '_mat.npy'), self.word_vec_mat) 207 | json.dump(self.word2id, open(os.path.join(processed_data_dir, word_vec_name_prefix + '_word2id.json'), 'w')) 208 | print("Finish storing") 209 | 210 | def next_one(self, N, K, Q): 211 | target_classes = random.sample(self.rel2scope.keys(), N) 212 | support_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 213 | query_set = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 214 | query_label = [] 215 | 216 | for i, class_name in enumerate(target_classes): 217 | scope = self.rel2scope[class_name] 218 | indices = np.random.choice(list(range(scope[0], scope[1])), K + Q, False) 219 | word = self.data_word[indices] 220 | pos1 = self.data_pos1[indices] 221 | pos2 = self.data_pos2[indices] 222 | mask = self.data_mask[indices] 223 | support_word, query_word, _ = np.split(word, [K, K + Q]) 224 | support_pos1, query_pos1, _ = np.split(pos1, [K, K + Q]) 225 | support_pos2, query_pos2, _ = np.split(pos2, [K, K + Q]) 226 | support_mask, query_mask, _ = np.split(mask, [K, K + Q]) 227 | support_set['word'].append(support_word) 228 | support_set['pos1'].append(support_pos1) 229 | support_set['pos2'].append(support_pos2) 230 | support_set['mask'].append(support_mask) 231 | query_set['word'].append(query_word) 232 | query_set['pos1'].append(query_pos1) 233 | query_set['pos2'].append(query_pos2) 234 | query_set['mask'].append(query_mask) 235 | query_label += [i] * Q 236 | 237 | support_set['word'] = np.stack(support_set['word'], 0) 238 | support_set['pos1'] = np.stack(support_set['pos1'], 0) 239 | support_set['pos2'] = np.stack(support_set['pos2'], 0) 240 | support_set['mask'] = np.stack(support_set['mask'], 0) 241 | query_set['word'] = np.concatenate(query_set['word'], 0) 242 | query_set['pos1'] = np.concatenate(query_set['pos1'], 0) 243 | query_set['pos2'] = np.concatenate(query_set['pos2'], 0) 244 | query_set['mask'] = np.concatenate(query_set['mask'], 0) 245 | query_label = np.array(query_label) 246 | 247 | perm = np.random.permutation(N * Q) 248 | query_set['word'] = query_set['word'][perm] 249 | query_set['pos1'] = query_set['pos1'][perm] 250 | query_set['pos2'] = query_set['pos2'][perm] 251 | query_set['mask'] = query_set['mask'][perm] 252 | query_label = query_label[perm] 253 | 254 | return support_set, query_set, query_label 255 | 256 | def next_batch(self, B, N, K, Q): 257 | support = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 258 | query = {'word': [], 'pos1': [], 'pos2': [], 'mask': []} 259 | label = [] 260 | for one_sample in range(B): 261 | current_support, current_query, current_label = self.next_one(N, K, Q) 262 | support['word'].append(current_support['word']) 263 | support['pos1'].append(current_support['pos1']) 264 | support['pos2'].append(current_support['pos2']) 265 | support['mask'].append(current_support['mask']) 266 | query['word'].append(current_query['word']) 267 | query['pos1'].append(current_query['pos1']) 268 | query['pos2'].append(current_query['pos2']) 269 | query['mask'].append(current_query['mask']) 270 | label.append(current_label) 271 | support['word'] = Variable(torch.from_numpy(np.stack(support['word'], 0)).long().view(-1, self.max_length)) 272 | support['pos1'] = Variable(torch.from_numpy(np.stack(support['pos1'], 0)).long().view(-1, self.max_length)) 273 | support['pos2'] = Variable(torch.from_numpy(np.stack(support['pos2'], 0)).long().view(-1, self.max_length)) 274 | support['mask'] = Variable(torch.from_numpy(np.stack(support['mask'], 0)).long().view(-1, self.max_length)) 275 | query['word'] = Variable(torch.from_numpy(np.stack(query['word'], 0)).long().view(-1, self.max_length)) 276 | query['pos1'] = Variable(torch.from_numpy(np.stack(query['pos1'], 0)).long().view(-1, self.max_length)) 277 | query['pos2'] = Variable(torch.from_numpy(np.stack(query['pos2'], 0)).long().view(-1, self.max_length)) 278 | query['mask'] = Variable(torch.from_numpy(np.stack(query['mask'], 0)).long().view(-1, self.max_length)) 279 | label = Variable(torch.from_numpy(np.stack(label, 0).astype(np.int64)).long()) 280 | 281 | # To cuda 282 | if self.cuda: 283 | for key in support: 284 | support[key] = support[key].cuda() 285 | for key in query: 286 | query[key] = query[key].cuda() 287 | label = label.cuda() 288 | 289 | return support, query, label 290 | -------------------------------------------------------------------------------- /fewshot_re_kit/framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sklearn.metrics 3 | import numpy as np 4 | import sys 5 | import time 6 | from . import sentence_encoder 7 | from . import data_loader 8 | import torch 9 | from torch import autograd, optim, nn 10 | from torch.autograd import Variable 11 | from torch.nn import functional as F 12 | 13 | class FewShotREModel(nn.Module): 14 | def __init__(self, sentence_encoder): 15 | ''' 16 | sentence_encoder: Sentence encoder 17 | 18 | You need to set self.cost as your own loss function. 19 | ''' 20 | nn.Module.__init__(self) 21 | self.sentence_encoder = sentence_encoder 22 | self.cost = nn.CrossEntropyLoss() 23 | 24 | def forward(self, support, query, N, K, Q): 25 | ''' 26 | support: Inputs of the support set. 27 | query: Inputs of the query set. 28 | N: Num of classes 29 | K: Num of instances for each class in the support set 30 | Q: Num of instances for each class in the query set 31 | return: logits, pred 32 | ''' 33 | raise NotImplementedError 34 | 35 | def loss(self, logits, label): 36 | ''' 37 | logits: Logits with the size (..., class_num) 38 | label: Label with whatever size. 39 | return: [Loss] (A single value) 40 | ''' 41 | N = logits.size(-1) 42 | return self.cost(logits.view(-1, N), label.view(-1)) 43 | 44 | def accuracy(self, pred, label): 45 | ''' 46 | pred: Prediction results with whatever size 47 | label: Label with whatever size 48 | return: [Accuracy] (A single value) 49 | ''' 50 | return torch.mean((pred.view(-1) == label.view(-1)).type(torch.FloatTensor)) 51 | 52 | 53 | class FewShotREFramework: 54 | 55 | def __init__(self, train_data_loader, val_data_loader, test_data_loader): 56 | ''' 57 | train_data_loader: DataLoader for training. 58 | val_data_loader: DataLoader for validating. 59 | test_data_loader: DataLoader for testing. 60 | ''' 61 | self.train_data_loader = train_data_loader 62 | self.val_data_loader = val_data_loader 63 | self.test_data_loader = test_data_loader 64 | 65 | def __load_model__(self, ckpt): 66 | ''' 67 | ckpt: Path of the checkpoint 68 | return: Checkpoint dict 69 | ''' 70 | if os.path.isfile(ckpt): 71 | checkpoint = torch.load(ckpt) 72 | print("Successfully loaded checkpoint '%s'" % ckpt) 73 | return checkpoint 74 | else: 75 | raise Exception("No checkpoint found at '%s'" % ckpt) 76 | 77 | def item(self, x): 78 | ''' 79 | PyTorch before and after 0.4 80 | ''' 81 | torch_version = torch.__version__.split('.') 82 | if int(torch_version[0]) == 0 and int(torch_version[1]) < 4: 83 | return x[0] 84 | else: 85 | return x.item() 86 | 87 | def train(self, 88 | model, 89 | model_name, 90 | B, N_for_train, N_for_eval, K, Q, 91 | ckpt_dir='./checkpoint', 92 | test_result_dir='./test_result', 93 | learning_rate=1e-1, 94 | lr_step_size=20000, 95 | weight_decay=1e-5, 96 | train_iter=30000, 97 | val_iter=1000, 98 | val_step=2000, 99 | test_iter=3000, 100 | cuda=True, 101 | pretrain_model=None, 102 | optimizer=optim.SGD): 103 | ''' 104 | model: a FewShotREModel instance 105 | model_name: Name of the model 106 | B: Batch size 107 | N: Num of classes for each batch 108 | K: Num of instances for each class in the support set 109 | Q: Num of instances for each class in the query set 110 | ckpt_dir: Directory of checkpoints 111 | test_result_dir: Directory of test results 112 | learning_rate: Initial learning rate 113 | lr_step_size: Decay learning rate every lr_step_size steps 114 | weight_decay: Rate of decaying weight 115 | train_iter: Num of iterations of training 116 | val_iter: Num of iterations of validating 117 | val_step: Validate every val_step steps 118 | test_iter: Num of iterations of testing 119 | cuda: Use CUDA or not 120 | pretrain_model: Pre-trained checkpoint path 121 | ''' 122 | print("Start training...") 123 | 124 | # Init 125 | parameters_to_optimize = filter(lambda x:x.requires_grad, model.parameters()) 126 | optimizer = optimizer(parameters_to_optimize, learning_rate, weight_decay=weight_decay) 127 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=lr_step_size) 128 | if pretrain_model: 129 | checkpoint = self.__load_model__(pretrain_model) 130 | model.load_state_dict(checkpoint['state_dict']) 131 | start_iter = checkpoint['iter'] + 1 132 | else: 133 | start_iter = 0 134 | 135 | if cuda: 136 | model = model.cuda() 137 | model.train() 138 | 139 | # Training 140 | best_acc = 0 141 | not_best_count = 0 # Stop training after several epochs without improvement. 142 | iter_loss = 0.0 143 | iter_right = 0.0 144 | iter_sample = 0.0 145 | for it in range(start_iter, start_iter + train_iter): 146 | scheduler.step() 147 | support, query, label = self.train_data_loader.next_batch(B, N_for_train, K, Q) 148 | logits, pred = model(support, query, N_for_train, K, Q) 149 | loss = model.loss(logits, label) 150 | right = model.accuracy(pred, label) 151 | optimizer.zero_grad() 152 | loss.backward() 153 | nn.utils.clip_grad_norm(parameters_to_optimize, 10) 154 | optimizer.step() 155 | 156 | iter_loss += self.item(loss.data) 157 | iter_right += self.item(right.data) 158 | iter_sample += 1 159 | sys.stdout.write('step: {0:4} | loss: {1:2.6f}, accuracy: {2:3.2f}%'.format(it + 1, iter_loss / iter_sample, 100 * iter_right / iter_sample) +'\r') 160 | sys.stdout.flush() 161 | 162 | if it % val_step == 0: 163 | iter_loss = 0. 164 | iter_right = 0. 165 | iter_sample = 0. 166 | 167 | if (it + 1) % val_step == 0: 168 | acc = self.eval(model, B, N_for_eval, K, Q, val_iter) 169 | model.train() 170 | if acc > best_acc: 171 | print('Best checkpoint') 172 | if not os.path.exists(ckpt_dir): 173 | os.makedirs(ckpt_dir) 174 | save_path = os.path.join(ckpt_dir, model_name + ".pth.tar") 175 | torch.save({'state_dict': model.state_dict()}, save_path) 176 | best_acc = acc 177 | 178 | print("\n####################\n") 179 | print("Finish training " + model_name) 180 | test_acc = self.eval(model, B, N_for_eval, K, Q, test_iter, ckpt=os.path.join(ckpt_dir, model_name + '.pth.tar')) 181 | print("Test accuracy: {}".format(test_acc)) 182 | 183 | def eval(self, 184 | model, 185 | B, N, K, Q, 186 | eval_iter, 187 | ckpt=None): 188 | ''' 189 | model: a FewShotREModel instance 190 | B: Batch size 191 | N: Num of classes for each batch 192 | K: Num of instances for each class in the support set 193 | Q: Num of instances for each class in the query set 194 | eval_iter: Num of iterations 195 | ckpt: Checkpoint path. Set as None if using current model parameters. 196 | return: Accuracy 197 | ''' 198 | print("") 199 | model.eval() 200 | if ckpt is None: 201 | eval_dataset = self.val_data_loader 202 | else: 203 | checkpoint = self.__load_model__(ckpt) 204 | model.load_state_dict(checkpoint['state_dict']) 205 | eval_dataset = self.test_data_loader 206 | 207 | iter_right = 0.0 208 | iter_sample = 0.0 209 | for it in range(eval_iter): 210 | support, query, label = eval_dataset.next_batch(B, N, K, Q) 211 | logits, pred = model(support, query, N, K, Q) 212 | right = model.accuracy(pred, label) 213 | iter_right += self.item(right.data) 214 | iter_sample += 1 215 | 216 | sys.stdout.write('[EVAL] step: {0:4} | accuracy: {1:3.2f}%'.format(it + 1, 100 * iter_right / iter_sample) +'\r') 217 | sys.stdout.flush() 218 | print("") 219 | return iter_right / iter_sample 220 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/__init__.py: -------------------------------------------------------------------------------- 1 | from . import embedding 2 | from . import encoder 3 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | 7 | class Embedding(nn.Module): 8 | 9 | def __init__(self, word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5): 10 | nn.Module.__init__(self) 11 | 12 | self.max_length = max_length 13 | self.word_embedding_dim = word_embedding_dim 14 | self.pos_embedding_dim = pos_embedding_dim 15 | 16 | # Word embedding 17 | unk = torch.randn(1, word_embedding_dim) / math.sqrt(word_embedding_dim) 18 | blk = torch.zeros(1, word_embedding_dim) 19 | word_vec_mat = torch.from_numpy(word_vec_mat) 20 | self.word_embedding = nn.Embedding(word_vec_mat.shape[0] + 2, self.word_embedding_dim, padding_idx=word_vec_mat.shape[0] + 1) 21 | self.word_embedding.weight.data.copy_(torch.cat((word_vec_mat, unk, blk), 0)) 22 | 23 | # Position Embedding 24 | self.pos1_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0) 25 | self.pos2_embedding = nn.Embedding(2 * max_length, pos_embedding_dim, padding_idx=0) 26 | 27 | def forward(self, inputs): 28 | word = inputs['word'] 29 | pos1 = inputs['pos1'] 30 | pos2 = inputs['pos2'] 31 | 32 | x = torch.cat([self.word_embedding(word), 33 | self.pos1_embedding(pos1), 34 | self.pos2_embedding(pos2)], 2) 35 | return x 36 | 37 | 38 | -------------------------------------------------------------------------------- /fewshot_re_kit/network/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from torch import optim 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=230): 10 | nn.Module.__init__(self) 11 | 12 | self.max_length = max_length 13 | self.hidden_size = hidden_size 14 | self.embedding_dim = word_embedding_dim + pos_embedding_dim * 2 15 | self.conv = nn.Conv1d(self.embedding_dim, self.hidden_size, 3, padding=1) 16 | self.pool = nn.MaxPool1d(max_length) 17 | 18 | # For PCNN 19 | self.mask_embedding = nn.Embedding(4, 3) 20 | self.mask_embedding.weight.data.copy_(torch.FloatTensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 0]])) 21 | self.mask_embedding.weight.requires_grad = False 22 | self._minus = -100 23 | 24 | def forward(self, inputs): 25 | return self.cnn(inputs) 26 | 27 | def cnn(self, inputs): 28 | x = self.conv(inputs.transpose(1, 2)) 29 | x = F.relu(x) 30 | x = self.pool(x) 31 | return x.squeeze(2) # n x hidden_size 32 | 33 | def pcnn(self, inputs, mask): 34 | x = self.conv(inputs.transpose(1, 2)) # n x hidden x length 35 | mask = 1 - self.mask_embedding(mask).transpose(1, 2) # n x 3 x length 36 | pool1 = self.pool(F.relu(x + self._minus * mask[:, 0:1, :])) 37 | pool2 = self.pool(F.relu(x + self._minus * mask[:, 1:2, :])) 38 | pool3 = self.pool(F.relu(x + self._minus * mask[:, 2:3, :])) 39 | x = torch.cat([pool1, pool2, pool3], 1) 40 | x = x.squeeze(2) # n x (hidden_size * 3) 41 | 42 | -------------------------------------------------------------------------------- /fewshot_re_kit/sentence_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from torch import optim 6 | from . import network 7 | 8 | class CNNSentenceEncoder(nn.Module): 9 | 10 | def __init__(self, word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=230): 11 | nn.Module.__init__(self) 12 | self.hidden_size = hidden_size 13 | self.max_length = max_length 14 | self.embedding = network.embedding.Embedding(word_vec_mat, max_length, word_embedding_dim, pos_embedding_dim) 15 | self.encoder = network.encoder.Encoder(max_length, word_embedding_dim, pos_embedding_dim, hidden_size) 16 | 17 | def forward(self, inputs): 18 | x = self.embedding(inputs) 19 | x = self.encoder(x) 20 | return x 21 | 22 | class PCNNSentenceEncoder(nn.Module): 23 | 24 | def __init__(self, word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=230): 25 | nn.Module.__init__(self) 26 | self.hidden_size = hidden_size 27 | self.max_length = max_length 28 | self.embedding = network.embedding.Embedding(word_vec_mat, max_length, word_embedding_dim, pos_embedding_dim) 29 | self.encoder = network.encoder.Encoder(max_length, word_embedding_dim, pos_embedding_dim, hidden_size) 30 | 31 | def forward(self, inputs): 32 | x = self.embedding(inputs) 33 | x = self.encoder.pcnn(x, inputs['mask']) 34 | return x 35 | 36 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import proto 2 | from models import snail 3 | from models import gnn 4 | from models import metanet 5 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | from . import gnn_iclr 9 | 10 | class GNN(fewshot_re_kit.framework.FewShotREModel): 11 | 12 | def __init__(self, sentence_encoder, N, hidden_size=230): 13 | ''' 14 | N: Num of classes 15 | ''' 16 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 17 | self.hidden_size = hidden_size 18 | self.node_dim = hidden_size + N 19 | self.gnn_obj = gnn_iclr.GNN_nl(N, self.node_dim, nf=96, J=1) 20 | 21 | def forward(self, support, query, N, K, Q): 22 | ''' 23 | support: Inputs of the support set. 24 | query: Inputs of the query set. 25 | N: Num of classes 26 | K: Num of instances for each class in the support set 27 | Q: Num of instances for each class in the query set 28 | ''' 29 | support = self.sentence_encoder(support) 30 | query = self.sentence_encoder(query) 31 | support = support.view(-1, N, K, self.hidden_size) 32 | query = query.view(-1, N * Q, self.hidden_size) 33 | 34 | B = support.size(0) 35 | NQ = query.size(1) 36 | D = self.hidden_size 37 | 38 | support = support.unsqueeze(1).expand(-1, NQ, -1, -1, -1).contiguous().view(-1, N * K, D) # (B * NQ, N * K, D) 39 | query = query.view(-1, 1, D) # (B * NQ, 1, D) 40 | labels = Variable(torch.zeros((B * NQ, 1 + N * K, N), dtype=torch.float)).cuda() 41 | for b in range(B * NQ): 42 | for i in range(N): 43 | for k in range(K): 44 | labels[b][1 + i * K + k][i] = 1 45 | nodes = torch.cat([torch.cat([query, support], 1), labels], -1) # (B * NQ, 1 + N * K, D + N) 46 | 47 | logits = self.gnn_obj(nodes) # (B * NQ, N) 48 | _, pred = torch.max(logits, 1) 49 | return logits, pred 50 | -------------------------------------------------------------------------------- /models/gnn_iclr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | 4 | # Pytorch requirements 5 | 6 | ''' 7 | GNN models implemented by vgsatorras from https://github.com/vgsatorras/few-shot-gnn 8 | ''' 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.autograd import Variable 13 | import torch.nn.functional as F 14 | 15 | if torch.cuda.is_available(): 16 | dtype = torch.cuda.FloatTensor 17 | dtype_l = torch.cuda.LongTensor 18 | else: 19 | dtype = torch.FloatTensor 20 | dtype_l = torch.cuda.LongTensor 21 | 22 | 23 | def gmul(input): 24 | W, x = input 25 | # x is a tensor of size (bs, N, num_features) 26 | # W is a tensor of size (bs, N, N, J) 27 | x_size = x.size() 28 | W_size = W.size() 29 | N = W_size[-2] 30 | W = W.split(1, 3) 31 | W = torch.cat(W, 1).squeeze(3) # W is now a tensor of size (bs, J*N, N) 32 | output = torch.bmm(W, x) # output has size (bs, J*N, num_features) 33 | output = output.split(N, 1) 34 | output = torch.cat(output, 2) # output has size (bs, N, J*num_features) 35 | return output 36 | 37 | 38 | class Gconv(nn.Module): 39 | def __init__(self, nf_input, nf_output, J, bn_bool=True): 40 | super(Gconv, self).__init__() 41 | self.J = J 42 | self.num_inputs = J*nf_input 43 | self.num_outputs = nf_output 44 | self.fc = nn.Linear(self.num_inputs, self.num_outputs) 45 | 46 | self.bn_bool = bn_bool 47 | if self.bn_bool: 48 | self.bn = nn.BatchNorm1d(self.num_outputs) 49 | 50 | def forward(self, input): 51 | W = input[0] 52 | x = gmul(input) # out has size (bs, N, num_inputs) 53 | #if self.J == 1: 54 | # x = torch.abs(x) 55 | x_size = x.size() 56 | x = x.contiguous() 57 | x = x.view(-1, self.num_inputs) 58 | x = self.fc(x) # has size (bs*N, num_outputs) 59 | 60 | if self.bn_bool: 61 | x = self.bn(x) 62 | 63 | x = x.view(x_size[0], x_size[1], self.num_outputs) 64 | return W, x 65 | 66 | 67 | class Wcompute(nn.Module): 68 | def __init__(self, input_features, nf, operator='J2', activation='softmax', ratio=[2,2,1,1], num_operators=1, drop=False): 69 | super(Wcompute, self).__init__() 70 | self.num_features = nf 71 | self.operator = operator 72 | self.conv2d_1 = nn.Conv2d(input_features, int(nf * ratio[0]), 1, stride=1) 73 | self.bn_1 = nn.BatchNorm2d(int(nf * ratio[0])) 74 | self.drop = drop 75 | if self.drop: 76 | self.dropout = nn.Dropout(0.3) 77 | self.conv2d_2 = nn.Conv2d(int(nf * ratio[0]), int(nf * ratio[1]), 1, stride=1) 78 | self.bn_2 = nn.BatchNorm2d(int(nf * ratio[1])) 79 | self.conv2d_3 = nn.Conv2d(int(nf * ratio[1]), nf*ratio[2], 1, stride=1) 80 | self.bn_3 = nn.BatchNorm2d(nf*ratio[2]) 81 | self.conv2d_4 = nn.Conv2d(nf*ratio[2], nf*ratio[3], 1, stride=1) 82 | self.bn_4 = nn.BatchNorm2d(nf*ratio[3]) 83 | self.conv2d_last = nn.Conv2d(nf, num_operators, 1, stride=1) 84 | self.activation = activation 85 | 86 | def forward(self, x, W_id): 87 | W1 = x.unsqueeze(2) 88 | W2 = torch.transpose(W1, 1, 2) #size: bs x N x N x num_features 89 | W_new = torch.abs(W1 - W2) #size: bs x N x N x num_features 90 | W_new = torch.transpose(W_new, 1, 3) #size: bs x num_features x N x N 91 | 92 | W_new = self.conv2d_1(W_new) 93 | W_new = self.bn_1(W_new) 94 | W_new = F.leaky_relu(W_new) 95 | if self.drop: 96 | W_new = self.dropout(W_new) 97 | 98 | W_new = self.conv2d_2(W_new) 99 | W_new = self.bn_2(W_new) 100 | W_new = F.leaky_relu(W_new) 101 | 102 | W_new = self.conv2d_3(W_new) 103 | W_new = self.bn_3(W_new) 104 | W_new = F.leaky_relu(W_new) 105 | 106 | W_new = self.conv2d_4(W_new) 107 | W_new = self.bn_4(W_new) 108 | W_new = F.leaky_relu(W_new) 109 | 110 | W_new = self.conv2d_last(W_new) 111 | W_new = torch.transpose(W_new, 1, 3) #size: bs x N x N x 1 112 | 113 | if self.activation == 'softmax': 114 | W_new = W_new - W_id.expand_as(W_new) * 1e8 115 | W_new = torch.transpose(W_new, 2, 3) 116 | # Applying Softmax 117 | W_new = W_new.contiguous() 118 | W_new_size = W_new.size() 119 | W_new = W_new.view(-1, W_new.size(3)) 120 | W_new = F.softmax(W_new) 121 | W_new = W_new.view(W_new_size) 122 | # Softmax applied 123 | W_new = torch.transpose(W_new, 2, 3) 124 | 125 | elif self.activation == 'sigmoid': 126 | W_new = F.sigmoid(W_new) 127 | W_new *= (1 - W_id) 128 | elif self.activation == 'none': 129 | W_new *= (1 - W_id) 130 | else: 131 | raise (NotImplementedError) 132 | 133 | if self.operator == 'laplace': 134 | W_new = W_id - W_new 135 | elif self.operator == 'J2': 136 | W_new = torch.cat([W_id, W_new], 3) 137 | else: 138 | raise(NotImplementedError) 139 | 140 | return W_new 141 | 142 | 143 | class GNN_nl_omniglot(nn.Module): 144 | def __init__(self, args, input_features, nf, J): 145 | super(GNN_nl_omniglot, self).__init__() 146 | self.args = args 147 | self.input_features = input_features 148 | self.nf = nf 149 | self.J = J 150 | 151 | self.num_layers = 2 152 | for i in range(self.num_layers): 153 | module_w = Wcompute(self.input_features + int(nf / 2) * i, 154 | self.input_features + int(nf / 2) * i, 155 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=False) 156 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 157 | self.add_module('layer_w{}'.format(i), module_w) 158 | self.add_module('layer_l{}'.format(i), module_l) 159 | 160 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, 161 | self.input_features + int(self.nf / 2) * (self.num_layers - 1), 162 | operator='J2', activation='softmax', ratio=[2, 1.5, 1, 1], drop=True) 163 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=True) 164 | 165 | def forward(self, x): 166 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 167 | if self.args.cuda: 168 | W_init = W_init.cuda() 169 | 170 | for i in range(self.num_layers): 171 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 172 | 173 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 174 | x = torch.cat([x, x_new], 2) 175 | 176 | Wl=self.w_comp_last(x, W_init) 177 | out = self.layer_last([Wl, x])[1] 178 | 179 | return out[:, 0, :] 180 | 181 | 182 | class GNN_nl(nn.Module): 183 | def __init__(self, N, input_features, nf, J): 184 | super(GNN_nl, self).__init__() 185 | # self.args = args 186 | self.input_features = input_features 187 | self.nf = nf 188 | self.J = J 189 | 190 | self.num_layers = 2 191 | 192 | for i in range(self.num_layers): 193 | if i == 0: 194 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 195 | module_l = Gconv(self.input_features, int(nf / 2), 2) 196 | else: 197 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 198 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 199 | self.add_module('layer_w{}'.format(i), module_w) 200 | self.add_module('layer_l{}'.format(i), module_l) 201 | 202 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 203 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, N, 2, bn_bool=False) 204 | 205 | def forward(self, x): 206 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 207 | W_init = W_init.cuda() 208 | 209 | for i in range(self.num_layers): 210 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 211 | 212 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 213 | x = torch.cat([x, x_new], 2) 214 | 215 | Wl=self.w_comp_last(x, W_init) 216 | out = self.layer_last([Wl, x])[1] 217 | 218 | return out[:, 0, :] 219 | 220 | class GNN_active(nn.Module): 221 | def __init__(self, args, input_features, nf, J): 222 | super(GNN_active, self).__init__() 223 | self.args = args 224 | self.input_features = input_features 225 | self.nf = nf 226 | self.J = J 227 | 228 | self.num_layers = 2 229 | for i in range(self.num_layers // 2): 230 | if i == 0: 231 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 232 | module_l = Gconv(self.input_features, int(nf / 2), 2) 233 | else: 234 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 235 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 236 | 237 | self.add_module('layer_w{}'.format(i), module_w) 238 | self.add_module('layer_l{}'.format(i), module_l) 239 | 240 | self.conv_active_1 = nn.Conv1d(self.input_features + int(nf / 2) * 1, self.input_features + int(nf / 2) * 1, 1) 241 | self.bn_active = nn.BatchNorm1d(self.input_features + int(nf / 2) * 1) 242 | self.conv_active_2 = nn.Conv1d(self.input_features + int(nf / 2) * 1, 1, 1) 243 | 244 | for i in range(int(self.num_layers/2), self.num_layers): 245 | if i == 0: 246 | module_w = Wcompute(self.input_features, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 247 | module_l = Gconv(self.input_features, int(nf / 2), 2) 248 | else: 249 | module_w = Wcompute(self.input_features + int(nf / 2) * i, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 250 | module_l = Gconv(self.input_features + int(nf / 2) * i, int(nf / 2), 2) 251 | self.add_module('layer_w{}'.format(i), module_w) 252 | self.add_module('layer_l{}'.format(i), module_l) 253 | 254 | self.w_comp_last = Wcompute(self.input_features + int(self.nf / 2) * self.num_layers, nf, operator='J2', activation='softmax', ratio=[2, 2, 1, 1]) 255 | self.layer_last = Gconv(self.input_features + int(self.nf / 2) * self.num_layers, args.train_N_way, 2, bn_bool=False) 256 | 257 | def active(self, x, oracles_yi, hidden_labels): 258 | x_active = torch.transpose(x, 1, 2) 259 | x_active = self.conv_active_1(x_active) 260 | x_active = F.leaky_relu(self.bn_active(x_active)) 261 | x_active = self.conv_active_2(x_active) 262 | x_active = torch.transpose(x_active, 1, 2) 263 | 264 | x_active = x_active.squeeze(-1) 265 | x_active = x_active - (1-hidden_labels)*1e8 266 | x_active = F.softmax(x_active) 267 | x_active = x_active*hidden_labels 268 | 269 | if self.args.active_random == 1: 270 | #print('random active') 271 | x_active.data.fill_(1./x_active.size(1)) 272 | decision = torch.multinomial(x_active) 273 | x_active = x_active.detach() 274 | else: 275 | if self.training: 276 | decision = torch.multinomial(x_active) 277 | else: 278 | _, decision = torch.max(x_active, 1) 279 | decision = decision.unsqueeze(-1) 280 | 281 | decision = decision.detach() 282 | 283 | mapping = torch.FloatTensor(decision.size(0),x_active.size(1)).zero_() 284 | mapping = Variable(mapping) 285 | if self.args.cuda: 286 | mapping = mapping.cuda() 287 | mapping.scatter_(1, decision, 1) 288 | 289 | mapping_bp = (x_active*mapping).unsqueeze(-1) 290 | mapping_bp = mapping_bp.expand_as(oracles_yi) 291 | 292 | label2add = mapping_bp*oracles_yi #bsxNodesxN_way 293 | padd = torch.zeros(x.size(0), x.size(1), x.size(2) - label2add.size(2)) 294 | padd = Variable(padd).detach() 295 | if self.args.cuda: 296 | padd = padd.cuda() 297 | label2add = torch.cat([label2add, padd], 2) 298 | 299 | x = x+label2add 300 | return x 301 | 302 | 303 | def forward(self, x, oracles_yi, hidden_labels): 304 | W_init = Variable(torch.eye(x.size(1)).unsqueeze(0).repeat(x.size(0), 1, 1).unsqueeze(3)) 305 | if self.args.cuda: 306 | W_init = W_init.cuda() 307 | 308 | for i in range(self.num_layers // 2): 309 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 310 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 311 | x = torch.cat([x, x_new], 2) 312 | 313 | x = self.active(x, oracles_yi, hidden_labels) 314 | 315 | for i in range(int(self.num_layers/2), self.num_layers): 316 | Wi = self._modules['layer_w{}'.format(i)](x, W_init) 317 | x_new = F.leaky_relu(self._modules['layer_l{}'.format(i)]([Wi, x])[1]) 318 | x = torch.cat([x, x_new], 2) 319 | 320 | 321 | Wl=self.w_comp_last(x, W_init) 322 | out = self.layer_last([Wl, x])[1] 323 | 324 | return out[:, 0, :] 325 | 326 | if __name__ == '__main__': 327 | # test modules 328 | bs = 4 329 | nf = 10 330 | num_layers = 5 331 | N = 8 332 | x = torch.ones((bs, N, nf)) 333 | W1 = torch.eye(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 334 | W2 = torch.ones(N).unsqueeze(0).unsqueeze(-1).expand(bs, N, N, 1) 335 | J = 2 336 | W = torch.cat((W1, W2), 3) 337 | input = [Variable(W), Variable(x)] 338 | ######################### test gmul ############################## 339 | # feature_maps = [num_features, num_features, num_features] 340 | # out = gmul(input) 341 | # print(out[0, :, num_features:]) 342 | ######################### test gconv ############################## 343 | # feature_maps = [num_features, num_features, num_features] 344 | # gconv = Gconv(feature_maps, J) 345 | # _, out = gconv(input) 346 | # print(out.size()) 347 | ######################### test gnn ############################## 348 | # x = torch.ones((bs, N, 1)) 349 | # input = [Variable(W), Variable(x)] 350 | # gnn = GNN(num_features, num_layers, J) 351 | # out = gnn(input) 352 | # print(out.size()) 353 | 354 | 355 | -------------------------------------------------------------------------------- /models/metanet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | from fewshot_re_kit.network.embedding import Embedding 5 | from fewshot_re_kit.network.encoder import Encoder 6 | import torch 7 | from torch import autograd, optim, nn 8 | from torch.autograd import Variable 9 | from torch.nn import functional as F 10 | import numpy as np 11 | 12 | def log_and_sign(inputs, k=7): 13 | eps = 1e-7 14 | log = torch.log(torch.abs(inputs) + eps) / k 15 | log[log < -1.0] = -1.0 16 | sign = log * np.exp(k) 17 | sign[sign < -1.0] = -1.0 18 | sign[sign > 1.0] = 1.0 19 | return torch.cat([log, sign], 1) 20 | 21 | class LearnerForAttention(nn.Module): 22 | 23 | def __init__(self): 24 | nn.Module.__init__(self) 25 | self.conv_lstm = nn.LSTM(2, 20, batch_first=True) 26 | self.conv_fc = nn.Linear(20, 1) 27 | self.fc_lstm = nn.LSTM(2, 20, batch_first=True) 28 | self.fc_fc = nn.Linear(20, 1) 29 | 30 | def forward(self, inputs, is_conv): 31 | size = inputs.size() 32 | x = inputs.view((-1, 1)) 33 | x = log_and_sign(x) # (-1, 2) 34 | 35 | #### NO BACKPROP 36 | x = Variable(x, requires_grad=False).unsqueeze(0) # (1, param_size, 2) 37 | #### 38 | 39 | if is_conv: 40 | x, _ = self.conv_lstm(x) # (1, param_size, 1) 41 | x = x.squeeze() 42 | x = self.conv_fc(x) 43 | else: 44 | x, _ = self.fc_lstm(x) # (1, param_size, 1) 45 | x = x.squeeze() 46 | x = self.fc_fc(x) 47 | return x.view(size) 48 | 49 | class LearnerForBasic(nn.Module): 50 | 51 | def __init__(self): 52 | nn.Module.__init__(self) 53 | self.conv_fc1 = nn.Linear(2, 20) 54 | self.conv_fc2 = nn.Linear(20, 20) 55 | self.conv_fc3 = nn.Linear(20, 1) 56 | self.fc_fc1 = nn.Linear(2, 20) 57 | self.fc_fc2 = nn.Linear(20, 20) 58 | self.fc_fc3 = nn.Linear(20, 1) 59 | 60 | 61 | def forward(self, inputs, is_conv): 62 | size = inputs.size() 63 | x = inputs.view((-1, 1)) 64 | x = log_and_sign(x) # (-1, 2) 65 | 66 | #### NO BACKPROP 67 | x = Variable(x, requires_grad=False) 68 | #### 69 | 70 | if is_conv: 71 | x = F.relu(self.conv_fc1(x)) 72 | x = F.relu(self.conv_fc2(x)) 73 | x = self.conv_fc3(x) 74 | else: 75 | x = F.relu(self.fc_fc1(x)) 76 | x = F.relu(self.fc_fc2(x)) 77 | x = self.fc_fc3(x) 78 | return x.view(size) 79 | 80 | class MetaNet(fewshot_re_kit.framework.FewShotREModel): 81 | 82 | def __init__(self, N, K, word_vec_mat, max_length, hidden_size=230): 83 | ''' 84 | N: num of classes 85 | K: num of instances for each class 86 | word_vec_mat, max_length, hidden_size: same as sentence_encoder 87 | ''' 88 | fewshot_re_kit.framework.FewShotREModel.__init__(self, None) 89 | self.max_length = max_length 90 | self.hidden_size = hidden_size 91 | self.N = N 92 | self.K = K 93 | 94 | self.embedding = Embedding(word_vec_mat, max_length, word_embedding_dim=50, pos_embedding_dim=5) 95 | 96 | self.basic_encoder = Encoder(max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=hidden_size) 97 | self.attention_encoder = Encoder(max_length, word_embedding_dim=50, pos_embedding_dim=5, hidden_size=hidden_size) 98 | 99 | self.basic_fast_conv_W = None 100 | self.attention_fast_conv_W = None 101 | 102 | self.basic_fc = nn.Linear(hidden_size, N, bias=False) 103 | self.attention_fc = nn.Linear(hidden_size, N, bias=False) 104 | 105 | self.basic_fast_fc_W = None 106 | self.attention_fast_fc_W = None 107 | 108 | self.learner_basic = LearnerForBasic() 109 | self.learner_attention = LearnerForAttention() 110 | 111 | def basic_emb(self, inputs, size, use_fast=False): 112 | x = self.embedding(inputs) 113 | output = self.basic_encoder(x) 114 | if use_fast: 115 | output += F.relu(F.conv1d(x.transpose(-1, -2), self.basic_fast_conv_W, padding=1)).max(-1)[0] 116 | return output.view(size) 117 | 118 | def attention_emb(self, inputs, size, use_fast=False): 119 | x = self.embedding(inputs) 120 | output = self.attention_encoder(x) 121 | if use_fast: 122 | output += F.relu(F.conv1d(x.transpose(-1, -2), self.attention_fast_conv_W, padding=1)).max(-1)[0] 123 | return output.view(size) 124 | 125 | def attention_score(self, s_att, q_att): 126 | ''' 127 | s_att: (B, N, K, D) 128 | q_att: (B, NQ, D) 129 | ''' 130 | s_att = s_att.view(s_att.size(0), s_att.size(1) * s_att.size(2), s_att.size(3)) # (B, N * K, D) 131 | s_att = s_att.unsqueeze(1) # (B, 1, N * K, D) 132 | q_att = q_att.unsqueeze(2) # (B, NQ, 1, D) 133 | cos = F.cosine_similarity(s_att, q_att, dim=-1) # (B, NQ, N * K) 134 | score = F.softmax(cos, -1) # (B, NQ, N * K) 135 | return score 136 | 137 | def forward(self, support, query, N, K, Q): 138 | ''' 139 | support: Inputs of the support set. 140 | query: Inputs of the query set. 141 | N: Num of classes 142 | K: Num of instances for each class in the support set 143 | Q: Num of instances for each class in the query set 144 | ''' 145 | 146 | # learn fast parameters for attention encoder 147 | s = self.attention_emb(support, (-1, N, K, self.hidden_size)) 148 | logits = self.attention_fc(s) # (B, N, K, N) 149 | 150 | B = s.size(0) 151 | NQ = N * Q 152 | assert(B == 1) 153 | 154 | self.zero_grad() 155 | tmp_label = Variable(torch.tensor([[x] * K for x in range(N)] * B, dtype=torch.long).cuda()) 156 | loss = self.cost(logits.view(-1, N), tmp_label.view(-1)) 157 | loss.backward(retain_graph=True) 158 | 159 | grad_conv = self.attention_encoder.conv.weight.grad 160 | grad_fc = self.attention_fc.weight.grad 161 | 162 | self.attention_fast_conv_W = self.learner_attention(grad_conv, is_conv=True) 163 | self.attention_fast_fc_W = self.learner_attention(grad_fc, is_conv=False) 164 | 165 | # learn fast parameters for basic encoder (each class) 166 | s = self.basic_emb(support, (-1, N, K, self.hidden_size)) 167 | logits = self.basic_fc(s) # (B, N, K, N) 168 | 169 | basic_fast_conv_params = [] 170 | basic_fast_fc_params = [] 171 | for i in range(N): 172 | for j in range(K): 173 | self.zero_grad() 174 | tmp_label = Variable(torch.tensor([i], dtype=torch.long).cuda()) 175 | loss = self.cost(logits[:, i, j].view(-1, N), tmp_label.view(-1)) 176 | loss.backward(retain_graph=True) 177 | 178 | grad_conv = self.basic_encoder.conv.weight.grad 179 | grad_fc = self.basic_fc.weight.grad 180 | 181 | basic_fast_conv_params.append(self.learner_basic(grad_conv, is_conv=True)) 182 | basic_fast_fc_params.append(self.learner_basic(grad_fc, is_conv=False)) 183 | basic_fast_conv_params = torch.stack(basic_fast_conv_params, 0) # (N * K, conv_weight_size) 184 | basic_fast_fc_params = torch.stack(basic_fast_fc_params, 0) # (N * K, fc_weight_size) 185 | 186 | # final 187 | self.zero_grad() 188 | s_att = self.attention_emb(support, (-1, N, K, self.hidden_size), use_fast=True) 189 | q_att = self.attention_emb(query, (-1, NQ, self.hidden_size), use_fast=True) 190 | score = self.attention_score(s_att, q_att).squeeze(0) # assume B = 1, (NQ, N * K) 191 | size_conv_param = basic_fast_conv_params.size()[1:] 192 | size_fc_param = basic_fast_fc_params.size()[1:] 193 | final_fast_conv_param = torch.matmul(score, basic_fast_conv_params.view(N * K, -1)) # (NQ, conv_weight_size) 194 | final_fast_fc_param = torch.matmul(score, basic_fast_fc_params.view(N * K, -1)) # (NQ, fc_weight_size) 195 | stack_logits = [] 196 | for i in range(NQ): 197 | self.basic_fast_conv_W = final_fast_conv_param[i].view(size_conv_param) 198 | self.basic_fast_fc_W = final_fast_fc_param[i].view(size_fc_param) 199 | q = self.basic_emb({'word': query['word'][i:i+1], 'pos1': query['pos1'][i:i+1], 'pos2': query['pos2'][i:i+1], 'mask': query['mask'][i:i+1]}, (self.hidden_size), use_fast=True) 200 | logits = self.basic_fc(q) + F.linear(q, self.basic_fast_fc_W) 201 | stack_logits.append(logits) 202 | logits = torch.stack(stack_logits, 0) 203 | 204 | _, pred = torch.max(logits.view(-1, N), 1) 205 | return logits, pred 206 | -------------------------------------------------------------------------------- /models/proto.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | 9 | class Proto(fewshot_re_kit.framework.FewShotREModel): 10 | 11 | def __init__(self, sentence_encoder, hidden_size=230): 12 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 13 | self.hidden_size = hidden_size 14 | self.fc = nn.Linear(hidden_size, hidden_size) 15 | self.drop = nn.Dropout() 16 | 17 | def __dist__(self, x, y, dim): 18 | return (torch.pow(x - y, 2)).sum(dim) 19 | 20 | def __batch_dist__(self, S, Q): 21 | return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3) 22 | 23 | def forward(self, support, query, N, K, Q): 24 | ''' 25 | support: Inputs of the support set. 26 | query: Inputs of the query set. 27 | N: Num of classes 28 | K: Num of instances for each class in the support set 29 | Q: Num of instances for each class in the query set 30 | ''' 31 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 32 | query = self.sentence_encoder(query) # (B * N * Q, D) 33 | support = self.drop(support) 34 | query = self.drop(query) 35 | support = support.view(-1, N, K, self.hidden_size) # (B, N, K, D) 36 | query = query.view(-1, N * Q, self.hidden_size) # (B, N * Q, D) 37 | 38 | B = support.size(0) # Batch size 39 | NQ = query.size(1) # Num of instances for each batch in the query set 40 | 41 | # Prototypical Networks 42 | support = torch.mean(support, 2) # Calculate prototype for each class 43 | logits = -self.__batch_dist__(support, query) 44 | _, pred = torch.max(logits.view(-1, N), 1) 45 | return logits, pred 46 | 47 | 48 | 49 | -------------------------------------------------------------------------------- /models/snail.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('..') 3 | import fewshot_re_kit 4 | import torch 5 | from torch import autograd, optim, nn 6 | from torch.autograd import Variable 7 | from torch.nn import functional as F 8 | import numpy as np 9 | 10 | class CausalConv1d(nn.Module): 11 | 12 | def __init__(self, in_channels, out_channels, kernel_size=2, dilation=2): 13 | super(CausalConv1d, self).__init__() 14 | self.padding = dilation 15 | self.causal_conv = nn.Conv1d( 16 | in_channels, 17 | out_channels, 18 | kernel_size, 19 | padding = self.padding, 20 | dilation = dilation) 21 | 22 | def forward(self, minibatch): 23 | return self.causal_conv(minibatch)[:, :, :-self.padding] 24 | 25 | 26 | class DenseBlock(nn.Module): 27 | 28 | def __init__(self, in_channels, filters, dilation=2): 29 | super(DenseBlock, self).__init__() 30 | self.causal_conv1 = CausalConv1d( 31 | in_channels, 32 | filters, 33 | dilation=dilation) 34 | self.causal_conv2 = CausalConv1d( 35 | in_channels, 36 | filters, 37 | dilation=dilation) 38 | 39 | def forward(self, minibatch): 40 | tanh = F.tanh(self.causal_conv1(minibatch)) 41 | sig = F.sigmoid(self.causal_conv2(minibatch)) 42 | out = torch.cat([minibatch, tanh*sig], dim=1) 43 | return out 44 | 45 | class TCBlock(nn.Module): 46 | 47 | def __init__(self, in_channels, filters, seq_len): 48 | super(TCBlock, self).__init__() 49 | layer_count = np.ceil(np.log2(seq_len)).astype(np.int32) 50 | blocks = [] 51 | channel_count = in_channels 52 | for layer in range(layer_count): 53 | block = DenseBlock(channel_count, filters, dilation=2**layer) 54 | blocks.append(block) 55 | channel_count += filters 56 | self.tcblock = nn.Sequential(*blocks) 57 | self._dim = channel_count 58 | 59 | def forward(self, minibatch): 60 | return self.tcblock(minibatch) 61 | 62 | @property 63 | def dim(self): 64 | return self._dim 65 | 66 | class AttentionBlock(nn.Module): 67 | def __init__(self, dims, k_size, v_size, seq_len): 68 | 69 | super(AttentionBlock, self).__init__() 70 | self.key_layer = nn.Linear(dims, k_size) 71 | self.query_layer = nn.Linear(dims, k_size) 72 | self.value_layer = nn.Linear(dims, v_size) 73 | self.sqrt_k = np.sqrt(k_size) 74 | mask = np.tril(np.ones((seq_len, seq_len))).astype(np.float32) 75 | self.mask = nn.Parameter(torch.from_numpy(mask), requires_grad=False) 76 | self.minus = - 100. 77 | self._dim = dims + v_size 78 | 79 | def forward(self, minibatch, current_seq_len): 80 | keys = self.key_layer(minibatch) 81 | #queries = self.query_layer(minibatch) 82 | queries = keys 83 | values = self.value_layer(minibatch) 84 | current_mask = self.mask[:current_seq_len, :current_seq_len] 85 | logits = current_mask * torch.div(torch.bmm(queries, keys.transpose(2,1)), self.sqrt_k) + self.minus * (1. - current_mask) 86 | probs = F.softmax(logits, 2) 87 | read = torch.bmm(probs, values) 88 | return torch.cat([minibatch, read], dim=2) 89 | 90 | @property 91 | def dim(self): 92 | return self._dim 93 | 94 | class SNAIL(fewshot_re_kit.framework.FewShotREModel): 95 | 96 | def __init__(self, sentence_encoder, N, K, hidden_size=230): 97 | ''' 98 | N: num of classes 99 | K: num of instances for each class in the support set 100 | ''' 101 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 102 | self.hidden_size = hidden_size 103 | self.drop = nn.Dropout() 104 | self.seq_len = N * K + 1 105 | self.att0 = AttentionBlock(hidden_size, 64, 32, self.seq_len) 106 | self.tc1 = TCBlock(self.att0.dim, 128, self.seq_len) 107 | self.att1 = AttentionBlock(self.tc1.dim, 256, 128, self.seq_len) 108 | self.tc2 = TCBlock(self.att1.dim, 128, self.seq_len) 109 | self.att2 = AttentionBlock(self.tc2.dim, 512, 256, self.seq_len) 110 | self.disc = nn.Linear(self.att2.dim, N, bias=False) 111 | self.bn1 = nn.BatchNorm2d(self.tc1.dim) 112 | self.bn2 = nn.BatchNorm2d(self.tc2.dim) 113 | 114 | def forward(self, support, query, N, K, Q): 115 | support = self.sentence_encoder(support) # (B * N * K, D), where D is the hidden size 116 | query = self.sentence_encoder(query) # (B * N * Q, D) 117 | # support = self.drop(support) 118 | # query = self.drop(query) 119 | support = support.view(-1, N, K, self.hidden_size) # (B, N, K, D) 120 | query = query.view(-1, N * Q, self.hidden_size) # (B, N * Q, D) 121 | B = support.size(0) # Batch size 122 | NQ = query.size(1) # Num of instances for each batch in the query set 123 | 124 | support = support.unsqueeze(1).expand(-1, NQ, -1, -1, -1).contiguous().view(-1, N * K, self.hidden_size) # (B * NQ, N * K, D) 125 | query = query.view(-1, 1, self.hidden_size) # (B * NQ, 1, D) 126 | minibatch = torch.cat([support, query], 1) 127 | 128 | x = self.att0(minibatch, self.seq_len).transpose(1, 2) 129 | #x = self.bn1(x).transpose(1, 2) 130 | x = self.bn1(self.tc1(x)).transpose(1, 2) 131 | #x = self.tc1(x).transpose(1, 2) 132 | x = self.att1(x, self.seq_len).transpose(1, 2) 133 | x = self.bn2(self.tc2(x)).transpose(1, 2) 134 | #x = self.tc2(x).transpose(1, 2) 135 | x = self.att2(x, self.seq_len) 136 | x = x[:, -1, :] 137 | logits = self.disc(x) 138 | _, pred = torch.max(logits, -1) 139 | return logits, pred 140 | 141 | -------------------------------------------------------------------------------- /paper/fewrel.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ProKil/FewRel/5fdc65d899c7efa21aac9e76be00e027978c2e7a/paper/fewrel.pdf -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # FewRel Dataset, Toolkits and Baseline Models 2 | 3 | FewRel is a large-scale few-shot relation extraction dataset, which contains 70000 natural language sentences expressing 100 different relations. This dataset is presented in the our EMNLP 2018 paper [FewRel: A Large-Scale Few-Shot Relation Classification Dataset with State-of-the-Art Evaluation](https://github.com/ProKil/FewRel/blob/master/paper/fewrel.pdf). 4 | 5 | More info at http://zhuhao.me/fewrel.html . 6 | 7 | ## Citing 8 | If you used our data, toolkits or baseline models, please kindly cite our paper: 9 | ``` 10 | @inproceedings{han2018fewrel, 11 | title={FewRel:A Large-Scale Supervised Few-Shot Relation Classification Dataset with State-of-the-Art Evaluation}, 12 | author={Han, Xu and Zhu, Hao and Yu, Pengfei and Wang, Ziyun and Yao, Yuan and Liu, Zhiyuan and Sun, Maosong}, 13 | booktitle={EMNLP}, 14 | year={2018} 15 | } 16 | ``` 17 | 18 | 19 | If you have questions about any part of the paper, submission, leaderboard, codes, data, please e-mail zhuhao@cmu.edu. 20 | 21 | ## Contributions 22 | 23 | Hao Zhu first proposed this problem and proposed the way to build the dataset and the baseline system; Ziyuan Wang built and maintained the crowdsourcing website; Yuan Yao helped download the original data and conducted preprocess; 24 | Xu Han, Hao Zhu, Pengfei Yu and Ziyun Wang implemented baselines and wrote the paper together; Zhiyuan Liu provided thoughtful advice and funds through the whole project. The order of the first four authors are determined by dice rolling. 25 | 26 | ## Dataset and Word Embedding 27 | 28 | The dataset has already be contained in the github repo. However, due to the large size, glove files (pre-trained word embeddings) are not included. Please download `glove.6B.50d.json` from [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b14bf0d3c9e04ead9c0a/?dl=1) or [Google Drive](https://drive.google.com/open?id=1UnncRYzDpezPkwIqhgkVW6BacIqz6EaB) and put it under `data/` folder. 29 | 30 | ## Usage 31 | 32 | To run our baseline models, use command 33 | 34 | ```bash 35 | python train_demo.py {MODEL_NAME} 36 | ``` 37 | 38 | replace `{MODEL_NAME}` with `proto`, `metanet`, `gnn` or `snail`. 39 | 40 | 41 | -------------------------------------------------------------------------------- /test_demo.py: -------------------------------------------------------------------------------- 1 | import models 2 | from fewshot_re_kit.data_loader import JSONFileDataLoader 3 | from fewshot_re_kit.framework import FewShotREFramework 4 | from fewshot_re_kit.sentence_encoder import CNNSentenceEncoder 5 | from models.proto import Proto 6 | from models.snowball import Snowball 7 | 8 | max_length = 40 9 | train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length) 10 | val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) 11 | test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length) 12 | 13 | framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) 14 | sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length) 15 | model = Proto(sentence_encoder).cuda() 16 | 17 | acc = 0 18 | for i in range(5): 19 | acc += framework.eval(model, 4, 5, 5, 100, 3000, ckpt='checkpoint/proto.pth.tar') 20 | acc /= 5.0 21 | print("ACC: {}".format(acc)) 22 | 23 | -------------------------------------------------------------------------------- /train_demo.py: -------------------------------------------------------------------------------- 1 | import models 2 | from fewshot_re_kit.data_loader import JSONFileDataLoader 3 | from fewshot_re_kit.framework import FewShotREFramework 4 | from fewshot_re_kit.sentence_encoder import CNNSentenceEncoder 5 | from models.proto import Proto 6 | from models.gnn import GNN 7 | from models.snail import SNAIL 8 | from models.metanet import MetaNet 9 | import sys 10 | from torch import optim 11 | 12 | model_name = 'proto' 13 | N = 5 14 | K = 5 15 | if len(sys.argv) > 1: 16 | model_name = sys.argv[1] 17 | if len(sys.argv) > 2: 18 | N = int(sys.argv[2]) 19 | if len(sys.argv) > 3: 20 | K = int(sys.argv[3]) 21 | 22 | print("{}-way-{}-shot Few-Shot Relation Classification".format(N, K)) 23 | print("Model: {}".format(model_name)) 24 | 25 | max_length = 40 26 | train_data_loader = JSONFileDataLoader('./data/train.json', './data/glove.6B.50d.json', max_length=max_length) 27 | val_data_loader = JSONFileDataLoader('./data/val.json', './data/glove.6B.50d.json', max_length=max_length) 28 | test_data_loader = JSONFileDataLoader('./data/test.json', './data/glove.6B.50d.json', max_length=max_length) 29 | 30 | framework = FewShotREFramework(train_data_loader, val_data_loader, test_data_loader) 31 | sentence_encoder = CNNSentenceEncoder(train_data_loader.word_vec_mat, max_length) 32 | 33 | if model_name == 'proto': 34 | model = Proto(sentence_encoder) 35 | framework.train(model, model_name, 4, 20, N, K, 5) 36 | elif model_name == 'gnn': 37 | model = GNN(sentence_encoder, N) 38 | framework.train(model, model_name, 2, N, N, K, 1, learning_rate=1e-3, weight_decay=0, optimizer=optim.Adam) 39 | elif model_name == 'snail': 40 | print("HINT: SNAIL works only in PyTorch 0.3.1") 41 | model = SNAIL(sentence_encoder, N, K) 42 | framework.train(model, model_name, 25, N, N, K, 1, learning_rate=1e-2, weight_decay=0, optimizer=optim.SGD) 43 | elif model_name == 'metanet': 44 | model = MetaNet(N, K, train_data_loader.word_vec_mat, max_length) 45 | framework.train(model, model_name, 1, N, N, K, 1, learning_rate=5e-3, weight_decay=0, optimizer=optim.Adam, train_iter=300000) 46 | else: 47 | raise NotImplementedError 48 | 49 | --------------------------------------------------------------------------------