├── data └── baidu │ └── rel.json ├── model ├── config.py ├── callback.py ├── casRel.py ├── evaluate.py └── data.py ├── README.md └── Run.py /data/baidu/rel.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "出品公司", 3 | "1": "国籍", 4 | "2": "出生地", 5 | "3": "民族", 6 | "4": "出生日期", 7 | "5": "毕业院校", 8 | "6": "歌手", 9 | "7": "所属专辑", 10 | "8": "作词", 11 | "9": "作曲", 12 | "10": "连载网站", 13 | "11": "作者", 14 | "12": "出版社", 15 | "13": "主演", 16 | "14": "导演", 17 | "15": "编剧", 18 | "16": "上映时间", 19 | "17": "成立日期" 20 | } -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Config(object): 5 | def __init__(self, args): 6 | self.args = args 7 | self.lr = args.lr 8 | self.dataset = args.dataset 9 | self.batch_size = args.batch_size 10 | self.max_epoch = args.max_epoch 11 | self.max_len = args.max_len 12 | self.bert_name = args.bert_name 13 | self.bert_dim = args.bert_dim 14 | 15 | self.train_path = 'data/' + self.dataset + '/train.json' 16 | self.test_path = 'data/' + self.dataset + '/test.json' 17 | self.dev_path = 'data/' + self.dataset + '/dev.json' 18 | self.rel_path = 'data/' + self.dataset + '/rel.json' 19 | self.num_relations = len(json.load(open(self.rel_path, 'r'))) 20 | 21 | self.save_weights_dir = 'saved_weights/' + self.dataset + '/' 22 | self.save_logs_dir = 'saved_logs/' + self.dataset + '/' 23 | self.result_dir = 'results/' + self.dataset + '/' 24 | 25 | self.period = 200 26 | self.test_epoch = 3 27 | self.weights_save_name = 'model.pt' 28 | self.log_save_name = 'model.out' 29 | self.result_save_name = 'result.json' 30 | -------------------------------------------------------------------------------- /model/callback.py: -------------------------------------------------------------------------------- 1 | from fastNLP import Callback 2 | import os 3 | from model.evaluate import metric 4 | import torch 5 | 6 | 7 | class MyCallBack(Callback): 8 | def __init__(self, data_iter, rel_vocab, config): 9 | super().__init__() 10 | self.best_epoch = 0 11 | self.best_recall = 0 12 | self.best_precision = 0 13 | self.best_f1_score = 0 14 | 15 | self.data_iter = data_iter 16 | self.rel_vocab = rel_vocab 17 | self.config = config 18 | 19 | def logging(self, s, print_=True, log_=False): 20 | if print_: 21 | print(s) 22 | if log_: 23 | with open(os.path.join(self.config.save_logs_dir, self.config.log_save_name), 'a+') as f_log: 24 | f_log.write(s + '\n') 25 | 26 | def on_train_begin(self): 27 | self.logging("-" * 5 + "Begin Training" + "-" * 5) 28 | 29 | def on_epoch_end(self): 30 | precision, recall, f1_score = metric(self.data_iter, self.rel_vocab, self.config, self.model) 31 | self.logging('epoch {:3d}, f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}' 32 | .format(self.epoch, f1_score, precision, recall)) 33 | 34 | if f1_score > self.best_f1_score: 35 | self.best_f1_score = f1_score 36 | self.best_epoch = self.epoch 37 | self.best_precision = precision 38 | self.best_recall = recall 39 | self.logging("Saving the model, epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}". 40 | format(self.best_epoch, self.best_f1_score, precision, recall)) 41 | path = os.path.join(self.config.save_weights_dir, self.config.weights_save_name) 42 | torch.save(self.model.state_dict(), path) 43 | 44 | def on_train_end(self): 45 | self.logging("-" * 5 + "Finish training" + "-" * 5) 46 | self.logging("best epoch: {:3d}, best f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2}". 47 | format(self.best_epoch, self.best_f1_score, self.best_precision, self.best_recall)) 48 | 49 | -------------------------------------------------------------------------------- /model/casRel.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from transformers import BertModel 4 | 5 | 6 | class CasRel(nn.Module): 7 | def __init__(self, config): 8 | super(CasRel, self).__init__() 9 | self.config = config 10 | self.bert = BertModel.from_pretrained(self.config.bert_name) 11 | self.sub_heads_linear = nn.Linear(self.config.bert_dim, 1) 12 | self.sub_tails_linear = nn.Linear(self.config.bert_dim, 1) 13 | self.obj_heads_linear = nn.Linear(self.config.bert_dim, self.config.num_relations) 14 | self.obj_tails_linear = nn.Linear(self.config.bert_dim, self.config.num_relations) 15 | 16 | def get_encoded_text(self, token_ids, mask): 17 | encoded_text = self.bert(token_ids, attention_mask=mask)[0] 18 | return encoded_text 19 | 20 | def get_subs(self, encoded_text): 21 | pred_sub_heads = torch.sigmoid(self.sub_heads_linear(encoded_text)) 22 | pred_sub_tails = torch.sigmoid(self.sub_tails_linear(encoded_text)) 23 | return pred_sub_heads, pred_sub_tails 24 | 25 | def get_objs_for_specific_sub(self, sub_head_mapping, sub_tail_mapping, encoded_text): 26 | # sub_head_mapping [batch, 1, seq] * encoded_text [batch, seq, dim] 27 | sub_head = torch.matmul(sub_head_mapping, encoded_text) 28 | sub_tail = torch.matmul(sub_tail_mapping, encoded_text) 29 | sub = (sub_head + sub_tail) / 2 30 | encoded_text = encoded_text + sub 31 | pred_obj_heads = torch.sigmoid(self.obj_heads_linear(encoded_text)) 32 | pred_obj_tails = torch.sigmoid(self.obj_tails_linear(encoded_text)) 33 | return pred_obj_heads, pred_obj_tails 34 | 35 | def forward(self, token_ids, mask, sub_head, sub_tail): 36 | encoded_text = self.get_encoded_text(token_ids, mask) 37 | pred_sub_heads, pred_sub_tails = self.get_subs(encoded_text) 38 | sub_head_mapping = sub_head.unsqueeze(1) 39 | sub_tail_mapping = sub_tail.unsqueeze(1) 40 | pred_obj_heads, pre_obj_tails = self.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, encoded_text) 41 | 42 | return { 43 | "sub_heads": pred_sub_heads, 44 | "sub_tails": pred_sub_tails, 45 | "obj_heads": pred_obj_heads, 46 | "obj_tails": pre_obj_tails, 47 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Notice 2 | This repository is no longer maintained. 3 | If any issues, check by yourself. 4 | # CasRel Model Pytorch reimplement 3 5 | The code is the PyTorch reimplement of the paper "A Novel Cascade Binary Tagging Framework for Relational Triple Extraction" ACL2020. 6 | The [official code](https://github.com/weizhepei/CasRel) was written in keras. 7 | 8 | I have encountered a lot of troubles with the keras version, so I decided to rewrite the code in PyTorch. 9 | # Introduction 10 | I followed the previous work of [longlongman](https://github.com/longlongman/CasRel-pytorch-reimplement) 11 | and [JuliaSun623](https://github.com/JuliaSun623/CasRel_fastNLP). 12 | 13 | So I have to express sincere thanks to them. 14 | 15 | I made some changes in order to better apply to the Chinese Dataset. 16 | The changes I have made are listed: 17 | - I changed the tokenizer from HBTokenizer to BertTokenizer, so Chinese sentences are tokenized by single character. 18 | (Note that you don't need to worry about keras) 19 | - I substituted the original pretrained model with 'bert-base-chinese'. 20 | - I used fastNLP to build the datasets. 21 | - I changed the encoding and decoding methods in order to fit the Chinese Dataset. 22 | - I reconstruct the structure for readability. 23 | # Requirements 24 | - torch==1.8.0+cu111 25 | - transformers==4.3.3 26 | - fastNLP==0.6.0 27 | - tqdm==4.59.0 28 | - numpy==1.20.1 29 | # Dataset 30 | I preprocessed the open-source dataset from Baidu. I did some cleaning, so the data given have 18 relation types. 31 | Some noisy data are eliminated. 32 | 33 | The data are in form of json. Take one as an example: 34 | ```json 35 | { 36 | "text": "陶喆的一首《好好说再见》推荐给大家,希望你们能够喜欢", 37 | "spo_list": [ 38 | { 39 | "predicate": "歌手", 40 | "object_type": "人物", 41 | "subject_type": "歌曲", 42 | "object": "陶喆", 43 | "subject": "好好说再见" 44 | } 45 | ] 46 | } 47 | ``` 48 | In fact the field object_type and subject_type are not used. 49 | 50 | If you have your own data, you can organize your data in the same format. 51 | # Usage 52 | ``` 53 | python run.py 54 | ``` 55 | I have already set the default value of the model, but you can still set your own configuration in model/config.py 56 | # Results 57 | The best F1 score on test data is 0.78 with a precision of 0.80 and recall of 0.76. 58 | 59 | It is to my expectation although it may not reach its utmost. 60 | 61 | I have also trained the [SpERT](https://github.com/lavis-nlp/spert) model, 62 | and CasRel turns out to perform better. 63 | More experiments need to be carried out since there are slight differences in both criterion and datasets. 64 | 65 | # Experiences 66 | - Learning rate 1e-5 seems a good choice. If you change the learning rate, the model will be dramatically affected. 67 | - It shows little improvement when I substitute BERT with RoBERTa. 68 | - It is crucial to shuffle the datasets in order to avoid overfitting. 69 | 70 | 71 | 72 | -------------------------------------------------------------------------------- /Run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.optim as optim 4 | from model.casRel import CasRel 5 | from model.callback import MyCallBack 6 | from model.data import load_data, get_data_iterator 7 | from model.config import Config 8 | from model.evaluate import metric 9 | import torch.nn.functional as F 10 | from fastNLP import Trainer, LossBase 11 | 12 | seed = 226 13 | torch.manual_seed(seed) 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | 16 | parser = argparse.ArgumentParser(description='Model Controller') 17 | parser.add_argument('--lr', type=float, default=1e-5, help='learning rate') 18 | parser.add_argument('--batch_size', type=int, default=8) 19 | parser.add_argument('--max_epoch', type=int, default=10) 20 | parser.add_argument('--max_len', type=int, default=300) 21 | parser.add_argument('--dataset', default='baidu', type=str, help='define your own dataset names') 22 | parser.add_argument("--bert_name", default='bert-base-chinese', type=str, help='choose pretrained bert name') 23 | parser.add_argument('--bert_dim', default=768, type=int) 24 | args = parser.parse_args() 25 | con = Config(args) 26 | 27 | 28 | class MyLoss(LossBase): 29 | def __init__(self): 30 | super(MyLoss, self).__init__() 31 | 32 | def get_loss(self, predict, target): 33 | mask = target['mask'] 34 | 35 | def loss_fn(pred, gold, mask): 36 | pred = pred.squeeze(-1) 37 | loss = F.binary_cross_entropy(pred, gold, reduction='none') 38 | if loss.shape != mask.shape: 39 | mask = mask.unsqueeze(-1) 40 | loss = torch.sum(loss * mask) / torch.sum(mask) 41 | return loss 42 | 43 | return loss_fn(predict['sub_heads'], target['sub_heads'], mask) + \ 44 | loss_fn(predict['sub_tails'], target['sub_tails'], mask) + \ 45 | loss_fn(predict['obj_heads'], target['obj_heads'], mask) + \ 46 | loss_fn(predict['obj_tails'], target['obj_tails'], mask) 47 | 48 | def __call__(self, pred_dict, target_dict, check=False): 49 | loss = self.get_loss(pred_dict, target_dict) 50 | return loss 51 | 52 | 53 | if __name__ == '__main__': 54 | model = CasRel(con).to(device) 55 | data_bundle, rel_vocab = load_data(con.train_path, con.dev_path, con.test_path, con.rel_path) 56 | train_data = get_data_iterator(con, data_bundle.get_dataset('train'), rel_vocab) 57 | dev_data = get_data_iterator(con, data_bundle.get_dataset('dev'), rel_vocab, is_test=True) 58 | test_data = get_data_iterator(con, data_bundle.get_dataset('test'), rel_vocab, is_test=True) 59 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=con.lr) 60 | trainer = Trainer(train_data=train_data, model=model, optimizer=optimizer, batch_size=con.batch_size, 61 | n_epochs=con.max_epoch, loss=MyLoss(), print_every=con.period, use_tqdm=True, 62 | callbacks=MyCallBack(dev_data, rel_vocab, con)) 63 | trainer.train() 64 | print("-" * 5 + "Begin Testing" + "-" * 5) 65 | metric(test_data, rel_vocab, con, model) 66 | -------------------------------------------------------------------------------- /model/evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | from transformers import BertTokenizer 5 | from tqdm import tqdm 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | 10 | def to_tuple(triple_list): 11 | ret = [] 12 | for triple in triple_list: 13 | ret.append((triple['subject'], triple['predicate'], triple['object'])) 14 | return ret 15 | 16 | 17 | def metric(data_iter, rel_vocab, config, model, output=True, h_bar=0.5, t_bar=0.5): 18 | 19 | orders = ['subject', 'relation', 'object'] 20 | correct_num, predict_num, gold_num = 0, 0, 0 21 | tokenizer = BertTokenizer.from_pretrained(config.bert_name) 22 | 23 | for batch_x, batch_y in tqdm(data_iter): 24 | with torch.no_grad(): 25 | token_ids = batch_x['token_ids'] 26 | mask = batch_x['mask'] 27 | encoded_text = model.get_encoded_text(token_ids, mask) 28 | pred_sub_heads, pred_sub_tails = model.get_subs(encoded_text) 29 | sub_heads = torch.where(pred_sub_heads[0] > h_bar)[0] 30 | sub_tails = torch.where(pred_sub_tails[0] > t_bar)[0] 31 | subjects = [] 32 | for sub_head in sub_heads: 33 | sub_tail = sub_tails[sub_tails >= sub_head] 34 | if len(sub_tail) > 0: 35 | sub_tail = sub_tail[0] 36 | subject = ''.join(tokenizer.decode(token_ids[0][sub_head: sub_tail + 1]).split()) 37 | subjects.append((subject, sub_head, sub_tail)) 38 | if subjects: 39 | triple_list = [] 40 | repeated_encoded_text = encoded_text.repeat(len(subjects), 1, 1) 41 | sub_head_mapping = torch.zeros((len(subjects), 1, encoded_text.size(1)), dtype=torch.float, 42 | device=device) 43 | sub_tail_mapping = torch.zeros((len(subjects), 1, encoded_text.size(1)), dtype=torch.float, 44 | device=device) 45 | for subject_idx, subject in enumerate(subjects): 46 | sub_head_mapping[subject_idx][0][subject[1]] = 1 47 | sub_tail_mapping[subject_idx][0][subject[2]] = 1 48 | pred_obj_heads, pred_obj_tails = model.get_objs_for_specific_sub(sub_head_mapping, sub_tail_mapping, 49 | repeated_encoded_text) 50 | for subject_idx, subject in enumerate(subjects): 51 | sub = subject[0] 52 | obj_heads = torch.where(pred_obj_heads[subject_idx] > h_bar) 53 | obj_tails = torch.where(pred_obj_tails[subject_idx] > t_bar) 54 | for obj_head, rel_head in zip(*obj_heads): 55 | for obj_tail, rel_tail in zip(*obj_tails): 56 | if obj_head <= obj_tail and rel_head == rel_tail: 57 | rel = rel_vocab.to_word(int(rel_head)) 58 | obj = ''.join(tokenizer.decode(token_ids[0][obj_head: obj_tail + 1]).split()) 59 | triple_list.append((sub, rel, obj)) 60 | break 61 | 62 | triple_set = set() 63 | for s, r, o in triple_list: 64 | triple_set.add((s, r, o)) 65 | pred_list = list(triple_set) 66 | 67 | else: 68 | pred_list = [] 69 | 70 | pred_triples = set(pred_list) 71 | gold_triples = set(to_tuple(batch_y['triples'][0])) 72 | correct_num += len(pred_triples & gold_triples) 73 | predict_num += len(pred_triples) 74 | gold_num += len(gold_triples) 75 | 76 | if output: 77 | if not os.path.exists(config.result_dir): 78 | os.mkdir(config.result_dir) 79 | path = os.path.join(config.result_dir, config.result_save_name) 80 | fw = open(path, 'w') 81 | result = json.dumps({ 82 | 'triple_list_gold': [ 83 | dict(zip(orders, triple)) for triple in gold_triples 84 | ], 85 | 'triple_list_pred': [ 86 | dict(zip(orders, triple)) for triple in pred_triples 87 | ], 88 | 'new': [ 89 | dict(zip(orders, triple)) for triple in pred_triples - gold_triples 90 | ], 91 | 'lack': [ 92 | dict(zip(orders, triple)) for triple in gold_triples - pred_triples 93 | ] 94 | }, ensure_ascii=False) 95 | fw.write(result + '\n') 96 | 97 | print("correct_num: {:3d}, predict_num: {:3d}, gold_num: {:3d}".format(correct_num, predict_num, gold_num)) 98 | precision = correct_num / (predict_num + 1e-10) 99 | recall = correct_num / (gold_num + 1e-10) 100 | f1_score = 2 * precision * recall / (precision + recall + 1e-10) 101 | print('f1: {:4.2f}, precision: {:4.2f}, recall: {:4.2f}'.format(f1_score, precision, recall)) 102 | return precision, recall, f1_score 103 | -------------------------------------------------------------------------------- /model/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | from random import choice 3 | from fastNLP import TorchLoaderIter, DataSet, Vocabulary, Sampler 4 | from fastNLP.io import JsonLoader 5 | import torch 6 | import numpy as np 7 | from transformers import BertTokenizer 8 | from collections import defaultdict 9 | from torch.nn.utils.rnn import pad_sequence 10 | 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | num_rel = 18 13 | 14 | 15 | def load_data(train_path, dev_path, test_path, rel_dict_path): 16 | paths = {'train': train_path, 'dev': dev_path, 'test': test_path} 17 | loader = JsonLoader({"text": "text", "spo_list": "spo_list"}) 18 | data_bundle = loader.load(paths) 19 | id2rel = json.load(open(rel_dict_path)) 20 | rel_vocab = Vocabulary(unknown=None, padding=None) 21 | rel_vocab.add_word_lst(list(id2rel.values())) 22 | return data_bundle, rel_vocab 23 | 24 | 25 | def find_head_idx(source, target): 26 | target_len = len(target) 27 | for i in range(len(source)): 28 | if source[i: i + target_len] == target: 29 | return i 30 | return -1 31 | 32 | 33 | class MyDataset(DataSet): 34 | def __init__(self, config, dataset, rel_vocab, is_test): 35 | super().__init__() 36 | self.config = config 37 | self.dataset = dataset 38 | self.rel_vocab = rel_vocab 39 | self.is_test = is_test 40 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') 41 | 42 | def __getitem__(self, item): 43 | json_data = self.dataset[item] 44 | text = json_data['text'] 45 | tokenized = self.tokenizer(text, max_length=self.config.max_len, truncation=True) 46 | tokens = tokenized['input_ids'] 47 | masks = tokenized['attention_mask'] 48 | text_len = len(tokens) 49 | 50 | token_ids = torch.tensor(tokens, dtype=torch.long) 51 | masks = torch.tensor(masks, dtype=torch.bool) 52 | sub_heads, sub_tails = torch.zeros(text_len), torch.zeros(text_len) 53 | sub_head, sub_tail = torch.zeros(text_len), torch.zeros(text_len) 54 | obj_heads = torch.zeros((text_len, self.config.num_relations)) 55 | obj_tails = torch.zeros((text_len, self.config.num_relations)) 56 | 57 | if not self.is_test: 58 | s2ro_map = defaultdict(list) 59 | for spo in json_data['spo_list']: 60 | triple = (self.tokenizer(spo['subject'], add_special_tokens=False)['input_ids'], 61 | self.rel_vocab.to_index(spo['predicate']), 62 | self.tokenizer(spo['object'], add_special_tokens=False)['input_ids']) 63 | sub_head_idx = find_head_idx(tokens, triple[0]) 64 | obj_head_idx = find_head_idx(tokens, triple[2]) 65 | if sub_head_idx != -1 and obj_head_idx != -1: 66 | sub = (sub_head_idx, sub_head_idx + len(triple[0]) - 1) 67 | s2ro_map[sub].append( 68 | (obj_head_idx, obj_head_idx + len(triple[2]) - 1, triple[1])) 69 | 70 | if s2ro_map: 71 | for s in s2ro_map: 72 | sub_heads[s[0]] = 1 73 | sub_tails[s[1]] = 1 74 | sub_head_idx, sub_tail_idx = choice(list(s2ro_map.keys())) 75 | sub_head[sub_head_idx] = 1 76 | sub_tail[sub_tail_idx] = 1 77 | for ro in s2ro_map.get((sub_head_idx, sub_tail_idx), []): 78 | obj_heads[ro[0]][ro[2]] = 1 79 | obj_tails[ro[1]][ro[2]] = 1 80 | 81 | return token_ids, masks, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, json_data['spo_list'] 82 | 83 | def __len__(self): 84 | return len(self.dataset) 85 | 86 | 87 | def my_collate_fn(batch): 88 | batch = list(filter(lambda x: x is not None, batch)) 89 | token_ids, masks, sub_heads, sub_tails, sub_head, sub_tail, obj_heads, obj_tails, triples = zip(*batch) 90 | batch_token_ids = pad_sequence(token_ids, batch_first=True) 91 | batch_masks = pad_sequence(masks, batch_first=True) 92 | batch_sub_heads = pad_sequence(sub_heads, batch_first=True) 93 | batch_sub_tails = pad_sequence(sub_tails, batch_first=True) 94 | batch_sub_head = pad_sequence(sub_head, batch_first=True) 95 | batch_sub_tail = pad_sequence(sub_tail, batch_first=True) 96 | batch_obj_heads = pad_sequence(obj_heads, batch_first=True) 97 | batch_obj_tails = pad_sequence(obj_tails, batch_first=True) 98 | 99 | return {"token_ids": batch_token_ids.to(device), 100 | "mask": batch_masks.to(device), 101 | "sub_head": batch_sub_head.to(device), 102 | "sub_tail": batch_sub_tail.to(device), 103 | "sub_heads": batch_sub_heads.to(device), 104 | }, \ 105 | {"mask": batch_masks.to(device), 106 | "sub_heads": batch_sub_heads.to(device), 107 | "sub_tails": batch_sub_tails.to(device), 108 | "obj_heads": batch_obj_heads.to(device), 109 | "obj_tails": batch_obj_tails.to(device), 110 | "triples": triples 111 | } 112 | 113 | 114 | class MyRandomSampler(Sampler): 115 | def __call__(self, data_set): 116 | return np.random.permutation(len(data_set)).tolist() 117 | 118 | 119 | def get_data_iterator(config, dataset, rel_vocab, is_test=False, collate_fn=my_collate_fn): 120 | dataset = MyDataset(config, dataset, rel_vocab, is_test) 121 | return TorchLoaderIter(dataset=dataset, 122 | collate_fn=collate_fn, 123 | batch_size=config.batch_size if not is_test else 1, 124 | sampler=MyRandomSampler()) 125 | --------------------------------------------------------------------------------