├── README.md └── tutor.ipynb /README.md: -------------------------------------------------------------------------------- 1 | # 基于BERT和机器阅读理解实现的嵌套命名实体识别任务的详细讲解 2 | 3 | 包括 Nested NER 任务的介绍和 MRC 框架的详细讲解, 4 | 以及 MRC 框架的代码实现(附注释)。 5 | 6 | 原论文: 7 | ``` 8 | @article{li2019unified, 9 | title={A Unified MRC Framework for Named Entity Recognition}, 10 | author={Li, Xiaoya and Feng, Jingrong and Meng, Yuxian and Han, Qinghong and Wu, Fei and Li, Jiwei}, 11 | journal={arXiv preprint arXiv:1910.11476}, 12 | year={2019} 13 | } 14 | ``` 15 | -------------------------------------------------------------------------------- /tutor.ipynb: -------------------------------------------------------------------------------- 1 | {"cells":[{"cell_type":"markdown","source":["# 一 任务描述及方法设计\n","\n","## 1 Flat & Nested NER 任务的区别\n","\n","嵌套命名实体识别(Nested NER)是NER中的一种特殊情况。\n","\n","解决Flat NER(常规NER)任务的方法通常是序列标注,\n","即让模型给每一个token打上标签,\n","常用的打标签方法有两种(BIO和BIOES),\n","例如使用BIOES给“我住在广州市海珠区。”这句话中的每一个token打上标签就是:\n","\n","> Example 1: \n","> 我住在广州市海珠区。 \n","> O, O, O, B-Loc,I-Loc,E-Loc,B-Loc,I-Loc,E-Loc, O。 \n","\n","其中O表示“该token不属于实体”,B表示“实体的开头token”,I表示“实体的中间token”,E表示“实体的结尾token”,S表示“只有一个token的实体”,Loc表示实体的标签类型为地名。 \n","\n","那 Nested NER 任务是什么?\n","以我们接下来要使用的ACE2004数据集中的一个句子为例(暂时找不到用于验证Nested NER任务的中文数据集,MSRA-ch和onto04-ch都是不存在嵌套实体的):\n","\n","> Example 2: \n","> The Chinese government and the Australian government signed an agreement today , wherein the Australian party would provide China with a preferential financial loan of 150 million Australian dollars . \n","\n","> [[0, 2], [1, 1], [4, 6], [5, 5], [13, 15], [14, 14], [18, 18], [27, 27]]\n","\n","其中3处地方存在嵌套实体:\n","> 实体\"The Chinese government\"\\[0,2\\]嵌套了实体\"Chinese\"\\[1,1\\]. \n","> 实体\"The Australian government\"\\[4,6\\]嵌套了实体\"Australian\"\\[5,5\\]. \n","> 实体\"The Australian party\"\\[13,15\\]嵌套了实体\"Australian\"\\[14,14\\]. \n","\n","对于这种情况,\n","如果继续使用BIOES等方法让序列标注模型为句子打上标签,\n","那么任务就从单标签任务变成了多标签任务,\n","如\"Chinese\"同时具有\"S-GPE\"和\"I-GPE\"两种标签,\n","但相关工作表明,\n","这种处理方法比较麻烦且效果也不理想。\n","\n","所以在2020年,[基于MRC(机器阅读理解)思想来解决Nested NER任务的框架](https://arxiv.org/pdf/1910.11476.pdf)被提出。\n","\n","## 2 基于 MRC 框架解决 Nest NER 任务\n","\n","MRC 是指给模型一段文本,然后指定一个问题,让模型在文本中找出该问题的答案。\n","比如在 Nested NER 中,对于`Example 2`的这句话,可以设置以下的问题:\n","\n","> Input1: Find an organization such as company, agency and insititution in the context. \n","\n","模型返回的是实体所在位置(span),如下所示。\n","\n","> Output2: [0, 2], [4, 6], [13, 15] \n","\n","接着对于其它实体,可以再问:\n","\n","> Input2: Find an country... \n","> Output2: [1, 1], [5, 5], [14, 14], [18, 18], [27, 27] \n","\n","所以 Nested NER 任务就转换成了多轮次的问答任务(A multi-turn QA task)。\n","\n","## 3 模型设计\n","\n","所以我们现在的目标是设计这样一个模型:**输入**一个问题和一段文本,\n","**返回**实体在文本中的具体位置。\n","\n","基于 MRC 的数据格式与常规 NER 的数据格式不太一样,\n","从常规 NER 转换到 MRT 的格式需要一些特殊处理\n","(在下面的代码实现中将会看到 MRC 格式数据的具体样子)。\n","\n","简单来说,每一个样本需要转换成如下形式的三元组表示:\n","\n","$$\n","(\\mathrm{QUESTION},\\mathrm{ANSWER},\\mathrm{CONTEXT})\n","$$\n","\n","$\\mathrm{QUESTION}$ 就是将原来的token标签$y \\in Y$转换成对应的自然语言问题描述$q_y = \\{q_1, q_2, \\cdots, q_m\\}$,如上述的*Input1*一样。\n","\n","$\\mathrm{CONTEXT}$ 是文本描述,记为 $ X = \\{x_1, x_2, \\cdots, x_n \\} $。\n","\n","$\\mathrm{ANSWER}$ 是实体所在位置,记为 $x_{start,end}=\\{x_{start}, x_{start+1}, \\cdots, x_{end_1}, x_{end}\\}$,实际上是$X$的一个子串。\n","\n","因此该单个样本的三元组表示为:\n","\n","$$(q_y, x_{start,end}, X)$$\n","\n","其中 $q_y$ 和 $X$ 作为模型的输入,$x_{start,end}$ 作为模型的输出。\n","\n","因为模型是基于 BERT 来实现,所以输入 BERT 的 token 是:\n","\n","$$[CLS], q_1, q_2, \\cdots, q_m, [SEP], x_1, x_2, \\cdots, x_n, [SEP]$$\n","\n","经过 BERT 得到词嵌入 $E \\in \\mathbb{R}^{n \\times d}$ 之后($d$ 是词嵌入维度),\n","需要将所有的词嵌入丢进 3 个 token 的二分类器中,分别计算实体的开始位置(start index)、结束位置(end index)和它们的位置匹配(span match)。\n","\n","### 3.1 确定所有实体的开始和结束位置\n","\n","实体的开始和结束位置各使用一个二分器来确定,它们接收token的嵌入表示,然后判断该token是否是实体的开始(结束)位置。\n","\n","判断 token 是否是开始位置的二分器的形式化描述如下:\n","$$\n","P_{start} = \\mathrm{softmax}_{each row} (E T_{start}) \\in \\mathbf{R}^{n \\times 2}\n","$$\n","\n","判断 token 是否是结束位置的二分器的形式化描述如下:\n","$$\n","P_{end} = \\mathrm{softmax}_{each row} (E T_{end}) \\in \\mathbf{R}^{n \\times 2}\n","$$\n","\n","其中$T_{start}, T_{end} \\in \\mathbb{R}^{d \\times 2}$是可学习的参数矩阵。\n","$P_{start}^i, P_{end}^j \\in \\mathbb{R}^{1 \\times 2}$ 分别表示第 $i(j)$ 个 token 是 start(end) index 和不是 start(end) index 的概率。\n","\n","### 3.2 匹配对应的开始和结束位置\n","\n","因为存在嵌套实体,因此不能按照先后顺序去匹配得到的开始和结束位置。\n","\n","所以使用第3个二分类器来判断所选的$E_{i_{start}}$和$E_{j_{end}}$的匹配概率:\n","\n","$$P_{i_{start}, j_{end}} = \\mathrm{sigmoid}(\\mathbf{m} [E_{i_{start}} \\Vert E_{j_{end}}])$$\n","其中$m \\in \\mathbb{R}^{1 \\times 2d}$。\n","\n","**注意**:\n","\n","在实际的代码实现中,$T_{start}, T_{end}, \\in \\mathbb{R}^{d \\times 1}$,\n","\n","$m$ 是两层的 MLP 实现。\n","\n","### 3.3 计算损失函数\n","\n","最终的损失值是3个二分器各自损失值的加权累加,形式化描述如下:\n","\n","$$\n","\\mathcal{L}_{start} = \\mathrm{CE}(P_{start}, Y_{start})\n","$$\n","\n","$$\n","\\mathcal{L}_{end} = \\mathrm{CE}(P_{end}, Y_{end})\n","$$\n","\n","$$\n","\\mathcal{L}_{span} = \\mathrm{CE}(P_{start, end}, Y_{start, end})\n","$$\n","$$\n","\\mathcal{L} = \\alpha \\mathcal{L}_{start} + \\beta \\mathcal{L}_{end} + \\gamma \\mathcal{L}_{span}\n","$$\n","\n","其中$\\alpha,\\beta,\\gamma$是超参数,原论文均设为1。\n","\n","最终损失值的计算在代码实现中会更复杂,因为要考虑负样本和正样本的数量不平衡问题。"],"metadata":{"id":"PSklQvcW1Fzu"}},{"cell_type":"markdown","source":["# 二 代码实现"],"metadata":{"id":"vliL5HAJ1Fz5"}},{"cell_type":"code","source":["from google.colab import drive\n","drive.mount('/content/drive')\n","\n","! pip install transformers"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"PoFq9HeV9Sra","executionInfo":{"status":"ok","timestamp":1646996737735,"user_tz":-480,"elapsed":33215,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"c1d7a31b-9270-4732-d1e2-7c007d54d679"},"execution_count":1,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n","Collecting transformers\n"," Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)\n","\u001b[K |████████████████████████████████| 3.8 MB 4.2 MB/s \n","\u001b[?25hCollecting sacremoses\n"," Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)\n","\u001b[K |████████████████████████████████| 895 kB 42.5 MB/s \n","\u001b[?25hRequirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.7/dist-packages (from transformers) (21.3)\n","Collecting pyyaml>=5.1\n"," Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)\n","\u001b[K |████████████████████████████████| 596 kB 46.8 MB/s \n","\u001b[?25hRequirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from transformers) (2.23.0)\n","Requirement already satisfied: importlib-metadata in /usr/local/lib/python3.7/dist-packages (from transformers) (4.11.2)\n","Collecting tokenizers!=0.11.3,>=0.11.1\n"," Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)\n","\u001b[K |████████████████████████████████| 6.5 MB 39.1 MB/s \n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.7/dist-packages (from transformers) (4.63.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (2019.12.20)\n","Collecting huggingface-hub<1.0,>=0.1.0\n"," Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)\n","\u001b[K |████████████████████████████████| 67 kB 5.1 MB/s \n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from transformers) (1.21.5)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.7/dist-packages (from transformers) (3.6.0)\n","Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.7/dist-packages (from huggingface-hub<1.0,>=0.1.0->transformers) (3.10.0.2)\n","Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>=20.0->transformers) (3.0.7)\n","Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata->transformers) (3.7.0)\n","Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2.10)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (2021.10.8)\n","Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (1.24.3)\n","Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->transformers) (3.0.4)\n","Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (7.1.2)\n","Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.1.0)\n","Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from sacremoses->transformers) (1.15.0)\n","Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers\n"," Attempting uninstall: pyyaml\n"," Found existing installation: PyYAML 3.13\n"," Uninstalling PyYAML-3.13:\n"," Successfully uninstalled PyYAML-3.13\n","Successfully installed huggingface-hub-0.4.0 pyyaml-6.0 sacremoses-0.0.47 tokenizers-0.11.6 transformers-4.17.0\n"]}]},{"cell_type":"code","execution_count":2,"source":["import json\n","import torch\n","import os\n","from tokenizers import BertWordPieceTokenizer\n","from torch.utils.data import Dataset, DataLoader\n","from typing import List\n","import numpy as np\n","import torch.nn as nn\n","from torch.nn import functional as F\n","from transformers import BertModel, BertPreTrainedModel, BertConfig, AdamW\n","from torch.nn.modules import BCEWithLogitsLoss"],"outputs":[],"metadata":{"id":"4EFfgZmo1Fz6","executionInfo":{"status":"ok","timestamp":1646996765745,"user_tz":-480,"elapsed":8794,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}}}},{"cell_type":"markdown","source":["## 1 数据集介绍\n","\n","从原论文中下载了 6 个数据集,\n","其中ACE2004、ACE2005和CONII03存在重叠实体,\n","GENIA、MSRA-ch和ONTO4-ch不存在重叠实体,\n","它们均已经被处理成MRC-NER的格式,\n","各数据集MRC格式的下载链接如下:\n","\n","- [MSRA-ch](https://drive.google.com/file/d/1bAoSJfT1IBdpbQWSrZPjQPPbAsDGlN2D/view?usp=sharing)\n","- [ONTO4-ch](https://drive.google.com/file/d/1CRVgZJDDGuj0O1NLK5DgujQBTLKyMR-g/view?usp=sharing)\n","- [CONLL03](https://drive.google.com/file/d/1mGO9CYkgXsV-Et-hSZpOmS0m9G8A5mau/view?usp=sharing)\n","- [ACE2004](https://drive.google.com/file/d/1U-hGOgLmdqudsRdKIGles1-QrNJ7SSg6/view?usp=sharing)\n","- [ACE2005](https://drive.google.com/file/d/1iodaJ92dTAjUWnkMyYm8aLEi5hj3cseY/view?usp=sharing)\n","- [GENIA](https://drive.google.com/file/d/1oF1P8s-0MN9X1M1PlKB2c5aBtxhmoxXb/view?usp=sharing)\n","\n","因为暂时找不到存在嵌套实体的中文数据集,\n","所以接下来以英文数据集 ACE2004 作为例子。\n","\n","---\n","\n","首先来看一下该数据集具体长什么样子。"],"metadata":{"id":"a05F1Ht_1Fz7"}},{"cell_type":"code","execution_count":3,"source":["project_path = '/content/drive/MyDrive/mrc-bert-ner/'\n","data_name = ['/ACE2004', '/ACE2005', '/CONII03', '/GENIA', '/MSRA-ch', '/ONTO4-ch', '/debug'][0] # 选择所要验证的数据集\n","file_path = [f'{project_path}data{data_name}/mrc-ner.train',f'{project_path}data{data_name}/mrc-ner.dev', f'{project_path}data{data_name}/mrc-ner.test']\n","with open(file_path[0]) as f: data = json.load(f)\n","data[7]"],"outputs":[{"output_type":"execute_result","data":{"text/plain":["{'context': 'The Chinese government and the Australian government signed an agreement today , wherein the Australian party would provide China with a preferential financial loan of 150 million Australian dollars .',\n"," 'end_position': [2, 1, 6, 5, 15, 14, 18, 27],\n"," 'entity_label': 'GPE',\n"," 'impossible': False,\n"," 'qas_id': '1.1',\n"," 'query': 'geographical political entities are geographical regions defined by political and or social groups such as countries, nations, regions, cities, states, government and its people.',\n"," 'span_position': ['0;2',\n"," '1;1',\n"," '4;6',\n"," '5;5',\n"," '13;15',\n"," '14;14',\n"," '18;18',\n"," '27;27'],\n"," 'start_position': [0, 1, 4, 5, 13, 14, 18, 27]}"]},"metadata":{},"execution_count":3}],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"vvLiwi7T1Fz9","executionInfo":{"status":"ok","timestamp":1646996771645,"user_tz":-480,"elapsed":1991,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"82dd8542-8c6d-40fb-9dad-7a3d249a17f4"}},{"cell_type":"markdown","source":["这是在Example2中的那句话,可以看到一个样本用字典来保存信息,其中各key的含义如下:\n","* context:句子的自然语言内容。\n","* end_position:各实体结束的位置。\n","* entity_label:实体的标签,这里的GPE指地理政治实体\n","* impossible:False代表该句子存在实体,True代表该句子不存在实体。\n","* qas_id:由于一条句子可以根据具体问题划分成多个样本,所以上面的1.1表示句子1中的第1个样本。\n","* query:指输入到模型的自然语言问题。\n","* span_position:具体实体所在的边界范围,比如'0;2'表示\"The Chinese government\"是一个GPE。\n","* start_position:各实体开始的位置。"],"metadata":{"id":"_0HmkexF1Fz_"}},{"cell_type":"code","execution_count":4,"source":["\"\"\"数据集中有哪些entity label,以及对应的query是什么,以及属于它们的样本的数量情况。\"\"\"\n","queries = {}\n","queries_count = {}\n","for d in data:\n"," if d['entity_label'] not in queries:\n"," queries[d['entity_label']] = d['query']\n"," queries_count[d['entity_label']] = 1\n"," if d['entity_label'] in queries_count:\n"," queries_count[d['entity_label']] += 1\n","print('entity label, count, query:')\n","for k in queries:\n"," print(k, queries_count[k], queries[k])"],"outputs":[{"output_type":"stream","name":"stdout","text":["entity label, count, query:\n","GPE 6203 geographical political entities are geographical regions defined by political and or social groups such as countries, nations, regions, cities, states, government and its people.\n","ORG 6203 organization entities are limited to companies, corporations, agencies, institutions and other groups of people.\n","PER 6203 a person entity is limited to human including a single individual or a group.\n","FAC 6203 facility entities are limited to buildings and other permanent man-made structures such as buildings, airports, highways, bridges.\n","VEH 6203 vehicle entities are physical devices primarily designed to move, carry, pull or push the transported object such as helicopters, trains, ship and motorcycles.\n","LOC 6203 location entities are limited to geographical entities such as geographical areas and landmasses, mountains, bodies of water, and geological formations.\n","WEA 6203 weapon entities are limited to physical devices such as instruments for physically harming such as guns, arms and gunpowder.\n"]}],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"x1gAOLJr1F0D","executionInfo":{"status":"ok","timestamp":1646996775506,"user_tz":-480,"elapsed":572,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"ce8b0aa6-ebe8-4993-c929-0b253bc110b3"}},{"cell_type":"markdown","source":["可以看到ACE2004的训练集中一共有7种类型的NER(GPE、ORG、PER、FAC、VEH、LOC和WEA)。\n","且各类样本的总数均是6203个。\n","\n","---\n","\n","接下来统计整个数据集中存在实体的句子有多少条,以及存在重叠实体的句子又有多少条。"],"metadata":{"id":"T61QPXfLTsCT"}},{"cell_type":"code","execution_count":5,"source":["\"\"\"验证数据集中的重叠实体情况\"\"\"\n","# \"\"\"step1:把从训练集、验证集和测试集中读出的数据都累加到一起。\"\"\"\n","with open(file_path[0]) as f: data = json.load(f)\n","with open(file_path[1]) as f: data += json.load(f)\n","with open(file_path[2]) as f: data += json.load(f)\n","\n","# \"\"\"step2:遍历整个数据集,建立以context作为key,span_position作为value的字典,并且只有存在实体的句子才会被加入字典之中。\"\"\"\n","context2spans = {}\n","for d in data:\n"," tmp_context = d['context']\n"," tmp_span = d['span_position']\n"," if not tmp_span: # 如果没有实体就跳过 \n"," continue \n"," else: # 如果有实体,就将span的字符串形式转换为列表形式,例如将[\"1;3\", \"2;4\"]转换为[[1,3], [2,4]]。目的是为了方便排序之后检测重叠实体。\n"," tmp_span = [[int(s.split(';')[0]), int(s.split(';')[1])] for s in tmp_span] \n","\n"," # 因为在MRC的格式中,一条相同的句子根据Query的不同可能被分成多个样本,所以需要把指向同一条句子的样本都放到同一个key中。\n"," if tmp_context in context2spans: \n"," context2spans[tmp_context].extend(tmp_span)\n"," else: \n"," context2spans[tmp_context] = tmp_span\n","\n","# \"\"\"step3:检测重叠实体,方法是将列表形式化后的span排序,然后检查第i个span的start是否小于等于第i-1个span的end,如果是,则把相应的key和value存入nested_example字典中。\"\"\"\n","nested_example = {}\n","for k in context2spans:\n"," span = context2spans[k]\n"," span.sort()\n","\n"," for i in range(1, len(span)):\n"," if span[i-1][1] >= span[i][0]:\n"," nested_example[k] = span\n"," break\n","\n","print(f'数据集:{data_name[1:]}中有实体的句子有{len(context2spans)}条,但有嵌套实体的句子有{len(nested_example)}条,例如下面就是一些嵌套句子:')\n","\n","breakpoint = 5\n","for k in nested_example:\n"," print(k, '\\n' ,nested_example[k])\n"," breakpoint -= 1\n"," if breakpoint == 0:\n"," break"],"outputs":[{"output_type":"stream","name":"stdout","text":["数据集:ACE2004中有实体的句子有6933条,但有嵌套实体的句子有3408条,例如下面就是一些嵌套句子:\n","The Chinese government and the Australian government signed an agreement today , wherein the Australian party would provide China with a preferential financial loan of 150 million Australian dollars . \n"," [[0, 2], [1, 1], [4, 6], [5, 5], [13, 15], [14, 14], [18, 18], [27, 27]]\n","Lasting for two days , the ' 94 Development Assistance Cooperation Annual Meeting between the Chinese government and the Australian government concluded today in Melbourne . \n"," [[14, 16], [15, 15], [18, 20], [19, 19], [24, 24]]\n","The Chinese delegation with Yongtu Long , assistant minister of the Ministry of Foreign Economy and Trade , as the delegation leader , and the Australian delegation with Flad , director general of Australian International Development Bureau Assistance Department as the delegation leader , chaired the meeting . \n"," [[0, 21], [1, 1], [4, 5], [7, 16], [10, 16], [19, 21], [20, 20], [24, 42], [25, 25], [28, 28], [30, 38], [33, 38], [40, 42], [41, 41]]\n","At the same time , the Australia side will provide China with a technical cooperation grant of 20 million Australian dollars , which will be mainly used in projects such as personnel training , supporting the poor , medical treatment , sanitation , etc . \n"," [[5, 7], [6, 6], [10, 10], [19, 19], [31, 31], [35, 36]]\n","Xinhua News Agency , Shanghai , December 27 , by wire ( reporter Kangxiong Luo ) \n"," [[0, 2], [4, 4], [12, 12], [12, 14]]\n"]}],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"L5VMzfIY1F0E","executionInfo":{"status":"ok","timestamp":1646996780540,"user_tz":-480,"elapsed":2355,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"4526c8a8-c4a5-4260-ea4a-4905a46b9de9"}},{"cell_type":"markdown","source":["## 2 数据预处理\n","\n","在模型训练开始之前,\n","需要将样本中的自然语言转换成模型可以计算的编码表示。\n","\n","`torch`提供了`DataLoader`来处理数据集,\n","但NLP的样本处理比较麻烦,\n","所以在将数据丢进`DataLoader`之前,\n","还需要对数据集做一些预处理,\n","具体的预处理过程如下面的`MRCNERDataset(Dataset)`类所示。\n","\n","`MRCNERDataset(Dataset)`类的主要作用是:\n","* 将原数据转换成 BERT 需要的输入格式,\n","* 将实体开始位置、结束位置和span match转换成向量表示。\n","\n","`MRCNERDataset(Dataset)`类的具体解释见代码注释。"],"metadata":{"id":"K-eVlcCy1F0G"}},{"cell_type":"code","execution_count":6,"source":["class MRCNERDataset(Dataset):\n"," \"\"\"\n"," Args:\n"," json_path: mrc-ner格式的json文件路径。\n"," tokenizer: Bert的分词器。\n"," max_length: 输入Bert的最大序列长度,指word-piece之后的query + context + 3的最大长度\n"," possible_only: 如果设置为True,则只使用存在答案的样本(即存在实体的样本)。\n"," is_chinese: 是否是中文数据集,中文数据集需要去掉在token之间的空格。\n"," pad_to_maxlen: 是否填充至最大长度。\n"," Returns:\n"," tokens: tokens of query + context, shape is [seq_len].\n"," token_type_ids: token type ids, 0 for query, 1 for context, shape is [seq_len].\n"," start_labels: start labels of NER in tokens, shape is [seq_len].\n"," end_labels: end labels of NER in tokens, shape is [seq_len].\n"," label_mask: label mask, 1 for counting into loss, 0 for ignoring. shape is [seq_len].\n"," match_labels: match labels, shape is [seq_len, seq_len]. match_labels[i][j]==1意味着从i到j存在实体。\n"," sample_idx: sample id\n"," label_idx: label id\n"," \"\"\"\n"," def __init__(self, json_path, tokenizer: BertWordPieceTokenizer, max_length: int = 512, possible_only=False, \n"," is_chinese=False, pad_to_maxlen=False):\n"," self.all_data = json.load(open(json_path, encoding='utf-8'))\n"," self.tokenizer = tokenizer\n"," self.max_length = max_length\n"," self.possible_only = possible_only\n"," if self.possible_only:\n"," self.all_data = [x for x in self.all_data if x['start_position']]\n"," self.is_chinese = is_chinese\n"," self.pad_to_maxlen = pad_to_maxlen\n","\n"," def __len__(self):\n"," return len(self.all_data)\n"," \n"," def __getitem__(self, item):\n"," \n"," tokenizer = self.tokenizer\n","\n"," # Step 1: 读取样本存储在字典中的相关值\n"," data = self.all_data[item]\n"," qas_id = data.get('qas_id', '0.0') \n"," sample_idx, label_idx = qas_id.split('.')\n"," sample_idx = torch.LongTensor([int(sample_idx)])\n"," label_idx = torch.LongTensor([int(label_idx)])\n"," query = data['query']\n"," context = data['context']\n"," start_positions = data['start_position']\n"," end_positions = data['end_position']\n","\n"," # Step 2: 区分处理中文和英文数据集,原因如下:\n"," # 中文数据集中context的字符是用空格隔开的(query不存在这个问题),如\"上 海 浦 东 开 发 与 法 制 建 设 同 步\"。\n"," # 并且中文所用到预训练模型会将空格也当做一个token处理,所以需要去掉中文context中的空格。\n"," # 另外,英文数据集的分词考虑了空格,而原来的开始位置和结束位置没有考虑空格,所以它们需要重新映射。\n"," if self.is_chinese:\n"," context == ''.join(context.split())\n"," end_positions = [x+1 for x in end_positions]\n"," else:\n"," words = context.split()\n"," start_positions = [x + sum([len(w) for w in words[:x]]) for x in start_positions]\n"," end_positions = [x + sum([len(w) for w in words[:x + 1]]) for x in end_positions]\n","\n"," # Step 3: 用 BERT 分词器处理 query + context。\n"," query_context_tokens = tokenizer.encode(query, context, add_special_tokens=True)\n"," tokens = query_context_tokens.ids\n"," type_ids = query_context_tokens.type_ids\n"," offsets = query_context_tokens.offsets\n"," \n"," # Step 4: 重新修正开始位置和结束位置,原因如下:\n"," # - 因为添加了query在前面。\n"," # - 英文的分词是根据词根词缀划分的,即word-piece tokenize。\n"," origin_offset2token_idx_start = {}\n"," origin_offset2token_idx_end = {}\n"," for token_idx in range(len(tokens)):\n"," if type_ids[token_idx] == 0: # 跳过query的token\n"," continue\n"," token_start, token_end = offsets[token_idx]\n"," if token_start == token_end == 0: # 跳过[CLS]和[SEP]等特殊token\n"," continue\n"," origin_offset2token_idx_start[token_start] = token_idx\n"," origin_offset2token_idx_end[token_end] = token_idx\n"," new_start_positions = [origin_offset2token_idx_start[start] for start in start_positions]\n"," new_end_positions = [origin_offset2token_idx_end[end] for end in end_positions]\n","\n"," # Step 5: 建立开始位置、结束位置的mask以及向量表示。\n"," label_mask = [\n"," (0 if type_ids[token_idx] == 0 or offsets[token_idx] == (0, 0) else 1)\n"," for token_idx in range(len(tokens))\n"," ]\n"," start_label_mask = label_mask.copy()\n"," end_label_mask = label_mask.copy()\n"," # 同样是由于work-piece的问题,英文的mask需要特殊处理(中文不需要,中文的start_label_mask和end_label_mask是一样的)。\n"," # 比如对于'xinhua'这个单词,它被work-piece分成'xi ##nh ##ua'三个token,\n"," # 其对应的 start_label_mask 和 end_label_mask 分别是 [1,0,0] 和 [0,0,1]。\n"," if not self.is_chinese: \n"," for token_idx in range(len(tokens)):\n"," current_word_idx = query_context_tokens.word_ids[token_idx]\n"," next_word_idx = query_context_tokens.word_ids[token_idx+1] if token_idx+1 < len(tokens) else None\n"," prev_word_idx = query_context_tokens.word_ids[token_idx-1] if token_idx-1 > 0 else None\n"," if prev_word_idx is not None and current_word_idx == prev_word_idx:\n"," start_label_mask[token_idx] = 0\n"," if next_word_idx is not None and current_word_idx == next_word_idx:\n"," end_label_mask[token_idx] = 0\n"," assert all(start_label_mask[p] != 0 for p in new_start_positions)\n"," assert all(end_label_mask[p] != 0 for p in new_end_positions)\n"," assert len(new_start_positions) == len(new_end_positions) == len(start_positions)\n"," assert len(label_mask) == len(tokens)\n"," start_labels = [(1 if idx in new_start_positions else 0) for idx in range(len(tokens))] # 开始位置的向量表示\n"," end_labels = [(1 if idx in new_end_positions else 0) for idx in range(len(tokens))] # 结束位置的向量表示\n","\n"," # Step 6: 按照句子最大长度截断(如果超出)、并保证最后一个token是[SEP]。\n"," tokens = tokens[: self.max_length]\n"," type_ids = type_ids[: self.max_length]\n"," start_labels = start_labels[: self.max_length]\n"," end_labels = end_labels[: self.max_length]\n"," start_label_mask = start_label_mask[: self.max_length]\n"," end_label_mask = end_label_mask[: self.max_length]\n"," sep_token = tokenizer.token_to_id('[SEP]')\n"," if tokens[-1] != sep_token:\n"," assert len(tokens) == self.max_length\n"," tokens = tokens[: -1] + [sep_token]\n"," start_labels[-1] = 0\n"," end_labels[-1] = 0\n"," start_label_mask[-1] = 0\n"," end_label_mask[-1] = 0\n","\n"," # Step 7: 填充句子\n"," if self.pad_to_maxlen:\n"," tokens = self.pad(tokens, 0)\n"," type_ids = self.pad(type_ids, 1)\n"," start_labels = self.pad(start_labels)\n"," end_labels = self.pad(end_labels)\n"," start_label_mask = self.pad(start_label_mask)\n"," end_label_mask = self.pad(end_label_mask)\n"," \n"," # Step 8: 生成 span match 矩阵\n"," seq_len = len(tokens)\n"," match_labels = torch.zeros([seq_len, seq_len], dtype=torch.long)\n"," for start, end in zip(new_start_positions, new_end_positions):\n"," if start >= seq_len or end >= seq_len:\n"," continue\n"," match_labels[start, end] = 1\n"," \n"," return [\n"," torch.LongTensor(tokens),\n"," torch.LongTensor(type_ids),\n"," torch.LongTensor(start_labels),\n"," torch.LongTensor(end_labels),\n"," torch.LongTensor(start_label_mask),\n"," torch.LongTensor(end_label_mask),\n"," match_labels,\n"," sample_idx,\n"," label_idx\n"," ]\n"," \n"," def pad(self, lst, value=0, max_length=None):\n"," max_length = max_length or self.max_length\n"," while len(lst) < max_length:\n"," lst.append(value)\n"," return lst\n","\n","def collate_to_max_length(batch: List[List[torch.Tensor]]) -> List[torch.Tensor]:\n"," \"\"\"\n"," 若当前batch中样本的最大句子长度为batch_max_seq_length < max_seq_length,\n"," 则当前batch的其余句子均填充到batch_max_seq_length即可。\n","\n"," pad to maximum length of this batch\n"," Args:\n"," batch: a batch of samples, each contains a list of field data(Tensor):\n"," tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx\n"," Returns:\n"," output: list of field batched data, which shape is [batch, max_length]\n"," \"\"\"\n"," batch_size = len(batch)\n"," max_length = max(x[0].shape[0] for x in batch)\n"," output = []\n","\n"," for field_idx in range(6):\n"," pad_output = torch.full([batch_size, max_length], 0, dtype=batch[0][field_idx].dtype)\n"," for sample_idx in range(batch_size):\n"," data = batch[sample_idx][field_idx]\n"," pad_output[sample_idx][: data.shape[0]] = data\n"," output.append(pad_output)\n","\n"," pad_match_labels = torch.zeros([batch_size, max_length, max_length], dtype=torch.long)\n"," for sample_idx in range(batch_size):\n"," data = batch[sample_idx][6]\n"," pad_match_labels[sample_idx, : data.shape[1], : data.shape[1]] = data\n"," output.append(pad_match_labels)\n","\n"," output.append(torch.stack([x[-2] for x in batch]))\n"," output.append(torch.stack([x[-1] for x in batch]))\n","\n"," return output\n"],"outputs":[],"metadata":{"id":"JDpUNi6b1F0G","executionInfo":{"status":"ok","timestamp":1646996787572,"user_tz":-480,"elapsed":486,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}}}},{"cell_type":"markdown","source":["下面函数显示了`MRCNERDataset(Dataset)`类对具体数据的处理结果。"],"metadata":{"id":"YmChklXlKbKI"}},{"cell_type":"code","execution_count":7,"source":["def show_data():\n"," \"\"\"查看MERNERDataset类对原数据处理后的结果\"\"\"\n","\n"," # 中文数据集\n"," bert_path = project_path + \"bert-base-chinese\"\n"," vocab_file = os.path.join(bert_path, \"vocab.txt\")\n"," json_path = file_path[0] # 训练集、验证集、测试集分别对应下标0,1,2\n"," is_chinese = True\n","\n"," # 若该数据集是英文数据集,请把下面4语句注释掉。\n"," bert_path = project_path + \"bert-base-english\"\n"," json_path = file_path[0]\n"," vocab_file = os.path.join(bert_path, \"vocab.txt\")\n"," is_chinese = False\n","\n"," tokenizer = BertWordPieceTokenizer(vocab_file)\n"," dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer, is_chinese=is_chinese)\n","\n"," dataloader = DataLoader(dataset, batch_size=1, collate_fn=collate_to_max_length)\n"," \n"," for batch in dataloader:\n"," for tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx in zip(*batch): # 让数据逐条析出\n"," \n"," if sample_idx != 1 or label_idx != 1: # 是为了只显示Example2所在的那个样本。\n"," continue\n","\n"," print('-*-' * 5, 'MRCNERDataset对Example2中那条句子处理后的结果如下', '-*-' * 5)\n"," print('sample_idx:',sample_idx.item())\n"," print('label_idx:', label_idx.item())\n"," print('word-piece之后的长度:', len(tokens))\n","\n"," tokens = tokens.tolist()\n"," print('=='*3, 'query+contenxt编码前的表示:','=='*3, '\\n', tokenizer.decode(tokens, skip_special_tokens=False))\n"," print('=='*3, 'query+contenxt编码后的表示:','=='*3, '\\n', tokens)\n"," print('=='*3,'token_type_ids:','=='*3, '\\n', token_type_ids.tolist())\n"," print('=='*3,'start_labels:','=='*3, '\\n', start_labels.tolist())\n"," print('=='*3,'end_labels:','=='*3,'\\n', end_labels.tolist())\n"," print('=='*3,'start_label_mask:','=='*3, '\\n', start_label_mask.tolist())\n"," print('=='*3,'end_label_mask:','=='*3,'\\n', end_label_mask.tolist())\n"," print('=='*3, 'match_labels中等于1的位置', '=='*3)\n"," print(np.argwhere(match_labels.numpy() == 1).tolist())\n"," return\n","\n","show_data()"],"outputs":[{"output_type":"stream","name":"stdout","text":["-*--*--*--*--*- MRCNERDataset对Example2中那条句子处理后的结果如下 -*--*--*--*--*-\n","sample_idx: 1\n","label_idx: 1\n","word-piece之后的长度: 64\n","====== query+contenxt编码前的表示: ====== \n"," [CLS] geographical political entities are geographical regions defined by political and or social groups such as countries, nations, regions, cities, states, government and its people. [SEP] the chinese government and the australian government signed an agreement today, wherein the australian party would provide china with a preferential financial loan of 150 million australian dollars. [SEP]\n","====== query+contenxt编码后的表示: ====== \n"," [101, 10056, 2576, 11422, 2024, 10056, 4655, 4225, 2011, 2576, 1998, 2030, 2591, 2967, 2107, 2004, 3032, 1010, 3741, 1010, 4655, 1010, 3655, 1010, 2163, 1010, 2231, 1998, 2049, 2111, 1012, 102, 1996, 2822, 2231, 1998, 1996, 2827, 2231, 2772, 2019, 3820, 2651, 1010, 16726, 1996, 2827, 2283, 2052, 3073, 2859, 2007, 1037, 9544, 24271, 3361, 5414, 1997, 5018, 2454, 2827, 6363, 1012, 102]\n","====== token_type_ids: ====== \n"," [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n","====== start_labels: ====== \n"," [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]\n","====== end_labels: ====== \n"," [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]\n","====== start_label_mask: ====== \n"," [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0]\n","====== end_label_mask: ====== \n"," [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]\n","====== match_labels中等于1的位置 ======\n","[[32, 34], [33, 33], [36, 38], [37, 37], [45, 47], [46, 46], [50, 50], [60, 60]]\n"]}],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"MsXAzR791F0I","executionInfo":{"status":"ok","timestamp":1646996795882,"user_tz":-480,"elapsed":1668,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"a57fb995-c6c2-4282-c56f-3a4d56203251"}},{"cell_type":"markdown","source":["## 3 模型实现\n","\n","如 一.3 模型设计中所介绍,模型主要包括BERT和三个二分类器。\n","\n","在下面的代码实现中,模型具体包括的部分如下:\n","- 作为词嵌入的 BERT。(对应`BertQueryNER(BertPreTrainedModel)`类中的`self.bert`)\n","- 二分类 $token_i$ 是否是开始位置的 Start Outpus Classifier。(对应`BertQueryNER(BertPreTrainedModel)`类中的`self.start_outputs`)\n","- 二分类 $token_j$ 是否是结束位置的 End Outpus Classifier。(对应`BertQueryNER(BertPreTrainedModel)`类中的`self.end_outputs`)\n","- 二分类 $tokens_{i,j}$(start token和end token的词嵌入拼接表示)是否匹配的 Span Match Output Classifier。(对应`BertQueryNER(BertPreTrainedModel)`类中的`self.span_embedding`,具体的实现在`MultiNonLinearClassifier`类)\n","- 3 个Classifier输出的损失计算(对于函数`compute_loss`)。\n","\n","**细节见代码注释。**"],"metadata":{"id":"UgzoQPNo1F0J"}},{"cell_type":"code","execution_count":8,"source":["class BertQueryNER(BertPreTrainedModel):\n"," def __init__(self, config):\n"," super(BertQueryNER, self).__init__(config)\n","\n"," self.bert = BertModel(config) # Bert\n"," self.start_outputs = nn.Linear(config.hidden_size, 1) # 开始位置分类器 Start Output Classifier\n"," self.end_outputs = nn.Linear(config.hidden_size, 1) # 结束位置分类器 End Output Classifier\n"," self.span_embedding = MultiNonLinearClassifier(config.hidden_size * 2, 1, # 边界匹配分类器 Span Match Output Classifier\n"," config.mrc_dropout, \n"," intermediate_hidden_size=config.classifier_intermediate_hidden_size)\n","\n"," self.hidden_size = config.hidden_size\n","\n"," self.init_weights() # 权重初始化\n","\n"," def forward(self, input_ids, token_type_ids=None, attention_mask=None):\n"," \"\"\"\n"," Args:\n"," input_ids: bert input tokens, tensor of shape [seq_len]\n"," token_type_ids: 0 for query, 1 for context, tensor of shape [seq_len]\n"," attention_mask: attention mask, tensor of shape [seq_len]\n"," Returns:\n"," start_logits: start/non-start probs of shape [seq_len]\n"," end_logits: end/non-end probs of shape [seq_len]\n"," match_logits: start-end-match probs of shape [seq_len, 1]\n"," \"\"\"\n","\n"," # Step 1: 获得所有token的BERT词嵌入\n"," bert_outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)\n"," sequence_heatmap = bert_outputs[0] # [batch, seq_len, hidden]\n"," batch_size, seq_len, hid_size = sequence_heatmap.size()\n","\n"," # Step 2: 计算 tokens 是开始位置的 logits 预测值\n"," start_logits = self.start_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]\n","\n"," # Step 3: 计算 tokens 是结束位置的 logits 预测值\n"," end_logits = self.end_outputs(sequence_heatmap).squeeze(-1) # [batch, seq_len, 1]\n","\n"," # Step 4: 逐个拼接句子中所有的 tokens 为 Span Match Output Classifier 做准备,\n"," # 最后做成一个shape为[batch, seq_len, seq_len, hidden*2]的张量span_matrix。\n"," start_extend = sequence_heatmap.unsqueeze(2).expand(-1, -1, seq_len, -1) # [batch, seq_len, seq_len, hidden]\n"," end_extend = sequence_heatmap.unsqueeze(1).expand(-1, seq_len, -1, -1) # [batch, seq_len, seq_len, hidden]\n"," span_matrix = torch.cat([start_extend, end_extend], 3) # [batch, seq_len, seq_len, hidden*2]\n"," \n"," # Step 5: 计算 span matrix 中的 logits 预测值,\n"," # 从 span matrix 中的 [batch, seq_len, seq_len, hidden*2] 变成 span_logits 中的[batch, seq_len, seq_len]\n"," # 其中 span_logits[i][j] 表示第 i 个 tokens 作为开始位置,第 j 个 tokens 作为结束位置的匹配预测值。\n"," span_logits = self.span_embedding(span_matrix).squeeze(-1) # [batch, seq_len, seq_len]\n"," \n","\n"," return start_logits, end_logits, span_logits\n","\n","class MultiNonLinearClassifier(nn.Module):\n"," def __init__(self, hidden_size, num_label, dropout_rate, act_func=\"gelu\", intermediate_hidden_size=None):\n"," super(MultiNonLinearClassifier, self).__init__()\n"," \n"," self.num_label = num_label\n"," self.intermediate_hidden_size = hidden_size if intermediate_hidden_size is None else intermediate_hidden_size\n"," self.classifier1 = nn.Linear(hidden_size, self.intermediate_hidden_size)\n"," self.classifier2 = nn.Linear(self.intermediate_hidden_size, self.num_label)\n"," self.dropout = nn.Dropout(dropout_rate)\n"," self.act_func = act_func\n","\n"," def forward(self, input_features):\n"," \"\"\"\n"," 令 input_features 为 X,output_features 为 O,则forward的逻辑就是一个MLP:\n"," O = W2 \\cdot dropout(activete(W1 \\cdot X))\n"," \"\"\"\n"," features_output1 = self.classifier1(input_features)\n","\n"," if self.act_func == \"gelu\":\n"," features_output1 = F.gelu(features_output1)\n"," elif self.act_func == \"relu\":\n"," features_output1 = F.relu(features_output1)\n"," elif self.act_func == \"tanh\":\n"," features_output1 = F.tanh(features_output1)\n"," else:\n"," raise ValueError\n"," features_output1 = self.dropout(features_output1)\n"," features_output2 = self.classifier2(features_output1)\n"," return features_output2\n","\n","class BertQueryNerConfig(BertConfig):\n"," def __init__(self, **kwargs):\n"," super(BertQueryNerConfig, self).__init__(**kwargs)\n"," self.mrc_dropout = kwargs.get(\"mrc_dropout\", 0.1)\n"," self.classifier_intermediate_hidden_size = kwargs.get(\"classifier_intermediate_hidden_size\", 1024)\n"," self.classifier_act_func = kwargs.get(\"classifier_act_func\", \"gelu\")\n","\n","def compute_loss(start_logits, end_logits, span_logits, \n"," start_labels, end_labels, match_labels, \n"," start_label_mask, end_label_mask):\n"," \"\"\"\n"," 分别计算start_logits、end_logits和span_logits的损失值,这里是计算所有位置上的损失,\n"," 但考虑到负样本是多数类,正样本是少数类,所以实际上还有更好的处理方法。\n"," \"\"\"\n","\n"," batch_size, seq_len = start_logits.size()\n"," bce_loss = BCEWithLogitsLoss(reduction='none')\n","\n"," start_float_label_mask = start_label_mask.view(-1).float() # shape=batch x n\n"," end_float_label_mask = end_label_mask.view(-1).float()\n"," match_label_row_mask = start_label_mask.bool().unsqueeze(-1).expand(-1, -1, seq_len)\n"," match_label_col_mask = end_label_mask.bool().unsqueeze(-2).expand(-1, seq_len, -1)\n"," match_label_mask = match_label_row_mask & match_label_col_mask\n"," match_label_mask = torch.triu(match_label_mask, 0) # start should be less equal to end\n","\n"," float_match_label_mask = match_label_mask.view(batch_size, -1).float()\n","\n"," start_loss = bce_loss(start_logits.view(-1), start_labels.view(-1).float())\n"," start_loss = (start_loss * start_float_label_mask).sum() / start_float_label_mask.sum()\n"," end_loss = bce_loss(end_logits.view(-1), end_labels.view(-1).float())\n"," end_loss = (end_loss * end_float_label_mask).sum() / end_float_label_mask.sum()\n"," match_loss = bce_loss(span_logits.view(batch_size, -1), match_labels.view(batch_size, -1).float())\n"," match_loss = match_loss * float_match_label_mask\n"," match_loss = match_loss.sum() / (float_match_label_mask.sum() + 1e-10)\n","\n"," return start_loss, end_loss, match_loss"],"outputs":[],"metadata":{"id":"M5nv9G8a1F0J","executionInfo":{"status":"ok","timestamp":1646996799374,"user_tz":-480,"elapsed":1,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}}}},{"cell_type":"markdown","source":["## 4 模型训练\n","\n","英文所用预训练模型的下载链接是:\n","https://huggingface.co/bert-base-uncased\n","\n","中文所用预训练模型的下载链接是:\n","https://huggingface.co/bert-base-chinese"],"metadata":{"id":"HZuc3KPJ1F0L"}},{"cell_type":"code","source":["def query_span_f1(start_preds, end_preds, match_logits, start_label_mask, end_label_mask, match_labels, flat=False):\n"," \"\"\"\n"," 根据模型的输出,计算span的F1值。\n"," Args:\n"," start_preds: [bsz, seq_len]\n"," end_preds: [bsz, seq_len]\n"," match_logits: [bsz, seq_len, seq_len]\n"," start_label_mask: [bsz, seq_len]\n"," end_label_mask: [bsz, seq_len]\n"," match_labels: [bsz, seq_len, seq_len]\n"," flat: if True, decode as flat-ner\n"," Returns:\n"," span-f1 counts, tensor of shape [3]: tp, fp, fn\n"," \"\"\"\n"," # 将0或1值转换成布尔值\n"," start_label_mask = start_label_mask.bool()\n"," end_label_mask = end_label_mask.bool()\n"," match_labels = match_labels.bool()\n"," \n"," bsz, seq_len = start_label_mask.size()\n"," \n"," match_preds = match_logits > 0 # [bsz, seq_len, seq_len]\n"," start_preds = start_preds.bool() # [bsz, seq_len]\n"," end_preds = end_preds.bool() # [bsz, seq_len]\n","\n"," match_preds = (match_preds & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1)) # 让start、end(expand之后)和match对应位置进行与运算\n"," match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1)) # 根据start和end的mask算出match的mask\n"," match_label_mask = torch.triu(match_label_mask, 0) # 保证实体开始的位置小于等于结束的位置\n"," match_preds = match_label_mask & match_preds\n","\n"," tp = (match_labels & match_preds).long().sum() # TRUE POSITIVE\n"," fp = (~match_labels & match_preds).long().sum() # FALSE POSITIVE\n"," fn = (match_labels & ~match_preds).long().sum() # FALSE NEGETIVE\n"," return torch.stack([tp, fp, fn])"],"metadata":{"id":"jj2R9Vyl2W8b","executionInfo":{"status":"ok","timestamp":1646971304431,"user_tz":-480,"elapsed":3,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}}},"execution_count":11,"outputs":[]},{"cell_type":"code","execution_count":null,"source":["# 超参数设置\n","alpha, beta, gamma = 1, 1, 1\n","EPOCHS = 4\n","device = 'cuda' if torch.cuda.is_available() else 'cpu'\n","lr = 2e-5\n","adam_eps = 1e-8\n","wd = 0.01\n","best_dev_f1 = -1\n","bs = 8\n","\n","# model\n","bert_config_dir = project_path + 'bert-base-english'\n","bert_config = BertQueryNerConfig.from_pretrained(bert_config_dir,\n"," hidden_dropout_prob=0.1,\n"," attention_probs_dropout_prob=0.1,\n"," mrc_dropout=0.1,\n"," classifier_act_func = 'gelu',\n"," classifier_intermediate_hidden_size=1024)\n","model = BertQueryNER.from_pretrained(bert_config_dir, config=bert_config)\n","model.to(device)\n","\n","# optimizer\n","no_decay = [\"bias\", \"LayerNorm.weight\"]\n","optimizer_grouped_parameters = [\n"," {\n"," \"params\": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n"," \"weight_decay\": wd,\n"," },\n"," {\n"," \"params\": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n"," \"weight_decay\": 0.0,\n"," },\n","]\n","optimizer = AdamW(optimizer_grouped_parameters, betas=(0.9, 0.98), lr=lr, eps=adam_eps,)\n","\n","# dataloader\n","json_path = file_path[0]\n","vocab_file = os.path.join(bert_config_dir, \"vocab.txt\")\n","is_chinese = False\n","tokenizer = BertWordPieceTokenizer(vocab_file)\n","train_dataset = MRCNERDataset(json_path=file_path[0], tokenizer=tokenizer, is_chinese=is_chinese)\n","train_dataloader = DataLoader(train_dataset, batch_size=bs, collate_fn=collate_to_max_length, shuffle=True)\n","dev_dataset = MRCNERDataset(json_path=file_path[1], tokenizer=tokenizer, is_chinese=is_chinese)\n","dev_dataloader = DataLoader(dev_dataset, batch_size=bs, collate_fn=collate_to_max_length, shuffle=False)\n","\n","# 训练模型\n","for epoch in range(EPOCHS):\n","\n"," model.train()\n","\n"," for i, batch in enumerate(train_dataloader):\n"," batch = tuple(b.to(device) for b in batch)\n"," tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, _, _ = batch\n"," attention_mask = (tokens != 0).long() # 只要token不是[PAD]就置1\n"," start_logits, end_logits, span_logits = model(tokens, token_type_ids, attention_mask)\n"," start_loss, end_loss, match_loss = compute_loss(start_logits=start_logits,\n"," end_logits=end_logits,\n"," span_logits=span_logits,\n"," start_labels=start_labels,\n"," end_labels=end_labels,\n"," match_labels=match_labels,\n"," start_label_mask=start_label_mask,\n"," end_label_mask=end_label_mask)\n"," total_loss = alpha * start_loss + beta * end_loss + gamma * match_loss\n"," model.zero_grad()\n"," total_loss.backward()\n"," optimizer.step()\n","\n"," # 定点评估验证集并保存模型\n"," if i > 3 and i % 2000 == 0:\n"," model.eval()\n"," dev_total_loss = 0\n"," dev_total_span_f1 = 0\n"," count_batch = 0\n"," \n"," for j, dev_batch in enumerate(dev_dataloader):\n"," dev_batch = tuple(dev_b.to(device) for dev_b in dev_batch)\n"," dev_tokens, dev_token_type_ids, dev_start_labels, dev_end_labels, dev_start_label_mask, dev_end_label_mask, dev_match_labels, _, _ = dev_batch\n"," dev_attention_mask = (dev_tokens != 0).long() # 只要token不是[PAD]就置1\n"," with torch.no_grad():\n"," \n"," # 计算验证集当前batch的loss\n"," dev_start_logits, dev_end_logits, dev_span_logits = model(dev_tokens, dev_token_type_ids, dev_attention_mask)\n"," dev_start_loss, dev_end_loss, dev_match_loss = compute_loss(start_logits=dev_start_logits,\n"," end_logits=dev_end_logits,\n"," span_logits=dev_span_logits,\n"," start_labels=dev_start_labels,\n"," end_labels=dev_end_labels,\n"," match_labels=dev_match_labels,\n"," start_label_mask=dev_start_label_mask,\n"," end_label_mask=dev_end_label_mask)\n"," dev_total_loss += alpha * dev_start_loss + beta * dev_end_loss + gamma * dev_match_loss\n","\n"," # 计算验证集当前batch的F1值 \n"," span_f1_state = query_span_f1(dev_start_logits, dev_end_logits, dev_span_logits, dev_start_label_mask, dev_end_label_mask, dev_match_labels)\n"," all_counts = torch.stack([x for x in span_f1_state]).view(-1, 3).sum(0)\n"," span_tp, span_fp, span_fn = all_counts\n"," span_recall = span_tp / (span_tp + span_fn + 1e-10)\n"," span_precision = span_tp / (span_tp + span_fp + 1e-10)\n"," dev_total_span_f1 += span_precision * span_recall * 2 / (span_recall + span_precision + 1e-10)\n","\n"," count_batch += 1\n","\n"," dev_loss = dev_total_loss / count_batch # 验证集最终的loss\n"," dev_span_f1 = dev_total_span_f1 / count_batch # 验证集最终的span f1 值\n"," print(f'epoch:{epoch}, batch:{i}, train loss:{total_loss.item()}, dev loss:{dev_loss.item()}, dev span f1:{dev_span_f1.item()}')\n"," \n"," # 如果此次验证集的F1大于当前最好的F1,则保存模型\n"," if best_dev_f1 < dev_span_f1:\n"," best_dev_f1 = dev_span_f1\n"," torch.save(model, project_path + 'model_save_new.pth')\n"," print('SAVE!')"],"outputs":[{"metadata":{"tags":null},"name":"stderr","output_type":"stream","text":["Some weights of the model checkpoint at /content/drive/MyDrive/mrc-bert-ner/bert-base-english were not used when initializing BertQueryNER: ['cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']\n","- This IS expected if you are initializing BertQueryNER from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n","- This IS NOT expected if you are initializing BertQueryNER from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n","Some weights of BertQueryNER were not initialized from the model checkpoint at /content/drive/MyDrive/mrc-bert-ner/bert-base-english and are newly initialized: ['start_outputs.bias', 'end_outputs.bias', 'span_embedding.classifier1.weight', 'end_outputs.weight', 'span_embedding.classifier1.bias', 'start_outputs.weight', 'span_embedding.classifier2.weight', 'span_embedding.classifier2.bias']\n","You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n","/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n"," FutureWarning,\n"]},{"output_type":"stream","name":"stdout","text":["epoch:0, batch:2000, train loss:0.030432697385549545, dev loss:0.0378371886909008, dev span f1:0.503181517124176\n","SAVE!\n","epoch:0, batch:4000, train loss:0.029311735183000565, dev loss:0.03624585270881653, dev span f1:0.6129595041275024\n","SAVE!\n","epoch:1, batch:2000, train loss:0.004096935503184795, dev loss:0.03656528517603874, dev span f1:0.7015897631645203\n","SAVE!\n","epoch:1, batch:4000, train loss:0.008193296380341053, dev loss:0.032151203602552414, dev span f1:0.6943182945251465\n","epoch:2, batch:2000, train loss:0.00279993936419487, dev loss:0.04054645821452141, dev span f1:0.716867983341217\n","SAVE!\n","epoch:2, batch:4000, train loss:0.042344387620687485, dev loss:0.03349826857447624, dev span f1:0.7225610613822937\n","SAVE!\n","epoch:3, batch:2000, train loss:0.0008626565686427057, dev loss:0.04692533612251282, dev span f1:0.73361736536026\n","SAVE!\n","epoch:3, batch:4000, train loss:0.001133581972680986, dev loss:0.03900938108563423, dev span f1:0.7114824652671814\n"]}],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"p0hKAbd51F0M","outputId":"96fe61d0-154e-46d5-e40f-2d5cebd2cdac"}},{"cell_type":"markdown","source":["## 5 推理演示"],"metadata":{"id":"WtgXP2iCQIad"}},{"cell_type":"code","source":["def extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask):\n"," \"\"\"根据模型给出的预测,返回实体所在的位置\"\"\"\n"," start_label_mask = start_label_mask.bool()\n"," end_label_mask = end_label_mask.bool()\n"," bsz, seq_len = start_label_mask.size()\n"," start_preds = start_preds.bool()\n"," end_preds = end_preds.bool()\n","\n"," match_preds = (match_preds & start_preds.unsqueeze(-1).expand(-1, -1, seq_len) & end_preds.unsqueeze(1).expand(-1, seq_len, -1))\n"," match_label_mask = (start_label_mask.unsqueeze(-1).expand(-1, -1, seq_len) & end_label_mask.unsqueeze(1).expand(-1, seq_len, -1))\n"," match_label_mask = torch.triu(match_label_mask, 0) # start should be less or equal to end\n"," match_preds = match_label_mask & match_preds\n"," match_pos_pairs = np.transpose(np.nonzero(match_preds.numpy())).tolist()\n","\n"," return [(pos[1], pos[2]) for pos in match_pos_pairs]"],"metadata":{"id":"BdR13IO2LFSp","executionInfo":{"status":"ok","timestamp":1646997916001,"user_tz":-480,"elapsed":585,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}}},"execution_count":24,"outputs":[]},{"cell_type":"code","source":["trained_model = torch.load(project_path + 'trained_model.pth', map_location=torch.device('cpu'))\n","\n","bert_config_dir = project_path + 'bert-base-english'\n","json_path = file_path[2]\n","vocab_file = os.path.join(bert_config_dir, \"vocab.txt\")\n","is_chinese = False\n","tokenizer = BertWordPieceTokenizer(vocab_file)\n","\n","test_dataset = MRCNERDataset(json_path=json_path, tokenizer=tokenizer, is_chinese=False, possible_only=True)\n","test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=collate_to_max_length, shuffle=False)\n","\n","check_ignore = 100\n","check_dur = 10\n","print('\\n注意:以下位置是指query+context之后并经过word-piece处理之后的绝对位置\\n')\n","for i, batch in enumerate(test_dataloader):\n"," if i > check_ignore:\n"," tokens, token_type_ids, start_labels, end_labels, start_label_mask, end_label_mask, match_labels, sample_idx, label_idx = batch\n"," attention_mask = (tokens != 0).long()\n"," start_logits, end_logits, span_logits = trained_model(tokens, attention_mask=attention_mask, token_type_ids=token_type_ids)\n"," start_preds, end_preds, span_preds = start_logits > 0, end_logits > 0, span_logits > 0\n"," match_preds = span_logits > 0\n"," sentence = tokenizer.decode(tokens[0].tolist(), skip_special_tokens=False).split('[SEP]')\n"," print(f'example{i-check_ignore}:')\n"," print('\\tQUERY 是 :', sentence[0][6:])\n"," print('\\tCONTEXT是 :', sentence[1][1:])\n"," \n"," infer_pos = extract_nested_spans(start_preds, end_preds, match_preds, start_label_mask, end_label_mask)\n"," real_pos = [(pos[0], pos[1]) for pos in np.argwhere(match_labels[0].numpy() == 1)]\n"," \n"," print('\\t推测的NER位置:', infer_pos)\n"," print('\\t真实的NER位置:', real_pos)\n"," \n"," print('------' * 20)\n"," if i == check_ignore + check_dur:\n"," break\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":0},"id":"KQGrdLfmNtzP","executionInfo":{"status":"ok","timestamp":1647000088163,"user_tz":-480,"elapsed":4663,"user":{"displayName":"龙泳潮","photoUrl":"https://lh3.googleusercontent.com/a/default-user=s64","userId":"16304534726243609269"}},"outputId":"2d510e2a-de2f-4af0-8fa4-dc16dae46365"},"execution_count":61,"outputs":[{"output_type":"stream","name":"stdout","text":["\n","注意:以下位置是指query+context之后并经过word-piece处理之后的绝对位置\n","\n","example1:\n","\tQUERY 是 : organization entities are limited to companies, corporations, agencies, institutions and other groups of people. \n","\tCONTEXT是 : indonesia's reformist - minded president abdurrahman wahid has blamed the army and police for the deaths. \n","\t推测的NER位置: [(36, 39)]\n","\t真实的NER位置: [(36, 39)]\n","------------------------------------------------------------------------------------------------------------------------\n","example2:\n","\tQUERY 是 : a person entity is limited to human including a single individual or a group. \n","\tCONTEXT是 : indonesia's reformist - minded president abdurrahman wahid has blamed the army and police for the deaths. \n","\t推测的NER位置: [(17, 24), (25, 30)]\n","\t真实的NER位置: [(17, 24), (17, 30)]\n","------------------------------------------------------------------------------------------------------------------------\n","example3:\n","\tQUERY 是 : geographical political entities are geographical regions defined by political and or social groups such as countries, nations, regions, cities, states, government and its people. \n","\tCONTEXT是 : in the provincial capital banda aceh, indonesian security forces opened fire at two cars after the rally saturday night, injuring six local residents. \n","\t推测的NER位置: [(33, 35), (33, 38), (36, 38), (40, 40)]\n","\t真实的NER位置: [(33, 35), (36, 38), (40, 40)]\n","------------------------------------------------------------------------------------------------------------------------\n","example4:\n","\tQUERY 是 : a person entity is limited to human including a single individual or a group. \n","\tCONTEXT是 : in the provincial capital banda aceh, indonesian security forces opened fire at two cars after the rally saturday night, injuring six local residents. \n","\t推测的NER位置: [(25, 27), (40, 42)]\n","\t真实的NER位置: [(25, 27), (40, 42)]\n","------------------------------------------------------------------------------------------------------------------------\n","example5:\n","\tQUERY 是 : vehicle entities are physical devices primarily designed to move, carry, pull or push the transported object such as helicopters, trains, ship and motorcycles. \n","\tCONTEXT是 : in the provincial capital banda aceh, indonesian security forces opened fire at two cars after the rally saturday night, injuring six local residents. \n","\t推测的NER位置: [(44, 45)]\n","\t真实的NER位置: [(44, 45)]\n","------------------------------------------------------------------------------------------------------------------------\n","example6:\n","\tQUERY 是 : geographical political entities are geographical regions defined by political and or social groups such as countries, nations, regions, cities, states, government and its people. \n","\tCONTEXT是 : security forces patrolled the city in armored vehicles, and officers carried out random identification checks, looking for separatist leaders. \n","\t推测的NER位置: [(36, 37)]\n","\t真实的NER位置: [(36, 37)]\n","------------------------------------------------------------------------------------------------------------------------\n","example7:\n","\tQUERY 是 : a person entity is limited to human including a single individual or a group. \n","\tCONTEXT是 : security forces patrolled the city in armored vehicles, and officers carried out random identification checks, looking for separatist leaders. \n","\t推测的NER位置: [(17, 18), (28, 28), (37, 40)]\n","\t真实的NER位置: [(17, 18), (28, 28), (37, 40)]\n","------------------------------------------------------------------------------------------------------------------------\n","example8:\n","\tQUERY 是 : vehicle entities are physical devices primarily designed to move, carry, pull or push the transported object such as helicopters, trains, ship and motorcycles. \n","\tCONTEXT是 : security forces patrolled the city in armored vehicles, and officers carried out random identification checks, looking for separatist leaders. \n","\t推测的NER位置: [(37, 38)]\n","\t真实的NER位置: [(37, 38)]\n","------------------------------------------------------------------------------------------------------------------------\n","example9:\n","\tQUERY 是 : a person entity is limited to human including a single individual or a group. \n","\tCONTEXT是 : police blocked off roads and prevented thousands from returning home. \n","\t推测的NER位置: [(17, 17), (23, 23)]\n","\t真实的NER位置: [(17, 17), (23, 23)]\n","------------------------------------------------------------------------------------------------------------------------\n","example10:\n","\tQUERY 是 : facility entities are limited to buildings and other permanent man - made structures such as buildings, airports, highways, bridges. \n","\tCONTEXT是 : police blocked off roads and prevented thousands from returning home. \n","\t推测的NER位置: [(28, 28)]\n","\t真实的NER位置: [(28, 28)]\n","------------------------------------------------------------------------------------------------------------------------\n"]}]}],"metadata":{"interpreter":{"hash":"9c2bd64e6d98de038c014b9413779704f896a074a414ccc0c66d16846377136f"},"kernelspec":{"name":"python3","display_name":"Python 3.8.12 64-bit ('lab': conda)"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.12"},"orig_nbformat":4,"colab":{"name":"tutorial.ipynb","provenance":[],"collapsed_sections":[]},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} --------------------------------------------------------------------------------