├── README.md ├── checkpoints └── 占位.txt ├── config.py ├── data └── 占位.txt ├── data_loader.py ├── logs └── bert.log ├── main.py ├── model.py ├── model_hub └── 占位.txt ├── preprocess.py └── utils ├── __pycache__ ├── common_utils.cpython-37.pyc ├── metric_utils.cpython-37.pyc └── train_utils.cpython-37.pyc ├── common_utils.py ├── metric_utils.py └── train_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_casrel_triple_extraction 2 | 基于pytorch的CasRel进行三元组抽取。 3 | 4 | 之前的[基于pytorch的中文三元组提取(命名实体识别+关系抽取)](https://github.com/taishan1994/pytorch_triple_extraction)是先抽取主体和客体,再进行关系分类,这里是使用casrel,先抽取主体,再识别客体和关系。具体使用说明: 5 | 6 | - 1、在data/ske/raw_data下是原始数据,新建一个process.py,主要是得到mid_data下的关系的类型。 7 | - 2、针对于不同的数据源,在data_loader.py中修改MyDataset类下,返回的是一个列表,列表中的每个元素是:(text, labels),其中labels是[[主体,类别,客体]]。 8 | - 3、运行main.py进行训练、验证、测试和预测。 9 | 10 | 数据和模型下载地址:链接:https://pan.baidu.com/s/12w8vC1Arfo8FTH3AuN1WAw?pwd=h7hu 提取码:h7hu 11 | 12 | # 依赖 13 | 14 | ``` 15 | pytorch==1.6.0 16 | transformers==4.5.0 17 | ``` 18 | 19 | # 运行 20 | 21 | ```python 22 | python main.py \ 23 | --bert_dir="model_hub/chinese-bert-wwm-ext/" \ 24 | --data_dir="./data/ske/" \ 25 | --log_dir="./logs/" \ 26 | --output_dir="./checkpoints/" \ 27 | --num_tags=49 \ # 关系类别 28 | --seed=123 \ 29 | --gpu_ids="0" \ 30 | --max_seq_len=256 \ # 句子最大长度 31 | --lr=5e-5 \ 32 | --other_lr=5e-5 \ 33 | --train_batch_size=32 \ 34 | --train_epochs=1 \ 35 | --eval_batch_size=8 \ 36 | --max_grad_norm=1 \ 37 | --warmup_proportion=0.1 \ 38 | --adam_epsilon=1e-8 \ 39 | --weight_decay=0.01 \ 40 | --dropout_prob=0.1 \ 41 | --use_tensorboard="False" \ # 是否使用tensorboardX可视化 42 | --use_dev_num=1000 \ # 使用多少验证数据进行验证 43 | ``` 44 | 45 | ### 结果 46 | 这里使用batch_size=32训练了3000步。 47 | ``` 48 | ========metric======== 49 | precision:0.7087378640776699 recall:0.7237960339943342 f1:0.7161878065872459 50 | ``` 51 | ```python 52 | 53 | 文本: 查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部 54 | 主体: [['查尔斯·阿兰基斯']] 55 | 客体: [['智利', '智利圣地亚哥', '1989年4月17日']] 56 | 关系: [[('查尔斯·阿兰基斯', '国籍', '智利'), ('查尔斯·阿兰基斯', '出生地', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '出生日期', '1989年4月17日')]] 57 | ==================================================================================================== 58 | 文本: 《离开》是由张宇谱曲,演唱 59 | 主体: [['离开']] 60 | 客体: [['张宇']] 61 | 关系: [[('离开', '作词', '张宇'), ('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]] 62 | ==================================================================================================== 63 | 文本: 《愤怒的唐僧》由北京吴意波影视文化工作室与优酷电视剧频道联合制作,故事以喜剧元素为主,讲述唐僧与佛祖打牌,得罪了佛祖,被踢下人间再渡九九八十一难的故事 64 | 主体: [['愤怒的唐僧']] 65 | 客体: [['北京吴意波影视文化工作室']] 66 | 关系: [[('愤怒的唐僧', '出品公司', '北京吴意波影视文化工作室')]] 67 | ==================================================================================================== 68 | 文本: 李治即位后,萧淑妃受宠,王皇后为了排挤萧淑妃,答应李治让身在感业寺的武则天续起头发,重新纳入后宫 69 | 主体: [['萧淑妃']] 70 | 客体: [[]] 71 | 关系: [[]] 72 | ==================================================================================================== 73 | 文本: 《工业4.0》是2015年机械工业出版社出版的图书,作者是(德)阿尔冯斯·波特霍夫,恩斯特·安德雷亚斯·哈特曼 74 | 主体: [['工业4.0']] 75 | 客体: [['阿尔冯斯·波特霍夫', '恩斯特·安德雷亚斯·哈特曼', '机械工业出版社']] 76 | 关系: [[('工业4.0', '作者', '阿尔冯斯·波特霍夫'), ('工业4.0', '作者', '恩斯特·安德雷亚斯·哈特曼'), ('工业4.0', '出版社', '机械工业出版社')]] 77 | ==================================================================================================== 78 | 文本: 周佛海被捕入狱之后,其妻杨淑慧散尽家产请蒋介石枪下留人,于是周佛海从死刑变为无期,不过此人或许作恶多端,改判没多久便病逝于监狱,据悉是心脏病发作 79 | 主体: [['周佛海', '杨淑慧']] 80 | 客体: [[]] 81 | 关系: [[]] 82 | ==================================================================================================== 83 | 文本: 《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧 84 | 主体: [['李烈钧自述']] 85 | 客体: [['人民日报出版社']] 86 | 关系: [[('李烈钧自述', '出版社', '人民日报出版社')]] 87 | ==================================================================================================== 88 | 文本: 除演艺事业外,李冰冰热心公益,发起并亲自参与多项环保慈善活动,积极投身其中,身体力行担起了回馈社会的责任于02年出演《少年包青天》,进入大家视线 89 | 主体: [['少年包青天']] 90 | 客体: [['李冰冰']] 91 | 关系: [[('少年包青天', '主演', '李冰冰')]] 92 | ==================================================================================================== 93 | 文本: 马志舟,1907年出生,陕西三原人,汉族,中国共产党,任红四团第一连连长,1933年逝世 94 | 主体: [['马志舟']] 95 | 客体: [['中国', '陕西三原', '1907年', '汉族']] 96 | 关系: [[('马志舟', '国籍', '中国'), ('马志舟', '出生地', '陕西三原'), ('马志舟', '出生日期', '1907年'), ('马志舟', '民族', '汉族')]] 97 | ==================================================================================================== 98 | 文本: 斑刺莺是雀形目、剌嘴莺科的一种动物,分布于澳大利亚和新西兰,包括澳大利亚、新西兰、塔斯马尼亚及其附近的岛屿 99 | 主体: [['斑刺莺']] 100 | 客体: [['雀形目']] 101 | 关系: [[('斑刺莺', '目', '雀形目')]] 102 | ==================================================================================================== 103 | 文本: 《课本上学不到的生物学2》是2013年上海科技教育出版社出版的图书 104 | 主体: [['课本上学不到的生物学2']] 105 | 客体: [['上海科技教育出版社']] 106 | 关系: [[('课本上学不到的生物学2', '出版社', '上海科技教育出版社')]] 107 | ==================================================================================================== 108 | ``` 109 | 110 | # 参考 111 | 112 | > 模型参考:[bert4torch/task_relation_extraction.py at master · Tongjilibo/bert4torch (github.com)](https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction.py) 113 | 114 | -------------------------------------------------------------------------------- /checkpoints/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/checkpoints/占位.txt -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Args: 5 | @staticmethod 6 | def parse(): 7 | parser = argparse.ArgumentParser() 8 | return parser 9 | 10 | @staticmethod 11 | def initialize(parser): 12 | # args for path 13 | parser.add_argument('--output_dir', default='./checkpoints/', 14 | help='the output dir for model checkpoints') 15 | 16 | parser.add_argument('--bert_dir', default='../model_hub/bert-base-chinese/', 17 | help='bert dir for uer') 18 | parser.add_argument('--data_dir', default='./data/cner/', 19 | help='data dir for uer') 20 | parser.add_argument('--log_dir', default='./logs/', 21 | help='log dir for uer') 22 | 23 | # other args 24 | parser.add_argument('--num_tags', default=53, type=int, 25 | help='number of tags') 26 | parser.add_argument('--seed', type=int, default=123, help='random seed') 27 | 28 | parser.add_argument('--gpu_ids', type=str, default='0', 29 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu') 30 | 31 | parser.add_argument('--max_seq_len', default=256, type=int) 32 | 33 | parser.add_argument('--eval_batch_size', default=12, type=int) 34 | 35 | 36 | # train args 37 | parser.add_argument('--train_epochs', default=15, type=int, 38 | help='Max training epoch') 39 | 40 | parser.add_argument('--dropout_prob', default=0.1, type=float, 41 | help='drop out probability') 42 | 43 | # 2e-5 44 | parser.add_argument('--lr', default=3e-5, type=float, 45 | help='bert学习率') 46 | # 2e-3 47 | parser.add_argument('--other_lr', default=3e-4, type=float, 48 | help='bilstm和多层感知机学习率') 49 | # 0.5 50 | parser.add_argument('--max_grad_norm', default=1, type=float, 51 | help='max grad clip') 52 | parser.add_argument('--use_tensorboard', default="True", 53 | help='max grad clip') 54 | 55 | parser.add_argument('--warmup_proportion', default=0.1, type=float) 56 | 57 | parser.add_argument('--weight_decay', default=0.01, type=float) 58 | 59 | parser.add_argument('--adam_epsilon', default=1e-8, type=float) 60 | 61 | parser.add_argument('--train_batch_size', default=32, type=int) 62 | parser.add_argument('--use_dev_num', default=32, type=int, help="用于验证和测试的数量") 63 | 64 | 65 | 66 | return parser 67 | 68 | def get_parser(self): 69 | parser = self.parse() 70 | parser = self.initialize(parser) 71 | return parser.parse_args() -------------------------------------------------------------------------------- /data/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/data/占位.txt -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset 5 | from utils.common_utils import sequence_padding 6 | 7 | 8 | class ListDataset(Dataset): 9 | def __init__(self, file_path=None, data=None, **kwargs): 10 | self.kwargs = kwargs 11 | if isinstance(file_path, (str, list)): 12 | self.data = self.load_data(file_path) 13 | elif isinstance(data, list): 14 | self.data = data 15 | else: 16 | raise ValueError('The input args shall be str format file_path / list format dataset') 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, index): 22 | return self.data[index] 23 | 24 | @staticmethod 25 | def load_data(file_path): 26 | return file_path 27 | 28 | 29 | 30 | # 加载数据集 31 | class MyDataset(ListDataset): 32 | @staticmethod 33 | def load_data(filename): 34 | examples = [] 35 | with open(filename, encoding='utf-8') as f: 36 | raw_examples = f.readlines() 37 | # 这里是从json数据中的字典中获取 38 | for i, item in enumerate(raw_examples): 39 | # print(i,item) 40 | item = json.loads(item) 41 | text = item['text'] 42 | spo_list = item['spo_list'] 43 | labels = [] # [subject, predicate, object] 44 | for spo in spo_list: 45 | subject = spo['subject'] 46 | object = spo['object'] 47 | predicate = spo['predicate'] 48 | labels.append([subject, predicate, object]) 49 | examples.append((text, labels)) 50 | return examples 51 | 52 | class Collate: 53 | def __init__(self, max_len, tag2id, device, tokenizer): 54 | self.maxlen = max_len 55 | self.tag2id = tag2id 56 | self.id2tag = {v:k for k,v in tag2id.items()} 57 | self.device = device 58 | self.tokenizer = tokenizer 59 | 60 | def collate_fn(self, batch): 61 | def search(pattern, sequence): 62 | """从sequence中寻找子串pattern 63 | 如果找到,返回第一个下标;否则返回-1。 64 | """ 65 | n = len(pattern) 66 | for i in range(len(sequence)): 67 | if sequence[i:i + n] == pattern: 68 | return i 69 | return -1 70 | batch_subject_labels = [] 71 | batch_object_labels = [] 72 | batch_subject_ids = [] 73 | batch_token_ids = [] 74 | batch_attention_mask = [] 75 | batch_token_type_ids = [] 76 | callback = [] 77 | for i, (text, text_labels) in enumerate(batch): 78 | if len(text) > self.maxlen: 79 | text = text[:self.maxlen] 80 | tokens = [i for i in text] 81 | spoes = {} 82 | callback_text_labels = [] 83 | for s, p, o in text_labels: 84 | p = self.tag2id[p] 85 | s_idx = search(s, text) 86 | o_idx = search(o, text) 87 | if s_idx != -1 and o_idx != -1: 88 | callback_text_labels.append((s, self.id2tag[p], o)) 89 | s = (s_idx, s_idx + len(s) - 1) 90 | o = (o_idx, o_idx + len(o) - 1, p) 91 | if s not in spoes: 92 | spoes[s] = [] 93 | spoes[s].append(o) 94 | # print(text_labels) 95 | # print(text) 96 | # print(spoes) 97 | if spoes: 98 | # subject标签 99 | subject_labels = np.zeros((len(tokens), 2)) 100 | for s in spoes: 101 | subject_labels[s[0], 0] = 1 # subject首 102 | subject_labels[s[1], 1] = 1 # subject尾 103 | start, end = np.array(list(spoes.keys())).T 104 | start = np.random.choice(start) 105 | end = np.random.choice(end[end >= start]) 106 | # 这里取出的可能不是一个真实的subject 107 | subject_ids = (start, end) 108 | # 对应的object标签 109 | object_labels = np.zeros((len(tokens), len(self.tag2id), 2)) 110 | for o in spoes.get(subject_ids, []): 111 | object_labels[o[0], o[2], 0] = 1 112 | object_labels[o[1], o[2], 1] = 1 113 | # 构建batch 114 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 115 | batch_token_ids.append(token_ids) # 前面已经限制了长度 116 | batch_attention_mask.append([1] * len(token_ids)) 117 | batch_token_type_ids.append([0] * len(token_ids)) 118 | batch_subject_labels.append(subject_labels) 119 | batch_object_labels.append(object_labels) 120 | batch_subject_ids.append(subject_ids) 121 | callback.append((text, callback_text_labels)) 122 | batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=self.maxlen), dtype=torch.long, device=self.device) 123 | attention_mask = torch.tensor(sequence_padding(batch_attention_mask, length=self.maxlen), dtype=torch.long, device=self.device) 124 | token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids, length=self.maxlen), dtype=torch.long, device=self.device) 125 | batch_subject_labels = torch.tensor(sequence_padding(batch_subject_labels, length=self.maxlen), dtype=torch.float, device=self.device) 126 | batch_object_labels = torch.tensor(sequence_padding(batch_object_labels, length=self.maxlen), dtype=torch.float, device=self.device) 127 | batch_subject_ids = torch.tensor(batch_subject_ids, dtype=torch.long, device=self.device) 128 | 129 | return batch_token_ids, attention_mask, token_type_ids, batch_subject_labels, batch_object_labels, batch_subject_ids, callback 130 | 131 | 132 | if __name__ == "__main__": 133 | from transformers import BertTokenizer 134 | max_len = 256 135 | tokenizer = BertTokenizer.from_pretrained('model_hub/chinese-bert-wwm-ext/vocab.txt') 136 | train_dataset = MyDataset(file_path='data/ske/raw_data/train_data.json', 137 | tokenizer=tokenizer, 138 | max_len=max_len) 139 | # print(train_dataset[0]) 140 | 141 | with open('data/ske/mid_data/predicates.json') as fp: 142 | labels = json.load(fp) 143 | id2tag = {} 144 | tag2id = {} 145 | for i,label in enumerate(labels): 146 | id2tag[i] = label 147 | tag2id[label] = i 148 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 149 | collate = Collate(max_len=max_len, tag2id=tag2id, device=device, tokenizer=tokenizer) 150 | # collate.collate_fn(train_dataset[:20]) 151 | batch_size = 2 152 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate.collate_fn) 153 | 154 | for i, batch in enumerate(train_dataloader): 155 | print(batch) 156 | print(batch[0].shape) 157 | print(batch[1].shape) 158 | print(batch[2].shape) 159 | print(batch[3].shape) 160 | print(batch[4].shape) 161 | print(batch[5].shape) 162 | print(batch[6]) 163 | break -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from re import template 4 | import numpy as np 5 | from collections import defaultdict 6 | import torch 7 | from torch.utils.data import DataLoader, RandomSampler 8 | from transformers import BertTokenizer 9 | 10 | import config 11 | import data_loader 12 | from model import Casrel 13 | from utils.common_utils import set_seed, set_logger, read_json, fine_grade_tokenize 14 | from utils.train_utils import load_model_and_parallel, build_optimizer_and_scheduler, save_model 15 | from utils.metric_utils import calculate_metric_relation, get_p_r_f 16 | from tensorboardX import SummaryWriter 17 | 18 | args = config.Args().get_parser() 19 | set_seed(args.seed) 20 | logger = logging.getLogger(__name__) 21 | 22 | if args.use_tensorboard == "True": 23 | writer = SummaryWriter(log_dir='./tensorboard') 24 | 25 | def get_spo(object_preds, subject_ids, length, example, id2tag): 26 | # object_preds:[batchsize, maxlen, num_labels, 2] 27 | num_label = object_preds.shape[2] 28 | num_subject = len(subject_ids) 29 | relations = [] 30 | subjects = [] 31 | objects = [] 32 | # print(object_preds.shape, subject, length, example) 33 | for b in range(num_subject): 34 | tmp = object_preds[b, ...] 35 | subject_start, subject_end = subject_ids[b].cpu().numpy() 36 | subject = example[subject_start:subject_end+1] 37 | if subject not in subjects: 38 | subjects.append(subject) 39 | for label_id in range(num_label): 40 | start = tmp[:, label_id, :1] 41 | end = tmp[:, label_id, 1:] 42 | start = start.squeeze()[:length] 43 | end = end.squeeze()[:length] 44 | for i, st in enumerate(start): 45 | if st > 0.5: 46 | s = i 47 | for j in range(i, length): 48 | if end[j] > 0.5: 49 | e = j 50 | object = example[s:e+1] 51 | if object not in objects: 52 | objects.append(object) 53 | if (subject, id2tag[label_id], object) not in relations: 54 | relations.append((subject, id2tag[label_id], object)) 55 | break 56 | # print(relations) 57 | return relations, subjects, objects 58 | 59 | 60 | 61 | def get_subject_ids(subject_preds, mask): 62 | lengths = torch.sum(mask, -1) 63 | starts = subject_preds[:, :, :1] 64 | ends = subject_preds[:, :, 1:] 65 | subject_ids = [] 66 | for start, end, l in zip(starts, ends, lengths): 67 | tmp = [] 68 | start = start.squeeze()[:l] 69 | end = end.squeeze()[:l] 70 | for i, st in enumerate(start): 71 | if st > 0.5: 72 | s = i 73 | for j in range(i, l): 74 | if end[j] > 0.5: 75 | e = j 76 | if (s,e) not in subject_ids: 77 | tmp.append([s,e]) 78 | break 79 | 80 | subject_ids.append(tmp) 81 | return subject_ids 82 | 83 | class BertForRe: 84 | def __init__(self, args, train_loader, dev_loader, test_loader, id2tag, tag2id, model, device): 85 | self.train_loader = train_loader 86 | self.dev_loader = dev_loader 87 | self.test_loader = test_loader 88 | self.args = args 89 | self.id2tag = id2tag 90 | self.tag2id = tag2id 91 | self.model = model 92 | self.device = device 93 | if train_loader is not None: 94 | self.t_total = len(self.train_loader) * args.train_epochs 95 | self.optimizer, self.scheduler = build_optimizer_and_scheduler(args, model, self.t_total) 96 | 97 | def train(self): 98 | # Train 99 | global_step = 0 100 | self.model.zero_grad() 101 | eval_steps = 500 #每多少个step打印损失及进行验证 102 | best_f1 = 0.0 103 | for epoch in range(self.args.train_epochs): 104 | for step, batch_data in enumerate(self.train_loader): 105 | self.model.train() 106 | for batch in batch_data[:-1]: 107 | batch = batch.to(self.device) 108 | # batch_token_ids, attention_mask, token_type_ids, batch_subject_labels, batch_object_labels, batch_subject_ids 109 | loss = self.model(batch_data[0], batch_data[1], batch_data[2], batch_data[3], batch_data[4], batch_data[5]) 110 | 111 | # loss.backward(loss.clone().detach()) 112 | loss.backward() 113 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 114 | self.optimizer.step() 115 | self.scheduler.step() 116 | self.model.zero_grad() 117 | logger.info('【train】 epoch:{} {}/{} loss:{:.4f}'.format(epoch, global_step, self.t_total, loss.item())) 118 | global_step += 1 119 | if self.args.use_tensorboard == "True": 120 | writer.add_scalar('data/loss', loss.item(), global_step) 121 | if global_step % eval_steps == 0: 122 | precision, recall, f1_score = self.dev() 123 | logger.info('[eval] precision={:.4f} recall={:.4f} f1_score={:.4f}'.format(precision, recall, f1_score)) 124 | if f1_score > best_f1: 125 | save_model(self.args, self.model, model_name, global_step) 126 | best_f1 = f1_score 127 | 128 | 129 | def dev(self): 130 | self.model.eval() 131 | spos = [] 132 | true_spos = [] 133 | subjects = [] 134 | objects = [] 135 | all_examples = [] 136 | with torch.no_grad(): 137 | for eval_step, dev_batch_data in enumerate(self.dev_loader): 138 | for dev_batch in dev_batch_data[:-1]: 139 | dev_batch = dev_batch.to(self.device) 140 | 141 | seq_output, subject_preds = self.model.predict_subject(dev_batch_data[0], dev_batch_data[1],dev_batch_data[2]) 142 | # 注意这里需要先获取subject,然后再来获取object和关系,和训练直接使用subject_ids不一样 143 | cur_batch_size = dev_batch_data[0].shape[0] 144 | dev_examples = dev_batch_data[-1] 145 | true_spos += [i[1] for i in dev_examples] 146 | all_examples += [i[0] for i in dev_examples] 147 | subject_labels = dev_batch_data[3].cpu().numpy() 148 | object_labels = dev_batch_data[4].cpu().numpy() 149 | subject_ids = get_subject_ids(subject_preds, dev_batch_data[1]) 150 | 151 | example_lengths = torch.sum(dev_batch_data[1].cpu(), -1) 152 | 153 | for i in range(cur_batch_size): 154 | seq_output_tmp = seq_output[i, ...] 155 | subject_ids_tmp = subject_ids[i] 156 | length = example_lengths[i] 157 | example = dev_examples[i][0] 158 | if subject_ids_tmp: 159 | seq_output_tmp = seq_output_tmp.unsqueeze(0).repeat(len(subject_ids_tmp), 1, 1) 160 | subject_ids_tmp = torch.tensor(subject_ids_tmp, dtype=torch.long, device=device) 161 | if len(seq_output_tmp.shape) == 2: 162 | seq_output_tmp = seq_output_tmp.unsqueeze(0) 163 | object_preds = model.predict_object([seq_output_tmp, subject_ids_tmp]) 164 | spo, subject, object = get_spo(object_preds, subject_ids_tmp, length, example, self.id2tag) 165 | spos.append(spo) 166 | subjects.append(subject) 167 | objects.append(object) 168 | else: 169 | spos.append([]) 170 | subjects.append([]) 171 | objects.append([]) 172 | 173 | # for m,n, ex in zip(spos, true_spos, all_examples): 174 | # print(ex) 175 | # print(m, n) 176 | # print('='*100) 177 | tp, fp, fn = calculate_metric_relation(spos, true_spos) 178 | p, r, f1 = get_p_r_f(tp, fp, fn) 179 | # print("========metric========") 180 | # print("precision:{} recall:{} f1:{}".format(p, r, f1)) 181 | 182 | return p, r, f1 183 | 184 | 185 | 186 | def test(self, model_path): 187 | model = Casrel(self.args, self.tag2id) 188 | model, device = load_model_and_parallel(model, self.args.gpu_ids, model_path) 189 | model.eval() 190 | spos = [] 191 | true_spos = [] 192 | subjects = [] 193 | objects = [] 194 | all_examples = [] 195 | with torch.no_grad(): 196 | for eval_step, dev_batch_data in enumerate(dev_loader): 197 | for dev_batch in dev_batch_data[:-1]: 198 | dev_batch = dev_batch.to(device) 199 | 200 | seq_output, subject_preds = model.predict_subject(dev_batch_data[0], dev_batch_data[1],dev_batch_data[2]) 201 | # 注意这里需要先获取subject,然后再来获取object和关系,和训练直接使用subject_ids不一样 202 | cur_batch_size = dev_batch_data[0].shape[0] 203 | dev_examples = dev_batch_data[-1] 204 | true_spos += [i[1] for i in dev_examples] 205 | all_examples += [i[0] for i in dev_examples] 206 | subject_labels = dev_batch_data[3].cpu().numpy() 207 | object_labels = dev_batch_data[4].cpu().numpy() 208 | subject_ids = get_subject_ids(subject_preds, dev_batch_data[1]) 209 | 210 | example_lengths = torch.sum(dev_batch_data[1].cpu(), -1) 211 | 212 | for i in range(cur_batch_size): 213 | seq_output_tmp = seq_output[i, ...] 214 | subject_ids_tmp = subject_ids[i] 215 | length = example_lengths[i] 216 | example = dev_examples[i][0] 217 | if subject_ids_tmp: 218 | seq_output_tmp = seq_output_tmp.unsqueeze(0).repeat(len(subject_ids_tmp), 1, 1) 219 | subject_ids_tmp = torch.tensor(subject_ids_tmp, dtype=torch.long, device=device) 220 | if len(seq_output_tmp.shape) == 2: 221 | seq_output_tmp = seq_output_tmp.unsqueeze(0) 222 | object_preds = model.predict_object([seq_output_tmp, subject_ids_tmp]) 223 | spo, subject, object = get_spo(object_preds, subject_ids_tmp, length, example, self.id2tag) 224 | spos.append(spo) 225 | subjects.append(subject) 226 | objects.append(object) 227 | else: 228 | spos.append([]) 229 | subjects.append([]) 230 | objects.append([]) 231 | 232 | 233 | 234 | for i, (m,n, ex) in enumerate(zip(spos, true_spos, all_examples)): 235 | if i <= 10: 236 | print(ex) 237 | print(m, n) 238 | print('='*100) 239 | tp, fp, fn = calculate_metric_relation(spos, true_spos) 240 | p, r, f1 = get_p_r_f(tp, fp, fn) 241 | print("========metric========") 242 | print("precision:{} recall:{} f1:{}".format(p, r, f1)) 243 | 244 | return p, r, f1 245 | 246 | def predict(self, raw_text, model, tokenizer): 247 | model.eval() 248 | with torch.no_grad(): 249 | tokens = [i for i in raw_text] 250 | if len(tokens) > self.args.max_seq_len: 251 | tokens = tokens[:self.args.max_seq_len] 252 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 253 | attention_masks = [1] * len(token_ids) 254 | token_type_ids = [0] * len(token_ids) 255 | if len(token_ids) < self.args.max_seq_len: 256 | token_ids = token_ids + [0] * (self.args.max_seq_len - len(tokens)) 257 | attention_masks = attention_masks + [0] * (self.args.max_seq_len - len(tokens)) 258 | token_type_ids = token_type_ids + [0] * (self.args.max_seq_len - len(tokens)) 259 | assert len(token_ids) == self.args.max_seq_len 260 | assert len(attention_masks) == self.args.max_seq_len 261 | assert len(token_type_ids) == self.args.max_seq_len 262 | token_ids = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device) 263 | attention_masks = torch.from_numpy(np.array(attention_masks, dtype=np.uint8)).unsqueeze(0).to(self.device) 264 | token_type_ids = torch.from_numpy(np.array(token_type_ids)).unsqueeze(0).to(self.device) 265 | seq_output, subject_preds = model.predict_subject(token_ids, attention_masks, token_type_ids) 266 | subject_ids = get_subject_ids(subject_preds, attention_masks) 267 | 268 | cur_batch_size = seq_output.shape[0] 269 | spos = [] 270 | subjects = [] 271 | objects = [] 272 | for i in range(cur_batch_size): 273 | seq_output_tmp = seq_output[i, ...] 274 | subject_ids_tmp = subject_ids[i] 275 | length = len(tokens) 276 | example = raw_text 277 | if any(subject_ids_tmp): 278 | seq_output_tmp = seq_output_tmp.unsqueeze(0).repeat(len(subject_ids_tmp), 1, 1) 279 | 280 | subject_ids_tmp = torch.tensor(subject_ids_tmp, dtype=torch.long, device=device) 281 | if len(seq_output_tmp.shape) == 2: 282 | seq_output_tmp = seq_output_tmp.unsqueeze(0) 283 | object_preds = model.predict_object([seq_output_tmp, subject_ids_tmp]) 284 | 285 | spo, subject, object = get_spo(object_preds, subject_ids_tmp, length, example, self.id2tag) 286 | 287 | subjects.append(subject) 288 | objects.append(object) 289 | spos.append(spo) 290 | print("文本:", raw_text) 291 | print('主体:', subjects) 292 | print('客体:', objects) 293 | print('关系:', spos) 294 | print("="*100) 295 | 296 | if __name__ == '__main__': 297 | data_name = 'ske' 298 | model_name = 'bert' 299 | 300 | set_logger(os.path.join(args.log_dir, '{}.log'.format(model_name))) 301 | if data_name == "ske": 302 | args.data_dir = './data/ske' 303 | data_path = os.path.join(args.data_dir, 'raw_data') 304 | label_list = read_json(os.path.join(args.data_dir, 'mid_data'), 'predicates') 305 | tag2id = {} 306 | id2tag = {} 307 | for k,v in enumerate(label_list): 308 | tag2id[v] = k 309 | id2tag[k] = v 310 | 311 | logger.info(args) 312 | max_seq_len = args.max_seq_len 313 | tokenizer = BertTokenizer.from_pretrained('model_hub/chinese-bert-wwm-ext/vocab.txt') 314 | 315 | model = Casrel(args, tag2id) 316 | model, device = load_model_and_parallel(model, args.gpu_ids) 317 | 318 | collate = data_loader.Collate(max_len=max_seq_len, tag2id=tag2id, device=device, tokenizer=tokenizer) 319 | 320 | 321 | train_dataset = data_loader.MyDataset(file_path=os.path.join(data_path, 'train_data.json'), 322 | tokenizer=tokenizer, 323 | max_len=max_seq_len) 324 | 325 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate.collate_fn) 326 | dev_dataset = data_loader.MyDataset(file_path=os.path.join(data_path, 'dev_data.json'), 327 | tokenizer=tokenizer, 328 | max_len=max_seq_len) 329 | 330 | dev_dataset = dev_dataset[:args.use_dev_num] 331 | dev_loader = DataLoader(dev_dataset, batch_size=args.eval_batch_size, shuffle=False, collate_fn=collate.collate_fn) 332 | 333 | 334 | bertForNer = BertForRe(args, train_loader, dev_loader, dev_loader, id2tag, tag2id, model, device) 335 | # bertForNer.train() 336 | 337 | model_path = './checkpoints/bert/model.pt'.format(model_name) 338 | # bertForNer.test(model_path) 339 | 340 | texts = [ 341 | '查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部', 342 | '《离开》是由张宇谱曲,演唱', 343 | '《愤怒的唐僧》由北京吴意波影视文化工作室与优酷电视剧频道联合制作,故事以喜剧元素为主,讲述唐僧与佛祖打牌,得罪了佛祖,被踢下人间再渡九九八十一难的故事', 344 | '李治即位后,萧淑妃受宠,王皇后为了排挤萧淑妃,答应李治让身在感业寺的武则天续起头发,重新纳入后宫', 345 | '《工业4.0》是2015年机械工业出版社出版的图书,作者是(德)阿尔冯斯·波特霍夫,恩斯特·安德雷亚斯·哈特曼', 346 | '周佛海被捕入狱之后,其妻杨淑慧散尽家产请蒋介石枪下留人,于是周佛海从死刑变为无期,不过此人或许作恶多端,改判没多久便病逝于监狱,据悉是心脏病发作', 347 | '《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧', 348 | '除演艺事业外,李冰冰热心公益,发起并亲自参与多项环保慈善活动,积极投身其中,身体力行担起了回馈社会的责任于02年出演《少年包青天》,进入大家视线', 349 | '马志舟,1907年出生,陕西三原人,汉族,中国共产党,任红四团第一连连长,1933年逝世', 350 | '斑刺莺是雀形目、剌嘴莺科的一种动物,分布于澳大利亚和新西兰,包括澳大利亚、新西兰、塔斯马尼亚及其附近的岛屿', 351 | '《课本上学不到的生物学2》是2013年上海科技教育出版社出版的图书', 352 | ] 353 | model = Casrel(args, tag2id) 354 | model, device = load_model_and_parallel(model, args.gpu_ids, model_path) 355 | for text in texts: 356 | bertForNer.predict(text, model, tokenizer) 357 | 358 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from transformers import BertModel 4 | 5 | 6 | class MyBCELoss(nn.BCELoss): 7 | def __init__(self, **kwargs): 8 | super().__init__(**kwargs) 9 | def forward(self, inputs, targets): 10 | # subject_preds: torch.Size([16, 256, 2]) 11 | # subject_labels: torch.Size([16, 256, 2]) 12 | # object_labels: torch.Size([16, 256, 49, 2]) 13 | # object_preds: torch.Size([16, 256, 49, 2]) 14 | subject_preds, object_preds = inputs 15 | subject_labels, object_labels, mask = targets 16 | # sujuect部分loss 17 | subject_loss = super().forward(subject_preds, subject_labels) 18 | subject_loss = subject_loss.mean(dim=-1) 19 | subject_loss = (subject_loss * mask).sum() / mask.sum() 20 | # object部分loss 21 | object_loss = super().forward(object_preds, object_labels) 22 | object_loss = object_loss.mean(dim=-1).sum(dim=-1) 23 | object_loss = (object_loss * mask).sum() / mask.sum() 24 | return subject_loss + object_loss 25 | 26 | 27 | class LayerNorm(nn.Module): 28 | def __init__(self, hidden_size, eps=1e-12, conditional_size=False, weight=True, bias=True, norm_mode='normal', **kwargs): 29 | """layernorm 层,这里自行实现,目的是为了兼容 conditianal layernorm,使得可以做条件文本生成、条件分类等任务 30 | 条件layernorm来自于苏剑林的想法,详情:https://spaces.ac.cn/archives/7124 31 | """ 32 | super(LayerNorm, self).__init__() 33 | 34 | # 兼容roformer_v2不包含weight 35 | if weight: 36 | self.weight = nn.Parameter(torch.ones(hidden_size)) 37 | # 兼容t5不包含bias项, 和t5使用的RMSnorm 38 | if bias: 39 | self.bias = nn.Parameter(torch.zeros(hidden_size)) 40 | self.norm_mode = norm_mode 41 | 42 | self.eps = eps 43 | self.conditional_size = conditional_size 44 | if conditional_size: 45 | # 条件layernorm, 用于条件文本生成, 46 | # 这里采用全零初始化, 目的是在初始状态不干扰原来的预训练权重 47 | self.dense1 = nn.Linear(conditional_size, hidden_size, bias=False) 48 | self.dense1.weight.data.uniform_(0, 0) 49 | self.dense2 = nn.Linear(conditional_size, hidden_size, bias=False) 50 | self.dense2.weight.data.uniform_(0, 0) 51 | 52 | def forward(self, x): 53 | inputs = x[0] # 这里是visible_hiddens 54 | 55 | if self.norm_mode == 'rmsnorm': 56 | # t5使用的是RMSnorm 57 | variance = inputs.to(torch.float32).pow(2).mean(-1, keepdim=True) 58 | o = inputs * torch.rsqrt(variance + self.eps) 59 | else: 60 | # 归一化是针对于inputs 61 | u = inputs.mean(-1, keepdim=True) 62 | s = (inputs - u).pow(2).mean(-1, keepdim=True) 63 | o = (inputs - u) / torch.sqrt(s + self.eps) 64 | 65 | if not hasattr(self, 'weight'): 66 | self.weight = 1 67 | if not hasattr(self, 'bias'): 68 | self.bias = 0 69 | 70 | if self.conditional_size: 71 | cond = x[1] # 这里是repeat_hiddens 72 | # 三者的形状都是一致的 73 | # print(inputs.shape, cond.shape, o.shape) 74 | for _ in range(len(inputs.shape) - len(cond.shape)): 75 | cond = cond.unsqueeze(dim=1) 76 | 77 | return (self.weight + self.dense1(cond)) * o + (self.bias + self.dense2(cond)) 78 | else: 79 | return self.weight * o + self.bias 80 | 81 | 82 | # 定义bert上的模型结构 83 | class Casrel(nn.Module): 84 | def __init__(self, args, tag2id): 85 | super().__init__() 86 | self.bert = BertModel.from_pretrained(args.bert_dir) 87 | self.tag2id = tag2id 88 | self.linear1 = nn.Linear(768, 2) 89 | # 768*2 90 | self.condLayerNorm = LayerNorm(hidden_size=768, conditional_size=768*2) 91 | self.linear2 = nn.Linear(768, len(tag2id)*2) 92 | self.crierion = MyBCELoss() 93 | 94 | @staticmethod 95 | def extract_subject(inputs): 96 | """根据subject_ids从output中取出subject的向量表征 97 | """ 98 | output, subject_ids = inputs 99 | start = torch.gather(output, dim=1, index=subject_ids[:, :1].unsqueeze(2).expand(-1, -1, output.shape[-1])) 100 | end = torch.gather(output, dim=1, index=subject_ids[:, 1:].unsqueeze(2).expand(-1, -1, output.shape[-1])) 101 | subject = torch.cat([start, end], 2) 102 | # print(subject.shape) 103 | return subject[:, 0] 104 | 105 | def forward(self, 106 | token_ids, 107 | attention_masks, 108 | token_type_ids, 109 | subject_labels=None, 110 | object_labels=None, 111 | subject_ids=None): 112 | # 预测subject 113 | bert_outputs = self.bert( 114 | input_ids=token_ids, 115 | attention_mask=attention_masks, 116 | token_type_ids=token_type_ids 117 | ) 118 | seq_output = bert_outputs[0] # [btz, seq_len, hdsz] 119 | subject_preds = (torch.sigmoid(self.linear1(seq_output)))**2 # [btz, seq_len, 2] 120 | 121 | # 传入subject,预测object 122 | # 通过Conditional Layer Normalization将subject融入到object的预测中 123 | # 理论上应该用LayerNorm前的,但是这样只能返回各个block顶层输出,这里和keras实现不一致 124 | subject = self.extract_subject([seq_output, subject_ids]) 125 | output = self.condLayerNorm([seq_output, subject]) 126 | output = (torch.sigmoid(self.linear2(output)))**4 127 | object_preds = output.reshape(*output.shape[:2], len(self.tag2id), 2) 128 | # print(object_preds.shape, object_labels.shape) 129 | loss = self.crierion([subject_preds, object_preds], [subject_labels, object_labels, attention_masks]) 130 | return loss 131 | 132 | def predict_subject(self, token_ids, attention_masks, token_type_ids): 133 | self.eval() 134 | with torch.no_grad(): 135 | bert_outputs = self.bert( 136 | input_ids=token_ids, 137 | attention_mask=attention_masks, 138 | token_type_ids=token_type_ids 139 | ) 140 | seq_output = bert_outputs[0] # [btz, seq_len, hdsz] 141 | subject_preds = (torch.sigmoid(self.linear1(seq_output)))**2 # [btz, seq_len, 2] 142 | return seq_output, subject_preds 143 | 144 | def predict_object(self, inputs): 145 | self.eval() 146 | with torch.no_grad(): 147 | seq_output, subject_ids = inputs 148 | subject = self.extract_subject([seq_output, subject_ids]) 149 | output = self.condLayerNorm([seq_output, subject]) 150 | output = (torch.sigmoid(self.linear2(output)))**4 151 | object_preds = output.reshape(*output.shape[:2], len(self.tag2id), 2) 152 | return object_preds -------------------------------------------------------------------------------- /model_hub/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/model_hub/占位.txt -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | from transformers import BertTokenizer 5 | from utils import common_utils 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class InputExample: 11 | def __init__(self, set_type, text, labels=None): 12 | self.set_type = set_type 13 | self.text = text 14 | self.labels = labels 15 | 16 | def __repr__(self): 17 | string = "" 18 | for key, value in self.__dict__.items(): 19 | string += f"{key}: {value}\n" 20 | return f"<{string}>" 21 | 22 | class ReProcessor: 23 | @staticmethod 24 | def read_json(file_path): 25 | pass 26 | 27 | def get_examples(self, raw_examples, set_type): 28 | pass 29 | 30 | 31 | class SKEProcessor(ReProcessor): 32 | @staticmethod 33 | def read_json(file_path): 34 | with open(file_path, encoding='utf-8') as f: 35 | raw_examples = f.readlines() 36 | return raw_examples 37 | 38 | def get_examples(self, raw_examples, set_type): 39 | examples = [] 40 | # 这里是从json数据中的字典中获取 41 | for i, item in enumerate(raw_examples): 42 | # print(i,item) 43 | item = json.loads(item) 44 | text = item['text'] 45 | spo_list = item['spo_list'] 46 | labels = [] # [subject, predicate, object] 47 | for spo in spo_list: 48 | subject = spo['subject'] 49 | object = spo['object'] 50 | predicate = spo['predicate'] 51 | labels.append([subject, predicate, object]) 52 | examples.append(InputExample(set_type=set_type, 53 | text=text, 54 | labels=labels)) 55 | 56 | return examples 57 | 58 | 59 | 60 | if __name__ == "__main__": 61 | skeProcessor = SKEProcessor() 62 | raw_examples = skeProcessor.read_json('data/ske/raw_data/train_data.json') 63 | examples = skeProcessor.get_examples(raw_examples, set_type="train") 64 | for i in range(5): 65 | print(examples[i]) -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/utils/__pycache__/common_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/utils/__pycache__/metric_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_casrel_triple_extraction/fb15e5807e6815f770092ab72e1ea7db019ab5c3/utils/__pycache__/train_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import random 3 | import os 4 | import json 5 | import logging 6 | import time 7 | import pickle 8 | import numpy as np 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | def trans_ij2k(seq_len, i, j): 14 | '''把第i行,第j列转化成上三角flat后的序号 15 | ''' 16 | if (i > seq_len - 1) or (j > seq_len - 1) or (i > j): 17 | return 0 18 | return int(0.5*(2*seq_len-i+1)*i+(j-i)) 19 | 20 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): 21 | """将序列padding到同一长度 22 | """ 23 | if isinstance(inputs[0], (np.ndarray, list)): 24 | if length is None: 25 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) 26 | elif not hasattr(length, '__getitem__'): 27 | length = [length] 28 | 29 | slices = [np.s_[:length[i]] for i in range(seq_dims)] 30 | slices = tuple(slices) if len(slices) > 1 else slices[0] 31 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 32 | 33 | outputs = [] 34 | for x in inputs: 35 | x = x[slices] 36 | for i in range(seq_dims): 37 | if mode == 'post': 38 | pad_width[i] = (0, length[i] - np.shape(x)[i]) 39 | elif mode == 'pre': 40 | pad_width[i] = (length[i] - np.shape(x)[i], 0) 41 | else: 42 | raise ValueError('"mode" argument must be "post" or "pre".') 43 | x = np.pad(x, pad_width, 'constant', constant_values=value) 44 | outputs.append(x) 45 | 46 | return np.array(outputs) 47 | 48 | elif isinstance(inputs[0], torch.Tensor): 49 | assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor' 50 | if length is not None: 51 | inputs = [i[:length] for i in inputs] 52 | return pad_sequence(inputs, padding_value=value, batch_first=True) 53 | else: 54 | raise ValueError('"input" argument must be tensor/list/ndarray.') 55 | 56 | 57 | def timer(func): 58 | """ 59 | 函数计时器 60 | :param func: 61 | :return: 62 | """ 63 | 64 | @functools.wraps(func) 65 | def wrapper(*args, **kwargs): 66 | start = time.time() 67 | res = func(*args, **kwargs) 68 | end = time.time() 69 | print("{}共耗时约{:.4f}秒".format(func.__name__, end - start)) 70 | return res 71 | 72 | return wrapper 73 | 74 | 75 | def set_seed(seed=123): 76 | """ 77 | 设置随机数种子,保证实验可重现 78 | :param seed: 79 | :return: 80 | """ 81 | random.seed(seed) 82 | torch.manual_seed(seed) 83 | np.random.seed(seed) 84 | torch.cuda.manual_seed_all(seed) 85 | 86 | 87 | def set_logger(log_path): 88 | """ 89 | 配置log 90 | :param log_path:s 91 | :return: 92 | """ 93 | logger = logging.getLogger() 94 | logger.setLevel(logging.INFO) 95 | 96 | # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler 97 | if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers): 98 | file_handler = logging.FileHandler(log_path) 99 | formatter = logging.Formatter( 100 | '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s') 101 | file_handler.setFormatter(formatter) 102 | logger.addHandler(file_handler) 103 | 104 | if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers): 105 | stream_handler = logging.StreamHandler() 106 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 107 | logger.addHandler(stream_handler) 108 | 109 | 110 | def save_json(data_dir, data, desc): 111 | """保存数据为json""" 112 | with open(os.path.join(data_dir, '{}.json'.format(desc)), 'w', encoding='utf-8') as f: 113 | json.dump(data, f, ensure_ascii=False, indent=2) 114 | 115 | 116 | def read_json(data_dir, desc): 117 | """读取数据为json""" 118 | with open(os.path.join(data_dir, '{}.json'.format(desc)), 'r', encoding='utf-8') as f: 119 | data = json.load(f) 120 | return data 121 | 122 | 123 | def save_pkl(data_dir, data, desc): 124 | """保存.pkl文件""" 125 | with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'wb') as f: 126 | pickle.dump(data, f) 127 | 128 | 129 | def read_pkl(data_dir, desc): 130 | """读取.pkl文件""" 131 | with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'rb') as f: 132 | data = pickle.load(f) 133 | return data 134 | 135 | 136 | def fine_grade_tokenize(raw_text, tokenizer): 137 | """ 138 | 序列标注任务 BERT 分词器可能会导致标注偏移, 139 | 用 char-level 来 tokenize 140 | """ 141 | tokens = [] 142 | 143 | for _ch in raw_text: 144 | if _ch in [' ', '\t', '\n']: 145 | tokens.append('[BLANK]') 146 | else: 147 | if not len(tokenizer.tokenize(_ch)): 148 | tokens.append('[INV]') 149 | else: 150 | tokens.append(_ch) 151 | 152 | return tokens -------------------------------------------------------------------------------- /utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | 9 | 10 | def calculate_metric_relation(gt, predict): 11 | tp, fp, fn = 0, 0, 0 12 | for entity_predict, entity_gt in zip(predict, gt): 13 | # print(entity_predict, entity_gt) 14 | flag = 0 15 | for ent in entity_predict: 16 | if ent in entity_gt: 17 | tp += 1 18 | else: 19 | fp += 1 20 | fn = sum([len(i) for i in gt]) - tp 21 | return tp, fp, fn 22 | 23 | 24 | 25 | def calculate_metric(gt, predict): 26 | """ 27 | 计算 tp fp fn 28 | """ 29 | tp, fp, fn = 0, 0, 0 30 | for entity_predict in predict: 31 | flag = 0 32 | for entity_gt in gt: 33 | if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]: 34 | flag = 1 35 | tp += 1 36 | break 37 | if flag == 0: 38 | fp += 1 39 | 40 | fn = len(gt) - tp 41 | 42 | return np.array([tp, fp, fn]) 43 | 44 | 45 | def get_p_r_f(tp, fp, fn): 46 | p = tp / (tp + fp) if tp + fp != 0 else 0 47 | r = tp / (tp + fn) if tp + fn != 0 else 0 48 | f1 = 2 * p * r / (p + r) if p + r != 0 else 0 49 | return np.array([p, r, f1]) 50 | 51 | def classification_report(metrics_matrix, label_list, id2label, total_count, digits=2, suffix=False): 52 | name_width = max([len(label) for label in label_list]) 53 | last_line_heading = 'micro-f1' 54 | width = max(name_width, len(last_line_heading), digits) 55 | 56 | headers = ["precision", "recall", "f1-score", "support"] 57 | head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers) 58 | report = head_fmt.format(u'', *headers, width=width) 59 | report += u'\n\n' 60 | 61 | row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n' 62 | 63 | ps, rs, f1s, s = [], [], [], [] 64 | for label_id, label_matrix in enumerate(metrics_matrix): 65 | type_name = id2label[label_id] 66 | p,r,f1 = get_p_r_f(label_matrix[0],label_matrix[1],label_matrix[2]) 67 | nb_true = total_count[label_id] 68 | report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits) 69 | ps.append(p) 70 | rs.append(r) 71 | f1s.append(f1) 72 | s.append(nb_true) 73 | 74 | report += u'\n' 75 | mirco_metrics = np.sum(metrics_matrix, axis=0) 76 | mirco_metrics = get_p_r_f(mirco_metrics[0], mirco_metrics[1], mirco_metrics[2]) 77 | # compute averages 78 | print('precision:{:.4f} recall:{:.4f} micro_f1:{:.4f}'.format(mirco_metrics[0],mirco_metrics[1],mirco_metrics[2])) 79 | report += row_fmt.format(last_line_heading, 80 | mirco_metrics[0], 81 | mirco_metrics[1], 82 | mirco_metrics[2], 83 | np.sum(s), 84 | width=width, digits=digits) 85 | 86 | return -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import logging 4 | from transformers import AdamW, get_linear_schedule_with_warmup 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def build_optimizer_and_scheduler(args, model, t_total): 11 | module = ( 12 | model.module if hasattr(model, "module") else model 13 | ) 14 | 15 | # 差分学习率 16 | no_decay = ["bias", "LayerNorm.weight"] 17 | model_param = list(module.named_parameters()) 18 | 19 | bert_param_optimizer = [] 20 | other_param_optimizer = [] 21 | 22 | for name, para in model_param: 23 | space = name.split('.') 24 | # print(name) 25 | if space[0] == 'bert_module': 26 | bert_param_optimizer.append((name, para)) 27 | else: 28 | other_param_optimizer.append((name, para)) 29 | 30 | optimizer_grouped_parameters = [ 31 | # bert other module 32 | {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 33 | "weight_decay": args.weight_decay, 'lr': args.lr}, 34 | {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 35 | "weight_decay": 0.0, 'lr': args.lr}, 36 | 37 | # 其他模块,差分学习率 38 | {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)], 39 | "weight_decay": args.weight_decay, 'lr': args.other_lr}, 40 | {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)], 41 | "weight_decay": 0.0, 'lr': args.other_lr}, 42 | ] 43 | 44 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 45 | scheduler = get_linear_schedule_with_warmup( 46 | optimizer, num_warmup_steps=int(args.warmup_proportion * t_total), num_training_steps=t_total 47 | ) 48 | 49 | return optimizer, scheduler 50 | 51 | def save_model(args, model, model_name, global_step): 52 | """保存最好的验证集效果最好那个模型""" 53 | output_dir = os.path.join(args.output_dir, '{}'.format(model_name, global_step)) 54 | if not os.path.exists(output_dir): 55 | os.makedirs(output_dir, exist_ok=True) 56 | 57 | # take care of model distributed / parallel training 58 | model_to_save = ( 59 | model.module if hasattr(model, "module") else model 60 | ) 61 | logger.info('Saving model checkpoint to {}'.format(output_dir)) 62 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt')) 63 | 64 | def save_model_step(args, model, global_step): 65 | """根据global_step来保存模型""" 66 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 67 | if not os.path.exists(output_dir): 68 | os.makedirs(output_dir, exist_ok=True) 69 | 70 | # take care of model distributed / parallel training 71 | model_to_save = ( 72 | model.module if hasattr(model, "module") else model 73 | ) 74 | logger.info('Saving model & optimizer & scheduler checkpoint to {}.format(output_dir)') 75 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt')) 76 | 77 | def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True): 78 | """ 79 | 加载模型 & 放置到 GPU 中(单卡 / 多卡) 80 | """ 81 | gpu_ids = gpu_ids.split(',') 82 | 83 | # set to device to the first cuda 84 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0]) 85 | 86 | if ckpt_path is not None: 87 | logger.info('Load ckpt from {}'.format(ckpt_path)) 88 | model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict) 89 | 90 | model.to(device) 91 | 92 | if len(gpu_ids) > 1: 93 | logger.info('Use multi gpus in: {}'.format(gpu_ids)) 94 | gpu_ids = [int(x) for x in gpu_ids] 95 | model = torch.nn.DataParallel(model, device_ids=gpu_ids) 96 | else: 97 | logger.info('Use single gpu in: {}'.format(gpu_ids)) 98 | 99 | return model, device --------------------------------------------------------------------------------