├── .idea ├── .gitignore ├── OneRel.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── bert-base-chinese └── placeholder.txt ├── config ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── config.cpython-38.pyc └── config.py ├── dataloader ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── dataloader.cpython-38.pyc └── dataloader.py ├── dataset ├── CIE │ ├── dev_data.json │ ├── schema.json │ └── train_data.json └── tag2id.json ├── dev_result └── dev_result.json ├── framework ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── framework.cpython-38.pyc └── framework.py ├── main.py ├── models ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── models.cpython-38.pyc └── models.py └── test.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | -------------------------------------------------------------------------------- /.idea/OneRel.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OneRel 2 | chinese relation extract 3 | 复现AAAI2022中的《OneRel: Joint Entity and Relation Extraction with One Module in One Step》 4 | 5 | ## 环境 6 | python==3.8 7 | torch==1.8.1 8 | transformers==4.3.1 9 | 10 | ## 运行 11 | 1.模型的超参数写在配置文件config.py中 12 | 2.数据以dataset中的为例 13 | 3.运行python main.py训练模型 14 | 4.test.py可单条测试 15 | 16 | ## 结果 17 | 本人用ccks2020的中文医学实体关系抽取的数据训练模型,训练100个epoch最后验证集的f1_score约为60.78% 18 | 19 | ## 模型结果对比 20 | CasRel:f1=51.28% 21 | GPLinker:f1=59.15% 22 | OneRel:f1=60.78% 23 | -------------------------------------------------------------------------------- /bert-base-chinese/placeholder.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/bert-base-chinese/placeholder.txt -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__init__.py -------------------------------------------------------------------------------- /config/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /config/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | 2 | class Config(): 3 | def __init__(self): 4 | self.num_rel = 53 5 | # self.rel_num = 53 6 | self.train_file = "./dataset/CIE/train_data.json" 7 | self.dev_file = "./dataset/CIE/dev_data.json" 8 | self.schema_fn = "./dataset/CIE/schema.json" 9 | self.bert_path = "./bert-base-chinese" 10 | self.tags = "./dataset/tag2id.json" 11 | self.bert_dim = 768 12 | self.tag_size = 4 13 | self.batch_size = 4 14 | self.max_len = 510 15 | self.learning_rate = 1e-5 16 | self.epochs = 100 17 | self.checkpoint = "checkpoint/OneRel_self.pt" 18 | self.dev_result = "dev_result/dev_result.json" 19 | self.dropout_prob = 0.1 20 | self.entity_pair_dropout = 0.2 21 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dataloader/__pycache__/dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__pycache__/dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import json 4 | from transformers import BertTokenizer 5 | import numpy as np 6 | 7 | 8 | def find_idx(token, target): 9 | target_length = len(target) 10 | for k, v in enumerate(token): 11 | if token[k: k + target_length] == target: 12 | return k 13 | return -1 14 | 15 | 16 | class REDataset(Dataset): 17 | def __init__(self, config, file, is_test=False): 18 | self.config = config 19 | with open(file, "r", encoding="utf-8") as f: 20 | self.data = json.load(f) 21 | with open(self.config.schema_fn, "r", encoding="utf-8") as fs: 22 | self.rel2id = json.load(fs)[0] 23 | with open(self.config.tags) as ft: 24 | self.tag2id = json.load(ft)[1] 25 | self.tokenizer = BertTokenizer.from_pretrained(self.config.bert_path) 26 | self.is_test = is_test 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, idx): 32 | ins_json_data = self.data[idx] 33 | sentence = ins_json_data["text"] 34 | triple = ins_json_data["spo_list"] 35 | token = ['[CLS]'] + list(sentence)[:self.config.max_len] + ['[SEP]'] 36 | token_len = len(token) 37 | 38 | token2id = self.tokenizer.convert_tokens_to_ids(token) 39 | input_ids = np.array(token2id) 40 | mask = [0] * token_len 41 | mask = np.array(mask) + 1 42 | mask_len = len(mask) 43 | loss_mask = np.ones((mask_len, mask_len)) 44 | 45 | if not self.is_test: 46 | s2po = {} 47 | for spo in triple: 48 | triple_tuple = (list(spo[0]), spo[1], list(spo[2])) 49 | sub_head = find_idx(token, triple_tuple[0]) 50 | obj_head = find_idx(token, triple_tuple[2]) 51 | if sub_head != -1 and obj_head != -1: 52 | sub = (sub_head, sub_head + len(triple_tuple[0]) - 1) 53 | obj = (obj_head, obj_head + len(triple_tuple[2]) - 1, self.rel2id[triple_tuple[1]]) 54 | if sub not in s2po: 55 | s2po[sub] = [] 56 | s2po[sub].append(obj) 57 | 58 | if len(s2po) > 0: 59 | 60 | matrix = np.zeros((self.config.num_rel, token_len, token_len)) 61 | for sub in s2po: 62 | sub_head = sub[0] 63 | sub_tail = sub[1] 64 | for obj in s2po.get((sub_head, sub_tail), []): 65 | obj_head, obj_tail, rel = obj 66 | matrix[rel][sub_head][obj_head] = self.tag2id["HB-TB"] 67 | matrix[rel][sub_head][obj_tail] = self.tag2id["HB-TE"] 68 | matrix[rel][sub_tail][obj_tail] = self.tag2id["HE-TE"] 69 | return sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask 70 | else: 71 | return None 72 | else: 73 | # token2id = self.tokenizer.convert_tokens_to_ids(token) 74 | # input_ids = np.array(token2id) 75 | # mask = [0] * token_len 76 | # mask = np.array(mask) + 1 77 | # mask_len = len(mask) 78 | # loss_mask = np.ones((mask_len, mask_len)) 79 | matrix = np.zeros((self.config.num_rel, token_len, token_len)) 80 | return sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask 81 | 82 | def collate_fn(batch): 83 | 84 | batch = list(filter(lambda x: x is not None, batch)) 85 | batch.sort(key=lambda x: x[4], reverse=True) 86 | sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask = zip(*batch) 87 | 88 | cur_batch = len(batch) 89 | max_token_len = max(token_len) 90 | 91 | batch_input_ids = torch.LongTensor(cur_batch, max_token_len).zero_() 92 | batch_attention_mask = torch.LongTensor(cur_batch, max_token_len).zero_() 93 | batch_loss_mask = torch.LongTensor(cur_batch, 1, max_token_len, max_token_len).zero_() 94 | # 这里的53是指关系的数量,即跟config.py里的num_rel一致 95 | batch_matrix = torch.LongTensor(cur_batch, 53, max_token_len, max_token_len).zero_() 96 | 97 | for i in range(cur_batch): 98 | batch_input_ids[i, :token_len[i]].copy_(torch.from_numpy(input_ids[i])) 99 | batch_attention_mask[i, :token_len[i]].copy_(torch.from_numpy(mask[i])) 100 | batch_loss_mask[i, 0, :token_len[i], :token_len[i]].copy_(torch.from_numpy(loss_mask[i])) 101 | batch_matrix[i, :, :token_len[i], :token_len[i]].copy_(torch.from_numpy(matrix[i])) 102 | 103 | return {"sentence": sentence, 104 | "token": token, 105 | "triple": triple, 106 | "input_ids": batch_input_ids, 107 | "attention_mask": batch_attention_mask, 108 | "loss_mask": batch_loss_mask, 109 | "matrix": batch_matrix} 110 | 111 | -------------------------------------------------------------------------------- /dataset/CIE/schema.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "疾病/预防/其他": 0, 4 | "疾病/阶段/其他": 1, 5 | "疾病/就诊科室/其他": 2, 6 | "其他/同义词/其他": 3, 7 | "疾病/辅助治疗/其他治疗": 4, 8 | "疾病/化疗/其他治疗": 5, 9 | "疾病/放射治疗/其他治疗": 6, 10 | "其他治疗/同义词/其他治疗": 7, 11 | "疾病/手术治疗/手术治疗": 8, 12 | "手术治疗/同义词/手术治疗": 9, 13 | "疾病/实验室检查/检查": 10, 14 | "疾病/影像学检查/检查": 11, 15 | "疾病/辅助检查/检查": 12, 16 | "疾病/组织学检查/检查": 13, 17 | "检查/同义词/检查": 14, 18 | "疾病/内窥镜检查/检查": 15, 19 | "疾病/筛查/检查": 16, 20 | "疾病/多发群体/流行病学": 17, 21 | "疾病/发病率/流行病学": 18, 22 | "疾病/发病年龄/流行病学": 19, 23 | "疾病/多发地区/流行病学": 20, 24 | "疾病/发病性别倾向/流行病学": 21, 25 | "疾病/死亡率/流行病学": 22, 26 | "疾病/多发季节/流行病学": 23, 27 | "疾病/传播途径/流行病学": 24, 28 | "流行病学/同义词/流行病学": 25, 29 | "疾病/同义词/疾病": 26, 30 | "疾病/并发症/疾病": 27, 31 | "疾病/病理分型/疾病": 28, 32 | "疾病/相关(导致)/疾病": 29, 33 | "疾病/鉴别诊断/疾病": 30, 34 | "疾病/相关(转化)/疾病": 31, 35 | "疾病/相关(症状)/疾病": 32, 36 | "疾病/临床表现/症状": 33, 37 | "疾病/治疗后症状/症状": 34, 38 | "疾病/侵及周围组织转移的症状/症状": 35, 39 | "症状/同义词/症状": 36, 40 | "疾病/病因/社会学": 37, 41 | "疾病/高危因素/社会学": 38, 42 | "疾病/风险评估因素/社会学": 39, 43 | "疾病/病史/社会学": 40, 44 | "疾病/遗传因素/社会学": 41, 45 | "社会学/同义词/社会学": 42, 46 | "疾病/发病机制/社会学": 43, 47 | "疾病/病理生理/社会学": 44, 48 | "疾病/药物治疗/药物": 45, 49 | "药物/同义词/药物": 46, 50 | "疾病/发病部位/部位": 47, 51 | "疾病/转移部位/部位": 48, 52 | "疾病/外侵部位/部位": 49, 53 | "部位/同义词/部位": 50, 54 | "疾病/预后状况/预后": 51, 55 | "疾病/预后生存率/预后": 52 56 | }, 57 | { 58 | "0": "疾病/预防/其他", 59 | "1": "疾病/阶段/其他", 60 | "2": "疾病/就诊科室/其他", 61 | "3": "其他/同义词/其他", 62 | "4": "疾病/辅助治疗/其他治疗", 63 | "5": "疾病/化疗/其他治疗", 64 | "6": "疾病/放射治疗/其他治疗", 65 | "7": "其他治疗/同义词/其他治疗", 66 | "8": "疾病/手术治疗/手术治疗", 67 | "9": "手术治疗/同义词/手术治疗", 68 | "10": "疾病/实验室检查/检查", 69 | "11": "疾病/影像学检查/检查", 70 | "12": "疾病/辅助检查/检查", 71 | "13": "疾病/组织学检查/检查", 72 | "14": "检查/同义词/检查", 73 | "15": "疾病/内窥镜检查/检查", 74 | "16": "疾病/筛查/检查", 75 | "17": "疾病/多发群体/流行病学", 76 | "18": "疾病/发病率/流行病学", 77 | "19": "疾病/发病年龄/流行病学", 78 | "20": "疾病/多发地区/流行病学", 79 | "21": "疾病/发病性别倾向/流行病学", 80 | "22": "疾病/死亡率/流行病学", 81 | "23": "疾病/多发季节/流行病学", 82 | "24": "疾病/传播途径/流行病学", 83 | "25": "流行病学/同义词/流行病学", 84 | "26": "疾病/同义词/疾病", 85 | "27": "疾病/并发症/疾病", 86 | "28": "疾病/病理分型/疾病", 87 | "29": "疾病/相关(导致)/疾病", 88 | "30": "疾病/鉴别诊断/疾病", 89 | "31": "疾病/相关(转化)/疾病", 90 | "32": "疾病/相关(症状)/疾病", 91 | "33": "疾病/临床表现/症状", 92 | "34": "疾病/治疗后症状/症状", 93 | "35": "疾病/侵及周围组织转移的症状/症状", 94 | "36": "症状/同义词/症状", 95 | "37": "疾病/病因/社会学", 96 | "38": "疾病/高危因素/社会学", 97 | "39": "疾病/风险评估因素/社会学", 98 | "40": "疾病/病史/社会学", 99 | "41": "疾病/遗传因素/社会学", 100 | "42": "社会学/同义词/社会学", 101 | "43": "疾病/发病机制/社会学", 102 | "44": "疾病/病理生理/社会学", 103 | "45": "疾病/药物治疗/药物", 104 | "46": "药物/同义词/药物", 105 | "47": "疾病/发病部位/部位", 106 | "48": "疾病/转移部位/部位", 107 | "49": "疾病/外侵部位/部位", 108 | "50": "部位/同义词/部位", 109 | "51": "疾病/预后状况/预后", 110 | "52": "疾病/预后生存率/预后" 111 | } 112 | ] -------------------------------------------------------------------------------- /dataset/tag2id.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "0": "A", 4 | "1": "HB-TB", 5 | "2": "HB-TE", 6 | "3": "HE-TE" 7 | }, 8 | { 9 | "A": 0, 10 | "HB-TB": 1, 11 | "HB-TE": 2, 12 | "HE-TE": 3 13 | } 14 | ] -------------------------------------------------------------------------------- /framework/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__init__.py -------------------------------------------------------------------------------- /framework/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /framework/__pycache__/framework.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__pycache__/framework.cpython-38.pyc -------------------------------------------------------------------------------- /framework/framework.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | 8 | from dataloader.dataloader import REDataset, collate_fn 9 | from models.models import OneRel 10 | 11 | 12 | class Framework(): 13 | def __init__(self, config): 14 | self.config = config 15 | with open(self.config.tags, "r", encoding="utf-8") as f: 16 | self.tag2id = json.load(f)[1] 17 | with open(self.config.schema_fn, "r", encoding="utf-8") as fs: 18 | self.id2rel = json.load(fs)[1] 19 | self.loss_function = torch.nn.CrossEntropyLoss(reduction="none") 20 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 21 | 22 | def train(self): 23 | 24 | def cal_loss(predict, target, mask): 25 | loss_ = self.loss_function(predict, target) 26 | loss = torch.sum(loss_ * mask) / torch.sum(mask) 27 | return loss 28 | 29 | train_dataset = REDataset(self.config, self.config.train_file) 30 | train_dataloader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=collate_fn) 31 | 32 | dev_dataset = REDataset(self.config, self.config.dev_file) 33 | dev_dataloader = DataLoader(dev_dataset, batch_size=1, collate_fn=collate_fn) 34 | 35 | model = OneRel(self.config).to(self.device) 36 | optimizer = torch.optim.AdamW(model.parameters(), lr=self.config.learning_rate) 37 | 38 | global_step = 0 39 | global_loss = 0 40 | best_epoch = 0 41 | 42 | best_f1_score = 0 43 | best_recall = 0 44 | best_precision = 0 45 | 46 | for epoch in range(self.config.epochs): 47 | print("[{}/{}]".format(epoch+1, self.config.epochs)) 48 | for data in tqdm(train_dataloader): 49 | output = model(data) 50 | 51 | optimizer.zero_grad() 52 | loss = cal_loss(output, data["matrix"].to(self.device), data["loss_mask"].to(self.device)) 53 | global_loss += loss.item() 54 | 55 | loss.backward() 56 | optimizer.step() 57 | if (global_step + 1) % 2000 == 0: 58 | print("epoch: {} global_step: {} global_loss: {:5.4f}".format(epoch + 1, global_step + 1, global_loss)) 59 | global_loss = 0 60 | 61 | if (epoch + 1) % 5 == 0: 62 | precision, recall, f1_score, predict = self.evaluate(dev_dataloader, model) 63 | if f1_score > best_f1_score: 64 | best_f1_score = f1_score 65 | best_recall = recall 66 | best_precision = precision 67 | best_epoch = epoch + 1 68 | print("save model ......") 69 | torch.save(model.state_dict(), self.config.checkpoint) 70 | json.dump(predict, open(self.config.dev_result, "w", encoding="utf-8"), indent=4, ensure_ascii=False) 71 | print("epoch:{} best_epoch:{} best_recall:{:5.4f} best_precision:{:5.4f} best_f1_score:{:5.4f}".format(epoch+1, best_epoch, best_recall, best_precision, best_f1_score)) 72 | print("best_epoch:{} best_recall:{:5.4f} best_precision:{:5.4f} best_f1_score:{:5.4f}".format(best_epoch, best_recall, best_precision, best_f1_score)) 73 | 74 | def evaluate(self, dataloader, model): 75 | print("eval mode......") 76 | model.eval() 77 | predict_num, gold_num, correct_num = 0, 0, 0 78 | predict = [] 79 | def to_ret(data): 80 | ret = [] 81 | for i in data: 82 | ret.append(tuple(i)) 83 | return tuple(ret) 84 | 85 | with torch.no_grad(): 86 | for data in tqdm(dataloader): 87 | # [num_rel, seq_len, seq_len] 88 | pred_triple_matrix = model(data, train=False).cpu()[0] 89 | number_rel, seq_lens, seq_lens = pred_triple_matrix.shape 90 | relations, heads, tails = np.where(pred_triple_matrix > 0) 91 | 92 | token = data["token"][0] 93 | gold = data["triple"][0] 94 | pair_numbers = len(relations) 95 | predict_triple = [] 96 | if pair_numbers > 0: 97 | for i in range(pair_numbers): 98 | r_index = relations[i] 99 | h_start_idx = heads[i] 100 | t_start_idx = tails[i] 101 | if pred_triple_matrix[r_index][h_start_idx][t_start_idx] == self.tag2id["HB-TB"] and i + 1 < pair_numbers: 102 | t_end_idx = tails[i + 1] 103 | if pred_triple_matrix[r_index][h_start_idx][t_end_idx] == self.tag2id["HB-TE"]: 104 | for h_end_index in range(h_start_idx, seq_lens): 105 | if pred_triple_matrix[r_index][h_end_index][t_end_idx] == self.tag2id["HE-TE"]: 106 | 107 | subject_head, subject_tail = h_start_idx, h_end_index 108 | object_head, object_tail = t_start_idx, t_end_idx 109 | subject = ''.join(token[subject_head: subject_tail + 1]) 110 | object = ''.join(token[object_head: object_tail + 1]) 111 | relation = self.id2rel[str(int(r_index))] 112 | if len(subject) > 0 and len(object) > 0: 113 | predict_triple.append((subject, relation, object)) 114 | break 115 | gold = to_ret(gold) 116 | predict_triple = to_ret(predict_triple) 117 | gold_num += len(gold) 118 | predict_num += len(predict_triple) 119 | correct_num += len(set(gold) & set(predict_triple)) 120 | lack = set(gold) - set(predict_triple) 121 | new = set(predict_triple) - set(gold) 122 | predict.append({"text": data["sentence"][0], "gold": gold, "predict": predict_triple, 123 | "lack": list(lack), "new": list(new)}) 124 | 125 | precision = correct_num / (predict_num + 1e-10) 126 | recall = correct_num / (gold_num + 1e-10) 127 | f1_score = 2 * precision * recall / (precision + recall + 1e-10) 128 | print("predict_num: {} gold_num: {} correct_num: {}".format(predict_num, gold_num, correct_num)) 129 | model.train() 130 | return precision, recall, f1_score, predict -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from framework.framework import Framework 2 | from config.config import Config 3 | import torch 4 | import numpy as np 5 | 6 | seed = 1234 7 | torch.manual_seed(seed) 8 | torch.cuda.manual_seed(seed) 9 | np.random.seed(seed) 10 | torch.backends.cudnn.deterministic = True 11 | torch.backends.cudnn.benchmark = False 12 | 13 | config = Config() 14 | fw = Framework(config) 15 | fw.train() 16 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__init__.py -------------------------------------------------------------------------------- /models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /models/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | 5 | 6 | class OneRel(nn.Module): 7 | def __init__(self, config): 8 | super(OneRel, self).__init__() 9 | self.config = config 10 | self.bert = BertModel.from_pretrained(self.config.bert_path) 11 | self.relation_linear = nn.Linear(self.config.bert_dim * 3, self.config.num_rel * self.config.tag_size) 12 | self.project_matrix = nn.Linear(self.config.bert_dim * 2, self.config.bert_dim * 3) 13 | self.dropout = nn.Dropout(0.2) 14 | self.dropout_2 = nn.Dropout(0.1) 15 | self.activation = nn.ReLU() 16 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 17 | 18 | def get_encoded_text(self, input_ids, mask): 19 | bert_encoded_text = self.bert(input_ids=input_ids, attention_mask=mask)[0] 20 | return bert_encoded_text 21 | 22 | def get_triple_score(self, bert_encoded_text, train): 23 | batch_size, seq_len, bert_dim = bert_encoded_text.size() 24 | # [batch_size, seq_len*seq_len, bert_dim] 25 | head_rep = bert_encoded_text.unsqueeze(dim=2).expand(batch_size, seq_len, seq_len, bert_dim).reshape(batch_size, seq_len * seq_len, bert_dim) 26 | tail_rep = bert_encoded_text.repeat(1, seq_len, 1) 27 | # [batch_size, seq_len*seq_len, bert_dim * 2] 28 | entity_pair = torch.cat([head_rep, tail_rep], dim=-1) 29 | 30 | # [batch_size, seq_len*seq_len, bert_dim * 3] 31 | entity_pair = self.project_matrix(entity_pair) 32 | entity_pair = self.dropout_2(entity_pair) 33 | entity_pair = self.activation(entity_pair) 34 | 35 | # [batch_size, seq_len*seq_len, num_rel*tag_size] 36 | matrix_socre = self.relation_linear(entity_pair).reshape(batch_size, seq_len, seq_len, self.config.num_rel, self.config.tag_size) 37 | if train: 38 | return matrix_socre.permute(0, 4, 3, 1, 2) 39 | else: 40 | return matrix_socre.argmax(dim=-1).permute(0, 3, 1, 2) 41 | 42 | def forward(self, data, train=True): 43 | input_ids = data["input_ids"].to(self.device) 44 | attention_mask = data["attention_mask"].to(self.device) 45 | 46 | bert_encoded_text = self.get_encoded_text(input_ids, attention_mask) 47 | bert_encoded_text = self.dropout(bert_encoded_text) 48 | 49 | matrix_score = self.get_triple_score(bert_encoded_text, train) 50 | 51 | return matrix_score -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from models.models import OneRel 2 | import torch 3 | from config.config import Config 4 | import numpy as np 5 | import json 6 | from transformers import BertTokenizer 7 | 8 | seed = 1234 9 | torch.manual_seed(seed) 10 | torch.cuda.manual_seed(seed) 11 | np.random.seed(seed) 12 | torch.backends.cudnn.deterministic = True 13 | torch.backends.cudnn.benchmark = False 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | config = Config() 16 | 17 | id2label = json.load(open(config.schema_fn, "r", encoding="utf-8"))[1] 18 | id2tag = json.load(open(config.tags, "r", encoding="utf-8"))[1] 19 | 20 | tokenizer = BertTokenizer.from_pretrained(config.bert_path) 21 | model = OneRel(config) 22 | model.load_state_dict(torch.load(config.checkpoint, map_location=device)) 23 | model.to(device) 24 | model.eval() 25 | 26 | 27 | def parser(sentence): 28 | token = ['[CLS]'] + list(sentence)[:510] + ['[SEP]'] 29 | token2id = [tokenizer.convert_tokens_to_ids(token)] 30 | mask = [[1]*len(token)] 31 | data = {"input_ids": torch.LongTensor(token2id), "attention_mask": torch.LongTensor(mask)} 32 | # [num_rel, seq_len, seq_len] 33 | output = model(data, False).cpu()[0] 34 | num_rel, seq_lens, seq_lens = output.shape 35 | 36 | relations, heads, tails = np.where(output > 0) 37 | 38 | predict = [] 39 | relation_num = len(relations) 40 | predict = {"text": sentence, "predict": []} 41 | if relation_num > 0: 42 | for r in range(relation_num): 43 | rel2indx = relations[r] 44 | h_start_index = heads[r] 45 | t_start_index = tails[r] 46 | if output[rel2indx][h_start_index][t_start_index] == id2tag["HB-TB"] and r + 1 < relation_num: 47 | t_end_index = tails[r + 1] 48 | if output[rel2indx][h_start_index][t_end_index] == id2tag["HB-TE"]: 49 | for h_end_index in range(h_start_index, seq_lens): 50 | if output[rel2indx][h_end_index][t_end_index] == id2tag["HE-TE"]: 51 | subject_head, subject_tail = h_start_index, h_end_index 52 | object_head, object_tail = t_start_index, t_end_index 53 | subject = "".join(token[subject_head: subject_tail+1]) 54 | object = "".join(token[object_head: object_tail+1]) 55 | relation = id2label[str(rel2indx)] 56 | if len(subject) > 0 and len(object) > 0: 57 | predict["predict"].append((subject, relation, object)) 58 | break 59 | return predict 60 | 61 | if __name__ == '__main__': 62 | while True: 63 | sentence = input("请输入:") 64 | predict = parser(sentence) 65 | print(json.dumps(predict, indent=4, ensure_ascii=False)) --------------------------------------------------------------------------------