├── deploy ├── test.json ├── rel_predict.json └── demo.py ├── requirements.txt ├── mains ├── __init__.py ├── trainer_ner.py ├── trainer.py └── trainer_rel.py ├── modules ├── __init__.py ├── model_ner.py ├── model_rel.py └── joint_model.py ├── utils ├── __init__.py ├── config.py ├── config_ner.py └── config_rel.py ├── data_loader ├── __init__.py ├── rel2id.json ├── token_type2id.json ├── process_rel.py ├── process_ner.py └── data_process.py ├── README.md └── .gitignore /deploy/test.json: -------------------------------------------------------------------------------- 1 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲"} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy=1.19.2 2 | tqdm=4.56.0 3 | seqeval=1.2.2 4 | transformers=4.3.2 5 | pytorch=1.7.1 6 | pytorch-crf=0.7.2 7 | -------------------------------------------------------------------------------- /mains/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/6 21:18 5 | # @File : __init__.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/6 21:11 5 | # @File : __init__.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/6 21:11 5 | # @File : __init__.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/6 21:11 5 | # @File : __init__.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ -------------------------------------------------------------------------------- /data_loader/rel2id.json: -------------------------------------------------------------------------------- 1 | {"N": 0, "丈夫": 1, "上映时间": 2, "专业代码": 3, "主持人": 4, "主演": 5, "主角": 6, "人口数量": 7, "作曲": 8, "作者": 9, "作词": 10, "修业年限": 11, "出品公司": 12, "出版社": 13, "出生地": 14, "出生日期": 15, "创始人": 16, "制片人": 17, "占地面积": 18, "号": 19, "嘉宾": 20, "国籍": 21, "妻子": 22, "字": 23, "官方语言": 24, "导演": 25, "总部地点": 26, "成立日期": 27, "所在城市": 28, "所属专辑": 29, "改编自": 30, "朝代": 31, "歌手": 32, "母亲": 33, "毕业院校": 34, "民族": 35, "气候": 36, "注册资本": 37, "海拔": 38, "父亲": 39, "目": 40, "祖籍": 41, "简称": 42, "编剧": 43, "董事长": 44, "身高": 45, "连载网站": 46, "邮政编码": 47, "面积": 48, "首都": 49} -------------------------------------------------------------------------------- /data_loader/token_type2id.json: -------------------------------------------------------------------------------- 1 | {"B-Date": 0, "I-Date": 1, "B-Number": 2, "I-Number": 3, "B-Text": 4, "I-Text": 5, "B-书籍": 6, "I-书籍": 7, "B-人物": 8, "I-人物": 9, "B-企业": 10, "I-企业": 11, "B-作品": 12, "I-作品": 13, "B-出版社": 14, "I-出版社": 15, "B-历史人物": 16, "I-历史人物": 17, "B-国家": 18, "I-国家": 19, "B-图书作品": 20, "I-图书作品": 21, "B-地点": 22, "I-地点": 23, "B-城市": 24, "I-城市": 25, "B-学校": 26, "I-学校": 27, "B-学科专业": 28, "I-学科专业": 29, "B-影视作品": 30, "I-影视作品": 31, "B-景点": 32, "I-景点": 33, "B-机构": 34, "I-机构": 35, "B-歌曲": 36, "I-歌曲": 37, "B-气候": 38, "I-气候": 39, "B-生物": 40, "I-生物": 41, "B-电视综艺": 42, "I-电视综艺": 43, "B-目": 44, "I-目": 45, "B-网站": 46, "I-网站": 47, "B-网络小说": 48, "I-网络小说": 49, "B-行政区": 50, "I-行政区": 51, "B-语言": 52, "I-语言": 53, "B-音乐专辑": 54, "I-音乐专辑": 55, "O": 56} -------------------------------------------------------------------------------- /deploy/rel_predict.json: -------------------------------------------------------------------------------- 1 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "大飞", "object": "深白色"}]} 2 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "大飞", "object": "直线"}]} 3 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "深白色", "object": "大飞"}]} 4 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "深白色", "object": "直线"}]} 5 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "直线", "object": "大飞"}]} 6 | {"text": "歌曲介绍由大飞作词,深白色作曲,吕绍淳编曲的歌曲《直线》,属于专辑中深具爆发力的一首歌曲", "spo_list": [{"subject": "直线", "object": "深白色"}]} 7 | {"text": "如何演好自己的角色,请读《演员自我修养》《喜剧之王》周星驰崛起于穷困潦倒之中的独门秘笈", "spo_list": [{"object": "周星驰", "subject": "喜剧之王"}]} 8 | {"text": "茶树茶网蝽,Stephanitis chinensis Drake,属半翅目网蝽科冠网椿属的一种昆虫", "spo_list": [{"object": "半翅目", "subject": "茶树茶网蝽"}]} 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EntityRelationExtraction 2 | #### 引流 3 | 4 | 完整的可视化demo见我的另一个[仓库](https://github.com/Xie-Minghui/MultiHeadJointEntityRelationExtraction_simple) 5 | 6 | #### 项目说明 7 | 8 | 项目使用pytorch实现实体关系抽取中的流水线式模型。 9 | 10 | 命名实体识别部分使用的是BiLSTM+CRF。 11 | 12 | 实体关系抽取使用的是Bert进行关系分类。 13 | 14 | 最终的效果比较好。 15 | 16 | #### 数据集说明 17 | 18 | 百度的DUIE数据集是业界规模最大的中文信息抽取数据集。它包含了43万三元组数据、21万中文句子。 19 | 20 | 句子的平均长度为54,每句话中的三元组数量的平均值为2.1。 21 | 下面是一个样本: 22 | {"text": "据了解,《小姨多鹤》主要在大连和丹东大梨树影视城取景,是导演安建继《北风那个吹》之后拍摄的又一部极具东北文化气息的作品", 23 | "spo_list": [{ 24 | "predicate": "导演", 25 | "object_type": "人物", 26 | "subject_type": "影视作品", 27 | "object": "安建", 28 | "subject": "小姨多鹤" 29 | }, { 30 | "predicate": "导演", 31 | "object_type": "人物", 32 | "subject_type": "影视作品", 33 | "object": "安建", 34 | "subject": "北风那个吹" 35 | }] 36 | } 37 | 38 | ##### 数据集和预训练模型的下载 39 | 数据集的下载 40 | 链接:https://pan.baidu.com/s/1XK3v6BQlnsvhGxgg-71IpA 41 | 提取码:nlp0 42 | 43 | 预训练模型和相关文件见 44 | https://huggingface.co/bert-base-chinese/tree/main 45 | #### 使用说明 46 | 47 | ##### 训练 48 | Bert模型使用的是huggingface的Bert-base模型 49 | 50 | 命名实体部分的训练,直接运行mains/train_ner.py 51 | 52 | 关系抽取部分的训练,直接运行mains/train_rel.py 53 | 54 | train.py,config.py是之前联合抽取的代码,在这个项目作废。 55 | 56 | ##### 测试 57 | 58 | 直接运行deploy/demo.py。会首先进行命名实体识别,然后将实体两两组成实体对进行关系分类。 59 | 60 | [![Star History Chart](https://api.star-history.com/svg?repos=Xie-Minghui/EntityRelationExtraction&type=Timeline)](https://star-history.com/#Xie-Minghui/EntityRelationExtraction&Date) 61 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | models/ 3 | bert-base-chinese/ 4 | .idea/ 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 20:42 5 | # @File : config.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | 13 | if torch.cuda.is_available(): 14 | USE_CUDA = True 15 | print("USE_CUDA....") 16 | else: 17 | USE_CUDA = False 18 | 19 | 20 | class Config: 21 | def __init__(self, 22 | lr=0.001, 23 | epochs=100, 24 | vocab_size=22000, 25 | embedding_dim=32, 26 | hidden_dim_lstm=128, 27 | num_layers=3, 28 | batch_size=32, 29 | layer_size=128, 30 | token_type_dim=8 31 | ): 32 | self.lr = lr 33 | self.epochs = epochs 34 | self.vocab_size = vocab_size 35 | self.embedding_dim = embedding_dim 36 | self.hidden_dim_lstm = hidden_dim_lstm 37 | self.num_layers = num_layers 38 | self.batch_size = batch_size 39 | self.layer_size = layer_size 40 | self.token_type_dim = token_type_dim 41 | self.relations = ["N", '丈夫', '上映时间', '专业代码', '主持人', '主演', '主角', '人口数量', '作曲', '作者', '作词', '修业年限', '出品公司', '出版社', '出生地', 42 | '出生日期', '创始人', '制片人', '占地面积', '号', '嘉宾', '国籍', '妻子', '字', '官方语言', '导演', '总部地点', '成立日期', '所在城市', '所属专辑', 43 | '改编自', '朝代', '歌手', '母亲', '毕业院校', '民族', '气候', '注册资本', '海拔', '父亲', '目', '祖籍', '简称', '编剧', '董事长', '身高', 44 | '连载网站', '邮政编码', '面积', '首都'] 45 | self.num_relations = len(self.relations) 46 | self.token_types_origin = ['Date', 'Number', 'Text', '书籍', '人物', '企业', '作品', '出版社', '历史人物', '国家', '图书作品', '地点', '城市', '学校', '学科专业', 47 | '影视作品', '景点', '机构', '歌曲', '气候', '生物', '电视综艺', '目', '网站', '网络小说', '行政区', '语言', '音乐专辑'] 48 | self.token_types = self.get_token_types() 49 | self.num_token_type = len(self.token_types) 50 | self.vocab_file = '../data/vocab.txt' 51 | self.max_seq_length = 256 52 | self.num_sample = 204800 53 | 54 | self.dropout_embedding = 0.1 # 从0.2到0.1 55 | self.dropout_lstm = 0.1 56 | self.dropout_lstm_output = 0.9 57 | self.dropout_head = 0.9 # 只更改这个参数 0.9到0.5 58 | self.dropout_ner = 0.8 59 | self.use_dropout = True 60 | self.threshold_rel = 0.95 # 从0.7到0.95 61 | self.teach_rate = 0.2 62 | self.checkpoint_path = '../models/' 63 | 64 | def get_token_types(self): 65 | token_type_bio = [] 66 | for token_type in self.token_types_origin: 67 | token_type_bio.append('B-' + token_type) 68 | token_type_bio.append('I-' + token_type) 69 | token_type_bio.append('O') 70 | 71 | return token_type_bio 72 | 73 | -------------------------------------------------------------------------------- /utils/config_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 20:42 5 | # @File : config_ner.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | 13 | if torch.cuda.is_available(): 14 | USE_CUDA = True 15 | print("USE_CUDA....") 16 | else: 17 | USE_CUDA = False 18 | 19 | 20 | class ConfigNer: 21 | def __init__(self, 22 | lr=0.001, 23 | epochs=100, 24 | vocab_size=22000, 25 | embedding_dim=32, 26 | hidden_dim_lstm=128, 27 | num_layers=3, 28 | batch_size=32, 29 | layer_size=128, 30 | token_type_dim=8 31 | ): 32 | self.lr = lr 33 | self.epochs = epochs 34 | self.vocab_size = vocab_size 35 | self.embedding_dim = embedding_dim 36 | self.hidden_dim_lstm = hidden_dim_lstm 37 | self.num_layers = num_layers 38 | self.batch_size = batch_size 39 | self.layer_size = layer_size 40 | self.token_type_dim = token_type_dim 41 | self.relations = ["N", '丈夫', '上映时间', '专业代码', '主持人', '主演', '主角', '人口数量', '作曲', '作者', '作词', '修业年限', '出品公司', '出版社', '出生地', 42 | '出生日期', '创始人', '制片人', '占地面积', '号', '嘉宾', '国籍', '妻子', '字', '官方语言', '导演', '总部地点', '成立日期', '所在城市', '所属专辑', 43 | '改编自', '朝代', '歌手', '母亲', '毕业院校', '民族', '气候', '注册资本', '海拔', '父亲', '目', '祖籍', '简称', '编剧', '董事长', '身高', 44 | '连载网站', '邮政编码', '面积', '首都'] 45 | self.num_relations = len(self.relations) 46 | self.token_types_origin = ['Date', 'Number', 'Text', '书籍', '人物', '企业', '作品', '出版社', '历史人物', '国家', '图书作品', '地点', '城市', '学校', '学科专业', 47 | '影视作品', '景点', '机构', '歌曲', '气候', '生物', '电视综艺', '目', '网站', '网络小说', '行政区', '语言', '音乐专辑'] 48 | self.token_types = self.get_token_types() 49 | self.num_token_type = len(self.token_types) 50 | self.vocab_file = '../data/vocab.txt' 51 | self.max_seq_length = 256 52 | self.num_sample = 204800 53 | 54 | self.dropout_embedding = 0.1 # 从0.2到0.1 55 | self.dropout_lstm = 0.1 56 | self.dropout_lstm_output = 0.9 57 | self.dropout_head = 0.9 # 只更改这个参数 0.9到0.5 58 | self.dropout_ner = 0.8 59 | self.use_dropout = True 60 | self.threshold_rel = 0.95 # 从0.7到0.95 61 | self.teach_rate = 0.2 62 | self.ner_checkpoint_path = '../models/sequence_labeling/' 63 | 64 | def get_token_types(self): 65 | token_type_bio = [] 66 | for token_type in self.token_types_origin: 67 | token_type_bio.append('B-' + token_type) 68 | token_type_bio.append('I-' + token_type) 69 | token_type_bio.append('O') 70 | 71 | return token_type_bio 72 | 73 | -------------------------------------------------------------------------------- /utils/config_rel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/11 16:53 5 | # @File : config_rel.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | 13 | if torch.cuda.is_available(): 14 | USE_CUDA = True 15 | print("USE_CUDA....") 16 | else: 17 | USE_CUDA = False 18 | 19 | 20 | class ConfigRel: 21 | def __init__(self, 22 | lr=0.001, 23 | epochs=100, 24 | vocab_size=16116, # 22000, 25 | embedding_dim=100, 26 | hidden_dim_lstm=128, 27 | num_layers=3, 28 | batch_size=32, 29 | layer_size=128, 30 | token_type_dim=8 31 | ): 32 | self.lr = lr 33 | self.epochs = epochs 34 | self.vocab_size = vocab_size 35 | self.embedding_dim = embedding_dim 36 | self.hidden_dim_lstm = hidden_dim_lstm 37 | self.num_layers = num_layers 38 | self.batch_size = batch_size 39 | self.layer_size = layer_size 40 | self.token_type_dim = token_type_dim 41 | self.relations = ["N", '丈夫', '上映时间', '专业代码', '主持人', '主演', '主角', '人口数量', '作曲', '作者', '作词', '修业年限', '出品公司', '出版社', 42 | '出生地', 43 | '出生日期', '创始人', '制片人', '占地面积', '号', '嘉宾', '国籍', '妻子', '字', '官方语言', '导演', '总部地点', '成立日期', 44 | '所在城市', '所属专辑', 45 | '改编自', '朝代', '歌手', '母亲', '毕业院校', '民族', '气候', '注册资本', '海拔', '父亲', '目', '祖籍', '简称', '编剧', '董事长', 46 | '身高', 47 | '连载网站', '邮政编码', '面积', '首都'] 48 | self.num_relations = len(self.relations) 49 | self.token_types_origin = ['Date', 'Number', 'Text', '书籍', '人物', '企业', '作品', '出版社', '历史人物', '国家', '图书作品', '地点', 50 | '城市', '学校', '学科专业', 51 | '影视作品', '景点', '机构', '歌曲', '气候', '生物', '电视综艺', '目', '网站', '网络小说', '行政区', '语言', '音乐专辑'] 52 | self.token_types = self.get_token_types() 53 | self.num_token_type = len(self.token_types) 54 | self.vocab_file = '../data/vocab.txt' 55 | self.max_seq_length = 256 56 | self.num_sample = 1480 57 | 58 | self.dropout_embedding = 0.1 # 从0.2到0.1 59 | self.dropout_lstm = 0.1 60 | self.dropout_lstm_output = 0.9 61 | self.dropout_head = 0.9 # 只更改这个参数 0.9到0.5 62 | self.dropout_ner = 0.8 63 | self.use_dropout = True 64 | self.threshold_rel = 0.95 # 从0.7到0.95 65 | self.teach_rate = 0.2 66 | self.ner_checkpoint_path = '../models/rel_cls/' 67 | self.pretrained = False 68 | self.pad_token_id = 0 69 | self.rel_num = 500 70 | 71 | self.pos_dim = 32 72 | 73 | def get_token_types(self): 74 | token_type_bio = [] 75 | for token_type in self.token_types_origin: 76 | token_type_bio.append('B-' + token_type) 77 | token_type_bio.append('I-' + token_type) 78 | token_type_bio.append('O') 79 | 80 | return token_type_bio 81 | 82 | -------------------------------------------------------------------------------- /deploy/demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/11 19:34 5 | # @File : demo.py 6 | 7 | """ 8 | file description: 9 | 10 | """ 11 | import torch 12 | from modules.model_ner import SeqLabel 13 | from modules.model_rel import AttBiLSTM 14 | from utils.config_ner import ConfigNer, USE_CUDA 15 | from utils.config_rel import ConfigRel 16 | 17 | from data_loader.process_ner import ModelDataPreparation 18 | from data_loader.process_rel import DataPreparationRel 19 | 20 | from mains import trainer_ner, trainer_rel 21 | import json 22 | from transformers import BertForSequenceClassification 23 | 24 | 25 | def get_entities(pred_ner, text): 26 | token_types = [[] for _ in range(len(pred_ner))] 27 | entities = [[] for _ in range(len(pred_ner))] 28 | for i in range(len(pred_ner)): 29 | token_type = [] 30 | entity = [] 31 | j = 0 32 | word_begin = False 33 | while j < len(pred_ner[i]): 34 | if pred_ner[i][j][0] == 'B': 35 | if word_begin: 36 | token_type = [] # 防止多个B出现在一起 37 | entity = [] 38 | token_type.append(pred_ner[i][j]) 39 | entity.append(text[i][j]) 40 | word_begin = True 41 | elif pred_ner[i][j][0] == 'I': 42 | if word_begin: 43 | token_type.append(pred_ner[i][j]) 44 | entity.append(text[i][j]) 45 | else: 46 | if word_begin: 47 | token_types[i].append(''.join(token_type)) 48 | token_type = [] 49 | entities[i].append(''.join(entity)) 50 | entity = [] 51 | word_begin = False 52 | j += 1 53 | return token_types, entities 54 | 55 | 56 | def test(): 57 | # test_path = './test.json' 58 | # PATH_NER = '../models/sequence_labeling/60m-f589.90n40236.67ccks2019_ner.pth' 59 | # config_ner = ConfigNer() 60 | # # config_ner.batch_size = 1 61 | # ner_model = SeqLabel(config_ner) 62 | # ner_model_dict = torch.load(PATH_NER) 63 | # ner_model.load_state_dict(ner_model_dict['state_dict']) 64 | # 65 | # ner_data_process = ModelDataPreparation(config_ner) 66 | # _, _, test_loader = ner_data_process.get_train_dev_data(path_test=test_path) 67 | # trainerNer = trainer_ner.Trainer(ner_model, config_ner, test_dataset=test_loader) 68 | # pred_ner = trainerNer.predict() 69 | # # print("haha") 70 | # print(pred_ner) 71 | # text = None 72 | # for data_item in test_loader: 73 | # text = data_item['text'] 74 | # token_types, entities = get_entities(pred_ner, text) 75 | # print(token_types) 76 | # print(entities) 77 | # 78 | # rel_list = [] 79 | # with open('./rel_predict.json', 'w', encoding='utf-8') as f: 80 | # for i in range(len(pred_ner)): 81 | # texti = text[i] 82 | # for j in range(len(entities[i])): 83 | # for k in range(len(entities[i])): 84 | # if j == k: 85 | # continue 86 | # rel_list.append({"text":texti, "spo_list":{"subject": entities[i][j], "object": entities[i][k]}}) 87 | # json.dump(rel_list, f, ensure_ascii=False) 88 | 89 | PATH_REL = '../models/rel_cls/25m-acc0.93ccks2019_rel.pth' 90 | 91 | config_rel = ConfigRel() 92 | config_rel.batch_size = 8 93 | rel_model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=config_rel.num_relations) 94 | # rel_model = AttBiLSTM(config_rel) 95 | rel_model_dict = torch.load(PATH_REL) 96 | rel_model.load_state_dict(rel_model_dict['state_dict']) 97 | rel_test_path = './rel_predict.json' 98 | 99 | rel_data_process = DataPreparationRel(config_rel) 100 | _, _, test_loader = rel_data_process.get_train_dev_data(path_test=rel_test_path) 101 | trainREL = trainer_rel.Trainer(rel_model, config_rel, test_dataset=test_loader) 102 | rel_pred = trainREL.bert_predict() 103 | print(rel_pred) 104 | 105 | 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | test() 111 | -------------------------------------------------------------------------------- /modules/model_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 10:02 5 | # @File : model_ner.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torchcrf import CRF 15 | import numpy as np 16 | from utils.config_ner import USE_CUDA 17 | 18 | def setup_seed(seed): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | 23 | class SeqLabel(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | setup_seed(1) 27 | 28 | self.vocab_size = config.vocab_size 29 | self.embedding_dim = config.embedding_dim 30 | self.hidden_dim = config.hidden_dim_lstm 31 | self.num_layers = config.num_layers 32 | self.batch_size = config.batch_size 33 | self.layer_size = config.layer_size # self.hidden_dim, 之前这里没有改 34 | self.num_token_type = config.num_token_type # 实体类型的综述 35 | self.config = config 36 | 37 | self.word_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) 38 | self.token_type_embedding = nn.Embedding(config.num_token_type, config.token_type_dim) 39 | self.gru = nn.GRU(config.embedding_dim, config.hidden_dim_lstm, num_layers=config.num_layers, batch_first=True, 40 | bidirectional=True) 41 | self.is_train = True 42 | if USE_CUDA: 43 | self.weights_rel = (torch.ones(self.config.num_relations) * 100).cuda() 44 | else: 45 | self.weights_rel = torch.ones(self.config.num_relations) * 100 46 | self.weights_rel[0] = 1 47 | 48 | self.V_ner = nn.Parameter(torch.rand((config.num_token_type, self.layer_size))) 49 | self.U_ner = nn.Parameter(torch.rand((self.layer_size, 2 * self.hidden_dim))) 50 | self.b_s_ner = nn.Parameter(torch.rand(self.layer_size)) 51 | self.b_c_ner = nn.Parameter(torch.rand(config.num_token_type)) 52 | 53 | self.dropout_embedding_layer = torch.nn.Dropout(config.dropout_embedding) 54 | self.dropout_ner_layer = torch.nn.Dropout(config.dropout_ner) 55 | self.dropout_lstm_layer = torch.nn.Dropout(config.dropout_lstm) 56 | self.crf_model = CRF(self.num_token_type, batch_first=True) 57 | 58 | def get_ner_score(self, output_lstm): 59 | 60 | res = torch.matmul(output_lstm, self.U_ner.transpose(-1, -2)) + self.b_s_ner # [seq_len, batch, self.layer_size] 61 | res = torch.tanh(res) 62 | # res = F.leaky_relu(res, negative_slope=0.01) 63 | if self.config.use_dropout: 64 | res = self.dropout_ner_layer(res) 65 | 66 | ans = torch.matmul(res, self.V_ner.transpose(-1, -2)) + self.b_c_ner # [seq_len, batch, num_token_type] 67 | 68 | return ans 69 | 70 | def forward(self, data_item, is_test=False): 71 | # 因为不是多跳机制,所以hidden_init不能继承之前的最终隐含态 72 | ''' 73 | 74 | :param data_item: data_item = {'',} 75 | :type data_item: dict 76 | :return: 77 | :rtype: 78 | ''' 79 | # print("hello5") 80 | embeddings = self.word_embedding(data_item['text_tokened'].to(torch.int64)) # 要转化为int64 81 | if self.config.use_dropout: 82 | embeddings = self.dropout_embedding_layer(embeddings) 83 | # if hidden_init is None: 84 | # print("hello6") 85 | if USE_CUDA: 86 | hidden_init = torch.randn(2*self.num_layers, self.batch_size, self.hidden_dim).cuda() 87 | else: 88 | hidden_init = torch.randn(2 * self.num_layers, self.batch_size, self.hidden_dim) 89 | output_lstm, h_n =self.gru(embeddings, hidden_init) 90 | # output_lstm [batch, seq_len, 2*hidden_dim] h_n [2*num_layers, batch, hidden_dim] 91 | # if self.config.use_dropout: 92 | # output_lstm = self.dropout_lstm_layer(output_lstm) # 用了效果变差 93 | ner_score = self.get_ner_score(output_lstm) 94 | # 下面是使用CFR 95 | if USE_CUDA: 96 | self.crf_model = self.crf_model.cuda() 97 | if not is_test: 98 | log_likelihood = self.crf_model(ner_score, data_item['token_type_list'].to(torch.int64), 99 | mask=data_item['mask_tokens']) 100 | loss_ner = -log_likelihood 101 | 102 | pred_ner = self.crf_model.decode(ner_score) # , mask=data_item['mask_tokens'] 103 | 104 | if is_test: 105 | return pred_ner 106 | return loss_ner, pred_ner 107 | 108 | 109 | -------------------------------------------------------------------------------- /modules/model_rel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/11 16:36 5 | # @File : model_rel.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from utils.config_rel import USE_CUDA 15 | 16 | def setup_seed(seed): 17 | torch.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) # 使用相同的初始化种子,保证每次初始化结果一直,便于调试 19 | 20 | 21 | class Attention(nn.Module): 22 | def __init__(self, config): 23 | super().__init__() 24 | setup_seed(1) 25 | # self.query = nn.Parameter(torch.randn(1, config.hidden_dim_lstm)) # [batch, 1, hidden_dim] 26 | 27 | # def forward(self, H): 28 | # M = torch.tanh(H) # H [batch_size, sentence_length, hidden_dim_lstm] 29 | # attention_prob = torch.matmul(M, self.query.transpose(-1, -2)) # [batch_size, sentence_length, 1] 30 | # alpha = F.softmax(attention_prob,dim=-1) 31 | # attention_output = torch.matmul(alpha.transpose(-1, -2), H) # [batch_size, 1, hidden_dim_lstm] 32 | # attention_output = attention_output.squeeze(axis=1) 33 | # attention_output = torch.tanh(attention_output) 34 | # return attention_output 35 | 36 | def forward(self, output_lstm, hidden_lstm): 37 | hidden_lstm = torch.sum(hidden_lstm, dim=0) 38 | att_weights = torch.matmul(output_lstm, hidden_lstm.unsqueeze(2)).squeeze(2) 39 | alpha = F.softmax(att_weights, dim=1) 40 | new_hidden = torch.matmul(output_lstm.transpose(-1, -2), alpha.unsqueeze(2)).squeeze(2) 41 | 42 | return new_hidden 43 | 44 | class AttBiLSTM(nn.Module): 45 | def __init__(self, config, embedding_pre=None): 46 | super().__init__() 47 | setup_seed(1) 48 | self.embedding_dim = config.embedding_dim 49 | self.vocab_size = config.vocab_size 50 | self.hidden_dim = config.hidden_dim_lstm 51 | self.num_layers = config.num_layers 52 | self.batch_size = config.batch_size 53 | self.embed_dropout = nn.Dropout(config.dropout_embedding) 54 | self.lstm_dropout = nn.Dropout(config.dropout_lstm_output) 55 | self.pretrained = config.pretrained 56 | self.config = config 57 | self.relation_embed_layer = nn.Embedding(config.num_relations, self.hidden_dim) 58 | self.relations = torch.Tensor([i for i in range(config.num_relations)]) 59 | if USE_CUDA: 60 | self.relations = self.relations.cuda() 61 | self.relation_bias = nn.Parameter(torch.randn(config.num_relations)) 62 | 63 | assert (self.pretrained is True and embedding_pre is not None) or \ 64 | (self.pretrained is False and embedding_pre is None), "预训练必须有训练好的embedding_pre" 65 | # 定义网络层 66 | # 对于关系抽取,命名实体识别和关系抽取共享编码层 67 | if self.pretrained: 68 | # self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_pre), freeze=False) 69 | self.word_embedding = nn.Embedding.from_pretrained(torch.FloatTensor(embedding_pre), freeze=False) 70 | else: 71 | self.word_embedding = nn.Embedding(config.vocab_size, config.embedding_dim, padding_idx=config.pad_token_id) 72 | 73 | # self.pos1_embedding = nn.Embedding(config.pos_size, config.embedding_dim) 74 | # self.pos2_embedding = nn.Embedding(config.pos_size, config.embedding_dim) 75 | self.gru = nn.GRU(config.embedding_dim+2*config.pos_dim, config.hidden_dim_lstm, num_layers=config.num_layers, batch_first=True, bidirectional=True, 76 | dropout=config.dropout_lstm) 77 | self.attention_layer = Attention(config) 78 | # self.classifier = nn.Linear(config.hidden_dim_lstm, config.num_relations) 79 | 80 | if USE_CUDA: 81 | self.weights_rel = (torch.ones(self.config.num_relations) * 6).cuda() 82 | else: 83 | self.weights_rel = torch.ones(self.config.num_relations) * 6 84 | # self.weights_rel[9], self.weights_rel[13], self.weights_rel[14], self.weights_rel[46] = 100, 100, 100, 100 85 | self.weights_rel[0] = 1 86 | self.hidden_init = torch.randn(2 * self.num_layers, self.batch_size, self.hidden_dim) 87 | if USE_CUDA: 88 | self.hidden_init = self.hidden_init.cuda() 89 | # self.pos_embedding_layer = nn.Embedding(config.max_seq_length*4, config.pos_dim) 90 | 91 | def forward(self, data_item, is_test=False): 92 | 93 | word_embeddings = self.word_embedding(data_item['sentence_cls'].to(torch.int64)) 94 | # pos1_embeddings = self.pos_embedding_layer(data_item['position_s'].to(torch.int64)) 95 | # pos2_embeddings = self.pos_embedding_layer(data_item['position_o'].to(torch.int64)) 96 | # embeddings = torch.cat((word_embeddings, pos1_embeddings, pos2_embeddings), 2) # batch_size, seq, word_dim+2*pos_dim 97 | embeddings = word_embeddings 98 | if self.config.use_dropout: 99 | embeddings = self.embed_dropout(embeddings) 100 | 101 | output, h_n = self.gru(embeddings, self.hidden_init) 102 | if self.config.use_dropout: 103 | output = self.lstm_dropout(output) 104 | attention_input = output[:, :, :self.hidden_dim] + output[:, :, self.hidden_dim:] 105 | attention_output = self.attention_layer(attention_input, h_n) 106 | # hidden_cls = torch.tanh(attention_output) 107 | # output_cls = self.classifier(attention_output) 108 | relation_embeds = self.relation_embed_layer(self.relations.to(torch.int64)) 109 | # res = torch.add(torch.matmul(attention_output, relation_embeds.transpose(-1, -2)), self.relation_bias) 110 | res = torch.matmul(attention_output, relation_embeds.transpose(-1, -2)) 111 | 112 | if not is_test: 113 | loss = F.cross_entropy(res, data_item['relation'], self.weights_rel) # loss = F.cross_entropy(attention_output, data_item['relation']) 114 | # loss /= self.config.batch_size 115 | res = F.softmax(res, -1) 116 | pred = res.argmax(dim=-1) 117 | if is_test: 118 | return pred 119 | return loss, pred 120 | -------------------------------------------------------------------------------- /mains/trainer_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 17:28 5 | # @File : trainer_ner.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import sys 12 | sys.path.append('/home/xieminghui/Projects/EntityRelationExtraction/') # 添加路径 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | from utils.config_ner import ConfigNer, USE_CUDA 19 | from modules.model_ner import SeqLabel 20 | from data_loader.process_ner import ModelDataPreparation 21 | import math 22 | 23 | from seqeval.metrics import f1_score 24 | from seqeval.metrics import precision_score 25 | from seqeval.metrics import accuracy_score 26 | from seqeval.metrics import recall_score 27 | from seqeval.metrics import classification_report 28 | import neptune 29 | 30 | 31 | class Trainer: 32 | def __init__(self, 33 | model, 34 | config, 35 | train_dataset=None, 36 | dev_dataset=None, 37 | test_dataset=None, 38 | ): 39 | self.model = model 40 | self.train_dataset = train_dataset 41 | self.dev_dataset = dev_dataset 42 | self.test_dataset = test_dataset 43 | self.config = config 44 | 45 | if USE_CUDA: 46 | self.model = self.model.cuda() 47 | # 初始优化器 48 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr) 49 | # 学习率调控 50 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, 51 | patience=8, min_lr=1e-5, verbose=True) 52 | self.get_id2token_type() 53 | 54 | def get_id2token_type(self): 55 | self.id2token_type = {} 56 | for i, token_type in enumerate(self.config.token_types): 57 | self.id2token_type[i] = token_type 58 | 59 | def train(self): 60 | print('STARTING TRAIN...') 61 | f1_ner_total_best = 0.0 62 | self.num_sample_total = len(self.train_dataset) * self.config.batch_size 63 | for epoch in range(self.config.epochs): 64 | print("Epoch: {}".format(epoch)) 65 | pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset)) 66 | loss_ner_total, f1_ner_total = 0, 0 67 | for i, data_item in pbar: 68 | loss_ner, f1_ner, pred_ner = self.train_batch(data_item) 69 | loss_ner_total += loss_ner 70 | f1_ner_total += f1_ner 71 | 72 | if (epoch+1) % 1 == 0: 73 | self.predict_sample() 74 | loss_ner_train_ave = loss_ner_total/self.num_sample_total 75 | print("train ner loss: {0}, f1 score: {1}".format(loss_ner_train_ave, 76 | f1_ner_total/self.num_sample_total*self.config.batch_size)) 77 | # neptune.log_metric("train ner loss", loss_ner_train_ave) 78 | # pbar.set_description('TRAIN LOSS: {}'.format(loss_total/self.num_sample_total)) 79 | if (epoch+1) % 1 == 0: 80 | self.evaluate() 81 | if epoch > 8 and f1_ner_total >= f1_ner_total_best: 82 | f1_ner_total_best = f1_ner_total 83 | torch.save({ 84 | 'epoch': epoch+1, 'state_dict': self.model.state_dict(), 'f1_best': f1_ner_total, 85 | 'optimizer': self.optimizer.state_dict(), 86 | }, 87 | self.config.ner_checkpoint_path + str(epoch) + 'm-' + 'f'+str("%.2f"%f1_ner_total) + 'n'+ 88 | str("%.2f"%loss_ner_total) +'ccks2019_ner.pth' 89 | ) 90 | 91 | def train_batch(self, data_item): 92 | self.optimizer.zero_grad() 93 | loss_ner, pred_ner = self.model(data_item) 94 | pred_token_type = self.restore_ner(pred_ner, data_item['mask_tokens']) 95 | f1_ner = f1_score(data_item['token_type_origin'], pred_token_type) 96 | loss_ner.backward() 97 | self.optimizer.step() 98 | 99 | return loss_ner,f1_ner, pred_ner 100 | 101 | def restore_ner(self, pred_ner, mask_tokens): 102 | pred_token_type = [] 103 | for i in range(len(pred_ner)): 104 | list_tmp = [] 105 | for j in range(len(pred_ner[0])): 106 | if mask_tokens[i, j] == 0: 107 | break 108 | list_tmp.append(self.id2token_type[pred_ner[i][j]]) 109 | pred_token_type.append(list_tmp) 110 | 111 | return pred_token_type 112 | 113 | def evaluate(self): 114 | print('STARTING EVALUATION...') 115 | self.model.train(False) 116 | pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset)) 117 | 118 | loss_total, loss_ner_total = 0, 0 119 | for i, data_item in pbar_dev: 120 | loss_ner, pred_ner = self.model(data_item) 121 | loss_ner_total += loss_ner 122 | 123 | self.model.train(True) 124 | print("eval ner loss: {0}".format(loss_ner_total / (len(self.dev_dataset) * self.config.batch_size))) 125 | # return loss_ner_total / (len(self.dev_dataset) * self.config.batch_size) 126 | 127 | def predict(self): 128 | print('STARTING PREDICTING...') 129 | self.model.train(False) 130 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 131 | for i, data_item in pbar: 132 | pred_ner = self.model(data_item, is_test=True) 133 | self.model.train(True) 134 | token_pred = [[] for _ in range(len(pred_ner))] 135 | for i in range(len(pred_ner)): 136 | for item in pred_ner[i]: 137 | token_pred[i].append(self.id2token_type[item]) 138 | return token_pred 139 | 140 | def predict_sample(self): 141 | print('STARTING TESTING...') 142 | self.model.train(False) 143 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 144 | 145 | for i, data_item in pbar: 146 | pred_ner = self.model(data_item, is_test=True) 147 | data_item0 = data_item 148 | pred_ner = pred_ner[0] 149 | token_pred = [] 150 | for i in pred_ner: 151 | token_pred.append(self.id2token_type[i]) 152 | print("token_pred: {}".format(token_pred)) 153 | print(data_item0['text'][0]) 154 | print(data_item0['spo_list'][0]) 155 | self.model.train(True) 156 | 157 | 158 | if __name__ == '__main__': 159 | # neptune.init(api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiNTM3OTQzY2ItMzRhNC00YjYzLWJhMTktMzI0NTk4NmM4NDc3In0=', project_qualified_name='mangopudding/EntityRelationExtraction') 160 | # neptune.create_experiment('ner_train') 161 | print("Run EntityRelationExtraction NER ...") 162 | config = ConfigNer() 163 | model = SeqLabel(config) 164 | data_processor = ModelDataPreparation(config) 165 | train_loader, dev_loader, test_loader = data_processor.get_train_dev_data( 166 | '../data/train_small.json', 167 | '../data/dev_small.json', 168 | '../data/predict.json') 169 | # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data('../data/train_data_small.json') 170 | trainer = Trainer(model, config, train_loader, dev_loader, test_loader) 171 | trainer.train() -------------------------------------------------------------------------------- /modules/joint_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 10:02 5 | # @File : joint_model.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torchcrf import CRF 15 | import numpy as np 16 | from utils.config import USE_CUDA 17 | 18 | def setup_seed(seed): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | 23 | class JointModel(nn.Module): 24 | def __init__(self, config): 25 | super().__init__() 26 | setup_seed(1) 27 | 28 | self.vocab_size = config.vocab_size 29 | self.embedding_dim = config.embedding_dim 30 | self.hidden_dim = config.hidden_dim_lstm 31 | self.num_layers = config.num_layers 32 | self.batch_size = config.batch_size 33 | self.layer_size = config.layer_size # self.hidden_dim, 之前这里没有改 34 | self.num_token_type = config.num_token_type # 实体类型的综述 35 | self.config = config 36 | 37 | self.word_embedding = nn.Embedding(config.vocab_size, config.embedding_dim) 38 | self.token_type_embedding = nn.Embedding(config.num_token_type, config.token_type_dim) 39 | self.gru = nn.GRU(config.embedding_dim, config.hidden_dim_lstm, num_layers=config.num_layers, batch_first=True, 40 | bidirectional=True) 41 | self.is_train = True 42 | if USE_CUDA: 43 | self.weights_rel = (torch.ones(self.config.num_relations) * 100).cuda() 44 | else: 45 | self.weights_rel = torch.ones(self.config.num_relations) * 100 46 | self.weights_rel[0] = 1 47 | 48 | self.V_ner = nn.Parameter(torch.rand((config.num_token_type, self.layer_size))) 49 | self.U_ner = nn.Parameter(torch.rand((self.layer_size, 2 * self.hidden_dim))) 50 | self.b_s_ner = nn.Parameter(torch.rand(self.layer_size)) 51 | self.b_c_ner = nn.Parameter(torch.rand(config.num_token_type)) 52 | 53 | self.U_head = nn.Parameter(torch.rand((self.layer_size, self.hidden_dim * 2 + self.config.token_type_dim))) 54 | self.W_head = nn.Parameter(torch.rand((self.layer_size, self.hidden_dim * 2 + self.config.token_type_dim))) 55 | self.V_head = nn.Parameter(torch.rand(self.layer_size, len(self.config.relations))) 56 | self.b_s_head = nn.Parameter(torch.rand(self.layer_size)) 57 | 58 | self.dropout_embedding_layer = torch.nn.Dropout(config.dropout_embedding) 59 | self.dropout_head_layer = torch.nn.Dropout(config.dropout_head) 60 | self.dropout_ner_layer = torch.nn.Dropout(config.dropout_ner) 61 | self.dropout_lstm_layer = torch.nn.Dropout(config.dropout_lstm) 62 | self.crf_model = CRF(self.num_token_type, batch_first=True) 63 | 64 | def get_ner_score(self, output_lstm): 65 | 66 | res = torch.matmul(output_lstm, self.U_ner.transpose(-1, -2)) + self.b_s_ner # [seq_len, batch, self.layer_size] 67 | res = torch.tanh(res) 68 | # res = F.leaky_relu(res, negative_slope=0.01) 69 | if self.config.use_dropout: 70 | res = self.dropout_ner_layer(res) 71 | 72 | ans = torch.matmul(res, self.V_ner.transpose(-1, -2)) + self.b_c_ner # [seq_len, batch, num_token_type] 73 | 74 | return ans 75 | 76 | def broadcasting(self, left, right): 77 | left = left.permute(1, 0, 2) 78 | left = left.unsqueeze(3) 79 | 80 | right = right.permute(0, 2, 1) 81 | right = right.unsqueeze(0) 82 | 83 | B = left + right # [seq_len, batch, layer_size, seq_len] = [seq_len, batch, layer_size, 1] + [1, batch, layer_size, seq_len] 84 | B = B.permute(1, 0, 3, 2) 85 | 86 | return B # [batch, seq_len, seq_len, layer_size] 87 | 88 | def getHeadSelectionScores(self, rel_input): 89 | 90 | left = torch.matmul(rel_input, self.U_head.transpose(-1, -2)) # [batch, seq, self.layer_size] 91 | right = torch.matmul(rel_input, self.W_head.transpose(-1, -2)) 92 | 93 | out_sum = self.broadcasting(left, right) 94 | out_sum_bias = out_sum + self.b_s_head 95 | # out_sum_bias = torch.tanh(out_sum_bias) # relu 96 | out_sum_bias = F.elu(out_sum_bias) 97 | # out_sum_bias = F.leaky_relu(out_sum_bias, negative_slope=0.01) # relu 98 | if self.config.use_dropout: 99 | out_sum_bias = self.dropout_head_layer(out_sum_bias) 100 | res = torch.matmul(out_sum_bias, self.V_head) # [layer_size, num_relation] [batch,..., num_relation] 101 | # res的维度应该是 [batch, seq_len, seq_len, num_relation] 102 | # if self.config.use_dropout: 103 | # res = self.dropout_head_layer(res) 104 | return res 105 | 106 | def forward(self, data_item, is_test=False, hidden_init=None): 107 | # 因为不是多跳机制,所以hidden_init不能继承之前的最终隐含态 108 | ''' 109 | 110 | :param data_item: data_item = {'',} 111 | :type data_item: dict 112 | :return: 113 | :rtype: 114 | ''' 115 | # print("hello5") 116 | embeddings = self.word_embedding(data_item['text_tokened'].to(torch.int64)) # 要转化为int64 117 | if self.config.use_dropout: 118 | embeddings = self.dropout_embedding_layer(embeddings) 119 | # if hidden_init is None: 120 | # print("hello6") 121 | if USE_CUDA: 122 | hidden_init = torch.randn(2*self.num_layers, self.batch_size, self.hidden_dim).cuda() 123 | else: 124 | hidden_init = torch.randn(2 * self.num_layers, self.batch_size, self.hidden_dim) 125 | output_lstm, h_n =self.gru(embeddings, hidden_init) 126 | # output_lstm [batch, seq_len, 2*hidden_dim] h_n [2*num_layers, batch, hidden_dim] 127 | # print("hello7") 128 | # if self.config.use_dropout: 129 | # output_lstm = self.dropout_lstm_layer(output_lstm) # 用了效果变差 130 | ner_score = self.get_ner_score(output_lstm) 131 | # print("hello0") 132 | # 下面是使用CFR 133 | 134 | if USE_CUDA: 135 | self.crf_model = self.crf_model.cuda() 136 | if not is_test: 137 | log_likelihood = self.crf_model(ner_score, data_item['token_type_list'].to(torch.int64), 138 | mask=data_item['mask_tokens']) 139 | loss_ner = -log_likelihood 140 | 141 | pred_ner = self.crf_model.decode(ner_score) # , mask=data_item['mask_tokens'] 142 | 143 | # 下面使用的是Softmax 144 | # loss_ner = F.softmax(ner_score, data_item['ner_type']) 145 | # pred_ner = torch.argmax(ner_score, 2) 146 | 147 | #--------------------------Relation 148 | if not is_test and torch.rand(1) > self.config.teach_rate: 149 | labels = data_item['token_type_list'] 150 | else: 151 | if USE_CUDA: 152 | labels = torch.Tensor(pred_ner).cuda() 153 | else: 154 | labels = torch.Tensor(pred_ner) 155 | # print("hello1") 156 | label_embeddings = self.token_type_embedding(labels.to(torch.int64)) 157 | rel_input = torch.cat((output_lstm, label_embeddings), 2) 158 | rel_score_matrix = self.getHeadSelectionScores(rel_input) # [batch, seq_len, seq_len, num_relation] 159 | rel_score_prob = torch.sigmoid(rel_score_matrix) 160 | #gold_predicate_matrix_one_hot = F.one_hot(data_item['pred_rel_matrix'], len(self.config.relations)) 161 | if not is_test: 162 | # 这样计算交叉熵有问题吗 163 | # 交叉熵计算不适用 rel_score_prob, 应该是rel_score_matrix 164 | loss_rel = F.cross_entropy(rel_score_prob.permute(0, 3, 1, 2), data_item['pred_rel_matrix'], self.weights_rel) # 要把分类放在第二维度 165 | loss_rel *= rel_score_prob.shape[1] 166 | rel_score_prob = rel_score_prob - (self.config.threshold_rel - 0.5) # 超过了一定阈值之后才能判断关系 167 | pred_rel = torch.round(rel_score_prob).to(torch.int64) 168 | # print("hello2") 169 | if is_test: 170 | return pred_ner, pred_rel 171 | # loss_ner = min(loss_ner, 30000) 172 | # loss_rel = min(loss_rel, 10) 173 | return loss_ner, loss_rel, pred_ner, pred_rel 174 | 175 | 176 | -------------------------------------------------------------------------------- /data_loader/process_rel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/11 15:20 5 | # @File : process_rel.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import json 12 | import torch 13 | from utils.config_rel import ConfigRel, USE_CUDA 14 | import copy 15 | from transformers import BertTokenizer 16 | import codecs 17 | from collections import defaultdict 18 | 19 | 20 | class DataPreparationRel: 21 | def __init__(self, config): 22 | self.config = config 23 | self.get_token2id() 24 | self.rel_cnt = defaultdict(int) 25 | 26 | def get_data(self, file_path, is_test=False): 27 | data = [] 28 | cnt = 0 29 | with open(file_path, 'r', encoding='utf-8')as f: 30 | for line in f: 31 | cnt += 1 32 | if cnt > self.config.num_sample: 33 | break 34 | data_item = json.loads(line) 35 | spo_list = data_item['spo_list'] 36 | text = data_item['text'] 37 | text = text.lower() 38 | for spo_item in spo_list: 39 | subject = spo_item["subject"] 40 | subject = subject.lower() 41 | object = spo_item["object"] 42 | object = object.lower() 43 | # 增加位置信息 44 | # index_s = text.index(subject) 45 | # index_o = text.index(object) 46 | # position_s, position_o = [], [] 47 | # for i, word in enumerate(text): 48 | # if word not in self.word2id: 49 | # continue 50 | # position_s.append(i-index_s+self.config.max_seq_length*2) 51 | # position_o.append(i-index_o+self.config.max_seq_length*2) 52 | if not is_test: 53 | relation = spo_item['predicate'] 54 | if self.rel_cnt[relation] > self.config.rel_num: 55 | continue 56 | self.rel_cnt[relation] += 1 57 | 58 | else: 59 | relation = [] 60 | # sentence_cls = '$'.join([subject, object, text.replace(subject, '#'*len(subject)).replace(object, '#'*len(object))]) 61 | sentence_cls = ''.join([subject, object, text]) 62 | # sentence_cls = text 63 | item = {'sentence_cls': sentence_cls, 'relation': relation, 'text': text, 64 | 'subject': subject, 'object': object} # 'position_s': position_s, 'position_o': position_o} 65 | data.append(item) 66 | # 添加负样本 67 | sentence_neg = '$'.join([object, subject, text]) 68 | # sentence_neg = '$'.join( 69 | # [object, subject, text.replace(subject, '#' * len(subject)).replace(object, '#' * len(object))]) 70 | item_neg = {'sentence_cls': sentence_neg, 'relation': 'N', 'text': text, 71 | 'subject': object, 'object': subject} 72 | data.append(item_neg) 73 | 74 | dataset = Dataset(data) 75 | if is_test: 76 | dataset.is_test = True 77 | data_loader = torch.utils.data.DataLoader( 78 | dataset=dataset, 79 | batch_size=self.config.batch_size, 80 | collate_fn=dataset.collate_fn, 81 | shuffle=True, 82 | drop_last=True 83 | ) 84 | 85 | return data_loader 86 | 87 | def get_token2id(self): 88 | self.word2id = {} 89 | with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f: 90 | cnt = 0 91 | for line in f.readlines(): 92 | self.word2id[line.split()[0]] = cnt 93 | cnt += 1 94 | 95 | def get_train_dev_data(self, path_train=None, path_dev=None, path_test=None, is_test=False): 96 | train_loader, dev_loader, test_loader = None, None, None 97 | if path_train is not None: 98 | train_loader = self.get_data(path_train) 99 | if path_dev is not None: 100 | dev_loader = self.get_data(path_dev) 101 | if path_test is not None: 102 | test_loader = self.get_data(path_test, is_test=True) 103 | 104 | return train_loader, dev_loader, test_loader 105 | 106 | 107 | class Dataset(torch.utils.data.Dataset): 108 | def __init__(self, data): 109 | self.data = copy.deepcopy(data) 110 | self.is_test = False 111 | with open('../data/rel2id.json', 'r', encoding='utf-8') as f: 112 | self.rel2id = json.load(f) 113 | vocab_file = '../bert-base-chinese/vocab.txt' 114 | self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file) 115 | 116 | # vocab_file = '../data/vec.txt' 117 | # self.get_token2id() 118 | 119 | def get_token2id(self): 120 | self.word2id = {} 121 | with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f: 122 | cnt = 0 123 | for line in f.readlines(): 124 | self.word2id[line.split()[0]] = cnt 125 | cnt += 1 126 | 127 | def __getitem__(self, index): 128 | sentence_cls = self.data[index]['sentence_cls'] 129 | relation = self.data[index]['relation'] 130 | text = self.data[index]['text'] 131 | subject = self.data[index]['subject'] 132 | object = self.data[index]['object'] 133 | # position_s = self.data[index]['position_s'] 134 | # position_o = self.data[index]['position_o'] 135 | 136 | data_info = {} 137 | for key in self.data[0].keys(): 138 | if key in locals(): 139 | data_info[key] = locals()[key] 140 | 141 | return data_info 142 | 143 | def __len__(self): 144 | return len(self.data) 145 | 146 | def collate_fn(self, data_batch): 147 | def merge(sequences): 148 | lengths = [len(seq) for seq in sequences] 149 | max_length = max(lengths) 150 | # padded_seqs = torch.zeros(len(sequences), max_length) 151 | padded_seqs = torch.zeros(len(sequences), max_length) 152 | tmp_pad = torch.ones(1, max_length) 153 | mask_tokens = torch.zeros(len(sequences), max_length) 154 | for i, seq in enumerate(sequences): 155 | end = lengths[i] 156 | seq = torch.LongTensor(seq) 157 | if len(seq) != 0: 158 | padded_seqs[i, :end] = seq[:end] 159 | mask_tokens[i, :end] = tmp_pad[0, :end] 160 | 161 | return padded_seqs, mask_tokens 162 | 163 | item_info = {} 164 | for key in data_batch[0].keys(): 165 | item_info[key] = [d[key] for d in data_batch] 166 | 167 | # 转化为数值 168 | sentence_cls = [self.bert_tokenizer.encode(sentence, add_special_tokens=True) for sentence in item_info['sentence_cls']] 169 | # sentence_cls = [[] for _ in range(len(item_info['sentence_cls']))] 170 | # for i, sentence in enumerate(item_info['sentence_cls']): 171 | # tmp = [] 172 | # for c in sentence: 173 | # if c in self.word2id: 174 | # tmp.append(self.word2id[c]) 175 | # sentence_cls[i] = tmp 176 | 177 | if not self.is_test: 178 | relation = torch.Tensor([self.rel2id[rel] for rel in item_info['relation']]).to(torch.int64) 179 | 180 | # 批量数据对齐 181 | sentence_cls, mask_tokens = merge(sentence_cls) 182 | sentence_cls = sentence_cls.to(torch.int64) 183 | mask_tokens = mask_tokens.to(torch.int64) 184 | relation = relation.to(torch.int64) 185 | # position_s, _ = merge(item_info['position_s']) 186 | # position_o, _ = merge(item_info['position_o']) 187 | if USE_CUDA: 188 | sentence_cls = sentence_cls.contiguous().cuda() 189 | mask_tokens = mask_tokens.contiguous().cuda() 190 | # position_s = position_s.contiguous().cuda() 191 | # position_o = position_o.contiguous().cuda() 192 | else: 193 | sentence_cls = sentence_cls.contiguous() 194 | mask_tokens = mask_tokens.contiguous() 195 | # position_s = position_s.contiguous() 196 | # position_o = position_o.contiguous() 197 | if not self.is_test: 198 | if USE_CUDA: 199 | relation = relation.contiguous().cuda() 200 | else: 201 | relation = relation.contiguous() 202 | 203 | data_info = {"mask_tokens": mask_tokens.to(torch.uint8)} 204 | data_info['text'] = item_info['text'] 205 | data_info['subject'] = item_info['subject'] 206 | data_info['object'] = item_info['object'] 207 | for key in item_info.keys(): 208 | if key in locals(): 209 | data_info[key] = locals()[key] 210 | 211 | return data_info 212 | 213 | 214 | if __name__ == '__main__': 215 | config = ConfigRel() 216 | process = DataPreparationRel(config) 217 | train_loader, dev_loader, test_loader = process.get_train_dev_data('../data/train_small.json') 218 | 219 | for item in train_loader: 220 | print(item) 221 | 222 | 223 | -------------------------------------------------------------------------------- /mains/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 17:28 5 | # @File : trainer.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | import sys 12 | sys.path.append('/home/xieminghui/Projects/MultiHeadJointEntityRelationExtraction_simple/') # 添加路径 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | from utils.config import Config, USE_CUDA 19 | from modules.joint_model import JointModel 20 | from data_loader.data_process import ModelDataPreparation 21 | import math 22 | 23 | from seqeval.metrics import f1_score 24 | from seqeval.metrics import precision_score 25 | from seqeval.metrics import accuracy_score 26 | from seqeval.metrics import recall_score 27 | from seqeval.metrics import classification_report 28 | import numpy as np 29 | # if torch.cuda.is_available(): 30 | # USE_CUDA = True 31 | 32 | 33 | class Trainer: 34 | def __init__(self, 35 | model, 36 | config, 37 | train_dataset=None, 38 | dev_dataset=None, 39 | test_dataset=None, 40 | token2id=None 41 | ): 42 | self.model = model 43 | self.train_dataset = train_dataset 44 | self.dev_dataset = dev_dataset 45 | self.test_dataset = test_dataset 46 | self.config = config 47 | self.token2id = token2id 48 | 49 | # 初始优化器 50 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr) 51 | # 学习率调控 52 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, 53 | patience=8, min_lr=1e-5, verbose=True) 54 | if USE_CUDA: 55 | self.model = self.model.cuda() 56 | 57 | self.num_sample_total = len(train_loader) * self.config.batch_size 58 | 59 | self.get_id2rel() 60 | 61 | def get_id2rel(self): 62 | self.id2rel = {} 63 | for i, rel in enumerate(self.config.relations): 64 | self.id2rel[i] = rel 65 | self.id2token_type = {} 66 | for i, token_type in enumerate(self.config.token_types): 67 | self.id2token_type[i] = token_type 68 | 69 | def train(self): 70 | print('STARTING TRAIN...') 71 | f1_ner_total_best = 0 72 | for epoch in range(self.config.epochs): 73 | print("Epoch: {}".format(epoch)) 74 | # print(len(self.train_dataset)) 75 | pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset)) 76 | loss_total, loss_ner_total, loss_rel_total, f1_ner_total = 0, 0, 0, 0 77 | # first = True 78 | # print("haha1") 79 | for i, data_item in pbar: 80 | # print("haha2") 81 | loss_ner, loss_rel, pred_ner, pred_rel, f1_ner = self.train_batch(data_item) 82 | 83 | loss_total += (loss_ner + loss_rel) 84 | loss_ner_total += loss_ner 85 | loss_rel_total += loss_rel 86 | f1_ner_total += f1_ner 87 | 88 | if (epoch+1) % 1 == 0: 89 | self.predict_sample() 90 | print("train ner loss: {0}, rel loss: {1}, f1 score: {2}".format(loss_ner_total/self.num_sample_total, loss_rel_total/self.num_sample_total, 91 | f1_ner_total/self.num_sample_total*self.config.batch_size)) 92 | # pbar.set_description('TRAIN LOSS: {}'.format(loss_total/self.num_sample_total)) 93 | if (epoch+1) % 1 == 0: 94 | self.evaluate() 95 | # if epoch > 10 and f1_ner_total > f1_ner_total_best: 96 | # torch.save({ 97 | # 'epoch': epoch+1, 'state_dict': model.state_dict(), 'f1_best': f1_ner_total, 98 | # 'optimizer': self.optimizer.state_dict(), 99 | # }, 100 | # self.config.checkpoint_path + str(epoch) + 'm-' + 'f'+str("%.4f"%f1_ner_total) + 'n'+str("%.4f"%loss_ner_total) + 101 | # 'r'+str("%.4f"%loss_rel_total) + '.pth' 102 | # ) 103 | 104 | 105 | def train_batch(self, data_item): 106 | # print("haha4") 107 | self.optimizer.zero_grad() 108 | # self.loss_ner, self.loss_rel, self.pred_ner, self.pred_rel = self.model(data_item) 109 | # self.loss = self.loss_ner + self.loss_rel 110 | # self.loss.backward() 111 | # self.optimizer.step() 112 | # print("haha5") 113 | loss_ner, loss_rel, pred_ner, pred_rel = self.model(data_item) 114 | pred_token_type = self.restore_ner(pred_ner, data_item['mask_tokens']) 115 | f1_ner = f1_score(data_item['token_type_origin'], pred_token_type) 116 | loss = (loss_ner + loss_rel) 117 | # print("hello3") 118 | loss.backward() 119 | # print("hello4") 120 | self.optimizer.step() 121 | 122 | return loss_ner, loss_rel, pred_ner, pred_rel, f1_ner 123 | 124 | def restore_ner(self, pred_ner, mask_tokens): 125 | pred_token_type = [] 126 | for i in range(len(pred_ner)): 127 | list_tmp = [] 128 | for j in range(len(pred_ner[0])): 129 | if mask_tokens[i, j] == 0: 130 | break 131 | list_tmp.append(self.id2token_type[pred_ner[i][j]]) 132 | pred_token_type.append(list_tmp) 133 | 134 | return pred_token_type 135 | 136 | def evaluate(self): 137 | print('STARTING EVALUATION...') 138 | self.model.train(False) 139 | pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset)) 140 | 141 | loss_total, loss_ner_total, loss_rel_total = 0, 0, 0 142 | for i, data_item in pbar_dev: 143 | loss_ner, loss_rel, pred_ner, pred_rel = self.model(data_item) 144 | loss_ner_total += loss_ner 145 | loss_rel_total += loss_rel 146 | # loss_total += (loss_ner + loss_rel) 147 | print("eval ner loss: {0}, rel loss: {1}".format(loss_ner_total, loss_rel_total)) 148 | self.model.train(True) 149 | 150 | return loss_total / (len(self.dev_dataset) * 8) 151 | 152 | def predict(self): 153 | print('STARTING TESTING...') 154 | self.model.train(False) 155 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 156 | for i, data_item in pbar: 157 | pred_ner, pred_rel = self.model(data_item, is_test=True) 158 | print("TEST NER:") 159 | print(pred_ner) 160 | print("TEST REL:") 161 | print(pred_rel) 162 | self.model.train(True) 163 | 164 | def predict_sample(self): 165 | print('STARTING TESTING...') 166 | self.model.train(False) 167 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 168 | data_item0 = None 169 | for i, data_item in pbar: 170 | 171 | pred_ner, pred_rel = self.model(data_item, is_test=True) 172 | data_item0 = data_item 173 | pred_ner, pred_rel = pred_ner[0], pred_rel[0] 174 | pred_rel_list = [] 175 | for i in range(pred_rel.shape[0]): 176 | for j in range(pred_rel.shape[1]): 177 | for k in range(pred_rel.shape[2]): 178 | if math.fabs(pred_rel[i, j, k] - 1.0) < 0.1: 179 | # print(i, j, k, self.id2rel[k]) 180 | if k != 0: 181 | pred_rel_list.append([i, j, self.id2rel[k]]) 182 | token_pred = [] 183 | for i in pred_ner: 184 | token_pred.append(self.id2token_type[i]) 185 | print("token_pred: {}".format(token_pred)) 186 | print(data_item0['text'][0]) 187 | print(data_item0['spo_list'][0]) 188 | print("pred_rel_list: {}".format(pred_rel_list)) 189 | self.model.train(True) 190 | subject_all, object_all, rel_all = self.convert2StandardOutput(data_item0, token_pred, pred_rel_list) 191 | print("Results:") 192 | print("主体: \n", subject_all) 193 | print("客体: \n", object_all) 194 | print("关系: \n", rel_all) 195 | 196 | def convert2StandardOutput(self, data_item, token_pred, pred_rel_list): 197 | subject_all, object_all, rel_all = [], [], [] 198 | text = [c for c in data_item['text'][0]] 199 | for item in pred_rel_list: 200 | subject, object, rel = [], [], [] 201 | s_start, o_start = item[0], item[1] 202 | if s_start == o_start: # 防止自己和自己构成关系 203 | continue 204 | if token_pred[s_start][0] != 'B' or token_pred[o_start][0] != 'B' or s_start > len(text) or o_start > len(text): 205 | continue 206 | subject.append(text[s_start]) 207 | object.append(text[o_start]) 208 | s_start += 1 209 | o_start += 1 210 | while s_start < len(text) and (token_pred[s_start][0] == 'I' ): # or token_pred[s_start][0] == 'B' 211 | subject.append(text[s_start]) 212 | s_start += 1 213 | while o_start < len(text) and (token_pred[o_start][0] == 'I'): # or token_pred[o_start][0] == 'B' 214 | object.append(text[o_start]) 215 | o_start += 1 216 | subject_all.append(''.join(subject)) 217 | object_all.append(''.join(object)) 218 | rel_all.append(item[2]) 219 | 220 | return subject_all, object_all, rel_all 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | config = Config() 226 | model = JointModel(config) 227 | data_processor = ModelDataPreparation(config) 228 | train_loader, dev_loader, test_loader = data_processor.get_train_dev_data( 229 | '../data/train_data_small.json', 230 | '../data/dev_small.json', 231 | '../data/predict.json') 232 | # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data('../data/train_data_small.json') 233 | trainer = Trainer(model, config, train_loader, dev_loader, test_loader, data_processor.token2id) 234 | trainer.train() -------------------------------------------------------------------------------- /data_loader/process_ner.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 14:31 5 | # @File : process_ner.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | 12 | ''' 13 | 针对spo_list的主客体在原文的token进行标注,第一个字标注B-type,后面的字标注I-type,文本中其他词标注为O 14 | (先将所有文本标注为O,然后根据spo_list的内容,将对应位置覆盖) 15 | 16 | ''' 17 | import json 18 | import torch 19 | import copy 20 | from utils.config_ner import ConfigNer, USE_CUDA 21 | import numpy as np 22 | 23 | 24 | class ModelDataPreparation: 25 | def __init__(self, config): 26 | self.config = config 27 | self.get_type2id() 28 | 29 | def subject_object_labeling(self, spo_list, text, text_tokened): 30 | # 在列表 k 中确定列表 q 的位置 31 | def _index_q_list_in_k_list(q_list, k_list): 32 | """Known q_list in k_list, find index(first time) of q_list in k_list""" 33 | q_list_length = len(q_list) 34 | k_list_length = len(k_list) 35 | for idx in range(k_list_length - q_list_length + 1): 36 | t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])] 37 | # print(idx, t) 38 | if all(t): 39 | # print(idx) 40 | idx_start = idx 41 | return idx_start 42 | 43 | # 给主体和客体表上BIO分割式类型标签 44 | def _labeling_type(subject_object, so_type): 45 | so_tokened = [c for c in subject_object] 46 | so_tokened_length = len(so_tokened) 47 | idx_start = _index_q_list_in_k_list(q_list=so_tokened, k_list=text_tokened) 48 | if idx_start is None: 49 | tokener_error_flag = True 50 | ''' 51 | 实体: "1981年" 原句: "●1981年2月27日,中国人口学会成立" 52 | so_tokened ['1981', '年'] text_tokened ['●', '##19', '##81', '年', '2', '月', '27', '日', ',', '中', '国', '人', '口', '学', '会', '成', '立'] 53 | so_tokened 无法在 text_tokened 找到!原因是bert_tokenizer.tokenize 分词增添 “##” 所致! 54 | ''' 55 | else: # 给实体开始处标 B 其它位置标 I 56 | labeling_list[idx_start] = "B-" + so_type 57 | if so_tokened_length == 2: 58 | labeling_list[idx_start + 1] = "I-" + so_type 59 | elif so_tokened_length >= 3: 60 | labeling_list[idx_start + 1: idx_start + so_tokened_length] = ["I-" + so_type] * ( 61 | so_tokened_length - 1) 62 | return idx_start 63 | 64 | labeling_list = ["O" for _ in range(len(text_tokened))] 65 | have_error = False 66 | for spo_item in spo_list: 67 | subject = spo_item["subject"] 68 | subject_type = spo_item["subject_type"] 69 | object = spo_item["object"] 70 | subject, object = map(self.get_rid_unkonwn_word, (subject, object)) 71 | subject = list(map(lambda x: x.lower(), subject)) 72 | object = list(map(lambda x: x.lower(), object)) 73 | object_type = spo_item["object_type"] 74 | subject_idx_start = _labeling_type(subject, subject_type) 75 | object_idx_start = _labeling_type(object, object_type) 76 | if subject_idx_start is None or object_idx_start is None: 77 | have_error = True 78 | return labeling_list, have_error 79 | #sample_cls = '$'.join([subject, object, text.replace(subject, '#'*len(subject)).replace(object, '#')]) 80 | #cls_list.append(sample_cls) 81 | return labeling_list, have_error 82 | 83 | def get_rid_unkonwn_word(self, text): 84 | text_rid = [] 85 | for token in text: # 删除不在vocab里面的词汇 86 | if token in self.token2id.keys(): 87 | text_rid.append(token) 88 | return text_rid 89 | 90 | def get_type2id(self): 91 | self.token_type2id = {} 92 | for i, token_type in enumerate(self.config.token_types): 93 | self.token_type2id[token_type] = i 94 | # with open('token_type2id.json', 'w', encoding='utf-8') as f: 95 | # json.dump(self.token_type2id, f, ensure_ascii=False) 96 | # with open('rel2id.json', 'w', encoding='utf-8') as f: 97 | # json.dump(self.rel2id, f, ensure_ascii=False) 98 | self.token2id = {} 99 | with open(self.config.vocab_file, 'r', encoding='utf-8') as f: 100 | cnt = 0 101 | for line in f: 102 | line = line.rstrip().split() 103 | self.token2id[line[0]] = cnt 104 | cnt += 1 105 | self.token2id[' '] = cnt 106 | 107 | def get_data(self, file_path, is_test=False): 108 | data = [] 109 | cnt = 0 110 | with open(file_path, 'r', encoding='utf-8') as f: 111 | for line in f: 112 | cnt += 1 113 | if cnt > self.config.num_sample: 114 | break 115 | data_item = json.loads(line) 116 | if not is_test: 117 | spo_list = data_item['spo_list'] 118 | else: 119 | spo_list = [] 120 | text = data_item['text'] 121 | text_tokened = [c.lower() for c in text] # 中文使用简单的分词 122 | token_type_list, token_type_origin = None, None 123 | 124 | text_tokened = self.get_rid_unkonwn_word(text_tokened) 125 | if not is_test: 126 | token_type_list, have_error = self.subject_object_labeling( 127 | spo_list=spo_list, text=text, text_tokened=text_tokened 128 | ) 129 | token_type_origin = token_type_list # 保存没有数值化前的token_type 130 | if have_error: 131 | continue 132 | item = {'text_tokened': text_tokened, 'token_type_list': token_type_list} 133 | item['text_tokened'] = [self.token2id[x] for x in item['text_tokened']] 134 | if not is_test: 135 | item['token_type_list'] = [self.token_type2id[x] for x in item['token_type_list']] 136 | item['text'] = ''.join(text_tokened) # 保存消除异常词汇的文本 137 | item['spo_list'] = spo_list 138 | item['token_type_origin'] = token_type_origin 139 | data.append(item) 140 | dataset = Dataset(data) 141 | if is_test: 142 | dataset.is_test = True 143 | data_loader = torch.utils.data.DataLoader( 144 | dataset=dataset, 145 | batch_size=self.config.batch_size, 146 | collate_fn=dataset.collate_fn, 147 | drop_last=True 148 | ) 149 | return data_loader 150 | 151 | def get_train_dev_data(self, path_train=None, path_dev=None, path_test=None): 152 | train_loader, dev_loader, test_loader = None, None, None 153 | if path_train is not None: 154 | train_loader = self.get_data(path_train) 155 | if path_dev is not None: 156 | dev_loader = self.get_data(path_dev) 157 | if path_test is not None: 158 | test_loader = self.get_data(path_test, is_test=True) 159 | 160 | return train_loader, dev_loader, test_loader 161 | 162 | 163 | class Dataset(torch.utils.data.Dataset): 164 | def __init__(self, data): 165 | self.data = copy.deepcopy(data) 166 | self.is_test = False 167 | 168 | def __getitem__(self, index): 169 | text_tokened = self.data[index]['text_tokened'] 170 | token_type_list = self.data[index]['token_type_list'] 171 | 172 | data_info = {} 173 | for key in self.data[0].keys(): 174 | # try: 175 | # data_info[key] = locals()[key] 176 | # except KeyError: 177 | # print('{} cannot be found in locals()'.format(key)) 178 | if key in locals(): 179 | data_info[key] = locals()[key] 180 | 181 | data_info['text'] = self.data[index]['text'] 182 | data_info['spo_list'] = self.data[index]['spo_list'] 183 | data_info['token_type_origin'] = self.data[index]['token_type_origin'] 184 | return data_info 185 | 186 | def __len__(self): 187 | return len(self.data) 188 | 189 | def collate_fn(self, data_batch): 190 | 191 | def merge(sequences): 192 | lengths = [len(seq) for seq in sequences] 193 | max_length = max(lengths) 194 | # padded_seqs = torch.zeros(len(sequences), max_length) 195 | padded_seqs = torch.zeros(len(sequences), max_length) 196 | tmp_pad = torch.ones(1, max_length) 197 | mask_tokens = torch.zeros(len(sequences), max_length) 198 | for i, seq in enumerate(sequences): 199 | end = lengths[i] 200 | seq = torch.LongTensor(seq) 201 | if len(seq) != 0: 202 | padded_seqs[i, :end] = seq[:end] 203 | mask_tokens[i, :end] = tmp_pad[0, :end] 204 | 205 | return padded_seqs, mask_tokens 206 | item_info = {} 207 | for key in data_batch[0].keys(): 208 | item_info[key] = [d[key] for d in data_batch] 209 | token_type_list = None 210 | text_tokened, mask_tokens = merge(item_info['text_tokened']) 211 | if not self.is_test: 212 | token_type_list, _ = merge(item_info['token_type_list']) 213 | # convert to contiguous and cuda 214 | if USE_CUDA: 215 | text_tokened = text_tokened.contiguous().cuda() 216 | mask_tokens = mask_tokens.contiguous().cuda() 217 | else: 218 | text_tokened = text_tokened.contiguous() 219 | mask_tokens = mask_tokens.contiguous() 220 | 221 | if not self.is_test: 222 | if USE_CUDA: 223 | token_type_list = token_type_list.contiguous().cuda() 224 | 225 | else: 226 | token_type_list = token_type_list.contiguous() 227 | 228 | data_info = {"mask_tokens": mask_tokens.to(torch.uint8)} 229 | data_info['text'] = item_info['text'] 230 | data_info['spo_list'] = item_info['spo_list'] 231 | data_info['token_type_origin'] = item_info['token_type_origin'] 232 | for key in item_info.keys(): 233 | # try: 234 | # data_info[key] = locals()[key] 235 | # except KeyError: 236 | # print('{} cannot be found in locals()'.format(key)) 237 | if key in locals(): 238 | data_info[key] = locals()[key] 239 | 240 | return data_info 241 | 242 | if __name__ == '__main__': 243 | config = ConfigN() 244 | process = ModelDataPreparation(config) 245 | train_loader, dev_loader, test_loader = process.get_train_dev_data('../data/train_small.json') 246 | # train_loader, dev_loader, test_loader = process.get_train_dev_data('../data/train_data_small.json') 247 | print(train_loader) 248 | for item in train_loader: 249 | print(item) -------------------------------------------------------------------------------- /mains/trainer_rel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/11 16:38 5 | # @File : trainer_rel.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | 12 | # coding=utf-8 13 | import sys 14 | sys.path.append('/home/xieminghui/Projects/EntityRelationExtraction/') 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from tqdm import tqdm 20 | from utils.config_rel import ConfigRel, USE_CUDA 21 | from modules.model_rel import AttBiLSTM 22 | from data_loader.process_rel import DataPreparationRel 23 | import numpy as np 24 | import codecs 25 | from transformers import BertForSequenceClassification 26 | import neptune 27 | 28 | 29 | class Trainer: 30 | def __init__(self, 31 | model, 32 | config, 33 | train_dataset=None, 34 | dev_dataset=None, 35 | test_dataset=None 36 | ): 37 | self.model = model 38 | self.config = config 39 | self.train_dataset = train_dataset 40 | self.dev_dataset = dev_dataset 41 | self.test_dataset = test_dataset 42 | 43 | if USE_CUDA: 44 | self.model = self.model.cuda() 45 | 46 | self.optimizer = torch.optim.SGD(self.model.parameters(), lr=config.lr) 47 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.5, 48 | patience=8, min_lr=1e-5, verbose=True) 49 | self.get_id2rel() 50 | 51 | def train(self): 52 | print('STARTING TRAIN...') 53 | self.num_sample_total = len(self.train_dataset) * self.config.batch_size 54 | loss_eval_best = 1e8 55 | for epoch in range(self.config.epochs): 56 | print("Epoch: {}".format(epoch)) 57 | pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset)) 58 | loss_rel_total = 0.0 59 | self.optimizer.zero_grad() 60 | # with torch.no_grad(): 61 | for i, data_item in pbar: 62 | loss_rel, pred_rel = self.model(data_item) 63 | loss_rel.backward() 64 | self.optimizer.step() 65 | 66 | loss_rel_total += loss_rel 67 | loss_rel_train_ave = loss_rel_total / self.num_sample_total 68 | print("train rel loss: {0}".format(loss_rel_train_ave)) 69 | neptune.log_metric("train rel loss", loss_rel_train_ave) 70 | if (epoch + 1) % 1 == 0: 71 | loss_rel_ave = self.evaluate() 72 | 73 | if epoch > 0 and (epoch+1) % 2 == 0: 74 | if loss_rel_ave < loss_eval_best: 75 | loss_eval_best = loss_rel_ave 76 | torch.save({ 77 | 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'loss_rel_best': loss_eval_best, 78 | 'optimizer': self.optimizer.state_dict(), 79 | }, 80 | self.config.ner_checkpoint_path + str(epoch) + 'm-' + 'loss' + 81 | str("%.2f" % loss_rel_ave) + 'ccks2019_rel.pth' 82 | ) 83 | 84 | def evaluate(self): 85 | print('STARTING EVALUATION...') 86 | self.model.train(False) 87 | pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset)) 88 | 89 | loss_rel_total = 0 90 | for i, data_item in pbar_dev: 91 | loss_rel, pred_rel = self.model(data_item) 92 | loss_rel_total += loss_rel 93 | 94 | self.model.train(True) 95 | loss_rel_ave = loss_rel_total / (len(self.dev_dataset) * self.config.batch_size) 96 | print("eval rel loss: {0}".format(loss_rel_ave)) 97 | 98 | print(data_item['text'][1]) 99 | print("subject: {0}, object:{1}".format(data_item['subject'][1], data_item['object'][1])) 100 | print("object rel: {}".format(self.id2rel[int(data_item['relation'][1])])) 101 | print("predict rel: {}".format(self.id2rel[int(pred_rel[1])])) 102 | return loss_rel_ave 103 | 104 | def get_id2rel(self): 105 | self.id2rel = {} 106 | for i, rel in enumerate(self.config.relations): 107 | self.id2rel[i] = rel 108 | 109 | def predict(self): 110 | print('STARTING PREDICTING...') 111 | self.model.train(False) 112 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 113 | for i, data_item in pbar: 114 | pred_rel = self.model(data_item, is_test=True) 115 | self.model.train(True) 116 | rel_pred = [[] for _ in range(len(pred_rel))] 117 | for i in range(len(pred_rel)): 118 | # for item in pred_rel[i]: 119 | rel_pred[i].append(self.id2rel[int(pred_rel[i])]) 120 | return rel_pred 121 | 122 | def bert_train(self): 123 | print('STARTING TRAIN...') 124 | self.num_sample_total = len(self.train_dataset) * self.config.batch_size 125 | acc_best = 0.0 126 | for epoch in range(self.config.epochs): 127 | print("Epoch: {}".format(epoch)) 128 | pbar = tqdm(enumerate(self.train_dataset), total=len(self.train_dataset)) 129 | loss_rel_total = 0.0 130 | # self.optimizer.zero_grad() 131 | correct = 0 132 | # with torch.no_grad(): 133 | for i, data_item in pbar: 134 | self.optimizer.zero_grad() 135 | output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens'], labels=data_item['relation']) 136 | loss_rel, logits = output[0], output[1] 137 | loss_rel.backward() 138 | self.optimizer.step() 139 | 140 | _, pred_rel = torch.max(logits.data, 1) 141 | correct += pred_rel.data.eq(data_item['relation'].data).cpu().sum().numpy() 142 | 143 | loss_rel_total += loss_rel 144 | loss_rel_train_ave = loss_rel_total / self.num_sample_total 145 | print("train rel loss: {0}".format(loss_rel_train_ave)) 146 | # neptune.log_metric("train rel loss", loss_rel_train_ave) 147 | print("precision_score: {0}".format(correct / self.num_sample_total)) 148 | if (epoch + 1) % 1 == 0: 149 | acc_eval = self.bert_evaluate() 150 | 151 | if epoch > 0 and (epoch + 1) % 2 == 0: 152 | if acc_eval > acc_best: 153 | acc_best = acc_eval 154 | torch.save({ 155 | 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'acc_best': acc_best, 156 | 'optimizer': self.optimizer.state_dict(), 157 | }, 158 | self.config.ner_checkpoint_path + str(epoch) + 'm-' + 'acc' + 159 | str("%.2f" % acc_best) + 'ccks2019_rel.pth' 160 | ) 161 | 162 | def bert_evaluate(self): 163 | print('STARTING EVALUATION...') 164 | self.model.train(False) 165 | pbar_dev = tqdm(enumerate(self.dev_dataset), total=len(self.dev_dataset)) 166 | 167 | loss_rel_total = 0 168 | correct = 0 169 | with torch.no_grad(): 170 | for i, data_item in pbar_dev: 171 | output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens'], labels=data_item['relation']) 172 | loss_rel, logits = output[0], output[1] 173 | _, pred_rel = torch.max(logits.data, 1) 174 | loss_rel_total += loss_rel 175 | correct += pred_rel.data.eq(data_item['relation'].data).cpu().sum().numpy() 176 | 177 | self.model.train(True) 178 | loss_rel_ave = loss_rel_total / (len(self.dev_dataset) * self.config.batch_size) 179 | correct_ave = correct / (len(self.dev_dataset) * self.config.batch_size) 180 | print("eval rel loss: {0}".format(loss_rel_ave)) 181 | print("precision_score: {0}".format(correct_ave)) 182 | 183 | print(data_item['text'][1]) 184 | print("subject: {0}, object:{1}".format(data_item['subject'][1], data_item['object'][1])) 185 | print("object rel: {}".format(self.id2rel[int(data_item['relation'][1])])) 186 | print("predict rel: {}".format(self.id2rel[int(pred_rel[1])])) 187 | return correct_ave 188 | 189 | def bert_predict(self): 190 | print('STARTING PREDICTING...') 191 | self.model.train(False) 192 | pbar = tqdm(enumerate(self.test_dataset), total=len(self.test_dataset)) 193 | for i, data_item in pbar: 194 | output = self.model(data_item['sentence_cls'], attention_mask=data_item['mask_tokens']) 195 | logits = output[0] 196 | _, pred_rel = torch.max(logits.data, 1) 197 | self.model.train(True) 198 | # rel_pred = [[] for _ in range(len(pred_rel))] 199 | rel_pred = [] 200 | for i in range(len(pred_rel)): 201 | # for item in pred_rel[i]: 202 | rel_pred.append(self.id2rel[int(pred_rel[i])]) 203 | return rel_pred 204 | 205 | 206 | def get_embedding_pre(): 207 | # token2id = {} 208 | # with open('../data/vocab.txt', 'r', encoding='utf-8') as f: 209 | # cnt = 0 210 | # for line in f: 211 | # line = line.rstrip().split() 212 | # token2id[line[0]] = cnt 213 | # cnt += 1 214 | 215 | word2id = {} 216 | with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f: 217 | cnt = 0 218 | for line in f.readlines(): 219 | word2id[line.split()[0]] = cnt 220 | cnt += 1 221 | 222 | word2vec = {} 223 | with codecs.open('../data/vec.txt', 'r', encoding='utf-8') as f: 224 | for line in f.readlines(): 225 | word2vec[line.split()[0]] = list(map(eval, line.split()[1:])) 226 | unkown_pre = [] 227 | unkown_pre.extend([1]*100) 228 | embedding_pre = [] 229 | embedding_pre.append(unkown_pre) 230 | for word in word2id: 231 | if word in word2vec: 232 | embedding_pre.append(word2vec[word]) 233 | else: 234 | embedding_pre.append(unkown_pre) 235 | embedding_pre = np.array(embedding_pre) 236 | return embedding_pre 237 | 238 | 239 | if __name__ == '__main__': 240 | # neptune.init( 241 | # api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vdWkubmVwdHVuZS5haSIsImFwaV91cmwiOiJodHRwczovL3VpLm5lcHR1bmUuYWkiLCJhcGlfa2V5IjoiNTM3OTQzY2ItMzRhNC00YjYzLWJhMTktMzI0NTk4NmM4NDc3In0=', 242 | # project_qualified_name='mangopudding/EntityRelationExtraction') 243 | # neptune.create_experiment('rel_train') 244 | print("Run EntityRelationExtraction REL BERT ...") 245 | config = ConfigRel() 246 | model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=config.num_relations) 247 | data_processor = DataPreparationRel(config) 248 | train_loader, dev_loader, test_loader = data_processor.get_train_dev_data( 249 | '../data/train_data_small.json', 250 | '../data/dev_small.json', 251 | '../data/predict.json') 252 | # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data('../data/train_data_small.json') 253 | trainer = Trainer(model, config, train_loader, dev_loader, test_loader) 254 | trainer.bert_train() 255 | 256 | # if __name__ == '__main__': 257 | # # PATH_NER = '../models/sequence_labeling/60m-f589.90n40236.67ccks2019_ner.pth' 258 | # # ner_model_dict = torch.load(PATH_NER) 259 | # 260 | # print("Run EntityRelationExtraction REL ...") 261 | # config = ConfigRel() 262 | # embedding_pre = get_embedding_pre() 263 | # # embedding_pre = ner_model_dict['state_dict']['word_embedding.weight'] 264 | # model = AttBiLSTM(config, embedding_pre) 265 | # data_processor = DataPreparationRel(config) 266 | # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data( 267 | # '../data/train_data_small.json', 268 | # '../data/dev_small.json', 269 | # '../data/predict.json') 270 | # # train_loader, dev_loader, test_loader = data_processor.get_train_dev_data('../data/train_data_small.json') 271 | # trainer = Trainer(model, config, train_loader, dev_loader, test_loader) 272 | # trainer.train() -------------------------------------------------------------------------------- /data_loader/data_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # @Author : xmh 4 | # @Time : 2021/3/3 14:31 5 | # @File : data_process.py 6 | 7 | """ 8 | file description:: 9 | 10 | """ 11 | 12 | ''' 13 | 针对spo_list的主客体在原文的token进行标注,第一个字标注B-type,后面的字标注I-type,文本中其他词标注为O 14 | (先将所有文本标注为O,然后根据spo_list的内容,将对应位置覆盖) 15 | 16 | ''' 17 | import json 18 | import torch 19 | import copy 20 | from utils.config import Config, USE_CUDA 21 | import numpy as np 22 | 23 | 24 | class ModelDataPreparation: 25 | def __init__(self, config): 26 | self.config = config 27 | self.get_type_rel2id() 28 | 29 | def subject_object_labeling(self, spo_list, text_tokened): 30 | # 在列表 k 中确定列表 q 的位置 31 | def _index_q_list_in_k_list(q_list, k_list): 32 | """Known q_list in k_list, find index(first time) of q_list in k_list""" 33 | q_list_length = len(q_list) 34 | k_list_length = len(k_list) 35 | for idx in range(k_list_length - q_list_length + 1): 36 | t = [q == k for q, k in zip(q_list, k_list[idx: idx + q_list_length])] 37 | # print(idx, t) 38 | if all(t): 39 | # print(idx) 40 | idx_start = idx 41 | return idx_start 42 | 43 | # 给主体和客体表上BIO分割式类型标签 44 | def _labeling_type(subject_object, so_type): 45 | so_tokened = [c for c in subject_object] 46 | so_tokened_length = len(so_tokened) 47 | idx_start = _index_q_list_in_k_list(q_list=so_tokened, k_list=text_tokened) 48 | if idx_start is None: 49 | tokener_error_flag = True 50 | ''' 51 | 实体: "1981年" 原句: "●1981年2月27日,中国人口学会成立" 52 | so_tokened ['1981', '年'] text_tokened ['●', '##19', '##81', '年', '2', '月', '27', '日', ',', '中', '国', '人', '口', '学', '会', '成', '立'] 53 | so_tokened 无法在 text_tokened 找到!原因是bert_tokenizer.tokenize 分词增添 “##” 所致! 54 | ''' 55 | else: # 给实体开始处标 B 其它位置标 I 56 | labeling_list[idx_start] = "B-" + so_type 57 | if so_tokened_length == 2: 58 | labeling_list[idx_start + 1] = "I-" + so_type 59 | elif so_tokened_length >= 3: 60 | labeling_list[idx_start + 1: idx_start + so_tokened_length] = ["I-" + so_type] * ( 61 | so_tokened_length - 1) 62 | return idx_start 63 | 64 | labeling_list = ["O" for _ in range(len(text_tokened))] 65 | predicate_value_list = [[] for _ in range(len(text_tokened))] 66 | predicate_location_list = [[] for _ in range(len(text_tokened))] 67 | have_error = False 68 | for spo_item in spo_list: 69 | subject = spo_item["subject"] 70 | subject_type = spo_item["subject_type"] 71 | object = spo_item["object"] 72 | subject, object = map(self.get_rid_unkonwn_word, (subject, object)) 73 | subject = list(map(lambda x: x.lower(), subject)) 74 | object = list(map(lambda x: x.lower(), object)) 75 | object_type = spo_item["object_type"] 76 | predicate_value = spo_item["predicate"] 77 | subject_idx_start = _labeling_type(subject, subject_type) 78 | object_idx_start = _labeling_type(object, object_type) 79 | if subject_idx_start is None or object_idx_start is None: 80 | have_error = True 81 | return labeling_list, predicate_value_list, predicate_location_list, have_error 82 | predicate_value_list[subject_idx_start].append(predicate_value) 83 | predicate_location_list[subject_idx_start].append(object_idx_start) 84 | # 数据集中主体和客体是颠倒的,数据集中的主体是唯一的,这里颠倒一下,这样每行最多只有一个关系 85 | # predicate_value_list[object_idx_start].append(predicate_value) 86 | # predicate_location_list[object_idx_start].append(subject_idx_start) 87 | 88 | # 把 predicate_value_list和predicate_location_list空余位置填充满 89 | for idx in range(len(text_tokened)): 90 | if len(predicate_value_list[idx]) == 0: 91 | predicate_value_list[idx].append("N") # 没有关系的位置,用“N”填充 92 | if len(predicate_location_list[idx]) == 0: 93 | predicate_location_list[idx].append(idx) # 没有关系的位置,用自身的序号填充 94 | 95 | return labeling_list, predicate_value_list, predicate_location_list, have_error 96 | 97 | def get_rid_unkonwn_word(self, text): 98 | text_rid = [] 99 | for token in text: # 删除不在vocab里面的词汇 100 | if token in self.token2id.keys(): 101 | text_rid.append(token) 102 | return text_rid 103 | 104 | def get_type_rel2id(self): 105 | self.token_type2id = {} 106 | for i, token_type in enumerate(self.config.token_types): 107 | self.token_type2id[token_type] = i 108 | 109 | self.rel2id = {} 110 | for i, rel in enumerate(self.config.relations): 111 | self.rel2id[rel] = i 112 | # with open('token_type2id.json', 'w', encoding='utf-8') as f: 113 | # json.dump(self.token_type2id, f, ensure_ascii=False) 114 | # with open('rel2id.json', 'w', encoding='utf-8') as f: 115 | # json.dump(self.rel2id, f, ensure_ascii=False) 116 | self.token2id = {} 117 | with open(self.config.vocab_file, 'r', encoding='utf-8') as f: 118 | cnt = 0 119 | for line in f: 120 | line = line.rstrip().split() 121 | self.token2id[line[0]] = cnt 122 | cnt += 1 123 | self.token2id[' '] = cnt 124 | 125 | def get_data(self, file_path, is_test=False): 126 | data = [] 127 | cnt = 0 128 | with open(file_path, 'r', encoding='utf-8') as f: 129 | for line in f: 130 | cnt += 1 131 | if cnt > self.config.num_sample: 132 | break 133 | data_item = json.loads(line) 134 | if not is_test: 135 | spo_list = data_item['spo_list'] 136 | else: 137 | spo_list = [] 138 | text = data_item['text'] 139 | text_tokened = [c.lower() for c in text] # 中文使用简单的分词 140 | token_type_list, predict_rel_list, predict_location_list, token_type_origin = None, None, None, None 141 | 142 | text_tokened = self.get_rid_unkonwn_word(text_tokened) 143 | if not is_test: 144 | token_type_list, predict_rel_list, predict_location_list, have_error = self.subject_object_labeling( 145 | spo_list=spo_list, text_tokened=text_tokened 146 | ) 147 | token_type_origin = token_type_list # 保存没有数值化前的token_type 148 | if have_error: 149 | continue 150 | item = {'text_tokened': text_tokened, 'token_type_list': token_type_list, 151 | 'predict_rel_list': predict_rel_list, 'predict_location_list': predict_location_list} 152 | # print(self.token2id[' ']) 153 | item['text_tokened'] = [self.token2id[x] for x in item['text_tokened']] 154 | if not is_test: 155 | item['token_type_list'] = [self.token_type2id[x] for x in item['token_type_list']] 156 | # item['predict_rel_list'] = [self.rel2id[x] for x in item['predict_rel_list']] 157 | predict_rel_id_tmp = [] 158 | for x in item['predict_rel_list']: 159 | rel_tmp = [] 160 | for y in x: 161 | rel_tmp.append(self.rel2id[y]) 162 | predict_rel_id_tmp.append(rel_tmp) 163 | item['predict_rel_list'] = predict_rel_id_tmp 164 | # item['rel_predct_matrix' ] = self._get_multiple_predicate_matrix(item['predict_rel_list'], 165 | # item['predict_location_list'], 166 | # ) 167 | item['text'] = ''.join(text_tokened) # 保存消除异常词汇的文本 168 | item['spo_list'] = data_item['spo_list'] 169 | item['token_type_origin'] = token_type_origin 170 | data.append(item) 171 | # data.append(data_item['text']) 172 | # data.append(data_item['spo_list']) 173 | # print(len(data)) 174 | dataset = Dataset(data) 175 | if is_test: 176 | dataset.is_test = True 177 | data_loader = torch.utils.data.DataLoader( 178 | dataset=dataset, 179 | batch_size=self.config.batch_size, 180 | collate_fn=dataset.collate_fn, 181 | drop_last=True 182 | ) 183 | return data_loader 184 | 185 | def get_train_dev_data(self, path_train, path_dev=None, path_test=None): 186 | train_loader, dev_loader, test_loader = None, None, None 187 | if path_train is not None: 188 | train_loader = self.get_data(path_train) 189 | if path_dev is not None: 190 | dev_loader = self.get_data(path_dev) 191 | if path_test is not None: 192 | test_loader = self.get_data(path_test, is_test=True) 193 | 194 | return train_loader, dev_loader, test_loader 195 | 196 | 197 | class Dataset(torch.utils.data.Dataset): 198 | def __init__(self, data): 199 | self.data = copy.deepcopy(data) 200 | self.is_test = False 201 | 202 | def __getitem__(self, index): 203 | text_tokened = self.data[index]['text_tokened'] 204 | token_type_list = self.data[index]['token_type_list'] 205 | predict_rel_list = self.data[index]['predict_rel_list'] 206 | predict_location_list = self.data[index]['predict_location_list'] 207 | 208 | data_info = {} 209 | for key in self.data[0].keys(): 210 | # try: 211 | # data_info[key] = locals()[key] 212 | # except KeyError: 213 | # print('{} cannot be found in locals()'.format(key)) 214 | if key in locals(): 215 | data_info[key] = locals()[key] 216 | 217 | data_info['text'] = self.data[index]['text'] 218 | data_info['spo_list'] = self.data[index]['spo_list'] 219 | data_info['token_type_origin'] = self.data[index]['token_type_origin'] 220 | return data_info 221 | 222 | def __len__(self): 223 | return len(self.data) 224 | 225 | def _get_multiple_predicate_matrix(self, predict_rel_list_batch, predict_location_list_batch, max_seq_length): 226 | batch_size = len(predict_rel_list_batch) 227 | predict_rel_matrix = torch.zeros((batch_size, max_seq_length, max_seq_length), dtype=torch.int64) 228 | for i, predict_rel_list in enumerate(predict_rel_list_batch): 229 | for xi, predict_rels in enumerate(predict_rel_list): 230 | if 0 in predict_rels: # 0 代表是 关系 N,就是没有关系 231 | continue 232 | for xj, predict_rel in enumerate(predict_rels): 233 | object_loc = predict_location_list_batch[i][xi][xj] 234 | predict_rel_matrix[i][xi][object_loc] = predict_rel 235 | 236 | return predict_rel_matrix 237 | 238 | def collate_fn(self, data_batch): 239 | 240 | def merge(sequences, is_two=False): 241 | lengths = [len(seq) for seq in sequences] 242 | max_length = max(lengths) 243 | # padded_seqs = torch.zeros(len(sequences), max_length) 244 | if is_two: 245 | max_len = 0 246 | for i, seq in enumerate(sequences): 247 | for x in seq: 248 | max_len = max(max_len, len(x)) 249 | padded_seqs = torch.zeros(len(sequences), max_length, max_len) 250 | mask_tokens = None 251 | else: 252 | padded_seqs = torch.zeros(len(sequences), max_length) 253 | tmp_pad = torch.ones(1, max_length) 254 | mask_tokens = torch.zeros(len(sequences), max_length) 255 | for i, seq in enumerate(sequences): 256 | end = lengths[i] 257 | # seq = np.array(seq) 258 | # seq = seq.astype(float) 259 | if is_two: 260 | # max_len = 0 261 | # for x in seq: 262 | # max_len = max(max_len, len(x)) 263 | # padded_seqs = torch.zeros(len(sequences), max_length, max_len) 264 | for j, x in enumerate(seq): 265 | lenx = len(x) 266 | padded_seqs[i, j, :lenx] = torch.Tensor(x)[:lenx] 267 | 268 | else: 269 | 270 | # padded_seqs = torch.zeros(len(sequences), max_length) 271 | seq = torch.LongTensor(seq) 272 | if len(seq) != 0: 273 | padded_seqs[i, :end] = seq[:end] 274 | mask_tokens[i, :end] = tmp_pad[0, :end] 275 | 276 | # seq = torch.from_numpy(seq) 277 | # if len(seq) != 0: 278 | # padded_seqs[i, :end, :] = seq 279 | return padded_seqs, mask_tokens 280 | item_info = {} 281 | for key in data_batch[0].keys(): 282 | item_info[key] = [d[key] for d in data_batch] 283 | token_type_list, predict_rel_list, pred_rel_matrix, predict_location_list = None, None, None, None 284 | text_tokened, mask_tokens = merge(item_info['text_tokened']) 285 | if not self.is_test: 286 | token_type_list, _ = merge(item_info['token_type_list']) 287 | predict_rel_list, _ = merge(item_info['predict_rel_list'], is_two=True) 288 | predict_location_list, _ = merge(item_info['predict_location_list'], is_two=True) 289 | max_seq_length = max([len(x) for x in text_tokened]) 290 | pred_rel_matrix = self._get_multiple_predicate_matrix(item_info['predict_rel_list'], 291 | item_info['predict_location_list'], 292 | max_seq_length) 293 | # tmp = np.array(pred_rel_matrix) 294 | # np.savetxt('rel_matrix.txt', pred_rel_matrix[0]) 295 | # convert to contiguous and cuda 296 | if USE_CUDA: 297 | text_tokened = text_tokened.contiguous().cuda() 298 | mask_tokens = mask_tokens.contiguous().cuda() 299 | else: 300 | text_tokened = text_tokened.contiguous() 301 | mask_tokens = mask_tokens.contiguous() 302 | 303 | if not self.is_test: 304 | if USE_CUDA: 305 | token_type_list = token_type_list.contiguous().cuda() 306 | predict_rel_list = predict_rel_list.contiguous().cuda() 307 | predict_location_list = predict_location_list.contiguous().cuda() 308 | pred_rel_matrix = pred_rel_matrix.contiguous().cuda() 309 | 310 | else: 311 | token_type_list = token_type_list.contiguous() 312 | predict_rel_list = predict_rel_list.contiguous() 313 | predict_location_list = predict_location_list.contiguous() 314 | pred_rel_matrix = pred_rel_matrix.contiguous() 315 | 316 | data_info = {'pred_rel_matrix': pred_rel_matrix, "mask_tokens": mask_tokens.to(torch.uint8)} 317 | data_info['text'] = item_info['text'] 318 | data_info['spo_list'] = item_info['spo_list'] 319 | data_info['token_type_origin'] = item_info['token_type_origin'] 320 | for key in item_info.keys(): 321 | # try: 322 | # data_info[key] = locals()[key] 323 | # except KeyError: 324 | # print('{} cannot be found in locals()'.format(key)) 325 | if key in locals(): 326 | data_info[key] = locals()[key] 327 | 328 | return data_info 329 | 330 | # def _cuda(x): 331 | # if torch.cuda.is_available(): 332 | # return x.cuda() 333 | # else: 334 | # return x 335 | 336 | 337 | if __name__ == '__main__': 338 | config = Config() 339 | process = ModelDataPreparation(config) 340 | train_loader, dev_loader, test_loader = process.get_train_dev_data('../data/small.json') 341 | # train_loader, dev_loader, test_loader = process.get_train_dev_data('../data/train_data_small.json') 342 | print(train_loader) 343 | for item in train_loader: 344 | # print(type(item['token_type_list'][0])) 345 | # print(item['token_type_list'][0].shape) 346 | # 347 | # print(item['predict_rel_list'][0]) 348 | # print(type(item['predict_rel_list'][0][0])) 349 | # print(item['predict_rel_list'][0].shape) 350 | for i in range(item['pred_rel_matrix'].shape[0]): 351 | for j in range(item['pred_rel_matrix'].shape[1]): 352 | for k in range(item['pred_rel_matrix'].shape[2]): 353 | if item['pred_rel_matrix'][i, j, k] > 0.1 or item['pred_rel_matrix'][i, j, k] < -0.1: 354 | print(i, j, k) 355 | print(item['pred_rel_matrix'][i, j, k]) 356 | print("yes man") --------------------------------------------------------------------------------