├── conf ├── __init__.py └── config.py ├── scripts ├── __init__.py ├── sim.py ├── utils.py ├── encoder.py ├── data_formatter.py ├── dataloader.py ├── predictor.py ├── dataprocessor.py ├── api.py ├── matcher.py ├── batch_generator.py └── trainer.py ├── predict_bert.sh ├── predict_protonet.sh ├── predict_shenyuan.sh ├── train_protonet.sh ├── train_shenyuan.sh ├── train_bert_classifier.sh ├── log └── .gitignore └── README.md /conf/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /predict_bert.sh: -------------------------------------------------------------------------------- 1 | python scripts/api.py --model bert --mode predict 2 | -------------------------------------------------------------------------------- /predict_protonet.sh: -------------------------------------------------------------------------------- 1 | python scripts/api.py --mode predict --model protonet 2 | -------------------------------------------------------------------------------- /predict_shenyuan.sh: -------------------------------------------------------------------------------- 1 | python scripts/api.py --mode predict --model shenyuan 2 | -------------------------------------------------------------------------------- /train_protonet.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/api.py --model protonet --mode train 2 | -------------------------------------------------------------------------------- /train_shenyuan.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/api.py --model shenyuan --mode train 2 | -------------------------------------------------------------------------------- /train_bert_classifier.sh: -------------------------------------------------------------------------------- 1 | python -W ignore scripts/api.py --model bert --mode train 2 | -------------------------------------------------------------------------------- /log/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /scripts/sim.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def word_jaccard(seg1, seg2): 4 | a = list(set(seg1).intersection(set(seg2))) 5 | b = list(set(seg1).union(set(seg2))) 6 | return float(len(a) / len(b)) 7 | 8 | 9 | def char_jaccard(sen1, sen2): 10 | a = list(set(list(sen1)).intersection(set(list(sen2)))) 11 | b = list(set(list(sen1)).union(set(list(sen2)))) 12 | return float(len(a) / len(b)) 13 | 14 | if __name__ == '__main__': 15 | a = '你 是 谁啊' 16 | b = '我 是 什么人呢' 17 | print(word_jaccard(a.split(),b.split())) 18 | print(char_jaccard(a,b)) -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | from collections import defaultdict 4 | from keras.preprocessing.sequence import pad_sequences 5 | import os 6 | os.environ['KMP_WARNINGS'] = '0' 7 | def sample(from_array, n_sample, replace=True, p=None): 8 | x = np.asarray(from_array) 9 | lenx = len(x) 10 | ind_range = np.arange(lenx) 11 | if lenx < n_sample: 12 | if replace: 13 | sampled_inds = np.random.choice(ind_range, n_sample, replace=True, p=p) 14 | else: 15 | sampled_inds = np.random.choice(ind_range, lenx, replace=True, p=p) 16 | else: 17 | sampled_inds = np.random.choice(ind_range, n_sample, replace=False, p=p) 18 | 19 | samples = x[sampled_inds] 20 | return samples.tolist() 21 | 22 | 23 | def expand_items(dict): 24 | expanded = [] 25 | items = dict.items() 26 | for key, vlist in items: 27 | expanded += [(key, x) for x in vlist] 28 | return expanded 29 | 30 | def texts_to_indices(X, max_len, tokenizer): 31 | X = [tokenizer.encode(sent, add_special_tokens=True) for sent in X] 32 | X = pad_sequences(X, maxlen=max_len, padding='post', truncating='post').tolist() 33 | return X 34 | 35 | def calc_acc(preds, targets): 36 | return np.mean(np.asarray(preds) == np.asarray(targets)) 37 | 38 | 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /scripts/encoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import numpy as np 3 | from torch.nn import * 4 | 5 | def create_emb_layer(weights_matrix, non_trainable=False): 6 | print(weights_matrix) 7 | num_embeddings, embedding_dim = weights_matrix.shape 8 | emb_layer = nn.Embedding(num_embeddings, embedding_dim) 9 | emb_layer.load_state_dict({'weight': weights_matrix}) 10 | if non_trainable: 11 | emb_layer.weight.requires_grad = False 12 | 13 | return emb_layer 14 | 15 | 16 | class BILSTM_encoder(nn.Module): 17 | def __init__(self, out_dim=300, in_dim=300, dropout=0.5, embedding_matrix=None, vocab_size=10000): 18 | super(BILSTM_encoder, self).__init__() 19 | self.out_dim = out_dim 20 | self.in_dim = in_dim 21 | self.dropout = dropout 22 | if not embedding_matrix: 23 | embedding_matrix = np.zeros((vocab_size, in_dim)) 24 | self.emb_layer = create_emb_layer(embedding_matrix) 25 | self.bilstm = LSTM(in_dim, out_dim, bidirectional=True, dropout=0.2, num_layers=1) 26 | 27 | def load_pretrained_embedding(self, embedding_matrix): 28 | self.emb_layer = create_emb_layer(embedding_matrix) 29 | 30 | def forward(self, x): 31 | x = self.emb_layer(x) 32 | t_h, (_, _) = self.bilstm(x) 33 | return t_h, t_h[-1] 34 | 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Introduction** 2 | *This project targets problems of insufficient data in text classification tasks. By using some few-shot learning tricks (ProtoNet, etc.), performance on tasks sees improvement, and has potential to furthur improve, but the convergence speed for ProtoNet+bert is much slower than normal Bert finetuning, and GPU memory is also a key limitation on its improvement (cannot set large number of supports at evaluation time, #TODO to fix this in the future)* 3 | 4 | *文本小样本多分类模型(仅测试过短文本), 目前用bert初始化, 可换用sentence-bert做初始化, 效果更佳 5 | 6 | **Classification Models** 7 | 1. ProtoNet+Bert (optimized for fewshot, can achieve better performance on some small dataset) 8 | 2. Ordinary Bert classification (for normal dataset, also works for fewshot thanks to the strength of BERT pretraining) 9 | 3. A Mysterious Algorithm from my colleague (optimized for matching tasks, do not train this for normal classification tasks, just for experimental purporse, just for fun) 10 | 11 | 12 | **Usage:** 13 | 1. put your data into ./data folder 14 | 2. write your own script (or use some pre-given function in data_formatter.py) 15 | to format your training/evaluation data into "sentence and its label separated by tab" per line 16 | 3. modify configuration in conf/config.py under the Config class for your chosen model, 17 | * Mandatory settings: 18 | * for Bert classifier: set number of classes and max sentence length, 19 | * for ProtoNet: set "k" and "shot", k must be between 20% to 100% of total number of classes, shot commonly between 2 and 10 depending on datasize 20 | * Optional settings: 21 | * for Bert classifier: batch_size 22 | * for ProtoNet: n_support, eval_n_support (number of supporting samples for each class, read the paper on ProtoNet for more details), 23 | you can just leave them unchanged, the bigger the better, but may exceeds GPU memory limits, 24 | especially at evaluation time, when number of classes is big. 25 | * general settings: learning rate, warmup, paths to essential data/modelfiles, device, etc.. 26 | 3. Alternatively, if you are sick of modifying the config file, or you want to train multiple models with different configs, you can just use <*python scripts/api.py*> directly, 27 | all kinds of settings can be re-defined here, overriding what's in config.py. type <*python scripts/api.py -h*> for more details. 28 | 4. choose to run from three shell script on your demand 29 | 5. predict with the other three shell scripts, don't forget to check all kinds of load paths before running 30 | 31 | **Requirements:** 32 | pytorch, transformers, pytorch_pretrained_bert, keras, sklearn, etc.. 33 | 34 | **Note:** 35 | Recommended hyperparameters are left as they are in conf/config.py except those that are task specific. All experiments are using bert-chinese-base, not tested for other languages, but you can always try it (remember to change bert_type in config). 36 | 37 | **TODO:** 38 | 1. support unlimited number of supports at evaluation/prediction time 39 | 2. support Meta-Learning 40 | 3. replace Euclidean distance with RE2 and BCEloss 41 | 42 | 43 | **-------------------------------------------------------------------------------------------------------:** 44 | 45 | 46 | 47 | 48 | | | ProtoNet+Bert | Bert | Training size| Test size | Balanced | Class Count| 49 | | ------ | ------ |------ |------ |------ |------ |------ | 50 | | Intent Classification (downsampled to 1%) | **88.3%** | 87.4% | ≈60*15 | 1333 | True | 15 | 51 | | Intent Classification (downsampled to 10%) | **91.9%** | 91.7% | ≈600*15 | 1333 | True | 15 | 52 | | Intent Classification | >93.7%(too slow to train) | **94.6%**(?) | ≈6000*15 | 1333 | True | 15| 53 | | Anonymous Dataset 1 | **87.8%** | 87.2% | 3200 | 352 | False | 86| 54 | | Anonymous Dataset 2 | **84.9%** | 84.3% | 1300 | 434 | False | 20| 55 | | Anonymous Dataset 3 | **88.1%** | 83.9% | 5000 | 320 | True | 68| 56 | 57 | -------------------------------------------------------------------------------- /scripts/data_formatter.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | 3 | 4 | def swap_sent_label(path_train, path_test): 5 | with open(path_test, 'r', encoding='utf-8') as f: 6 | lines = f.readlines() 7 | lines = [x.strip().split('\t') for x in lines] 8 | lines = ['\t'.join([x[1],x[0]])+'\n' for x in lines] 9 | with open(path_test,'w', encoding='utf-8') as f: 10 | f.writelines(lines) 11 | with open(path_train, 'r', encoding='utf-8') as f: 12 | lines = f.readlines() 13 | lines = [x.strip().split('\t') for x in lines] 14 | lines = ['\t'.join([x[1],x[0]])+'\n' for x in lines] 15 | with open(path_train,'w', encoding='utf-8') as f: 16 | f.writelines(lines) 17 | 18 | def matching_to_classification(path_train, path_test): 19 | with open(path_test, 'r', encoding='utf-8') as f: 20 | lines = f.readlines() 21 | lines = [x.strip().split('\t') for x in lines] 22 | lines = ['\t'.join(x[:2])+'\n' for x in lines if x[-1]=="Y"] 23 | with open(path_test+'.formatted','w', encoding='utf-8') as f: 24 | f.writelines(lines) 25 | with open(path_train, 'r', encoding='utf-8') as f: 26 | lines = f.readlines() 27 | lines = [x.strip().split('\t') for x in lines] 28 | lines = ['\t'.join(x[:2])+'\n' for x in lines if x[2]=="Y"] 29 | with open(path_train+'.formatted','w', encoding='utf-8') as f: 30 | f.writelines(lines) 31 | 32 | def matching_to_classification_with_other(path_train, path_test): 33 | with open(path_test, 'r', encoding='utf-8') as f: 34 | lines = f.readlines() 35 | lines = [x.strip().split('\t') for x in lines] 36 | lines = ['\t'.join(x[:2])+'\n' if x[2]=='Y' else '\t'.join([x[0],'其他'])+'\n' for x in lines] 37 | with open(path_test+'.o','w', encoding='utf-8') as f: 38 | f.writelines(lines) 39 | with open(path_train, 'r', encoding='utf-8') as f: 40 | lines = f.readlines() 41 | lines = [x.strip().split('\t') for x in lines] 42 | lines = ['\t'.join(x[:2])+'\n' if x[2]=='Y' else '\t'.join([x[0],'其他'])+'\n' for x in lines] 43 | with open(path_train+'.o','w', encoding='utf-8') as f: 44 | f.writelines(lines) 45 | 46 | def split_corpus(fn): 47 | with open(fn, 'r', encoding='utf-8') as f: 48 | lines = f.readlines() 49 | lines = [line.strip().split('\t') for line in lines] 50 | X = [line[0] for line in lines] 51 | Y = [line[1] for line in lines] 52 | trainX, testX, trainY, testY = train_test_split(X, Y, stratify=Y, test_size=0.1) 53 | train_lines = ['\t'.join([x[0],x[1]])+'\n' for x in zip(trainX,trainY)] 54 | test_lines = ['\t'.join([x[0],x[1]])+'\n' for x in zip(testX, testY)] 55 | with open('data/ptrain.txt', 'w', encoding='utf-8') as f: 56 | f.writelines(train_lines) 57 | with open('data/ptest.txt', 'w', encoding='utf-8') as f: 58 | f.writelines(test_lines) 59 | 60 | def down_sampling(fn, sampling_rate=0.01): 61 | with open(fn, 'r', encoding='utf-8') as f: 62 | lines = f.readlines() 63 | lines = [line.strip().split('\t') for line in lines] 64 | X = [line[0] for line in lines] 65 | Y = [line[1] for line in lines] 66 | trainX, testX, trainY, testY = train_test_split(X, Y, stratify=Y, test_size=sampling_rate) 67 | from collections import defaultdict 68 | data_dict = defaultdict(list) 69 | for x, y in zip(testX, testY): 70 | data_dict[y].append(x) 71 | data_dict = {x:len(y) for x, y in data_dict.items()} 72 | print(data_dict) 73 | print(min(data_dict.values())) 74 | test_lines = ['\t'.join([x[0],x[1]])+'\n' for x in zip(testX,testY)] 75 | with open(fn+'.small', 'w', encoding='utf-8') as f: 76 | f.writelines(test_lines) 77 | 78 | if __name__ == '__main__': 79 | # swap_sent_label('data/train_intent.txt','data/test_intent.txt') 80 | # matching_to_classification_with_other('data/train.tsv','data/test.tsv') 81 | # split_corpus('data/new_perfor1020') 82 | down_sampling('../data/train_intent.txt') 83 | -------------------------------------------------------------------------------- /scripts/dataloader.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import train_test_split 2 | from transformers import BertTokenizer 3 | from torch.utils.data import DataLoader, TensorDataset 4 | from dataprocessor import * 5 | from batch_generator import * 6 | from conf.config import * 7 | import random 8 | class FewShotDataLoader: 9 | def __init__(self, config): 10 | self.config = config 11 | if hasattr(self.config,'bert_initialization_path'): 12 | print('loading specified tokenizer from ', self.config.bert_initialization_path, '...') 13 | self.tokenizer = BertTokenizer.from_pretrained(self.config.bert_initialization_path) 14 | print('tokenizer loaded') 15 | else: 16 | print('loading tokenzier from transformers') 17 | print('if this hangs too long, pls restart program') 18 | self.tokenizer = BertTokenizer.from_pretrained(self.config.bert_type) 19 | print('tokenizer loaded') 20 | self.dp = DataPreProcessor(if_del_serial=True) 21 | 22 | def transform_labels(self, Y): 23 | labels = list(set(Y)) 24 | labels = list(sorted(labels)) 25 | label_map = OrderedDict({y:i for i,y in enumerate(labels)}) 26 | self.class_map = label_map 27 | newY = [label_map[y] for y in Y] 28 | return newY 29 | 30 | def get_dataloader(self, lines, trainer='protonet'): 31 | config = self.config 32 | rawX = [self.dp.proc_sent(x[0]) for x in lines] 33 | 34 | max_len = np.max([len(s) for s in rawX]) 35 | 36 | print("max sentence length: ", max_len) 37 | labels = [x[1] for x in lines] 38 | 39 | Y = self.transform_labels(labels) 40 | print(self.class_map, len(self.class_map)) 41 | X, testX, Y, testY = train_test_split(rawX, Y, test_size=self.test_num, shuffle=False) 42 | 43 | if trainer == 'protonet': 44 | config = ProtoNetConfig() 45 | dataloader = ProtoNetBatchGenerator(X=X, Y=Y, tokenizer=self.tokenizer,class_map=self.class_map, config=config) 46 | eval_dataloader = ProtoNetEvalBatchGenerator(X=X, Y=Y,class_map=self.class_map, tokenizer=self.tokenizer, testX=testX, testY=testY, config=config) 47 | 48 | # elif trainer == 'matching': 49 | # config = 50 | # dataloader = RandomBatchGenerator(X=X, Y=Y, tokenizer=self.tokenizer, n_supports=self.config.n_support, k=self.config.k, n_batch=self.config.nstep, n_pos=self.config.batch_pos, n_neg=self.config.batch_neg) 51 | # eval_dataloader = ProtoNetEvalBatchGenerator(X=X, Y=Y,class_map=self.class_map, tokenizer=self.tokenizer, testX=testX, testY=testY, n_supports=self.config.eval_n_support, batch_size=self.config.eval_batch_size) 52 | elif trainer == 'bert': 53 | trainX = texts_to_indices(X, config.max_sent_len, self.tokenizer) 54 | train_dataset = TensorDataset(T.LongTensor(trainX), T.LongTensor(Y)) 55 | dataloader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True) 56 | testX = texts_to_indices(testX, config.max_sent_len, self.tokenizer) 57 | test_dataset = TensorDataset(T.LongTensor(testX), T.LongTensor(testY)) 58 | eval_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False) 59 | elif trainer == 'shenyuan': 60 | dataloader = ShenyuanBatchGenerator(X=X, Y=Y, tokenizer=self.tokenizer, class_map=self.class_map, config=config) 61 | eval_dataloader = ShenyuanEvalBatchGenerator(X=X, Y=Y, testX=testX, tokenizer=self.tokenizer, testY=testY, class_map=self.class_map, config=config) 62 | else: 63 | raise NotImplementedError 64 | return dataloader, eval_dataloader 65 | 66 | 67 | 68 | def load_data(self): 69 | with open(self.config.train_path, encoding='utf-8') as f: 70 | train_lines = f.readlines()[:] 71 | random.shuffle(train_lines) 72 | with open(self.config.test_path, encoding='utf-8') as f: 73 | test_lines = f.readlines()[:] 74 | self.test_num = len(test_lines) 75 | lines = train_lines + test_lines 76 | lines = [x.strip().split('\t') for x in lines] 77 | 78 | return lines 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | if __name__ == '__main__': 88 | dl = DataLoader() 89 | dl.load_data() 90 | 91 | -------------------------------------------------------------------------------- /scripts/predictor.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | from transformers import * 3 | from utils import * 4 | from dataprocessor import DataPreProcessor 5 | import pickle 6 | import json 7 | from torch import save, load 8 | 9 | class Predictor: 10 | def __init__(self, config): 11 | self.config = config 12 | self.dp = DataPreProcessor(if_del_serial=True) 13 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 14 | 15 | def load_model_components(self): 16 | encoder_path = self.config.encoder_load_path 17 | self.encoder = BertModel.from_pretrained(encoder_path) 18 | self.encoder.eval() 19 | class_map_path = self.config.class_map_load_path 20 | with open(class_map_path, 'r', encoding='utf-8') as f: 21 | self.class_map = json.load(f) 22 | 23 | 24 | def _preprocess(self, sent): 25 | sent = self.dp.proc_sent(sent) 26 | sent = self.tokenizer.encode(sent) 27 | sent = pad_sequences([sent], maxlen=self.config.max_sent_len, padding='post', truncating='post').tolist() 28 | return sent 29 | 30 | def _preprocess_pair(self, a, b): 31 | senta = self.dp.proc_sent(a) 32 | sentb = self.dp.proc_sent(b) 33 | senta = senta[:min([self.config.max_sent_len, len(senta)])] 34 | sentb = sentb[:min([self.config.max_sent_len, len(sentb)])] 35 | sent = self.tokenizer.encode(senta, text_pair=sentb, add_special_tokens=True) 36 | return sent 37 | 38 | def predict(self, sent): 39 | pass 40 | 41 | class ShenyuanPredictor(Predictor): 42 | def load_model_components(self): 43 | super(ShenyuanPredictor, self).load_model_components() 44 | corpus_path = self.config.corpus_path 45 | with open(corpus_path, 'r', encoding='utf-8') as f: 46 | self.data_dict = json.load(f) 47 | self.classes = list(self.data_dict.keys()) 48 | self.encoder.to(self.config.device) 49 | self.matcher = load(self.config.matcher_load_path) 50 | self.matcher.to(self.config.device) 51 | self.matcher.eval() 52 | 53 | def predict(self,sent): 54 | topns = self._sample_topN_supports(self.config.eval_n_support, sent) 55 | paired_topns = [] 56 | for i, topn_cls in enumerate(topns): 57 | topn_cls = [self._preprocess_pair(sent,self.class_map[str(i)]+a) for a in topn_cls] 58 | topn_cls = pad_sequences(topn_cls, maxlen=self.config.max_sent_len, padding='post', truncating='post').tolist() 59 | paired_topns.append(topn_cls) 60 | topn = T.LongTensor(paired_topns).to(self.config.device) 61 | topn = topn.view(len(self.class_map)*self.config.eval_n_support, self.config.max_sent_len) 62 | topn = self.encoder(topn)[1] 63 | logits = self.matcher(topn).squeeze() 64 | logits = logits.view(len(self.class_map), self.config.eval_n_support) 65 | logits = T.mean(logits, dim=-1).squeeze() 66 | pred = T.argmax(logits, dim=-1).squeeze().item() 67 | max_score = T.max(logits,dim=-1)[0].squeeze().item() 68 | if max_score < .5: 69 | print("OTHERS") 70 | else: 71 | print(self.class_map[str(pred)]) 72 | 73 | 74 | 75 | 76 | def _sample_topN_supports(self, n, x): 77 | topn_all_class = [] 78 | for cls in self.classes: 79 | candidates = self.data_dict[cls] 80 | if n > len(candidates): 81 | sims = [(y, char_jaccard(x, y)) for y in candidates] 82 | topn = list(sorted(sims, key=lambda x: x[1]))[-n:] 83 | topn = [x[0] for x in topn] 84 | else: 85 | topn = sample(candidates, n, replace=True) 86 | topn_all_class.append(topn) 87 | return topn_all_class 88 | 89 | class ProtoNetPredictor(Predictor): 90 | 91 | def load_model_components(self): 92 | super(ProtoNetPredictor, self).load_model_components() 93 | center_path = self.config.center_load_path 94 | with open(center_path, 'rb') as f: 95 | self.centers = pickle.load(f) 96 | self.centers = T.FloatTensor(self.centers) 97 | 98 | def predict(self, sent): 99 | sent = self._preprocess(sent) 100 | sent = T.LongTensor(sent) 101 | sent = self.encoder(sent)[1] 102 | dists = T.cdist(sent,self.centers).squeeze() 103 | pred = dists.argmin(dim=-1).item() 104 | raw_pred = self.class_map[str(pred)] 105 | print(raw_pred) 106 | 107 | class BertPredictor(Predictor): 108 | def load_model_components(self): 109 | super(BertPredictor,self).load_model_components() 110 | matcher_path = self.config.matcher_path 111 | self.matcher = load(matcher_path) 112 | self.matcher.to(self.config.cpu) 113 | self.matcher.eval() 114 | 115 | def predict(self, sent): 116 | sent = self._preprocess(sent) 117 | sent = T.LongTensor(sent) 118 | sent = self.encoder(sent)[1] 119 | pred = T.argmax(self.matcher(sent),dim=-1).item() 120 | raw_pred = self.class_map[str(pred)] 121 | print(raw_pred) 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /scripts/dataprocessor.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | class DataPreProcessor(object): 6 | """Runs end-to-end pre-processor.""" 7 | 8 | def __init__(self, if_del_serial=False, clean_ratio=0.6, keep_punc=True): 9 | self.punc = "!?。%()+,-:∶;≤<=>≥@[]{|}~、》「」『』【】〔〕〖〗《》〝〞–—‘’‛“”„‟‧﹏①②③④⑤⑥⑦⑧⑨⑩" 10 | self.punc += string.punctuation 11 | self.if_del_serial = if_del_serial 12 | self.clean_ratio = clean_ratio 13 | self.keep_punc = keep_punc 14 | 15 | def clear_rare_char(self, input_char): 16 | if u'\u4e00' <= input_char <= u'\u9fa5' \ 17 | or input_char in self.punc \ 18 | or u'\u0030' <= input_char <= u'\u0039' \ 19 | or u'\u0041' <= input_char <= u'\u005A' \ 20 | or u'\u0061' <= input_char <= u'\u007A': 21 | return input_char 22 | return '' 23 | 24 | @staticmethod 25 | def strQ2B(ustring): 26 | """把字符串全角转半角""" 27 | ss = [] 28 | for s in ustring: 29 | rstring = "" 30 | for uchar in s: 31 | inside_code = ord(uchar) 32 | if inside_code == 12288: # 全角空格直接转换 33 | inside_code = 32 34 | elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化 35 | inside_code -= 65248 36 | rstring += chr(inside_code) 37 | ss.append(rstring) 38 | return ''.join(ss) 39 | 40 | @staticmethod 41 | def del_serial(in_str): 42 | re_str = re.sub('^\d+(\.\d+)+', '', in_str) 43 | re_str = re.sub( 44 | '^[第((]*[a-zA-Z①②③④⑤⑥⑦⑧⑨⑩一二三四五六七八九十\d]+[、.::))。条节章点](?!\d+)', '', re_str) 45 | match_flag = re.match('^\d+(个|年|月|日|岁|天|小时|点|张|寸|分|:\d+|:\d+)', re_str) 46 | if match_flag == None: 47 | re_str = re.sub('^\d+', '', re_str) 48 | re_str = re.sub('^[第\((]*[①②③④⑤⑥⑦⑧⑨⑩][、.::\))。条节章点]*', '', re_str) 49 | # re_str = re.sub('\d\.\d\.\d', '', re_str) 50 | return re_str 51 | 52 | @staticmethod 53 | def format_sent(sent): 54 | sent = sent.replace(' ', '') 55 | sent = sent.replace(' ', '') 56 | sent = sent.replace(' ', '') 57 | sent = sent.replace('\t', '') 58 | sent = sent.replace('工作描述:', '') 59 | sent = sent.replace('工作经历:', '') 60 | sent = sent.replace('工作范围:', '') 61 | sent = sent.replace('工作职责:', '') 62 | sent = sent.replace('工作业绩:', '') 63 | sent = sent.replace('主要业绩:', '') 64 | sent = sent.replace('主要职责:', '') 65 | sent = sent.replace('主要工作:', '') 66 | sent = sent.rstrip('\n\r') 67 | sent = sent.lstrip('+,。;:’‘“”:】、·))??>,.-—_=+!~`@#$%^&*') 68 | return sent 69 | 70 | @staticmethod 71 | def split_sent(sents, keep_punc=True): 72 | """split sentences(output has no ending punc)""" 73 | sents = sents.replace('[SEP]', '。') 74 | sents = re.split('(。|;|\?|?|\\n|\\\\n)', sents) 75 | if keep_punc: 76 | sents.append('') 77 | sents = ["".join(i) for i in zip(sents[0::2], sents[1::2])] 78 | sents = list(filter(lambda x: len(x) > 1, sents)) 79 | return sents 80 | 81 | @classmethod 82 | def split_doc_file(cls, doc_path, del_empty_line=False): 83 | split_txt = [] 84 | for line in open(doc_path, encoding='utf8'): 85 | # split_line = cls.split_sent(line.strip()) 86 | # split_txt += split_line 87 | split_txt.append(line.strip()) 88 | if del_empty_line: 89 | split_txt = list(filter(lambda x: len(x) > 1, split_txt)) 90 | return split_txt 91 | 92 | @classmethod 93 | def split_doc_list(cls, doc_list, del_empty_line=False): 94 | split_txt = [] 95 | for line in doc_list: 96 | split_line = cls.split_sent(line.strip()) 97 | split_txt += split_line 98 | if del_empty_line: 99 | split_txt = list(filter(lambda x: len(x) > 0, split_txt)) 100 | return split_txt 101 | 102 | def clean_sent(self, sent): 103 | cnt = 0 104 | clean_sent = '' 105 | for char in sent: 106 | if char in self.punc: 107 | cnt += 1 108 | clean_sent += self.clear_rare_char(char) 109 | if len(clean_sent) == 0: 110 | return '' 111 | if float(cnt) / float(len(clean_sent)) > self.clean_ratio: 112 | return '' 113 | return clean_sent 114 | 115 | def proc_sent(self, sent): 116 | """process single sentence""" 117 | sent = self.strQ2B(sent) 118 | sent = self.clean_sent(sent) 119 | sent = self.format_sent(sent) 120 | if len(sent) > 0: 121 | if self.if_del_serial: 122 | sent = self.format_sent(self.del_serial(sent)) 123 | return sent 124 | 125 | def proc_sents(self, sents): 126 | """split and process sentence[s]""" 127 | sents = self.split_sent(sents, self.keep_punc) 128 | sents = [self.proc_sent(sent) for sent in sents] 129 | sents = list(filter(lambda x: len(x) > 0, sents)) 130 | return sents 131 | 132 | 133 | if __name__ == '__main__': 134 | # =====测试====== 135 | t1 = "● 2017.7-2017.9,轮岗管理医疗部门(12人),打通医疗部门与业务部门的合作" 136 | t2 = "2:轮岗至百度直通车团队,管理百度直通车团队,对手机百度,百度地图,百度糯米产品的新客户开发以及老客户维护;" 137 | t3 = "工作描述:销售部门轮岗(1)完成每月公司的分销目标及其他KPI的达成" 138 | t4 = "*任职期间曾在公司资产管理部轮岗,熟悉如何利用资管工具为客户提供更为综合的金融服务" 139 | # print(DataPreProcessor.split_sent(test_str)) 140 | 141 | proc = DataPreProcessor(if_del_serial=True) # if_del_serial 是否删除开头序号 142 | print(proc.proc_sent(t3)) 143 | # print(proc.proc_sents(test_str)) 144 | -------------------------------------------------------------------------------- /conf/config.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | from string import punctuation 3 | 4 | class BaseConfig: 5 | def __init__(self): 6 | '''data path''' 7 | self.test_path = 'data/test.tsv.formatted' 8 | self.train_path = 'data/train.tsv.formatted' 9 | self.test_path = 'data/test_intent.txt' 10 | self.train_path = 'data/train_intent.txt.small' 11 | self.test_path = 'data/ptest.txt' 12 | self.train_path = 'data/ptrain.txt' 13 | ''' log file dir''' 14 | self.log_dir = 'log/' 15 | 16 | '''maximum sentence length''' 17 | self.max_sent_len = 30 18 | ''' learning rate''' 19 | self.lr = 5e-5 20 | '''encoder output dimension''' 21 | self.hidden_dim = 768 22 | ''' cuda device number ''' 23 | self.device = T.device("cuda:2" if T.cuda.is_available() else "cpu") 24 | ''' just in case, no gpu''' 25 | self.cpu = T.device("cpu") 26 | '''some punctuations''' 27 | self.puncs = '【】{}()!-=——+@#¥%……&*()——+、。,;‘“:、|、·~《》,。、?' + punctuation 28 | '''dropout rate for classifier''' 29 | self.dropout = 0.3 30 | '''path to save index-to-classname map (for inference)''' 31 | self.class_map_save_path = 'model/bert/class_map.json' 32 | self.class_map_load_path = 'model/bert/class_map.json' 33 | 34 | '''learning rate decay''' 35 | #self.lr_decay = 0.98 36 | '''learning rate decay interval''' 37 | #self.lr_decay_step_size = 5 38 | '''if you need to use other pretrained encoder in BERT file format, pls specify dir path, otherwise default bert-base-chinese will be used''' 39 | # self.encoder_initialization_path = './model/ernie/' 40 | '''warm up ratio''' 41 | self.warmup = 0.05 42 | '''evaluation interval''' 43 | self.eval_interval = 3 44 | '''pretrained bert from 'transformers' package ''' 45 | self.bert_type = 'bert-base-chinese' 46 | 47 | class BertClassificationConfig(BaseConfig): 48 | def __init__(self): 49 | super(BertClassificationConfig, self).__init__() 50 | '''this should be a directory if using bert, ernie etc.''' 51 | self.encoder_save_path = './model/bert/' 52 | self.matcher_save_path = './model/bert/matcher.model' 53 | self.encoder_load_path = './model/bert/' 54 | self.matcher_load_path = './model/bert/matcher.model' 55 | ''' number of epoch ''' 56 | self.epoch = 50 57 | '''number of classes''' 58 | self.n_classes = 86 59 | '''training batch size''' 60 | self.batch_size = 64 61 | '''total number of training step''' 62 | with open(self.train_path,'r') as f: 63 | nstep = len(f.readlines())//self.batch_size 64 | self.t_total = nstep * self.epoch 65 | print("TOTAL TRAINING STEP: ", self.t_total) 66 | 67 | 68 | class ProtoNetConfig(BaseConfig): 69 | def __init__(self): 70 | super(ProtoNetConfig, self).__init__() 71 | '''encoder path''' 72 | self.encoder_save_path = './model/protonet/' 73 | self.encoder_load_path = './model/protonet/' 74 | self.center_load_path = './model/protonet/centers.pkl' 75 | self.center_save_path = './model/protonet/centers.pkl' 76 | ''' number of epoch ''' 77 | self.epoch = 1000 78 | '''evaluation batch size''' 79 | self.eval_batch_size = 2 80 | '''number of supports to compute center of each class''' 81 | self.n_support = 5 82 | '''number of supports to compute center at eval time''' 83 | self.eval_n_support = 5 84 | '''number of training steps per epoch''' 85 | self.n_batch = 1 86 | ''' number of classes and number of sampled instances for training (k-way-n-shot)''' 87 | self.k = 30 88 | ''' number of samples per step (denoted as N), note that negative samples will automatically be computed, 89 | leaving N positive samples and (k-1)*N negative samples''' 90 | self.shot = 2 91 | '''total number of training step''' 92 | self.t_total = self.n_batch * self.epoch 93 | print("TOTAL TRAINING STEP: ", self.t_total) 94 | '''distance epsilon''' 95 | self.dist_epsilon = 1e-1 96 | '''path to save centers corresponding to each class (as list)''' 97 | # self.lr_decay = 0.98 98 | # self.lr_decay_step_size = 5 99 | 100 | class ShenyuanConfig(BaseConfig): 101 | def __init__(self): 102 | super(ShenyuanConfig, self).__init__() 103 | self.encoder_save_path = './model/shenyuan/' 104 | self.matcher_save_path = './model/shenyuan/matcher.model' 105 | self.encoder_load_path = './model/shenyuan/' 106 | self.matcher_load_path = './model/shenyuan/matcher.model' 107 | ''' number of epoch ''' 108 | self.epoch = 100 109 | '''ratio of neg:pos training data sampling''' 110 | self.neg_ratio = 1 111 | '''training batch size''' 112 | self.batch_size = 64 113 | '''evaluation batch size''' 114 | self.eval_batch_size = 4 115 | '''number of supports to compute center of each class''' 116 | self.eval_n_support = 3 117 | '''number of training steps per epoch''' 118 | self.nstep = 2 119 | '''total number of training step''' 120 | self.t_total = self.batch_size*self.nstep 121 | '''path to save reference sentences for each class''' 122 | self.corpus_path = './model/shenyuan/data_dict.json' 123 | '''if you need to use other pretrained encoder in BERT file format, pls specify dir path, otherwise default bert-base-chinese will be used''' 124 | # self.encoder_initialization_path = './model/ernie/' 125 | -------------------------------------------------------------------------------- /scripts/api.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | sys.path.append('../') 4 | from transformers import BertModel 5 | import argparse 6 | from trainer import * 7 | from conf.config import * 8 | from dataloader import FewShotDataLoader 9 | from matcher import * 10 | import pickle 11 | from predictor import * 12 | import numpy as np 13 | from datetime import datetime 14 | import os 15 | os.environ['KMP_WARNINGS'] = '0' 16 | #np.random.seed(1) 17 | #torch.manual_seed(1) 18 | #torch.cuda.manual_seed_all(1) 19 | #torch.backends.cudnn.deterministic = True 20 | class API: 21 | def __init__(self, args): 22 | self.parse_args(args) 23 | 24 | def parse_args(self, args): 25 | self.model = args.model 26 | self.mode = args.mode 27 | if self.model == 'protonet': 28 | self.config = ProtoNetConfig() 29 | self.override_config(args, self.config) 30 | dl = FewShotDataLoader(self.config) 31 | encoder = self.get_encoder(self.config) 32 | trainer = ProtoNetTrainer(encoder, None, dl, self.config, EuclideanLoss(self.config.dist_epsilon)) 33 | predictor = ProtoNetPredictor(self.config) 34 | 35 | elif self.model == 'bert': 36 | self.config = BertClassificationConfig() 37 | self.override_config(args, self.config) 38 | encoder = self.get_encoder(self.config) 39 | dl = FewShotDataLoader(self.config) 40 | trainer = BertClassificationTrainer(encoder, SimpleClassifier(self.config.n_classes, dropout=self.config.dropout), dl, self.config, CrossEntropyLoss()) 41 | predictor = BertPredictor(self.config) 42 | 43 | elif self.model == 'shenyuan': 44 | self.config = ShenyuanConfig() 45 | self.override_config(args, self.config) 46 | dl = FewShotDataLoader(self.config) 47 | encoder = self.get_encoder(self.config) 48 | trainer = ShenyuanTrainer(encoder, SimpleRegressor(), dl, self.config, BCELoss()) 49 | predictor = ShenyuanPredictor(self.config) 50 | else: 51 | raise NotImplementedError 52 | self.trainer = trainer 53 | self.predictor = predictor 54 | 55 | def override_config(self, args, config): 56 | #parsed_args = args.parse_args() 57 | for arg in vars(args): 58 | arg_val = getattr(args, arg) 59 | if arg_val: 60 | if arg == 'device': 61 | arg_val = T.device("cuda:{}".format(arg_val) if T.cuda.is_available() else "cpu") 62 | 63 | setattr(config, arg, arg_val) 64 | if args.model == 'protonet': 65 | config.t_total = config.epoch 66 | elif args.model == 'bert': 67 | with open(config.train_path,'r') as f: 68 | nstep = len(f.readlines())//config.batch_size 69 | config.t_total = nstep * config.epoch 70 | else: 71 | config.t_total = config.nstep * config.epoch 72 | 73 | def get_encoder(self,config): 74 | if hasattr(config, "encoder_initialization_path"): 75 | fn = config.encoder_initialization_path 76 | print("loading encoder from ", fn, '...') 77 | encoder = BertModel.from_pretrained(config.encoder_initialization_path) 78 | print('encoder loaded.') 79 | else: 80 | t = config.bert_type 81 | print("loading pretrained encoder of type ", t, ' from transformers ...') 82 | print("if this hangs too long, pls restart the program") 83 | encoder = BertModel.from_pretrained(config.bert_type) 84 | print('encoder loaded.') 85 | return encoder 86 | 87 | def train(self, load=False): 88 | if load: 89 | self.trainer.load_model(self.config.encoder_load_path, self.config.matcher_load_path) 90 | self.trainer.load_data() 91 | best_acc, best_report, best_epoch = self.trainer.train() 92 | print(best_report) 93 | s = "Model{}, achieved best accuracy {} at epoch {}".format(self.model, best_acc, best_epoch) 94 | print(s) 95 | print("TRAINING FINISHED : ") 96 | log_fn = self.config.log_dir+"{}_{}.log".format(datetime.now().date(), self.model) 97 | with open(log_fn, 'a', encoding='utf-8') as f: 98 | s = '\n'.join([best_report, s])+'\n\n' 99 | f.write(s) 100 | print('log written to {}'.format(log_fn)) 101 | 102 | def evaluate(self, load=False): 103 | if load: 104 | self.trainer.load_model(self.config.encoder_load_path, self.config.matcher_load_path) 105 | self.trainer.load_data() 106 | self.trainer.evaluate() 107 | 108 | def predict(self): 109 | self.predictor.load_model_components() 110 | while 1: 111 | # try: 112 | sent = input("pls input sentence: ") 113 | self.predictor.predict(sent) 114 | # except: 115 | # print("sth goes wrong, try again") 116 | 117 | def exec(self): 118 | if self.mode == 'predict': 119 | self.predict() 120 | else: 121 | self.train() 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--model", default='protonet', help='train with protonet or ordinary bert finetuning {protonet, bert, matching}', type=str) 127 | parser.add_argument("--mode", default='train', help="choose from {train, predict} mode", type=str) 128 | parser.add_argument("--epoch", help='num of epoch', type=int) 129 | parser.add_argument('--lr', help='learning rate', type=float) 130 | parser.add_argument('--warmup', help='portion of steps to warm up', type=float) 131 | parser.add_argument('--batch_size', help='batch size', type=int) 132 | parser.add_argument('--train_path', help='training data file path', type=str) 133 | parser.add_argument('--test_path', help='testing data file path', type=str) 134 | parser.add_argument('--max_sent_len', help='max sentence length to pad', type=int) 135 | parser.add_argument('--encoder_load_path', help='path to load pretrained/continue-training encoder', type=str) 136 | parser.add_argument('--encoder_save_path', help='path to save encoder', type=str) 137 | parser.add_argument('--matcher_load_path', help='path to load pretrained/continue-training matcher', type=str) 138 | parser.add_argument('--matcher_save_path', help='path to save matcher', type=str) 139 | parser.add_argument('--shot', help='number of samples per class per epoch for protonet (K-way-N-shot tasks)', type=int) 140 | parser.add_argument('--k', help='number of classes per epoch for protonet (K-way-N-shot tasks)', type=int) 141 | parser.add_argument('--n_classes', help='total number of classes', type=int) 142 | parser.add_argument('--dist_epsilon', help='distance epsilon to be added to euclidean distance', type=float) 143 | parser.add_argument('--device', help='cuda device index or cpu to be used', type=int) 144 | parser.add_argument('--dropout', help='dropout ratio (not keep ratio)', type=float) 145 | parser.add_argument('--class_map_save_path', help='path to save class mapping json', type=str) 146 | parser.add_argument('--class_map_load_path', help='path to load class mapping json', type=str) 147 | parser.add_argument('--n_support', help='number of supporting samples per class per epoch', type=int) 148 | parser.add_argument('--eval_n_support', help='number of supporting samples per class per epoch', type=int) 149 | parser.add_argument('--eval_batch_size', help='eval batch_size, note this should be small for protonet if \ 150 | total number of classes or eval n_support is big') 151 | parser.add_argument('--center_save_path', help='path to save centers', type=str) 152 | parser.add_argument('--center_load_path', help='path to load centers', type=str) 153 | 154 | args = parser.parse_args() 155 | api = API(args) 156 | api.exec() 157 | 158 | 159 | -------------------------------------------------------------------------------- /scripts/matcher.py: -------------------------------------------------------------------------------- 1 | import torch as T 2 | from torch.nn import * 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import math 6 | from pytorch_pretrained_bert.modeling import BertLayerNorm 7 | 8 | 9 | # 10 | # def init_ff_layer(layer, f1=None): 11 | # weight_size = layer.weight.data.size()[0] 12 | # if not f1: 13 | # f1 = 1 / math.sqrt(weight_size) 14 | # nn.init.uniform_(layer.weight.data, -f1, f1) 15 | # nn.init.uniform_(layer.bias.data, -f1, f1) 16 | 17 | def init_bert_weights(module): 18 | """ Initialize the weights. 19 | """ 20 | if isinstance(module, (nn.Linear, nn.Embedding)): 21 | # Slightly different from the TF version which uses truncated_normal for initialization 22 | # cf https://github.com/pytorch/pytorch/pull/5617 23 | module.weight.data.normal_(mean=0.0, std=0.02) 24 | elif isinstance(module, BertLayerNorm): 25 | module.bias.data.zero_() 26 | module.weight.data.fill_(1.0) 27 | if isinstance(module, nn.Linear) and module.bias is not None: 28 | module.bias.data.zero_() 29 | 30 | class EuclideanMatcher(Module): 31 | def __init__(self): 32 | super(EuclideanMatcher, self).__init__() 33 | self.dist = T.cdist 34 | 35 | def forward(self, a, b): 36 | return self.dist(a, b) 37 | 38 | class CosMatcher(Module): 39 | def __init__(self): 40 | super(CosMatcher, self).__init__() 41 | self.cos = CosineSimilarity(dim=1) 42 | 43 | def forward(self, a, b): 44 | return (self.cos(a,b)+1)/2 45 | 46 | class EuclideanLoss(Module): 47 | def __init__(self, eps=1e-6): 48 | super(EuclideanLoss, self).__init__() 49 | self.eps = eps 50 | 51 | def forward(self, queries, centers, ind): 52 | 53 | dists = T.cdist(queries, centers)+self.eps 54 | lossp = dists[:, ind].sum(0) 55 | tmp1 = dists[:, :ind] 56 | tmp2 = dists[:, ind + 1:] 57 | tmp = T.cat([tmp1, tmp2], dim=1) 58 | exp_sum = T.exp(-tmp).sum(1) 59 | lossn = T.log(exp_sum).sum(0) 60 | loss = (lossp + lossn) / queries.size(0) 61 | return loss 62 | 63 | class DotMatcher(Module): 64 | def forward(self, a, b): 65 | return F.dropout(a.view(a.size(0),1,-1).bmm(b.view(b.size(0),-1,1)), p=0.2) 66 | 67 | class SimpleClassifier(Module): 68 | def __init__(self, n_class, dim=768, inner_dim=300, dropout=0.3): 69 | super(SimpleClassifier, self).__init__() 70 | self.dim = dim 71 | self.dropout = nn.Dropout(0.3) 72 | self.fc_hidden = Linear(dim, inner_dim) 73 | # init_ff_layer(self.fc_hidden) 74 | self.fc = Linear(dim, n_class) 75 | # init_ff_layer(self.fc) 76 | 77 | 78 | def forward(self, x): 79 | #x = F.relu(self.fc_hidden(x)) 80 | return self.dropout(self.fc(x)) 81 | return self.fc(x) 82 | 83 | 84 | class SimpleRegressor(Module): 85 | def __init__(self, dim=768, dropout=0.3): 86 | super(SimpleRegressor, self).__init__() 87 | self.dim = dim 88 | self.dropout = dropout 89 | self.fc = Linear(dim, 1) 90 | # init_ff_layer(self.fc) 91 | 92 | 93 | def forward(self, x): 94 | return T.sigmoid(F.dropout(self.fc(x), self.dropout)) 95 | 96 | 97 | class SimpleMatcher(Module): 98 | def __init__(self, dim=768, dropout=0.3): 99 | super(SimpleMatcher, self).__init__() 100 | self.dim = dim 101 | # self.fc = Linear(dim, dim) 102 | self.fcout = Linear(dim*4, 1) 103 | # init_ff_layer(self.fc) 104 | # init_ff_layer(self.fcout) 105 | self.dropout = nn.Dropout(0.3) 106 | # self.norm = LayerNorm(dim*4) 107 | 108 | 109 | def forward(self, a, b): 110 | c = T.cat([a, b, T.abs(a - b), a * b], dim=-1) 111 | out = F.sigmoid(F.dropout(self.fcout(c),p=self.dropout)) 112 | return out 113 | 114 | 115 | class RE2(Module): 116 | def __init__(self, dim=768, num_block=2): 117 | super(RE2, self).__init__() 118 | self.num_block = num_block 119 | self.blocks = [] 120 | self.predictor = Prediction(dim) 121 | self.aligner = Aligner(dim) 122 | for i in range(num_block): 123 | self.blocks.append(RE2Block(dim, self.aligner)) 124 | 125 | 126 | def forward(self, emb_a, emb_b): 127 | o_t2a = T.zeros_like(emb_a, dtype=T.float32) 128 | o_t1a = T.zeros_like(emb_a, dtype=T.float32) 129 | o_t2b = T.zeros_like(emb_a, dtype=T.float32) 130 | o_t1b = T.zeros_like(emb_a, dtype=T.float32) 131 | assert self.num_block > 0 132 | for i in range(self.num_block): 133 | inp_a = T.cat([emb_a, o_t1a + o_t2a], dim=-1) 134 | inp_b = T.cat([emb_b, o_t1b + o_t2b], dim=-1) 135 | o_a, o_b = self.blocks[i](inp_a, inp_b) 136 | o_t2a = o_t1a 137 | o_t1a = o_a 138 | o_t2b = o_t1b 139 | o_t1b = o_b 140 | y = self.predictor(o_a, o_b) 141 | return y 142 | 143 | 144 | class RE2Block(Module): 145 | def __init__(self, dim, aligner): 146 | super(RE2Block, self).__init__() 147 | self.encoder = Encoder(dim) 148 | self.fuser = Fuser(dim) 149 | self.aligner = aligner 150 | 151 | 152 | def forward(self, inp_a, inp_b): 153 | encoded_a = self.encoder(inp_a) 154 | encoded_b = self.encoder(inp_b) 155 | aligned_a = self.aligner(encoded_a, encoded_b) 156 | fused_a = self.fuser(encoded_a, aligned_a) 157 | aligned_b = self.aligner(encoded_b, encoded_a) 158 | fused_b = self.fuser(encoded_b, aligned_b) 159 | return fused_a, fused_b 160 | 161 | 162 | class Encoder(Module): 163 | def __init__(self, dim): 164 | super(Encoder, self).__init__() 165 | self.bilstm = LSTM(dim * 2, dim, bidirectional=True, dropout=0.2, num_layers=1) 166 | 167 | 168 | def forward(self, x): 169 | t_h, (_, _) = self.bilstm(x) 170 | o = T.cat([x, t_h], dim=-1) 171 | return o 172 | 173 | 174 | class Fuser(Module): 175 | def __init__(self, dim): 176 | super(Fuser, self).__init__() 177 | self.G1 = Linear(dim * 8, dim) 178 | # init_ff_layer(self.G1) 179 | self.G2 = Linear(dim * 8, dim) 180 | self.G3 = Linear(dim * 8, dim) 181 | self.G = Linear(dim * 3, dim) 182 | # init_ff_layer(self.G2) 183 | # init_ff_layer(self.G3) 184 | # init_ff_layer(self.G) 185 | 186 | 187 | def _fuse(self, z, z_new): 188 | z1 = self.G1(T.cat([z, z_new], dim=-1)) 189 | z2 = self.G2(T.cat([z, z - z_new], dim=-1)) 190 | z3 = self.G3(T.cat([z, z * z_new], dim=-1)) 191 | z_o = F.dropout(self.G(T.cat([z1, z2, z3], dim=-1)), p=0.2) 192 | return z_o 193 | 194 | def forward(self, z, z_new): 195 | z = self._fuse(z, z_new) 196 | return z 197 | 198 | 199 | class Aligner(Module): 200 | def __init__(self, dim): 201 | super(Aligner, self).__init__() 202 | self.ff = Linear(dim * 4, dim * 4) 203 | # init_ff_layer(self.ff) 204 | 205 | 206 | def forward(self, ix, iother): 207 | ex = F.dropout(self.ff(ix), p=0.2) 208 | eother = F.dropout(self.ff(iother), p=0.2) 209 | align_x = ex.bmm(eother.transpose(1, 2)) 210 | align_x = F.softmax(align_x, dim=-1) 211 | aligned = align_x.bmm(iother) 212 | return aligned 213 | 214 | 215 | class Prediction(Module): 216 | def __init__(self, dim): 217 | super(Prediction, self).__init__() 218 | self.H = Linear(dim * 4, 1) 219 | # init_ff_layer(self.H) 220 | 221 | 222 | def forward(self, a, b): 223 | a = T.max(a, dim=1)[0] 224 | b = T.max(b, dim=1)[0] 225 | y = F.dropout(self.H(T.cat([a, b, T.abs(a - b), a * b], dim=-1)), p=0.5) 226 | y = T.sigmoid(y) 227 | return y 228 | 229 | 230 | if __name__ == '__main__': 231 | a = T.ones(32, 768, dtype=T.float32) 232 | b = T.ones(32, 768, dtype=T.float32) 233 | y = EuclideanMatcher()(a,b) 234 | print(y.size()) 235 | -------------------------------------------------------------------------------- /scripts/batch_generator.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, OrderedDict 2 | from random import choice 3 | from utils import * 4 | from sim import * 5 | import itertools 6 | import json 7 | class BatchGenerator: 8 | def __init__(self, X, Y, tokenizer, class_map, config, indiced_data_dict=True): 9 | self.class_map = class_map 10 | self.inverse_map = {y: x for x, y in self.class_map.items()} 11 | self.classes = list(class_map.values()) 12 | self.config = config 13 | self.tokenizer = tokenizer 14 | self.indiced_data_dict = indiced_data_dict 15 | self.X = X 16 | self.Y = Y 17 | self.data_dict = self._compute_data_dict(X, Y, index=indiced_data_dict) 18 | self.labelled_data_dict = self._compute_labelled_data_dict(X, Y, index=indiced_data_dict) 19 | self.count_dict = {x: len(y) for x, y in self.data_dict.items()} 20 | print(self.count_dict) 21 | 22 | def sample_batch(self): 23 | pass 24 | 25 | def sample_support_dict(self, classes, n_support): 26 | support_dict = defaultdict(list) 27 | for cls in classes: 28 | supports = sample(from_array=self.labelled_data_dict[cls], n_sample=n_support, replace=False) 29 | assert supports, (cls, len(self.labelled_data_dict[cls])) 30 | support_dict[cls] = supports 31 | return support_dict 32 | 33 | def __len__(self): 34 | pass 35 | 36 | def _compute_labelled_data_dict(self, X, Y, index=True): 37 | X = [self.inverse_map[y] + '=' + x for x, y in list(zip(X, Y))] 38 | if index: 39 | X = texts_to_indices(X, self.config.max_sent_len, self.tokenizer) 40 | data_dict = defaultdict(list) 41 | for i, (x, y) in enumerate(zip(X, Y)): 42 | data_dict[y].append(x) 43 | return data_dict 44 | 45 | def _compute_data_dict(self, X, Y, index=True): 46 | if index: 47 | X = texts_to_indices(X, self.config.max_sent_len, self.tokenizer) 48 | data_dict = defaultdict(list) 49 | for i, (x, y) in enumerate(zip(X, Y)): 50 | data_dict[y].append(x) 51 | return data_dict 52 | 53 | 54 | ''' need to pass in a dictionary with keys being classes(int), and values being positive samples of that class, data in format [sent, class] ''' 55 | ''' 56 | data_dict: {class_nm:instances} dict 57 | n_pos, n_neg: num of pos/neg samples 58 | all_sents: in case not all sents are in data_dict (e.g. some sents belong to no classes) 59 | ''' 60 | 61 | 62 | class RandomBatchGenerator(BatchGenerator): 63 | def __init__(self, *args, **kwargs): 64 | super(RandomBatchGenerator, self).__init__(*args, **kwargs) 65 | self.k = self.config.k 66 | self.n_batch = self.config.n_batch 67 | self.classes = list(self.data_dict.keys()) 68 | self.n_pos = self.config.n_pos 69 | self.neg_ratio = self.config.neg_ratio 70 | self.n_support = self.config.n_support 71 | 72 | def sample_batch(self): 73 | for i in range(self.n_batch): 74 | chosen_classes = sample(self.classes, self.k, replace=False) 75 | batch_pos_data = [] 76 | batch_neg_data = [] 77 | for cls in chosen_classes: 78 | cls_pos_data = sample(self.data_dict[cls], self.n_pos, replace=True) 79 | corr_classes = list(set(chosen_classes).difference([cls])) * self.n_pos 80 | tmp = list(itertools.chain.from_iterable(itertools.repeat(x, self.k - 1) for x in cls_pos_data)) 81 | cls_pos_data = list(zip([cls] * self.n_pos, cls_pos_data)) 82 | cls_neg_data = list(zip(corr_classes, tmp)) 83 | cls_neg_data = sample(cls_neg_data, self.neg_ratio * self.n_pos, replace=False) 84 | batch_pos_data += cls_pos_data 85 | batch_neg_data += cls_neg_data 86 | supports_dict = self.sample_support_dict(chosen_classes, self.n_support) 87 | if batch_pos_data: 88 | yield batch_pos_data, batch_neg_data, supports_dict 89 | 90 | def __len__(self): 91 | return self.n_batch 92 | 93 | 94 | class ProtoNetBatchGenerator(BatchGenerator): 95 | def __init__(self, *args, **kwargs): 96 | super(ProtoNetBatchGenerator, self).__init__(*args, **kwargs) 97 | self.k = self.config.k 98 | self.n_batch = self.config.n_batch 99 | self.shot = self.config.shot 100 | self.n_support = self.config.n_support 101 | self.class_weight = [] 102 | for cls in self.classes: 103 | self.class_weight.append(self.count_dict[cls]) 104 | self.class_weight = np.asarray(self.class_weight) 105 | self.class_weight = (self.class_weight / self.class_weight.sum(0)).tolist() 106 | print(self.class_weight) 107 | 108 | def sample_batch(self): 109 | for i in range(self.n_batch): 110 | query_dict = OrderedDict() 111 | if self.k < len(self.classes): 112 | p = self.class_weight 113 | else: 114 | p = None 115 | chosen_classes = sample(self.classes, self.k, replace=False, p=p) 116 | assert len(set(chosen_classes).difference(self.classes)) == 0 117 | for cls in chosen_classes: 118 | cls_pos_data = sample(self.data_dict[cls], self.shot, replace=False) 119 | query_dict[cls] = cls_pos_data 120 | # all_y = list(set([x[0] for x in batch_neg_data+batch_pos_data])) 121 | supports_dict = self.sample_support_dict(chosen_classes, self.n_support) 122 | assert query_dict.keys() == supports_dict.keys() 123 | if query_dict and supports_dict: 124 | yield supports_dict, query_dict 125 | 126 | def __len__(self): 127 | return self.n_batch 128 | 129 | 130 | class ProtoNetEvalBatchGenerator(BatchGenerator): 131 | def __init__(self, testX, testY, *args, **kwargs): 132 | super(ProtoNetEvalBatchGenerator, self).__init__(*args, **kwargs) 133 | self.testX = texts_to_indices(testX, self.config.max_sent_len, self.tokenizer) 134 | self.testY = testY 135 | print("Length of Test Data", len(testY)) 136 | self.batch_size = self.config.eval_batch_size 137 | if len(self.testY)%self.batch_size==0: 138 | self.n_batch = len(self.testY) // self.batch_size 139 | else: 140 | self.n_batch = len(self.testY) // self.batch_size + 1 141 | 142 | def __len__(self): 143 | return self.n_batch 144 | 145 | def get_support_dict(self): 146 | return self.sample_support_dict(self.classes, self.config.eval_n_support) 147 | 148 | def sample_batch(self): 149 | for i in range(self.n_batch): 150 | head = i * self.batch_size 151 | tail = min(len(self.testY), head + self.batch_size) 152 | batch_testX = self.testX[head:tail] 153 | batch_testY = self.testY[head:tail] 154 | if batch_testX: 155 | yield batch_testX, batch_testY 156 | 157 | 158 | class ShenyuanBatchGenerator(BatchGenerator): 159 | def __init__(self, *args, **kwargs): 160 | super(ShenyuanBatchGenerator, self).__init__(indiced_data_dict=False, *args, **kwargs) 161 | self.batch_size = self.config.batch_size 162 | if len(self.X) % self.batch_size == 0: 163 | self.n_batch = len(self.X) // self.batch_size 164 | else: 165 | self.n_batch = len(self.X) // self.batch_size + 1 166 | with open(self.config.corpus_path, 'w', encoding='utf-8') as f: 167 | json.dump(self.data_dict, f, ensure_ascii=False) 168 | 169 | def __len__(self): 170 | return self.n_batch 171 | 172 | def sample_batch(self): 173 | for i in range(self.n_batch): 174 | head = i * self.batch_size 175 | tail = min(len(self.X), head + self.batch_size) 176 | batchX = self.X[head:tail] 177 | batchY = self.Y[head:tail] 178 | batch_paired_X = [] 179 | batch_match_Y = [] 180 | for x, y in zip(batchX, batchY): 181 | other_classes = list(set(self.classes).difference([y])) 182 | random_class = choice(other_classes) 183 | sampled_neg = sample(self.data_dict[random_class], 1)[0] 184 | sampled_pos = sample(self.data_dict[y], 1)[0] 185 | pos = self._join(x, sampled_pos) 186 | neg = self._join(x, sampled_neg) 187 | batch_paired_X.append(pos) 188 | batch_paired_X.append(neg) 189 | batch_match_Y += [1, 0] 190 | if batch_match_Y: 191 | batch_paired_X = texts_to_indices(batch_paired_X, max_len=self.config.max_sent_len, 192 | tokenizer=self.tokenizer) 193 | yield batch_paired_X, batch_match_Y 194 | 195 | def _join(self, a, b): 196 | cut_len_a = min(self.config.max_sent_len // 2, len(a)) 197 | cut_len_b = min(self.config.max_sent_len // 2, len(b)) 198 | return a[:cut_len_a] + '[SEP]' + b[:cut_len_b] 199 | 200 | 201 | class ShenyuanEvalBatchGenerator(BatchGenerator): 202 | def __init__(self, testX, testY, *args, **kwargs): 203 | super(ShenyuanEvalBatchGenerator, self).__init__(indiced_data_dict=False, *args, **kwargs) 204 | self.batch_size = self.config.eval_batch_size 205 | self.testX = testX 206 | self.testY = testY 207 | self.n_support = self.config.eval_n_support 208 | if len(self.testX) % self.batch_size == 0: 209 | self.n_batch = len(self.testX) // self.batch_size 210 | else: 211 | self.n_batch = len(self.testX) // self.batch_size + 1 212 | 213 | 214 | def sample_batch(self): 215 | for i in range(self.n_batch): 216 | head = i * self.batch_size 217 | tail = min(len(self.testY), head + self.batch_size) 218 | batchX = self.testX[head:tail] 219 | batchY = self.testY[head:tail] 220 | batch_paired_X = [] 221 | batch_labels = [] 222 | for x, y in zip(batchX, batchY): 223 | supports = self._sample_topN_supports(self.n_support, x) 224 | pairs = [[self._join(x, b) for b in ss] for ss in supports] 225 | batch_paired_X += [texts_to_indices(p, max_len=self.config.max_sent_len, tokenizer=self.tokenizer) for p 226 | in pairs] 227 | batch_labels.append(y) 228 | if batch_labels: 229 | yield batch_paired_X, batch_labels 230 | 231 | def __len__(self): 232 | return self.n_batch 233 | 234 | def _join(self, a, b): 235 | cut_len_a = min(self.config.max_sent_len // 2, len(a)) 236 | cut_len_b = min(self.config.max_sent_len // 2, len(b)) 237 | return a[:cut_len_a] + '[SEP]' + b[:cut_len_b] 238 | 239 | def _sample_topN_supports(self, n, x): 240 | topn_all_class = [] 241 | 242 | for cls in self.classes: 243 | candidates = self.data_dict[cls] 244 | if n > len(candidates): 245 | sims = [(y, char_jaccard(x, y)) for y in candidates] 246 | topn = list(sorted(sims, key=lambda x: x[1]))[-n:] 247 | topn = [x[0] for x in topn] 248 | else: 249 | topn = sample(candidates, n, replace=True) 250 | topn_all_class.append(topn) 251 | return topn_all_class 252 | -------------------------------------------------------------------------------- /scripts/trainer.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | # from matcher import * 3 | import torch as T 4 | from sklearn.metrics import roc_auc_score 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | from torch import save, load 8 | from utils import * 9 | import numpy as np 10 | from torch.optim.lr_scheduler import StepLR 11 | from torch.optim import Adam 12 | from sklearn.metrics import classification_report 13 | import pickle 14 | import json 15 | import torch.nn as nn 16 | from pytorch_pretrained_bert.optimization import BertAdam 17 | 18 | #from pytorch_pretrained_bert.modeling import BertModel, BertPreTrainedModel 19 | from matcher import init_bert_weights 20 | 21 | def pool_from_encoder(encoding, pooling='cls'): 22 | if pooling == 'max': 23 | encoding = T.max(encoding[0], dim=1)[0] 24 | elif pooling == 'mean': 25 | encoding = T.mean(encoding[0], dim=1) 26 | elif pooling == 'cls': 27 | encoding = encoding[1] 28 | else: 29 | encoding = encoding[0] 30 | return encoding 31 | 32 | class Trainer: 33 | def __init__(self, encoder, matcher, dl, config, loss_fn, *args, **kwargs): 34 | self.device = config.device 35 | self.config = config 36 | self.dl = dl 37 | self.matcher = matcher 38 | print(self.device) 39 | self.loss_fn = loss_fn 40 | self.encoder = encoder 41 | #init_bert_weights(self.encoder) 42 | self.encoder.to(self.device) 43 | for param in self.encoder.parameters(): 44 | param.requires_grad = True 45 | self.params = list(self.encoder.named_parameters()) 46 | if self.matcher: 47 | self.matcher.to(self.device) 48 | self.params += list(self.matcher.named_parameters()) 49 | param_optimizer = self.params 50 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 51 | self.params = [{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}] 52 | if self.config.warmup>0: 53 | self.optimizer = BertAdam(lr=self.config.lr, params=self.params, t_total=self.config.t_total, warmup=self.config.warmup) 54 | else: 55 | self.optimizer = BertAdam(lr=self.config.lr, params=self.params) 56 | #self.optimizer = Adam(lr=self.config.lr, params=self.params) 57 | self.steplr = None 58 | if hasattr(config, 'lr_decay'): 59 | if config.lr_decay: 60 | self.steplr = StepLR(self.optimizer, step_size=config.lr_decay_step_size, gamma=config.lr_decay) 61 | self.best_report = '' 62 | self.best_acc = 0. 63 | self.best_epoch = 0. 64 | 65 | def load_data(self): 66 | pass 67 | 68 | def train(self): 69 | pass 70 | 71 | def evaluate(self): 72 | pass 73 | 74 | 75 | def save_model(self): 76 | encoder_path = self.config.encoder_save_path 77 | try: 78 | encoder_path = '/'.join(encoder_path.split('/')[:-1]) 79 | self.encoder.save_pretrained(encoder_path) 80 | 81 | except: 82 | save(self.encoder, encoder_path) 83 | if self.matcher and self.matcher.parameters(): 84 | matcher_path = self.config.matcher_save_path 85 | save(self.matcher, matcher_path) 86 | 87 | def load_model(self): 88 | encoder_path = self.config.encoder_load_path 89 | try: 90 | encoder_path = '/'.join(encoder_path.split('/')[:-1]) 91 | self.encoder = BertModel.from_pretrained(encoder_path) 92 | except: 93 | self.encoder = load(encoder_path) 94 | try: 95 | matcher_path = self.config.matcher_load_path 96 | self.matcher = load(matcher_path) 97 | except: 98 | print('no need to load matcher or trained matcher does not exist') 99 | 100 | 101 | class MatchingTrainer(Trainer): 102 | 103 | def load_data(self): 104 | self.lines = self.dl.load_data() 105 | self.training_data_generator, self.eval_data_generator = self.dl.get_dataloader(self.lines, trainer='matching') 106 | 107 | 108 | def train(self): 109 | for i in range(self.config.epoch): 110 | print("{}/{} epoch".format(i, self.config.epoch)) 111 | epoch_auc = 0. 112 | self.encoder.train() 113 | self.matcher.train() 114 | for j, (pos_batch, neg_batch, supports_dict) in tqdm( 115 | enumerate(self.training_data_generator.sample_batch()), total=len(self.training_data_generator)): 116 | labels = [1]*len(pos_batch)+[0]*len(neg_batch) 117 | batch_data = pos_batch+neg_batch 118 | queries = [x[1] for x in batch_data] 119 | domains = [x[0] for x in batch_data] 120 | encoded_center_dict, encoded_queries = self._encode_batch(queries, supports_dict) 121 | loss, auc = self._get_loss(encoded_center_dict, encoded_queries, labels, domains) 122 | self.optimizer.zero_grad() 123 | 124 | loss.backward() 125 | # nn.utils.clip_grad_norm_(self.params, 10) 126 | self.optimizer.step() 127 | epoch_auc += auc 128 | print("EPOCH AVG AUC: {}".format(epoch_auc / len(self.training_data_generator))) 129 | if i % self.config.eval_interval == 1: 130 | with T.no_grad(): 131 | eval_ACC, eval_report = self.evaluate() 132 | if eval_ACC > self.best_acc: 133 | self.best_acc = eval_ACC 134 | self.best_report = eval_report 135 | self.best_epoch = i 136 | self.save_model() 137 | print(eval_report) 138 | print('saved model...') 139 | 140 | print("current best ACC: ", self.best_acc) 141 | 142 | 143 | def evaluate(self): 144 | print("Evaluating...") 145 | self.encoder.eval() 146 | self.matcher.eval() 147 | supports_dict = self.eval_data_generator.get_support_dict() 148 | encoded_center_dict = OrderedDict() 149 | for cls in supports_dict: 150 | supports = supports_dict[cls] 151 | encoded_supports = self.encoder(T.LongTensor(supports).to(self.device)) 152 | encoded_supports = pool_from_encoder(encoded_supports) 153 | encoded_center_dict[cls] = encoded_supports.mean(0).unsqueeze(0) 154 | all_preds = [] 155 | all_Y = [] 156 | for k, (batch_X, batch_Y) in enumerate(self.eval_data_generator.sample_batch()): 157 | preds = self._encode_eval_batch(batch_X, encoded_center_dict) 158 | all_preds += preds 159 | all_Y += batch_Y 160 | acc = calc_acc(all_preds, all_Y) 161 | print("EPOCH ACC : ", acc) 162 | report = classification_report(all_Y, all_preds) 163 | return acc, report 164 | 165 | def _encode_batch(self, queries, supports_dict): 166 | center_dict = OrderedDict() 167 | classes = supports_dict.keys() 168 | for c in classes: 169 | supports = supports_dict[c] 170 | encoded_supports = self.encoder(T.LongTensor(supports).to(self.device)) 171 | encoded_supports = pool_from_encoder(encoded_supports) 172 | 173 | center_dict[c] = encoded_supports.mean(0).unsqueeze(0) 174 | queries = np.asarray(queries) 175 | 176 | encoded_queries = self.encoder(T.LongTensor(queries).to(self.device)) 177 | encoded_queries = pool_from_encoder(encoded_queries) 178 | return center_dict, encoded_queries 179 | 180 | def _encode_eval_batch(self, queries, center_dict): 181 | encoded_queries = self.encoder(T.LongTensor(queries).to(self.device)) 182 | encoded_queries = pool_from_encoder(encoded_queries) 183 | # encoded_queries = encoded_queries.tolist() 184 | repeated_queries = T.cat(len(center_dict)*[encoded_queries]) 185 | classes = list(center_dict.keys()) 186 | target_centers = T.cat(len(queries)*[center_dict[x] for x in classes]) 187 | scores = self.matcher(repeated_queries, target_centers).view(len(queries), len(classes)) 188 | preds = T.argmax(scores, dim=-1).squeeze().tolist() 189 | #print(preds, '\n', labels) 190 | 191 | return preds 192 | 193 | def _get_loss(self, encoded_center_dict, encoded_queries, labels, domains): 194 | # centers = list(encoded_center_dict.values()) 195 | # centers = T.cat(centers, dim=0) 196 | centers = T.cat([encoded_center_dict[x] for x in domains], dim=0).to(self.device) 197 | scores = self.matcher(encoded_queries, centers).squeeze() 198 | # print(scores, truth) 199 | with T.no_grad(): 200 | auc = roc_auc_score(labels, scores.detach().tolist()) 201 | loss = self.loss_fn(scores, T.FloatTensor(labels).to(self.device)) 202 | return loss, auc 203 | 204 | class ProtoNetTrainer(Trainer): 205 | 206 | def load_data(self): 207 | self.lines = self.dl.load_data() 208 | self.training_data_generator, self.eval_data_generator = self.dl.get_dataloader(self.lines, trainer='protonet') 209 | 210 | def train(self): 211 | for i in range(self.config.epoch): 212 | # if self.steplr: 213 | # self.steplr.step() 214 | self.encoder.train() 215 | print("{}/{} epoch".format(i, self.config.epoch)) 216 | print("TRAINING...") 217 | for j, (supports_dict, queries_dict) in enumerate(self.training_data_generator.sample_batch()): 218 | center_dict, q_dict = self._process_batch(supports_dict, queries_dict) 219 | centers = list(center_dict.values()) 220 | centers = T.cat(centers, dim=0) 221 | loss = 0. 222 | for ind, cls in enumerate(supports_dict): 223 | qs = q_dict[cls] 224 | loss += self.loss_fn(qs, centers, ind) 225 | loss /= len(supports_dict) 226 | self.optimizer.zero_grad() 227 | loss.backward() 228 | #nn.utils.clip_grad_norm_(self.params, 10) 229 | self.optimizer.step() 230 | print('TRAINING loss: ', loss.item()) 231 | 232 | if i % self.config.eval_interval == 1: 233 | with T.no_grad(): 234 | eval_ACC, eval_report = self.evaluate() 235 | if eval_ACC > self.best_acc: 236 | self.best_acc = eval_ACC 237 | self.best_report = eval_report 238 | self.best_epoch = i 239 | self.save_model() 240 | print(eval_report) 241 | print('saved model...') 242 | 243 | print("current best ACC: ", self.best_acc) 244 | return self.best_acc, self.best_report, self.best_epoch 245 | 246 | def evaluate(self): 247 | print("EVALUATING...") 248 | self.encoder.eval() 249 | all_preds = [] 250 | all_Y = [] 251 | supports_dict = self.eval_data_generator.get_support_dict() 252 | encoded_center_dict = OrderedDict() 253 | order_keys = list(supports_dict.keys()) 254 | self.classes = order_keys 255 | for cls in supports_dict: 256 | supports = supports_dict[cls] 257 | encoded_supports = self.encoder(T.LongTensor(supports).to(self.device)) 258 | encoded_supports = pool_from_encoder(encoded_supports) 259 | encoded_center_dict[cls] = encoded_supports.mean(0).unsqueeze(0) 260 | centers = T.cat(list(encoded_center_dict.values()), dim=0) 261 | self.centers = centers.detach().tolist() 262 | total = len(self.eval_data_generator) 263 | for k, (batch_X, batch_Y) in tqdm(enumerate(self.eval_data_generator.sample_batch()), total=total): 264 | preds = self._process_eval_batch(centers, batch_X) 265 | preds = [order_keys[m] for m in preds] 266 | all_preds += preds 267 | all_Y += batch_Y 268 | accs = calc_acc(all_preds, all_Y) 269 | report = classification_report(all_Y, all_preds) 270 | print("EPOCH evaluation ACC: ", accs) 271 | return accs, report 272 | 273 | def _process_batch(self, supports_dict, queries_dict): 274 | center_dict = OrderedDict() 275 | q_dict = OrderedDict() 276 | # print(supports_dict.keys()) 277 | for i in supports_dict: 278 | samples = supports_dict[i] 279 | assert samples, (i, samples) 280 | samples = self.encoder(T.LongTensor(samples).to(self.device)) 281 | qsamples = queries_dict[i] 282 | qsamples = self.encoder(T.LongTensor(qsamples).to(self.device)) 283 | samples = pool_from_encoder(samples) 284 | qsamples = pool_from_encoder(qsamples) 285 | center = samples.mean(0) 286 | center_dict[i] = center.unsqueeze(0) 287 | q_dict[i] = qsamples 288 | return center_dict, q_dict 289 | 290 | def _process_eval_batch(self, centers, queries): 291 | encoded_batchX = self.encoder(T.LongTensor(queries).to(self.device)) 292 | encoded_batchX = pool_from_encoder(encoded_batchX) 293 | dists = T.cdist(encoded_batchX, centers) 294 | 295 | lbs = T.argmin(dists, dim=-1).detach().tolist() 296 | #print() 297 | #print("PREDS: ", lbs) 298 | 299 | return lbs 300 | 301 | def _get_loss(self): 302 | pass 303 | 304 | def save_model(self): 305 | super(ProtoNetTrainer, self).save_model() 306 | with open(self.config.center_save_path, 'wb') as f: 307 | pickle.dump(self.centers, f) 308 | with open(self.config.class_map_save_path, 'w', encoding='utf-8') as f: 309 | inverse_class_map = {y:x for x,y in self.dl.class_map.items()} 310 | json.dump(inverse_class_map, f) 311 | 312 | class BertClassificationTrainer(Trainer): 313 | def load_data(self): 314 | self.lines = self.dl.load_data() 315 | self.training_data_generator, self.eval_data_generator = self.dl.get_dataloader(self.lines, trainer='bert') 316 | #init_bert_weights(self.matcher) 317 | 318 | def train(self): 319 | for i in range(self.config.epoch): 320 | self.encoder.train() 321 | self.matcher.train() 322 | print("{}/{} epoch".format(i, self.config.epoch)) 323 | print("TRAINING...") 324 | total = len(self.training_data_generator) 325 | accs = 0. 326 | losses = 0. 327 | for j, (batchX, batchY) in tqdm(enumerate(self.training_data_generator), total=total): 328 | batchX = batchX.to(self.device) 329 | batchY = batchY.to(self.device) 330 | encoded_X = self.encoder(batchX)[1] 331 | logits = self.matcher(encoded_X) 332 | # loss = self.loss_fn(logits, batchY) 333 | loss = F.cross_entropy(logits, batchY) 334 | preds = T.argmax(logits, dim=-1).detach().tolist() 335 | acc = calc_acc(preds,batchY.detach().tolist()) 336 | accs += acc 337 | losses+=loss 338 | self.optimizer.zero_grad() 339 | loss.backward() 340 | # nn.utils.clip_grad_norm_(self.params, 10) 341 | self.optimizer.step() 342 | print('TRAINING loss: ', losses.item()/total) 343 | print("Training acc: ", accs/total) 344 | if i % self.config.eval_interval == 1: 345 | with T.no_grad(): 346 | eval_ACC, eval_report = self.evaluate() 347 | if eval_ACC > self.best_acc: 348 | self.best_acc = eval_ACC 349 | self.best_report = eval_report 350 | self.best_epoch = i 351 | self.save_model() 352 | print(eval_report) 353 | print('saved model...') 354 | 355 | print("current best ACC: ", self.best_acc) 356 | return self.best_acc, self.best_report, self.best_epoch 357 | 358 | def evaluate(self): 359 | self.encoder.eval() 360 | self.matcher.eval() 361 | print("Evaluating...") 362 | total = len(self.eval_data_generator) 363 | all_preds = [] 364 | all_Y = [] 365 | for j, (batchX, batchY) in tqdm(enumerate(self.eval_data_generator), total=total): 366 | batchX = batchX.to(self.device) 367 | batchY = batchY.to(self.device) 368 | encoded_X = self.encoder(batchX)[1] 369 | logits = self.matcher(encoded_X) 370 | preds = T.argmax(logits, dim=-1).detach().tolist() 371 | all_preds += preds 372 | all_Y += batchY.detach().tolist() 373 | accs = calc_acc(all_preds, all_Y) 374 | report = classification_report(all_Y, all_preds) 375 | print("EVALUATION ACC: ", accs) 376 | return accs, report 377 | 378 | def save_model(self): 379 | super(BertClassificationTrainer, self).save_model() 380 | with open(self.config.class_map_save_path, 'w', encoding='utf-8') as f: 381 | inverse_class_map = {y:x for x,y in self.dl.class_map.items()} 382 | json.dump(inverse_class_map, f) 383 | 384 | 385 | class ShenyuanTrainer(Trainer): 386 | # def __init__(self, *args, **kwargs): 387 | # super(ShenyuanTrainer, self).__init__(*args, **kwargs) 388 | 389 | def load_data(self): 390 | self.lines = self.dl.load_data() 391 | self.training_data_generator, self.eval_data_generator = self.dl.get_dataloader(self.lines, trainer='shenyuan') 392 | 393 | def train(self): 394 | self.eval_batches = [] 395 | self.eval_Y = [] 396 | for i in range(self.config.epoch): 397 | #if self.steplr: 398 | # self.steplr.step() 399 | self.encoder.train() 400 | self.matcher.train() 401 | print("{}/{} epoch".format(i, self.config.epoch)) 402 | print("TRAINING...") 403 | total = len(self.training_data_generator) 404 | accs = 0. 405 | losses = 0. 406 | for j, (batchX, batchY) in tqdm(enumerate(self.training_data_generator.sample_batch()), total=total): 407 | batchX = T.LongTensor(batchX) 408 | batchY = T.FloatTensor(batchY) 409 | batchX = batchX.to(self.device) 410 | batchY = batchY.to(self.device) 411 | encoded_X = self.encoder(batchX)[1] 412 | logits = self.matcher(encoded_X).squeeze() 413 | loss = self.loss_fn(logits, batchY) 414 | # print(logits.size()) 415 | # preds = T.argmax(logits, dim=-1).detach().tolist() 416 | # acc = calc_acc(preds, batchY.detach().tolist()) 417 | # accs += acc 418 | losses += loss 419 | self.optimizer.zero_grad() 420 | loss.backward() 421 | # nn.utils.clip_grad_norm_(self.params, 10) 422 | self.optimizer.step() 423 | print('TRAINING loss: ', losses.item() / total) 424 | # print("Training acc: ", accs / total) 425 | with T.no_grad(): 426 | eval_acc, report = self.evaluate() 427 | if eval_acc > self.best_acc: 428 | self.best_acc = eval_acc 429 | self.best_report = report 430 | self.best_epoch = i 431 | self.save_model(self.config.corpus_path, self.config.class_map_save_path) 432 | print("Current Best ACC: ", self.best_acc) 433 | return self.best_acc, self.best_report, self.best_epoch 434 | 435 | def evaluate(self): 436 | self.encoder.eval() 437 | self.matcher.eval() 438 | print("Evaluating...") 439 | total = len(self.eval_data_generator) 440 | classes = self.eval_data_generator.classes 441 | save_flag = not self.eval_batches 442 | all_Y = [] 443 | all_preds = [] 444 | if save_flag: 445 | for j, (batchX, batchY) in tqdm(enumerate(self.eval_data_generator.sample_batch()), total=total): 446 | batch_size = len(batchX)//len(classes) 447 | batchX = T.LongTensor(batchX) 448 | #batchX = batchX.to(self.device) 449 | batchX = batchX.view(batch_size*len(classes)*self.config.eval_n_support, self.config.max_sent_len) 450 | self.eval_batches.append(batchX) 451 | self.eval_Y.append(batchY) 452 | for i, batchX in enumerate(self.eval_batches): 453 | batchY = self.eval_Y[i] 454 | if i % 10 ==0: 455 | print("{}/{}".format(i, len(self.eval_batches))) 456 | batch_size = batchX.size(0)//(len(classes)*self.config.eval_n_support) 457 | batchX = batchX.to(self.device) 458 | encoded_X = self.encoder(batchX)[1] 459 | logits = self.matcher(encoded_X).squeeze() 460 | logits = logits.view(batch_size, len(classes), self.config.eval_n_support) 461 | logits = T.mean(logits, dim=-1).squeeze() 462 | preds = T.argmax(logits, dim=-1).squeeze().detach().tolist() 463 | all_preds += preds 464 | all_Y += batchY 465 | accs = calc_acc(all_preds, all_Y) 466 | report = classification_report(all_Y, all_preds) 467 | print("EVALUATION ACC: ", accs / len(self.eval_Y)) 468 | return accs, report 469 | 470 | def save_model(self): 471 | class_map_path = self.config.class_map_save_path 472 | super(ShenyuanTrainer, self).save_model() 473 | with open(class_map_path, 'w', encoding='utf-8') as f: 474 | inverse_class_map = {y:x for x,y in self.dl.class_map.items()} 475 | json.dump(inverse_class_map, f, ensure_ascii=False) 476 | 477 | if __name__ == '__main__': 478 | matcher = None 479 | # matcher = RE2() 480 | # matcher = SimpleMatcher() 481 | matcher = CosMatcher() 482 | # matcher = DotMatcher() 483 | trainer = Trainer(matcher=matcher) 484 | print("(((((((((((((((((((((((((((") 485 | trainer.load_data() 486 | # trainer.train_classification() 487 | trainer.train() 488 | --------------------------------------------------------------------------------