├── .idea
├── .gitignore
├── OneRel.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── bert-base-chinese
└── placeholder.txt
├── config
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── config.cpython-38.pyc
└── config.py
├── dataloader
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── dataloader.cpython-38.pyc
└── dataloader.py
├── dataset
├── CIE
│ ├── dev_data.json
│ ├── schema.json
│ └── train_data.json
└── tag2id.json
├── dev_result
└── dev_result.json
├── framework
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── framework.cpython-38.pyc
└── framework.py
├── main.py
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── models.cpython-38.pyc
└── models.py
└── test.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/.idea/OneRel.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OneRel
2 | chinese relation extract
3 | 复现AAAI2022中的《OneRel: Joint Entity and Relation Extraction with One Module in One Step》
4 |
5 | ## 环境
6 | python==3.8
7 | torch==1.8.1
8 | transformers==4.3.1
9 |
10 | ## 运行
11 | 1.模型的超参数写在配置文件config.py中
12 | 2.数据以dataset中的为例
13 | 3.运行python main.py训练模型
14 | 4.test.py可单条测试
15 |
16 | ## 结果
17 | 本人用ccks2020的中文医学实体关系抽取的数据训练模型,训练100个epoch最后验证集的f1_score约为60.78%
18 |
19 | ## 模型结果对比
20 | CasRel:f1=51.28%
21 | GPLinker:f1=59.15%
22 | OneRel:f1=60.78%
23 |
--------------------------------------------------------------------------------
/bert-base-chinese/placeholder.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/bert-base-chinese/placeholder.txt
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__init__.py
--------------------------------------------------------------------------------
/config/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/config/__pycache__/config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/config/__pycache__/config.cpython-38.pyc
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 |
2 | class Config():
3 | def __init__(self):
4 | self.num_rel = 53
5 | # self.rel_num = 53
6 | self.train_file = "./dataset/CIE/train_data.json"
7 | self.dev_file = "./dataset/CIE/dev_data.json"
8 | self.schema_fn = "./dataset/CIE/schema.json"
9 | self.bert_path = "./bert-base-chinese"
10 | self.tags = "./dataset/tag2id.json"
11 | self.bert_dim = 768
12 | self.tag_size = 4
13 | self.batch_size = 4
14 | self.max_len = 510
15 | self.learning_rate = 1e-5
16 | self.epochs = 100
17 | self.checkpoint = "checkpoint/OneRel_self.pt"
18 | self.dev_result = "dev_result/dev_result.json"
19 | self.dropout_prob = 0.1
20 | self.entity_pair_dropout = 0.2
21 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__init__.py
--------------------------------------------------------------------------------
/dataloader/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloader/__pycache__/dataloader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/dataloader/__pycache__/dataloader.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import json
4 | from transformers import BertTokenizer
5 | import numpy as np
6 |
7 |
8 | def find_idx(token, target):
9 | target_length = len(target)
10 | for k, v in enumerate(token):
11 | if token[k: k + target_length] == target:
12 | return k
13 | return -1
14 |
15 |
16 | class REDataset(Dataset):
17 | def __init__(self, config, file, is_test=False):
18 | self.config = config
19 | with open(file, "r", encoding="utf-8") as f:
20 | self.data = json.load(f)
21 | with open(self.config.schema_fn, "r", encoding="utf-8") as fs:
22 | self.rel2id = json.load(fs)[0]
23 | with open(self.config.tags) as ft:
24 | self.tag2id = json.load(ft)[1]
25 | self.tokenizer = BertTokenizer.from_pretrained(self.config.bert_path)
26 | self.is_test = is_test
27 |
28 | def __len__(self):
29 | return len(self.data)
30 |
31 | def __getitem__(self, idx):
32 | ins_json_data = self.data[idx]
33 | sentence = ins_json_data["text"]
34 | triple = ins_json_data["spo_list"]
35 | token = ['[CLS]'] + list(sentence)[:self.config.max_len] + ['[SEP]']
36 | token_len = len(token)
37 |
38 | token2id = self.tokenizer.convert_tokens_to_ids(token)
39 | input_ids = np.array(token2id)
40 | mask = [0] * token_len
41 | mask = np.array(mask) + 1
42 | mask_len = len(mask)
43 | loss_mask = np.ones((mask_len, mask_len))
44 |
45 | if not self.is_test:
46 | s2po = {}
47 | for spo in triple:
48 | triple_tuple = (list(spo[0]), spo[1], list(spo[2]))
49 | sub_head = find_idx(token, triple_tuple[0])
50 | obj_head = find_idx(token, triple_tuple[2])
51 | if sub_head != -1 and obj_head != -1:
52 | sub = (sub_head, sub_head + len(triple_tuple[0]) - 1)
53 | obj = (obj_head, obj_head + len(triple_tuple[2]) - 1, self.rel2id[triple_tuple[1]])
54 | if sub not in s2po:
55 | s2po[sub] = []
56 | s2po[sub].append(obj)
57 |
58 | if len(s2po) > 0:
59 |
60 | matrix = np.zeros((self.config.num_rel, token_len, token_len))
61 | for sub in s2po:
62 | sub_head = sub[0]
63 | sub_tail = sub[1]
64 | for obj in s2po.get((sub_head, sub_tail), []):
65 | obj_head, obj_tail, rel = obj
66 | matrix[rel][sub_head][obj_head] = self.tag2id["HB-TB"]
67 | matrix[rel][sub_head][obj_tail] = self.tag2id["HB-TE"]
68 | matrix[rel][sub_tail][obj_tail] = self.tag2id["HE-TE"]
69 | return sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask
70 | else:
71 | return None
72 | else:
73 | # token2id = self.tokenizer.convert_tokens_to_ids(token)
74 | # input_ids = np.array(token2id)
75 | # mask = [0] * token_len
76 | # mask = np.array(mask) + 1
77 | # mask_len = len(mask)
78 | # loss_mask = np.ones((mask_len, mask_len))
79 | matrix = np.zeros((self.config.num_rel, token_len, token_len))
80 | return sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask
81 |
82 | def collate_fn(batch):
83 |
84 | batch = list(filter(lambda x: x is not None, batch))
85 | batch.sort(key=lambda x: x[4], reverse=True)
86 | sentence, triple, input_ids, mask, token_len, matrix, token, loss_mask = zip(*batch)
87 |
88 | cur_batch = len(batch)
89 | max_token_len = max(token_len)
90 |
91 | batch_input_ids = torch.LongTensor(cur_batch, max_token_len).zero_()
92 | batch_attention_mask = torch.LongTensor(cur_batch, max_token_len).zero_()
93 | batch_loss_mask = torch.LongTensor(cur_batch, 1, max_token_len, max_token_len).zero_()
94 | # 这里的53是指关系的数量,即跟config.py里的num_rel一致
95 | batch_matrix = torch.LongTensor(cur_batch, 53, max_token_len, max_token_len).zero_()
96 |
97 | for i in range(cur_batch):
98 | batch_input_ids[i, :token_len[i]].copy_(torch.from_numpy(input_ids[i]))
99 | batch_attention_mask[i, :token_len[i]].copy_(torch.from_numpy(mask[i]))
100 | batch_loss_mask[i, 0, :token_len[i], :token_len[i]].copy_(torch.from_numpy(loss_mask[i]))
101 | batch_matrix[i, :, :token_len[i], :token_len[i]].copy_(torch.from_numpy(matrix[i]))
102 |
103 | return {"sentence": sentence,
104 | "token": token,
105 | "triple": triple,
106 | "input_ids": batch_input_ids,
107 | "attention_mask": batch_attention_mask,
108 | "loss_mask": batch_loss_mask,
109 | "matrix": batch_matrix}
110 |
111 |
--------------------------------------------------------------------------------
/dataset/CIE/schema.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "疾病/预防/其他": 0,
4 | "疾病/阶段/其他": 1,
5 | "疾病/就诊科室/其他": 2,
6 | "其他/同义词/其他": 3,
7 | "疾病/辅助治疗/其他治疗": 4,
8 | "疾病/化疗/其他治疗": 5,
9 | "疾病/放射治疗/其他治疗": 6,
10 | "其他治疗/同义词/其他治疗": 7,
11 | "疾病/手术治疗/手术治疗": 8,
12 | "手术治疗/同义词/手术治疗": 9,
13 | "疾病/实验室检查/检查": 10,
14 | "疾病/影像学检查/检查": 11,
15 | "疾病/辅助检查/检查": 12,
16 | "疾病/组织学检查/检查": 13,
17 | "检查/同义词/检查": 14,
18 | "疾病/内窥镜检查/检查": 15,
19 | "疾病/筛查/检查": 16,
20 | "疾病/多发群体/流行病学": 17,
21 | "疾病/发病率/流行病学": 18,
22 | "疾病/发病年龄/流行病学": 19,
23 | "疾病/多发地区/流行病学": 20,
24 | "疾病/发病性别倾向/流行病学": 21,
25 | "疾病/死亡率/流行病学": 22,
26 | "疾病/多发季节/流行病学": 23,
27 | "疾病/传播途径/流行病学": 24,
28 | "流行病学/同义词/流行病学": 25,
29 | "疾病/同义词/疾病": 26,
30 | "疾病/并发症/疾病": 27,
31 | "疾病/病理分型/疾病": 28,
32 | "疾病/相关(导致)/疾病": 29,
33 | "疾病/鉴别诊断/疾病": 30,
34 | "疾病/相关(转化)/疾病": 31,
35 | "疾病/相关(症状)/疾病": 32,
36 | "疾病/临床表现/症状": 33,
37 | "疾病/治疗后症状/症状": 34,
38 | "疾病/侵及周围组织转移的症状/症状": 35,
39 | "症状/同义词/症状": 36,
40 | "疾病/病因/社会学": 37,
41 | "疾病/高危因素/社会学": 38,
42 | "疾病/风险评估因素/社会学": 39,
43 | "疾病/病史/社会学": 40,
44 | "疾病/遗传因素/社会学": 41,
45 | "社会学/同义词/社会学": 42,
46 | "疾病/发病机制/社会学": 43,
47 | "疾病/病理生理/社会学": 44,
48 | "疾病/药物治疗/药物": 45,
49 | "药物/同义词/药物": 46,
50 | "疾病/发病部位/部位": 47,
51 | "疾病/转移部位/部位": 48,
52 | "疾病/外侵部位/部位": 49,
53 | "部位/同义词/部位": 50,
54 | "疾病/预后状况/预后": 51,
55 | "疾病/预后生存率/预后": 52
56 | },
57 | {
58 | "0": "疾病/预防/其他",
59 | "1": "疾病/阶段/其他",
60 | "2": "疾病/就诊科室/其他",
61 | "3": "其他/同义词/其他",
62 | "4": "疾病/辅助治疗/其他治疗",
63 | "5": "疾病/化疗/其他治疗",
64 | "6": "疾病/放射治疗/其他治疗",
65 | "7": "其他治疗/同义词/其他治疗",
66 | "8": "疾病/手术治疗/手术治疗",
67 | "9": "手术治疗/同义词/手术治疗",
68 | "10": "疾病/实验室检查/检查",
69 | "11": "疾病/影像学检查/检查",
70 | "12": "疾病/辅助检查/检查",
71 | "13": "疾病/组织学检查/检查",
72 | "14": "检查/同义词/检查",
73 | "15": "疾病/内窥镜检查/检查",
74 | "16": "疾病/筛查/检查",
75 | "17": "疾病/多发群体/流行病学",
76 | "18": "疾病/发病率/流行病学",
77 | "19": "疾病/发病年龄/流行病学",
78 | "20": "疾病/多发地区/流行病学",
79 | "21": "疾病/发病性别倾向/流行病学",
80 | "22": "疾病/死亡率/流行病学",
81 | "23": "疾病/多发季节/流行病学",
82 | "24": "疾病/传播途径/流行病学",
83 | "25": "流行病学/同义词/流行病学",
84 | "26": "疾病/同义词/疾病",
85 | "27": "疾病/并发症/疾病",
86 | "28": "疾病/病理分型/疾病",
87 | "29": "疾病/相关(导致)/疾病",
88 | "30": "疾病/鉴别诊断/疾病",
89 | "31": "疾病/相关(转化)/疾病",
90 | "32": "疾病/相关(症状)/疾病",
91 | "33": "疾病/临床表现/症状",
92 | "34": "疾病/治疗后症状/症状",
93 | "35": "疾病/侵及周围组织转移的症状/症状",
94 | "36": "症状/同义词/症状",
95 | "37": "疾病/病因/社会学",
96 | "38": "疾病/高危因素/社会学",
97 | "39": "疾病/风险评估因素/社会学",
98 | "40": "疾病/病史/社会学",
99 | "41": "疾病/遗传因素/社会学",
100 | "42": "社会学/同义词/社会学",
101 | "43": "疾病/发病机制/社会学",
102 | "44": "疾病/病理生理/社会学",
103 | "45": "疾病/药物治疗/药物",
104 | "46": "药物/同义词/药物",
105 | "47": "疾病/发病部位/部位",
106 | "48": "疾病/转移部位/部位",
107 | "49": "疾病/外侵部位/部位",
108 | "50": "部位/同义词/部位",
109 | "51": "疾病/预后状况/预后",
110 | "52": "疾病/预后生存率/预后"
111 | }
112 | ]
--------------------------------------------------------------------------------
/dataset/tag2id.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "0": "A",
4 | "1": "HB-TB",
5 | "2": "HB-TE",
6 | "3": "HE-TE"
7 | },
8 | {
9 | "A": 0,
10 | "HB-TB": 1,
11 | "HB-TE": 2,
12 | "HE-TE": 3
13 | }
14 | ]
--------------------------------------------------------------------------------
/framework/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__init__.py
--------------------------------------------------------------------------------
/framework/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/framework/__pycache__/framework.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/framework/__pycache__/framework.cpython-38.pyc
--------------------------------------------------------------------------------
/framework/framework.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from tqdm import tqdm
7 |
8 | from dataloader.dataloader import REDataset, collate_fn
9 | from models.models import OneRel
10 |
11 |
12 | class Framework():
13 | def __init__(self, config):
14 | self.config = config
15 | with open(self.config.tags, "r", encoding="utf-8") as f:
16 | self.tag2id = json.load(f)[1]
17 | with open(self.config.schema_fn, "r", encoding="utf-8") as fs:
18 | self.id2rel = json.load(fs)[1]
19 | self.loss_function = torch.nn.CrossEntropyLoss(reduction="none")
20 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21 |
22 | def train(self):
23 |
24 | def cal_loss(predict, target, mask):
25 | loss_ = self.loss_function(predict, target)
26 | loss = torch.sum(loss_ * mask) / torch.sum(mask)
27 | return loss
28 |
29 | train_dataset = REDataset(self.config, self.config.train_file)
30 | train_dataloader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True, collate_fn=collate_fn)
31 |
32 | dev_dataset = REDataset(self.config, self.config.dev_file)
33 | dev_dataloader = DataLoader(dev_dataset, batch_size=1, collate_fn=collate_fn)
34 |
35 | model = OneRel(self.config).to(self.device)
36 | optimizer = torch.optim.AdamW(model.parameters(), lr=self.config.learning_rate)
37 |
38 | global_step = 0
39 | global_loss = 0
40 | best_epoch = 0
41 |
42 | best_f1_score = 0
43 | best_recall = 0
44 | best_precision = 0
45 |
46 | for epoch in range(self.config.epochs):
47 | print("[{}/{}]".format(epoch+1, self.config.epochs))
48 | for data in tqdm(train_dataloader):
49 | output = model(data)
50 |
51 | optimizer.zero_grad()
52 | loss = cal_loss(output, data["matrix"].to(self.device), data["loss_mask"].to(self.device))
53 | global_loss += loss.item()
54 |
55 | loss.backward()
56 | optimizer.step()
57 | if (global_step + 1) % 2000 == 0:
58 | print("epoch: {} global_step: {} global_loss: {:5.4f}".format(epoch + 1, global_step + 1, global_loss))
59 | global_loss = 0
60 |
61 | if (epoch + 1) % 5 == 0:
62 | precision, recall, f1_score, predict = self.evaluate(dev_dataloader, model)
63 | if f1_score > best_f1_score:
64 | best_f1_score = f1_score
65 | best_recall = recall
66 | best_precision = precision
67 | best_epoch = epoch + 1
68 | print("save model ......")
69 | torch.save(model.state_dict(), self.config.checkpoint)
70 | json.dump(predict, open(self.config.dev_result, "w", encoding="utf-8"), indent=4, ensure_ascii=False)
71 | print("epoch:{} best_epoch:{} best_recall:{:5.4f} best_precision:{:5.4f} best_f1_score:{:5.4f}".format(epoch+1, best_epoch, best_recall, best_precision, best_f1_score))
72 | print("best_epoch:{} best_recall:{:5.4f} best_precision:{:5.4f} best_f1_score:{:5.4f}".format(best_epoch, best_recall, best_precision, best_f1_score))
73 |
74 | def evaluate(self, dataloader, model):
75 | print("eval mode......")
76 | model.eval()
77 | predict_num, gold_num, correct_num = 0, 0, 0
78 | predict = []
79 | def to_ret(data):
80 | ret = []
81 | for i in data:
82 | ret.append(tuple(i))
83 | return tuple(ret)
84 |
85 | with torch.no_grad():
86 | for data in tqdm(dataloader):
87 | # [num_rel, seq_len, seq_len]
88 | pred_triple_matrix = model(data, train=False).cpu()[0]
89 | number_rel, seq_lens, seq_lens = pred_triple_matrix.shape
90 | relations, heads, tails = np.where(pred_triple_matrix > 0)
91 |
92 | token = data["token"][0]
93 | gold = data["triple"][0]
94 | pair_numbers = len(relations)
95 | predict_triple = []
96 | if pair_numbers > 0:
97 | for i in range(pair_numbers):
98 | r_index = relations[i]
99 | h_start_idx = heads[i]
100 | t_start_idx = tails[i]
101 | if pred_triple_matrix[r_index][h_start_idx][t_start_idx] == self.tag2id["HB-TB"] and i + 1 < pair_numbers:
102 | t_end_idx = tails[i + 1]
103 | if pred_triple_matrix[r_index][h_start_idx][t_end_idx] == self.tag2id["HB-TE"]:
104 | for h_end_index in range(h_start_idx, seq_lens):
105 | if pred_triple_matrix[r_index][h_end_index][t_end_idx] == self.tag2id["HE-TE"]:
106 |
107 | subject_head, subject_tail = h_start_idx, h_end_index
108 | object_head, object_tail = t_start_idx, t_end_idx
109 | subject = ''.join(token[subject_head: subject_tail + 1])
110 | object = ''.join(token[object_head: object_tail + 1])
111 | relation = self.id2rel[str(int(r_index))]
112 | if len(subject) > 0 and len(object) > 0:
113 | predict_triple.append((subject, relation, object))
114 | break
115 | gold = to_ret(gold)
116 | predict_triple = to_ret(predict_triple)
117 | gold_num += len(gold)
118 | predict_num += len(predict_triple)
119 | correct_num += len(set(gold) & set(predict_triple))
120 | lack = set(gold) - set(predict_triple)
121 | new = set(predict_triple) - set(gold)
122 | predict.append({"text": data["sentence"][0], "gold": gold, "predict": predict_triple,
123 | "lack": list(lack), "new": list(new)})
124 |
125 | precision = correct_num / (predict_num + 1e-10)
126 | recall = correct_num / (gold_num + 1e-10)
127 | f1_score = 2 * precision * recall / (precision + recall + 1e-10)
128 | print("predict_num: {} gold_num: {} correct_num: {}".format(predict_num, gold_num, correct_num))
129 | model.train()
130 | return precision, recall, f1_score, predict
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from framework.framework import Framework
2 | from config.config import Config
3 | import torch
4 | import numpy as np
5 |
6 | seed = 1234
7 | torch.manual_seed(seed)
8 | torch.cuda.manual_seed(seed)
9 | np.random.seed(seed)
10 | torch.backends.cudnn.deterministic = True
11 | torch.backends.cudnn.benchmark = False
12 |
13 | config = Config()
14 | fw = Framework(config)
15 | fw.train()
16 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__init__.py
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/models/__pycache__/models.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huangjie-nlp/OneRel/c1e38c6566725a8376787c277a322d3f36ba3c11/models/__pycache__/models.cpython-38.pyc
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from transformers import BertModel
4 |
5 |
6 | class OneRel(nn.Module):
7 | def __init__(self, config):
8 | super(OneRel, self).__init__()
9 | self.config = config
10 | self.bert = BertModel.from_pretrained(self.config.bert_path)
11 | self.relation_linear = nn.Linear(self.config.bert_dim * 3, self.config.num_rel * self.config.tag_size)
12 | self.project_matrix = nn.Linear(self.config.bert_dim * 2, self.config.bert_dim * 3)
13 | self.dropout = nn.Dropout(0.2)
14 | self.dropout_2 = nn.Dropout(0.1)
15 | self.activation = nn.ReLU()
16 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17 |
18 | def get_encoded_text(self, input_ids, mask):
19 | bert_encoded_text = self.bert(input_ids=input_ids, attention_mask=mask)[0]
20 | return bert_encoded_text
21 |
22 | def get_triple_score(self, bert_encoded_text, train):
23 | batch_size, seq_len, bert_dim = bert_encoded_text.size()
24 | # [batch_size, seq_len*seq_len, bert_dim]
25 | head_rep = bert_encoded_text.unsqueeze(dim=2).expand(batch_size, seq_len, seq_len, bert_dim).reshape(batch_size, seq_len * seq_len, bert_dim)
26 | tail_rep = bert_encoded_text.repeat(1, seq_len, 1)
27 | # [batch_size, seq_len*seq_len, bert_dim * 2]
28 | entity_pair = torch.cat([head_rep, tail_rep], dim=-1)
29 |
30 | # [batch_size, seq_len*seq_len, bert_dim * 3]
31 | entity_pair = self.project_matrix(entity_pair)
32 | entity_pair = self.dropout_2(entity_pair)
33 | entity_pair = self.activation(entity_pair)
34 |
35 | # [batch_size, seq_len*seq_len, num_rel*tag_size]
36 | matrix_socre = self.relation_linear(entity_pair).reshape(batch_size, seq_len, seq_len, self.config.num_rel, self.config.tag_size)
37 | if train:
38 | return matrix_socre.permute(0, 4, 3, 1, 2)
39 | else:
40 | return matrix_socre.argmax(dim=-1).permute(0, 3, 1, 2)
41 |
42 | def forward(self, data, train=True):
43 | input_ids = data["input_ids"].to(self.device)
44 | attention_mask = data["attention_mask"].to(self.device)
45 |
46 | bert_encoded_text = self.get_encoded_text(input_ids, attention_mask)
47 | bert_encoded_text = self.dropout(bert_encoded_text)
48 |
49 | matrix_score = self.get_triple_score(bert_encoded_text, train)
50 |
51 | return matrix_score
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from models.models import OneRel
2 | import torch
3 | from config.config import Config
4 | import numpy as np
5 | import json
6 | from transformers import BertTokenizer
7 |
8 | seed = 1234
9 | torch.manual_seed(seed)
10 | torch.cuda.manual_seed(seed)
11 | np.random.seed(seed)
12 | torch.backends.cudnn.deterministic = True
13 | torch.backends.cudnn.benchmark = False
14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15 | config = Config()
16 |
17 | id2label = json.load(open(config.schema_fn, "r", encoding="utf-8"))[1]
18 | id2tag = json.load(open(config.tags, "r", encoding="utf-8"))[1]
19 |
20 | tokenizer = BertTokenizer.from_pretrained(config.bert_path)
21 | model = OneRel(config)
22 | model.load_state_dict(torch.load(config.checkpoint, map_location=device))
23 | model.to(device)
24 | model.eval()
25 |
26 |
27 | def parser(sentence):
28 | token = ['[CLS]'] + list(sentence)[:510] + ['[SEP]']
29 | token2id = [tokenizer.convert_tokens_to_ids(token)]
30 | mask = [[1]*len(token)]
31 | data = {"input_ids": torch.LongTensor(token2id), "attention_mask": torch.LongTensor(mask)}
32 | # [num_rel, seq_len, seq_len]
33 | output = model(data, False).cpu()[0]
34 | num_rel, seq_lens, seq_lens = output.shape
35 |
36 | relations, heads, tails = np.where(output > 0)
37 |
38 | predict = []
39 | relation_num = len(relations)
40 | predict = {"text": sentence, "predict": []}
41 | if relation_num > 0:
42 | for r in range(relation_num):
43 | rel2indx = relations[r]
44 | h_start_index = heads[r]
45 | t_start_index = tails[r]
46 | if output[rel2indx][h_start_index][t_start_index] == id2tag["HB-TB"] and r + 1 < relation_num:
47 | t_end_index = tails[r + 1]
48 | if output[rel2indx][h_start_index][t_end_index] == id2tag["HB-TE"]:
49 | for h_end_index in range(h_start_index, seq_lens):
50 | if output[rel2indx][h_end_index][t_end_index] == id2tag["HE-TE"]:
51 | subject_head, subject_tail = h_start_index, h_end_index
52 | object_head, object_tail = t_start_index, t_end_index
53 | subject = "".join(token[subject_head: subject_tail+1])
54 | object = "".join(token[object_head: object_tail+1])
55 | relation = id2label[str(rel2indx)]
56 | if len(subject) > 0 and len(object) > 0:
57 | predict["predict"].append((subject, relation, object))
58 | break
59 | return predict
60 |
61 | if __name__ == '__main__':
62 | while True:
63 | sentence = input("请输入:")
64 | predict = parser(sentence)
65 | print(json.dumps(predict, indent=4, ensure_ascii=False))
--------------------------------------------------------------------------------