├── .gitignore ├── README.md ├── data └── data_preprocess.py ├── fewshot_re_kit ├── __init__.py ├── data_loader.py ├── framework.py └── util.py ├── models ├── __init__.py └── proto.py ├── run_main.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # folders 2 | check_points 3 | test_result 4 | 5 | # files 6 | *.pyc 7 | *.swp 8 | *.tar 9 | *.sh 10 | sbatch* 11 | *.ipynb 12 | backup_model 13 | *.xlsx 14 | # data 15 | pretrain 16 | *.json 17 | # virtualenv 18 | .virtual 19 | 20 | # result 21 | result* 22 | 23 | # test data 24 | data/test_wiki.json 25 | data/test_pubmed.json 26 | data/train_wiki+pubmed.json 27 | data/pubmed.json 28 | 29 | # tmp file 30 | pretrain.tar 31 | cn_check_points 32 | 33 | # mac cache 34 | .DS_Store 35 | 36 | # editor cache 37 | .idea 38 | .vscode 39 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prototypical Networks 2 | 3 | A implementation of the models for the paper ["Prototypical Networks for Few-shot Learning"](https://arxiv.org/pdf/1703.05175.pdf) published in NIPS 2017. -------------------------------------------------------------------------------- /data/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import logging 3 | import numpy as np 4 | import collections 5 | 6 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 7 | datefmt='%m/%d/%Y %H:%M:%S', 8 | level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def filter_data(data, threshold): 13 | filtered_data = {} 14 | for k, v in data.items(): 15 | if len(v) > threshold: 16 | filtered_data[k] = v 17 | return filtered_data 18 | 19 | 20 | def read_data(data_path, threshold): 21 | data = pd.read_excel(data_path, sheet_name='Sheet1', header=[0], usecols='A,B,C').fillna(0) 22 | data_label = {} 23 | count = 0 24 | for i in range(len(data)): 25 | if data.iloc[i]["Abstract"] == 0 or data.iloc[i]["Problem Description"] == 0: 26 | count += 1 27 | continue 28 | text = [data.iloc[i]["Abstract"], data.iloc[i]["Problem Description"], i] 29 | label = data.iloc[i]["Root Cause"] 30 | if label not in data_label: 31 | data_label[label] = [text] 32 | else: 33 | data_label[label].append(text) 34 | filtered_data = filter_data(data_label, threshold=threshold) 35 | logger.info("drop %d samples" % count) 36 | 37 | logger.info("starting samples eval data") 38 | train_data = collections.defaultdict(list) 39 | eval_data = collections.defaultdict(list) 40 | 41 | train_data_nums = 0 42 | eval_data_nums = 0 43 | for k, v in filtered_data.items(): 44 | eval_nums = len(v) // 10 45 | if eval_nums == 0: 46 | eval_nums = 1 47 | eval_data_nums += eval_nums 48 | train_data_nums += len(v) - eval_nums 49 | indices = np.random.choice(list(range(len(v))), eval_nums) 50 | for j in range(len(v)): 51 | if j in indices: 52 | eval_data[k].append(v[j]) 53 | else: 54 | train_data[k].append(v[j]) 55 | logger.info('train data nums: %d, eval data nums: %d' % (train_data_nums, eval_data_nums)) 56 | return train_data, eval_data 57 | 58 | 59 | if __name__ == "__main__": 60 | a, b = read_data('./data/source_data.xlsx', 5) -------------------------------------------------------------------------------- /fewshot_re_kit/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data_loader 2 | from . import framework 3 | from . import util -------------------------------------------------------------------------------- /fewshot_re_kit/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import numpy as np 4 | import pandas as pd 5 | import random 6 | import collections 7 | import json 8 | import logging 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S', 12 | level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def get_rc_data(data, rc_lists): 17 | new_data = {} 18 | count = 0 19 | for k, v in data.items(): 20 | if k in rc_lists: 21 | new_data[k] = v 22 | count += len(v) 23 | logger.info('Remained data numbers: %d' % count) 24 | return new_data 25 | 26 | 27 | def RC_align(path): 28 | def RC_align_out(function): 29 | def wrap(data_path, threshold, is_train=True): 30 | rc_list = json.load(open(path, 'r', encoding='utf8')) 31 | data = function(data_path, threshold, is_train) 32 | if is_train: 33 | train_data, eval_data = data 34 | train_data, eval_data = get_rc_data(train_data, rc_list), get_rc_data(eval_data, rc_list) 35 | return train_data, eval_data 36 | else: 37 | filtered_data = data 38 | filtered_data = get_rc_data(filtered_data, rc_list) 39 | return filtered_data 40 | return wrap 41 | return RC_align_out 42 | 43 | 44 | def filter_data(data, threshold): 45 | filtered_data = {} 46 | for k, v in data.items(): 47 | if len(v) > threshold: 48 | filtered_data[k] = v 49 | logger.info('root cause numbers: %d' % len(filtered_data)) 50 | return filtered_data 51 | 52 | 53 | @RC_align('./data/rc_list.json') 54 | def read_data(data_path, threshold, is_train=True): 55 | data = pd.read_excel(data_path, sheet_name='Sheet1', header=[0], usecols='A,B,C').fillna(0) 56 | data_label = {} 57 | count = 0 58 | for i in range(len(data)): 59 | if data.iloc[i]["Abstract"] == 0 or data.iloc[i]["Problem Description"] == 0: 60 | count += 1 61 | continue 62 | text = [data.iloc[i]["Abstract"], data.iloc[i]["Problem Description"], i] 63 | label = data.iloc[i]["Root Cause"] 64 | if label not in data_label: 65 | data_label[label] = [text] 66 | else: 67 | data_label[label].append(text) 68 | filtered_data = filter_data(data_label, threshold=threshold) 69 | logger.info("drop %d samples" % count) 70 | if is_train: 71 | logger.info("starting samples eval data") 72 | train_data = collections.defaultdict(list) 73 | eval_data = collections.defaultdict(list) 74 | 75 | train_data_nums = 0 76 | eval_data_nums = 0 77 | for k, v in filtered_data.items(): 78 | eval_nums = len(v) // 10 79 | if eval_nums == 0: 80 | eval_nums = 1 81 | eval_data_nums += eval_nums 82 | train_data_nums += len(v) - eval_nums 83 | indices = np.random.choice(list(range(len(v))), eval_nums) 84 | for j in range(len(v)): 85 | if j in indices: 86 | eval_data[k].append(v[j]) 87 | else: 88 | train_data[k].append(v[j]) 89 | logger.info('train data nums: %d, eval data nums: %d' % (train_data_nums, eval_data_nums)) 90 | return train_data, eval_data 91 | else: 92 | return filtered_data 93 | 94 | 95 | class ThinkpadDataset(data.Dataset): 96 | """ 97 | thinkpad 数据集 98 | """ 99 | def __init__(self, data, tokenizer, max_seq_len, N, K, Q): 100 | self.data = data 101 | self.classes = list(self.data.keys()) 102 | self.tokenizer = tokenizer 103 | self.cls_token = self.tokenizer.cls_token 104 | self.sep_token = self.tokenizer.sep_token 105 | self.max_seq_len = max_seq_len 106 | self.N = N 107 | self.K = K 108 | self.Q = Q 109 | 110 | def tokenize(self, text): 111 | tokens = self.tokenizer.tokenize(text) 112 | 113 | if len(tokens) > self.max_seq_len - 2: 114 | tokens = tokens[: self.max_seq_len - 2] 115 | 116 | tokens = [self.cls_token] + tokens + [self.sep_token] 117 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 118 | return token_ids 119 | 120 | def __getitem__(self, idx): 121 | target_classes = random.sample(self.classes, self.N) 122 | support_abstract_set = [] 123 | support_description_set = [] 124 | query_abstract_set = [] 125 | query_description_set = [] 126 | query_label = [] 127 | for i, class_name in enumerate(target_classes): 128 | indices = np.random.choice( 129 | list(range(len(self.data[class_name]))), 130 | self.K + self.Q, False) # 针对每种类别随机抽出K+Q条数据 131 | count = 0 132 | for j in indices: 133 | abstract, description, _ = self.data[class_name][j] 134 | abstract = self.tokenize(abstract) 135 | description = self.tokenize(description) 136 | if count < self.K: 137 | support_abstract_set.append(abstract) # 为support set加入数据 138 | support_description_set.append(description) 139 | else: 140 | query_abstract_set.append(abstract) # 为query set加入数据 141 | query_description_set.append(description) 142 | count += 1 143 | query_label += [i] * self.Q 144 | return support_abstract_set, support_description_set, query_abstract_set, query_description_set, query_label 145 | 146 | def __len__(self): 147 | return 1000000000 148 | 149 | 150 | def pad_seq(insts): 151 | return_list = [] 152 | 153 | max_len = max(len(inst) for inst in insts) 154 | 155 | # input ids 156 | inst_data = np.array( 157 | [inst + list([0] * (max_len - len(inst))) for inst in insts], 158 | ) 159 | return_list += [inst_data.astype("int64")] 160 | 161 | # input sentence type 162 | return_list += [np.zeros_like(inst_data).astype("int64")] 163 | 164 | # input position 165 | inst_pos = np.array([list(range(0, len(inst))) + [0] * (max_len - len(inst)) for inst in insts]) 166 | return_list += [inst_pos.astype("int64")] 167 | 168 | # input mask 169 | input_mask_data = np.array([[1] * len(inst) + [0] * (max_len - len(inst)) for inst in insts]) 170 | return_list += [input_mask_data.astype("float32")] 171 | 172 | return return_list 173 | 174 | 175 | def collate_fn(data): 176 | # [batch_size, N*K, seq_len] \ [batch_size, N*Q, seq_len] 177 | support_abstract_sets, support_description_sets, query_abstract_sets, query_description_sets, query_labels = zip(*data) 178 | batch_s_abs = [] 179 | batch_s_des = [] 180 | batch_q_abs = [] 181 | batch_q_des = [] 182 | batch_q_label = [] 183 | for i in range(len(support_abstract_sets)): 184 | batch_s_abs.extend(support_abstract_sets[i]) 185 | batch_s_des.extend(support_description_sets[i]) 186 | batch_q_abs.extend(query_abstract_sets[i]) 187 | batch_q_des.extend(query_description_sets[i]) 188 | batch_q_label.extend(query_labels[i]) 189 | padded_token_s_abs_ids, padded_text_type_s_abs_ids, padded_position_s_abs_ids, input_s_abs_mask = pad_seq(batch_s_abs) 190 | padded_token_s_des_ids, padded_text_type_s_des_ids, padded_position_s_des_ids, input_s_des_mask = pad_seq(batch_s_des) 191 | padded_token_q_abs_ids, padded_text_type_q_abs_ids, padded_position_q_abs_ids, input_q_abs_mask = pad_seq(batch_q_abs) 192 | padded_token_q_des_ids, padded_text_type_q_des_ids, padded_position_q_des_ids, input_q_des_mask = pad_seq(batch_q_des) 193 | 194 | return_list = [ 195 | padded_token_s_abs_ids, padded_text_type_s_abs_ids, padded_position_s_abs_ids, input_s_abs_mask, 196 | padded_token_s_des_ids, padded_text_type_s_des_ids, padded_position_s_des_ids, input_s_des_mask, 197 | padded_token_q_abs_ids, padded_text_type_q_abs_ids, padded_position_q_abs_ids, input_q_abs_mask, 198 | padded_token_q_des_ids, padded_text_type_q_des_ids, padded_position_q_des_ids, input_q_des_mask, 199 | batch_q_label 200 | ] 201 | return_list = [torch.tensor(batch_data) for batch_data in return_list] 202 | return return_list 203 | 204 | 205 | def get_loader(train_data, encoder, max_seq_len, N, K, Q, batch_size, num_workers=8, collate_fn=collate_fn): 206 | dataset = ThinkpadDataset(train_data, encoder, max_seq_len, N, K, Q) 207 | data_loader = data.DataLoader(dataset=dataset, 208 | batch_size=batch_size, 209 | shuffle=False, 210 | pin_memory=True, 211 | # num_workers=num_workers, 212 | collate_fn=collate_fn) 213 | return iter(data_loader) 214 | 215 | 216 | def output_data_to_excel(data, output_path): 217 | rc = [] 218 | abstract = [] 219 | des = [] 220 | index = [] 221 | for k, v in data.items(): 222 | for _ in v: 223 | rc.append(k) 224 | abstract.append(_[0]) 225 | des.append(_[1]) 226 | index.append(_[2]) 227 | pd_data = pd.DataFrame(data={'abstract': abstract, 'des': des, 'root_cause': rc}) 228 | print(index) 229 | pd_data.to_excel(output_path) 230 | 231 | 232 | if __name__ == "__main__": 233 | np.random.seed(100) 234 | train_data, eval_data = read_data('./data/source_add_CN_V2.xlsx', 5) 235 | # print(len(train_data.keys())) 236 | # train_data, eval_data = read_data('./data/source_data.xlsx', 5) 237 | # print(len(train_data.keys())) 238 | 239 | # output_data_to_excel(train_data, './data/t1.xlsx') 240 | # output_data_to_excel(eval_data, './data/t2.xlsx') 241 | # logger.info("heihei") 242 | 243 | # train_rc_emb = json.load(open('./data/train_rc_emb.json', 'r', encoding='utf8')) 244 | # rc_list = list(train_rc_emb.keys()) 245 | # json.dump(rc_list, open('./data/rc_list.json', 'w', encoding='utf8'), ensure_ascii=False) 246 | -------------------------------------------------------------------------------- /fewshot_re_kit/framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | from fewshot_re_kit import util 5 | import torch 6 | from torch import nn 7 | 8 | # from pytorch_pretrained_bert import BertAdam 9 | from transformers import AdamW, get_linear_schedule_with_warmup 10 | import logging 11 | 12 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 13 | datefmt='%m/%d/%Y %H:%M:%S', 14 | level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def warmup_linear(global_step, warmup_step): 19 | if global_step < warmup_step: 20 | return global_step / warmup_step 21 | else: 22 | return 1.0 23 | 24 | 25 | class FewShotREModel(nn.Module): 26 | def __init__(self, sentence_encoder): 27 | ''' 28 | sentence_encoder: Sentence encoder 29 | 30 | You need to set self.cost as your own loss function. 31 | ''' 32 | nn.Module.__init__(self) 33 | self.sentence_encoder = nn.DataParallel(sentence_encoder) 34 | self.cost = nn.CrossEntropyLoss() 35 | 36 | def forward(self, support, query, N, K, Q): 37 | ''' 38 | support: Inputs of the support set. 39 | query: Inputs of the query set. 40 | N: Num of classes 41 | K: Num of instances for each class in the support set 42 | Q: Num of instances for each class in the query set 43 | return: logits, pred 44 | ''' 45 | raise NotImplementedError 46 | 47 | def loss(self, logits, label): 48 | ''' 49 | logits: Logits with the size (..., class_num) 50 | label: Label with whatever size. 51 | return: [Loss] (A single value) 52 | ''' 53 | N = logits.size(-1) 54 | return self.cost(logits.view(-1, N), label.view(-1)) 55 | 56 | def accuracy(self, pred, label): 57 | ''' 58 | pred: Prediction results with whatever size 59 | label: Label with whatever size 60 | return: [Accuracy] (A single value) 61 | ''' 62 | return torch.mean((pred.view(-1) == label.view(-1)).type(torch.FloatTensor)) 63 | 64 | 65 | class FewShotREFramework: 66 | 67 | def __init__(self, 68 | tokenizer=None, 69 | train_data_loader=None, 70 | val_data_loader=None, 71 | test_data_loader=None, 72 | train_data=None, 73 | eval_data=None): 74 | ''' 75 | train_data_loader: DataLoader for training. 76 | val_data_loader: DataLoader for validating. 77 | test_data_loader: DataLoader for testing. 78 | ''' 79 | self.train_data_loader = train_data_loader 80 | self.val_data_loader = val_data_loader 81 | self.test_data_loader = test_data_loader 82 | 83 | self.train_data = train_data 84 | self.eval_data = eval_data 85 | 86 | self.tokenizer = tokenizer 87 | 88 | def __load_model__(self, ckpt): 89 | ''' 90 | ckpt: Path of the checkpoint 91 | return: Checkpoint dict 92 | ''' 93 | if os.path.isfile(ckpt): 94 | checkpoint = torch.load(ckpt) 95 | logger.info("Successfully loaded checkpoint '%s'" % ckpt) 96 | return checkpoint 97 | else: 98 | raise Exception("No checkpoint found at '%s'" % ckpt) 99 | 100 | def item(self, x): 101 | ''' 102 | PyTorch before and after 0.4 103 | ''' 104 | torch_version = torch.__version__.split('.') 105 | if int(torch_version[0]) == 0 and int(torch_version[1]) < 4: 106 | return x[0] 107 | else: 108 | return x.item() 109 | 110 | def eval(self, model, opt): 111 | train_data_emb, train_rc_emb = util.get_emb(model, self.tokenizer, self.train_data, opt) 112 | eval_data_emb, _ = util.get_emb(model, self.tokenizer, self.eval_data, opt) 113 | if not opt.proto_emb: 114 | eval_acc = util.single_acc(train_data_emb, eval_data_emb) 115 | else: 116 | eval_acc = util.proto_acc(train_rc_emb, eval_data_emb) 117 | return eval_acc 118 | 119 | def train(self, model, B, N_for_train, K, Q, opt): 120 | logger.info("Start training...") 121 | 122 | # Init 123 | logger.info('Use bert optim!') 124 | parameters_to_optimize = list(model.named_parameters()) 125 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 126 | parameters_to_optimize = [ 127 | {'params': [p for n, p in parameters_to_optimize 128 | if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, 129 | {'params': [p for n, p in parameters_to_optimize 130 | if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 131 | ] 132 | optimizer = AdamW(parameters_to_optimize, lr=opt.lr, correct_bias=False) 133 | 134 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=opt.warmup_rate*opt.train_iter, num_training_steps=opt.train_iter) 135 | 136 | start_iter = 0 137 | 138 | model.train() 139 | 140 | # Training 141 | iter_loss = 0.0 142 | iter_right = 0.0 143 | iter_sample = 0.0 144 | iter_time = 0.0 145 | begin_time = time.time() 146 | for it in range(start_iter, start_iter + opt.train_iter): 147 | batch = next(self.train_data_loader) 148 | if opt.use_cuda: 149 | batch = tuple(t.cuda() for t in batch) 150 | logits, pred = model(batch, N_for_train, K, Q * N_for_train) 151 | loss = model.loss(logits, batch[-1]) 152 | right = model.accuracy(pred, batch[-1]) 153 | loss.backward() 154 | # torch.nn.utils.clip_grad_norm_(model.parameters(), 10) 155 | 156 | optimizer.step() 157 | scheduler.step() 158 | optimizer.zero_grad() 159 | 160 | iter_loss += self.item(loss.data) 161 | iter_right += self.item(right.data) 162 | iter_sample += 1 163 | 164 | step_time = time.time() - begin_time 165 | iter_time += step_time 166 | sys.stdout.write('step: %d | loss: %.6f, accuracy: %.2f, time/step: %.4f' % (it + 1, iter_loss / iter_sample, 100 * iter_right / iter_sample, iter_time / iter_sample) +'\r') 167 | sys.stdout.flush() 168 | 169 | if iter_sample % opt.eval_step == 0: 170 | if opt.do_eval: 171 | eval_start_time = time.time() 172 | model.eval() 173 | eval_model = model.sentence_encoder.module 174 | eval_acc = self.eval(eval_model, opt) 175 | logger.info("eval used time: %.4f —— proto: %s eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (time.time() - eval_start_time, opt.proto_emb, eval_acc[0], eval_acc[1], eval_acc[2])) 176 | model.train() 177 | 178 | if iter_sample % opt.save_step == 0: 179 | logger.info("save model into %s steps: %d" % (opt.save_ckpt, iter_sample)) 180 | torch.save(model.state_dict(), os.path.join(opt.save_ckpt, 'model_%d.bin') % iter_sample) 181 | begin_time = time.time() 182 | logger.info("\n####################\n") 183 | logger.info("Finish training") 184 | torch.save(model.state_dict(), os.path.join(opt.save_ckpt, 'model_final.bin')) -------------------------------------------------------------------------------- /fewshot_re_kit/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import json 4 | 5 | 6 | def text_to_tensor(text, tokenizer, max_seq_len): 7 | tokens = tokenizer.tokenize(text) 8 | if len(tokens) > max_seq_len - 2: 9 | tokens = tokens[: max_seq_len - 2] 10 | tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token] 11 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 12 | return torch.tensor([token_ids]) 13 | 14 | 15 | def get_emb(model, tokenizer, data, opt): 16 | id_to_emd = {} 17 | root_cause_emb = {} 18 | for root_cause, samples in data.items(): 19 | rc_emb = [] 20 | for i in range(len(samples)): 21 | abstract, description, index = samples[i] 22 | abs_id = text_to_tensor(abstract, tokenizer, opt.max_length) 23 | des_id = text_to_tensor(description, tokenizer, opt.max_length) 24 | if opt.use_cuda: 25 | abs_id = abs_id.cuda() 26 | des_id = des_id.cuda() 27 | emb = model(abs_id)[1] + model(des_id)[1] 28 | emb = emb.view(-1).detach().cpu().numpy().tolist() 29 | rc_emb.append(emb) 30 | id_to_emd[index] = [abstract, description, emb, root_cause] 31 | root_cause_emb[root_cause] = np.mean(np.array(rc_emb), axis=0).tolist() 32 | return id_to_emd, root_cause_emb 33 | 34 | 35 | def get_series_emb(emb): 36 | embeddings = [] 37 | label_to_id = {} 38 | i = 0 39 | for index, value in emb.items(): 40 | label_to_id[i] = index 41 | embeddings.append(value[2]) 42 | i += 1 43 | return embeddings, label_to_id 44 | 45 | 46 | def get_rc_emb(emb): 47 | embeddings = [] 48 | label_to_id = {} 49 | i = 0 50 | for index, value in emb.items(): 51 | label_to_id[i] = index 52 | embeddings.append(value) 53 | i += 1 54 | return embeddings, label_to_id 55 | 56 | 57 | def calculate_distance(S, Q): 58 | S = S.unsqueeze(0) # [1, N, D] 59 | Q = Q.unsqueeze(1) # [Q, 1, D] 60 | return -(torch.pow(S - Q, 2)).sum(2) # [Q, N] 61 | 62 | 63 | def get_similarity(id_to_emd_1, id_to_emd_2, proto=False): 64 | if not proto: 65 | emb_1, label_to_id_1 = get_series_emb(id_to_emd_1) 66 | else: 67 | emb_1, label_to_id_1 = get_rc_emb(id_to_emd_1) 68 | emb_2, label_to_id_2 = get_series_emb(id_to_emd_2) 69 | emb_1 = torch.tensor(emb_1) # train emb 70 | emb_2 = torch.tensor(emb_2) # eval emb 71 | similarity = calculate_distance(emb_1, emb_2) # [Q, N] 72 | return similarity, label_to_id_1, label_to_id_2 73 | 74 | 75 | def single_acc(id_to_emd_1, id_to_emd_2): 76 | ''' 77 | id_to_emd_1: 被查询的emb 78 | id_to_emd_2: 查询emb 79 | ''' 80 | similarity, label_to_id_1, label_to_id_2 = get_similarity(id_to_emd_1, id_to_emd_2) 81 | 82 | acc = [] 83 | res = {} 84 | for k in [1, 3, 5, 10, 50]: 85 | tmp_search_rc = [] 86 | true_num = 0 87 | _, indices = similarity.topk(k, dim=-1) 88 | indices = indices.numpy().tolist() 89 | for j in range(len(indices)): 90 | cur_res = indices[j] 91 | 92 | # 获取真实rc 93 | tmp_rc = [] 94 | for m in cur_res: 95 | rc = id_to_emd_1[label_to_id_1[m]][-1] 96 | tmp_rc.append(rc) 97 | tmp_search_rc.append(tmp_rc) 98 | 99 | for m in cur_res: 100 | rc = id_to_emd_1[label_to_id_1[m]][-1] 101 | if rc == id_to_emd_2[label_to_id_2[j]][-1]: 102 | true_num += 1 103 | break 104 | res['top%d' % k] = tmp_search_rc 105 | total_num = len(indices) 106 | acc.append(true_num / total_num) 107 | json.dump(res, open('./data/predict/single.json', 'w', encoding='utf8')) 108 | return acc 109 | 110 | 111 | def proto_acc(id_to_emd_1, id_to_emd_2): 112 | ''' 113 | id_to_emd: 查询emb 114 | proto_emb: 原型emb 115 | ''' 116 | similarity, label_to_id_1, label_to_id_2 = get_similarity(id_to_emd_1, id_to_emd_2, True) 117 | 118 | acc = [] 119 | res = {} 120 | for k in [1, 3, 5, 10]: 121 | tmp_search_rc = [] 122 | true_num = 0 123 | _, indices = similarity.topk(k, dim=-1) 124 | indices = indices.numpy().tolist() 125 | for j in range(len(indices)): 126 | cur_res = indices[j] 127 | 128 | # 获取真实rc 129 | tmp_rc = [] 130 | for m in cur_res: 131 | rc = label_to_id_1[m] 132 | tmp_rc.append(rc) 133 | tmp_search_rc.append(tmp_rc) 134 | 135 | for m in cur_res: 136 | rc = label_to_id_1[m] 137 | if rc == id_to_emd_2[label_to_id_2[j]][-1]: 138 | true_num += 1 139 | break 140 | res['top%d' % k] = tmp_search_rc 141 | total_num = len(indices) 142 | acc.append(true_num / total_num) 143 | json.dump(res, open('./data/predict/proto.json', 'w', encoding='utf8')) 144 | return acc 145 | 146 | 147 | def get_topK_RC(data, K): 148 | res = [] 149 | for k, v in data.items(): 150 | res.append((k, v)) 151 | res.sort(key=lambda tt: tt[1], reverse=True) 152 | r = [_[0] for _ in res[: K]] 153 | return r 154 | 155 | 156 | def policy_acc(train_data_emb, eval_data_emb, recall_num=100): # 57 157 | similarity, label_to_id_1, label_to_id_2 = get_similarity(train_data_emb, eval_data_emb) 158 | # 对候选结果进行归类,参数设置为30、15、10(若某个类别数据量过少,则会受到候选的影响) 159 | 160 | acc = [] 161 | res = {} # 获取预测的 root cause 结果 162 | for x in [1, 3, 5, 10, 50]: 163 | tmp_search_rc = [] 164 | true_num = 0 165 | _, indices = similarity.topk(similarity.shape[-1], dim=-1) 166 | 167 | indices = indices.numpy().tolist() # [Q, N] 168 | for j in range(len(indices)): 169 | cur_res = indices[j][:recall_num] 170 | root_cause_score = {} 171 | for m in cur_res: 172 | similar_data = train_data_emb[label_to_id_1[m]] 173 | rc = similar_data[-1] 174 | distance = similarity[j][m].item() 175 | distance = 1 / (-distance + 1e-5) 176 | if rc not in root_cause_score: 177 | root_cause_score[rc] = distance 178 | else: 179 | root_cause_score[rc] += distance 180 | predict_cause = get_topK_RC(root_cause_score, x) 181 | tmp_search_rc.append(predict_cause) 182 | cur_case = eval_data_emb[label_to_id_2[j]] 183 | cur_cause = cur_case[-1] 184 | if cur_cause in predict_cause: 185 | true_num += 1 186 | total_num = len(indices) 187 | acc.append(true_num / total_num) 188 | res['top%d' % x] = tmp_search_rc 189 | json.dump(res, open('./data/predict/policy.json', 'w', encoding='utf8')) 190 | return acc 191 | 192 | 193 | def vote_acc(t1, e1, t2, e2, proto=False, policy=False, recall_num=60): 194 | similarity_1, label_to_id_1, label_to_id_2 = get_similarity(t1, e1, proto) 195 | similarity_2, _, _ = get_similarity(t2, e2, proto) 196 | similarity = similarity_1 + similarity_2 197 | 198 | acc = [] 199 | if not policy: 200 | for k in [1, 3, 5, 10, 50]: 201 | true_num = 0 202 | _, indices = similarity.topk(k, dim=-1) 203 | indices = indices.numpy().tolist() 204 | for j in range(len(indices)): 205 | cur_res = indices[j] 206 | for m in cur_res: 207 | rc = t1[label_to_id_1[m]][-1] 208 | if rc == e1[label_to_id_2[j]][-1]: 209 | true_num += 1 210 | break 211 | total_num = len(indices) 212 | acc.append(true_num / total_num) 213 | else: 214 | for x in [1, 3, 5, 10, 50]: 215 | true_num = 0 216 | _, indices = similarity.topk(similarity.shape[-1], dim=-1) 217 | 218 | indices = indices.numpy().tolist() # [Q, N] 219 | for j in range(len(indices)): 220 | cur_res = indices[j][:recall_num] 221 | root_cause_score = {} 222 | for m in cur_res: 223 | similar_data = t1[label_to_id_1[m]] 224 | rc = similar_data[-1] 225 | distance = similarity[j][m].item() 226 | distance = 1 / (-distance + 1e-5) 227 | if rc not in root_cause_score: 228 | root_cause_score[rc] = distance 229 | else: 230 | root_cause_score[rc] += distance 231 | predict_cause = get_topK_RC(root_cause_score, x) 232 | cur_case = e1[label_to_id_2[j]] 233 | cur_cause = cur_case[-1] 234 | if cur_cause in predict_cause: 235 | true_num += 1 236 | total_num = len(indices) 237 | acc.append(true_num / total_num) 238 | 239 | return acc 240 | 241 | 242 | 243 | if __name__ == "__main__": 244 | train_data_emb = json.load(open('./data/train_emb.json', 'r', encoding='utf8')) 245 | eval_data_emb = json.load(open('./data/eval_emb.json', 'r', encoding='utf8')) 246 | train_rc_emb = json.load(open('./data/train_rc_emb.json', 'r', encoding='utf8')) 247 | 248 | acc1 = single_acc(train_data_emb, eval_data_emb) 249 | print(acc1) 250 | acc2 = proto_acc(train_rc_emb, eval_data_emb) 251 | print(acc2) 252 | acc3 = policy_acc(train_data_emb, eval_data_emb, 57) 253 | print(acc3) 254 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import proto -------------------------------------------------------------------------------- /models/proto.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import fewshot_re_kit 4 | 5 | 6 | class Proto(fewshot_re_kit.framework.FewShotREModel): 7 | def __init__(self, sentence_encoder, config): 8 | fewshot_re_kit.framework.FewShotREModel.__init__(self, sentence_encoder) 9 | self.drop = nn.Dropout(config.dropout) 10 | 11 | def __dist__(self, x, y, dim): 12 | return -(torch.pow(x - y, 2)).sum(dim) 13 | 14 | def __batch_dist__(self, S, Q): 15 | return self.__dist__(S.unsqueeze(1), Q.unsqueeze(2), 3) 16 | 17 | def forward(self, batch, N, K, total_Q): 18 | ''' 19 | support: Inputs of the support set. 20 | query: Inputs of the query set. 21 | N: Num of classes 22 | K: Num of instances for each class in the support set 23 | Q: Num of instances in the query set 24 | ''' 25 | support_abs_emb = self.sentence_encoder(batch[0], batch[3], batch[1], batch[2])[1] # (B * N * K, D), where D is the hidden size 26 | support_des_emb = self.sentence_encoder(batch[4], batch[7], batch[5], batch[6])[1] 27 | support_emb = support_abs_emb + support_des_emb 28 | 29 | query_abs_emb = self.sentence_encoder(batch[8], batch[11], batch[9], batch[10])[1] 30 | query_des_emb = self.sentence_encoder(batch[12], batch[15], batch[13], batch[14])[1] 31 | query_emb = query_abs_emb + query_des_emb # (B * total_Q, D) 32 | 33 | hidden_size = support_emb.size(-1) 34 | support = self.drop(support_emb) 35 | query = self.drop(query_emb) 36 | support = support.view(-1, N, K, hidden_size) # (B, N, K, D) 37 | query = query.view(-1, total_Q, hidden_size) # (B, total_Q, D) 38 | 39 | B = support.size(0) # Batch size 40 | 41 | # Prototypical Networks 42 | support = torch.mean(support, 2) # Calculate prototype for each class 43 | logits = self.__batch_dist__(support, query) # (B, total_Q, N) 44 | _, pred = torch.max(logits.view(-1, N), 1) 45 | return logits, pred -------------------------------------------------------------------------------- /run_main.py: -------------------------------------------------------------------------------- 1 | from fewshot_re_kit.data_loader import get_loader, read_data 2 | from fewshot_re_kit.framework import FewShotREFramework 3 | from fewshot_re_kit import util 4 | from models.proto import Proto 5 | import torch 6 | import numpy as np 7 | import argparse 8 | from transformers import BertModel, BertConfig, BertTokenizer 9 | import logging 10 | import random 11 | import json 12 | import os 13 | # import test 14 | 15 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 16 | datefmt='%m/%d/%Y %H:%M:%S', 17 | level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def main(): 22 | parser = argparse.ArgumentParser() 23 | 24 | # 模型相关 25 | parser.add_argument('--do_train', default=True, type=bool, help='do train') 26 | parser.add_argument('--do_eval', default=True, type=bool, help='do eval') 27 | parser.add_argument('--do_predict', default=False, type=bool, help='do predict') 28 | parser.add_argument('--do_cn_eval', default=False, type=bool, help='do CN eval') 29 | 30 | parser.add_argument('--proto_emb', default=False, help='Get root cause proto emb or sentence emb. Require do_predict=True') 31 | parser.add_argument('--train_file', default='./data/source_add_CN_V2.xlsx', help='source file') 32 | 33 | parser.add_argument('--trainN', default=5, type=int, help='N in train') 34 | parser.add_argument('--N', default=5, type=int, help='N way') 35 | parser.add_argument('--K', default=3, type=int, help='K shot') 36 | parser.add_argument('--Q', default=2, type=int, help='Num of query per class') 37 | parser.add_argument('--batch_size', default=8, type=int, help='batch size') 38 | parser.add_argument('--train_iter', default=10000, type=int, help='num of iters in training') 39 | parser.add_argument('--warmup_rate', default=0.1, type=float) 40 | parser.add_argument('--max_length', default=128, type=int, help='max length') 41 | parser.add_argument('--lr', default=1e-5, type=float, help='learning rate') 42 | parser.add_argument('--dropout', default=0.0, type=float, help='dropout rate') 43 | parser.add_argument('--seed', default=100, type=int) # 100 44 | 45 | # 保存与加载 46 | parser.add_argument('--load_ckpt', default='./check_points/model_54000.bin', help='load ckpt') 47 | parser.add_argument('--save_ckpt', default='./check_points/', help='save ckpt') 48 | parser.add_argument('--save_emb', default='./data/emb.json', help='save embedding') 49 | parser.add_argument('--save_root_emb', default='./data/root_emb.json', help='save embedding') 50 | 51 | parser.add_argument('--use_cuda', default=True, help='whether to use cuda') 52 | parser.add_argument('--eval_step', default=100) 53 | parser.add_argument('--save_step', default=500) 54 | parser.add_argument('--threshold', default=5) 55 | 56 | # bert pretrain 57 | parser.add_argument("--vocab_file", default="./pretrain/vocab.txt", type=str, help="Init vocab to resume training from.") 58 | parser.add_argument("--config_path", default="./pretrain/bert_config.json", type=str, help="Init config to resume training from.") 59 | parser.add_argument("--init_checkpoint", default="./pretrain/pytorch_model.bin", type=str, help="Init checkpoint to resume training from.") 60 | 61 | opt = parser.parse_args() 62 | trainN = opt.trainN 63 | K = opt.K 64 | Q = opt.Q 65 | batch_size = opt.batch_size 66 | max_length = opt.max_length 67 | 68 | logger.info("{}-way-{}-shot Few-Shot Dignose".format(trainN, K)) 69 | logger.info("max_length: {}".format(max_length)) 70 | 71 | random.seed(opt.seed) 72 | np.random.seed(opt.seed) 73 | torch.manual_seed(opt.seed) 74 | 75 | if not os.path.exists(opt.save_ckpt): 76 | os.mkdir(opt.save_ckpt) 77 | 78 | bert_tokenizer = BertTokenizer.from_pretrained(opt.vocab_file) 79 | bert_config = BertConfig.from_pretrained(opt.config_path) 80 | bert_model = BertModel.from_pretrained(opt.init_checkpoint, config=bert_config) 81 | model = Proto(bert_model, opt) 82 | if opt.use_cuda: 83 | model.cuda() 84 | 85 | if opt.do_train: 86 | train_data, eval_data = read_data(opt.train_file, opt.threshold) 87 | train_data_loader = get_loader(train_data, bert_tokenizer, max_length, N=trainN, K=K, Q=Q, batch_size=batch_size) 88 | 89 | framework = FewShotREFramework(tokenizer=bert_tokenizer, 90 | train_data_loader=train_data_loader, 91 | train_data=train_data, 92 | eval_data=eval_data) 93 | framework.train(model, batch_size, trainN, K, Q, opt) 94 | 95 | if opt.do_eval: 96 | train_data, eval_data = read_data(opt.train_file, opt.threshold) 97 | state_dict = torch.load(opt.load_ckpt) 98 | own_state = bert_model.state_dict() 99 | for name, param in state_dict.items(): 100 | name = name.replace('sentence_encoder.module.', '') 101 | if name not in own_state: 102 | continue 103 | own_state[name].copy_(param) 104 | step = opt.load_ckpt.split('/')[-1].replace('model_', '').split('.')[0] 105 | bert_model.eval() 106 | train_data_emb, train_rc_emb = util.get_emb(bert_model, bert_tokenizer, train_data, opt) 107 | eval_data_emb, _ = util.get_emb(bert_model, bert_tokenizer, eval_data, opt) 108 | acc1 = util.single_acc(train_data_emb, eval_data_emb) 109 | acc2 = util.proto_acc(train_rc_emb, eval_data_emb) 110 | acc3 = util.policy_acc(train_data_emb, eval_data_emb) 111 | logger.info("single eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc1[0], acc1[1], acc1[2])) 112 | logger.info("proto eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc2[0], acc2[1], acc2[2])) 113 | logger.info("policy eval accuracy: [top1: %.4f] [top3: %.4f] [top5: %.4f]" % (acc3[0], acc3[1], acc3[2])) 114 | 115 | with open('./data/train_emb_%s.json' % step, 'w', encoding='utf8') as f: 116 | json.dump(train_data_emb, f, ensure_ascii=False) 117 | 118 | with open('./data/train_rc_emb_%s.json' % step, 'w', encoding='utf8') as f: 119 | json.dump(train_rc_emb, f, ensure_ascii=False) 120 | 121 | with open('./data/eval_emb_%s.json' % step, 'w', encoding='utf8') as f: 122 | json.dump(eval_data_emb, f, ensure_ascii=False) 123 | 124 | if opt.do_predict: 125 | test_data = read_data(opt.train_file, opt.threshold, False) 126 | # predict proto emb or sentence emb 127 | state_dict = torch.load(opt.load_ckpt) 128 | own_state = bert_model.state_dict() 129 | for name, param in state_dict.items(): 130 | name = name.replace('sentence_encoder.module.', '') 131 | if name not in own_state: 132 | continue 133 | own_state[name].copy_(param) 134 | bert_model.eval() 135 | id_to_emd, root_cause_emb = util.get_emb(bert_model, bert_tokenizer, test_data, opt) 136 | 137 | if opt.save_emb and opt.save_root_emb: 138 | with open(opt.save_emb, 'w', encoding='utf8') as f: 139 | json.dump(id_to_emd, f, ensure_ascii=False) 140 | 141 | with open(opt.save_root_emb, 'w', encoding='utf8') as f: 142 | json.dump(root_cause_emb, f, ensure_ascii=False) 143 | 144 | 145 | if __name__ == "__main__": 146 | main() 147 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from fewshot_re_kit.data_loader import read_data 2 | import torch 3 | import json 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import logging 8 | 9 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 10 | datefmt='%m/%d/%Y %H:%M:%S', 11 | level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def text_to_tensor(text, bert_tokenizer, max_seq_len): 16 | tokens = bert_tokenizer.tokenize(text) 17 | if len(tokens) > max_seq_len - 2: 18 | tokens = tokens[: max_seq_len - 2] 19 | tokens = [bert_tokenizer.cls_token] + tokens + [bert_tokenizer.sep_token] 20 | token_ids = bert_tokenizer.convert_tokens_to_ids(tokens) 21 | return torch.tensor([token_ids]) 22 | 23 | 24 | def predict(opt, bert_model, bert_tokenizer): 25 | if not os.path.exists(opt.save_emb): 26 | state_dict = torch.load(opt.load_ckpt) 27 | own_state = bert_model.state_dict() 28 | for name, param in state_dict.items(): 29 | name = name.replace('sentence_encoder.module.', '') 30 | if name not in own_state: 31 | continue 32 | own_state[name].copy_(param) 33 | bert_model.eval() 34 | root_dict = read_data(opt.train_file, opt.threshold) 35 | id_to_emd = {} 36 | root_cause_emb = {} 37 | for root_cause, samples in tqdm(root_dict.items()): 38 | rc_emb = [] 39 | for i in range(len(samples)): 40 | abstract, description, index = samples[i] 41 | abs_id = text_to_tensor(abstract, bert_tokenizer, opt.max_length) 42 | des_id = text_to_tensor(description, bert_tokenizer, opt.max_length) 43 | if opt.use_cuda: 44 | abs_id = abs_id.cuda() 45 | des_id = des_id.cuda() 46 | emb = bert_model(abs_id)[1] + bert_model(des_id)[1] 47 | emb = emb.view(-1).detach().cpu().numpy().tolist() 48 | rc_emb.append(emb) 49 | id_to_emd[index] = [abstract, description, emb, root_cause] 50 | root_cause_emb[root_cause] = np.mean(np.array(rc_emb), axis=0).tolist() 51 | json.dump(id_to_emd, open(opt.save_emb, 'w', encoding='utf8'), ensure_ascii=False) 52 | json.dump(root_cause_emb, open(opt.save_root_emb, 'w', encoding='utf8'), ensure_ascii=False) 53 | else: 54 | id_to_emd = json.load(open(opt.save_emb, 'r', encoding='utf8')) 55 | root_cause_emb = json.load(open(opt.save_emb, 'r', encoding='utf8')) 56 | return id_to_emd, root_cause_emb 57 | 58 | 59 | def calculate_distance(S, Q): 60 | S = S.unsqueeze(0) # [1, N, D] 61 | Q = Q.unsqueeze(1) # [Q, 1, D] 62 | return -(torch.pow(S - Q, 2)).sum(2) 63 | 64 | 65 | def test_acc(id_to_emd=None): 66 | if id_to_emd is None: 67 | id_to_emd = json.load(open('./data/emb.json', 'r', encoding='utf8')) 68 | embeddings = [] 69 | label_to_id = {} 70 | i = 0 71 | for index, value in id_to_emd.items(): 72 | label_to_id[i] = index 73 | embeddings.append(value[2]) 74 | i += 1 75 | embeddings = torch.tensor(embeddings) # [N, 768] 76 | similarity = calculate_distance(embeddings, embeddings) # [N, N] 77 | mask = torch.triu(torch.ones_like(similarity), diagonal=1) 78 | mask = mask + mask.transpose(1, 0) 79 | mask = (1 - mask) * -10000 80 | similarity = similarity + mask 81 | 82 | for k in [1, 3, 5]: 83 | true_num = 0 84 | _, indices = similarity.topk(k, dim=-1) 85 | indices = indices.numpy().tolist() 86 | for j in range(len(indices)): 87 | cur_res = indices[j] 88 | for m in cur_res: 89 | rc = id_to_emd[label_to_id[m]][-1] 90 | if rc == id_to_emd[label_to_id[j]][-1]: 91 | true_num += 1 92 | break 93 | total_num = len(indices) 94 | acc = true_num / total_num 95 | logger.info('top %d : accuracy —— %.4f' % (k, acc)) 96 | logger.info('total numbers: %d' % total_num) 97 | logger.info('true numbers: %d' % true_num) 98 | 99 | 100 | def classification_test(id_to_emd=None, root_cause_emb=None): 101 | if id_to_emd is None and root_cause_emb is None: 102 | id_to_emd = json.load(open('./data/emb.json', 'r', encoding='utf8')) 103 | root_cause_emb = json.load(open('./data/root_emb.json', 'r', encoding='utf8')) 104 | pass 105 | 106 | 107 | if __name__ == "__main__": 108 | test_acc() 109 | print('OK') --------------------------------------------------------------------------------