├── README.md ├── checkpoints └── 占位.txt ├── config.py ├── data └── ske │ ├── mid_data │ └── predicates.json │ └── raw_data │ ├── all_50_schemas │ ├── dev_data.json │ └── process.py ├── data_loader.py ├── logs └── bert.log ├── main.py ├── model.py ├── model_hub └── 占位.txt └── utils ├── __pycache__ ├── common_utils.cpython-37.pyc ├── metric_utils.cpython-37.pyc └── train_utils.cpython-37.pyc ├── common_utils.py ├── metric_utils.py └── train_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # pytorch_GlobalPointer_triple_extraction 2 | 3 | 基于pytorch的GlobalPointer进行三元组抽取。 4 | 5 | 具体使用说明: 6 | 7 | - 1、在data/ske/raw_data下是原始数据,新建一个process.py,主要是得到mid_data下的关系的类型。 8 | - 2、针对于不同的数据源,在data_loader.py中修改MyDataset类下,返回的是一个列表,列表中的每个元素是:(text, labels),其中labels是[[主体,类别,客体]]。 9 | - 3、运行main.py进行训练、验证、测试和预测。 10 | 11 | 数据和模型下载地址:链接:https://pan.baidu.com/s/1HOaGUiRsknIBtXS_ASNF-w?pwd=rm2u 提取码:rm2u 12 | 13 | # 依赖 14 | 15 | ``` 16 | pytorch==1.6.0 17 | transformers==4.5.0 18 | tensorboardX 19 | ``` 20 | 21 | **特别注意**:eval_steps要根据总的steps进行合理设置,不要设置得太小,否则在初期就进行验证会极慢(因为预测的负样本太多了)。 22 | # 运行 23 | 24 | ```python 25 | python main.py \ 26 | --bert_dir="model_hub/chinese-bert-wwm-ext/" \ 27 | --data_dir="./data/ske/" \ 28 | --log_dir="./logs/" \ 29 | --output_dir="./checkpoints/" \ 30 | --num_tags=49 \ 31 | --seed=123 \ 32 | --gpu_ids="0" \ 33 | --max_seq_len=256 \ 34 | --lr=5e-5 \ 35 | --other_lr=5e-5 \ 36 | --train_batch_size=32 \ 37 | --train_epochs=1 \ 38 | --eval_steps=500 \ 39 | --eval_batch_size=8 \ 40 | --max_grad_norm=1 \ 41 | --warmup_proportion=0.1 \ 42 | --adam_epsilon=1e-8 \ 43 | --weight_decay=0.01 \ 44 | --dropout_prob=0.1 \ 45 | --use_tensorboard="False" \ 46 | --use_dev_num=1000 47 | ``` 48 | 49 | ### 结果 50 | 51 | 这里以batch_size=32运行了2000步。 52 | 53 | ```python 54 | precision:0.7831715210355987 recall:0.7298578199052133 f1:0.7555753791257805 55 | ``` 56 | 57 | ```python 58 | 文本: 查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部 59 | 主体: [['查尔斯·阿兰基斯']] 60 | 客体: [['智利', '圣地亚哥', '智利圣地亚哥', '1989年4月17日']] 61 | 关系: [[('查尔斯·阿兰基斯', '出生日期', '1989年4月17日'), ('查尔斯·阿兰基斯', '出生地', '智利'), ('查尔斯·阿兰基斯', '国籍', '智利'), ('查尔斯·阿兰基斯', '出生地', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '国籍', '智利圣地亚哥'), ('查尔斯·阿兰基斯', '出生地', '圣地亚哥')]] 62 | ==================================================================================================== 63 | 文本: 《离开》是由张宇谱曲,演唱 64 | 主体: [['离开']] 65 | 客体: [['张宇']] 66 | 关系: [[('离开', '歌手', '张宇'), ('离开', '作曲', '张宇')]] 67 | ==================================================================================================== 68 | 文本: 《愤怒的唐僧》由北京吴意波影视文化工作室与优酷电视剧频道联合制作,故事以喜剧元素为主,讲述唐僧与佛祖打牌,得罪了佛祖,被踢下人间再渡九九八十一难的故事 69 | 主体: [['愤怒的唐僧']] 70 | 客体: [['北京吴意波影视文化工作室']] 71 | 关系: [[('愤怒的唐僧', '出品公司', '北京吴意波影视文化工作室')]] 72 | ==================================================================================================== 73 | 文本: 李治即位后,萧淑妃受宠,王皇后为了排挤萧淑妃,答应李治让身在感业寺的武则天续起头发,重新纳入后宫 74 | 主体: [['李治', '萧淑妃']] 75 | 客体: [['李治', '萧淑妃']] 76 | 关系: [[('李治', '妻子', '萧淑妃'), ('萧淑妃', '丈夫', '李治')]] 77 | ==================================================================================================== 78 | 文本: 《工业4.0》是2015年机械工业出版社出版的图书,作者是(德)阿尔冯斯·波特霍夫,恩斯特·安德雷亚斯·哈特曼 79 | 主体: [['工业4.0']] 80 | 客体: [['机械工业出版社', '阿尔冯斯·波特霍夫']] 81 | 关系: [[('工业4.0', '出版社', '机械工业出版社'), ('工业4.0', '作者', '阿尔冯斯·波特霍夫')]] 82 | ==================================================================================================== 83 | 文本: 周佛海被捕入狱之后,其妻杨淑慧散尽家产请蒋介石枪下留人,于是周佛海从死刑变为无期,不过此人或许作恶多端,改判没多久便病逝于监狱,据悉是心脏病发作 84 | 主体: [['周佛海', '杨淑慧', '蒋介石']] 85 | 客体: [['周佛海', '杨淑慧', '蒋介石']] 86 | 关系: [[('周佛海', '妻子', '杨淑慧'), ('杨淑慧', '丈夫', '周佛海'), ('杨淑慧', '丈夫', '蒋介石'), ('蒋介石', '妻子', '杨淑慧')]] 87 | ==================================================================================================== 88 | 文本: 《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧 89 | 主体: [['李烈钧自述']] 90 | 客体: [['李烈钧', '人民日报出版社']] 91 | 关系: [[('李烈钧自述', '作者', '李烈钧'), ('李烈钧自述', '出版社', '人民日报出版社')]] 92 | ==================================================================================================== 93 | 文本: 除演艺事业外,李冰冰热心公益,发起并亲自参与多项环保慈善活动,积极投身其中,身体力行担起了回馈社会的责任于02年出演《少年包青天》,进入大家视线 94 | 主体: [['少年包青天']] 95 | 客体: [['李冰冰']] 96 | 关系: [[('少年包青天', '主演', '李冰冰')]] 97 | ==================================================================================================== 98 | 文本: 马志舟,1907年出生,陕西三原人,汉族,中国共产党,任红四团第一连连长,1933年逝世 99 | 主体: [['马志舟']] 100 | 客体: [['汉族', '1907年', '中国', '陕西三原']] 101 | 关系: [[('马志舟', '出生日期', '1907年'), ('马志舟', '出生地', '陕西三原'), ('马志舟', '民族', '汉族'), ('马志舟', '国籍', '中国')]] 102 | ==================================================================================================== 103 | 文本: 斑刺莺是雀形目、剌嘴莺科的一种动物,分布于澳大利亚和新西兰,包括澳大利亚、新西兰、塔斯马尼亚及其附近的岛屿 104 | 主体: [['斑刺莺']] 105 | 客体: [['雀形目']] 106 | 关系: [[('斑刺莺', '目', '雀形目')]] 107 | ==================================================================================================== 108 | 文本: 《课本上学不到的生物学2》是2013年上海科技教育出版社出版的图书 109 | 主体: [['课本上学不到的生物学2']] 110 | 客体: [['上海科技教育出版社']] 111 | 关系: [[('课本上学不到的生物学2', '出版社', '上海科技教育出版社')]] 112 | ==================================================================================================== 113 | ``` 114 | 115 | # 参考 116 | 117 | > 模型参考:[bert4torch/task_relation_extraction_gplinker.py](https://github.com/Tongjilibo/bert4torch/blob/master/examples/relation_extraction/task_relation_extraction_gplinker.py) 118 | > 119 | > [将“softmax+交叉熵”推广到多标签分类问题 - 科学空间|Scientific Spaces](https://spaces.ac.cn/archives/7359) 120 | > 121 | > [GPLinker:基于GlobalPointer的实体关系联合抽取 - 科学空间|Scientific Spaces](https://spaces.ac.cn/archives/8888) 122 | 123 | -------------------------------------------------------------------------------- /checkpoints/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_GlobalPointer_triple_extraction/538c9b9261c4b6c6436a6831d09e88552a2db774/checkpoints/占位.txt -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | class Args: 5 | @staticmethod 6 | def parse(): 7 | parser = argparse.ArgumentParser() 8 | return parser 9 | 10 | @staticmethod 11 | def initialize(parser): 12 | # args for path 13 | parser.add_argument('--output_dir', default='./checkpoints/', 14 | help='the output dir for model checkpoints') 15 | 16 | parser.add_argument('--bert_dir', default='../model_hub/bert-base-chinese/', 17 | help='bert dir for uer') 18 | parser.add_argument('--data_dir', default='./data/cner/', 19 | help='data dir for uer') 20 | parser.add_argument('--log_dir', default='./logs/', 21 | help='log dir for uer') 22 | 23 | # other args 24 | parser.add_argument('--num_tags', default=53, type=int, 25 | help='number of tags') 26 | parser.add_argument('--seed', type=int, default=123, help='random seed') 27 | 28 | parser.add_argument('--gpu_ids', type=str, default='0', 29 | help='gpu ids to use, -1 for cpu, "0,1" for multi gpu') 30 | 31 | parser.add_argument('--max_seq_len', default=256, type=int) 32 | 33 | parser.add_argument('--eval_batch_size', default=12, type=int) 34 | 35 | 36 | # train args 37 | parser.add_argument('--train_epochs', default=15, type=int, 38 | help='Max training epoch') 39 | 40 | parser.add_argument('--dropout_prob', default=0.1, type=float, 41 | help='drop out probability') 42 | 43 | # 2e-5 44 | parser.add_argument('--lr', default=3e-5, type=float, 45 | help='bert学习率') 46 | # 2e-3 47 | parser.add_argument('--other_lr', default=3e-4, type=float, 48 | help='bilstm和多层感知机学习率') 49 | # 0.5 50 | parser.add_argument('--max_grad_norm', default=1, type=float, 51 | help='max grad clip') 52 | parser.add_argument('--use_tensorboard', default="True", 53 | help='max grad clip') 54 | 55 | parser.add_argument('--warmup_proportion', default=0.1, type=float) 56 | 57 | parser.add_argument('--weight_decay', default=0.01, type=float) 58 | 59 | parser.add_argument('--adam_epsilon', default=1e-8, type=float) 60 | 61 | parser.add_argument('--train_batch_size', default=32, type=int) 62 | parser.add_argument('--use_dev_num', default=32, type=int, help="用于验证和测试的数量") 63 | parser.add_argument('--eval_steps', default=32, type=int, help="多少步进行验证") 64 | 65 | 66 | 67 | return parser 68 | 69 | def get_parser(self): 70 | parser = self.parse() 71 | parser = self.initialize(parser) 72 | return parser.parse_args() -------------------------------------------------------------------------------- /data/ske/mid_data/predicates.json: -------------------------------------------------------------------------------- 1 | ["目", "面积", "作词", "作者", "祖籍", "海拔", "号", "修业年限", "编剧", "父亲", "出版社", "所属专辑", "歌手", "国籍", "嘉宾", "所在城市", "邮政编码", "主演", "出生地", "占地面积", "毕业院校", "丈夫", "朝代", "字", "注册资本", "专业代码", "出生日期", "主持人", "身高", "创始人", "简称", "连载网站", "人口数量", "总部地点", "母亲", "董事长", "作曲", "民族", "制片人", "成立日期", "导演", "气候", "首都", "改编自", "出品公司", "主角", "上映时间", "官方语言", "妻子"] -------------------------------------------------------------------------------- /data/ske/raw_data/all_50_schemas: -------------------------------------------------------------------------------- 1 | {"object_type": "地点", "predicate": "祖籍", "subject_type": "人物"} 2 | {"object_type": "人物", "predicate": "父亲", "subject_type": "人物"} 3 | {"object_type": "地点", "predicate": "总部地点", "subject_type": "企业"} 4 | {"object_type": "地点", "predicate": "出生地", "subject_type": "人物"} 5 | {"object_type": "目", "predicate": "目", "subject_type": "生物"} 6 | {"object_type": "Number", "predicate": "面积", "subject_type": "行政区"} 7 | {"object_type": "Text", "predicate": "简称", "subject_type": "机构"} 8 | {"object_type": "Date", "predicate": "上映时间", "subject_type": "影视作品"} 9 | {"object_type": "人物", "predicate": "妻子", "subject_type": "人物"} 10 | {"object_type": "音乐专辑", "predicate": "所属专辑", "subject_type": "歌曲"} 11 | {"object_type": "Number", "predicate": "注册资本", "subject_type": "企业"} 12 | {"object_type": "城市", "predicate": "首都", "subject_type": "国家"} 13 | {"object_type": "人物", "predicate": "导演", "subject_type": "影视作品"} 14 | {"object_type": "Text", "predicate": "字", "subject_type": "历史人物"} 15 | {"object_type": "Number", "predicate": "身高", "subject_type": "人物"} 16 | {"object_type": "企业", "predicate": "出品公司", "subject_type": "影视作品"} 17 | {"object_type": "Number", "predicate": "修业年限", "subject_type": "学科专业"} 18 | {"object_type": "Date", "predicate": "出生日期", "subject_type": "人物"} 19 | {"object_type": "人物", "predicate": "制片人", "subject_type": "影视作品"} 20 | {"object_type": "人物", "predicate": "母亲", "subject_type": "人物"} 21 | {"object_type": "人物", "predicate": "编剧", "subject_type": "影视作品"} 22 | {"object_type": "国家", "predicate": "国籍", "subject_type": "人物"} 23 | {"object_type": "Number", "predicate": "海拔", "subject_type": "地点"} 24 | {"object_type": "网站", "predicate": "连载网站", "subject_type": "网络小说"} 25 | {"object_type": "人物", "predicate": "丈夫", "subject_type": "人物"} 26 | {"object_type": "Text", "predicate": "朝代", "subject_type": "历史人物"} 27 | {"object_type": "Text", "predicate": "民族", "subject_type": "人物"} 28 | {"object_type": "Text", "predicate": "号", "subject_type": "历史人物"} 29 | {"object_type": "出版社", "predicate": "出版社", "subject_type": "书籍"} 30 | {"object_type": "人物", "predicate": "主持人", "subject_type": "电视综艺"} 31 | {"object_type": "Text", "predicate": "专业代码", "subject_type": "学科专业"} 32 | {"object_type": "人物", "predicate": "歌手", "subject_type": "歌曲"} 33 | {"object_type": "人物", "predicate": "作词", "subject_type": "歌曲"} 34 | {"object_type": "人物", "predicate": "主角", "subject_type": "网络小说"} 35 | {"object_type": "人物", "predicate": "董事长", "subject_type": "企业"} 36 | {"object_type": "Date", "predicate": "成立日期", "subject_type": "机构"} 37 | {"object_type": "学校", "predicate": "毕业院校", "subject_type": "人物"} 38 | {"object_type": "Number", "predicate": "占地面积", "subject_type": "机构"} 39 | {"object_type": "语言", "predicate": "官方语言", "subject_type": "国家"} 40 | {"object_type": "Text", "predicate": "邮政编码", "subject_type": "行政区"} 41 | {"object_type": "Number", "predicate": "人口数量", "subject_type": "行政区"} 42 | {"object_type": "城市", "predicate": "所在城市", "subject_type": "景点"} 43 | {"object_type": "人物", "predicate": "作者", "subject_type": "图书作品"} 44 | {"object_type": "Date", "predicate": "成立日期", "subject_type": "企业"} 45 | {"object_type": "人物", "predicate": "作曲", "subject_type": "歌曲"} 46 | {"object_type": "气候", "predicate": "气候", "subject_type": "行政区"} 47 | {"object_type": "人物", "predicate": "嘉宾", "subject_type": "电视综艺"} 48 | {"object_type": "人物", "predicate": "主演", "subject_type": "影视作品"} 49 | {"object_type": "作品", "predicate": "改编自", "subject_type": "影视作品"} 50 | {"object_type": "人物", "predicate": "创始人", "subject_type": "企业"} 51 | -------------------------------------------------------------------------------- /data/ske/raw_data/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict, Counter 4 | from tqdm import tqdm 5 | import pandas as pd 6 | 7 | if not os.path.exists('../mid_data'): 8 | os.mkdir("../mid_data") 9 | 10 | predicates = set() 11 | 12 | with open('all_50_schemas', 'r') as fp: 13 | data = fp.readlines() 14 | for d in data: 15 | d = json.loads(d) 16 | predicates.add(d['predicate']) 17 | 18 | with open('../mid_data/predicates.json', 'w', encoding='utf-8') as fp: 19 | json.dump(list(predicates), fp, ensure_ascii=False) 20 | 21 | 22 | with open('train_data.json', 'r', encoding='utf-8') as fp: 23 | data = fp.readlines() 24 | 25 | count_lengths = [] 26 | count_predicates = defaultdict(int) 27 | for d in tqdm(data, ncols=100): 28 | d = json.loads(d) 29 | text = d['text'] 30 | for spo in d['spo_list']: 31 | count_predicates[spo['predicate']] += 1 32 | count_lengths.append(len(text)) 33 | 34 | lengths = Counter(count_lengths) 35 | print(lengths) 36 | print(count_predicates) 37 | 38 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | from torch.utils.data import DataLoader, Dataset 5 | from utils.common_utils import sequence_padding 6 | 7 | 8 | class ListDataset(Dataset): 9 | def __init__(self, file_path=None, data=None, **kwargs): 10 | self.kwargs = kwargs 11 | if isinstance(file_path, (str, list)): 12 | self.data = self.load_data(file_path) 13 | elif isinstance(data, list): 14 | self.data = data 15 | else: 16 | raise ValueError('The input args shall be str format file_path / list format dataset') 17 | 18 | def __len__(self): 19 | return len(self.data) 20 | 21 | def __getitem__(self, index): 22 | return self.data[index] 23 | 24 | @staticmethod 25 | def load_data(file_path): 26 | return file_path 27 | 28 | 29 | 30 | # 加载数据集 31 | class MyDataset(ListDataset): 32 | @staticmethod 33 | def load_data(filename): 34 | examples = [] 35 | with open(filename, encoding='utf-8') as f: 36 | raw_examples = f.readlines() 37 | # 这里是从json数据中的字典中获取 38 | for i, item in enumerate(raw_examples): 39 | # print(i,item) 40 | item = json.loads(item) 41 | text = item['text'] 42 | spo_list = item['spo_list'] 43 | labels = [] # [subject, predicate, object] 44 | for spo in spo_list: 45 | subject = spo['subject'] 46 | object = spo['object'] 47 | predicate = spo['predicate'] 48 | labels.append([subject, predicate, object]) 49 | examples.append((text, labels)) 50 | return examples 51 | 52 | class Collate: 53 | def __init__(self, max_len, tag2id, device, tokenizer): 54 | self.maxlen = max_len 55 | self.tag2id = tag2id 56 | self.id2tag = {v:k for k,v in tag2id.items()} 57 | self.device = device 58 | self.tokenizer = tokenizer 59 | 60 | def collate_fn(self, batch): 61 | def search(pattern, sequence): 62 | """从sequence中寻找子串pattern 63 | 如果找到,返回第一个下标;否则返回-1。 64 | """ 65 | n = len(pattern) 66 | for i in range(len(sequence)): 67 | if sequence[i:i + n] == pattern: 68 | return i 69 | return -1 70 | batch_head_labels = [] 71 | batch_tail_labels = [] 72 | batch_entity_labels = [] 73 | batch_token_ids = [] 74 | batch_attention_mask = [] 75 | batch_token_type_ids = [] 76 | callback = [] 77 | for i, (text, text_labels) in enumerate(batch): 78 | if len(text) > self.maxlen - 2: 79 | text = text[:self.maxlen - 2] 80 | tokens = [i for i in text] 81 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 82 | spoes = set() 83 | callback_text_labels = [] 84 | for s, p, o in text_labels: 85 | p = self.tag2id[p] 86 | s = [i for i in s] 87 | o = [i for i in o] 88 | s_idx = search(s, tokens) # 主体的头 89 | o_idx = search(o, tokens) # 客体的头 90 | if s_idx != -1 and o_idx != -1: 91 | callback_text_labels.append(("".join(s), self.id2tag[p], "".join(o))) 92 | spoes.add((s_idx, s_idx + len(s) - 1, p, o_idx, o_idx + len(o) - 1)) 93 | # print(text_labels) 94 | # print(text) 95 | # print(spoes) 96 | # 构建标签 97 | entity_labels = [set() for _ in range(2)] # [主体, 客体] 98 | head_labels = [set() for _ in range(len(self.tag2id))] # 每个关系中主体和客体的头 99 | tail_labels = [set() for _ in range(len(self.tag2id))] # 每个关系中主体和客体的尾 100 | for sh, st, p, oh, ot in spoes: 101 | entity_labels[0].add((sh, st)) 102 | entity_labels[1].add((oh, ot)) 103 | head_labels[p].add((sh, oh)) 104 | tail_labels[p].add((st, ot)) 105 | 106 | 107 | for label in entity_labels + head_labels + tail_labels: 108 | if not label: # 至少要有一个标签 109 | label.add((0, 0)) # 如果没有则用0填充 110 | 111 | # entity_labels:(2, 1, 2) head_labels:(49, 1, 2) tail_labels:(49, 1, 2) 112 | """ 113 | 对于entity_labels而言,第一个集合是主体,第二个集合是客体,使用pading补全到相同长度 114 | [{(0, 2)}, {(21, 22), (5, 9)}] 115 | [[[ 0 2] 116 | [ 0 0]] 117 | 118 | [[21 22] 119 | [ 5 9]]] 120 | [['九玄珠', '连载网站', '纵横中文网'], ['九玄珠', '作者', '龙马']] 121 | """ 122 | 123 | entity_labels = sequence_padding([list(l) for l in entity_labels]) # [subject/object=2, 实体个数, 实体起终点] 124 | head_labels = sequence_padding([list(l) for l in head_labels]) # [关系个数, 该关系下subject/object配对数, subject/object起点] 125 | tail_labels = sequence_padding([list(l) for l in tail_labels]) # [关系个数, 该关系下subject/object配对数, subject/object终点] 126 | 127 | 128 | token_ids = self.tokenizer.convert_tokens_to_ids(tokens) 129 | batch_token_ids.append(token_ids) # 前面已经限制了长度 130 | batch_attention_mask.append([1] * len(token_ids)) 131 | batch_token_type_ids.append([0] * len(token_ids)) 132 | batch_head_labels.append(head_labels) 133 | batch_tail_labels.append(tail_labels) 134 | batch_entity_labels.append(entity_labels) 135 | callback.append((text, callback_text_labels)) 136 | batch_token_ids = torch.tensor(sequence_padding(batch_token_ids, length=self.maxlen), dtype=torch.long, device=self.device) 137 | attention_mask = torch.tensor(sequence_padding(batch_attention_mask, length=self.maxlen), dtype=torch.long, device=self.device) 138 | token_type_ids = torch.tensor(sequence_padding(batch_token_type_ids, length=self.maxlen), dtype=torch.long, device=self.device) 139 | batch_head_labels = torch.tensor(sequence_padding(batch_head_labels, seq_dims=2), dtype=torch.float, device=self.device) 140 | batch_tail_labels = torch.tensor(sequence_padding(batch_tail_labels, seq_dims=2), dtype=torch.float, device=self.device) 141 | batch_entity_labels = torch.tensor(sequence_padding(batch_entity_labels, seq_dims=2), dtype=torch.float, device=self.device) 142 | 143 | return batch_token_ids, attention_mask, token_type_ids, batch_head_labels, batch_tail_labels, batch_entity_labels, callback 144 | 145 | 146 | if __name__ == "__main__": 147 | from transformers import BertTokenizer 148 | max_len = 256 149 | tokenizer = BertTokenizer.from_pretrained('model_hub/chinese-bert-wwm-ext/vocab.txt') 150 | train_dataset = MyDataset(file_path='data/ske/raw_data/train_data.json', 151 | tokenizer=tokenizer, 152 | max_len=max_len) 153 | # print(train_dataset[0]) 154 | 155 | with open('data/ske/mid_data/predicates.json') as fp: 156 | labels = json.load(fp) 157 | id2tag = {} 158 | tag2id = {} 159 | for i,label in enumerate(labels): 160 | id2tag[i] = label 161 | tag2id[label] = i 162 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 163 | collate = Collate(max_len=max_len, tag2id=tag2id, device=device, tokenizer=tokenizer) 164 | # collate.collate_fn(train_dataset[:16]) 165 | batch_size = 2 166 | train_dataset = train_dataset[:10] 167 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate.collate_fn) 168 | 169 | """ 170 | torch.Size([2, 256]) 171 | torch.Size([2, 256]) 172 | torch.Size([2, 256]) 173 | torch.Size([2, 49, 1, 2]) 174 | torch.Size([2, 49, 1, 2]) 175 | torch.Size([2, 2, 1, 2]) 176 | """ 177 | for i, batch in enumerate(train_dataloader): 178 | leng = len(batch) - 1 179 | for j in range(leng): 180 | print(batch[j].shape) 181 | break -------------------------------------------------------------------------------- /logs/bert.log: -------------------------------------------------------------------------------- 1 | 2022-06-29 11:48:12,004 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 2 | 2022-06-29 11:49:58,527 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 3 | 2022-06-29 11:50:31,279 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 4 | 2022-06-29 11:50:47,094 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 5 | 2022-06-29 11:51:29,678 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 6 | 2022-06-29 11:51:36,701 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 7 | 2022-06-30 02:11:39,922 - INFO - main.py - - 294 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 8 | 2022-06-30 02:12:03,590 - INFO - main.py - - 294 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 9 | 2022-06-30 02:12:30,170 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 10 | 2022-06-30 02:12:41,438 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 11 | 2022-06-30 02:12:42,516 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 12 | 2022-06-30 02:26:07,338 - INFO - main.py - - 245 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 13 | 2022-06-30 02:26:14,214 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 14 | 2022-06-30 02:26:22,612 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 15 | 2022-06-30 02:26:23,623 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 16 | 2022-06-30 02:27:46,714 - INFO - main.py - - 245 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 17 | 2022-06-30 02:27:53,669 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 18 | 2022-06-30 02:28:02,258 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 19 | 2022-06-30 02:28:03,191 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 20 | 2022-06-30 02:28:52,891 - INFO - main.py - - 245 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 21 | 2022-06-30 02:28:59,734 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 22 | 2022-06-30 02:29:08,295 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 23 | 2022-06-30 02:29:09,217 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 24 | 2022-06-30 02:32:00,288 - INFO - main.py - - 247 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 25 | 2022-06-30 02:32:07,429 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 26 | 2022-06-30 02:32:16,005 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 27 | 2022-06-30 02:32:17,131 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 28 | 2022-06-30 02:32:59,456 - INFO - main.py - - 247 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 29 | 2022-06-30 02:33:06,349 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 30 | 2022-06-30 02:33:14,933 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 31 | 2022-06-30 02:33:15,858 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 32 | 2022-06-30 02:35:54,312 - INFO - main.py - - 246 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 33 | 2022-06-30 02:36:01,613 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 34 | 2022-06-30 02:36:10,140 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 35 | 2022-06-30 02:36:11,075 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 36 | 2022-06-30 02:36:43,697 - INFO - main.py - - 247 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 37 | 2022-06-30 02:36:50,572 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 38 | 2022-06-30 02:36:59,058 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 39 | 2022-06-30 02:36:59,988 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 40 | 2022-06-30 02:38:37,012 - INFO - main.py - - 248 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 41 | 2022-06-30 02:38:43,990 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 42 | 2022-06-30 02:38:52,993 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 43 | 2022-06-30 02:38:54,328 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 44 | 2022-06-30 02:44:44,645 - INFO - main.py - - 258 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 45 | 2022-06-30 02:44:51,599 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 46 | 2022-06-30 02:45:00,068 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 47 | 2022-06-30 02:45:01,028 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 48 | 2022-06-30 02:45:54,371 - INFO - main.py - - 258 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 49 | 2022-06-30 02:46:01,411 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 50 | 2022-06-30 02:46:09,887 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 51 | 2022-06-30 02:46:10,848 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 52 | 2022-06-30 02:47:06,798 - INFO - main.py - - 259 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 53 | 2022-06-30 02:47:14,152 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 54 | 2022-06-30 02:47:23,172 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 55 | 2022-06-30 02:47:24,137 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 56 | 2022-06-30 02:50:17,231 - INFO - main.py - - 261 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 57 | 2022-06-30 02:50:24,204 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 58 | 2022-06-30 02:50:32,708 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 59 | 2022-06-30 02:50:33,671 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 60 | 2022-06-30 02:52:15,179 - INFO - main.py - - 266 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 61 | 2022-06-30 02:52:22,175 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 62 | 2022-06-30 02:52:30,733 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 63 | 2022-06-30 02:52:31,705 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 64 | 2022-06-30 02:55:04,553 - INFO - main.py - - 267 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 65 | 2022-06-30 02:55:12,311 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 66 | 2022-06-30 02:55:21,395 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 67 | 2022-06-30 02:55:22,548 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 68 | 2022-06-30 02:57:29,336 - INFO - main.py - - 281 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 69 | 2022-06-30 02:57:36,520 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 70 | 2022-06-30 02:57:45,085 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 71 | 2022-06-30 02:57:46,024 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 72 | 2022-06-30 02:58:51,474 - INFO - main.py - - 281 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 73 | 2022-06-30 02:58:58,431 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 74 | 2022-06-30 02:59:06,975 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 75 | 2022-06-30 02:59:07,942 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 76 | 2022-06-30 03:09:03,482 - INFO - main.py - - 276 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 77 | 2022-06-30 03:09:10,471 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 78 | 2022-06-30 03:09:19,054 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 79 | 2022-06-30 03:09:20,034 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 80 | 2022-06-30 03:10:58,013 - INFO - main.py - - 276 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 81 | 2022-06-30 03:11:04,963 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 82 | 2022-06-30 03:11:13,499 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 83 | 2022-06-30 03:11:14,463 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 84 | 2022-06-30 03:14:59,353 - INFO - main.py - - 279 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 85 | 2022-06-30 03:15:06,319 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 86 | 2022-06-30 03:15:14,828 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 87 | 2022-06-30 03:15:15,757 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 88 | 2022-06-30 03:17:16,243 - INFO - main.py - - 279 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 89 | 2022-06-30 03:17:23,241 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 90 | 2022-06-30 03:17:31,867 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 91 | 2022-06-30 03:17:32,805 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 92 | 2022-06-30 03:17:44,900 - INFO - main.py - - 279 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 93 | 2022-06-30 03:17:51,854 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 94 | 2022-06-30 03:18:00,426 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 95 | 2022-06-30 03:18:01,536 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 96 | 2022-06-30 03:19:47,623 - INFO - main.py - - 279 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 97 | 2022-06-30 03:19:54,624 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 98 | 2022-06-30 03:20:03,068 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 99 | 2022-06-30 03:20:04,020 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 100 | 2022-06-30 03:24:38,076 - INFO - main.py - - 312 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 101 | 2022-06-30 03:24:45,025 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 102 | 2022-06-30 03:24:53,688 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 103 | 2022-06-30 03:24:54,664 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 104 | 2022-06-30 03:26:50,811 - INFO - main.py - - 313 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 105 | 2022-06-30 03:26:57,934 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 106 | 2022-06-30 03:27:06,607 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 107 | 2022-06-30 03:27:07,579 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 108 | 2022-06-30 03:30:37,398 - INFO - main.py - - 314 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 109 | 2022-06-30 03:30:44,394 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 110 | 2022-06-30 03:30:52,927 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 111 | 2022-06-30 03:30:53,908 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 112 | 2022-06-30 03:31:15,955 - INFO - main.py - - 314 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 113 | 2022-06-30 03:31:22,841 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 114 | 2022-06-30 03:31:31,424 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 115 | 2022-06-30 03:31:32,384 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 116 | 2022-06-30 03:33:07,563 - INFO - main.py - - 316 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 117 | 2022-06-30 03:33:14,542 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 118 | 2022-06-30 03:33:23,172 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 119 | 2022-06-30 03:33:24,134 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 120 | 2022-06-30 03:35:02,581 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 121 | 2022-06-30 03:35:09,602 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 122 | 2022-06-30 03:35:18,845 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 123 | 2022-06-30 03:35:19,804 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 124 | 2022-06-30 03:36:22,396 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=100, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 125 | 2022-06-30 03:36:33,832 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=10, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 126 | 2022-06-30 03:36:40,826 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 127 | 2022-06-30 03:36:48,812 - INFO - main.py - train - 64 - 【train】 epoch:0 0/5406 loss:269.4182 entity_loss:16.4784 head_loss:395.4211 tail_loss:396.3550 128 | 2022-06-30 03:36:50,498 - INFO - main.py - train - 64 - 【train】 epoch:0 1/5406 loss:270.2010 entity_loss:16.5278 head_loss:396.7143 tail_loss:397.3611 129 | 2022-06-30 03:36:52,189 - INFO - main.py - train - 64 - 【train】 epoch:0 2/5406 loss:266.3738 entity_loss:16.2139 head_loss:391.1121 tail_loss:391.7955 130 | 2022-06-30 03:36:53,879 - INFO - main.py - train - 64 - 【train】 epoch:0 3/5406 loss:269.4011 entity_loss:16.5687 head_loss:395.5172 tail_loss:396.1174 131 | 2022-06-30 03:36:55,569 - INFO - main.py - train - 64 - 【train】 epoch:0 4/5406 loss:250.0257 entity_loss:15.3507 head_loss:367.0298 tail_loss:367.6967 132 | 2022-06-30 03:36:57,281 - INFO - main.py - train - 64 - 【train】 epoch:0 5/5406 loss:259.6663 entity_loss:15.8807 head_loss:381.1756 tail_loss:381.9426 133 | 2022-06-30 03:36:58,986 - INFO - main.py - train - 64 - 【train】 epoch:0 6/5406 loss:256.6382 entity_loss:15.6712 head_loss:376.7637 tail_loss:377.4796 134 | 2022-06-30 03:37:00,694 - INFO - main.py - train - 64 - 【train】 epoch:0 7/5406 loss:271.7689 entity_loss:16.6315 head_loss:398.9182 tail_loss:399.7569 135 | 2022-06-30 03:37:02,421 - INFO - main.py - train - 64 - 【train】 epoch:0 8/5406 loss:258.0942 entity_loss:15.7037 head_loss:378.8968 tail_loss:379.6821 136 | 2022-06-30 03:37:04,130 - INFO - main.py - train - 64 - 【train】 epoch:0 9/5406 loss:258.2488 entity_loss:15.8331 head_loss:379.0878 tail_loss:379.8255 137 | 2022-06-30 03:48:08,606 - INFO - main.py - - 315 - Namespace(adam_epsilon=1e-08, bert_dir='model_hub/chinese-bert-wwm-ext/', data_dir='./data/ske', dropout_prob=0.1, eval_batch_size=8, eval_steps=10, gpu_ids='0', log_dir='./logs/', lr=5e-05, max_grad_norm=1.0, max_seq_len=256, num_tags=49, other_lr=5e-05, output_dir='./checkpoints/', seed=123, train_batch_size=32, train_epochs=1, use_dev_num=1000, use_tensorboard='False', warmup_proportion=0.1, weight_decay=0.01) 138 | 2022-06-30 03:48:27,057 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 139 | 2022-06-30 03:48:35,709 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 140 | 2022-06-30 03:48:38,125 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 141 | 2022-06-30 03:49:01,627 - INFO - train_utils.py - load_model_and_parallel - 87 - Load ckpt from ./checkpoints/bert/model.pt 142 | 2022-06-30 03:49:02,644 - INFO - train_utils.py - load_model_and_parallel - 97 - Use single gpu in: ['0'] 143 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from re import template 4 | import numpy as np 5 | from collections import defaultdict 6 | import torch 7 | from torch.utils.data import DataLoader, RandomSampler 8 | from transformers import BertTokenizer 9 | 10 | import config 11 | import data_loader 12 | from model import GlobalPointerRe 13 | from utils.common_utils import set_seed, set_logger, read_json, fine_grade_tokenize 14 | from utils.train_utils import load_model_and_parallel, build_optimizer_and_scheduler, save_model 15 | from utils.metric_utils import calculate_metric_relation, get_p_r_f 16 | from tensorboardX import SummaryWriter 17 | 18 | args = config.Args().get_parser() 19 | set_seed(args.seed) 20 | logger = logging.getLogger(__name__) 21 | 22 | if args.use_tensorboard == "True": 23 | writer = SummaryWriter(log_dir='./tensorboard') 24 | 25 | 26 | class BertForRe: 27 | def __init__(self, args, train_loader, dev_loader, test_loader, id2tag, tag2id, model, device): 28 | self.train_loader = train_loader 29 | self.dev_loader = dev_loader 30 | self.test_loader = test_loader 31 | self.args = args 32 | self.id2tag = id2tag 33 | self.tag2id = tag2id 34 | self.model = model 35 | self.device = device 36 | if train_loader is not None: 37 | self.t_total = len(self.train_loader) * args.train_epochs 38 | self.optimizer, self.scheduler = build_optimizer_and_scheduler(args, model, self.t_total) 39 | 40 | def train(self): 41 | # Train 42 | global_step = 0 43 | self.model.zero_grad() 44 | eval_steps = args.eval_steps #每多少个step打印损失及进行验证 45 | best_f1 = 0.0 46 | for epoch in range(self.args.train_epochs): 47 | for step, batch_data in enumerate(self.train_loader): 48 | self.model.train() 49 | for batch in batch_data[:-1]: 50 | batch = batch.to(self.device) 51 | # batch_token_ids, attention_mask, token_type_ids, batch_head_labels, batch_tail_labels, batch_entity_ids 52 | all_loss = self.model(batch_data[0], batch_data[1], batch_data[2], batch_data[3], batch_data[4], batch_data[5]) 53 | loss = all_loss['loss'] 54 | entity_loss = all_loss['entity_loss'] 55 | head_loss = all_loss['head_loss'] 56 | tail_loss = all_loss['tail_loss'] 57 | # loss.backward(loss.clone().detach()) 58 | loss.backward() 59 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm) 60 | self.optimizer.step() 61 | self.scheduler.step() 62 | self.model.zero_grad() 63 | logger.info('【train】 epoch:{} {}/{} loss:{:.4f} entity_loss:{:.4f} head_loss:{:.4f} tail_loss:{:.4f}'.format( 64 | epoch, global_step, self.t_total, loss.item(), entity_loss.item(), head_loss.item(), tail_loss.item())) 65 | global_step += 1 66 | 67 | if self.args.use_tensorboard == "True": 68 | writer.add_scalar('data/loss', loss.item(), global_step) 69 | if global_step % eval_steps == 0: 70 | precision, recall, f1_score = self.dev() 71 | logger.info('[eval] precision={:.4f} recall={:.4f} f1_score={:.4f}'.format(precision, recall, f1_score)) 72 | if f1_score > best_f1: 73 | save_model(self.args, self.model, model_name, global_step) 74 | best_f1 = f1_score 75 | 76 | 77 | def dev(self): 78 | self.model.eval() 79 | spos = [] 80 | true_spos = [] 81 | subjects = [] 82 | objects = [] 83 | all_examples = [] 84 | with torch.no_grad(): 85 | for eval_step, dev_batch_data in enumerate(dev_loader): 86 | for dev_batch in dev_batch_data[:-1]: 87 | dev_batch = dev_batch.to(device) 88 | 89 | entity_output, head_output, tail_output = model(dev_batch_data[0], dev_batch_data[1], dev_batch_data[2]) 90 | cur_batch_size = dev_batch_data[0].shape[0] 91 | dev_examples = dev_batch_data[-1] 92 | true_spos += [i[1] for i in dev_examples] 93 | all_examples += [i[0] for i in dev_examples] 94 | 95 | # torch.Size([8, 2, 256, 256]) torch.Size([8, 49, 256, 256]) torch.Size([8, 49, 256, 256]) 96 | # print(entity_output.shape, head_output.shape, tail_output.shape) 97 | for i in range(cur_batch_size): 98 | example = dev_examples[i][0] 99 | l = len(example) 100 | subject = [] 101 | object = [] 102 | subject_ids = [] 103 | object_ids = [] 104 | spo = [] 105 | single_entity_output = entity_output[i, ...] 106 | single_head_output = head_output[i, ...] 107 | single_tail_output = tail_output[i, ...] 108 | single_head_output = single_head_output[:, 1:l+1:, 1:l+1] 109 | single_tail_output = single_tail_output[:, 1:l+1:, 1:l+1] 110 | subject_entity_outpout = single_entity_output[:1, 1:l+1:, 1:l+1].squeeze() 111 | object_entity_output = single_entity_output[1:, 1:l+1:, 1:l+1].squeeze() 112 | # 注意这里阈值为什么是0 113 | subject_entity_outpout = np.where(subject_entity_outpout.cpu().numpy() > 0) 114 | object_entity_output = np.where(object_entity_output.cpu().numpy() > 0) 115 | for m,n in zip(*subject_entity_outpout): 116 | subject_ids.append((m, n)) 117 | for m,n in zip(*object_entity_output): 118 | object_ids.append((m, n)) 119 | for sh, st in subject_ids: 120 | for oh, ot in object_ids: 121 | # print(example[sh:st+1], example[oh:ot+1]) 122 | # print(np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)) 123 | # print(np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)) 124 | subj = example[sh:st+1] 125 | obj = example[oh:ot+1] 126 | subject.append(subj) 127 | object.append(obj) 128 | 129 | re1 = np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)[0] 130 | re2 = np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)[0] 131 | res = set(re1) & set(re2) 132 | for r in res: 133 | spo.append((subj, self.id2tag[r], obj)) 134 | subjects.append(subject) 135 | objects.append(object) 136 | spos.append(spo) 137 | 138 | 139 | tp, fp, fn = calculate_metric_relation(spos, true_spos) 140 | p, r, f1 = get_p_r_f(tp, fp, fn) 141 | 142 | return p, r, f1 143 | 144 | 145 | 146 | def test(self, model_path): 147 | model = GlobalPointerRe(self.args) 148 | model, device = load_model_and_parallel(model, self.args.gpu_ids, model_path) 149 | model.eval() 150 | spos = [] 151 | true_spos = [] 152 | subjects = [] 153 | objects = [] 154 | all_examples = [] 155 | with torch.no_grad(): 156 | for eval_step, dev_batch_data in enumerate(dev_loader): 157 | for dev_batch in dev_batch_data[:-1]: 158 | dev_batch = dev_batch.to(device) 159 | 160 | entity_output, head_output, tail_output = model(dev_batch_data[0], dev_batch_data[1], dev_batch_data[2]) 161 | cur_batch_size = dev_batch_data[0].shape[0] 162 | dev_examples = dev_batch_data[-1] 163 | true_spos += [i[1] for i in dev_examples] 164 | all_examples += [i[0] for i in dev_examples] 165 | 166 | # torch.Size([8, 2, 256, 256]) torch.Size([8, 49, 256, 256]) torch.Size([8, 49, 256, 256]) 167 | # print(entity_output.shape, head_output.shape, tail_output.shape) 168 | for i in range(cur_batch_size): 169 | example = dev_examples[i][0] 170 | l = len(example) 171 | subject = [] 172 | object = [] 173 | subject_ids = [] 174 | object_ids = [] 175 | spo = [] 176 | single_entity_output = entity_output[i, ...] 177 | single_head_output = head_output[i, ...] 178 | single_tail_output = tail_output[i, ...] 179 | single_head_output = single_head_output[:, 1:l+1:, 1:l+1] 180 | single_tail_output = single_tail_output[:, 1:l+1:, 1:l+1] 181 | subject_entity_outpout = single_entity_output[:1, 1:l+1:, 1:l+1].squeeze() 182 | object_entity_output = single_entity_output[1:, 1:l+1:, 1:l+1].squeeze() 183 | # 注意这里阈值为什么是0 184 | subject_entity_outpout = np.where(subject_entity_outpout.cpu().numpy() > 0) 185 | object_entity_output = np.where(object_entity_output.cpu().numpy() > 0) 186 | for m,n in zip(*subject_entity_outpout): 187 | subject_ids.append((m, n)) 188 | for m,n in zip(*object_entity_output): 189 | object_ids.append((m, n)) 190 | for sh, st in subject_ids: 191 | for oh, ot in object_ids: 192 | # print(example[sh:st+1], example[oh:ot+1]) 193 | # print(np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)) 194 | # print(np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)) 195 | subj = example[sh:st+1] 196 | obj = example[oh:ot+1] 197 | subject.append(subj) 198 | object.append(obj) 199 | 200 | re1 = np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)[0] 201 | re2 = np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)[0] 202 | res = set(re1) & set(re2) 203 | for r in res: 204 | spo.append((subj, self.id2tag[r], obj)) 205 | subjects.append(subject) 206 | objects.append(object) 207 | spos.append(spo) 208 | 209 | 210 | # for i, (m, n, ex) in enumerate(zip(spos, true_spos, all_examples)): 211 | # if i <= 10: 212 | # print(ex) 213 | # print(m, n) 214 | # print('='*100) 215 | # print(len(all_examples)) 216 | # print(len(true_spos)) 217 | # print(len(spos)) 218 | tp, fp, fn = calculate_metric_relation(spos, true_spos) 219 | p, r, f1 = get_p_r_f(tp, fp, fn) 220 | print("========metric========") 221 | print("precision:{} recall:{} f1:{}".format(p, r, f1)) 222 | 223 | return p, r, f1 224 | 225 | def predict(self, raw_text, model, tokenizer): 226 | model.eval() 227 | with torch.no_grad(): 228 | tokens = [i for i in raw_text] 229 | if len(tokens) > self.args.max_seq_len - 2: 230 | tokens = tokens[:self.args.max_seq_len - 2] 231 | tokens = ['[CLS]'] + tokens + ['[SEP]'] 232 | token_ids = tokenizer.convert_tokens_to_ids(tokens) 233 | attention_masks = [1] * len(token_ids) 234 | token_type_ids = [0] * len(token_ids) 235 | if len(token_ids) < self.args.max_seq_len: 236 | token_ids = token_ids + [0] * (self.args.max_seq_len - len(tokens)) 237 | attention_masks = attention_masks + [0] * (self.args.max_seq_len - len(tokens)) 238 | token_type_ids = token_type_ids + [0] * (self.args.max_seq_len - len(tokens)) 239 | assert len(token_ids) == self.args.max_seq_len 240 | assert len(attention_masks) == self.args.max_seq_len 241 | assert len(token_type_ids) == self.args.max_seq_len 242 | token_ids = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device) 243 | attention_masks = torch.from_numpy(np.array(attention_masks, dtype=np.uint8)).unsqueeze(0).to(self.device) 244 | token_type_ids = torch.from_numpy(np.array(token_type_ids)).unsqueeze(0).to(self.device) 245 | entity_output, head_output, tail_output = model(token_ids, attention_masks, token_type_ids) 246 | 247 | cur_batch_size = entity_output.shape[0] 248 | spos = [] 249 | subjects = [] 250 | objects = [] 251 | # print(entity_output.shape, head_output.shape, tail_output.shape) 252 | for i in range(cur_batch_size): 253 | example = raw_text 254 | l = len(example) 255 | subject = [] 256 | object = [] 257 | subject_ids = [] 258 | object_ids = [] 259 | spo = [] 260 | single_entity_output = entity_output[i, ...] 261 | single_head_output = head_output[i, ...] 262 | single_tail_output = tail_output[i, ...] 263 | single_head_output = single_head_output[:, 1:l+1:, 1:l+1] 264 | single_tail_output = single_tail_output[:, 1:l+1:, 1:l+1] 265 | subject_entity_outpout = single_entity_output[:1, 1:l+1:, 1:l+1].squeeze() 266 | object_entity_output = single_entity_output[1:, 1:l+1:, 1:l+1].squeeze() 267 | # 注意这里阈值为什么是0 268 | subject_entity_outpout = np.where(subject_entity_outpout.cpu().numpy() > 0) 269 | object_entity_output = np.where(object_entity_output.cpu().numpy() > 0) 270 | for m,n in zip(*subject_entity_outpout): 271 | subject_ids.append((m, n)) 272 | for m,n in zip(*object_entity_output): 273 | object_ids.append((m, n)) 274 | for sh, st in subject_ids: 275 | for oh, ot in object_ids: 276 | # print(example[sh:st+1], example[oh:ot+1]) 277 | # print(np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)) 278 | # print(np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)) 279 | subj = example[sh:st+1] 280 | obj = example[oh:ot+1] 281 | subject.append(subj) 282 | object.append(obj) 283 | 284 | re1 = np.where(single_head_output[:, sh, oh].cpu().numpy() > 0)[0] 285 | re2 = np.where(single_tail_output[:, st, ot].cpu().numpy() > 0)[0] 286 | res = set(re1) & set(re2) 287 | for r in res: 288 | spo.append((subj, self.id2tag[r], obj)) 289 | 290 | subjects.append(subject) 291 | objects.append(object) 292 | spos.append(spo) 293 | 294 | print("文本:", raw_text) 295 | print('主体:', [list(set(i)) for i in subjects]) 296 | print('客体:', [list(set(i)) for i in objects]) 297 | print('关系:', spos) 298 | print("="*100) 299 | 300 | if __name__ == '__main__': 301 | data_name = 'ske' 302 | model_name = 'bert' 303 | 304 | set_logger(os.path.join(args.log_dir, '{}.log'.format(model_name))) 305 | if data_name == "ske": 306 | args.data_dir = './data/ske' 307 | data_path = os.path.join(args.data_dir, 'raw_data') 308 | label_list = read_json(os.path.join(args.data_dir, 'mid_data'), 'predicates') 309 | tag2id = {} 310 | id2tag = {} 311 | for k,v in enumerate(label_list): 312 | tag2id[v] = k 313 | id2tag[k] = v 314 | 315 | logger.info(args) 316 | max_seq_len = args.max_seq_len 317 | tokenizer = BertTokenizer.from_pretrained('model_hub/chinese-bert-wwm-ext/vocab.txt') 318 | 319 | model = GlobalPointerRe(args) 320 | model, device = load_model_and_parallel(model, args.gpu_ids) 321 | 322 | collate = data_loader.Collate(max_len=max_seq_len, tag2id=tag2id, device=device, tokenizer=tokenizer) 323 | 324 | 325 | train_dataset = data_loader.MyDataset(file_path=os.path.join(data_path, 'train_data.json'), 326 | tokenizer=tokenizer, 327 | max_len=max_seq_len) 328 | 329 | train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate.collate_fn) 330 | dev_dataset = data_loader.MyDataset(file_path=os.path.join(data_path, 'dev_data.json'), 331 | tokenizer=tokenizer, 332 | max_len=max_seq_len) 333 | 334 | dev_dataset = dev_dataset[:args.use_dev_num] 335 | dev_loader = DataLoader(dev_dataset, batch_size=args.eval_batch_size, shuffle=False, collate_fn=collate.collate_fn) 336 | 337 | 338 | bertForNer = BertForRe(args, train_loader, dev_loader, dev_loader, id2tag, tag2id, model, device) 339 | # bertForNer.train() 340 | 341 | model_path = './checkpoints/bert/model.pt'.format(model_name) 342 | bertForNer.test(model_path) 343 | 344 | texts = [ 345 | '查尔斯·阿兰基斯(Charles Aránguiz),1989年4月17日出生于智利圣地亚哥,智利职业足球运动员,司职中场,效力于德国足球甲级联赛勒沃库森足球俱乐部', 346 | '《离开》是由张宇谱曲,演唱', 347 | '《愤怒的唐僧》由北京吴意波影视文化工作室与优酷电视剧频道联合制作,故事以喜剧元素为主,讲述唐僧与佛祖打牌,得罪了佛祖,被踢下人间再渡九九八十一难的故事', 348 | '李治即位后,萧淑妃受宠,王皇后为了排挤萧淑妃,答应李治让身在感业寺的武则天续起头发,重新纳入后宫', 349 | '《工业4.0》是2015年机械工业出版社出版的图书,作者是(德)阿尔冯斯·波特霍夫,恩斯特·安德雷亚斯·哈特曼', 350 | '周佛海被捕入狱之后,其妻杨淑慧散尽家产请蒋介石枪下留人,于是周佛海从死刑变为无期,不过此人或许作恶多端,改判没多久便病逝于监狱,据悉是心脏病发作', 351 | '《李烈钧自述》是2011年11月1日人民日报出版社出版的图书,作者是李烈钧', 352 | '除演艺事业外,李冰冰热心公益,发起并亲自参与多项环保慈善活动,积极投身其中,身体力行担起了回馈社会的责任于02年出演《少年包青天》,进入大家视线', 353 | '马志舟,1907年出生,陕西三原人,汉族,中国共产党,任红四团第一连连长,1933年逝世', 354 | '斑刺莺是雀形目、剌嘴莺科的一种动物,分布于澳大利亚和新西兰,包括澳大利亚、新西兰、塔斯马尼亚及其附近的岛屿', 355 | '《课本上学不到的生物学2》是2013年上海科技教育出版社出版的图书', 356 | ] 357 | model = GlobalPointerRe(args) 358 | model, device = load_model_and_parallel(model, args.gpu_ids, model_path) 359 | for text in texts: 360 | bertForNer.predict(text, model, tokenizer) 361 | 362 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from transformers import BertModel 6 | 7 | class SparseMultilabelCategoricalCrossentropy(nn.Module): 8 | """稀疏版多标签分类的交叉熵 9 | 说明: 10 | 1. y_true.shape=[..., num_positive], 11 | y_pred.shape=[..., num_classes]; 12 | 2. 请保证y_pred的值域是全体实数,换言之一般情况下 13 | y_pred不用加激活函数,尤其是不能加sigmoid或者 14 | softmax; 15 | 3. 预测阶段则输出y_pred大于0的类; 16 | 4. 详情请看:https://kexue.fm/archives/7359 。 17 | """ 18 | def __init__(self, mask_zero=False, epsilon=1e-7, **kwargs): 19 | super().__init__(**kwargs) 20 | self.mask_zero = mask_zero 21 | self.epsilon = epsilon 22 | 23 | def forward(self, y_pred, y_true): 24 | zeros = torch.zeros_like(y_pred[..., :1]) 25 | y_pred = torch.cat([y_pred, zeros], dim=-1) 26 | if self.mask_zero: 27 | infs = zeros + float('inf') 28 | y_pred = torch.cat([infs, y_pred[..., 1:]], dim=-1) 29 | y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true) 30 | y_pos_1 = torch.cat([y_pos_2, zeros], dim=-1) 31 | if self.mask_zero: 32 | y_pred = torch.cat([-infs, y_pred[..., 1:]], dim=-1) 33 | y_pos_2 = torch.gather(y_pred, dim=-1, index=y_true) 34 | pos_loss = torch.logsumexp(-y_pos_1, dim=-1) 35 | all_loss = torch.logsumexp(y_pred, dim=-1) # a 36 | aux_loss = torch.logsumexp(y_pos_2, dim=-1) - all_loss # b-a 37 | aux_loss = torch.clamp(1 - torch.exp(aux_loss), self.epsilon, 1) # 1-exp(b-a) 38 | neg_loss = all_loss + torch.log(aux_loss) # a + log[1-exp(b-a)] 39 | return pos_loss + neg_loss 40 | 41 | 42 | class MyLoss(SparseMultilabelCategoricalCrossentropy): 43 | def __init__(self, **kwargs): 44 | super().__init__(**kwargs) 45 | def forward(self, y_preds, y_trues): 46 | ''' y_preds: [Tensor], shape为[btz, heads, seq_len ,seq_len] 47 | ''' 48 | loss_list = [] 49 | for y_pred, y_true in zip(y_preds, y_trues): 50 | shape = y_pred.shape 51 | # 乘以seq_len是因为(i, j)在展开到seq_len*seq_len维度对应的下标是i*seq_len+j 52 | y_true = y_true[..., 0] * shape[2] + y_true[..., 1] # [btz, heads, 实体起终点的下标] 53 | y_pred = y_pred.reshape(shape[0], -1, np.prod(shape[2:])) # [btz, heads, seq_len*seq_len] 54 | loss = super().forward(y_pred, y_true.long()) 55 | loss = torch.mean(torch.sum(loss, dim=1)) 56 | loss_list.append(loss) 57 | return {'loss': sum(loss_list)/3, 'entity_loss': loss_list[0], 'head_loss': loss_list[1], 'tail_loss': loss_list[2]} 58 | 59 | 60 | def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): 61 | '''Returns: [seq_len, d_hid] 62 | ''' 63 | position = torch.arange(0, n_position, dtype=torch.float).unsqueeze(1) 64 | div_term = torch.exp(torch.arange(0, d_hid, 2).float() * (-math.log(10000.0) / d_hid)) 65 | embeddings_table = torch.zeros(n_position, d_hid) 66 | embeddings_table[:, 0::2] = torch.sin(position * div_term) 67 | embeddings_table[:, 1::2] = torch.cos(position * div_term) 68 | return embeddings_table 69 | 70 | 71 | class RoPEPositionEncoding(nn.Module): 72 | """旋转式位置编码: https://kexue.fm/archives/8265 73 | """ 74 | def __init__(self, max_position, embedding_size): 75 | super(RoPEPositionEncoding, self).__init__() 76 | position_embeddings = get_sinusoid_encoding_table(max_position, embedding_size) # [seq_len, hdsz] 77 | cos_position = position_embeddings[:, 1::2].repeat_interleave(2, dim=-1) 78 | sin_position = position_embeddings[:, ::2].repeat_interleave(2, dim=-1) 79 | # register_buffer是为了最外层model.to(device),不用内部指定device 80 | self.register_buffer('cos_position', cos_position) 81 | self.register_buffer('sin_position', sin_position) 82 | 83 | def forward(self, qw, seq_dim=-2): 84 | # 默认最后两个维度为[seq_len, hdsz] 85 | seq_len = qw.shape[seq_dim] 86 | qw2 = torch.stack([-qw[..., 1::2], qw[..., ::2]], dim=-1).reshape_as(qw) 87 | return qw * self.cos_position[:seq_len] + qw2 * self.sin_position[:seq_len] 88 | 89 | 90 | class EfficientGlobalPointer(nn.Module): 91 | """更加参数高效的GlobalPointer 92 | 参考:https://kexue.fm/archives/8877 93 | """ 94 | def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True): 95 | super().__init__() 96 | self.heads = heads 97 | self.head_size = head_size 98 | self.RoPE = RoPE 99 | self.tril_mask = tril_mask 100 | 101 | self.p_dense = nn.Linear(hidden_size, head_size * 2, bias=use_bias) 102 | self.q_dense = nn.Linear(head_size * 2, heads * 2, bias=use_bias) 103 | if self.RoPE: 104 | self.position_embedding = RoPEPositionEncoding(max_len, head_size) 105 | 106 | def forward(self, inputs, mask=None): 107 | ''' inputs: [..., hdsz] 108 | mask: [bez, seq_len], padding部分为0 109 | ''' 110 | sequence_output = self.p_dense(inputs) # [..., head_size*2] 111 | qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., heads, head_size] 112 | 113 | # ROPE编码 114 | if self.RoPE: 115 | qw = self.position_embedding(qw) 116 | kw = self.position_embedding(kw) 117 | 118 | # 计算内积 119 | logits = torch.einsum('bmd,bnd->bmn', qw, kw) / self.head_size**0.5 # [btz, seq_len, seq_len] 120 | bias_input = self.q_dense(sequence_output) # [..., heads*2] 121 | bias = torch.stack(torch.chunk(bias_input, self.heads, dim=-1), dim=-2).transpose(1,2) # [btz, head_size, seq_len,2] 122 | logits = logits.unsqueeze(1) + bias[..., :1] + bias[..., 1:].transpose(2, 3) # [btz, head_size, seq_len, seq_len] 123 | 124 | # 排除padding 125 | if mask is not None: 126 | attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] 127 | attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] 128 | logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf')) 129 | logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf')) 130 | 131 | # 排除下三角 132 | if self.tril_mask: 133 | logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 134 | 135 | return logits 136 | 137 | 138 | class GlobalPointer(nn.Module): 139 | """全局指针模块 140 | 将序列的每个(start, end)作为整体来进行判断 141 | 参考:https://kexue.fm/archives/8373 142 | """ 143 | def __init__(self, hidden_size, heads, head_size, RoPE=True, max_len=512, use_bias=True, tril_mask=True): 144 | super().__init__() 145 | self.heads = heads 146 | self.head_size = head_size 147 | self.RoPE = RoPE 148 | self.tril_mask = tril_mask 149 | 150 | self.dense = nn.Linear(hidden_size, heads * head_size * 2, bias=use_bias) 151 | if self.RoPE: 152 | self.position_embedding = RoPEPositionEncoding(max_len, head_size) 153 | 154 | 155 | 156 | def forward(self, inputs, mask=None): 157 | ''' inputs: [..., hdsz] 158 | mask: [bez, seq_len], padding部分为0 159 | ''' 160 | # [batchsize, 150, 8*64*2] 161 | sequence_output = self.dense(inputs) # [..., heads*head_size*2] 162 | # torch.chunk(sequence_output, self.heads, dim=-1) 8个(batchsize, 150, 64*2) 163 | # [batchsize, 150, 8, 64*2] 164 | sequence_output = torch.stack(torch.chunk(sequence_output, self.heads, dim=-1), dim=-2) # [..., heads, head_size*2] 165 | # qw:[batchsize, 150, 8, 64], kw:[batchsize, 150, 8, 64] 166 | qw, kw = sequence_output[..., :self.head_size], sequence_output[..., self.head_size:] # [..., heads, head_size] 167 | 168 | # ROPE编码 169 | if self.RoPE: 170 | qw = self.position_embedding(qw) 171 | kw = self.position_embedding(kw) 172 | 173 | # 计算内积 174 | logits = torch.einsum('bmhd,bnhd->bhmn', qw, kw) # [btz, heads, seq_len, seq_len] 175 | 176 | # 排除padding 177 | if mask is not None: 178 | attention_mask1 = 1 - mask.unsqueeze(1).unsqueeze(3) # [btz, 1, seq_len, 1] 179 | attention_mask2 = 1 - mask.unsqueeze(1).unsqueeze(2) # [btz, 1, 1, seq_len] 180 | logits = logits.masked_fill(attention_mask1.bool(), value=-float('inf')) 181 | logits = logits.masked_fill(attention_mask2.bool(), value=-float('inf')) 182 | 183 | # 排除下三角 184 | if self.tril_mask: 185 | logits = logits - torch.tril(torch.ones_like(logits), -1) * 1e12 186 | 187 | return logits / self.head_size**0.5 188 | 189 | 190 | class GlobalPointerRe(nn.Module): 191 | def __init__(self, args): 192 | super().__init__() 193 | self.bert = BertModel.from_pretrained(args.bert_dir, output_hidden_states=True, 194 | hidden_dropout_prob=args.dropout_prob) 195 | self.entity_output = GlobalPointer(hidden_size=768, heads=2, head_size=64) 196 | self.head_output = GlobalPointer(hidden_size=768, heads=args.num_tags, head_size=64, RoPE=False, tril_mask=False) 197 | self.tail_output = GlobalPointer(hidden_size=768, heads=args.num_tags, head_size=64, RoPE=False, tril_mask=False) 198 | self.criterion = MyLoss(mask_zero=True) 199 | 200 | def forward(self, 201 | token_ids, 202 | attention_masks, 203 | token_type_ids, 204 | head_labels=None, 205 | tail_labels=None, 206 | entity_labels=None): 207 | bert_output = self.bert(token_ids, attention_masks, token_type_ids) # [btz, seq_len, hdsz] 208 | hidden_states = bert_output[0] 209 | mask = attention_masks 210 | 211 | entity_output = self.entity_output(hidden_states, mask) # [btz, heads, seq_len, seq_len] 212 | head_output = self.head_output(hidden_states, mask) # [btz, heads, seq_len, seq_len] 213 | tail_output = self.tail_output(hidden_states, mask) # [btz, heads, seq_len, seq_len] 214 | if head_labels is None: 215 | return entity_output, head_output, tail_output 216 | loss = self.criterion([entity_output, head_output, tail_output], [entity_labels, head_labels, tail_labels]) 217 | return loss -------------------------------------------------------------------------------- /model_hub/占位.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_GlobalPointer_triple_extraction/538c9b9261c4b6c6436a6831d09e88552a2db774/model_hub/占位.txt -------------------------------------------------------------------------------- /utils/__pycache__/common_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_GlobalPointer_triple_extraction/538c9b9261c4b6c6436a6831d09e88552a2db774/utils/__pycache__/common_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/metric_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_GlobalPointer_triple_extraction/538c9b9261c4b6c6436a6831d09e88552a2db774/utils/__pycache__/metric_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taishan1994/pytorch_GlobalPointer_triple_extraction/538c9b9261c4b6c6436a6831d09e88552a2db774/utils/__pycache__/train_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import random 3 | import os 4 | import json 5 | import logging 6 | import time 7 | import pickle 8 | import numpy as np 9 | import torch 10 | from torch.nn.utils.rnn import pad_sequence 11 | 12 | 13 | def trans_ij2k(seq_len, i, j): 14 | '''把第i行,第j列转化成上三角flat后的序号 15 | ''' 16 | if (i > seq_len - 1) or (j > seq_len - 1) or (i > j): 17 | return 0 18 | return int(0.5*(2*seq_len-i+1)*i+(j-i)) 19 | 20 | def sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post'): 21 | """将序列padding到同一长度 22 | """ 23 | if isinstance(inputs[0], (np.ndarray, list)): 24 | if length is None: 25 | length = np.max([np.shape(x)[:seq_dims] for x in inputs], axis=0) 26 | elif not hasattr(length, '__getitem__'): 27 | length = [length] 28 | 29 | slices = [np.s_[:length[i]] for i in range(seq_dims)] 30 | slices = tuple(slices) if len(slices) > 1 else slices[0] 31 | pad_width = [(0, 0) for _ in np.shape(inputs[0])] 32 | 33 | outputs = [] 34 | for x in inputs: 35 | x = x[slices] 36 | for i in range(seq_dims): 37 | if mode == 'post': 38 | pad_width[i] = (0, length[i] - np.shape(x)[i]) 39 | elif mode == 'pre': 40 | pad_width[i] = (length[i] - np.shape(x)[i], 0) 41 | else: 42 | raise ValueError('"mode" argument must be "post" or "pre".') 43 | x = np.pad(x, pad_width, 'constant', constant_values=value) 44 | outputs.append(x) 45 | 46 | return np.array(outputs) 47 | 48 | elif isinstance(inputs[0], torch.Tensor): 49 | assert mode == 'post', '"mode" argument must be "post" when element is torch.Tensor' 50 | if length is not None: 51 | inputs = [i[:length] for i in inputs] 52 | return pad_sequence(inputs, padding_value=value, batch_first=True) 53 | else: 54 | raise ValueError('"input" argument must be tensor/list/ndarray.') 55 | 56 | 57 | def timer(func): 58 | """ 59 | 函数计时器 60 | :param func: 61 | :return: 62 | """ 63 | 64 | @functools.wraps(func) 65 | def wrapper(*args, **kwargs): 66 | start = time.time() 67 | res = func(*args, **kwargs) 68 | end = time.time() 69 | print("{}共耗时约{:.4f}秒".format(func.__name__, end - start)) 70 | return res 71 | 72 | return wrapper 73 | 74 | 75 | def set_seed(seed=123): 76 | """ 77 | 设置随机数种子,保证实验可重现 78 | :param seed: 79 | :return: 80 | """ 81 | random.seed(seed) 82 | torch.manual_seed(seed) 83 | np.random.seed(seed) 84 | torch.cuda.manual_seed_all(seed) 85 | 86 | 87 | def set_logger(log_path): 88 | """ 89 | 配置log 90 | :param log_path:s 91 | :return: 92 | """ 93 | logger = logging.getLogger() 94 | logger.setLevel(logging.INFO) 95 | 96 | # 由于每调用一次set_logger函数,就会创建一个handler,会造成重复打印的问题,因此需要判断root logger中是否已有该handler 97 | if not any(handler.__class__ == logging.FileHandler for handler in logger.handlers): 98 | file_handler = logging.FileHandler(log_path) 99 | formatter = logging.Formatter( 100 | '%(asctime)s - %(levelname)s - %(filename)s - %(funcName)s - %(lineno)d - %(message)s') 101 | file_handler.setFormatter(formatter) 102 | logger.addHandler(file_handler) 103 | 104 | if not any(handler.__class__ == logging.StreamHandler for handler in logger.handlers): 105 | stream_handler = logging.StreamHandler() 106 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 107 | logger.addHandler(stream_handler) 108 | 109 | 110 | def save_json(data_dir, data, desc): 111 | """保存数据为json""" 112 | with open(os.path.join(data_dir, '{}.json'.format(desc)), 'w', encoding='utf-8') as f: 113 | json.dump(data, f, ensure_ascii=False, indent=2) 114 | 115 | 116 | def read_json(data_dir, desc): 117 | """读取数据为json""" 118 | with open(os.path.join(data_dir, '{}.json'.format(desc)), 'r', encoding='utf-8') as f: 119 | data = json.load(f) 120 | return data 121 | 122 | 123 | def save_pkl(data_dir, data, desc): 124 | """保存.pkl文件""" 125 | with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'wb') as f: 126 | pickle.dump(data, f) 127 | 128 | 129 | def read_pkl(data_dir, desc): 130 | """读取.pkl文件""" 131 | with open(os.path.join(data_dir, '{}.pkl'.format(desc)), 'rb') as f: 132 | data = pickle.load(f) 133 | return data 134 | 135 | 136 | def fine_grade_tokenize(raw_text, tokenizer): 137 | """ 138 | 序列标注任务 BERT 分词器可能会导致标注偏移, 139 | 用 char-level 来 tokenize 140 | """ 141 | tokens = [] 142 | 143 | for _ch in raw_text: 144 | if _ch in [' ', '\t', '\n']: 145 | tokens.append('[BLANK]') 146 | else: 147 | if not len(tokenizer.tokenize(_ch)): 148 | tokens.append('[INV]') 149 | else: 150 | tokens.append(_ch) 151 | 152 | return tokens -------------------------------------------------------------------------------- /utils/metric_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | 9 | 10 | def calculate_metric_relation(predict, gt): 11 | tp, fp, fn = 0, 0, 0 12 | for entity_predict, entity_gt in zip(predict, gt): 13 | # print(entity_predict, entity_gt) 14 | flag = 0 15 | for ent in entity_predict: 16 | if ent in entity_gt: 17 | tp += 1 18 | else: 19 | fp += 1 20 | fn = sum([len(i) for i in gt]) - tp 21 | return tp, fp, fn 22 | 23 | 24 | 25 | def calculate_metric(gt, predict): 26 | """ 27 | 计算 tp fp fn 28 | """ 29 | tp, fp, fn = 0, 0, 0 30 | for entity_predict in predict: 31 | flag = 0 32 | for entity_gt in gt: 33 | if entity_predict[0] == entity_gt[0] and entity_predict[1] == entity_gt[1]: 34 | flag = 1 35 | tp += 1 36 | break 37 | if flag == 0: 38 | fp += 1 39 | 40 | fn = len(gt) - tp 41 | 42 | return np.array([tp, fp, fn]) 43 | 44 | 45 | def get_p_r_f(tp, fp, fn): 46 | p = tp / (tp + fp) if tp + fp != 0 else 0 47 | r = tp / (tp + fn) if tp + fn != 0 else 0 48 | f1 = 2 * p * r / (p + r) if p + r != 0 else 0 49 | return np.array([p, r, f1]) 50 | 51 | def classification_report(metrics_matrix, label_list, id2label, total_count, digits=2, suffix=False): 52 | name_width = max([len(label) for label in label_list]) 53 | last_line_heading = 'micro-f1' 54 | width = max(name_width, len(last_line_heading), digits) 55 | 56 | headers = ["precision", "recall", "f1-score", "support"] 57 | head_fmt = u'{:>{width}s} ' + u' {:>9}' * len(headers) 58 | report = head_fmt.format(u'', *headers, width=width) 59 | report += u'\n\n' 60 | 61 | row_fmt = u'{:>{width}s} ' + u' {:>9.{digits}f}' * 3 + u' {:>9}\n' 62 | 63 | ps, rs, f1s, s = [], [], [], [] 64 | for label_id, label_matrix in enumerate(metrics_matrix): 65 | type_name = id2label[label_id] 66 | p,r,f1 = get_p_r_f(label_matrix[0],label_matrix[1],label_matrix[2]) 67 | nb_true = total_count[label_id] 68 | report += row_fmt.format(*[type_name, p, r, f1, nb_true], width=width, digits=digits) 69 | ps.append(p) 70 | rs.append(r) 71 | f1s.append(f1) 72 | s.append(nb_true) 73 | 74 | report += u'\n' 75 | mirco_metrics = np.sum(metrics_matrix, axis=0) 76 | mirco_metrics = get_p_r_f(mirco_metrics[0], mirco_metrics[1], mirco_metrics[2]) 77 | # compute averages 78 | print('precision:{:.4f} recall:{:.4f} micro_f1:{:.4f}'.format(mirco_metrics[0],mirco_metrics[1],mirco_metrics[2])) 79 | report += row_fmt.format(last_line_heading, 80 | mirco_metrics[0], 81 | mirco_metrics[1], 82 | mirco_metrics[2], 83 | np.sum(s), 84 | width=width, digits=digits) 85 | 86 | return 87 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import os 3 | import logging 4 | from transformers import AdamW, get_linear_schedule_with_warmup 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def build_optimizer_and_scheduler(args, model, t_total): 11 | module = ( 12 | model.module if hasattr(model, "module") else model 13 | ) 14 | 15 | # 差分学习率 16 | no_decay = ["bias", "LayerNorm.weight"] 17 | model_param = list(module.named_parameters()) 18 | 19 | bert_param_optimizer = [] 20 | other_param_optimizer = [] 21 | 22 | for name, para in model_param: 23 | space = name.split('.') 24 | # print(name) 25 | if space[0] == 'bert_module': 26 | bert_param_optimizer.append((name, para)) 27 | else: 28 | other_param_optimizer.append((name, para)) 29 | 30 | optimizer_grouped_parameters = [ 31 | # bert other module 32 | {"params": [p for n, p in bert_param_optimizer if not any(nd in n for nd in no_decay)], 33 | "weight_decay": args.weight_decay, 'lr': args.lr}, 34 | {"params": [p for n, p in bert_param_optimizer if any(nd in n for nd in no_decay)], 35 | "weight_decay": 0.0, 'lr': args.lr}, 36 | 37 | # 其他模块,差分学习率 38 | {"params": [p for n, p in other_param_optimizer if not any(nd in n for nd in no_decay)], 39 | "weight_decay": args.weight_decay, 'lr': args.other_lr}, 40 | {"params": [p for n, p in other_param_optimizer if any(nd in n for nd in no_decay)], 41 | "weight_decay": 0.0, 'lr': args.other_lr}, 42 | ] 43 | 44 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=args.adam_epsilon) 45 | scheduler = get_linear_schedule_with_warmup( 46 | optimizer, num_warmup_steps=int(args.warmup_proportion * t_total), num_training_steps=t_total 47 | ) 48 | 49 | return optimizer, scheduler 50 | 51 | def save_model(args, model, model_name, global_step): 52 | """保存最好的验证集效果最好那个模型""" 53 | output_dir = os.path.join(args.output_dir, '{}'.format(model_name, global_step)) 54 | if not os.path.exists(output_dir): 55 | os.makedirs(output_dir, exist_ok=True) 56 | 57 | # take care of model distributed / parallel training 58 | model_to_save = ( 59 | model.module if hasattr(model, "module") else model 60 | ) 61 | logger.info('Saving model checkpoint to {}'.format(output_dir)) 62 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt')) 63 | 64 | def save_model_step(args, model, global_step): 65 | """根据global_step来保存模型""" 66 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 67 | if not os.path.exists(output_dir): 68 | os.makedirs(output_dir, exist_ok=True) 69 | 70 | # take care of model distributed / parallel training 71 | model_to_save = ( 72 | model.module if hasattr(model, "module") else model 73 | ) 74 | logger.info('Saving model & optimizer & scheduler checkpoint to {}.format(output_dir)') 75 | torch.save(model_to_save.state_dict(), os.path.join(output_dir, 'model.pt')) 76 | 77 | def load_model_and_parallel(model, gpu_ids, ckpt_path=None, strict=True): 78 | """ 79 | 加载模型 & 放置到 GPU 中(单卡 / 多卡) 80 | """ 81 | gpu_ids = gpu_ids.split(',') 82 | 83 | # set to device to the first cuda 84 | device = torch.device("cpu" if gpu_ids[0] == '-1' else "cuda:" + gpu_ids[0]) 85 | 86 | if ckpt_path is not None: 87 | logger.info('Load ckpt from {}'.format(ckpt_path)) 88 | model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu')), strict=strict) 89 | 90 | model.to(device) 91 | 92 | if len(gpu_ids) > 1: 93 | logger.info('Use multi gpus in: {}'.format(gpu_ids)) 94 | gpu_ids = [int(x) for x in gpu_ids] 95 | model = torch.nn.DataParallel(model, device_ids=gpu_ids) 96 | else: 97 | logger.info('Use single gpu in: {}'.format(gpu_ids)) 98 | 99 | return model, device --------------------------------------------------------------------------------