├── models ├── __init__.py ├── README.md ├── selector_one.py ├── seq2seq_model_t5_dismatch.py ├── seq2seq_model_t5.py ├── seq2seq_model.py ├── seq2seq_model_dismatch.py └── predictor_v5.py ├── utils ├── README.md ├── __init__.py ├── SimCLS.py └── snippets.py ├── data_utils ├── README.md ├── __init__.py ├── data_process_fxy.py ├── search_exp.py ├── cat_baseline_exp.py ├── Thematic_Similarity.py ├── seq2seq_convert_NILE.py ├── generate_faithful.py ├── cat_predictor.py ├── predictor_convert.py ├── predictor_convert_t5.py ├── analyse_data.py ├── extract_data_process.py ├── baseline_dataset.py ├── seq2seq_convert_xLIRE.py ├── seq2seq_convert_cail.py ├── splite_to_ner.py └── seq2seq_convert.py ├── dataset └── README.md ├── requirements.txt ├── LICENSE └── README.md /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_utils/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from bert_seq2seq.model.gpt2_model import GPT2LMHeadModel, GPT2Config -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # import ir_datasets 2 | # dataset = ir_datasets.load("trec-robust04/fold1") -------------------------------------------------------------------------------- /dataset/README.md: -------------------------------------------------------------------------------- 1 | eCAIL dataset 2 | https://drive.google.com/file/d/1ixjnkpGvM8RL7arxFDrCMiVWzJtifQYv/view?usp=sharing 3 | ELAM dataset is in CAIL2022 competition. 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.1+cu111 2 | transformers>=4.20.1 3 | numpy>=1.20.1 4 | jieba>=0.42.1 5 | six>=1.15.0 6 | rouge>=1.0.1 7 | tqdm>=4.62.3 8 | scikit-learn>=1.0.1 9 | pandas>=1.2.4 10 | nni>=2.6.1 11 | matplotlib>=3.3.4 12 | termcolor>=1.1.0 13 | networkx>=2.5 14 | requests>=2.25.1 15 | filelock>=3.0.12 16 | textrank4zh>=0.3 17 | gensim>=3.8.3 18 | openprompt>=1.0 19 | scipy>=1.8.0 20 | seaborn>=0.11.1 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2022] [ELAM dataset] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /data_utils/data_process_fxy.py: -------------------------------------------------------------------------------- 1 | import json 2 | all_data = [] 3 | with open("/home/zhongxiang_sun/code/explanation_project/explanation_model/dataset/fxy/all_clean.json", "r") as f: 4 | for line in f: 5 | item = json.loads(line) 6 | temp_dic = {"case_A":[item["splited_sentence"]]} 7 | data_list = [] 8 | aleady_data = [] 9 | for span_key in item["spans"].keys(): 10 | if span_key[5:8] == "001": 11 | label = 1 12 | if span_key[5:7] == "02": 13 | label = 2 14 | if span_key[5:8] == "003": 15 | label = 3 16 | for span in item["spans"][span_key]: 17 | for sent in range(span[0], span[1]+1): 18 | data_list.append([sent, label]) 19 | aleady_data.append(sent) 20 | for i in range(len(item["splited_sentence"])): 21 | if i in aleady_data: 22 | pass 23 | else: 24 | data_list.append([i, 0]) 25 | temp_dic["case_A"].append(data_list) 26 | all_data.append(temp_dic) 27 | with open("/home/zhongxiang_sun/code/explanation_project/explanation_model/dataset/fxy/fxy.json",'w') as f: 28 | for line in all_data: 29 | f.writelines(json.dumps(line, ensure_ascii=False)) 30 | f.write('\n') 31 | 32 | 33 | -------------------------------------------------------------------------------- /data_utils/search_exp.py: -------------------------------------------------------------------------------- 1 | import json 2 | our_data_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/heat_map/temp.json" 3 | xLIRE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_xLIRE.json" 4 | wo_token_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_wo_token.json" 5 | NILE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_NILE.json" 6 | 7 | with open(our_data_file, 'r') as f: 8 | for line in f: 9 | our_data = json.loads(line) 10 | 11 | data_NILE = [] 12 | with open(NILE_file, 'r') as f: 13 | for line in f: 14 | data_NILE.append(json.loads(line)) 15 | 16 | data_xLIRE = [] 17 | with open(xLIRE_file, 'r') as f: 18 | for line in f: 19 | data_xLIRE.append(json.loads(line)) 20 | 21 | data_wo_token = [] 22 | with open(wo_token_file, 'r') as f: 23 | for line in f: 24 | data_wo_token.append(json.loads(line)) 25 | 26 | 27 | print("NILE ", data_NILE[2985]['exp'][1]) 28 | 29 | print("xLIRE ", data_xLIRE[2985]['exp'][1]) 30 | 31 | print("wo_token ", data_wo_token[2985]['exp'][1]) 32 | 33 | print("golden exp", data_wo_token[2985]["explanation"]) 34 | # for i, data in enumerate(data_NILE): 35 | # if data["source_2_dis"][0] == "".join(our_data["case_A"][0]) and data["source_2_dis"][1] == "".join(our_data["case_B"][0]): 36 | # print(i) 37 | 38 | 39 | 40 | 41 | -------------------------------------------------------------------------------- /data_utils/cat_baseline_exp.py: -------------------------------------------------------------------------------- 1 | import json 2 | exp_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_wo_token.json" 3 | wo_rationale = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/ELAM_bert_legal_wo_rationale.json" 4 | rationale = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/ELAM_bert_legal_rationale.json" 5 | all_sents = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/ELAM_bert_legal_all_sents.json" 6 | 7 | save_wo_rationale = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/stage3/data_prediction_wo_rationale.json" 8 | save_rationale = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/stage3/data_prediction_rationale.json" 9 | save_all_sents = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/stage3/data_prediction_all_sents.json" 10 | 11 | 12 | def cat_file(exp_files, second_file, save_file): 13 | match_data = [] 14 | midmatch_data = [] 15 | dismatch_data = [] 16 | exp_data = [] 17 | with open(second_file, 'r') as f: 18 | for line in f: 19 | item = json.loads(line) 20 | if item['label'] == 2: 21 | match_data.append(item) 22 | elif item['label'] == 1: 23 | midmatch_data.append(item) 24 | elif item['label'] == 0: 25 | dismatch_data.append(item) 26 | else: 27 | exit() 28 | datas = match_data + midmatch_data + dismatch_data 29 | with open(exp_files, 'r') as f: 30 | for i, line in enumerate(f): 31 | item = json.loads(line) 32 | item['case_a'] = datas[i]['case_a'] 33 | item['case_b'] = datas[i]['case_b'] 34 | exp_data.append(item) 35 | 36 | with open(save_file, 'w') as f: 37 | for item in exp_data: 38 | f.writelines(json.dumps(item, ensure_ascii=False)) 39 | f.write('\n') 40 | 41 | cat_file(exp_file, wo_rationale, save_wo_rationale) 42 | cat_file(exp_file, rationale, save_rationale) 43 | cat_file(exp_file, all_sents, save_all_sents) 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /data_utils/Thematic_Similarity.py: -------------------------------------------------------------------------------- 1 | import json 2 | def process_cail_data(file_path, save_path): 3 | datas = [] 4 | with open(file_path, 'r') as f: 5 | for line in f: 6 | item = json.loads(line) 7 | for k in item["case_A_dic_small"].keys(): 8 | sents = "" 9 | for sent in item["case_A_dic_small"][k]: 10 | sents += sent[0] 11 | item["case_A_dic_small"][k] = sents 12 | for k in item["case_B_dic_small"].keys(): 13 | sents = "" 14 | for sent in item["case_B_dic_small"][k]: 15 | sents += sent[0] 16 | item["case_B_dic_small"][k] = sents 17 | datas.append({"case_A":item["case_A_dic_small"], "case_B":item["case_B_dic_small"], "label":item["label"]}) 18 | with open(save_path, 'w') as f: 19 | for l in datas: 20 | f.writelines(json.dumps(l, ensure_ascii=False)) 21 | f.write('\n') 22 | 23 | 24 | def process_ELAM_data(file_path, save_path): 25 | datas = [] 26 | final_data = [] 27 | for file in file_path: 28 | with open(file, 'r') as f: 29 | datas += json.load(f) 30 | 31 | for item in datas: 32 | final_data.append({"case_A": {item["case_A"][0]["tag"]:item["case_A"][0]["content"], item["case_A"][1]["tag"]:item["case_A"][1]["content"]}, 33 | "case_B": {item["case_B"][0]["tag"]:item["case_B"][0]["content"], item["case_B"][1]["tag"]:item["case_B"][1]["content"]}, 34 | "label":item["gold_label"]}) 35 | with open(save_path, 'w') as f: 36 | for l in final_data: 37 | f.writelines(json.dumps(l, ensure_ascii=False)) 38 | f.write('\n') 39 | 40 | if __name__ == '__main__': 41 | cail_file_path = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/process_data.json" 42 | our_data_train = "/new_disk2/zhongxiang_sun/code/explanation_project/NER/data/law_train.json" 43 | our_data_test = "/new_disk2/zhongxiang_sun/code/explanation_project/NER/data/law_test.json" 44 | our_data_dev = "/new_disk2/zhongxiang_sun/code/explanation_project/NER/data/law_dev.json" 45 | 46 | ELAM_files = [our_data_train, our_data_dev, our_data_test] 47 | process_cail_data(cail_file_path, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/CAIL_bert_legal_Thematic_Similarity.json") 48 | process_ELAM_data(ELAM_files, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/baselines_datasets/ELAM_bert_legal_Thematic_Similarity.json") 49 | -------------------------------------------------------------------------------- /utils/SimCLS.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/maszhongming/MatchSum 2 | import torch 3 | from torch import nn 4 | from transformers import RobertaModel 5 | 6 | 7 | def RankingLoss(score, summary_score=None, margin=0, gold_margin=0, gold_weight=1, no_gold=False, no_cand=False): 8 | ones = torch.ones_like(score) 9 | loss_func = torch.nn.MarginRankingLoss(0.0) 10 | TotalLoss = loss_func(score, score, ones) 11 | # candidate loss 12 | n = score.size(1) 13 | if not no_cand: 14 | for i in range(1, n): 15 | pos_score = score[:, :-i] 16 | neg_score = score[:, i:] 17 | pos_score = pos_score.contiguous().view(-1) 18 | neg_score = neg_score.contiguous().view(-1) 19 | ones = torch.ones_like(pos_score) 20 | loss_func = torch.nn.MarginRankingLoss(margin * i) 21 | loss = loss_func(pos_score, neg_score, ones) 22 | TotalLoss += loss 23 | if no_gold: 24 | return TotalLoss 25 | # gold summary loss 26 | pos_score = summary_score.unsqueeze(-1).expand_as(score) 27 | neg_score = score 28 | pos_score = pos_score.contiguous().view(-1) 29 | neg_score = neg_score.contiguous().view(-1) 30 | ones = torch.ones_like(pos_score) 31 | loss_func = torch.nn.MarginRankingLoss(gold_margin) 32 | TotalLoss += gold_weight * loss_func(pos_score, neg_score, ones) 33 | return TotalLoss 34 | 35 | 36 | class ReRanker(nn.Module): 37 | def __init__(self, encoder, pad_token_id): 38 | super(ReRanker, self).__init__() 39 | self.encoder = RobertaModel.from_pretrained(encoder) 40 | self.pad_token_id = pad_token_id 41 | 42 | def forward(self, text_id, candidate_id, summary_id=None, require_gold=True): 43 | 44 | batch_size = text_id.size(0) 45 | 46 | input_mask = text_id != self.pad_token_id 47 | out = self.encoder(text_id, attention_mask=input_mask)[0] 48 | doc_emb = out[:, 0, :] 49 | 50 | if require_gold: 51 | # get reference score 52 | input_mask = summary_id != self.pad_token_id 53 | out = self.encoder(summary_id, attention_mask=input_mask)[0] 54 | summary_emb = out[:, 0, :] 55 | summary_score = torch.cosine_similarity(summary_emb, doc_emb, dim=-1) 56 | 57 | candidate_num = candidate_id.size(1) 58 | candidate_id = candidate_id.view(-1, candidate_id.size(-1)) 59 | input_mask = candidate_id != self.pad_token_id 60 | out = self.encoder(candidate_id, attention_mask=input_mask)[0] 61 | candidate_emb = out[:, 0, :].view(batch_size, candidate_num, -1) 62 | 63 | # get candidate score 64 | doc_emb = doc_emb.unsqueeze(1).expand_as(candidate_emb) 65 | score = torch.cosine_similarity(candidate_emb, doc_emb, dim=-1) 66 | 67 | output = {'score': score} 68 | if require_gold: 69 | output['summary_score'] = summary_score 70 | return output -------------------------------------------------------------------------------- /data_utils/seq2seq_convert_NILE.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from models.selector_two_multi_class_ot_v3 import Selector2_mul_class, args, load_checkpoint, OT, load_data, data_extract_npy, data_extract_json, device 4 | import torch 5 | from utils.snippets import * 6 | import torch 7 | import json 8 | 9 | 10 | 11 | def fold_convert_our_data_ot(data, data_x, type, generate=False, generate_mode = 'cluster'): 12 | """每一fold用对应的模型做数据转换 13 | """ 14 | max_len = 0 15 | with torch.no_grad(): 16 | results = [] 17 | print(type+"ing") 18 | for i, d in enumerate(data): 19 | if type == 'match' and d["label"] == 2 or type == 'midmatch' and d["label"] == 1 or type == 'dismatch' and d["label"] == 0: 20 | case_a = d['case_A'] 21 | case_b = d['case_B'] 22 | source_1_a = ''.join(case_a[0]) 23 | source_1_b = ''.join(case_b[0]) 24 | source_2_a = ''.join(case_a[0]) 25 | source_2_b = ''.join(case_b[0]) 26 | max_len = max(max_len, len(source_1_a+source_1_b)) 27 | max_len = max(max_len, len(source_2_a+source_2_b)) 28 | 29 | # result = { 30 | # 'source_1': source_1_a + source_1_b, 31 | # 'source_2': source_2_a + source_2_b, 32 | # 'explanation': ';'.join(list(d['explanation'].values())), 33 | # 'source_1_dis': [source_1_a, source_1_b], 34 | # 'source_2_dis': [source_2_a, source_2_b], 35 | # 'label': d['label'] 36 | # } 37 | result = { 38 | 'source_1': source_1_a + source_1_b, 39 | 'source_2': source_2_a + source_2_b, 40 | 'explanation': d['explanation'], 41 | 'source_1_dis': [source_1_a, source_1_b], 42 | 'source_2_dis': [source_2_a, source_2_b], 43 | 'label': d['label'] 44 | } 45 | results.append(result) 46 | print(max_len) 47 | if generate: 48 | return results 49 | 50 | 51 | 52 | 53 | def convert(filename, data, data_x, type, generate_mode): 54 | """转换为生成式数据 55 | """ 56 | total_results = fold_convert_our_data_ot(data, data_x, type, generate=True, generate_mode=generate_mode) 57 | 58 | with open(filename, 'w') as f: 59 | for item in total_results: 60 | f.writelines(json.dumps(item, ensure_ascii=False)) 61 | f.write('\n') 62 | 63 | 64 | 65 | if __name__ == '__main__': 66 | data_extract_json = '../dataset/data_extract.json' 67 | data_extract_npy = '../dataset/data_extract.npy' 68 | data = load_data(data_extract_json) 69 | data_x = np.load(data_extract_npy) 70 | da_type = "sort" 71 | match_data_seq2seq_json = '../dataset/match_data_seq2seq_NILE.json' 72 | midmatch_data_seq2seq_json = '../dataset/midmatch_data_seq2seq_NILE.json' 73 | dismatch_data_seq2seq_json = '../dataset/dismatch_data_seq2seq_NILE.json' 74 | convert(match_data_seq2seq_json, data, data_x, type='match', generate_mode=da_type) 75 | convert(midmatch_data_seq2seq_json, data, data_x, type='midmatch', generate_mode=da_type) 76 | convert(dismatch_data_seq2seq_json, data, data_x, type='dismatch', generate_mode=da_type) 77 | 78 | 79 | print(u'输出over!') 80 | -------------------------------------------------------------------------------- /data_utils/generate_faithful.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | ELAM_xLIRE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_xLIRE.json" 5 | ELAM_wo_token_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_wo_token.json" 6 | ELAM_NILE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_NILE.json" 7 | 8 | ELAM_xLIRE_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/ELAM_sample_data_prediction_t5_xLIRE.json" 9 | ELAM_wo_token_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/ELAM_sample_data_prediction_t5_wo_token.json" 10 | ELAM_NILE_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/ELAM_sample_data_prediction_t5_NILE.json" 11 | 12 | 13 | 14 | CAIL_xLIRE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/data_prediction_t5_xLIRE.json" 15 | CAIL_wo_token_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/data_prediction_t5_wo_token.json" 16 | CAIL_NILE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/data_prediction_t5_NILE.json" 17 | CAIL_xLIRE_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/CAIL_sample_data_prediction_t5_xLIRE.json" 18 | CAIL_wo_token_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/CAIL_sample_data_prediction_t5_wo_token.json" 19 | CAIL_NILE_file_save = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/CAIL_sample_data_prediction_t5_NILE.json" 20 | 21 | 22 | 23 | def proecess(file_path, save_path): 24 | random.seed(1) 25 | match_list = [] 26 | midmatch_list = [] 27 | dismatch_list = [] 28 | 29 | with open(file_path, 'r') as f: 30 | for line in f: 31 | item = json.loads(line) 32 | if item['label'] == 2: 33 | match_list.append({"case_A":item["source_1_dis"][0], "case_B":item["source_1_dis"][1], "explanation":item["exp"][2-item["label"]], "golden_exp":item["explanation"]}) 34 | elif item['label'] == 1: 35 | midmatch_list.append({"case_A":item["source_1_dis"][0], "case_B":item["source_1_dis"][1], "explanation":item["exp"][2-item["label"]], "golden_exp":item["explanation"]}) 36 | elif item['label'] == 0: 37 | dismatch_list.append({"case_A":item["source_1_dis"][0], "case_B":item["source_1_dis"][1], "explanation":item["exp"][2-item["label"]], "golden_exp":item["explanation"]}) 38 | 39 | dev_test = dismatch_list[int(0.8*len(dismatch_list)):] + midmatch_list[int(0.8*len(midmatch_list)):] #+ match_list[int(0.8*len(match_list)):] 40 | dev_test_sample = random.sample(dev_test, 100) 41 | with open(save_path, 'w') as f: 42 | for line in dev_test_sample: 43 | f.writelines(json.dumps(line, ensure_ascii=False)) 44 | f.write('\n') 45 | 46 | if __name__ == '__main__': 47 | proecess(ELAM_xLIRE_file, ELAM_xLIRE_file_save) 48 | proecess(ELAM_NILE_file, ELAM_NILE_file_save) 49 | proecess(ELAM_wo_token_file, ELAM_wo_token_file_save) 50 | proecess(CAIL_xLIRE_file, CAIL_xLIRE_file_save) 51 | proecess(CAIL_NILE_file, CAIL_NILE_file_save) 52 | proecess(CAIL_wo_token_file, CAIL_wo_token_file_save) 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /data_utils/cat_predictor.py: -------------------------------------------------------------------------------- 1 | import json 2 | dismatch_files = ["/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/dismatch_data_prediction_t5_NILE.json", 3 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/dismatch_data_prediction_t5_wo_token.json", 4 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/dismatch_data_prediction_t5_xLIRE.json"] 5 | 6 | midmatch_files = ["/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/midmatch_data_prediction_t5_NILE.json", 7 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/midmatch_data_prediction_t5_wo_token.json", 8 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/midmatch_data_prediction_t5_xLIRE.json"] 9 | 10 | match_files = ["/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/match_data_prediction_t5_NILE.json", 11 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/match_data_prediction_t5_wo_token.json", 12 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/match_data_prediction_t5_xLIRE.json"] 13 | 14 | wo_token_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_wo_token.json" 15 | NILE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_NILE.json" 16 | xLIRE_file = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_prediction_t5_xLIRE.json" 17 | 18 | NILE_data = [] 19 | with open(match_files[0], 'r') as f: 20 | for line in f: 21 | NILE_data.append(json.loads(line)) 22 | 23 | with open(midmatch_files[0], 'r') as f: 24 | for i, line in enumerate(f): 25 | item = json.loads(line) 26 | NILE_data[i]['exp'] += item['exp'] 27 | 28 | with open(dismatch_files[0], 'r') as f: 29 | for i, line in enumerate(f): 30 | item = json.loads(line) 31 | NILE_data[i]['exp'] += item['exp'] 32 | 33 | wo_token_data = [] 34 | with open(match_files[1], 'r') as f: 35 | for line in f: 36 | wo_token_data.append(json.loads(line)) 37 | 38 | with open(midmatch_files[1], 'r') as f: 39 | for i, line in enumerate(f): 40 | item = json.loads(line) 41 | wo_token_data[i]['exp'] += item['exp'] 42 | 43 | with open(dismatch_files[1], 'r') as f: 44 | for i, line in enumerate(f): 45 | item = json.loads(line) 46 | wo_token_data[i]['exp'] += item['exp'] 47 | 48 | xLIRE_data = [] 49 | with open(match_files[2], 'r') as f: 50 | for line in f: 51 | xLIRE_data.append(json.loads(line)) 52 | 53 | with open(midmatch_files[2], 'r') as f: 54 | for i, line in enumerate(f): 55 | item = json.loads(line) 56 | xLIRE_data[i]['exp'] += item['exp'] 57 | 58 | with open(dismatch_files[2], 'r') as f: 59 | for i, line in enumerate(f): 60 | item = json.loads(line) 61 | xLIRE_data[i]['exp'] += item['exp'] 62 | 63 | with open(wo_token_file, 'w') as f: 64 | for item in wo_token_data: 65 | f.writelines(json.dumps(item, ensure_ascii=False)) 66 | f.write('\n') 67 | 68 | with open(NILE_file, 'w') as f: 69 | for item in NILE_data: 70 | f.writelines(json.dumps(item, ensure_ascii=False)) 71 | f.write('\n') 72 | 73 | with open(xLIRE_file, 'w') as f: 74 | for item in xLIRE_data: 75 | f.writelines(json.dumps(item, ensure_ascii=False)) 76 | f.write('\n') 77 | 78 | -------------------------------------------------------------------------------- /data_utils/predictor_convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | 最后的predictor函数 仿照 esnli 写一个,后期用对比学习去做一下,这个文件里产生数据用 3 | """ 4 | import sys 5 | sys.path.append("..") 6 | import json 7 | from models.seq2seq_model import * 8 | from models.seq2seq_model_dismatch import * 9 | from models.seq2seq_model import GenerateModel as G_m 10 | from models.seq2seq_model_dismatch import GenerateModel as G_d 11 | from models.seq2seq_model import AutoSummary as AS_m 12 | from models.seq2seq_model_dismatch import AutoSummary as AS_d 13 | 14 | device = torch.device('cuda:'+'1') if torch.cuda.is_available() else torch.device('cpu') 15 | 16 | def load_checkpoint_p(model, optimizer, trained_epoch, file_name=None): 17 | if file_name==None: 18 | file_name = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 19 | save_params = torch.load(file_name, map_location=device) 20 | model.load_state_dict(save_params["model"]) 21 | 22 | 23 | 24 | def convert(file_list, save_path): 25 | with torch.no_grad(): 26 | #match_model = G_m() 27 | #load_checkpoint_p(match_model, None, 20, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/match-seq2seq-8.pkl") 28 | midmatch_model = G_m() 29 | load_checkpoint_p(midmatch_model, None, 20, 30 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/midmatch-seq2seq-16.pkl") 31 | #dismatch_model = G_d() 32 | #load_checkpoint_p(dismatch_model, None, 20, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/dismatch_v2-seq2seq-15.pkl") 33 | all_data = [] 34 | for file in file_list[2:3]: 35 | with open(file, 'r') as f: 36 | for line in f: 37 | all_data.append(json.loads(line)) 38 | 39 | autosummary_midmatch = AS_m( 40 | start_id=midmatch_model.tokenizer.cls_token_id, 41 | end_id=midmatch_model.tokenizer.sep_token_id, 42 | maxlen=args.maxlen // 4, 43 | model=midmatch_model 44 | ) 45 | # autosummary_match = AS_m( 46 | # start_id=match_model.tokenizer.cls_token_id, 47 | # end_id=match_model.tokenizer.sep_token_id, 48 | # maxlen=args.maxlen // 4, 49 | # model=match_model 50 | # ) 51 | # 52 | # autosummary_dismatch = AS_d( 53 | # start_id=dismatch_model.tokenizer.cls_token_id, 54 | # end_id=dismatch_model.tokenizer.sep_token_id, 55 | # maxlen=args.maxlen // 4, 56 | # model=dismatch_model 57 | # ) 58 | 59 | for d in tqdm(all_data, desc=u'评估中'): 60 | # match_exp = autosummary_match.generate(d['source_1'], 1) 61 | # dismatch_exp = autosummary_dismatch.generate(d['source_1_dis'][0], 1) 62 | # dismatch_exp += autosummary_dismatch.generate(d['source_1_dis'][1], 1) 63 | midmatch_exp = autosummary_midmatch.generate(d['source_1'], 5) 64 | d["exp"].append(midmatch_exp) 65 | 66 | with open(save_path, 'w') as f: 67 | for item in all_data: 68 | f.writelines(json.dumps(item, ensure_ascii=False)) 69 | f.write('\n') 70 | 71 | if __name__ == '__main__': 72 | 73 | match_data_seq2seq_json = '../dataset/match_data_prediction.json' 74 | midmatch_data_seq2seq_json = '../dataset/midmatch_data_prediction.json' 75 | dismatch_data_seq2seq_json = '../dataset/dismatch_data_prediction.json' 76 | save_path = "../dataset/dismatch_data_prediction_v2.json" 77 | file_list = [match_data_seq2seq_json, midmatch_data_seq2seq_json, dismatch_data_seq2seq_json] 78 | convert(file_list, save_path) 79 | 80 | 81 | print(u'输出over!') 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /models/selector_one.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import argparse 4 | from tqdm import tqdm 5 | from transformers import BertTokenizer, BertModel 6 | from utils.snippets import * 7 | import torch.nn as nn 8 | import torch 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use') 12 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 13 | args = parser.parse_args() 14 | 15 | np.random.seed(args.seed) 16 | torch.manual_seed(args.seed) 17 | torch.cuda.manual_seed_all(args.seed) 18 | 19 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 20 | 21 | 22 | 23 | class GlobalAveragePooling1D(nn.Module): 24 | """自定义全局池化 25 | 对一个句子的pooler取平均,一个长句子用短句的pooler平均代替 26 | """ 27 | def __init__(self): 28 | super(GlobalAveragePooling1D, self).__init__() 29 | 30 | 31 | def forward(self, inputs, mask=None): 32 | if mask is not None: 33 | mask = mask.to(torch.float)[:, :, None] 34 | return torch.sum(inputs * mask, dim=1) / torch.sum(mask, dim=1) 35 | else: 36 | return torch.mean(inputs, dim=1) 37 | 38 | 39 | class Selector_1(nn.Module): 40 | def __init__(self): 41 | super(Selector_1, self).__init__() 42 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_fold, mirror='tuna', do_lower_case=True) 43 | self.Pooling = GlobalAveragePooling1D() 44 | self.encoder = BertModel.from_pretrained(pretrained_bert_fold) 45 | self.max_seq_len = 512 46 | 47 | 48 | def predict(self, texts): 49 | """句子列表转换为句向量 50 | """ 51 | with torch.no_grad(): 52 | output_1s = [] 53 | bert_output = self.tokenizer.batch_encode_plus(texts, padding=True, truncation=True, max_length=self.max_seq_len) 54 | for p in range(len(texts)): 55 | bert_output_p = {"input_ids": torch.tensor([bert_output["input_ids"][p]], device=device), "token_type_ids": torch.tensor([bert_output["token_type_ids"][p]], device=device), 56 | "attention_mask": torch.tensor([bert_output["attention_mask"][p]], device=device)} 57 | output_1 = self.encoder(**bert_output_p)["last_hidden_state"] 58 | output_1s.append(output_1[0]) 59 | output_1_final = torch.stack(output_1s, dim=0) 60 | outputs = self.Pooling(output_1_final) 61 | return outputs 62 | 63 | 64 | 65 | def load_data(filename): 66 | """加载数据 67 | 返回:[texts] 68 | """ 69 | D = [] 70 | with open(filename) as f: 71 | for l in f: 72 | texts_a = json.loads(l)['case_A'][0] 73 | texts_b = json.loads(l)['case_B'][0] 74 | D.append([texts_a, texts_b]) 75 | 76 | return D 77 | 78 | 79 | 80 | 81 | def convert(data): 82 | """转换所有样本 83 | """ 84 | embeddings = [] 85 | model = Selector_1() 86 | model.to(device) 87 | for texts in tqdm(data, desc=u'向量化'): 88 | outputs_a = model.predict(texts[0]) 89 | outputs_b = model.predict(texts[1]) 90 | embeddings.append(outputs_a) 91 | embeddings.append(outputs_b) 92 | embeddings = sequence_padding(embeddings) 93 | return embeddings 94 | 95 | 96 | if __name__ == '__main__': 97 | # 98 | data_extract_json = '../dataset/our_data/data_extract_old.json' 99 | data_extract_npy = '../dataset/our_data/data_extract_old.npy' 100 | data = load_data(data_extract_json) 101 | embeddings = convert(data) 102 | np.save(data_extract_npy, embeddings) 103 | print(u'输出路径:%s' % data_extract_npy) 104 | -------------------------------------------------------------------------------- /data_utils/predictor_convert_t5.py: -------------------------------------------------------------------------------- 1 | """ 2 | 最后的predictor函数 仿照 esnli 写一个,后期用对比学习去做一下,这个文件里产生数据用 3 | """ 4 | import sys 5 | sys.path.append("..") 6 | import json 7 | import torch 8 | from utils.snippets import * 9 | from tqdm import tqdm 10 | import argparse 11 | from transformers import MT5ForConditionalGeneration 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('--maxlen', type=int, default=1024, help='max_len') 14 | parser.add_argument('--match_type', type=str, default="midmatch", help='[match midmatch dismatch]') 15 | parser.add_argument('--cuda_pos', type=str, default="0", help='[0 1]') 16 | parser.add_argument('--data_type', type=str, default="ELAM", help='ELAM, CAIL') 17 | parser.add_argument('--mode_type', type=str, default="wo_token", help='wo_token NILE xLIRE') 18 | args = parser.parse_args() 19 | print(args) 20 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 21 | 22 | def generate(text, model, tokenizer, device=device, max_length=30): 23 | feature = tokenizer.encode(text, return_token_type_ids=True, return_tensors='pt', 24 | max_length=args.maxlen, truncation=True) 25 | feature = {'input_ids': feature} 26 | feature = {k: v.to(device) for k, v in list(feature.items())} 27 | 28 | gen = model.generate(max_length=max_length, eos_token_id=tokenizer.sep_token_id, 29 | decoder_start_token_id=tokenizer.cls_token_id, 30 | **feature).cpu().numpy()[0] 31 | gen = gen[1:] 32 | gen = tokenizer.decode(gen, skip_special_tokens=True).replace(' ', '') 33 | return gen 34 | 35 | def load_checkpoint_p(model, optimizer, trained_epoch, file_name=None): 36 | save_params = torch.load(file_name, map_location=device) 37 | model.load_state_dict(save_params["model"]) 38 | 39 | 40 | 41 | def convert(file_list, save_path): 42 | with torch.no_grad(): 43 | tokenizer = T5PegasusTokenizer.from_pretrained(pretrained_t5_fold) 44 | tokenizer.add_tokens(["[AO]", "[YO]", "[ZO]", '[AI]', "[YI]", "[ZI]"]) 45 | model = MT5ForConditionalGeneration.from_pretrained(pretrained_t5_fold) 46 | model.resize_token_embeddings(len(tokenizer)) 47 | model = model.to(device) 48 | load_checkpoint_p(model, None, None, 49 | "../models/weights/seq2seq_model/{}-t5-seq2seq-{}-{}.pkl".format(args.match_type, args.data_type, 50 | args.mode_type)) 51 | 52 | all_data = [] 53 | for file in file_list: 54 | with open(file, 'r') as f: 55 | for line in f: 56 | all_data.append(json.loads(line)) 57 | 58 | for d in tqdm(all_data, desc=u'评估中'): 59 | match_exp = generate(d['source_1'], model, tokenizer, device,max_length=args.maxlen//4) 60 | d["exp"] = [match_exp] 61 | 62 | with open(save_path, 'w') as f: 63 | for item in all_data: 64 | f.writelines(json.dumps(item, ensure_ascii=False)) 65 | f.write('\n') 66 | 67 | if __name__ == '__main__': 68 | if args.data_type=='ELAM': 69 | match_data_seq2seq_json = '../dataset/our_data/match_data_seq2seq_{}.json'.format(args.mode_type) 70 | midmatch_data_seq2seq_json = '../dataset/our_data/midmatch_data_seq2seq_{}.json'.format(args.mode_type) 71 | dismatch_data_seq2seq_json = '../dataset/our_data/dismatch_data_seq2seq_{}.json'.format(args.mode_type) 72 | save_path = "../dataset/our_data/{}_data_prediction_t5_{}.json".format(args.match_type, args.mode_type) 73 | elif args.data_type=='CAIL': 74 | match_data_seq2seq_json = '../dataset/match_data_seq2seq_{}.json'.format(args.mode_type) 75 | midmatch_data_seq2seq_json = '../dataset/midmatch_data_seq2seq_{}.json'.format(args.mode_type) 76 | dismatch_data_seq2seq_json = '../dataset/dismatch_data_seq2seq_{}.json'.format(args.mode_type) 77 | save_path = "../dataset/{}_data_prediction_t5_{}.json".format(args.match_type, args.mode_type) 78 | else: 79 | exit() 80 | file_list = [match_data_seq2seq_json, midmatch_data_seq2seq_json, dismatch_data_seq2seq_json] 81 | convert(file_list, save_path) 82 | 83 | 84 | print(u'输出over!') 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IOT-Match 2 | 3 | This is the official implementation for the SIGIR 2022 [paper](https://dl.acm.org/doi/pdf/10.1145/3477495.3531974) 4 | "Explainable Legal Case Matching via Inverse Optimal Transport-based Rationale Extraction". 5 | 6 | 7 | # Overview 8 | As an essential operation of legal retrieval, legal case matching plays a central role in intelligent legal systems. This task has a high demand on the explainability of matching results because of its critical impacts on downstream applications --- the matched legal cases may provide supportive evidence for the judgments of target cases and thus influence the fairness and justice of legal decisions. Focusing on this challenging task, we propose a novel and explainable method, namely IOT-Match, with the help of computational optimal transport, which formulates the legal case matching problem as an inverse optimal transport (IOT) problem. Different from most existing methods, which merely focus on the sentence-level semantic similarity between legal cases, our IOT-Match learns to extract rationales from paired legal cases based on both semantics and legal characteristics of their sentences. The extracted rationales are further applied to generate faithful explanations and conduct matching. Moreover, the proposed IOT-Match is robust to the alignment label insufficiency issue commonly in practical legal case matching tasks, which is suitable for both supervised and semi-supervised learning paradigms. To demonstrate the superiority of our IOT-Match method and construct a benchmark of explainable legal case matching task, we not only extend the well-known Challenge of AI in Law (CAIL) dataset but also build a new Explainable Legal cAse Matching (ELAM) dataset, which contains lots of legal cases with detailed and explainable annotations. Experiments on these two datasets show that our IOT-Match outperforms state-of-the-art methods consistently on matching prediction, rationale extraction, and explanation generation. 9 | 10 | # Data 11 | We will provide ELAM to support [CAIL 2022](http://cail.cipsc.org.cn/) explainable legal case matching track. For a fair competition, we will not release ELAM here. Please stay tuned for CAIL! 12 | 13 | For eCAIL, you would like to download it [here](https://drive.google.com/file/d/1ixjnkpGvM8RL7arxFDrCMiVWzJtifQYv/view?usp=sharing). 14 | For ELAM, you would like to download it [here](https://drive.google.com/file/d/1_nHIRJfwshBlMZCF-m-BQ7V5Sj_hKwW_/view?usp=sharing). 15 | For ELAM for CAIL you would like to download it [here](https://drive.google.com/file/d/1-FyJOUMC9d0SJa2T9RVBXjnqnnYt95v3/view?usp=sharing). 16 | # Requirements 17 | ```python 18 | python>=3.7 19 | torch>=1.9.1+cu111 20 | transformers>=4.20.1 21 | numpy>=1.20.1 22 | jieba>=0.42.1 23 | six>=1.15.0 24 | rouge>=1.0.1 25 | tqdm>=4.62.3 26 | scikit-learn>=1.0.1 27 | pandas>=1.2.4 28 | nni>=2.6.1 29 | matplotlib>=3.3.4 30 | termcolor>=1.1.0 31 | networkx>=2.5 32 | requests>=2.25.1 33 | filelock>=3.0.12 34 | textrank4zh>=0.3 35 | gensim>=3.8.3 36 | openprompt>=1.0 37 | scipy>=1.8.0 38 | seaborn>=0.11.1 39 | ``` 40 | t5-pegasus 41 | # Trainining and Evaluation 42 | ```python 43 | python selector_one.py 44 | python selector_two_multi_class_ot_v3.py 45 | python seq2seq_convert_cail.py 46 | python seq2seq_model_t5.py 47 | python predictor_convert_t5.py 48 | cat_predictor.py 49 | python predictor_v5.py 50 | ``` 51 | The parameters used in above code are shown in their own files as default parameters. 52 | 53 | # Acknowledgement 54 | Please cite the following papers as the references if you use our codes or the processed datasets. 55 | 56 | ```bib 57 | @inproceedings{10.1145/3477495.3531974, 58 | author = {Yu, Weijie and Sun, Zhongxiang and Xu, Jun and Dong, Zhenhua and Chen, Xu and Xu, Hongteng and Wen, Ji-Rong}, 59 | title = {Explainable Legal Case Matching via Inverse Optimal Transport-Based Rationale Extraction}, 60 | year = {2022}, 61 | isbn = {9781450387323}, 62 | publisher = {Association for Computing Machinery}, 63 | address = {New York, NY, USA}, 64 | url = {https://doi.org/10.1145/3477495.3531974}, 65 | doi = {10.1145/3477495.3531974}, 66 | booktitle = {Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 67 | pages = {657–668}, 68 | numpages = {12}, 69 | keywords = {legal retrieval, explainable matching}, 70 | location = {Madrid, Spain}, 71 | series = {SIGIR '22} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /data_utils/analyse_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | import re 4 | data_extract_json = '../dataset/our_data/data_extract.json' 5 | data_extract_npy = '../dataset/our_data/data_extract.npy' 6 | 7 | def load_data(filename): 8 | """加载数据 9 | 返回:[(texts, labels, exp)] 10 | """ 11 | D = [] 12 | with open(filename, encoding='utf-8') as f: 13 | for l in f: 14 | D.append(json.loads(l)) 15 | return D 16 | data = load_data(data_extract_json) 17 | label_0, label_1, label_2 = 0, 0, 0 18 | A_all, Y_all, Z_all, A_I, Y_I, Z_I = 0, 0, 0, 0, 0, 0 19 | # sentence = 0 20 | # for i, d in enumerate(data): 21 | # if d["label"] == 0: 22 | # label_0 += 1 23 | # elif d["label"] == 1: 24 | # label_1 += 1 25 | # else: 26 | # label_2 += 1 27 | # 28 | # case_A_list = d["case_A"][1] 29 | # sentence += len(d["case_A"][0]) 30 | # case_B_list = d["case_B"][1] 31 | # sentence += len(d["case_B"][0]) 32 | # for item in case_A_list + case_B_list: 33 | # if item[1] == 1: 34 | # A_all += 1 35 | # elif item[1] == 2: 36 | # Y_all += 1 37 | # elif item[1] == 3: 38 | # Z_all += 1 39 | # else: 40 | # exit() 41 | # case_A_I = set() 42 | # case_B_I = set() 43 | # for pa in d["relation_label"]["relation_label_aqss"]: 44 | # case_A_I.add(pa[0]) 45 | # case_B_I.add(pa[1]) 46 | # A_I += len(case_A_I) + len(case_B_I) 47 | # case_A_I = set() 48 | # case_B_I = set() 49 | # for pa in d["relation_label"]["relation_label_yjss"]: 50 | # case_A_I.add(pa[0]) 51 | # case_B_I.add(pa[1]) 52 | # Y_I += len(case_A_I) + len(case_B_I) 53 | # 54 | # case_A_I = set() 55 | # case_B_I = set() 56 | # for pa in d["relation_label"]["relation_label_zyjd"]: 57 | # case_A_I.add(pa[0]) 58 | # case_B_I.add(pa[1]) 59 | # Z_I += len(case_A_I) + len(case_B_I) 60 | # print("label 0: {} label 1: {} label 2: {}".format(label_0, label_1, label_2)) 61 | # print("A_all: {} , Y_all: {}, Z_all: {}, A_I: {}, Y_I: {}, Z_I: {}".format(A_all, Y_all, Z_all, A_I, Y_I, Z_I)) 62 | #print(sentence/(label_0+label_1+label_2)) 63 | all, I = 0, 0 64 | sentence = 0 65 | for i, d in enumerate(data): 66 | if d["label"] == 0: 67 | label_0 += 1 68 | elif d["label"] == 1: 69 | label_1 += 1 70 | else: 71 | label_2 += 1 72 | 73 | case_A_list = d["case_A"][1] 74 | sentence += len(d["case_A"][0]) 75 | case_B_list = d["case_B"][1] 76 | sentence += len(d["case_B"][0]) 77 | all += len(case_A_list + case_B_list) 78 | 79 | case_A_I = set() 80 | case_B_I = set() 81 | for pa in d["relation_label"]: 82 | case_A_I.add(pa[0]) 83 | case_B_I.add(pa[1]) 84 | I += len(case_A_I) + len(case_B_I) 85 | print("pro {} con {} per {}".format(I, all - I, all/(2*len(data)))) 86 | 87 | 88 | 89 | sentence = 0 90 | word_num = 0 91 | for i, d in enumerate(data[:int(0.8*len(data))]): 92 | sentence += len(d["case_A"][0]) 93 | for item in d["case_A"][0]: 94 | word_num += len(item) 95 | sentence += len(d["case_B"][0]) 96 | for item in d["case_B"][0]: 97 | word_num += len(item) 98 | print("train_avg_sentence_length {} avg_sentence_num {}".format(word_num/sentence, sentence/(2*i))) 99 | sentence = 0 100 | word_num = 0 101 | for i, d in enumerate(data[int(0.8*len(data)):int(0.9*len(data))]): 102 | sentence += len(d["case_A"][0]) 103 | for item in d["case_A"][0]: 104 | word_num += len(item) 105 | sentence += len(d["case_B"][0]) 106 | for item in d["case_B"][0]: 107 | word_num += len(item) 108 | print("dev_avg_sentence_length {} avg_sentence_num {}".format(word_num/sentence, sentence/(2*i))) 109 | 110 | sentence = 0 111 | word_num = 0 112 | for i, d in enumerate(data[int(0.9*len(data)):]): 113 | sentence += len(d["case_A"][0]) 114 | for item in d["case_A"][0]: 115 | word_num += len(item) 116 | sentence += len(d["case_B"][0]) 117 | for item in d["case_B"][0]: 118 | word_num += len(item) 119 | print("test_avg_sentence_length {} avg_sentence_num {}".format(word_num/sentence, sentence/(2*i))) 120 | sentence = 0 121 | word_num = 0 122 | exp_len = 0 123 | for i, d in enumerate(data): 124 | sentence += len(d["case_A"][0]) 125 | for item in d["case_A"][0]: 126 | word_num += len(item) 127 | sentence += len(d["case_B"][0]) 128 | for item in d["case_B"][0]: 129 | word_num += len(item) 130 | exp_len += len(d["explanation"]) 131 | print("avg_sentence_length {} avg_sentence_num {} exp_len: {}".format(word_num/sentence, sentence/(2*i), exp_len/i)) 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /data_utils/extract_data_process.py: -------------------------------------------------------------------------------- 1 | from utils.snippets import * 2 | import json 3 | data_json = '/new_disk2/zhongxiang_sun/code/law_project/data/CAIL2021/code/persudo_data_new.json' 4 | def load_cail_data(filename, save_path): 5 | """加载数据 6 | 返回:[(text, summary)] 7 | """ 8 | D = [] 9 | with open(filename, encoding='utf-8') as f: 10 | for l in f: 11 | l = json.loads(l) 12 | case_a = l['case_A'] 13 | text_a = case_a['content'] 14 | labels_a = [[d['index'], 1] for d in case_a['evidence']] 15 | 16 | case_b = l['case_B'] 17 | text_b = case_b['content'] 18 | labels_b = [[d['index'], 1] for d in case_b['evidence']] 19 | item = {} 20 | item["case_A"] = [text_a, labels_a] 21 | item["case_B"] = [text_b, labels_b] 22 | item["explanation"] = l["features_exp"] 23 | item["label"] = l['label'] 24 | item["relation_label"] = l["relation_label"] 25 | D.append(item) 26 | 27 | 28 | with open(save_path, 'w') as f: 29 | for l in D: 30 | f.writelines(json.dumps(l,ensure_ascii=False)) 31 | f.write('\n') 32 | 33 | 34 | 35 | def load_our_data(filename, save_path): 36 | """加载数据 37 | 返回:[(text, summary)] 38 | """ 39 | D = [] 40 | data = [] 41 | for file in filename: 42 | with open(file, encoding='utf-8') as f: 43 | data += json.load(f) 44 | 45 | for l in data: 46 | relation_label_aqss, relation_label_yjss, relation_label_zyjd = [], [], [] 47 | for it in l["relation"]: 48 | if it['entityLeft'].split('#')[2] == '案情事实': 49 | for a in it["label_Left"]: 50 | for b in it["label_Right"]: 51 | relation_label_aqss.append([a, b]) 52 | 53 | if it['entityLeft'].split('#')[2] == '要件事实': 54 | for a in it["label_Left"]: 55 | for b in it["label_Right"]: 56 | relation_label_yjss.append([a, b]) 57 | 58 | if it['entityLeft'].split('#')[2] == '争议焦点': 59 | for a in it["label_Left"]: 60 | for b in it["label_Right"]: 61 | relation_label_zyjd.append([a, b]) 62 | 63 | case_a = l['case_a_ner'] 64 | text_a = [] 65 | labels_a = [] 66 | for i, it in enumerate(case_a): 67 | text_a.append(it["text"]) 68 | if it["entities"] != []: 69 | if it["entities"][0]["type"] == "aqss": 70 | labels_a.append([i, 1]) 71 | elif it["entities"][0]["type"] == "yjss": 72 | labels_a.append([i, 2]) 73 | elif it["entities"][0]["type"] == "zyjd": 74 | labels_a.append([i, 3]) 75 | else: 76 | print(it["entities"]) 77 | else: 78 | pass 79 | 80 | case_b = l['case_b_ner'] 81 | text_b = [] 82 | labels_b = [] 83 | for i, it in enumerate(case_b): 84 | text_b.append(it["text"].replace('h', '')) 85 | if it["entities"] != []: 86 | if it["entities"][0]["type"] == "aqss": 87 | labels_b.append([i, 1]) 88 | elif it["entities"][0]["type"] == "yjss": 89 | labels_b.append([i, 2]) 90 | elif it["entities"][0]["type"] == "zyjd": 91 | labels_b.append([i, 3]) 92 | else: 93 | print(l) 94 | else: 95 | pass 96 | 97 | item = {} 98 | item["case_A"] = [text_a, labels_a] 99 | item["case_B"] = [text_b, labels_b] 100 | item["explanation"] = l["explanation"] 101 | item["label"] = l['gold_label'] 102 | item["relation_label"] = {"relation_label_aqss": relation_label_aqss, 103 | "relation_label_yjss": relation_label_yjss, 104 | "relation_label_zyjd": relation_label_zyjd} 105 | item['id'] = l['pair_ID'] 106 | D.append(item) 107 | 108 | 109 | with open(save_path, 'w') as f: 110 | for l in D: 111 | f.writelines(json.dumps(l,ensure_ascii=False)) 112 | f.write('\n') 113 | 114 | 115 | 116 | 117 | if __name__ == '__main__': 118 | # """ 119 | # cail data 120 | # """ 121 | # data_extract_json = '../dataset/data_extract.json' 122 | # load_cail_data(data_json, data_extract_json) 123 | 124 | """ 125 | our data 126 | """ 127 | data_extract_json = '../dataset/our_data/data_extract.json' 128 | load_our_data([our_data_train,our_data_dev, our_data_test], data_extract_json) 129 | 130 | 131 | -------------------------------------------------------------------------------- /data_utils/baseline_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | 用于构造: 3 | Bert_legal (all sents) 4 | Bert_legal (rational) 5 | Bert_legal (all sents-rational) 6 | 数据集 7 | """ 8 | import sys 9 | sys.path.append("..") 10 | data_type = 'CAIL' 11 | if data_type == 'CAIL': 12 | from models.selector_two_multi_class_ot_cail_v2 import Selector2_mul_class, args, load_checkpoint, OT, load_data, \ 13 | data_extract_npy, data_extract_json, device 14 | elif data_type =='ELAM': 15 | from models.selector_two_multi_class_ot_v3 import Selector2_mul_class, args, load_checkpoint, OT, load_data, data_extract_npy, data_extract_json, device 16 | import torch 17 | from utils.snippets import * 18 | import torch 19 | import json 20 | """ 21 | todo 需要添加一个分类器,对相似句子与不相似句子进行分类 22 | """ 23 | def model_class(model, OT_model, case_A, case_B, seq_len_A, seq_len_B): 24 | """ 25 | 26 | :param model: 27 | :param OT_model: 28 | :param case_A: 29 | :param case_B: 30 | :return: AO, YO, ZO, AI, YI, ZI 31 | """ 32 | 33 | output_batch_A, batch_mask_A = model(case_A) 34 | output_batch_B, batch_mask_B = model(case_B) 35 | plan_list = OT_model(output_batch_A, output_batch_B, case_A, case_B, None, 36 | batch_mask_A, batch_mask_B, model_type='valid') 37 | OT_matrix = torch.ge(plan_list, 1 / case_A.shape[1] / args.threshold_ot).long() 38 | vec_correct_A = torch.argmax(output_batch_A, dim=-1).long()[0][:seq_len_A] 39 | vec_correct_B = torch.argmax(output_batch_B, dim=-1).long()[0][:seq_len_B] 40 | relation_A = torch.sum(OT_matrix[0], dim=1) 41 | relation_B = torch.sum(OT_matrix[0], dim=0) 42 | 43 | if data_type == 'CAIL': 44 | return [vec_correct_A+(torch.ge(relation_A[:seq_len_A], 1)*1)*vec_correct_A, vec_correct_B+(torch.ge(relation_B[:seq_len_B], 1)*1)*vec_correct_B] 45 | elif data_type == 'ELAM': 46 | return [vec_correct_A+(torch.ge(relation_A[:seq_len_A], 1)*3)*vec_correct_A, vec_correct_B+(torch.ge(relation_B[:seq_len_B], 1)*3)*vec_correct_B] 47 | else: 48 | exit() 49 | 50 | 51 | 52 | 53 | def get_extract_text_wo_token(case_a, prediction): 54 | all_sents_a, rationale_a, all_wo_rationale_a = '', '', '' 55 | for i, output_class in enumerate(prediction): 56 | if output_class != 0: 57 | rationale_a += case_a[0][i] 58 | else: 59 | all_wo_rationale_a += case_a[0][i] 60 | all_sents_a += case_a[0][i] 61 | return all_sents_a, rationale_a, all_wo_rationale_a 62 | 63 | 64 | def generate_text_wo_token(case_a, case_b, d, prediction): 65 | 66 | all_sents_a, rationale_a, all_wo_rationale_a = get_extract_text_wo_token(case_a, prediction[0]) 67 | all_sents_b, rationale_b, all_wo_rationale_b = get_extract_text_wo_token(case_b, prediction[1]) 68 | 69 | result_all_sents = { 70 | 'case_a': all_sents_a, 71 | 'case_b': all_sents_b, 72 | 'label': d['label'] 73 | } 74 | result_rationale = { 75 | 'case_a': rationale_a, 76 | 'case_b': rationale_b, 77 | 'label': d['label'] 78 | } 79 | 80 | result_all_wo_rationale = { 81 | 'case_a': all_wo_rationale_a, 82 | 'case_b': all_wo_rationale_b, 83 | 'label': d['label'] 84 | } 85 | 86 | return result_all_sents, result_rationale, result_all_wo_rationale 87 | 88 | 89 | def fold_convert_our_data_ot(data, data_x): 90 | """每一fold用对应的模型做数据转换 91 | """ 92 | 93 | with torch.no_grad(): 94 | model = Selector2_mul_class(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 95 | if data_type == 'CAIL': 96 | load_checkpoint(model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/cail_extract-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 97 | elif data_type == 'ELAM': 98 | load_checkpoint(model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/extract-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 99 | else: 100 | exit() 101 | model = model.to(device) 102 | ot_model = OT() 103 | ot_model = ot_model.to(device) 104 | if data_type == 'ELAM': 105 | load_checkpoint(ot_model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/extract_ot-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 106 | elif data_type == 'CAIL': 107 | load_checkpoint(ot_model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/cail_extract_ot-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 108 | else: 109 | exit() 110 | results_all_sents, results_rationale, results_wo_rationale = [], [], [] 111 | for i, d in enumerate(data): 112 | case_a = d['case_A'] 113 | case_b = d['case_B'] 114 | 115 | 116 | 117 | 118 | prediction = model_class(model, ot_model, torch.tensor(np.expand_dims(data_x[2*i], axis=0), device=device), 119 | torch.tensor(np.expand_dims(data_x[2*i+1], axis=0), device=device),len(case_a[0]), len(case_b[0])) 120 | 121 | 122 | all_sents, rationale, wo_rationale = generate_text_wo_token(case_a, case_b, d, prediction) 123 | results_all_sents.append(all_sents) 124 | results_rationale.append(rationale) 125 | results_wo_rationale.append(wo_rationale) 126 | 127 | return [results_all_sents, results_rationale, results_wo_rationale] 128 | 129 | 130 | 131 | def convert(filename, data, data_x): 132 | """转换为生成式数据 133 | """ 134 | total_results = fold_convert_our_data_ot(data, data_x) 135 | for i in range(len(total_results)): 136 | with open(filename[i], 'w') as f: 137 | for item in total_results[i]: 138 | f.writelines(json.dumps(item, ensure_ascii=False)) 139 | f.write('\n') 140 | 141 | 142 | 143 | if __name__ == '__main__': 144 | 145 | data = load_data(data_extract_json) 146 | data_x = np.load(data_extract_npy) 147 | bert_legal_all_sents_json = '../dataset/baselines_datasets/{}_bert_legal_all_sents.json'.format(data_type) 148 | bert_legal_rationale_json = '../dataset/baselines_datasets/{}_bert_legal_rationale.json'.format(data_type) 149 | bert_legal_wo_rationale_json = '../dataset/baselines_datasets/{}_bert_legal_wo_rationale.json'.format(data_type) 150 | 151 | 152 | convert([bert_legal_all_sents_json, bert_legal_rationale_json, bert_legal_wo_rationale_json], data, data_x) 153 | 154 | 155 | print(u'输出over!') 156 | -------------------------------------------------------------------------------- /data_utils/seq2seq_convert_xLIRE.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from baselines.selector_baselines.Attention_seven_class import Selector2_mul_class, args, load_checkpoint, attention, load_data, data_extract_npy, data_extract_json, device 4 | import torch 5 | from utils.snippets import * 6 | import torch 7 | import json 8 | 9 | def get_predition(batch_A, batch_B, case_a_important, case_b_important): 10 | relation_A = torch.ge(case_a_important, args.threshold).long() 11 | relation_B = torch.ge(case_b_important, args.threshold).long() 12 | if args.data_type == 'CAIL': 13 | vec_correct_A = torch.argmax(batch_A, dim=-1) + (relation_A * 1).squeeze() * torch.ge(torch.argmax(batch_A, dim=-1), 1) 14 | vec_correct_B = torch.argmax(batch_B, dim=-1) + (relation_B * 1).squeeze() * torch.ge(torch.argmax(batch_B, dim=-1), 1) 15 | else: 16 | vec_correct_A = torch.argmax(batch_A, dim=-1) + (relation_A * 3).squeeze() * torch.ge(torch.argmax(batch_A, dim=-1), 1) 17 | vec_correct_B = torch.argmax(batch_B, dim=-1) + (relation_B * 3).squeeze() * torch.ge(torch.argmax(batch_B, dim=-1), 1) 18 | 19 | return vec_correct_A, vec_correct_B 20 | 21 | 22 | def model_class(model, OT_model, case_A, case_B, seq_len_A, seq_len_B): 23 | """ 24 | :param model: 25 | :param OT_model: 26 | :param case_A: 27 | :param case_B: 28 | :return: AO, YO, ZO, AI, YI, ZI 29 | """ 30 | 31 | output_batch_A, batch_mask_A = model(case_A) 32 | output_batch_B, batch_mask_B = model(case_B) 33 | case_a_important, case_b_important = OT_model(output_batch_A, output_batch_B, case_A, case_B, None, None, 34 | batch_mask_A, batch_mask_B) 35 | 36 | seven_prediction_A, seven_prediction_B = get_predition(output_batch_A.clone(), output_batch_B.clone(), 37 | case_a_important, case_b_important) 38 | 39 | return seven_prediction_A, seven_prediction_B 40 | 41 | 42 | def get_extract_text(case_a, prediction): 43 | source_1_a = '' 44 | for i, output_class in enumerate(prediction[0]): 45 | if i >= len(case_a[0]): 46 | break 47 | if output_class == 0: 48 | source_1_a += case_a[0][i] 49 | else: 50 | source_1_a += '[' + case_a[0][i] + ']' 51 | return source_1_a 52 | 53 | 54 | 55 | 56 | 57 | def generate_text_sort(case_a, case_b, d, prediction, label): 58 | 59 | source_1_a = get_extract_text(case_a, prediction[0]) 60 | source_1_b = get_extract_text(case_b, prediction[1]) 61 | source_2_a = get_extract_text(case_a, prediction[0]) 62 | source_2_b = get_extract_text(case_b, prediction[1]) 63 | if args.data_type=='CAIL': 64 | result = { 65 | 'source_1': source_1_a + source_1_b, 66 | 'source_2': source_2_a + source_2_b, 67 | 'explanation': d['explanation'], 68 | 'source_1_dis': [source_1_a, source_1_b], 69 | 'source_2_dis': [source_2_a, source_2_b], 70 | 'label': d['label'] 71 | } 72 | else: 73 | result = { 74 | 'source_1': source_1_a + source_1_b, 75 | 'source_2': source_2_a + source_2_b, 76 | 'explanation': ';'.join(list(d['explanation'].values())), 77 | 'source_1_dis': [source_1_a, source_1_b], 78 | 'source_2_dis': [source_2_a, source_2_b], 79 | 'label': d['label'] 80 | } 81 | return result 82 | 83 | 84 | 85 | 86 | def fold_convert_our_data_ot(data, data_x, type, generate=False): 87 | """每一fold用对应的模型做数据转换 88 | """ 89 | 90 | with torch.no_grad(): 91 | model = Selector2_mul_class(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 92 | load_checkpoint(model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/baselines/selector_baselines/weights/extract/extract-{}.pkl".format(args.data_type)) 93 | model = model.to(device) 94 | ot_model = attention() 95 | ot_model = ot_model.to(device) 96 | load_checkpoint(ot_model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/baselines/selector_baselines/weights/extract/extract_attention-{}.pkl".format(args.data_type)) 97 | results = [] 98 | print(type+"ing") 99 | for i, d in enumerate(data): 100 | if type == 'match' and d["label"] == 2 or type == 'midmatch' and d["label"] == 1 or type == 'dismatch' and d["label"] == 0: 101 | case_a = d['case_A'] 102 | case_b = d['case_B'] 103 | important_A, important_B = [], [] 104 | data_y_seven_class_A, data_y_seven_class_B = [0]*len(case_a[0]), [0]*len(case_b[0]) 105 | if args.data_type == 'CAIL': 106 | for pos in d['relation_label']: 107 | row, col = pos[0], pos[-1] 108 | important_A.append(row) 109 | important_B.append(col) 110 | for j in case_a[1]: 111 | if j[0] in important_A: 112 | data_y_seven_class_A[j[0]] = j[1] + 1 113 | else: 114 | data_y_seven_class_A[j[0]] = j[1] 115 | 116 | for j in case_b[1]: 117 | if j[0] in important_B: 118 | data_y_seven_class_B[j[0]] = j[1] + 1 119 | else: 120 | data_y_seven_class_B[j[0]] = j[1] 121 | else: 122 | for pos_list in d['relation_label'].values(): 123 | for pos in pos_list: 124 | row, col = pos[0], pos[-1] 125 | important_A.append(row) 126 | important_B.append(col) 127 | 128 | for j in case_a[1]: 129 | if j[0] in important_A: 130 | data_y_seven_class_A[j[0]] = j[1] + 3 131 | else: 132 | data_y_seven_class_A[j[0]] = j[1] 133 | 134 | for j in case_b[1]: 135 | if j[0] in important_B: 136 | data_y_seven_class_B[j[0]] = j[1] + 3 137 | else: 138 | data_y_seven_class_B[j[0]] = j[1] 139 | 140 | label = [data_y_seven_class_A, data_y_seven_class_B] 141 | data_x_a = torch.tensor(np.expand_dims(data_x[2*i], axis=0), device=device) 142 | data_x_b = torch.tensor(np.expand_dims(data_x[2*i+1], axis=0), device=device) 143 | 144 | seven_prediction_A, seven_prediction_B = model_class(model, ot_model, data_x_a, data_x_b, len(case_a[0]), len(case_b[0])) 145 | 146 | prediction = [seven_prediction_A, seven_prediction_B] 147 | if generate: 148 | results.append(generate_text_sort(case_a, case_b, d, prediction, label)) 149 | 150 | if generate: 151 | return results 152 | 153 | 154 | def convert(filename, data, data_x, type): 155 | """转换为生成式数据 156 | """ 157 | total_results = fold_convert_our_data_ot(data, data_x, type, generate=True) 158 | 159 | with open(filename, 'w') as f: 160 | for item in total_results: 161 | f.writelines(json.dumps(item, ensure_ascii=False)) 162 | f.write('\n') 163 | 164 | 165 | 166 | if __name__ == '__main__': 167 | if args.data_type == 'CAIL': 168 | data_extract_json = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/data_extract.json" 169 | data_extract_npy = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/data_extract.npy" 170 | else: 171 | data_extract_json = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_extract.json" 172 | data_extract_npy = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/data_extract.npy" 173 | 174 | data = load_data(data_extract_json) 175 | data_x = np.load(data_extract_npy) 176 | da_type = 'xLIRE' 177 | match_data_seq2seq_json = '../dataset/our_data/match_data_seq2seq_{}.json'.format(da_type) # 文件夹位置显示不同的data 178 | midmatch_data_seq2seq_json = '../dataset/our_data/midmatch_data_seq2seq_{}.json'.format(da_type) 179 | dismatch_data_seq2seq_json = '../dataset/our_data/dismatch_data_seq2seq_{}.json'.format(da_type) 180 | convert(match_data_seq2seq_json, data, data_x, type='match') 181 | convert(midmatch_data_seq2seq_json, data, data_x, type='midmatch') 182 | convert(dismatch_data_seq2seq_json, data, data_x, type='dismatch') 183 | 184 | 185 | print(u'输出over!') 186 | -------------------------------------------------------------------------------- /data_utils/seq2seq_convert_cail.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from models.selector_two_multi_class_ot_cail_v2 import Selector2_mul_class, args, load_checkpoint, OT, load_data, data_extract_npy, data_extract_json, device 4 | import torch 5 | from utils.snippets import * 6 | import torch 7 | import json 8 | """ 9 | todo 需要添加一个分类器,对相似句子与不相似句子进行分类 10 | """ 11 | def model_class(model, OT_model, case_A, case_B, seq_len_A, seq_len_B): 12 | """ 13 | :param model: 14 | :param OT_model: 15 | :param case_A: 16 | :param case_B: 17 | :return: AO, YO, ZO, AI, YI, ZI 18 | """ 19 | O_a, I_a = [], [] 20 | O_b, I_b = [], [] 21 | output_batch_A, batch_mask_A = model(case_A) 22 | output_batch_B, batch_mask_B = model(case_B) 23 | plan_list = OT_model(output_batch_A, output_batch_B, case_A, case_B, None, 24 | batch_mask_A, batch_mask_B, model_type="valid") 25 | OT_matrix = torch.ge(plan_list, 1 / case_A.shape[1] / args.threshold_ot).long() 26 | vec_correct_A = torch.argmax(output_batch_A, dim=-1).long()[0][:seq_len_A] 27 | vec_correct_B = torch.argmax(output_batch_B, dim=-1).long()[0][:seq_len_B] 28 | relation_A = torch.sum(OT_matrix[0], dim=1) 29 | relation_B = torch.sum(OT_matrix[0], dim=0) 30 | 31 | for i, label in enumerate(vec_correct_A): 32 | if label == 1: 33 | if relation_A[i] >= 1: 34 | I_a.append(i) 35 | else: 36 | O_a.append(i) 37 | 38 | for i, label in enumerate(vec_correct_B): 39 | if label == 1: 40 | if relation_B[i] >= 1: 41 | I_b.append(i) 42 | else: 43 | O_b.append(i) 44 | O, I = [O_a, O_b], [I_a, I_b] 45 | 46 | return O, I, [vec_correct_A+(torch.ge(relation_A[:seq_len_A], 1)*1)*vec_correct_A, vec_correct_B+(torch.ge(relation_B[:seq_len_B], 1)*1)*vec_correct_B] 47 | 48 | 49 | 50 | 51 | def generate_text_cluster(case_a, case_b, d, O, I, all_true, I_true): 52 | source_1_a = ''.join(["[O]" + case_a[0][i] for i in O[0]] + ["[I]" + case_a[0][i] for i in I[0]]) 53 | 54 | source_1_b = ''.join(["[O]" + case_b[0][i] for i in O[1]] + ["[I]" + case_b[0][i] for i in I[1]]) 55 | 56 | source_2_a = ''.join(["[O]" + case_a[0][i] for i in all_true[0] if i not in I_true[0]] + ["[I]" + case_a[0][i] for i in I_true[0]]) 57 | 58 | source_2_b = ''.join(["[O]" + case_b[0][i] for i in all_true[1] if i not in I_true[1]] + ["[I]" + case_b[0][i] for i in I_true[1]]) 59 | 60 | result = { 61 | 'source_1': source_1_a + source_1_b, 62 | 'source_2': source_2_a + source_2_b, 63 | 'explanation': d['explanation'], 64 | 'source_1_dis': [source_1_a, source_1_b], 65 | 'source_2_dis': [source_2_a, source_2_b], 66 | 'label': d['label'] 67 | } 68 | return result 69 | 70 | 71 | def get_extract_text(case_a, prediction): 72 | source_1_a = '' 73 | for i, output_class in enumerate(prediction): 74 | if output_class == 1: 75 | source_1_a += "[O]" + case_a[0][i] 76 | elif output_class == 2: 77 | source_1_a += "[I]" + case_a[0][i] 78 | else: 79 | pass 80 | return source_1_a 81 | 82 | 83 | def get_extract_text_wo_token(case_a, prediction): 84 | source_1_a = '' 85 | for i, output_class in enumerate(prediction): 86 | if output_class != 0: 87 | source_1_a += case_a[0][i] 88 | else: 89 | pass 90 | return source_1_a 91 | 92 | 93 | def generate_text_sort(case_a, case_b, d, prediction, label): 94 | 95 | source_1_a = get_extract_text(case_a, prediction[0]) 96 | source_1_b = get_extract_text(case_b, prediction[1]) 97 | source_2_a = get_extract_text(case_a, label[0]) 98 | source_2_b = get_extract_text(case_b, label[1]) 99 | 100 | result = { 101 | 'source_1': source_1_a + source_1_b, 102 | 'source_2': source_2_a + source_2_b, 103 | 'explanation': d['explanation'], 104 | 'source_1_dis': [source_1_a, source_1_b], 105 | 'source_2_dis': [source_2_a, source_2_b], 106 | 'label': d['label'] 107 | } 108 | return result 109 | 110 | 111 | def generate_text_wo_token(case_a, case_b, d, prediction, label): 112 | 113 | source_1_a = get_extract_text_wo_token(case_a, prediction[0]) 114 | source_1_b = get_extract_text_wo_token(case_b, prediction[1]) 115 | source_2_a = get_extract_text_wo_token(case_a, label[0]) 116 | source_2_b = get_extract_text_wo_token(case_b, label[1]) 117 | 118 | result = { 119 | 'source_1': source_1_a + source_1_b, 120 | 'source_2': source_2_a + source_2_b, 121 | 'explanation': d['explanation'], 122 | 'source_1_dis': [source_1_a, source_1_b], 123 | 'source_2_dis': [source_2_a, source_2_b], 124 | 'label': d['label'] 125 | } 126 | return result 127 | 128 | 129 | def fold_convert_cail_ot(data, data_x, type, generate=False, generate_mode = 'cluster'): 130 | """每一fold用对应的模型做数据转换 131 | """ 132 | 133 | with torch.no_grad(): 134 | model = Selector2_mul_class(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 135 | load_checkpoint(model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/cail_extract-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 136 | model = model.to(device) 137 | ot_model = OT() 138 | ot_model = ot_model.to(device) 139 | load_checkpoint(ot_model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/cail_extract_ot-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 140 | results = [] 141 | print(type+"ing") 142 | for i, d in enumerate(data): 143 | if type == 'match' and d["label"] == 2 or type == 'midmatch' and d["label"] == 1 or type == 'dismatch' and d["label"] == 0: 144 | case_a = d['case_A'] 145 | case_b = d['case_B'] 146 | important_A, important_B = [], [] 147 | data_y_seven_class_A, data_y_seven_class_B = [0]*len(case_a[0]), [0]*len(case_b[0]) 148 | for pos in d['relation_label']: 149 | row, col = pos[0], pos[-1] 150 | important_A.append(row) 151 | important_B.append(col) 152 | 153 | for j in case_a[1]: 154 | if j[0] in important_A: 155 | data_y_seven_class_A[j[0]] = j[1] + 1 156 | else: 157 | data_y_seven_class_A[j[0]] = j[1] 158 | 159 | for j in case_b[1]: 160 | if j[0] in important_B: 161 | data_y_seven_class_B[j[0]] = j[1] + 1 162 | else: 163 | data_y_seven_class_B[j[0]] = j[1] 164 | label = [data_y_seven_class_A, data_y_seven_class_B] 165 | 166 | 167 | 168 | O, I, prediction = model_class(model, ot_model, torch.tensor(np.expand_dims(data_x[2*i], axis=0), device=device), 169 | torch.tensor(np.expand_dims(data_x[2*i+1], axis=0), device=device),len(case_a[0]), len(case_b[0])) 170 | 171 | all_true, I_true = [], [] 172 | 173 | temp_a, temp_b = [], [] 174 | for i in d['relation_label']: 175 | temp_a.append(i[0]) 176 | temp_b.append(i[1]) 177 | I_true.append(temp_a) 178 | I_true.append(temp_b) 179 | 180 | 181 | case_a_temp = [] 182 | for i in case_a[1]: 183 | if i[1] == 1: 184 | case_a_temp.append(i[0]) 185 | all_true.append(case_a_temp) 186 | 187 | case_b_temp = [] 188 | for i in case_b[1]: 189 | if i[1] == 1: 190 | case_b_temp.append(i[0]) 191 | 192 | all_true.append(case_b_temp) 193 | if generate: 194 | if generate_mode == 'cluster': 195 | results.append(generate_text_cluster(case_a, case_b, d, O, I, all_true, I_true)) 196 | elif generate_mode == 'sort': 197 | results.append(generate_text_sort(case_a, case_b, d, prediction, label)) 198 | else: 199 | results.append(generate_text_wo_token(case_a, case_b, d, prediction, label)) 200 | 201 | if generate: 202 | return results 203 | 204 | 205 | 206 | def convert(filename, data, data_x, type, generate_mode): 207 | """转换为生成式数据 208 | """ 209 | total_results = fold_convert_cail_ot(data, data_x, type, generate=True, generate_mode=generate_mode) 210 | 211 | with open(filename, 'w') as f: 212 | for item in total_results: 213 | f.writelines(json.dumps(item, ensure_ascii=False)) 214 | f.write('\n') 215 | 216 | 217 | 218 | if __name__ == '__main__': 219 | 220 | data = load_data(data_extract_json) 221 | data_x = np.load(data_extract_npy) 222 | da_type = "wo_token" 223 | match_data_seq2seq_json = '../dataset/match_data_seq2seq_{}.json'.format(da_type) 224 | midmatch_data_seq2seq_json = '../dataset/midmatch_data_seq2seq_{}.json'.format(da_type) 225 | dismatch_data_seq2seq_json = '../dataset/dismatch_data_seq2seq_{}.json'.format(da_type) 226 | convert(match_data_seq2seq_json, data, data_x, type='match', generate_mode=da_type) 227 | convert(midmatch_data_seq2seq_json, data, data_x, type='midmatch', generate_mode=da_type) 228 | convert(dismatch_data_seq2seq_json, data, data_x, type='dismatch', generate_mode=da_type) 229 | 230 | 231 | print(u'输出over!') 232 | -------------------------------------------------------------------------------- /data_utils/splite_to_ner.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | """ 4 | 首先按照句号分句,之后对句号中的每个特殊符号进行切分 5 | \001 案情事实 \002 案情事实匹配特征 \003 要件事实 \004要件事实匹配特征 \005争议焦点 \006争议焦点匹配特征 6 | """ 7 | 8 | input_path = "/new_disk2/zhongxiang_sun/code/law_project/data/tagger_result/all_data.json" 9 | input_path_raw = "/new_disk2/zhongxiang_sun/code/law_project/data/tagger_result/all_data_raw.json" 10 | 11 | output_path = "../data/tagger_ner.json" 12 | 13 | count_ky = 0 14 | is_continue = False 15 | former_type = "" 16 | def split_chinese_sentence(sentence, delete_notag_para) -> list: 17 | """ 18 | delete_notag_para: 是否把没有特征的段落删去(可能是没剔除掉的无用打断) 19 | """ 20 | if delete_notag_para: 21 | para = "" 22 | para_list_raw = sentence.split('\n') 23 | tag_list = ["\\001", "\\002", "\\003", "\\004", "\\005", "\\006"] 24 | for pa in para_list_raw: 25 | for tag in tag_list: 26 | if tag in pa: 27 | para += pa 28 | break 29 | else: 30 | para = sentence.replace('\n','h') 31 | para = re.sub('([。!?\?])([^”’])', r"\1\n\2", para) # 单字符断句符 32 | para = re.sub('([;])([^”’])', r"\1\n\2", para) # 司法案例加入中文分号,冒号暂时不加 33 | 34 | para = re.sub('(\.{6})([^”’])', r"\1\n\2", para) # 英文省略号 35 | para = re.sub('(\…{2})([^”’])', r"\1\n\2", para) # 中文省略号 36 | para = re.sub('([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) 37 | # 如果双引号前有终止符,那么双引号才是句子的终点,把分句符\n放到双引号后,注意前面的几句都小心保留了双引号 38 | para = para.rstrip() # 段尾如果有多余的\n就去掉它 39 | # 很多规则中会考虑分号;,但是这里我把它忽略不计,破折号、英文双引号等同样忽略,需要的再做些简单调整即可 40 | return para.split("\n") 41 | 42 | def sentence2ner(case_A_marked): 43 | count_signal = 0 44 | marked = False 45 | start = 0 46 | last_type = "" 47 | j = 0 48 | entities = [] 49 | global is_continue 50 | while j < len(case_A_marked): 51 | if case_A_marked[j:j + 4] == '\\001': # todo 判断一些001咋相等 52 | if marked == False and not (is_continue==True and start==0): # 说明是普通句子不加粗不加颜色写进去 这里不可能存在标记不对应的情况 53 | sentence = case_A_marked[start:j] 54 | start = j + 4 55 | j += 4 56 | marked = True 57 | last_type = "aqss" 58 | count_signal += 4 59 | else: 60 | sentence = case_A_marked[start:j] 61 | if start < j: 62 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "aqss", "entity":sentence}) 63 | count_signal += 4 64 | start = j + 4 65 | j += 4 66 | marked = False 67 | 68 | elif case_A_marked[j:j + 4] == '\\002': 69 | if marked == False and not (is_continue==True and start==0): 70 | sentence = case_A_marked[start:j] 71 | start = j + 4 72 | j += 4 73 | marked = True 74 | last_type = "aqss" 75 | count_signal += 4 76 | 77 | else: # 说明是非关联案情事实不加粗加红写进去 78 | sentence = case_A_marked[start:j] 79 | if start < j: 80 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "aqss", "entity":sentence}) 81 | count_signal += 4 82 | 83 | start = j + 4 84 | j += 4 85 | marked = False 86 | 87 | elif case_A_marked[j:j + 4] == '\\003': 88 | if marked == False and not (is_continue==True and start==0): 89 | sentence = case_A_marked[start:j] 90 | start = j + 4 91 | j += 4 92 | marked = True 93 | last_type = "yjss" 94 | count_signal += 4 95 | 96 | 97 | else: # 说明是非关联案情事实不加粗加红写进去 98 | sentence = case_A_marked[start:j] 99 | if start < j: 100 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "yjss", "entity":sentence}) 101 | count_signal += 4 102 | start = j + 4 103 | j += 4 104 | marked = False 105 | 106 | elif case_A_marked[j:j + 4] == '\\004': 107 | if marked == False and not (is_continue==True and start==0): 108 | sentence = case_A_marked[start:j] 109 | start = j + 4 110 | j += 4 111 | marked = True 112 | last_type = "yjss" 113 | count_signal += 4 114 | 115 | else: # 说明是非关联案情事实不加粗加红写进去 116 | sentence = case_A_marked[start:j] 117 | if start < j: 118 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "yjss", "entity":sentence}) 119 | 120 | start = j + 4 121 | j += 4 122 | marked = False 123 | count_signal += 4 124 | 125 | elif case_A_marked[j:j + 4] == '\\005': 126 | if marked == False and not (is_continue==True and start==0): 127 | sentence = case_A_marked[start:j] 128 | start = j + 4 129 | j += 4 130 | marked = True 131 | last_type = "zyjd" 132 | count_signal += 4 133 | 134 | else: # 说明是非关联案情事实不加粗加红写进去 135 | sentence = case_A_marked[start:j] 136 | if start < j: 137 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "zyjd", "entity":sentence}) 138 | 139 | start = j + 4 140 | j += 4 141 | marked = False 142 | count_signal += 4 143 | 144 | elif case_A_marked[j:j + 4] == '\\006': 145 | if marked == False and not (is_continue==True and start==0): 146 | sentence = case_A_marked[start:j] 147 | start = j + 4 148 | j += 4 149 | marked = True 150 | last_type = "zyjd" 151 | count_signal += 4 152 | 153 | else: # 说明是非关联案情事实不加粗加红写进去 154 | sentence = case_A_marked[start:j] 155 | if start < j: 156 | entities.append({"start_idx": start-count_signal, "end_idx": j-1-count_signal, "type": "zyjd", "entity":sentence}) 157 | 158 | start = j + 4 159 | j += 4 160 | marked = False 161 | count_signal += 4 162 | 163 | else: 164 | j += 1 165 | global former_type 166 | if last_type == '': 167 | pass 168 | else: 169 | former_type = last_type 170 | if marked == True: 171 | """跨越句号了""" 172 | print("跨越句号") 173 | global count_ky 174 | count_ky += 1 175 | is_continue = True 176 | if start < j: 177 | entities.append({"start_idx": start - count_signal, "end_idx": j - 1 - count_signal, "type": last_type, "entity": case_A_marked[start:j]}) 178 | elif start == 0 and is_continue: 179 | entities.append({"start_idx": start - count_signal, "end_idx": j - 1 - count_signal, "type": former_type, 180 | "entity": case_A_marked[start:j]}) 181 | else: 182 | is_continue = False 183 | case_A_marked = case_A_marked.replace('\\001', '').replace('\\002', '').replace('\\003', '').replace('\\004', 184 | '').replace('\\005', '').replace('\\006', '') 185 | 186 | return {"text":case_A_marked, "entities": entities} 187 | def get_data(data=None): 188 | final_data = [] 189 | if data==None: 190 | with open(input_path, 'r') as f: 191 | data = json.load(f) 192 | else: 193 | pass 194 | for item in data: 195 | # if item["pair_ID"] == "848dbfd5-4263-ab70-e105-e8d9586a15b5|981a181f-5269-3604-25ec-b1eb7a75b8ac": 196 | # continue 197 | case_a_marked = item["case_A_marked"] 198 | case_b_marked = item["case_B_marked"] 199 | case_a_marked_list = split_chinese_sentence(case_a_marked[0]["content"], delete_notag_para=False) + split_chinese_sentence(case_a_marked[1]["content"], delete_notag_para=False) 200 | case_b_marked_list = split_chinese_sentence(case_b_marked[0]["content"], delete_notag_para=False) + split_chinese_sentence(case_b_marked[1]["content"], delete_notag_para=False) 201 | for i in range(len(case_a_marked_list)): 202 | case_a_marked_list[i] = sentence2ner(case_a_marked_list[i]) 203 | for i in range(len(case_b_marked_list)): 204 | case_b_marked_list[i] = sentence2ner(case_b_marked_list[i]) 205 | item["case_a_ner"] = case_a_marked_list 206 | item["case_b_ner"] = case_b_marked_list 207 | final_data.append(item) 208 | return final_data 209 | 210 | def insert_relation(): 211 | with open(input_path, 'r') as f: 212 | data = json.load(f) 213 | 214 | with open(input_path_raw, 'r') as f: 215 | data_raw = json.load(f) 216 | """ 217 | 先把raw 里的 relation 拼接到data上 218 | """ 219 | for item in data: 220 | for raw in data_raw: 221 | if item['pair_ID'] == raw['id']: 222 | item['relation'] = raw['relationships'] 223 | for item in data: 224 | for relation in item['relation']: 225 | case_a_span = [int(relation['entityLeft'].split('#')[0]), int(relation['entityLeft'].split('#')[1])] 226 | case_b_span = [int(relation['entityRight'].split('#')[0]), int(relation['entityRight'].split('#')[1])] 227 | sentence_a = split_chinese_sentence( 228 | "hhhhh" + item['case_A'][0]['content'], delete_notag_para=False) + split_chinese_sentence('hhhhhh' + item['case_A'][1]['content'], delete_notag_para=False) 229 | 230 | sentence_b = split_chinese_sentence( 231 | "hhhhh" + item['case_B'][0]['content'], delete_notag_para=False) + split_chinese_sentence('hhhhhh' + item['case_B'][1]['content'], delete_notag_para=False) 232 | if item['pair_ID'] == '5d196c7b-b198-bc69-54e7-e54741bf65d0|39ff52de-b2c4-6b75-cd75-0ab966872c1e': 233 | print() 234 | selected_sentence_a = get_relation_label(sentence_a, case_a_span) 235 | selected_sentence_b = get_relation_label(sentence_b, case_b_span) 236 | relation['label_Left'] = selected_sentence_a 237 | relation['label_Right'] = selected_sentence_b 238 | return data 239 | 240 | def get_relation_label(sentence, span1): 241 | selected_sentence = [] 242 | count_all = 0 243 | for i, s in enumerate(sentence): 244 | count_all += len(s) 245 | if count_all > span1[0]: 246 | selected_sentence.append(i) 247 | for j in range(i + 1, len(sentence)): 248 | count_all += len(sentence[j]) 249 | if count_all < span1[1]: 250 | selected_sentence.append(j) 251 | else: 252 | break 253 | break 254 | return selected_sentence 255 | 256 | 257 | if __name__ == '__main__': 258 | add_relation = insert_relation() 259 | ner_data = get_data(add_relation) 260 | train_data = ner_data[:int(len(ner_data)*0.8)] 261 | test_data = ner_data[int(len(ner_data)*0.8):int(len(ner_data)*0.9)] 262 | dev_data = ner_data[int(len(ner_data)*0.9):] 263 | with open("../data/law_train.json", 'w') as f: 264 | json.dump(train_data, f, ensure_ascii=False, indent=4) 265 | with open("../data/law_test.json", 'w') as f: 266 | json.dump(test_data, f, ensure_ascii=False, indent=4) 267 | with open("../data/law_dev.json", 'w') as f: 268 | json.dump(dev_data, f, ensure_ascii=False, indent=4) 269 | print(count_ky) 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | -------------------------------------------------------------------------------- /models/seq2seq_model_t5_dismatch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import datetime 4 | from transformers import BertTokenizer, AutoTokenizer, GPT2Tokenizer, GPT2Model 5 | import argparse 6 | import torch 7 | from transformers import AdamW 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from transformers import MT5ForConditionalGeneration 11 | from torch.utils.data import Dataset, DataLoader 12 | import logging 13 | from utils.snippets import * 14 | 15 | # 基本参数 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 18 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 19 | parser.add_argument('--each_test_epoch', type=int, default=1) 20 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 21 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 22 | parser.add_argument('--model_name', type=str, default='t5', help='matching model') 23 | parser.add_argument('--checkpoint', type=str, default="./weights/seq2seq_model", help='checkpoint path') 24 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 25 | parser.add_argument('--maxlen', type=int, default=512, help='max length of each case') 26 | parser.add_argument('--input_size', type=int, default=768) 27 | parser.add_argument('--hidden_size', type=int, default=384) 28 | parser.add_argument('--kernel_size', type=int, default=3) 29 | parser.add_argument('--threshold', type=float, default=0.3) 30 | parser.add_argument('--k_sparse', type=int, default=10) 31 | parser.add_argument('--early_stopping_patience', type=int, default=5) 32 | parser.add_argument('--log_name', type=str, default="log_seq2seq") 33 | parser.add_argument('--seq2seq_type', type=str, default='match') 34 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use') 35 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 36 | parser.add_argument('--train', action='store_true') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | torch.cuda.manual_seed_all(args.seed) 43 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 44 | log_name = args.log_name 45 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 46 | filename='../logs/{}-{}.log'.format(log_name, args.seq2seq_type), 47 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 48 | #a是追加模式,默认如果不写的话,就是追加模式 49 | format= 50 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 51 | #日志格式 52 | ) 53 | 54 | if args.seq2seq_type == 'match': 55 | data_seq2seq_json = '../dataset/match_data_seq2seq.json' 56 | seq2seq_config_json = '../dataset/match_data_seq2seq_config.json' 57 | elif args.seq2seq_type == 'midmatch': 58 | data_seq2seq_json = '../dataset/midmatch_data_seq2seq.json' 59 | seq2seq_config_json = '../dataset/midmatch_data_seq2seq_config.json' 60 | else: 61 | data_seq2seq_json = '../dataset/dismatch_data_seq2seq.json' 62 | seq2seq_config_json = '../dataset/dismatch_data_seq2seq_config.json' 63 | 64 | 65 | def load_data(filename): 66 | """加载数据 67 | 返回:[{...}] 68 | """ 69 | D = [] 70 | with open(filename) as f: 71 | for l in f: 72 | D.append(json.loads(l)) 73 | return D 74 | 75 | 76 | class DataGenerator(Dataset): 77 | def __init__(self, input_data, random=True): 78 | super(DataGenerator, self).__init__() 79 | self.input_data = input_data 80 | self.random = random 81 | 82 | def __len__(self): 83 | return len(self.input_data) 84 | 85 | def __getitem__(self, idx): 86 | i = np.random.choice(2) + 1 if self.random else 1 87 | source_1, target_1 = self.input_data[idx]['source_%s_dis' % i][0], self.input_data[idx]['explanation_dis'][0] 88 | source_2, target_2 = self.input_data[idx]['source_%s_dis' % i][1], self.input_data[idx]['explanation_dis'][1] 89 | return [source_1, target_1], [source_2, target_2] 90 | 91 | 92 | class Collate: 93 | def __init__(self): 94 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_t5_fold) 95 | self.max_seq_len = args.maxlen 96 | 97 | def __call__(self, batch): 98 | source_batch = [] 99 | target_batch = [] 100 | 101 | for item in batch: 102 | source_batch.append(item[0][0]) 103 | source_batch.append(item[1][0]) 104 | target_batch.append(item[0][1]) 105 | target_batch.append(item[1][1]) 106 | 107 | 108 | enc_source_batch = self.tokenizer(source_batch, max_length=self.max_seq_len, truncation=True, return_tensors='pt',padding=True) 109 | source_ids = enc_source_batch["input_ids"] 110 | source_attention_mask = enc_source_batch["attention_mask"] 111 | enc_target_batch = self.tokenizer(target_batch, max_length=self.max_seq_len, truncation=True,return_tensors='pt',padding=True) 112 | target_ids = enc_target_batch["input_ids"] 113 | target_attention_mask = enc_target_batch["attention_mask"] 114 | 115 | features = {'input_ids': source_ids, 'decoder_input_ids': target_ids, 'attention_mask': source_attention_mask, 116 | 'decoder_attention_mask': target_attention_mask} 117 | 118 | return features 119 | 120 | 121 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0,): 122 | data_generator =DataGenerator(data, random=True) 123 | collate = Collate() 124 | return DataLoader( 125 | data_generator, 126 | batch_size=batch_size, 127 | shuffle=shuffle, 128 | num_workers=num_workers, 129 | collate_fn=collate 130 | ) 131 | 132 | 133 | 134 | 135 | def load_checkpoint(model, optimizer, trained_epoch, file_name=None): 136 | if file_name==None: 137 | file_name = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 138 | save_params = torch.load(file_name, map_location=device) 139 | model.load_state_dict(save_params["model"]) 140 | #optimizer.load_state_dict(save_params["optimizer"]) 141 | 142 | 143 | def save_checkpoint(model, optimizer, trained_epoch): 144 | save_params = { 145 | "model": model.state_dict(), 146 | "optimizer": optimizer.state_dict(), 147 | "trained_epoch": trained_epoch, 148 | } 149 | if not os.path.exists(args.checkpoint): 150 | # 判断文件夹是否存在,不存在则创建文件夹 151 | os.mkdir(args.checkpoint) 152 | filename = args.checkpoint + '/' + f"{args.seq2seq_type}-{args.model_name}-seq2seq-{trained_epoch}.pkl" 153 | torch.save(save_params, filename) 154 | 155 | 156 | def train_valid(train_data, valid_data, test_data, model): 157 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 158 | # ema = EMA(model, 0.9999) 159 | # ema.register() 160 | early_stop = EarlyStop(args.early_stopping_patience) 161 | for epoch in range(args.epochs): 162 | epoch_loss = 0. 163 | current_step = 0 164 | model.train() 165 | # for batch_data in tqdm(train_data_loader, ncols=0): 166 | pbar = tqdm(train_data, desc="Iteration", postfix='train') 167 | for batch_data in pbar: 168 | cur = {k: v.to(device) for k, v in batch_data.items()} 169 | prob = model(**cur)[0] 170 | mask = cur['decoder_attention_mask'][:, 1:].reshape(-1).bool() 171 | prob = prob[:, :-1] 172 | prob = prob.reshape((-1, prob.size(-1)))[mask] 173 | labels = cur['decoder_input_ids'][:, 1:].reshape(-1)[mask] 174 | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) 175 | loss = loss_fct(prob, labels) 176 | loss.backward() 177 | optimizer.step() 178 | optimizer.zero_grad() 179 | # ema.update() 180 | loss_item = loss.cpu().detach().item() 181 | epoch_loss += loss_item/labels.shape[-1] 182 | current_step += 1 183 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 184 | if current_step % 100 == 0: 185 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 186 | 187 | epoch_loss = epoch_loss / current_step 188 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 189 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 190 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 191 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema 192 | with torch.no_grad(): 193 | model.eval() 194 | current_val_metric_value = evaluate(valid_data, model, type='valid')['main'] 195 | is_save = early_stop.step(current_val_metric_value, epoch) 196 | if is_save: 197 | save_checkpoint(model, optimizer, epoch) 198 | else: 199 | pass 200 | if early_stop.stop_training(epoch): 201 | logging.info( 202 | "early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 203 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 204 | )) 205 | print( 206 | "early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 207 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 208 | )) 209 | break 210 | evaluate(test_data, model, type='test') 211 | 212 | 213 | def generate(text, model, max_length=30): 214 | tokenizer = BertTokenizer.from_pretrained(pretrained_t5_fold) 215 | feature = tokenizer.encode(text, return_token_type_ids=True, return_tensors='pt', 216 | max_length=args.maxlen, truncation=True) 217 | feature = {'input_ids': feature} 218 | feature = {k: v.to(device) for k, v in list(feature.items())} 219 | 220 | gen = model.generate(max_length=max_length, eos_token_id=tokenizer.sep_token_id, 221 | decoder_start_token_id=tokenizer.cls_token_id, 222 | **feature).cpu().numpy()[0] 223 | gen = gen[1:] 224 | gen = tokenizer.decode(gen, skip_special_tokens=True).replace(' ', '') 225 | return gen 226 | 227 | 228 | def evaluate(data, model, filename=None, type='valid'): 229 | """验证集评估 230 | """ 231 | if filename is not None: 232 | F = open(filename, 'w', encoding='utf-8') 233 | total_metrics = {k: 0.0 for k in metric_keys} 234 | for d in tqdm(data, desc=u'评估中'): 235 | pred_summary_1 = generate(d['source_1_dis'][0], model, max_length=args.maxlen//2) 236 | pred_summary_2 = generate(d['source_1_dis'][1], model, max_length=args.maxlen//2) 237 | metrics = compute_metrics(pred_summary_1+pred_summary_2, d['explanation_dis'][0]+d['explanation_dis'][1]) 238 | for k, v in metrics.items(): 239 | total_metrics[k] += v 240 | if filename is not None: 241 | F.write(d['explanation_dis'][0]+d['explanation_dis'][1] + '\t' + pred_summary_1 + pred_summary_2 + '\n') 242 | F.flush() 243 | if filename is not None: 244 | F.close() 245 | print(total_metrics) 246 | logging.info("~~~~~~~~{}~~~~~~~~~~~".format(type)) 247 | for k, v in total_metrics.items(): 248 | logging.info(k+": {} ".format(v/len(data))) 249 | logging.info("~~~~~~~~{}~~~~~~~~~~~".format(type)) 250 | return {k: v / len(data) for k, v in total_metrics.items()} 251 | 252 | if __name__ == '__main__': 253 | # 加载数据 254 | data = load_data(data_seq2seq_json) 255 | train_data = data_split(data, 'train') 256 | valid_data = data_split(data, 'valid') 257 | test_data = data_split(data, 'test') 258 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size) 259 | G_model = MT5ForConditionalGeneration.from_pretrained(pretrained_t5_fold) 260 | if args.train: 261 | G_model = G_model.to(device) 262 | train_valid(train_data_loader, valid_data, test_data,G_model) 263 | else: 264 | load_checkpoint(G_model, None, None, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/match-seq2seq-8.pkl") 265 | with torch.no_grad(): 266 | G_model.eval() 267 | evaluate(valid_data, G_model) 268 | 269 | 270 | 271 | 272 | 273 | -------------------------------------------------------------------------------- /models/seq2seq_model_t5.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import datetime 4 | from transformers import BertTokenizer, AutoTokenizer, GPT2Tokenizer, GPT2Model, T5Tokenizer 5 | import argparse 6 | import torch 7 | from transformers import AdamW 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | from transformers import MT5ForConditionalGeneration 11 | from torch.utils.data import Dataset, DataLoader 12 | import logging 13 | from utils.snippets import * 14 | 15 | # 基本参数 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 18 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 19 | parser.add_argument('--each_test_epoch', type=int, default=1) 20 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 21 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 22 | parser.add_argument('--model_name', type=str, default='t5', help='matching model') 23 | parser.add_argument('--checkpoint', type=str, default="./weights/seq2seq_model", help='checkpoint path') 24 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 25 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case') 26 | parser.add_argument('--input_size', type=int, default=768) 27 | parser.add_argument('--hidden_size', type=int, default=384) 28 | parser.add_argument('--kernel_size', type=int, default=3) 29 | parser.add_argument('--threshold', type=float, default=0.3) 30 | parser.add_argument('--k_sparse', type=int, default=10) 31 | parser.add_argument('--early_stopping_patience', type=int, default=5) 32 | parser.add_argument('--log_name', type=str, default="log_seq2seq") 33 | parser.add_argument('--seq2seq_type', type=str, default='match') 34 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use') 35 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 36 | parser.add_argument('--train', action='store_true') 37 | parser.add_argument('--data_type', type=str, default='ELAM', help="data type:[ELAM, CAIL]") 38 | parser.add_argument('--data_format', type=str, default='sort', help="data format") 39 | args = parser.parse_args() 40 | print(args) 41 | np.random.seed(args.seed) 42 | torch.manual_seed(args.seed) 43 | torch.cuda.manual_seed_all(args.seed) 44 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 45 | log_name = args.log_name 46 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 47 | filename='../logs/{}-{}-{}-{}.log'.format(log_name, args.model_name, args.seq2seq_type, args.data_type), 48 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 49 | #a是追加模式,默认如果不写的话,就是追加模式 50 | format= 51 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 52 | #日志格式 53 | ) 54 | logging.info(args) 55 | if args.seq2seq_type == 'match': 56 | data_seq2seq_json = '../models_v2/data/our_data/match_data_seq2seq_{}.json'.format(args.data_format) 57 | elif args.seq2seq_type == 'midmatch': 58 | data_seq2seq_json = '../models_v2/data/our_data/midmatch_data_seq2seq_{}.json'.format(args.data_format) 59 | else: 60 | data_seq2seq_json = '../models_v2/data/our_data/dismatch_data_seq2seq_{}.json'.format(args.data_format) 61 | 62 | 63 | def load_data(filename): 64 | """加载数据 65 | 返回:[{...}] 66 | """ 67 | D = [] 68 | with open(filename) as f: 69 | for l in f: 70 | D.append(json.loads(l)) 71 | return D 72 | 73 | 74 | class DataGenerator(Dataset): 75 | def __init__(self, input_data, random=True): 76 | super(DataGenerator, self).__init__() 77 | self.input_data = input_data 78 | self.random = random 79 | 80 | def __len__(self): 81 | return len(self.input_data) 82 | 83 | def __getitem__(self, idx): 84 | 85 | i = np.random.choice(2) + 1 if self.random else 1 86 | source, target = self.input_data[idx]['source_%s' % i], self.input_data[idx]['explanation'] 87 | return [source, target] 88 | 89 | 90 | class Collate: 91 | def __init__(self, tokenizer): 92 | self.tokenizer = tokenizer 93 | self.max_seq_len = args.maxlen 94 | 95 | def __call__(self, batch): 96 | source_batch = [] 97 | target_batch = [] 98 | for item in batch: 99 | source_batch.append(item[0]) 100 | target_batch.append(item[1]) 101 | 102 | enc_source_batch = self.tokenizer(source_batch, max_length=self.max_seq_len, truncation=True, padding=True, return_tensors='pt') 103 | source_ids = enc_source_batch["input_ids"] 104 | source_attention_mask = enc_source_batch["attention_mask"] 105 | enc_target_batch = self.tokenizer(target_batch, max_length=self.max_seq_len, truncation=True, padding=True, return_tensors='pt') 106 | target_ids = enc_target_batch["input_ids"] 107 | target_attention_mask = enc_target_batch["attention_mask"] 108 | 109 | features = {'input_ids': source_ids, 'decoder_input_ids': target_ids, 'attention_mask': source_attention_mask, 110 | 'decoder_attention_mask': target_attention_mask} 111 | 112 | return features 113 | 114 | 115 | def build_pretrain_dataloader(data, batch_size, tokenizer, shuffle=True, num_workers=0): 116 | data_generator =DataGenerator(data, random=True) 117 | collate = Collate(tokenizer) 118 | return DataLoader( 119 | data_generator, 120 | batch_size=batch_size, 121 | shuffle=shuffle, 122 | num_workers=num_workers, 123 | collate_fn=collate 124 | ) 125 | 126 | 127 | 128 | 129 | def load_checkpoint(model, optimizer, trained_epoch, file_name=None): 130 | if file_name==None: 131 | file_name = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 132 | save_params = torch.load(file_name, map_location=device) 133 | model.load_state_dict(save_params["model"]) 134 | #optimizer.load_state_dict(save_params["optimizer"]) 135 | 136 | 137 | def save_checkpoint(model, optimizer, trained_epoch): 138 | save_params = { 139 | "model": model.state_dict(), 140 | "optimizer": optimizer.state_dict(), 141 | "trained_epoch": trained_epoch, 142 | } 143 | if not os.path.exists(args.checkpoint): 144 | # 判断文件夹是否存在,不存在则创建文件夹 145 | os.mkdir(args.checkpoint) 146 | filename = args.checkpoint + '/' + f"{args.seq2seq_type}-{args.model_name}-seq2seq-{args.data_type}-{args.data_format}.pkl" 147 | torch.save(save_params, filename) 148 | 149 | 150 | def train_valid(train_data, valid_data, test_data, model, tokenizer): 151 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 152 | # ema = EMA(model, 0.9999) 153 | # ema.register() 154 | early_stop = EarlyStop(args.early_stopping_patience) 155 | for epoch in range(args.epochs): 156 | epoch_loss = 0. 157 | current_step = 0 158 | model.train() 159 | # for batch_data in tqdm(train_data_loader, ncols=0): 160 | pbar = tqdm(train_data, desc="Iteration", postfix='train') 161 | for batch_data in pbar: 162 | cur = {k: v.to(device) for k, v in batch_data.items()} 163 | prob = model(**cur)[0] 164 | mask = cur['decoder_attention_mask'][:, 1:].reshape(-1).bool() 165 | prob = prob[:, :-1] 166 | prob = prob.reshape((-1, prob.size(-1)))[mask] 167 | labels = cur['decoder_input_ids'][:, 1:].reshape(-1)[mask] 168 | loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) 169 | loss = loss_fct(prob, labels) 170 | loss.backward() 171 | optimizer.step() 172 | optimizer.zero_grad() 173 | # ema.update() 174 | loss_item = loss.cpu().detach().item() 175 | epoch_loss += loss_item/labels.shape[-1] 176 | current_step += 1 177 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 178 | if current_step % 100 == 0: 179 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 180 | 181 | epoch_loss = epoch_loss / current_step 182 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 183 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 184 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 185 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema 186 | with torch.no_grad(): 187 | model.eval() 188 | current_val_metric_value = evaluate(valid_data, model, tokenizer, type='valid')['main'] 189 | is_save = early_stop.step(current_val_metric_value, epoch) 190 | if is_save: 191 | save_checkpoint(model, optimizer, epoch) 192 | else: 193 | pass 194 | if early_stop.stop_training(epoch): 195 | logging.info( 196 | "early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 197 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 198 | )) 199 | print( 200 | "early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 201 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 202 | )) 203 | break 204 | evaluate(test_data, model, tokenizer, type='test') 205 | 206 | 207 | def generate(text, model, tokenizer, device=device, max_length=30): 208 | feature = tokenizer.encode(text, return_token_type_ids=True, return_tensors='pt', 209 | max_length=args.maxlen, truncation=True) 210 | feature = {'input_ids': feature} 211 | feature = {k: v.to(device) for k, v in list(feature.items())} 212 | 213 | gen = model.generate(max_length=max_length, eos_token_id=tokenizer.sep_token_id, 214 | decoder_start_token_id=tokenizer.cls_token_id, 215 | **feature).cpu().numpy()[0] 216 | gen = gen[1:] 217 | gen = tokenizer.decode(gen, skip_special_tokens=True).replace(' ', '') 218 | return gen 219 | 220 | 221 | def evaluate(data, model, tokenizer, filename=None, type='valid'): 222 | """验证集评估 223 | """ 224 | if filename is not None: 225 | F = open(filename, 'w', encoding='utf-8') 226 | total_metrics = {k: 0.0 for k in metric_keys} 227 | for d in tqdm(data, desc=u'评估中'): 228 | pred_summary = generate(d['source_1'], model, tokenizer, max_length=args.maxlen//4) 229 | metrics = compute_metrics(pred_summary, d['explanation']) 230 | for k, v in metrics.items(): 231 | total_metrics[k] += v 232 | if filename is not None: 233 | F.write(d['explanation'] + '\t' + pred_summary + '\n') 234 | F.flush() 235 | if filename is not None: 236 | F.close() 237 | print(total_metrics) 238 | logging.info("~~~~~~~~{}~~~~~~~~~~~".format(type)) 239 | for k, v in total_metrics.items(): 240 | logging.info(k+": {} ".format(v/len(data))) 241 | logging.info("~~~~~~~~{}~~~~~~~~~~~".format(type)) 242 | return {k: v / len(data) for k, v in total_metrics.items()} 243 | 244 | if __name__ == '__main__': 245 | # 加载数据 246 | tokenizer = T5PegasusTokenizer.from_pretrained(pretrained_t5_fold) 247 | if args.data_type == 'CAIL': 248 | tokenizer.add_tokens(["[O]", '[I]']) 249 | elif args.data_type == 'ELAM': 250 | tokenizer.add_tokens(["[AO]", "[YO]", "[ZO]", '[AI]', "[YI]", "[ZI]"]) 251 | else: 252 | print("data type error") 253 | exit(-1) 254 | data = load_data(data_seq2seq_json) 255 | train_data = data_split(data, 'train') 256 | valid_data = data_split(data, 'valid') 257 | test_data = data_split(data, 'test') 258 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size, tokenizer) 259 | G_model = MT5ForConditionalGeneration.from_pretrained(pretrained_t5_fold) 260 | G_model.resize_token_embeddings(len(tokenizer)) # 扩充embedding 261 | if args.train: 262 | G_model = G_model.to(device) 263 | train_valid(train_data_loader, valid_data, test_data, G_model, tokenizer) 264 | else: 265 | load_checkpoint(G_model, None, None, "/home/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/match-t5-seq2seq-ELAM-sort.pkl") 266 | G_model = G_model.to(device) 267 | with torch.no_grad(): 268 | G_model.eval() 269 | evaluate(valid_data, G_model, tokenizer) 270 | 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /data_utils/seq2seq_convert.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | from models.selector_two_multi_class_ot_v3 import Selector2_mul_class, args, load_checkpoint, OT, load_data, data_extract_npy, data_extract_json, device 4 | import torch 5 | from utils.snippets import * 6 | import torch 7 | import json 8 | """ 9 | todo 需要添加一个分类器,对相似句子与不相似句子进行分类 10 | """ 11 | def model_class(model, OT_model, case_A, case_B, seq_len_A, seq_len_B): 12 | """ 13 | 14 | :param model: 15 | :param OT_model: 16 | :param case_A: 17 | :param case_B: 18 | :return: AO, YO, ZO, AI, YI, ZI 19 | """ 20 | AO_a, YO_a, ZO_a, AI_a, YI_a, ZI_a = [], [], [], [], [], [] 21 | AO_b, YO_b, ZO_b, AI_b, YI_b, ZI_b = [], [], [], [], [], [] 22 | 23 | output_batch_A, batch_mask_A = model(case_A) 24 | output_batch_B, batch_mask_B = model(case_B) 25 | plan_list = OT_model(output_batch_A, output_batch_B, case_A, case_B, None, 26 | batch_mask_A, batch_mask_B, 'valid') 27 | OT_matrix = torch.ge(plan_list, 1 / case_A.shape[1] / args.threshold_ot).long() 28 | vec_correct_A = torch.argmax(output_batch_A, dim=-1).long()[0][:seq_len_A] 29 | vec_correct_B = torch.argmax(output_batch_B, dim=-1).long()[0][:seq_len_B] 30 | relation_A = torch.sum(OT_matrix[0], dim=1) 31 | relation_B = torch.sum(OT_matrix[0], dim=0) 32 | 33 | for i, label in enumerate(vec_correct_A): 34 | if label == 1: 35 | if relation_A[i] >= 1: 36 | AI_a.append(i) 37 | else: 38 | AO_a.append(i) 39 | elif label == 2: 40 | if relation_A[i] >= 1: 41 | YI_a.append(i) 42 | else: 43 | YO_a.append(i) 44 | elif label == 3: 45 | if relation_A[i] >= 1: 46 | ZI_a.append(i) 47 | else: 48 | ZO_a.append(i) 49 | 50 | for i, label in enumerate(vec_correct_B): 51 | if label == 1: 52 | if relation_B[i] >= 1: 53 | AI_b.append(i) 54 | else: 55 | AO_b.append(i) 56 | elif label == 2: 57 | if relation_B[i] >= 1: 58 | YI_b.append(i) 59 | else: 60 | YO_b.append(i) 61 | elif label == 3: 62 | if relation_B[i] >= 1: 63 | ZI_b.append(i) 64 | else: 65 | ZO_b.append(i) 66 | AO, YO, ZO, AI, YI, ZI = [AO_a, AO_b], [YO_a, YO_b], [ZO_a, ZO_b], [AI_a, AI_b], [YI_a, YI_b], [ZI_a, ZI_b] 67 | 68 | return AO, YO, ZO, AI, YI, ZI, [vec_correct_A+(torch.ge(relation_A[:seq_len_A], 1)*3)*vec_correct_A, vec_correct_B+(torch.ge(relation_B[:seq_len_B], 1)*3)*vec_correct_B] 69 | 70 | 71 | 72 | 73 | def generate_text_cluster(case_a, case_b, d, AO, YO, ZO, AI, YI, ZI, A_all_true, Y_all_true, Z_all_true, AI_true, YI_true, ZI_true): 74 | source_1_a = ''.join(["[AO]" + case_a[0][i] for i in AO[0]] + ["[YO]" + case_a[0][i] for i in YO[0]] + 75 | ["[ZO]" + case_a[0][i] for i in ZO[0]] + ["[AI]" + case_a[0][i] for i in AI[0]] + 76 | ["[YI]" + case_a[0][i] for i in YI[0]] + ["[ZI]" + case_a[0][i] for i in ZI[0]]) 77 | 78 | source_1_b = ''.join(["[AO]" + case_b[0][i] for i in AO[1]] + ["[YO]" + case_b[0][i] for i in YO[1]] + 79 | ["[ZO]" + case_b[0][i] for i in ZO[1]] + ["[AI]" + case_b[0][i] for i in AI[1]] + 80 | ["[YI]" + case_b[0][i] for i in YI[1]] + ["[ZI]" + case_b[0][i] for i in ZI[1]]) 81 | 82 | source_2_a = ''.join( 83 | ["[AO]" + case_a[0][i] for i in A_all_true[0] if i not in AI_true[0]] + ["[YO]" + case_a[0][i] for i in 84 | Y_all_true[0] if i not in YI_true[0]] + 85 | ["[ZO]" + case_a[0][i] for i in Z_all_true[0] if i not in ZI_true[0]] + ["[AI]" + case_a[0][i] for i in 86 | AI_true[0]] + 87 | ["[YI]" + case_a[0][i] for i in YI_true[0]] + ["[ZI]" + case_a[0][i] for i in ZI_true[0]]) 88 | 89 | source_2_b = ''.join( 90 | ["[AO]" + case_b[0][i] for i in A_all_true[1] if i not in AI_true[1]] + ["[YO]" + case_b[0][i] for i in 91 | Y_all_true[1] if i not in YI_true[1]] + 92 | ["[ZO]" + case_b[0][i] for i in Z_all_true[1] if i not in ZI_true[1]] + ["[AI]" + case_b[0][i] for i in 93 | AI_true[1]] + 94 | ["[YI]" + case_b[0][i] for i in YI_true[1]] + ["[ZI]" + case_b[0][i] for i in ZI_true[1]]) 95 | 96 | result = { 97 | 'source_1': source_1_a + source_1_b, 98 | 'source_2': source_2_a + source_2_b, 99 | 'explanation': ';'.join(list(d['explanation'].values())), 100 | 'source_1_dis': [source_1_a, source_1_b], 101 | 'source_2_dis': [source_2_a, source_2_b], 102 | 'label': d['label'] 103 | } 104 | return result 105 | 106 | 107 | def get_extract_text(case_a, prediction): 108 | source_1_a = '' 109 | for i, output_class in enumerate(prediction): 110 | if output_class == 1: 111 | source_1_a += "[AO]" + case_a[0][i] 112 | elif output_class == 2: 113 | source_1_a += "[YO]" + case_a[0][i] 114 | elif output_class == 3: 115 | source_1_a += "[ZO]" + case_a[0][i] 116 | elif output_class == 4: 117 | source_1_a += "[AI]" + case_a[0][i] 118 | elif output_class == 5: 119 | source_1_a += "[YI]" + case_a[0][i] 120 | elif output_class == 6: 121 | source_1_a += "[ZI]" + case_a[0][i] 122 | else: 123 | pass 124 | return source_1_a 125 | 126 | 127 | def get_extract_text_wo_token(case_a, prediction): 128 | source_1_a = '' 129 | for i, output_class in enumerate(prediction): 130 | if output_class != 0: 131 | source_1_a += case_a[0][i] 132 | else: 133 | pass 134 | return source_1_a 135 | 136 | 137 | def generate_text_sort(case_a, case_b, d, prediction, label): 138 | 139 | source_1_a = get_extract_text(case_a, prediction[0]) 140 | source_1_b = get_extract_text(case_b, prediction[1]) 141 | source_2_a = get_extract_text(case_a, label[0]) 142 | source_2_b = get_extract_text(case_b, label[1]) 143 | 144 | result = { 145 | 'source_1': source_1_a + source_1_b, 146 | 'source_2': source_2_a + source_2_b, 147 | 'explanation': ';'.join(list(d['explanation'].values())), 148 | 'source_1_dis': [source_1_a, source_1_b], 149 | 'source_2_dis': [source_2_a, source_2_b], 150 | 'label': d['label'] 151 | } 152 | return result 153 | 154 | 155 | def generate_text_wo_token(case_a, case_b, d, prediction, label): 156 | 157 | source_1_a = get_extract_text_wo_token(case_a, prediction[0]) 158 | source_1_b = get_extract_text_wo_token(case_b, prediction[1]) 159 | source_2_a = get_extract_text_wo_token(case_a, label[0]) 160 | source_2_b = get_extract_text_wo_token(case_b, label[1]) 161 | 162 | result = { 163 | 'source_1': source_1_a + source_1_b, 164 | 'source_2': source_2_a + source_2_b, 165 | 'explanation': ';'.join(list(d['explanation'].values())), 166 | 'source_1_dis': [source_1_a, source_1_b], 167 | 'source_2_dis': [source_2_a, source_2_b], 168 | 'label': d['label'] 169 | } 170 | return result 171 | 172 | 173 | def fold_convert_our_data_ot(data, data_x, type, generate=False, generate_mode = 'cluster'): 174 | """每一fold用对应的模型做数据转换 175 | """ 176 | 177 | with torch.no_grad(): 178 | model = Selector2_mul_class(args.input_size, args.hidden_size, kernel_size=args.kernel_size, dilation_rate=[1, 2, 4, 8, 1, 1]) 179 | load_checkpoint(model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/extract-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 180 | model = model.to(device) 181 | ot_model = OT() 182 | ot_model = ot_model.to(device) 183 | load_checkpoint(ot_model, None, 2, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/extract/extract_ot-criterion-BCEFocal-onehot-1-ot_mode-max-convert_to_onehot-1-weight-100-simot-1-simpercent-1.0.pkl") 184 | results = [] 185 | print(type+"ing") 186 | for i, d in enumerate(data): 187 | if type == 'match' and d["label"] == 2 or type == 'midmatch' and d["label"] == 1 or type == 'dismatch' and d["label"] == 0: 188 | case_a = d['case_A'] 189 | case_b = d['case_B'] 190 | important_A, important_B = [], [] 191 | data_y_seven_class_A, data_y_seven_class_B = [0]*len(case_a[0]), [0]*len(case_b[0]) 192 | for pos_list in d['relation_label'].values(): 193 | for pos in pos_list: 194 | row, col = pos[0], pos[-1] 195 | important_A.append(row) 196 | important_B.append(col) 197 | 198 | for j in case_a[1]: 199 | if j[0] in important_A: 200 | data_y_seven_class_A[j[0]] = j[1] + 3 201 | else: 202 | data_y_seven_class_A[j[0]] = j[1] 203 | 204 | for j in case_b[1]: 205 | if j[0] in important_B: 206 | data_y_seven_class_B[j[0]] = j[1] + 3 207 | else: 208 | data_y_seven_class_B[j[0]] = j[1] 209 | label = [data_y_seven_class_A, data_y_seven_class_B] 210 | 211 | 212 | 213 | AO, YO, ZO, AI, YI, ZI, prediction = model_class(model, ot_model, torch.tensor(np.expand_dims(data_x[2*i], axis=0), device=device), 214 | torch.tensor(np.expand_dims(data_x[2*i+1], axis=0), device=device),len(case_a[0]), len(case_b[0])) 215 | 216 | A_all_true, Y_all_true, Z_all_true, AI_true, YI_true, ZI_true = [], [], [], [], [], [] 217 | 218 | temp_a, temp_b = [], [] 219 | for i in d['relation_label']['relation_label_aqss']: 220 | temp_a.append(i[0]) 221 | temp_b.append(i[1]) 222 | AI_true.append(temp_a) 223 | AI_true.append(temp_b) 224 | 225 | temp_a, temp_b = [], [] 226 | for i in d['relation_label']['relation_label_yjss']: 227 | temp_a.append(i[0]) 228 | temp_b.append(i[1]) 229 | YI_true.append(temp_a) 230 | YI_true.append(temp_b) 231 | 232 | temp_a, temp_b = [], [] 233 | for i in d['relation_label']['relation_label_zyjd']: 234 | temp_a.append(i[0]) 235 | temp_b.append(i[1]) 236 | ZI_true.append(temp_a) 237 | ZI_true.append(temp_b) 238 | aqss_temp, yjss_temp, zyjd_temp = [], [], [] 239 | for i in case_a[1]: 240 | if i[1] == 1: 241 | aqss_temp.append(i[0]) 242 | elif i[1] == 2: 243 | yjss_temp.append(i[0]) 244 | elif i[1] == 3: 245 | zyjd_temp.append(i[0]) 246 | A_all_true.append(aqss_temp) 247 | Y_all_true.append(yjss_temp) 248 | Z_all_true.append(zyjd_temp) 249 | 250 | aqss_temp, yjss_temp, zyjd_temp = [], [], [] 251 | for i in case_b[1]: 252 | if i[1] == 1: 253 | aqss_temp.append(i[0]) 254 | elif i[1] == 2: 255 | yjss_temp.append(i[0]) 256 | elif i[1] == 3: 257 | zyjd_temp.append(i[0]) 258 | A_all_true.append(aqss_temp) 259 | Y_all_true.append(yjss_temp) 260 | Z_all_true.append(zyjd_temp) 261 | if generate: 262 | if generate_mode == 'cluster': 263 | results.append(generate_text_cluster(case_a, case_b, d, AO, YO, ZO, AI, YI, ZI, A_all_true, Y_all_true, Z_all_true, AI_true, YI_true, 264 | ZI_true)) 265 | elif generate_mode == 'sort': 266 | results.append(generate_text_sort(case_a, case_b, d, prediction, label)) 267 | else: 268 | results.append(generate_text_wo_token(case_a, case_b, d, prediction, label)) 269 | 270 | if generate: 271 | return results 272 | 273 | 274 | 275 | 276 | 277 | def convert(filename, data, data_x, type, generate_mode): 278 | """转换为生成式数据 279 | """ 280 | total_results = fold_convert_our_data_ot(data, data_x, type, generate=True, generate_mode=generate_mode) 281 | 282 | with open(filename, 'w') as f: 283 | for item in total_results: 284 | f.writelines(json.dumps(item, ensure_ascii=False)) 285 | f.write('\n') 286 | 287 | 288 | 289 | if __name__ == '__main__': 290 | 291 | data = load_data(data_extract_json) 292 | data_x = np.load(data_extract_npy) 293 | da_type = "sort" 294 | match_data_seq2seq_json = '../dataset/our_data/match_data_seq2seq_{}.json'.format(da_type) 295 | midmatch_data_seq2seq_json = '../dataset/our_data/midmatch_data_seq2seq_{}.json'.format(da_type) 296 | dismatch_data_seq2seq_json = '../dataset/our_data/dismatch_data_seq2seq_{}.json'.format(da_type) 297 | convert(match_data_seq2seq_json, data, data_x, type='match', generate_mode=da_type) 298 | convert(midmatch_data_seq2seq_json, data, data_x, type='midmatch', generate_mode=da_type) 299 | convert(dismatch_data_seq2seq_json, data, data_x, type='dismatch', generate_mode=da_type) 300 | 301 | 302 | print(u'输出over!') 303 | -------------------------------------------------------------------------------- /models/seq2seq_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import datetime 4 | from transformers import BertTokenizer, AutoTokenizer, GPT2Tokenizer, GPT2Model 5 | import argparse 6 | import torch 7 | from transformers import AdamW 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | import copy 11 | from torch.utils.data import Dataset, DataLoader 12 | import logging 13 | from utils.snippets import * 14 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 15 | from bert_seq2seq import load_bert, load_gpt 16 | # 基本参数 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 19 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 20 | parser.add_argument('--each_test_epoch', type=int, default=1) 21 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 22 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 23 | parser.add_argument('--model_name', type=str, default='nezha', help='matching model') 24 | parser.add_argument('--checkpoint', type=str, default="./weights/seq2seq_model", help='checkpoint path') 25 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 26 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case') 27 | parser.add_argument('--input_size', type=int, default=768) 28 | parser.add_argument('--hidden_size', type=int, default=384) 29 | parser.add_argument('--kernel_size', type=int, default=3) 30 | parser.add_argument('--threshold', type=float, default=0.3) 31 | parser.add_argument('--k_sparse', type=int, default=10) 32 | parser.add_argument('--log_name', type=str, default="log_seq2seq") 33 | parser.add_argument('--seq2seq_type', type=str, default='match') 34 | parser.add_argument('--cuda_pos', type=str, default='1', help='which GPU to use') 35 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 36 | parser.add_argument('--train', action='store_true') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | torch.cuda.manual_seed_all(args.seed) 43 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 44 | log_name = args.log_name 45 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 46 | filename='../logs/{}.log'.format(log_name), 47 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 48 | #a是追加模式,默认如果不写的话,就是追加模式 49 | format= 50 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 51 | #日志格式 52 | ) 53 | 54 | if args.seq2seq_type == 'match': 55 | data_seq2seq_json = '../dataset/match_data_seq2seq.json' 56 | seq2seq_config_json = '../dataset/match_data_seq2seq_config.json' 57 | elif args.seq2seq_type == 'midmatch': 58 | data_seq2seq_json = '../dataset/midmatch_data_seq2seq.json' 59 | seq2seq_config_json = '../dataset/midmatch_data_seq2seq_config.json' 60 | else: 61 | data_seq2seq_json = '../dataset/dismatch_data_seq2seq.json' 62 | seq2seq_config_json = '../dataset/dismatch_data_seq2seq_config.json' 63 | 64 | 65 | 66 | 67 | 68 | def load_data(filename): 69 | """加载数据 70 | 返回:[{...}] 71 | """ 72 | D = [] 73 | with open(filename) as f: 74 | for l in f: 75 | D.append(json.loads(l)) 76 | return D 77 | 78 | 79 | 80 | 81 | def generate_copy_labels(source, target): 82 | """构建copy机制对应的label 83 | """ 84 | mapping = longest_common_subsequence(source, target)[1] 85 | source_labels = [0] * len(source) 86 | target_labels = [0] * len(target) 87 | i0, j0 = -2, -2 88 | for i, j in mapping: 89 | if i == i0 + 1 and j == j0 + 1: 90 | source_labels[i] = 2 91 | target_labels[j] = 2 92 | else: 93 | source_labels[i] = 1 94 | target_labels[j] = 1 95 | i0, j0 = i, j 96 | return source_labels, target_labels 97 | 98 | 99 | def random_masking(token_ids_all): 100 | """对输入进行随机mask,增加泛化能力 101 | """ 102 | result = [] 103 | for token_ids in token_ids_all: 104 | rands = np.random.random(len(token_ids)) 105 | result.append([ 106 | t if r > 0.15 else np.random.choice(token_ids) 107 | for r, t in zip(rands, token_ids) 108 | ]) 109 | return result 110 | 111 | 112 | class DataGenerator(Dataset): 113 | def __init__(self, input_data, random=True): 114 | super(DataGenerator, self).__init__() 115 | self.input_data = input_data 116 | self.random = random 117 | 118 | def __len__(self): 119 | return len(self.input_data) 120 | 121 | def __getitem__(self, idx): 122 | 123 | i = np.random.choice(2) + 1 if self.random else 1 124 | source, target = self.input_data[idx]['source_%s' % i], self.input_data[idx]['explanation'] 125 | return [source, target] 126 | 127 | 128 | class Collate: 129 | def __init__(self): 130 | if args.model_name == 'nezha': 131 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 132 | self.max_seq_len = args.maxlen 133 | elif args.model_name == 't5': 134 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_t5_fold) 135 | self.max_seq_len = args.maxlen 136 | 137 | def __call__(self, batch): 138 | # assert len(A_batch) == 1 139 | # print("A_batch: ", A_batch) 140 | if args.model_name == 'nezha': 141 | dic_data = self.tokenizer.batch_encode_plus(batch, padding=True, truncation=True, 142 | max_length=self.max_seq_len) 143 | mask_dic_data = copy.deepcopy(dic_data) 144 | 145 | token_ids = dic_data["input_ids"] 146 | 147 | masked_token_ids = random_masking(token_ids) 148 | mask_dic_data['input_ids'] = masked_token_ids 149 | labels = [] 150 | for item_masked_token_ids, item_token_ids in zip(masked_token_ids, token_ids): 151 | idx = item_token_ids.index(self.tokenizer.sep_token_id) + 1 152 | source_labels, target_labels = generate_copy_labels( 153 | item_masked_token_ids[:idx], item_token_ids[idx:] 154 | ) 155 | """ 156 | [CLS]...[SEP] ... [SEP] 157 | """ 158 | labels.append(source_labels[1:] + target_labels) # 因为是预测所以第一位后移 159 | 160 | return torch.tensor(dic_data["input_ids"]), torch.tensor(dic_data["token_type_ids"]), torch.tensor(labels) 161 | elif args.model_name == 't5': 162 | source_batch = [] 163 | target_batch = [] 164 | for item in batch: 165 | source_batch.append(item[0]) 166 | target_batch.append(item[1]) 167 | 168 | enc_source_batch = self.tokenizer(source_batch, max_length=self.max_seq_len, truncation='only_first') 169 | source_ids = enc_source_batch["input_ids"] 170 | source_attention_mask = enc_source_batch["attention_mask"] 171 | enc_target_batch = self.tokenizer(target_batch, max_length=self.max_seq_len, truncation='only_first') 172 | target_ids = enc_target_batch["input_ids"] 173 | target_attention_mask = enc_target_batch["attention_mask"] 174 | 175 | features = {'input_ids': source_ids, 'decoder_input_ids': target_ids, 'attention_mask': source_attention_mask, 176 | 'decoder_attention_mask': target_attention_mask} 177 | 178 | return features 179 | 180 | 181 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0,): 182 | data_generator =DataGenerator(data, random=True) 183 | collate = Collate() 184 | return DataLoader( 185 | data_generator, 186 | batch_size=batch_size, 187 | shuffle=shuffle, 188 | num_workers=num_workers, 189 | collate_fn=collate 190 | ) 191 | 192 | 193 | def compute_seq2seq_loss(predictions, token_type_id, input_ids, vocab_size): 194 | 195 | predictions = predictions[:, :-1].contiguous() 196 | target_mask = token_type_id[:, 1:].contiguous() 197 | """ 198 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 199 | """ 200 | predictions = predictions.view(-1, vocab_size) 201 | labels = input_ids[:, 1:].contiguous() 202 | labels = labels.view(-1) 203 | target_mask = target_mask.view(-1).float() 204 | # 正loss 205 | pos_loss = predictions[list(range(predictions.shape[0])), labels] 206 | # 负loss 207 | y_pred = torch.topk(predictions, k=args.k_sparse)[0] 208 | neg_loss = torch.logsumexp(y_pred, dim=-1) 209 | 210 | loss = neg_loss - pos_loss 211 | return (loss * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 212 | 213 | 214 | def compute_copy_loss(predictions, token_type_id, labels): 215 | predictions = predictions[:, :-1].contiguous() 216 | target_mask = token_type_id[:, 1:].contiguous() 217 | """ 218 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 219 | """ 220 | predictions = predictions.view(-1, 3) 221 | labels = labels.view(-1) 222 | target_mask = target_mask.view(-1).float() 223 | loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none") 224 | return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 225 | 226 | class GenerateModel(nn.Module): 227 | def __init__(self): 228 | super(GenerateModel, self).__init__() 229 | if args.model_name == 'nezha': 230 | self.word2idx = load_chinese_base_vocab(pretrained_nezha_fold+"vocab.txt", simplfied=False) 231 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 232 | self.bert_model = load_bert(self.word2idx, model_name=args.model_name, model_class="seq2seq") 233 | ## 加载预训练的模型参数~ 234 | if args.train: 235 | self.bert_model.load_pretrain_params(pretrained_nezha_fold+"pytorch_model.bin") 236 | else: 237 | pass 238 | self.bert_model.set_device(device) 239 | self.configuration = self.bert_model.config 240 | elif args.model_name == 'gpt2': 241 | self.gpt_model = GPT2Model.from_pretrained(pretrained_gpt2_fold) 242 | 243 | self.linear = nn.Linear(self.configuration.hidden_size, 3).to(device) 244 | 245 | def forward(self, token_ids, token_type_ids): 246 | hidden_state = None 247 | seq2seq_predictions = None 248 | if args.model_name == 'nezha': 249 | seq2seq_predictions, hidden_state = self.bert_model(token_ids, token_type_ids) 250 | elif args.model_name == 'gpt2': 251 | seq2seq_predictions, hidden_state = self.gpt_model(input_ids=token_ids, token_type_ids=token_type_ids) 252 | copy_predictions = self.linear(nn.GELU()(hidden_state)) 253 | return seq2seq_predictions, copy_predictions 254 | 255 | 256 | def load_checkpoint(model, optimizer, trained_epoch, file_name=None): 257 | if file_name==None: 258 | file_name = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 259 | save_params = torch.load(file_name, map_location=device) 260 | model.load_state_dict(save_params["model"]) 261 | #optimizer.load_state_dict(save_params["optimizer"]) 262 | 263 | 264 | def save_checkpoint(model, optimizer, trained_epoch): 265 | save_params = { 266 | "model": model.state_dict(), 267 | "optimizer": optimizer.state_dict(), 268 | "trained_epoch": trained_epoch, 269 | } 270 | if not os.path.exists(args.checkpoint): 271 | # 判断文件夹是否存在,不存在则创建文件夹 272 | os.mkdir(args.checkpoint) 273 | filename = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 274 | torch.save(save_params, filename) 275 | 276 | 277 | def train_valid(train_data, valid_data, model): 278 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 279 | # ema = EMA(model, 0.9999) 280 | # ema.register() 281 | for epoch in range(args.epochs): 282 | epoch_loss = 0. 283 | current_step = 0 284 | model.train() 285 | # for batch_data in tqdm(train_data_loader, ncols=0): 286 | pbar = tqdm(train_data, desc="Iteration", postfix='train') 287 | for batch_data in pbar: 288 | input_ids, token_type_ids, labels = batch_data 289 | input_ids, token_type_ids, labels = input_ids.to(device), token_type_ids.to(device), labels.to(device) 290 | seq2seq_predictions, copy_predictions = model(input_ids, token_type_ids) 291 | 292 | seq2seq_loss = compute_seq2seq_loss(seq2seq_predictions, token_type_ids, input_ids, 293 | model.configuration.vocab_size) 294 | copy_loss = compute_copy_loss(copy_predictions, token_type_ids, labels) 295 | loss = seq2seq_loss + 2 * copy_loss 296 | optimizer.zero_grad() 297 | loss.backward() 298 | optimizer.step() 299 | # ema.update() 300 | loss_item = loss.cpu().detach().item() 301 | epoch_loss += loss_item 302 | current_step += 1 303 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 304 | if current_step % 100 == 0: 305 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 306 | 307 | epoch_loss = epoch_loss / current_step 308 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 309 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 310 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 311 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema 312 | save_checkpoint(model, optimizer, epoch) 313 | with torch.no_grad(): 314 | model.eval() 315 | # ema.apply_shadow() 316 | evaluate(valid_data, model) 317 | # ema.restore() 318 | model.train() 319 | 320 | class AutoSummary(AutoRegressiveDecoder): 321 | """seq2seq解码器 322 | """ 323 | def get_ngram_set(self, x, n): 324 | """生成ngram合集,返回结果格式是: 325 | {(n-1)-gram: set([n-gram的第n个字集合])} 326 | """ 327 | result = {} 328 | for i in range(len(x) - n + 1): 329 | k = tuple(x[i:i + n]) 330 | if k[:-1] not in result: 331 | result[k[:-1]] = set() 332 | result[k[:-1]].add(k[-1]) 333 | return result 334 | 335 | @AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True) 336 | def predict(self, inputs, output_ids, states): 337 | token_ids, segment_ids = inputs 338 | token_ids = np.concatenate([token_ids, output_ids], 1) 339 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 340 | seq2seq_predictions, copy_predictions = self.model(torch.tensor(token_ids, device=device), torch.tensor(segment_ids, device=device)) 341 | prediction = [seq2seq_predictions[:, -1].cpu().numpy(), torch.softmax(copy_predictions[:, -1], dim=-1).cpu().numpy()] # 返回最后一个字符的预测结果,(1, vocab_size),(1, 3) todo 我这里需要加一个softmax 前面的生成模型给也需要 342 | # states用来缓存ngram的n值 343 | if states is None: 344 | states = [0] 345 | elif len(states) == 1 and len(token_ids) > 1: 346 | states = states * len(token_ids) 347 | # 根据copy标签来调整概率分布 348 | probas = np.zeros_like(prediction[0]) - 1000 # 最终要返回的概率分布 349 | for i, token_ids in enumerate(inputs[0]): 350 | if states[i] == 0: 351 | prediction[1][i, 2] *= -1 # 0不能接2 352 | label = prediction[1][i].argmax() # 当前label 353 | if label < 2: 354 | states[i] = label 355 | else: 356 | states[i] += 1 # 2后面接什么都行 357 | if states[i] > 0: 358 | ngrams = self.get_ngram_set(token_ids, states[i]) 359 | prefix = tuple(output_ids[i, 1 - states[i]:]) 360 | if prefix in ngrams: # 如果确实是适合的ngram 361 | candidates = ngrams[prefix] 362 | else: # 没有的话就退回1gram 363 | ngrams = self.get_ngram_set(token_ids, 1) 364 | candidates = ngrams[tuple()] 365 | states[i] = 1 366 | candidates = list(candidates) 367 | probas[i, candidates] = prediction[0][i, candidates] 368 | else: 369 | probas[i] = prediction[0][i] 370 | idxs = probas[i].argpartition(-args.k_sparse) 371 | probas[i, idxs[:-args.k_sparse]] = -1000 372 | return probas, states 373 | 374 | def generate(self, text, topk=1): 375 | max_c_len = args.maxlen - self.maxlen 376 | encode_text = self.model.tokenizer(text, padding=True, truncation=True, 377 | max_length=max_c_len) 378 | token_ids, segment_ids = encode_text['input_ids'], encode_text['token_type_ids'] 379 | output_ids = self.beam_search([token_ids, segment_ids], 380 | topk) # 基于beam search 381 | return ''.join(self.model.tokenizer.convert_ids_to_tokens(output_ids)) # skip_special_tokens=True 382 | 383 | 384 | 385 | 386 | 387 | def evaluate(data, model, topk=1, filename=None): 388 | """验证集评估 389 | """ 390 | autosummary = AutoSummary( 391 | start_id=model.tokenizer.cls_token_id, 392 | end_id=model.tokenizer.sep_token_id, 393 | maxlen=args.maxlen // 4, 394 | model=model 395 | ) 396 | if filename is not None: 397 | F = open(filename, 'w', encoding='utf-8') 398 | total_metrics = {k: 0.0 for k in metric_keys} 399 | for d in tqdm(data, desc=u'评估中'): 400 | pred_summary = autosummary.generate(d['source_1'], topk) 401 | metrics = compute_metrics(pred_summary, d['explanation']) 402 | for k, v in metrics.items(): 403 | total_metrics[k] += v 404 | if filename is not None: 405 | F.write(d['explanation'] + '\t' + pred_summary + '\n') 406 | F.flush() 407 | if filename is not None: 408 | F.close() 409 | print(total_metrics) 410 | logging.info("~~~~~~~~~~~~~~~~~~~~") 411 | for k, v in total_metrics.items(): 412 | logging.info(k+": {} ".format(v/len(data))) 413 | logging.info("~~~~~~~~~~~~~~~~~~~~") 414 | return {k: v / len(data) for k, v in total_metrics.items()} 415 | 416 | if __name__ == '__main__': 417 | # 加载数据 418 | data = load_data(data_seq2seq_json) 419 | train_data = data_split(data, 'train') 420 | valid_data = data_split(data, 'valid') 421 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size) 422 | G_model = GenerateModel() 423 | if args.train: 424 | G_model = G_model.to(device) 425 | train_valid(train_data_loader, valid_data, G_model) 426 | else: 427 | load_checkpoint(G_model, None, None, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/match-seq2seq-8.pkl") 428 | with torch.no_grad(): 429 | G_model.eval() 430 | evaluate(valid_data, G_model) 431 | 432 | 433 | 434 | 435 | 436 | -------------------------------------------------------------------------------- /models/seq2seq_model_dismatch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | import datetime 4 | from transformers import BertTokenizer, AutoTokenizer 5 | import argparse 6 | import torch 7 | from transformers import AdamW 8 | import torch.nn as nn 9 | from tqdm import tqdm 10 | import copy 11 | from torch.utils.data import Dataset, DataLoader 12 | import logging 13 | from utils.snippets import * 14 | from bert_seq2seq import Tokenizer, load_chinese_base_vocab 15 | from bert_seq2seq import load_bert 16 | # 基本参数 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--batch_size', type=int, default=2, help='batch size') 19 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 20 | parser.add_argument('--each_test_epoch', type=int, default=1) 21 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 22 | parser.add_argument('--weight_decay', type=float, default=0., help='decay weight of optimizer') 23 | parser.add_argument('--model_name', type=str, default='nezha', help='matching model') 24 | parser.add_argument('--checkpoint', type=str, default="./weights/seq2seq_model", help='checkpoint path') 25 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 26 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case') 27 | parser.add_argument('--input_size', type=int, default=768) 28 | parser.add_argument('--hidden_size', type=int, default=384) 29 | parser.add_argument('--kernel_size', type=int, default=3) 30 | parser.add_argument('--threshold', type=float, default=0.3) 31 | parser.add_argument('--k_sparse', type=int, default=10) 32 | parser.add_argument('--log_name', type=str, default="log_seq2seq") 33 | parser.add_argument('--seq2seq_type', type=str, default='match') 34 | parser.add_argument('--cuda_pos', type=str, default='0', help='which GPU to use') 35 | parser.add_argument('--seed', type=int, default=42, help='max length of each case') 36 | parser.add_argument('--train', action='store_true') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | np.random.seed(args.seed) 41 | torch.manual_seed(args.seed) 42 | torch.cuda.manual_seed_all(args.seed) 43 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 44 | log_name = args.log_name 45 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 46 | filename='../logs/{}.log'.format(log_name), 47 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 48 | #a是追加模式,默认如果不写的话,就是追加模式 49 | format= 50 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 51 | #日志格式 52 | ) 53 | 54 | 55 | data_seq2seq_json = '../dataset/dismatch_data_seq2seq_v2.json' 56 | seq2seq_config_json = '../dataset/dismatch_data_seq2seq_config.json' 57 | 58 | 59 | 60 | 61 | def load_data(filename): 62 | """加载数据 63 | 返回:[{...}] 64 | """ 65 | D = [] 66 | with open(filename) as f: 67 | for l in f: 68 | D.append(json.loads(l)) 69 | return D 70 | 71 | 72 | 73 | 74 | def generate_copy_labels(source, target): 75 | """构建copy机制对应的label 76 | """ 77 | mapping = longest_common_subsequence(source, target)[1] 78 | source_labels = [0] * len(source) 79 | target_labels = [0] * len(target) 80 | i0, j0 = -2, -2 81 | for i, j in mapping: 82 | if i == i0 + 1 and j == j0 + 1: 83 | source_labels[i] = 2 84 | target_labels[j] = 2 85 | else: 86 | source_labels[i] = 1 87 | target_labels[j] = 1 88 | i0, j0 = i, j 89 | return source_labels, target_labels 90 | 91 | 92 | def random_masking(token_ids_all): 93 | """对输入进行随机mask,增加泛化能力 94 | """ 95 | result = [] 96 | for token_ids in token_ids_all: 97 | rands = np.random.random(len(token_ids)) 98 | result.append([ 99 | t if r > 0.15 else np.random.choice(token_ids) 100 | for r, t in zip(rands, token_ids) 101 | ]) 102 | return result 103 | 104 | 105 | class DataGenerator(Dataset): 106 | def __init__(self, input_data, random=True): 107 | super(DataGenerator, self).__init__() 108 | self.input_data = input_data 109 | self.random = random 110 | 111 | def __len__(self): 112 | return len(self.input_data) 113 | 114 | def __getitem__(self, idx): 115 | 116 | i = np.random.choice(2) + 1 if self.random else 1 117 | source_1, target_1 = self.input_data[idx]['source_%s_dis' % i][0], self.input_data[idx]['explanation_dis'][0] 118 | source_2, target_2 = self.input_data[idx]['source_%s_dis' % i][1], self.input_data[idx]['explanation_dis'][1] 119 | return [source_1, target_1], [source_2, target_2] 120 | 121 | 122 | class Collate: 123 | def __init__(self): 124 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 125 | 126 | self.max_seq_len = args.maxlen 127 | 128 | def __call__(self, batch): 129 | # assert len(A_batch) == 1 130 | # print("A_batch: ", A_batch) 131 | text_1, text_2 = [], [] 132 | for item in batch: 133 | text_1.append(item[0]) 134 | text_2.append(item[1]) 135 | dic_data_1, dic_data_2 = self.tokenizer.batch_encode_plus(text_1, padding=True, truncation=True, 136 | max_length=self.max_seq_len), self.tokenizer.batch_encode_plus(text_2, padding=True, truncation=True, 137 | max_length=self.max_seq_len) 138 | mask_dic_data_1, mask_dic_data_2 = copy.deepcopy(dic_data_1), copy.deepcopy(dic_data_2) 139 | 140 | token_ids_1, token_ids_2 = dic_data_1["input_ids"], dic_data_2["input_ids"], 141 | 142 | masked_token_ids_1, masked_token_ids_2 = random_masking(token_ids_1), random_masking(token_ids_2) 143 | mask_dic_data_1['input_ids'], mask_dic_data_2['input_ids'] = masked_token_ids_1, masked_token_ids_2 144 | labels_1, labels_2 = [], [] 145 | for item_masked_token_ids_1, item_token_ids_1 in zip(masked_token_ids_1, token_ids_1): 146 | idx = item_token_ids_1.index(self.tokenizer.sep_token_id) + 1 147 | source_labels_1, target_labels_1 = generate_copy_labels( 148 | item_masked_token_ids_1[:idx], item_token_ids_1[idx:] 149 | ) 150 | """ 151 | [CLS]...[SEP] ... [SEP] 152 | """ 153 | labels_1.append(source_labels_1[1:] + target_labels_1) # 因为是预测所以第一位后移 154 | for item_masked_token_ids_2, item_token_ids_2 in zip(masked_token_ids_2, token_ids_2): 155 | idx = item_token_ids_2.index(self.tokenizer.sep_token_id) + 1 156 | source_labels_2, target_labels_2 = generate_copy_labels( 157 | item_masked_token_ids_2[:idx], item_token_ids_2[idx:] 158 | ) 159 | """ 160 | [CLS]...[SEP] ... [SEP] 161 | """ 162 | labels_2.append(source_labels_2[1:] + target_labels_2) # 因为是预测所以第一位后移 163 | 164 | return torch.tensor(dic_data_1["input_ids"]), torch.tensor(dic_data_1["token_type_ids"]), torch.tensor(labels_1), \ 165 | torch.tensor(dic_data_2["input_ids"]), torch.tensor(dic_data_2["token_type_ids"]), torch.tensor(labels_2) 166 | 167 | 168 | 169 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0,): 170 | data_generator =DataGenerator(data, random=True) 171 | collate = Collate() 172 | return DataLoader( 173 | data_generator, 174 | batch_size=batch_size, 175 | shuffle=shuffle, 176 | num_workers=num_workers, 177 | collate_fn=collate 178 | ) 179 | 180 | 181 | def compute_seq2seq_loss(predictions, token_type_id, input_ids, vocab_size): 182 | 183 | predictions = predictions[:, :-1].contiguous() 184 | target_mask = token_type_id[:, 1:].contiguous() 185 | """ 186 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 187 | """ 188 | predictions = predictions.view(-1, vocab_size) 189 | labels = input_ids[:, 1:].contiguous() 190 | labels = labels.view(-1) 191 | target_mask = target_mask.view(-1).float() 192 | # 正loss 193 | pos_loss = predictions[list(range(predictions.shape[0])), labels] 194 | # 负loss 195 | y_pred = torch.topk(predictions, k=args.k_sparse)[0] 196 | neg_loss = torch.logsumexp(y_pred, dim=-1) 197 | 198 | loss = neg_loss - pos_loss 199 | return (loss * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 200 | 201 | 202 | def compute_copy_loss(predictions, token_type_id, labels): 203 | predictions = predictions[:, :-1].contiguous() 204 | target_mask = token_type_id[:, 1:].contiguous() 205 | """ 206 | target_mask : 句子a部分和pad部分全为0, 而句子b部分为1 207 | """ 208 | predictions = predictions.view(-1, 3) 209 | labels = labels.view(-1) 210 | target_mask = target_mask.view(-1).float() 211 | loss = nn.CrossEntropyLoss(ignore_index=0, reduction="none") 212 | return (loss(predictions, labels) * target_mask).sum() / target_mask.sum() ## 通过mask 取消 pad 和句子a部分预测的影响 213 | 214 | class GenerateModel(nn.Module): 215 | def __init__(self): 216 | super(GenerateModel, self).__init__() 217 | self.word2idx = load_chinese_base_vocab(pretrained_nezha_fold+"vocab.txt", simplfied=False) 218 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 219 | self.bert_model = load_bert(self.word2idx, model_name=args.model_name, model_class="seq2seq") 220 | ## 加载预训练的模型参数~ 221 | if args.train: 222 | self.bert_model.load_pretrain_params(pretrained_nezha_fold + "pytorch_model.bin") 223 | else: 224 | pass 225 | self.bert_model.set_device(device) 226 | self.configuration = self.bert_model.config 227 | self.linear = nn.Linear(self.configuration.hidden_size, 3).to(device) 228 | 229 | def forward(self, token_ids, token_type_ids): 230 | seq2seq_predictions, hidden_state = self.bert_model(token_ids, token_type_ids) 231 | copy_predictions = self.linear(nn.GELU()(hidden_state)) 232 | 233 | return seq2seq_predictions, copy_predictions 234 | 235 | 236 | def load_checkpoint(model, optimizer, trained_epoch, file_name=None): 237 | if file_name==None: 238 | file_name = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 239 | save_params = torch.load(file_name, map_location=device) 240 | model.load_state_dict(save_params["model"]) 241 | #optimizer.load_state_dict(save_params["optimizer"]) 242 | 243 | 244 | def save_checkpoint(model, optimizer, trained_epoch): 245 | save_params = { 246 | "model": model.state_dict(), 247 | "optimizer": optimizer.state_dict(), 248 | "trained_epoch": trained_epoch, 249 | } 250 | if not os.path.exists(args.checkpoint): 251 | # 判断文件夹是否存在,不存在则创建文件夹 252 | os.mkdir(args.checkpoint) 253 | filename = args.checkpoint + '/' + f"{args.seq2seq_type}-seq2seq-{trained_epoch}.pkl" 254 | torch.save(save_params, filename) 255 | 256 | 257 | def train_valid(train_data, valid_data, model): 258 | optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 259 | # ema = EMA(model, 0.9999) 260 | # ema.register() 261 | for epoch in range(args.epochs): 262 | epoch_loss = 0. 263 | current_step = 0 264 | model.train() 265 | # for batch_data in tqdm(train_data_loader, ncols=0): 266 | pbar = tqdm(train_data, desc="Iteration", postfix='train') 267 | for batch_data in pbar: 268 | input_ids_1, token_type_ids_1, labels_1, input_ids_2, token_type_ids_2, labels_2 = batch_data 269 | input_ids_1, token_type_ids_1, labels_1, input_ids_2, token_type_ids_2, labels_2 = input_ids_1.to(device), token_type_ids_1.to(device), labels_1.to(device), input_ids_2.to(device), token_type_ids_2.to(device), labels_2.to(device) 270 | seq2seq_predictions_1, copy_predictions_1 = model(input_ids_1, token_type_ids_1) 271 | 272 | 273 | seq2seq_loss_1 = compute_seq2seq_loss(seq2seq_predictions_1, token_type_ids_1, input_ids_1, 274 | model.configuration.vocab_size) 275 | copy_loss_1 = compute_copy_loss(copy_predictions_1, token_type_ids_1, labels_1) 276 | 277 | loss_1 = seq2seq_loss_1 + 2 * copy_loss_1 278 | 279 | optimizer.zero_grad() 280 | loss_1.backward() 281 | optimizer.step() 282 | # ema.update() 283 | loss_item_1 = loss_1.cpu().detach().item() 284 | epoch_loss += loss_item_1 285 | 286 | seq2seq_predictions_2, copy_predictions_2 = model(input_ids_2, token_type_ids_2) 287 | seq2seq_loss_2 = compute_seq2seq_loss(seq2seq_predictions_2, token_type_ids_2, input_ids_2, 288 | model.configuration.vocab_size) 289 | copy_loss_2 = compute_copy_loss(copy_predictions_2, token_type_ids_2, labels_2) 290 | loss_2 = seq2seq_loss_2 + 2 * copy_loss_2 291 | optimizer.zero_grad() 292 | loss_2.backward() 293 | optimizer.step() 294 | # ema.update() 295 | loss_item_2 = loss_2.cpu().detach().item() 296 | epoch_loss += loss_item_2 297 | 298 | 299 | current_step += 1 300 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 301 | if current_step % 100 == 0: 302 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 303 | 304 | epoch_loss = epoch_loss / current_step 305 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 306 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 307 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 308 | # todo 看一下 EMA是否会让模型准确率提升,如果可以的话在保存模型前加入 ema 309 | save_checkpoint(model, optimizer, epoch) 310 | with torch.no_grad(): 311 | model.eval() 312 | # ema.apply_shadow() 313 | evaluate(valid_data, model) 314 | # ema.restore() 315 | model.train() 316 | 317 | class AutoSummary(AutoRegressiveDecoder): 318 | """seq2seq解码器 319 | """ 320 | def get_ngram_set(self, x, n): 321 | """生成ngram合集,返回结果格式是: 322 | {(n-1)-gram: set([n-gram的第n个字集合])} 323 | """ 324 | result = {} 325 | for i in range(len(x) - n + 1): 326 | k = tuple(x[i:i + n]) 327 | if k[:-1] not in result: 328 | result[k[:-1]] = set() 329 | result[k[:-1]].add(k[-1]) 330 | return result 331 | 332 | @AutoRegressiveDecoder.wraps(default_rtype='logits', use_states=True) 333 | def predict(self, inputs, output_ids, states): 334 | token_ids, segment_ids = inputs 335 | token_ids = np.concatenate([token_ids, output_ids], 1) 336 | segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) 337 | seq2seq_predictions, copy_predictions = self.model(torch.tensor(token_ids, device=device), torch.tensor(segment_ids, device=device)) 338 | prediction = [seq2seq_predictions[:, -1].cpu().numpy(), torch.softmax(copy_predictions[:, -1], dim=-1).cpu().numpy()] # 返回最后一个字符的预测结果,(1, vocab_size),(1, 3) todo 我这里需要加一个softmax 前面的生成模型给也需要 339 | # states用来缓存ngram的n值 340 | if states is None: 341 | states = [0] 342 | elif len(states) == 1 and len(token_ids) > 1: 343 | states = states * len(token_ids) 344 | # 根据copy标签来调整概率分布 345 | probas = np.zeros_like(prediction[0]) - 1000 # 最终要返回的概率分布 346 | for i, token_ids in enumerate(inputs[0]): 347 | if states[i] == 0: 348 | prediction[1][i, 2] *= -1 # 0不能接2 349 | label = prediction[1][i].argmax() # 当前label 350 | if label < 2: 351 | states[i] = label 352 | else: 353 | states[i] += 1 # 2后面接什么都行 354 | if states[i] > 0: 355 | ngrams = self.get_ngram_set(token_ids, states[i]) 356 | prefix = tuple(output_ids[i, 1 - states[i]:]) 357 | if prefix in ngrams: # 如果确实是适合的ngram 358 | candidates = ngrams[prefix] 359 | else: # 没有的话就退回1gram 360 | ngrams = self.get_ngram_set(token_ids, 1) 361 | candidates = ngrams[tuple()] 362 | states[i] = 1 363 | candidates = list(candidates) 364 | probas[i, candidates] = prediction[0][i, candidates] 365 | else: 366 | probas[i] = prediction[0][i] 367 | idxs = probas[i].argpartition(-args.k_sparse) 368 | probas[i, idxs[:-args.k_sparse]] = -1000 369 | return probas, states 370 | 371 | def generate(self, text, topk=1): 372 | max_c_len = args.maxlen - self.maxlen 373 | encode_text = self.model.tokenizer(text, padding=True, truncation=True, 374 | max_length=max_c_len) 375 | token_ids, segment_ids = encode_text['input_ids'], encode_text['token_type_ids'] 376 | output_ids = self.beam_search([token_ids, segment_ids], 377 | topk) # 基于beam search 378 | return ''.join(self.model.tokenizer.convert_ids_to_tokens(output_ids)) 379 | 380 | 381 | 382 | 383 | 384 | def evaluate(data, model, topk=1, filename=None): 385 | """验证集评估 386 | """ 387 | autosummary = AutoSummary( 388 | start_id=model.tokenizer.cls_token_id, 389 | end_id=model.tokenizer.sep_token_id, 390 | maxlen=args.maxlen // 4, 391 | model=model 392 | ) 393 | if filename is not None: 394 | F = open(filename, 'w', encoding='utf-8') 395 | total_metrics = {k: 0.0 for k in metric_keys} 396 | for d in tqdm(data, desc=u'评估中'): 397 | pred_summary_1 = autosummary.generate(d['source_1_dis'][0], topk) 398 | pred_summary_2 = autosummary.generate(d['source_1_dis'][1], topk) 399 | metrics = compute_metrics(pred_summary_1+pred_summary_2, d['explanation_dis'][0]+d['explanation_dis'][1]) 400 | for k, v in metrics.items(): 401 | total_metrics[k] += v 402 | if filename is not None: 403 | F.write(d['explanation_dis'][0]+d['explanation_dis'][1] + '\t' + pred_summary_1 + pred_summary_2 + '\n') 404 | F.flush() 405 | if filename is not None: 406 | F.close() 407 | print(total_metrics) 408 | logging.info("~~~~~~~~~~~~~~~~~~~~") 409 | for k, v in total_metrics.items(): 410 | logging.info(k+": {} ".format(v/len(data))) 411 | logging.info("~~~~~~~~~~~~~~~~~~~~") 412 | return {k: v / len(data) for k, v in total_metrics.items()} 413 | 414 | 415 | if __name__ == '__main__': 416 | # 加载数据 417 | data = load_data(data_seq2seq_json) 418 | train_data = data_split(data, 'train') 419 | valid_data = data_split(data, 'valid') 420 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size) 421 | G_model = GenerateModel() 422 | if args.train: 423 | G_model = G_model.to(device) 424 | train_valid(train_data_loader, valid_data, G_model) 425 | else: 426 | for i in range(24): 427 | logging.info("epoch: {}".format(i)) 428 | load_checkpoint(G_model, None, None, "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/seq2seq_model/dismatch-seq2seq-{}.pkl".format(i)) 429 | with torch.no_grad(): 430 | G_model.eval() 431 | evaluate(valid_data, G_model) 432 | 433 | 434 | 435 | 436 | 437 | -------------------------------------------------------------------------------- /models/predictor_v5.py: -------------------------------------------------------------------------------- 1 | """ 2 | 分类任务 + margin loss 3 | """ 4 | import sys 5 | sys.path.append("..") 6 | import datetime 7 | from transformers import BertTokenizer, AutoTokenizer, BertModel, AutoModel 8 | import argparse 9 | import torch 10 | import torch.nn as nn 11 | from tqdm import tqdm 12 | import copy 13 | from torch.utils.data import Dataset, DataLoader 14 | import logging 15 | from utils.snippets import * 16 | from sklearn.metrics import accuracy_score 17 | from sklearn.metrics import recall_score 18 | from sklearn.metrics import precision_score 19 | from sklearn.metrics import f1_score 20 | from sklearn.metrics import cohen_kappa_score 21 | from sklearn.metrics import hamming_loss 22 | from sklearn.metrics import jaccard_score 23 | # 基本参数 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--batch_size_train', type=int, default=3, help='batch size') 26 | parser.add_argument('--batch_size_test', type=int, default=2, help='batch size') 27 | parser.add_argument('--epochs', type=int, default=50, help='number of epochs') 28 | parser.add_argument('--lr', type=float, default=2e-5, help='learning rate') 29 | parser.add_argument('--weight_decay', type=float, default=0.0, help='decay weight of optimizer') 30 | parser.add_argument('--model_name', type=str, default='legal_bert', help='[nezha, legal_bert, lawformer]') 31 | parser.add_argument('--checkpoint', type=str, default="./weights/predict_model", help='checkpoint path') 32 | parser.add_argument('--bert_maxlen', type=int, default=512, help='max length of each case') 33 | parser.add_argument('--maxlen', type=int, default=1024, help='max length of each case') 34 | parser.add_argument('--input_size', type=int, default=768) 35 | parser.add_argument('--hidden_size', type=int, default=384) 36 | parser.add_argument('--dropout', type=float, default=0.1) 37 | parser.add_argument('--cuda_pos', type=str, default='1', help='which GPU to use') 38 | parser.add_argument('--seed', type=int, default=1, help='max length of each case') 39 | parser.add_argument('--train', type=bool, default=True, help='whether train') 40 | parser.add_argument('--early_stopping_patience', type=int, default=5, help='whether train') 41 | parser.add_argument('--log_name', type=str, default="predictor_v3_2", help='whether train') 42 | parser.add_argument('--margin', type=float, default=0.01, help='margin') 43 | parser.add_argument('--weight', type=float, default=1., help='gold_weight') 44 | parser.add_argument('--gold_margin', type=float, default=0., help='gold_margin') 45 | parser.add_argument('--gold_weight', type=float, default=1., help='gold_weight') 46 | parser.add_argument('--scale_in', type=float, default=10., help='scale_in') 47 | parser.add_argument('--scale_out', type=float, default=10., help='scale_out') 48 | parser.add_argument('--warmup_steps', type=int, default=10000, help='warmup_steps') 49 | parser.add_argument('--accumulate_step', type=int, default=12, help='accumulate_step') 50 | parser.add_argument('--data_type', type=str, default="ELAM", help='[ELAM, CAIL]') 51 | parser.add_argument('--mode_type', type=str, default="rationale", help='[all_sents, wo_rationale, rationale]') 52 | parser.add_argument('--eval_metric', type=str, default="linear_out", help='[linear_out, cosine_out]') 53 | 54 | args = parser.parse_args() 55 | print(args) 56 | np.random.seed(args.seed) 57 | random.seed(args.seed) 58 | torch.manual_seed(args.seed) 59 | torch.cuda.manual_seed_all(args.seed) 60 | device = torch.device('cuda:'+args.cuda_pos) if torch.cuda.is_available() else torch.device('cpu') 61 | logging.basicConfig(level=logging.INFO,#控制台打印的日志级别 62 | filename='../logs/predictor_v5_data_type_{}_mode_type_{}.log'.format(args.data_type, args.mode_type), 63 | filemode='a',##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 64 | #a是追加模式,默认如果不写的话,就是追加模式 65 | format= 66 | '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' 67 | #日志格式 68 | ) 69 | 70 | 71 | if args.data_type == 'CAIL': 72 | data_predictor_json = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/stage3/data_prediction_{}.json".format(args.mode_type) 73 | elif args.data_type == 'ELAM': 74 | data_predictor_json = "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/dataset/our_data/stage3/data_prediction_{}.json".format(args.mode_type) 75 | else: 76 | exit() 77 | def load_data(filename): 78 | """加载数据 79 | 返回:[{...}] 80 | """ 81 | all_data = [] 82 | with open(filename) as f: 83 | for l in f: 84 | all_data.append(json.loads(l)) 85 | random.shuffle(all_data) 86 | return all_data 87 | 88 | 89 | def load_checkpoint(model, optimizer, trained_epoch, file_name=None): 90 | save_params = torch.load(file_name, map_location=device) 91 | model.load_state_dict(save_params["model"]) 92 | #optimizer.load_state_dict(save_params["optimizer"]) 93 | 94 | 95 | def save_checkpoint(model, optimizer, trained_epoch): 96 | save_params = { 97 | "model": model.state_dict(), 98 | "optimizer": optimizer.state_dict(), 99 | "trained_epoch": trained_epoch, 100 | } 101 | if not os.path.exists(args.checkpoint): 102 | # 判断文件夹是否存在,不存在则创建文件夹 103 | os.mkdir(args.checkpoint) 104 | filename = args.checkpoint + '/' + "predictor_v5_data_type_{}_mode_type_{}.log".format(args.data_type, args.mode_type) 105 | torch.save(save_params, filename) 106 | 107 | 108 | class PredictorDataset(Dataset): 109 | """ 110 | input data predictor convert的输出就OK 111 | """ 112 | def __init__(self, input_data, random=True): 113 | super(PredictorDataset, self).__init__() 114 | self.data = input_data 115 | self.random = random 116 | 117 | def __len__(self): 118 | return len(self.data) 119 | 120 | def __getitem__(self, index): 121 | """ 122 | 注意exp的为 match dismatch midmatch 123 | :param index: 124 | :return: 125 | """ 126 | i = np.random.choice(2) + 1 if self.random else 1 127 | if i == 1: 128 | return self.data[index]['case_a'], self.data[index]['case_b'], self.data[index]['exp'][0], self.data[index]['exp'][1], self.data[index]['exp'][2], self.data[index]['label'], self.data[index]['explanation'] 129 | else: 130 | return self.data[index]['source_2_dis'][0], self.data[index]['source_2_dis'][1], self.data[index]['exp'][0], self.data[index]['exp'][1], self.data[index]['exp'][2], self.data[index]['label'], self.data[index]['explanation'] 131 | 132 | 133 | 134 | class Collate: 135 | def __init__(self): 136 | if args.model_name=='nezha': 137 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 138 | self.max_seq_len = args.maxlen 139 | elif args.model_name == 'legal_bert': 140 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_fold) 141 | self.max_seq_len = args.bert_maxlen 142 | elif args.model_name == 'lawformer': 143 | self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") 144 | self.max_seq_len = args.maxlen 145 | 146 | def __call__(self, batch): 147 | text_a, text_b, exp_match, exp_dismatch, exp_midmatch, labels, gold_exp = [], [], [], [], [], [], [] 148 | for item in batch: 149 | text_a.append(item[0]) 150 | text_b.append(item[1]) 151 | exp_match.append(item[2]) 152 | exp_midmatch.append(item[3]) 153 | exp_dismatch.append(item[4]) 154 | labels.append(item[5]) 155 | gold_exp.append(item[6]) 156 | dic_data_a = self.tokenizer.batch_encode_plus(text_a, padding=True, truncation=True, 157 | max_length=self.max_seq_len, return_tensors='pt') 158 | dic_data_b = self.tokenizer.batch_encode_plus(text_b, padding=True, truncation=True, 159 | max_length=self.max_seq_len, return_tensors='pt') 160 | dic_match = self.tokenizer.batch_encode_plus(exp_match, padding=True, truncation=True, 161 | max_length=self.max_seq_len, return_tensors='pt') 162 | dic_dismatch = self.tokenizer.batch_encode_plus(exp_dismatch, padding=True, truncation=True, 163 | max_length=self.max_seq_len, return_tensors='pt') 164 | dic_midmatch = self.tokenizer.batch_encode_plus(exp_midmatch, padding=True, truncation=True, 165 | max_length=self.max_seq_len, return_tensors='pt') 166 | dic_gold_exp = self.tokenizer.batch_encode_plus(gold_exp, padding=True, truncation=True, 167 | max_length=self.max_seq_len, return_tensors='pt') 168 | return dic_data_a, dic_data_b, dic_match, dic_midmatch, dic_dismatch, torch.tensor(labels), dic_gold_exp 169 | 170 | 171 | def build_pretrain_dataloader(data, batch_size, shuffle=True, num_workers=0, random=True): 172 | data_generator =PredictorDataset(data, random=random) 173 | collate = Collate() 174 | return DataLoader( 175 | data_generator, 176 | batch_size=batch_size, 177 | shuffle=shuffle, 178 | num_workers=num_workers, 179 | collate_fn=collate 180 | ) 181 | 182 | 183 | class PredictorModel(nn.Module): 184 | def __init__(self): 185 | super(PredictorModel, self).__init__() 186 | if args.model_name=='nezha': 187 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_nezha_fold) 188 | self.model = BertModel.from_pretrained(pretrained_nezha_fold) 189 | elif args.model_name == 'legal_bert': 190 | self.tokenizer = BertTokenizer.from_pretrained(pretrained_bert_fold) 191 | self.model = BertModel.from_pretrained(pretrained_bert_fold) 192 | elif args.model_name == 'lawformer': 193 | self.tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext") 194 | self.model = AutoModel.from_pretrained("thunlp/Lawformer") 195 | 196 | self.configuration = self.model.config 197 | 198 | self.n = 2 199 | self.linear1 = nn.Sequential( 200 | nn.Linear(self.n*self.configuration.hidden_size, self.configuration.hidden_size), # self.hidden_dim * 2 for bi-GRU & concat AB 201 | nn.LeakyReLU(), 202 | ) 203 | 204 | 205 | 206 | self.linear2_match = nn.Sequential( 207 | nn.Linear(self.configuration.hidden_size+self.configuration.hidden_size, self.configuration.hidden_size), 208 | nn.LeakyReLU(), 209 | nn.Linear(self.configuration.hidden_size, 1), 210 | nn.Sigmoid() 211 | ) 212 | self.linear2_midmatch = nn.Sequential( 213 | nn.Linear(self.configuration.hidden_size+self.configuration.hidden_size, self.configuration.hidden_size), 214 | nn.LeakyReLU(), 215 | nn.Linear(self.configuration.hidden_size, 1), 216 | nn.Sigmoid() 217 | ) 218 | self.linear2_dismatch = nn.Sequential( 219 | nn.Linear(self.configuration.hidden_size+self.configuration.hidden_size, self.configuration.hidden_size), 220 | nn.LeakyReLU(), 221 | nn.Linear(self.configuration.hidden_size, 1), 222 | nn.Sigmoid() 223 | ) 224 | self.cos = torch.nn.CosineSimilarity(dim=-1, eps=1e-6) 225 | 226 | def forward(self, text_a, text_b, match, midmatch, dismatch, gold_exp, batch_label, model_type='train'): 227 | output_text_a = self.model(**text_a)['pooler_output'] 228 | output_text_b = self.model(**text_b)['pooler_output'] 229 | output_exp1 = self.model(**match)['pooler_output'] 230 | output_exp2 = self.model(**midmatch)['pooler_output'] 231 | output_exp3 = self.model(**dismatch)['pooler_output'] 232 | gold_exp_pl = self.model(**gold_exp)['pooler_output'] 233 | data_p = torch.cat([output_text_a, output_text_b], dim=-1) 234 | 235 | query = self.linear1(data_p) 236 | class_match = self.linear2_match(torch.cat([query, output_exp1], dim=-1)) 237 | class_midmatch = self.linear2_midmatch(torch.cat([query, output_exp2], dim=-1)) 238 | class_dismatch = self.linear2_dismatch(torch.cat([query, output_exp3], dim=-1)) 239 | """ 240 | 算一个query与三个exp + golden的cos 241 | """ 242 | exps = torch.stack([output_exp3, output_exp2, output_exp1], dim=1) # (batch_size, 3, dim) 还是要把dismatch放前面 243 | query_1 = query.unsqueeze(1).repeat(1, 3, 1) # (batch_size, 3, dim) 244 | in_cos_score = self.cos(exps, query_1) 245 | golden_cos_similarity = self.cos(gold_exp_pl, query) 246 | """ 247 | 样本间对比操作 248 | query 与 其他数据的exp算得分 249 | """ 250 | if model_type == 'train': 251 | select = exps[:, batch_label.squeeze(), :] 252 | fi_select = select.permute([1, 0, 2]) # (batch_size, batch_size, dim) 253 | out_cos_score = self.cos(fi_select, query.unsqueeze(-2)) 254 | output_scores = torch.cat((class_dismatch, class_midmatch, class_match), dim=-1) 255 | return {"exp_score":output_scores, "in_cos_score":in_cos_score, "golden_cos_score":golden_cos_similarity, "out_cos_score":out_cos_score} # 需要两个mask 一个对角线mask 另一个label mask 256 | else: 257 | output_scores = torch.cat((class_dismatch, class_midmatch, class_match), dim=-1) 258 | return {"exp_score":output_scores, "in_cos_score":in_cos_score} # 需要两个mask 一个对角线mask 另一个label mask 259 | 260 | def in_class_loss(score, summary_score=None, gold_margin=0, gold_weight=1): 261 | pos_score = summary_score.unsqueeze(-1).expand_as(score) 262 | neg_score = score 263 | pos_score = pos_score.contiguous().view(-1) 264 | neg_score = neg_score.contiguous().view(-1) 265 | ones = torch.ones_like(pos_score) 266 | loss_func = torch.nn.MarginRankingLoss(gold_margin) 267 | TotalLoss = gold_weight * loss_func(pos_score, neg_score, ones) 268 | return TotalLoss 269 | 270 | 271 | def out_class_loss(score, summary_score=None, margin=0, weight=1): 272 | select = torch.le(torch.eye(len(summary_score), device=device), 0) 273 | pos_score = summary_score.unsqueeze(-1).expand_as(score) 274 | neg_score = score 275 | pos_score = pos_score.contiguous().view(-1) 276 | neg_score = neg_score.contiguous().view(-1) 277 | select = select.contiguous().view(-1) 278 | ones = torch.ones_like(pos_score, device=device) 279 | loss_func = torch.nn.MarginRankingLoss(margin, reduction='none') 280 | TotalLoss = weight * torch.sum(loss_func(pos_score, neg_score, ones)*select) 281 | return TotalLoss 282 | 283 | 284 | def train_valid(model, train_dataloader, valid_dataloader, test_dataloader): 285 | model = model.to(device) 286 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 287 | criterion = nn.BCELoss() 288 | #criterion = nn.CrossEntropyLoss() 289 | early_stop = EarlyStop(args.early_stopping_patience) 290 | for epoch in range(args.epochs): 291 | epoch_loss = 0.0 292 | current_step = 0 293 | model.train() 294 | pbar = tqdm(train_dataloader, desc="Iteration", postfix='train') 295 | for batch_data in pbar: 296 | text_batch_a, text_batch_b, match_batch, midmatch_batch, dismatch_batch, label_batch, gold_exp_batch = batch_data 297 | 298 | text_batch_a, text_batch_b, match_batch, dismatch_batch, midmatch_batch, label_batch, gold_exp_batch = \ 299 | text_batch_a.to(device), text_batch_b.to(device), match_batch.to(device), dismatch_batch.to(device), midmatch_batch.to(device), label_batch.to(device), gold_exp_batch.to(device) 300 | scores = model(text_batch_a, text_batch_b, match_batch, midmatch_batch, dismatch_batch, 301 | gold_exp_batch, label_batch, model_type='train') 302 | 303 | """ 304 | match midmatch dismatch 305 | """ 306 | 307 | linear_similarity, gold_similarity, in_cos_score, out_cos_score = scores['exp_score'], scores['golden_cos_score'], scores['in_cos_score'], scores['out_cos_score'] 308 | 309 | loss_in_class = args.scale_in * in_class_loss(in_cos_score, gold_similarity, args.gold_margin, args.gold_weight) 310 | loss_out_class = args.scale_out * out_class_loss(out_cos_score, in_cos_score[list(range(len(in_cos_score))), label_batch.squeeze()], args.margin, args.weight) 311 | 312 | bce_labels = torch.zeros_like(scores['exp_score']) 313 | bce_labels[list(range(len(bce_labels))), label_batch] = 1 314 | bce_labels = bce_labels.to(device) 315 | 316 | loss_bce = criterion(scores['exp_score'], bce_labels) 317 | loss = loss_in_class + loss_out_class + loss_bce # 多任务了属于是 318 | optimizer.zero_grad() 319 | loss.backward() 320 | optimizer.step() 321 | 322 | loss_item = loss.cpu().detach().item() 323 | epoch_loss += loss_item 324 | current_step += 1 325 | pbar.set_description("train loss {}".format(epoch_loss / current_step)) 326 | if current_step % 100 == 0: 327 | logging.info("train step {} loss {}".format(current_step, epoch_loss / current_step)) 328 | 329 | epoch_loss = epoch_loss / current_step 330 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 331 | print('{} train epoch {} loss: {:.4f}'.format(time_str, epoch, epoch_loss)) 332 | logging.info('train epoch {} loss: {:.4f}'.format(epoch, epoch_loss)) 333 | model.eval() 334 | 335 | current_val_metric_value = evaluation(valid_dataloader, model, epoch) 336 | 337 | is_save = early_stop.step(current_val_metric_value, epoch) 338 | if is_save: 339 | save_checkpoint(model, optimizer, epoch) 340 | else: 341 | pass 342 | if early_stop.stop_training(epoch): 343 | logging.info( 344 | "early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 345 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 346 | )) 347 | print("early stopping at epoch {} since didn't improve from epoch no {}. Best value {}, current value {}".format( 348 | epoch, early_stop.best_epoch, early_stop.best_value, current_val_metric_value 349 | )) 350 | break 351 | evaluation(test_dataloader, model, epoch, type='test') 352 | 353 | def evaluation(valid_dataloader, model, epoch, type='valid'): 354 | with torch.no_grad(): 355 | correct = 0 356 | total = 0 357 | current_step = 0 358 | prediction_batch_list, label_batch_list = [], [] 359 | pbar = tqdm(valid_dataloader, desc="Iteration", postfix=type) 360 | for batch_data in pbar: 361 | text_batch_a, text_batch_b, match_batch, midmatch_batch, dismatch_batch, label_batch, gold_exp_batch = batch_data 362 | text_batch_a = text_batch_a.to(device) 363 | text_batch_b = text_batch_b.to(device) 364 | match_batch, dismatch_batch, midmatch_batch, gold_exp_batch = match_batch.to(device), dismatch_batch.to(device), midmatch_batch.to(device), gold_exp_batch.to(device) 365 | label_batch = label_batch.to(device) 366 | label_batch_list.append(label_batch) 367 | """ 368 | todo 这里好好看一下注意改造 mid 0 dis 1 match 2 369 | """ 370 | output_batch = model(text_batch_a, text_batch_b, match_batch, midmatch_batch, dismatch_batch, gold_exp_batch, label_batch, model_type=type) 371 | if args.eval_metric == 'linear_out': 372 | _, predicted_output = torch.max(output_batch["exp_score"], -1) 373 | elif args.eval_metric == 'cosine_out': 374 | _, predicted_output = torch.max(output_batch["in_cos_score"], -1) 375 | else: 376 | exit() 377 | label_batch = label_batch.to(device) 378 | total += len(label_batch) 379 | prediction_batch_list.append(predicted_output) 380 | correct += torch.sum(torch.eq(label_batch, predicted_output)) 381 | pbar.set_description("{} acc {}".format(type, correct / total)) 382 | current_step += 1 383 | if current_step % 100 == 0: 384 | logging.info('{} epoch {} acc {}/{}={:.4f}'.format(type, epoch, correct, total, correct / total)) 385 | prediction_batch_list = torch.cat(prediction_batch_list, dim=0).cpu().tolist() 386 | label_batch_list = torch.cat(label_batch_list, dim=0).cpu().tolist() 387 | accuracy = accuracy_score(label_batch_list, prediction_batch_list) 388 | precision_macro = precision_score(label_batch_list, prediction_batch_list, average='macro') 389 | recall_macro = recall_score(label_batch_list, prediction_batch_list, average='macro') 390 | f1_macro = f1_score(label_batch_list, prediction_batch_list, average='macro') 391 | precision_micro = precision_score(label_batch_list, prediction_batch_list, average='micro') 392 | recall_micro = recall_score(label_batch_list, prediction_batch_list, average='micro') 393 | f1_micro = f1_score(label_batch_list, prediction_batch_list, average='micro') 394 | cohen_kappa = cohen_kappa_score(label_batch_list, prediction_batch_list) 395 | hamming = hamming_loss(label_batch_list, prediction_batch_list) 396 | jaccard_macro = jaccard_score(label_batch_list, prediction_batch_list, average='macro') 397 | jaccard_micro = jaccard_score(label_batch_list, prediction_batch_list, average='micro') 398 | 399 | time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") 400 | print('{} {} acc {} precision ma {} mi {} recall ma {} mi {} f1 ma {} mi {}'.format(time_str, type, accuracy, 401 | precision_macro, 402 | precision_micro, 403 | recall_macro, 404 | recall_micro, f1_macro, 405 | f1_micro)) 406 | print('cohen_kappa {} hamming {} jaccard_macro {} jaccard_micro {}'.format(cohen_kappa, hamming, jaccard_macro, 407 | jaccard_micro)) 408 | logging.info( 409 | '{} {} acc {} precision ma {} mi {} recall ma {} mi {} f1 ma {} mi {}'.format(time_str, type, accuracy, 410 | precision_macro, 411 | precision_micro, 412 | recall_macro, recall_micro, 413 | f1_macro, f1_micro)) 414 | logging.info( 415 | 'cohen_kappa {} hamming {} jaccard_macro {} jaccard_micro {}'.format(cohen_kappa, hamming, jaccard_macro, 416 | jaccard_micro)) 417 | return accuracy 418 | 419 | def frozen_model(P_model, unfreeze_layers): 420 | """ 421 | 用于冻结模型 422 | :param model: 423 | :param free_layer: 424 | :return: 425 | """ 426 | for name, param in P_model.named_parameters(): 427 | print(name, param.size()) 428 | print("*" * 30) 429 | print('\n') 430 | 431 | for name, param in P_model.named_parameters(): 432 | param.requires_grad = False 433 | for ele in unfreeze_layers: 434 | if ele in name: 435 | param.requires_grad = True 436 | break 437 | # 验证一下 438 | for name, param in P_model.named_parameters(): 439 | if param.requires_grad: 440 | print(name, param.size()) 441 | 442 | if __name__ == '__main__': 443 | data = load_data(data_predictor_json) 444 | train_data = prediction_data_split(data, 'train', splite_ratio=0.8) 445 | valid_data = prediction_data_split(data, 'valid', splite_ratio=0.8) 446 | test_data = prediction_data_split(data, 'test', splite_ratio=0.8) 447 | train_data_loader = build_pretrain_dataloader(train_data, args.batch_size_train, shuffle=True, random=True) 448 | valid_data_loader = build_pretrain_dataloader(valid_data, args.batch_size_test, shuffle=False, random=False) 449 | test_data_loader = build_pretrain_dataloader(test_data, args.batch_size_test, shuffle=False, random=False) 450 | P_model = PredictorModel() 451 | if args.train: 452 | P_model = P_model.to(device) 453 | train_valid(P_model, train_data_loader, valid_data_loader, test_data_loader) 454 | else: 455 | P_model = P_model.to(device) 456 | load_checkpoint(P_model, None, None, 457 | "/new_disk2/zhongxiang_sun/code/explanation_project/explanation_model/models/weights/predict_model/predictor-0.pkl") 458 | with torch.no_grad(): 459 | P_model.eval() 460 | evaluation(valid_data_loader, P_model, 0) 461 | 462 | -------------------------------------------------------------------------------- /utils/snippets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from rouge import Rouge 4 | import random 5 | import os, sys 6 | import jieba 7 | import copy 8 | import six 9 | from collections import defaultdict 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | from transformers import BertTokenizer 14 | import argparse 15 | # 自定义词典 16 | user_dict_path = '/home/zhongxiang_sun/code/explanation_project/explanation_model/dataset/user_dict.txt' 17 | user_dict_path_2 = '/home/zhongxiang_sun/code/explanation_project/explanation_model/dataset/user_dict_2.txt' 18 | jieba.load_userdict(user_dict_path) 19 | jieba.initialize() 20 | 21 | # 设置递归深度 22 | sys.setrecursionlimit(1000000) 23 | 24 | # 标注数据 25 | 26 | our_data_train = "/home/zhongxiang_sun/code/explanation_project/NER/data/law_train.json" 27 | our_data_test = "/home/zhongxiang_sun/code/explanation_project/NER/data/law_test.json" 28 | our_data_dev = "/home/zhongxiang_sun/code/explanation_project/NER/data/law_dev.json" 29 | 30 | 31 | # 保存权重的文件夹 32 | if not os.path.exists('weights'): 33 | os.mkdir('weights') 34 | 35 | # pretrained_model 配置 36 | pretrained_bert_fold = "/home/zhongxiang_sun/code/pretrain_model/bert_legal/" 37 | pretrained_chinese_bert_fold = "/home/zhongxiang_sun/code/pretrain_model/chinese_bert/" 38 | pretrained_nezha_fold = "/home/zhongxiang_sun/code/pretrain_model/NEZHA/" 39 | pretrained_gpt2_fold = "/home/zhongxiang_sun/code/pretrain_model/GPT2_CH/" 40 | pretrained_t5_fold = "/home/zhongxiang_sun/code/pretrain_model/T5_PEGASUS/" 41 | pretrained_lawformer_fold = "/home/zhongxiang_sun/code/pretrain_model/lawformer/" 42 | pretrained_bert_legal_civil_fold = "/home/zhongxiang_sun/code/pretrain_model/bert_legal_civil/" 43 | pretrained_bert_legal_criminal_fold = "/home/zhongxiang_sun/code/pretrain_model/bert_legal_criminal/" 44 | 45 | 46 | # 将数据划分N份,一份作为验证集 47 | num_folds = 15 48 | 49 | # 指标名 50 | metric_keys = ['main', 'rouge-1', 'rouge-2', 'rouge-l'] 51 | 52 | # 计算rouge用 53 | rouge = Rouge() 54 | 55 | def idxtobool(idx, size, device): 56 | V = torch.zeros(size, dtype=torch.long, device=device) 57 | if len(size) > 2: 58 | 59 | for i in range(size[0]): 60 | for j in range(size[1]): 61 | subidx = idx[i, j, :] 62 | V[i, j, subidx] = float(1) 63 | 64 | elif len(size) == 2: 65 | 66 | for i in range(size[0]): 67 | subidx = idx[i, :] 68 | V[i, subidx] = float(1) 69 | 70 | else: 71 | 72 | raise argparse.ArgumentTypeError('len(size) should be larger than 1') 73 | 74 | return V 75 | 76 | class BCEFocalLoss(torch.nn.Module): 77 | def __init__(self, gamma=2, alpha=0.25, reduction='mean'): 78 | super(BCEFocalLoss, self).__init__() 79 | self.gamma = gamma 80 | self.alpha = alpha 81 | self.reduction = reduction 82 | 83 | def forward(self, predict, target): 84 | pt = torch.sigmoid(predict) # sigmoide获取概率 85 | #在原始ce上增加动态权重因子,注意alpha的写法,下面多类时不能这样使用 86 | loss = - self.alpha * (1 - pt) ** self.gamma * target * torch.log(pt) - (1 - self.alpha) * pt ** self.gamma * (1 - target) * torch.log(1 - pt) 87 | 88 | if self.reduction == 'mean': 89 | loss = torch.mean(loss) 90 | elif self.reduction == 'sum': 91 | loss = torch.sum(loss) 92 | else: 93 | pass 94 | return loss 95 | 96 | 97 | def dcg_score(y_true, y_score, k=10): 98 | order = np.argsort(y_score)[::-1] 99 | y_true = np.take(y_true, order[:k]) 100 | gains = 2 ** y_true - 1 101 | discounts = np.log2(np.arange(len(y_true)) + 2) 102 | return np.sum(gains / discounts) 103 | 104 | 105 | def ndcg_score(y_true, y_score, k=10): 106 | best = dcg_score(y_true, y_true, k) 107 | actual = dcg_score(y_true, y_score, k) 108 | return actual / best 109 | 110 | 111 | def mrr_score(y_true, y_score): 112 | order = np.argsort(y_score)[::-1] 113 | y_true = np.take(y_true, order) 114 | rr_score = y_true / (np.arange(len(y_true)) + 1) 115 | return np.sum(rr_score) / np.sum(y_true) 116 | 117 | 118 | def hr(y_true, y_score, k=10): 119 | order = np.argsort(y_score)[::-1] 120 | y_tmp = np.take(y_true, order[:k]) 121 | return y_tmp.sum() / np.sum(y_true) 122 | 123 | def ot_acc_score(y_true, y_pred, mask): 124 | y_pred = torch.greater(y_pred, 0.0001).type(torch.long) 125 | y_true = torch.greater(y_true, 0.0001).type(torch.long) 126 | return torch.sum(torch.eq(y_true.to(int), y_pred.to(int))*mask), torch.sum(mask) # TP+TN, TP+TN+FP+FN 127 | 128 | def ot_precision_score(y_true, y_pred, mask): 129 | y_pred = torch.greater(y_pred, 0.0001) 130 | y_true = torch.greater(y_true, 0.0001) 131 | return torch.sum(y_true * y_pred * mask), torch.sum(y_pred * mask) # TP, TP+FP 132 | 133 | 134 | def ot_recall_score(y_true, y_pred, mask): 135 | y_pred = torch.greater(y_pred, 0.0001) 136 | y_true = torch.greater(y_true, 0.0001) 137 | return torch.sum(y_true * y_pred * mask), torch.sum(y_true * mask) #TP, TP+FN 138 | 139 | def ot_sum_1_score(y_true, y_pred): 140 | y_pred = torch.greater(y_pred, 0.0001) 141 | y_true = torch.greater(y_true, 0.0001) 142 | return torch.sum(y_true), y_true.numel() 143 | 144 | class T5PegasusTokenizer(BertTokenizer): 145 | def __init__(self, pre_tokenizer=lambda x: jieba.cut(x, HMM=False), *args, **kwargs): 146 | super().__init__(*args, **kwargs) 147 | self.pre_tokenizer = pre_tokenizer 148 | 149 | def _tokenize(self, text, *arg, **kwargs): 150 | split_tokens = [] 151 | for text in self.pre_tokenizer(text): 152 | if text in self.vocab: 153 | split_tokens.append(text) 154 | else: 155 | split_tokens.extend(super()._tokenize(text)) 156 | return split_tokens 157 | 158 | def submul(x1, x2): 159 | mul = x1 * x2 160 | sub = x1 - x2 161 | return torch.cat([sub, mul], -1) 162 | 163 | class EarlyStop: 164 | def __init__(self, patience, max_or_min="max"): 165 | self.patience = patience 166 | self.best_value = 0.0 167 | self.best_epoch = 0 168 | self.max_or_min = max_or_min 169 | def step(self, current_value, current_epoch): 170 | if self.max_or_min == 'max': 171 | print("Current:{} Best:{}".format(current_value, self.best_value)) 172 | if current_value > self.best_value: 173 | self.best_value = current_value 174 | self.best_epoch = current_epoch 175 | return True 176 | return False 177 | elif self.max_or_min == 'min': 178 | print("Current:{} Best:{}".format(current_value, self.best_value)) 179 | if current_value < self.best_value: 180 | self.best_value = current_value 181 | self.best_epoch = current_epoch 182 | return True 183 | return False 184 | else: 185 | print("early stop type is max or min") 186 | exit(-1) 187 | def stop_training(self, current_epoch) -> bool: 188 | return current_epoch - self.best_epoch > self.patience 189 | 190 | def softmax(x, axis=-1): 191 | """numpy版softmax 192 | """ 193 | x = x - x.max(axis=axis, keepdims=True) 194 | x = np.exp(x) 195 | return x / x.sum(axis=axis, keepdims=True) 196 | 197 | class AutoRegressiveDecoder(object): 198 | """通用自回归生成模型解码基类 199 | 包含beam search和random sample两种策略 200 | """ 201 | def __init__(self, start_id, end_id, maxlen,minlen=1, model=None, tokenizer=None): 202 | self.start_id = start_id 203 | self.end_id = end_id 204 | self.maxlen = maxlen 205 | self.minlen = minlen 206 | self.model = model 207 | self.tokenizer = tokenizer 208 | if start_id is None: 209 | self.first_output_ids = np.empty((1, 0), dtype=int) 210 | else: 211 | self.first_output_ids = np.array([[self.start_id]]) 212 | 213 | @staticmethod 214 | def wraps(default_rtype='probas', use_states=False): 215 | """用来进一步完善predict函数 216 | 目前包含:1. 设置rtype参数,并做相应处理; 217 | 2. 确定states的使用,并做相应处理; 218 | 3. 设置温度参数,并做相应处理。 219 | """ 220 | def actual_decorator(predict): 221 | def new_predict( 222 | self, 223 | inputs, 224 | output_ids, 225 | states, 226 | temperature=1, 227 | rtype=default_rtype 228 | ): 229 | assert rtype in ['probas', 'logits'] 230 | prediction = predict(self, inputs, output_ids, states) 231 | 232 | if not use_states: 233 | prediction = (prediction, None) 234 | 235 | if default_rtype == 'logits': 236 | prediction = ( 237 | softmax(prediction[0] / temperature), prediction[1] 238 | ) 239 | elif temperature != 1: 240 | probas = np.power(prediction[0], 1.0 / temperature) 241 | probas = probas / probas.sum(axis=-1, keepdims=True) 242 | prediction = (probas, prediction[1]) 243 | 244 | if rtype == 'probas': 245 | return prediction 246 | else: 247 | return np.log(prediction[0] + 1e-12), prediction[1] 248 | 249 | return new_predict 250 | 251 | return actual_decorator 252 | 253 | 254 | 255 | def predict(self, inputs, output_ids, states=None): 256 | """用户需自定义递归预测函数 257 | 说明:定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states, 258 | 其中default_rtype为字符串logits或probas,probas时返回归一化的概率, 259 | rtype=logits时则返回softmax前的结果或者概率对数。 260 | 返回:二元组 (得分或概率, states) 261 | """ 262 | raise NotImplementedError 263 | 264 | def beam_search(self, inputs, topk, states=None, temperature=1, min_ends=1): 265 | """beam search解码 266 | 说明:这里的topk即beam size; 267 | 返回:最优解码序列。 268 | """ 269 | inputs = [np.array([i]) for i in inputs] 270 | output_ids, output_scores = self.first_output_ids, np.zeros(1) 271 | for step in range(self.maxlen): 272 | scores, states = self.predict( 273 | inputs, output_ids, states, temperature, 'logits' 274 | ) # 计算当前得分 275 | if step == 0: # 第1步预测后将输入重复topk次 276 | inputs = [np.repeat(i, topk, axis=0) for i in inputs] 277 | scores = output_scores.reshape((-1, 1)) + scores # 综合累积得分 278 | indices = scores.argpartition(-topk, axis=None)[-topk:] # 仅保留topk 279 | indices_1 = indices // scores.shape[1] # 行索引 280 | indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 281 | output_ids = np.concatenate([output_ids[indices_1], indices_2], 282 | 1) # 更新输出 283 | output_scores = np.take_along_axis( 284 | scores, indices, axis=None 285 | ) # 更新得分 286 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束 287 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 288 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 289 | best = output_scores.argmax() # 得分最大的那个 290 | if is_end[best] and end_counts[best] >= min_ends: # 如果已经终止 291 | return output_ids[best] # 直接输出 292 | else: # 否则,只保留未完成部分 293 | flag = ~is_end | (end_counts < min_ends) # 标记未完成序列 294 | if not flag.all(): # 如果有已完成的 295 | inputs = [i[flag] for i in inputs] # 扔掉已完成序列 296 | output_ids = output_ids[flag] # 扔掉已完成序列 297 | output_scores = output_scores[flag] # 扔掉已完成序列 298 | end_counts = end_counts[flag] # 扔掉已完成end计数 299 | topk = flag.sum() # topk相应变化 300 | # 达到长度直接输出 301 | return output_ids[output_scores.argmax()] 302 | 303 | def random_sample( 304 | self, 305 | inputs, 306 | n, 307 | topk=None, 308 | topp=None, 309 | states=None, 310 | temperature=1, 311 | min_ends=1 312 | ): 313 | """随机采样n个结果 314 | 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp 315 | 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 316 | 返回:n个解码序列组成的list。 317 | """ 318 | inputs = [np.array([i]) for i in inputs] 319 | output_ids = self.first_output_ids 320 | results = [] 321 | for step in range(self.maxlen): 322 | probas, states = self.predict( 323 | inputs, output_ids, states, temperature, 'probas' 324 | ) # 计算当前概率 325 | probas /= probas.sum(axis=1, keepdims=True) # 确保归一化 326 | if step == 0: # 第1步预测后将结果重复n次 327 | probas = np.repeat(probas, n, axis=0) 328 | inputs = [np.repeat(i, n, axis=0) for i in inputs] 329 | output_ids = np.repeat(output_ids, n, axis=0) 330 | if topk is not None: 331 | k_indices = probas.argpartition(-topk, 332 | axis=1)[:, -topk:] # 仅保留topk 333 | probas = np.take_along_axis(probas, k_indices, axis=1) # topk概率 334 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 335 | if topp is not None: 336 | p_indices = probas.argsort(axis=1)[:, ::-1] # 从高到低排序 337 | probas = np.take_along_axis(probas, p_indices, axis=1) # 排序概率 338 | cumsum_probas = np.cumsum(probas, axis=1) # 累积概率 339 | flag = np.roll(cumsum_probas >= topp, 1, axis=1) # 标记超过topp的部分 340 | flag[:, 0] = False # 结合上面的np.roll,实现平移一位的效果 341 | probas[flag] = 0 # 后面的全部置零 342 | probas /= probas.sum(axis=1, keepdims=True) # 重新归一化 343 | sample_func = lambda p: np.random.choice(len(p), p=p) # 按概率采样函数 344 | sample_ids = np.apply_along_axis(sample_func, 1, probas) # 执行采样 345 | sample_ids = sample_ids.reshape((-1, 1)) # 对齐形状 346 | if topp is not None: 347 | sample_ids = np.take_along_axis( 348 | p_indices, sample_ids, axis=1 349 | ) # 对齐原id 350 | if topk is not None: 351 | sample_ids = np.take_along_axis( 352 | k_indices, sample_ids, axis=1 353 | ) # 对齐原id 354 | output_ids = np.concatenate([output_ids, sample_ids], 1) # 更新输出 355 | is_end = output_ids[:, -1] == self.end_id # 标记是否以end标记结束 356 | end_counts = (output_ids == self.end_id).sum(1) # 统计出现的end标记 357 | if output_ids.shape[1] >= self.minlen: # 最短长度判断 358 | flag = is_end & (end_counts >= min_ends) # 标记已完成序列 359 | if flag.any(): # 如果有已完成的 360 | for ids in output_ids[flag]: # 存好已完成序列 361 | results.append(ids) 362 | flag = (flag == False) # 标记未完成序列 363 | inputs = [i[flag] for i in inputs] # 只保留未完成部分输入 364 | output_ids = output_ids[flag] # 只保留未完成部分候选集 365 | end_counts = end_counts[flag] # 只保留未完成部分end计数 366 | if len(output_ids) == 0: 367 | break 368 | # 如果还有未完成序列,直接放入结果 369 | for ids in output_ids: 370 | results.append(ids) 371 | # 返回结果 372 | return results 373 | 374 | class EMA(): 375 | def __init__(self, model, decay): 376 | self.model = model 377 | self.decay = decay 378 | self.shadow = {} 379 | self.backup = {} 380 | 381 | def register(self): 382 | for name, param in self.model.named_parameters(): 383 | if param.requires_grad: 384 | self.shadow[name] = param.data.clone() 385 | 386 | def update(self): 387 | for name, param in self.model.named_parameters(): 388 | if param.requires_grad: 389 | assert name in self.shadow 390 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 391 | self.shadow[name] = new_average.clone() 392 | 393 | def apply_shadow(self): 394 | for name, param in self.model.named_parameters(): 395 | if param.requires_grad: 396 | assert name in self.shadow 397 | self.backup[name] = param.data 398 | param.data = self.shadow[name] 399 | 400 | def restore(self): 401 | for name, param in self.model.named_parameters(): 402 | if param.requires_grad: 403 | assert name in self.backup 404 | param.data = self.backup[name] 405 | self.backup = {} 406 | 407 | 408 | def convert_to_unicode(text, encoding='utf-8', errors='ignore'): 409 | """字符串转换为unicode格式(假设输入为utf-8格式) 410 | """ 411 | if is_py2: 412 | if isinstance(text, str): 413 | text = text.decode(encoding, errors=errors) 414 | else: 415 | if isinstance(text, bytes): 416 | text = text.decode(encoding, errors=errors) 417 | return text 418 | 419 | 420 | def convert_to_str(text, encoding='utf-8', errors='ignore'): 421 | """字符串转换为str格式(假设输入为utf-8格式) 422 | """ 423 | 424 | if isinstance(text, bytes): 425 | text = text.decode(encoding, errors=errors) 426 | return text 427 | 428 | 429 | def is_string(s): 430 | """判断是否是字符串 431 | """ 432 | return isinstance(s, str) 433 | 434 | is_py2 = six.PY2 435 | 436 | 437 | 438 | def parallel_apply( 439 | func, 440 | iterable, 441 | workers, 442 | max_queue_size, 443 | callback=None, 444 | dummy=False, 445 | random_seeds=True, 446 | unordered=True 447 | ): 448 | """多进程或多线程地将func应用到iterable的每个元素中。 449 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 450 | 输出可能是func(c), func(a), func(b)。 451 | 参数: 452 | callback: 处理单个输出的回调函数; 453 | dummy: False是多进程/线性,True则是多线程/线性; 454 | random_seeds: 每个进程的随机种子; 455 | unordered: 若为False,则按照输入顺序返回,仅当callback为None时生效。 456 | """ 457 | generator = parallel_apply_generator( 458 | func, iterable, workers, max_queue_size, dummy, random_seeds 459 | ) 460 | 461 | if callback is None: 462 | if unordered: 463 | return [d for i, d in generator] 464 | else: 465 | results = sorted(generator, key=lambda d: d[0]) 466 | return [d for i, d in results] 467 | else: 468 | for d in generator: 469 | callback(d) 470 | 471 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): 472 | """Numpy函数,将序列padding到同一长度 473 | """ 474 | if length is None: 475 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) 476 | elif not hasattr(length, '__getitem__'): 477 | length = [length] 478 | 479 | slices = [np.s_[:length[i]] for i in range(seq_dims)] 480 | slices = tuple(slices) if len(slices) > 1 else slices[0] 481 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 482 | 483 | outputs = [] 484 | for x in inputs: 485 | x = x[slices] 486 | for i in range(seq_dims): 487 | if mode == 'post': 488 | pad_width[i] = (0, length[i] - np.shape(x)[i]) 489 | elif mode == 'pre': 490 | pad_width[i] = (length[i] - np.shape(x)[i], 0) 491 | else: 492 | raise ValueError('"mode" argument must be "post" or "pre".') 493 | x = np.pad(x.cpu(), pad_width, 'constant', constant_values=value) 494 | outputs.append(x) 495 | 496 | return np.array(outputs) 497 | 498 | def parallel_apply_generator( 499 | func, iterable, workers, max_queue_size, dummy=False, random_seeds=True 500 | ): 501 | """多进程或多线程地将func应用到iterable的每个元素中。 502 | 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是 503 | 输出可能是func(c), func(a), func(b)。结果将作为一个 504 | generator返回,其中每个item是输入的序号以及该输入对应的 505 | 处理结果。 506 | 参数: 507 | dummy: False是多进程/线性,True则是多线程/线性; 508 | random_seeds: 每个进程的随机种子。 509 | """ 510 | if dummy: 511 | from multiprocessing.dummy import Pool, Queue 512 | else: 513 | from multiprocessing import Pool, Queue 514 | 515 | in_queue, out_queue, seed_queue = Queue(max_queue_size), Queue(), Queue() 516 | if random_seeds is True: 517 | random_seeds = [None] * workers 518 | elif random_seeds is None or random_seeds is False: 519 | random_seeds = [] 520 | for seed in random_seeds: 521 | seed_queue.put(seed) 522 | 523 | def worker_step(in_queue, out_queue): 524 | """单步函数包装成循环执行 525 | """ 526 | if not seed_queue.empty(): 527 | np.random.seed(seed_queue.get()) 528 | while True: 529 | i, d = in_queue.get() 530 | r = func(d) 531 | out_queue.put((i, r)) 532 | 533 | # 启动多进程/线程 534 | pool = Pool(workers, worker_step, (in_queue, out_queue)) 535 | 536 | # 存入数据,取出结果 537 | in_count, out_count = 0, 0 538 | for i, d in enumerate(iterable): 539 | in_count += 1 540 | while True: 541 | try: 542 | in_queue.put((i, d), block=False) 543 | break 544 | except six.moves.queue.Full: 545 | for _ in range(out_queue.qsize()): 546 | yield out_queue.get() 547 | out_count += 1 548 | if in_count % max_queue_size == 0: 549 | for _ in range(out_queue.qsize()): 550 | yield out_queue.get() 551 | out_count += 1 552 | 553 | while out_count != in_count: 554 | for _ in range(out_queue.qsize()): 555 | yield out_queue.get() 556 | out_count += 1 557 | 558 | pool.terminate() 559 | 560 | def text_segmentate(text, maxlen, seps='\n', strips=None): 561 | """将文本按照标点符号划分为若干个短句 562 | """ 563 | text = text.strip().strip(strips) 564 | if seps and len(text) > maxlen: 565 | pieces = text.split(seps[0]) 566 | text, texts = '', [] 567 | for i, p in enumerate(pieces): 568 | if text and p and len(text) + len(p) > maxlen - 1: 569 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 570 | text = '' 571 | if i + 1 == len(pieces): 572 | text = text + p 573 | else: 574 | text = text + p + seps[0] 575 | if text: 576 | texts.extend(text_segmentate(text, maxlen, seps[1:], strips)) 577 | return texts 578 | else: 579 | return [text] 580 | 581 | 582 | def load_user_dict(filename): 583 | """加载用户词典 584 | """ 585 | user_dict = [] 586 | with open(filename, encoding='utf-8') as f: 587 | for l in f: 588 | w = l.split()[0] 589 | user_dict.append(w) 590 | return user_dict 591 | 592 | 593 | def data_split(data, mode, splite_ratio=0.8, if_random=False): 594 | """划分训练集和验证集 595 | """ 596 | if if_random: 597 | data = copy.deepcopy(data) 598 | random.seed(1) 599 | random.shuffle(data) 600 | else: 601 | pass 602 | splite_point1 = int(splite_ratio*len(data)) 603 | splite_point2 = int((splite_ratio+0.1) * len(data)) 604 | if mode == 'train': 605 | D = data[:splite_point1] 606 | elif mode == 'valid': 607 | D = data[splite_point1:splite_point2] 608 | elif mode == 'test': 609 | D = data[splite_point2:] 610 | else: 611 | print("mode type can only in train test valid") 612 | 613 | 614 | 615 | if isinstance(data, np.ndarray): 616 | return np.array(D) 617 | else: 618 | return D 619 | 620 | 621 | class SmoothCrossEntropy(nn.Module): 622 | """ 623 | loss = SmoothCrossEntropy() 624 | input = torch.randn(3, 5, requires_grad=True) 625 | target = torch.empty(3, dtype=torch.long).random_(5) 626 | output = loss(input, target) 627 | """ 628 | def __init__(self, alpha=0.1): 629 | super(SmoothCrossEntropy, self).__init__() 630 | self.alpha = alpha 631 | 632 | def forward(self, logits, labels): 633 | num_classes = logits.shape[-1] 634 | alpha_div_k = self.alpha / num_classes 635 | target_probs = F.one_hot(labels, num_classes=num_classes).float() * \ 636 | (1. - self.alpha) + alpha_div_k 637 | loss = -(target_probs * torch.log_softmax(logits, dim=-1)).sum(dim=-1) 638 | return loss.mean() 639 | 640 | def prediction_data_split(data, mode, splite_ratio=0.8): 641 | """划分训练集和验证集 642 | """ 643 | D = [] 644 | splite_point1 = int(splite_ratio * len(data)) 645 | splite_point2 = int((splite_ratio+0.1) * len(data)) 646 | 647 | if mode == 'train': 648 | D += data[:splite_point1] 649 | elif mode == 'valid': 650 | D += data[splite_point1:splite_point2] 651 | else: 652 | D += data[splite_point2:] 653 | 654 | return D 655 | 656 | 657 | def compute_rouge(source, target, unit='word'): 658 | """计算rouge-1、rouge-2、rouge-l 659 | """ 660 | # if unit == 'word': 661 | # source = jieba.cut(source, HMM=False) 662 | # target = jieba.cut(target, HMM=False) 663 | source, target = ' '.join(source), ' '.join(target) 664 | try: 665 | scores = rouge.get_scores(hyps=source, refs=target) 666 | return { 667 | 'rouge-1': scores[0]['rouge-1']['f'], 668 | 'rouge-2': scores[0]['rouge-2']['f'], 669 | 'rouge-l': scores[0]['rouge-l']['f'], 670 | } 671 | except ValueError: 672 | return { 673 | 'rouge-1': 0.0, 674 | 'rouge-2': 0.0, 675 | 'rouge-l': 0.0, 676 | } 677 | 678 | 679 | def compute_metrics(source, target, unit='word'): 680 | """计算所有metrics 681 | """ 682 | metrics = compute_rouge(source, target, unit) 683 | metrics['main'] = ( 684 | metrics['rouge-1'] * 0.2 + metrics['rouge-2'] * 0.4 + 685 | metrics['rouge-l'] * 0.4 686 | ) 687 | return metrics 688 | 689 | 690 | def compute_main_metric(source, target, unit='word'): 691 | """计算主要metric 692 | """ 693 | return compute_metrics(source, target, unit)['main'] 694 | 695 | 696 | def longest_common_subsequence(source, target): 697 | """最长公共子序列(source和target的最长非连续子序列) 698 | 返回:子序列长度, 映射关系(映射对组成的list) 699 | 注意:最长公共子序列可能不止一个,所返回的映射只代表其中一个。 700 | """ 701 | c = defaultdict(int) 702 | for i, si in enumerate(source, 1): 703 | for j, tj in enumerate(target, 1): 704 | if si == tj: 705 | c[i, j] = c[i - 1, j - 1] + 1 706 | elif c[i, j - 1] > c[i - 1, j]: 707 | c[i, j] = c[i, j - 1] 708 | else: 709 | c[i, j] = c[i - 1, j] 710 | l, mapping = c[len(source), len(target)], [] 711 | i, j = len(source) - 1, len(target) - 1 712 | while len(mapping) < l: 713 | if source[i] == target[j]: 714 | mapping.append((i, j)) 715 | i, j = i - 1, j - 1 716 | elif c[i + 1, j] > c[i, j + 1]: 717 | j = j - 1 718 | else: 719 | i = i - 1 720 | return l, mapping[::-1] --------------------------------------------------------------------------------