├── .gitignore ├── README.md ├── imgs ├── re_brat_preview.png ├── re_gen_train_samples.png ├── re_gen_x_y.png └── re_network.png └── run.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | logs 3 | .idea 4 | *.pyc 5 | data 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 瑞金医院MMC人工智能辅助构建知识图谱大赛复赛 2 | 3 | [竞赛链接](https://tianchi.aliyun.com/competition/introduction.htm?raceId=231687) 4 | 5 | >:warning: 由于可能存在的版权问题,请自行联系大赛主办方索要数据,在 Issues 中索要数据的请求将不再回复,谢谢! 6 | 7 | ## 背景 8 | 9 | 复赛题目是在 Named Entity 给定的基础上,做 Relation 抽取。 10 | 11 | 初赛代码见 [beader/ruijin_round1](https://github.com/beader/ruijin_round1) 12 | 13 | 14 | 实体关系类别名称: 15 | 16 | |From Entity Type|To Entity Type|Relation Type| 17 | |:---|:---|:---| 18 | |检查方法|疾病|Test_Disease| 19 | |临床表现|疾病|Symptom_Disease| 20 | |非药治疗|疾病|Treatment_Disease| 21 | |药品名称|疾病|Drug_Disease| 22 | |部位|疾病|Anatomy_Disease| 23 | |用药频率|药品名称|Frequency_Drug| 24 | |持续时间|药品名称|Duration_Drug| 25 | |用药剂量|药品名称|Amount_Drug| 26 | |用药方法|药品名称|Method_Drug| 27 | |不良反应|药品名称|SideEff-Drug| 28 | 29 | ## 数据样例 30 | 31 | `0.txt` 32 | 33 | ``` 34 | 中国成人2型糖尿病HBA1C c控制目标的专家共识 35 | 目前,2型糖尿病及其并发症已经成为危害公众 36 | 健康的主要疾病之一,控制血糖是延缓糖尿病进展及 37 | 其并发症发生的重要措施之一。虽然HBA1C 。是评价血 38 | 糖控制水平的公认指标,但应该控制的理想水平即目 39 | 标值究竟是多少还存在争议。糖尿病控制与并发症试 40 | 验(DCCT,1993)、熊本(Kumamoto,1995)、英国前瞻性 41 | 糖尿病研究(UKPDS,1998)等高质量临床研究已经证 42 | 实,对新诊断的糖尿病患者或病情较轻的患者进行严 43 | 格的血糖控制会延缓糖尿病微血管病变的发生、发展, 44 | ``` 45 | 46 | `0.ann` 47 | 48 | ``` 49 | T1 Disease 1845 1850 1型糖尿病 50 | T2 Disease 1983 1988 1型糖尿病 51 | T4 Disease 30 35 2型糖尿病 52 | T5 Disease 1822 1827 2型糖尿病 53 | ... 54 | R206 Symptom_Disease Arg1:T329 Arg2:T325 55 | R207 Symptom_Disease Arg1:T331 Arg2:T325 56 | R208 Test_Disease Arg1:T337 Arg2:T338 57 | R209 Treatment_Disease Arg1:T343 Arg2:T345 58 | R210 Treatment_Disease Arg1:T344 Arg2:T345 59 | ``` 60 | 61 | 数据使用 [brat](http://brat.nlplab.org/) 进行标注,每个 .txt 文件对应一个 .ann 标注文件。 62 | 63 | ![](imgs/re_brat_preview.png) 64 | 65 | ## 模型 66 | 67 | ### 构建训练样本 68 | 69 | 之前没有做 Relation Extraction 的经验,最直觉的想法是当成一个二分类问题来做。先生成 Candidate Entity Pairs,做一些简单的过滤,然后利用训练集中的 Relation 数据给 Candidate Entity Pairs 打 0 或者 1 的标签。 70 | 71 | ![](imgs/re_gen_train_samples.png) 72 | 73 | 比赛中,用中文句号 (。) 做句子切分,选取 `size=2`, `step=1` 的滑动窗口来生成句子。即每个句子包含原始文章中的2句话。接着把每个句子中出现的 entities 做个排列组合,把不存在于比赛要求的 10 个 relation type 中的组合过滤掉,作为 candidate entity pairs。 74 | 75 | ### 向量化 76 | 77 | ![](imgs/re_gen_x_y.png) 78 | 79 | 对每个样本进行向量化,提取 5 个向量作为模型的输入。 80 | 81 | - `char id sequence` 为转化为字符id后的句子文本序列 82 | 83 | - `entity labels vector` 为代表 entity 类别的向量 84 | 85 | - `from entity mask` 用 \[1\] 标记出 from_entity 的位置,剩余位置补 \[0\] 86 | 87 | - `to entity mask` 用 \[1\] 标记出 to_entity 的位置,剩余位置补 \[0\] 88 | 89 | - `entity distance` 为一个带符号的实数,用来表示两个 entity 的距离 90 | 91 | ### 神经网络结构 92 | 93 | ![](imgs/re_network.png) 94 | 95 | ## 效果评估 96 | 97 | 复赛采用 F1-Score 来衡量模型效果。最终这个 baseline model 线上的成绩为 0.733 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /imgs/re_brat_preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beader/ruijin_round2/32e00dbecb1885d187bc3f2f90a639c9ef00cde1/imgs/re_brat_preview.png -------------------------------------------------------------------------------- /imgs/re_gen_train_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beader/ruijin_round2/32e00dbecb1885d187bc3f2f90a639c9ef00cde1/imgs/re_gen_train_samples.png -------------------------------------------------------------------------------- /imgs/re_gen_x_y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beader/ruijin_round2/32e00dbecb1885d187bc3f2f90a639c9ef00cde1/imgs/re_gen_x_y.png -------------------------------------------------------------------------------- /imgs/re_network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/beader/ruijin_round2/32e00dbecb1885d187bc3f2f90a639c9ef00cde1/imgs/re_network.png -------------------------------------------------------------------------------- /run.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import re\n", 11 | "import math\n", 12 | "import zipfile\n", 13 | "import numpy as np\n", 14 | "from collections import Counter, defaultdict\n", 15 | "from itertools import permutations, chain\n", 16 | "from gensim.models import Word2Vec" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "train_data_dir = '/workspace/data/tianchi_ruijin/ruijin_round2_train/'\n", 26 | "test_a_data_dir = '/workspace/data/tianchi_ruijin/ruijin_round2_test_a/'\n", 27 | "test_b_data_dir = '/workspace/data/tianchi_ruijin/ruijin_round2_test_b/'" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## 一、类与函数定义" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 3, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "ENTITIES = [\n", 44 | " \"Amount\", \"Anatomy\", \"Disease\", \"Drug\",\n", 45 | " \"Duration\", \"Frequency\", \"Level\", \"Method\",\n", 46 | " \"Operation\", \"Reason\", \"SideEff\", \"Symptom\",\n", 47 | " \"Test\", \"Test_Value\", \"Treatment\"\n", 48 | "]\n", 49 | "\n", 50 | "RELATIONS = [\n", 51 | " \"Test_Disease\",\"Symptom_Disease\",\"Treatment_Disease\",\n", 52 | " \"Drug_Disease\",\"Anatomy_Disease\",\"Frequency_Drug\",\n", 53 | " \"Duration_Drug\",\"Amount_Drug\",\"Method_Drug\",\"SideEff-Drug\"\n", 54 | "]\n", 55 | "\n", 56 | "class Entity(object):\n", 57 | " def __init__(self, ent_id, category, start_pos, end_pos, text):\n", 58 | " self.ent_id = ent_id\n", 59 | " self.category = category\n", 60 | " self.start_pos = start_pos\n", 61 | " self.end_pos = end_pos\n", 62 | " self.text = text\n", 63 | " \n", 64 | " def __gt__(self, other):\n", 65 | " return self.start_pos > other.start_pos\n", 66 | " \n", 67 | " def offset(self, offset_val):\n", 68 | " return Entity(self.ent_id, \n", 69 | " self.category, \n", 70 | " self.start_pos + offset_val,\n", 71 | " self.end_pos + offset_val,\n", 72 | " self.text)\n", 73 | " \n", 74 | " def __repr__(self):\n", 75 | " fmt = '({ent_id}, {category}, ({start_pos}, {end_pos}), {text})'\n", 76 | " return fmt.format(**self.__dict__)\n", 77 | "\n", 78 | "class Entities(object):\n", 79 | " def __init__(self, ents):\n", 80 | " self.ents = sorted(ents)\n", 81 | " self.ent_dict = dict(zip([ent.ent_id for ent in ents], ents))\n", 82 | " \n", 83 | " def __getitem__(self, key):\n", 84 | " if isinstance(key, int) or isinstance(key, slice):\n", 85 | " return self.ents[key]\n", 86 | " else:\n", 87 | " return self.ent_dict.get(key, None)\n", 88 | " \n", 89 | " def __len__(self):\n", 90 | " return len(self.ents)\n", 91 | " \n", 92 | " def offset(self, offset_val):\n", 93 | " ents = [ent.offset(offset_val) for ent in self.ents]\n", 94 | " return Entities(ents)\n", 95 | " \n", 96 | " def vectorize(self, vec_len, cate2idx):\n", 97 | " res_vec = np.zeros(vec_len, dtype=int)\n", 98 | " for ent in self.ents:\n", 99 | " res_vec[ent.start_pos: ent.end_pos] = cate2idx[ent.category]\n", 100 | " return res_vec\n", 101 | " \n", 102 | " def find_entities(self, start_pos, end_pos):\n", 103 | " res = []\n", 104 | " for ent in self.ents:\n", 105 | " if ent.start_pos > end_pos:\n", 106 | " break\n", 107 | " sp, ep = (max(start_pos, ent.start_pos), min(end_pos, ent.end_pos))\n", 108 | " if ep > sp:\n", 109 | " new_ent = Entity(ent.ent_id, ent.category, sp, ep, ent.text[:(ep - sp)])\n", 110 | " res.append(new_ent)\n", 111 | " return Entities(res)\n", 112 | " \n", 113 | " def __add__(self, other):\n", 114 | " ents = self.ents + other.ents\n", 115 | " return Entities(ents)\n", 116 | " \n", 117 | " def merge(self):\n", 118 | " merged_ents = []\n", 119 | " for ent in self.ents:\n", 120 | " if len(merged_ents) == 0:\n", 121 | " merged_ents.append(ent)\n", 122 | " elif (merged_ents[-1].end_pos == ent.start_pos and \n", 123 | " merged_ents[-1].category == ent.category):\n", 124 | " merged_ent = Entity(ent_id=merged_ents[-1].ent_id, \n", 125 | " category=ent.category,\n", 126 | " start_pos=merged_ents[-1].start_pos, \n", 127 | " end_pos=ent.end_pos, \n", 128 | " text=merged_ents[-1].text + ent.text)\n", 129 | " merged_ents[-1] = merged_ent\n", 130 | " else:\n", 131 | " merged_ents.append(ent)\n", 132 | " return Entities(merged_ents)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 4, 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "class Relation(object):\n", 142 | " def __init__(self, rel_id, category, ent1, ent2):\n", 143 | " self.rel_id = rel_id\n", 144 | " self.category = category\n", 145 | " self.ent1 = ent1\n", 146 | " self.ent2 = ent2\n", 147 | " \n", 148 | " @property\n", 149 | " def is_valid(self):\n", 150 | " return (isinstance(self.ent1, Entity) and\n", 151 | " isinstance(self.ent2, Entity) and\n", 152 | " [self.ent1.category, self.ent2.category] == re.split('[-_]', self.category))\n", 153 | " \n", 154 | " @property\n", 155 | " def start_pos(self):\n", 156 | " return min(self.ent1.start_pos, self.ent2.start_pos)\n", 157 | " \n", 158 | " @property\n", 159 | " def end_pos(self):\n", 160 | " return max(self.ent1.end_pos, self.ent2.end_pos)\n", 161 | " \n", 162 | " def offset(self, offset_val):\n", 163 | " return Relation(self.rel_id, \n", 164 | " self.category, \n", 165 | " self.ent1.offset(offset_val),\n", 166 | " self.ent2.offset(offset_val))\n", 167 | " \n", 168 | " def __gt__(self, other_rel):\n", 169 | " return self.ent1.start_pos > other_rel.ent1.start_pos\n", 170 | " \n", 171 | " def __repr__(self):\n", 172 | " fmt = '({rel_id}, {category} Arg1:{ent1} Arg2:{ent2})'\n", 173 | " return fmt.format(**self.__dict__)\n", 174 | "\n", 175 | "class Relations(object):\n", 176 | " def __init__(self, rels):\n", 177 | " self.rels = rels\n", 178 | " \n", 179 | " def __getitem__(self, key):\n", 180 | " if isinstance(key, int):\n", 181 | " return self.rels[key]\n", 182 | " elif isinstance(key, slice):\n", 183 | " return Relations(self.rels[key])\n", 184 | " \n", 185 | " def __add__(self, other):\n", 186 | " rels = self.rels + other.rels\n", 187 | " return Relations(rels)\n", 188 | " \n", 189 | " def find_relations(self, start_pos, end_pos):\n", 190 | " res = []\n", 191 | " for rel in self.rels:\n", 192 | " if start_pos <= rel.start_pos and end_pos >= rel.end_pos:\n", 193 | " res.append(rel)\n", 194 | " return Relations(res)\n", 195 | " \n", 196 | " def offset(self, offset_val): \n", 197 | " return Relations([rel.offset(offset_val) for rel in self.rels])\n", 198 | " \n", 199 | " @property\n", 200 | " def start_pos(self):\n", 201 | " return min([rel.start_pos for rel in self.rels])\n", 202 | "\n", 203 | " @property\n", 204 | " def end_pos(self):\n", 205 | " return max([rel.end_pos for rel in self.rels])\n", 206 | " \n", 207 | " def __len__(self):\n", 208 | " return len(self.rels)\n", 209 | " \n", 210 | " def __repr__(self):\n", 211 | " return self.rels.__repr__()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "code", 216 | "execution_count": 5, 217 | "metadata": {}, 218 | "outputs": [], 219 | "source": [ 220 | "class TextSpan(object):\n", 221 | " def __init__(self, text, ents, rels, **kwargs):\n", 222 | " self.text = text\n", 223 | " self.ents = ents\n", 224 | " self.rels = rels\n", 225 | " \n", 226 | " def __getitem__(self, key):\n", 227 | " if isinstance(key, int):\n", 228 | " start, stop = key, key + 1\n", 229 | " elif isinstance(key, slice):\n", 230 | " start = key.start if key.start is not None else 0\n", 231 | " stop = key.stop if key.stop is not None else len(self.text)\n", 232 | " else:\n", 233 | " raise ValueError('parameter should be int or slice')\n", 234 | " if start < 0:\n", 235 | " start += len(self.text)\n", 236 | " if stop < 0:\n", 237 | " stop += len(self.text)\n", 238 | " text = self.text[key]\n", 239 | " ents = self.ents.find_entities(start, stop).offset(-start)\n", 240 | " rels = self.rels.find_relations(start, stop).offset(-start)\n", 241 | " return TextSpan(text, ents, rels)\n", 242 | " \n", 243 | " def __len__(self):\n", 244 | " return len(self.text)\n", 245 | " \n", 246 | "\n", 247 | "class Sentence(object):\n", 248 | " def __init__(self, doc_id, offset, text='', ents=[], rels=[], textspan=None):\n", 249 | " self.doc_id = doc_id\n", 250 | " self.offset = offset\n", 251 | " if isinstance(textspan, TextSpan):\n", 252 | " self.textspan = textspan\n", 253 | " else:\n", 254 | " self.textspan = TextSpan(text, ents, rels)\n", 255 | " \n", 256 | " @property\n", 257 | " def text(self):\n", 258 | " return self.textspan.text\n", 259 | " \n", 260 | " @property\n", 261 | " def ents(self):\n", 262 | " return self.textspan.ents\n", 263 | " \n", 264 | " @property\n", 265 | " def rels(self):\n", 266 | " return self.textspan.rels\n", 267 | " \n", 268 | " def abbreviate(self, max_len, ellipse_chars='$$'):\n", 269 | " if max_len <= len(ellipse_chars):\n", 270 | " return ''\n", 271 | " left_trim = (max_len - len(ellipse_chars)) // 2\n", 272 | " right_trim = max_len - len(ellipse_chars) - left_trim\n", 273 | " return self[:left_trim] + ellipse_chars + self[-right_trim:]\n", 274 | " \n", 275 | " def __getitem__(self, key):\n", 276 | " if isinstance(key, int):\n", 277 | " start, stop = key, key + 1\n", 278 | " elif isinstance(key, slice):\n", 279 | " start = key.start if key.start is not None else 0\n", 280 | " stop = key.stop if key.stop is not None else len(self.text)\n", 281 | " else:\n", 282 | " raise ValueError('parameter should be int or slice')\n", 283 | " if start < 0:\n", 284 | " start += len(self.text)\n", 285 | " if stop < 0:\n", 286 | " stop += len(self.text)\n", 287 | " offset = self.offset + start\n", 288 | " textspan = self.textspan[start: stop]\n", 289 | " return Sentence(self.doc_id, offset, textspan=textspan)\n", 290 | " \n", 291 | " def __gt__(self, other):\n", 292 | " return self.offset > other.offse\n", 293 | " \n", 294 | " def __add__(self, other):\n", 295 | " if isinstance(other, str):\n", 296 | " return Sentence(doc_id=self.doc_id, offset=self.offset, text=self.text + other, \n", 297 | " ents=self.ents, rels=self.rels)\n", 298 | " assert self.doc_id == other.doc_id, 'sentences should be from the same document'\n", 299 | " assert self.offset + len(self) <= other.offset, 'sentences should not have overlap'\n", 300 | " doc_id = self.doc_id\n", 301 | " text = self.text + other.text\n", 302 | " offset = self.offset\n", 303 | " ents = self.ents + other.ents.offset(len(self.text))\n", 304 | " rels = self.rels + other.rels.offset(len(self.text))\n", 305 | " return Sentence(doc_id=doc_id, offset=offset, text=text, ents=ents, rels=rels)\n", 306 | " \n", 307 | " def __len__(self):\n", 308 | " return len(self.textspan)\n", 309 | " \n", 310 | "\n", 311 | "class Document(object):\n", 312 | " def __init__(self, doc_id, text, ents, rels):\n", 313 | " self.doc_id = doc_id\n", 314 | " self.textspan = TextSpan(text, ents, rels)\n", 315 | " \n", 316 | " @property\n", 317 | " def text(self):\n", 318 | " return self.textspan.text\n", 319 | " \n", 320 | " @property\n", 321 | " def ents(self):\n", 322 | " return self.textspan.ents\n", 323 | " \n", 324 | " @property\n", 325 | " def rels(self):\n", 326 | " return self.textspan.rels\n", 327 | " \n", 328 | "\n", 329 | "class Documents(object):\n", 330 | " def __init__(self, data_dir, doc_ids=None):\n", 331 | " self.data_dir = data_dir\n", 332 | " self.doc_ids = doc_ids\n", 333 | " if self.doc_ids is None:\n", 334 | " self.doc_ids = self.scan_doc_ids()\n", 335 | " \n", 336 | " def scan_doc_ids(self):\n", 337 | " doc_ids = [fname.split('.')[0] for fname in os.listdir(self.data_dir)]\n", 338 | " doc_ids = [doc_id for doc_id in doc_ids if len(doc_id) > 0]\n", 339 | " return np.unique(doc_ids)\n", 340 | " \n", 341 | " def read_txt_file(self, doc_id):\n", 342 | " fname = os.path.join(self.data_dir, doc_id + '.txt')\n", 343 | " with open(fname, encoding='utf-8') as f:\n", 344 | " text = f.read()\n", 345 | " return text\n", 346 | " \n", 347 | " def parse_entity_line(self, raw_str):\n", 348 | " ent_id, label, text = raw_str.strip().split('\\t')\n", 349 | " category, pos = label.split(' ', 1)\n", 350 | " pos = pos.split(' ')\n", 351 | " ent = Entity(ent_id, category, int(pos[0]), int(pos[-1]), text)\n", 352 | " return ent\n", 353 | " \n", 354 | " def parse_relation_line(self, raw_str, ents):\n", 355 | " rel_id, label = raw_str.strip().split('\\t')\n", 356 | " category, arg1, arg2 = label.split(' ')\n", 357 | " arg1 = arg1.split(':')[1]\n", 358 | " arg2 = arg2.split(':')[1]\n", 359 | " ent1 = ents[arg1]\n", 360 | " ent2 = ents[arg2]\n", 361 | " return Relation(rel_id, category, ent1, ent2)\n", 362 | " \n", 363 | " def read_anno_file(self, doc_id):\n", 364 | " ents = []\n", 365 | " rels = []\n", 366 | " fname = os.path.join(self.data_dir, doc_id + '.ann')\n", 367 | " with open(fname, encoding='utf-8') as f:\n", 368 | " lines = f.readlines()\n", 369 | " \n", 370 | " for line in lines:\n", 371 | " if line.startswith('T'):\n", 372 | " ent = self.parse_entity_line(line)\n", 373 | " ents.append(ent)\n", 374 | " ents = Entities(ents)\n", 375 | " \n", 376 | " for line in lines:\n", 377 | " if line.startswith('R'):\n", 378 | " rel = self.parse_relation_line(line, ents)\n", 379 | " if rel.is_valid:\n", 380 | " rels.append(rel)\n", 381 | " rels = Relations(rels)\n", 382 | " return ents, rels\n", 383 | " \n", 384 | " def __len__(self):\n", 385 | " return len(self.doc_ids)\n", 386 | " \n", 387 | " def get_doc(self, doc_id):\n", 388 | " text = self.read_txt_file(doc_id)\n", 389 | " ents, rels = self.read_anno_file(doc_id)\n", 390 | " doc = Document(doc_id, text, ents, rels)\n", 391 | " return doc\n", 392 | " \n", 393 | " def __getitem__(self, key):\n", 394 | " if isinstance(key, int):\n", 395 | " doc_id = self.doc_ids[key]\n", 396 | " return self.get_doc(doc_id)\n", 397 | " if isinstance(key, str):\n", 398 | " doc_id = key\n", 399 | " return self.get_doc(doc_id)\n", 400 | " if isinstance(key, np.ndarray) and key.dtype == int:\n", 401 | " doc_ids = self.doc_ids[key]\n", 402 | " return Documents(self.data_dir, doc_ids=doc_ids)" 403 | ] 404 | }, 405 | { 406 | "cell_type": "code", 407 | "execution_count": 6, 408 | "metadata": {}, 409 | "outputs": [], 410 | "source": [ 411 | "class SentenceExtractor(object):\n", 412 | " def __init__(self, sent_split_char, window_size, rel_types, filter_no_rel_candidates_sents=True):\n", 413 | " self.sent_split_char = sent_split_char\n", 414 | " self.window_size = window_size\n", 415 | " self.filter_no_rel_candidates_sents = filter_no_rel_candidates_sents\n", 416 | " self.rels_type_set = set()\n", 417 | " for rel_type in rel_types:\n", 418 | " self.rels_type_set.add(tuple(re.split('[-_]', rel_type)))\n", 419 | " \n", 420 | " def get_sent_boundaries(self, text):\n", 421 | " dot_indices = []\n", 422 | " for i, ch in enumerate(text):\n", 423 | " if ch == self.sent_split_char:\n", 424 | " dot_indices.append(i + 1)\n", 425 | " \n", 426 | " if len(dot_indices) <= self.window_size - 1:\n", 427 | " return [(0, len(text))]\n", 428 | " \n", 429 | " dot_indices = [0] + dot_indices\n", 430 | " if text[-1] != self.sent_split_char:\n", 431 | " dot_indices += [len(text)]\n", 432 | "\n", 433 | " boundries = []\n", 434 | " for i in range(len(dot_indices) - self.window_size):\n", 435 | " start_stop = (\n", 436 | " dot_indices[i],\n", 437 | " dot_indices[i + self.window_size]\n", 438 | " )\n", 439 | " boundries.append(start_stop)\n", 440 | " return boundries\n", 441 | " \n", 442 | " def has_rels_candidates(self, ents):\n", 443 | " ent_cates = set([ent.category for ent in ents])\n", 444 | " for pos_rel in permutations(ent_cates, 2):\n", 445 | " if pos_rel in self.rels_type_set:\n", 446 | " return True\n", 447 | " return False\n", 448 | " \n", 449 | " def extract_doc(self, doc):\n", 450 | " sents = []\n", 451 | " for start_pos, end_pos in self.get_sent_boundaries(doc.text):\n", 452 | " ents = []\n", 453 | " sent_text = doc.text[start_pos: end_pos]\n", 454 | " for ent in doc.ents.find_entities(start_pos=start_pos, end_pos=end_pos):\n", 455 | " ents.append(ent.offset(-start_pos))\n", 456 | " if self.filter_no_rel_candidates_sents and not self.has_rels_candidates(ents):\n", 457 | " continue\n", 458 | " rels = []\n", 459 | " for rel in doc.rels.find_relations(start_pos=start_pos, end_pos=end_pos):\n", 460 | " rels.append(rel.offset(-start_pos))\n", 461 | " sent = Sentence(doc.doc_id, \n", 462 | " offset=start_pos, \n", 463 | " text=sent_text, \n", 464 | " ents=Entities(ents),\n", 465 | " rels=Relations(rels))\n", 466 | " sents.append(sent)\n", 467 | " return sents\n", 468 | " \n", 469 | " def __call__(self, docs):\n", 470 | " sents = []\n", 471 | " for doc in docs:\n", 472 | " sents += self.extract_doc(doc)\n", 473 | " return sents" 474 | ] 475 | }, 476 | { 477 | "cell_type": "code", 478 | "execution_count": 7, 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [ 482 | "class EntityPair(object):\n", 483 | " def __init__(self, doc_id, sent, from_ent, to_ent):\n", 484 | " self.doc_id = doc_id\n", 485 | " self.sent = sent\n", 486 | " self.from_ent = from_ent\n", 487 | " self.to_ent = to_ent\n", 488 | " \n", 489 | " def __repr__(self):\n", 490 | " fmt = 'doc {}, sent {}, {} -> {}'\n", 491 | " return fmt.format(self.doc_id, self.sent.text, self.from_ent, self.to_ent)\n", 492 | "\n", 493 | "class EntityPairsExtractor(object):\n", 494 | " def __init__(self, allow_rel_types, max_len=150, ellipse_chars='$$', pad=10):\n", 495 | " self.allow_rel_types = allow_rel_types\n", 496 | " self.max_len = max_len\n", 497 | " self.pad = pad\n", 498 | " self.ellipse_chars = ellipse_chars\n", 499 | " \n", 500 | " def extract_candidate_rels(self, sent):\n", 501 | " candidate_rels = []\n", 502 | " for f_ent, t_ent in permutations(sent.ents, 2):\n", 503 | " rel_cate = (f_ent.category, t_ent.category)\n", 504 | " if rel_cate in self.allow_rel_types:\n", 505 | " candidate_rels.append((f_ent, t_ent))\n", 506 | " return candidate_rels\n", 507 | " \n", 508 | " def make_entity_pair(self, sent, f_ent, t_ent):\n", 509 | " doc_id = sent.doc_id\n", 510 | " if f_ent.start_pos < t_ent.start_pos:\n", 511 | " left_ent, right_ent = f_ent, t_ent\n", 512 | " else:\n", 513 | " left_ent, right_ent = t_ent, f_ent\n", 514 | " start_pos = max(0, left_ent.start_pos - self.pad)\n", 515 | " end_pos = min(len(sent), right_ent.end_pos + self.pad)\n", 516 | " res_sent = sent[start_pos: end_pos]\n", 517 | " \n", 518 | " if len(res_sent) > self.max_len:\n", 519 | " res_sent = res_sent.abbreviate(self.max_len)\n", 520 | " f_ent = res_sent.ents[f_ent.ent_id]\n", 521 | " t_ent = res_sent.ents[t_ent.ent_id]\n", 522 | " return EntityPair(doc_id, res_sent, f_ent, t_ent)\n", 523 | " \n", 524 | " def __call__(self, sents):\n", 525 | " samples = []\n", 526 | " for sent in sents:\n", 527 | " for f_ent, t_ent in self.extract_candidate_rels(sent.ents):\n", 528 | " entity_pair = self.make_entity_pair(sent, f_ent, t_ent)\n", 529 | " samples.append(entity_pair)\n", 530 | " return samples" 531 | ] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "execution_count": 8, 536 | "metadata": {}, 537 | "outputs": [], 538 | "source": [ 539 | "class Dataset(object):\n", 540 | " def __init__(self, entity_pairs, doc_ent_pair_ids=set(), word2idx=None, cate2idx=None, max_len=150):\n", 541 | " self.entity_pairs = entity_pairs\n", 542 | " self.doc_ent_pair_ids = doc_ent_pair_ids\n", 543 | " self.max_len = max_len\n", 544 | " self.word2idx = word2idx\n", 545 | " self.cate2idx = cate2idx\n", 546 | " \n", 547 | " def __len__(self):\n", 548 | " return len(self.entity_pairs)\n", 549 | " \n", 550 | " def build_vocab_dict(self, vocab_size=2000):\n", 551 | " counter = Counter()\n", 552 | " for ent_pair in self.entity_pairs:\n", 553 | " for char in ent_pair.sent.text:\n", 554 | " counter[char] += 1\n", 555 | " word2idx = dict()\n", 556 | " word2idx[''] = 0\n", 557 | " word2idx[''] = 1\n", 558 | " if vocab_size > 0:\n", 559 | " num_most_common = vocab_size - len(word2idx)\n", 560 | " else:\n", 561 | " num_most_common = len(counter)\n", 562 | " for char, _ in counter.most_common(num_most_common):\n", 563 | " word2idx[char] = word2idx.get(char, len(word2idx))\n", 564 | " self.word2idx = word2idx\n", 565 | " \n", 566 | " def vectorize(self, ent_pair):\n", 567 | " sent_vec = np.zeros(self.max_len, dtype='int')\n", 568 | " for i, c in enumerate(ent_pair.sent.text):\n", 569 | " sent_vec[i] = self.word2idx.get(c, 1) \n", 570 | " ents_vec = ent_pair.sent.ents.vectorize(vec_len=self.max_len, cate2idx=self.cate2idx)\n", 571 | " from_ent_vec = np.zeros(self.max_len, dtype='int')\n", 572 | " from_ent_vec[ent_pair.from_ent.start_pos: ent_pair.from_ent.end_pos] = 1\n", 573 | " to_ent_vec = np.zeros(self.max_len, dtype='int')\n", 574 | " to_ent_vec[ent_pair.to_ent.start_pos: ent_pair.to_ent.end_pos] = 1\n", 575 | " \n", 576 | " if (ent_pair.sent.doc_id, ent_pair.from_ent.ent_id, ent_pair.to_ent.ent_id) in self.doc_ent_pair_ids:\n", 577 | " label = 1\n", 578 | " else:\n", 579 | " label = 0\n", 580 | " return sent_vec, ents_vec, from_ent_vec, to_ent_vec, label\n", 581 | " \n", 582 | " def __getitem__(self, idx):\n", 583 | " sent_vecs, ents_vecs, from_ent_vecs, to_ent_vecs, ent_dists, labels = [], [], [], [], [], []\n", 584 | " entity_pairs = self.entity_pairs[idx]\n", 585 | " if not isinstance(entity_pairs, list):\n", 586 | " entity_pairs = [entity_pairs]\n", 587 | " for ent_pair in entity_pairs:\n", 588 | " sent_vec, ents_vec, from_ent_vec, to_ent_vec, label = self.vectorize(ent_pair)\n", 589 | " sent_vecs.append(sent_vec)\n", 590 | " ents_vecs.append(ents_vec)\n", 591 | " from_ent_vecs.append(from_ent_vec)\n", 592 | " to_ent_vecs.append(to_ent_vec)\n", 593 | " ent_dists.append(ent_pair.to_ent.start_pos - ent_pair.from_ent.end_pos)\n", 594 | " labels.append(label)\n", 595 | " \n", 596 | " sent_vecs = np.array(sent_vecs)\n", 597 | " ents_vecs = np.array(ents_vecs)\n", 598 | " from_ent_vecs = np.array(from_ent_vecs)\n", 599 | " to_ent_vecs = np.array(to_ent_vecs)\n", 600 | " ent_dists = np.array(ent_dists)\n", 601 | " labels = np.array(labels)\n", 602 | " \n", 603 | " return sent_vecs, ents_vecs, from_ent_vecs, to_ent_vecs, ent_dists, labels" 604 | ] 605 | }, 606 | { 607 | "cell_type": "code", 608 | "execution_count": 9, 609 | "metadata": {}, 610 | "outputs": [], 611 | "source": [ 612 | "def train_word_embeddings(entity_pairs, word2idx, *args, **kwargs):\n", 613 | " w2v_train_sents = []\n", 614 | " for ent_pair in entity_pairs:\n", 615 | " w2v_train_sents.append(list(ent_pair.sent.text))\n", 616 | " w2v_model = Word2Vec(w2v_train_sents, *args, **kwargs)\n", 617 | " word2idx.update({w: i for i, w in enumerate(w2v_model.wv.index2word, start=len(word2idx))})\n", 618 | " idx2word = {v: k for k, v in word2idx.items()}\n", 619 | " vocab_size = len(word2idx)\n", 620 | " w2v_embeddings = np.zeros((len(word2idx), w2v_model.vector_size))\n", 621 | " for char, char_idx in word2idx.items():\n", 622 | " if char in w2v_model.wv:\n", 623 | " w2v_embeddings[char_idx] = w2v_model.wv[char]\n", 624 | " return word2idx, idx2word, w2v_embeddings" 625 | ] 626 | }, 627 | { 628 | "cell_type": "markdown", 629 | "metadata": {}, 630 | "source": [ 631 | "## 二、生成训练与测试样本" 632 | ] 633 | }, 634 | { 635 | "cell_type": "code", 636 | "execution_count": 10, 637 | "metadata": {}, 638 | "outputs": [], 639 | "source": [ 640 | "all_rel_types = set([tuple(re.split('[-_]' ,rel)) for rel in RELATIONS])\n", 641 | "ent2idx = dict(zip(ENTITIES, range(1, len(ENTITIES) + 1)))" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 11, 647 | "metadata": {}, 648 | "outputs": [], 649 | "source": [ 650 | "train_docs = Documents(train_data_dir)\n", 651 | "test_docs = Documents(test_b_data_dir)" 652 | ] 653 | }, 654 | { 655 | "cell_type": "markdown", 656 | "metadata": {}, 657 | "source": [ 658 | "### 1. 提取训练集中所有的 relation" 659 | ] 660 | }, 661 | { 662 | "cell_type": "code", 663 | "execution_count": 12, 664 | "metadata": {}, 665 | "outputs": [], 666 | "source": [ 667 | "doc_ent_pair_ids = set()\n", 668 | "for doc in train_docs:\n", 669 | " for rel in doc.rels:\n", 670 | " doc_ent_pair_id = (doc.doc_id, rel.ent1.ent_id, rel.ent2.ent_id)\n", 671 | " doc_ent_pair_ids.add(doc_ent_pair_id)" 672 | ] 673 | }, 674 | { 675 | "cell_type": "markdown", 676 | "metadata": {}, 677 | "source": [ 678 | "### 2. 从文档中抽取句子" 679 | ] 680 | }, 681 | { 682 | "cell_type": "code", 683 | "execution_count": 13, 684 | "metadata": {}, 685 | "outputs": [], 686 | "source": [ 687 | "sent_extractor = SentenceExtractor(sent_split_char='。', window_size=2, rel_types=RELATIONS, \n", 688 | " filter_no_rel_candidates_sents=True)\n", 689 | "train_sents = sent_extractor(train_docs)\n", 690 | "test_sents = sent_extractor(test_docs)" 691 | ] 692 | }, 693 | { 694 | "cell_type": "code", 695 | "execution_count": 16, 696 | "metadata": {}, 697 | "outputs": [ 698 | { 699 | "data": { 700 | "text/plain": [ 701 | "'中国成人2型糖尿病HBA1C c控制目标的专家共识\\n目前,2型糖尿病及其并发症已经成为危害公众\\n健康的主要疾病之一,控制血糖是延缓糖尿病进展及\\n其并发症发生的重要措施之一。虽然HBA1C 。'" 702 | ] 703 | }, 704 | "execution_count": 16, 705 | "metadata": {}, 706 | "output_type": "execute_result" 707 | } 708 | ], 709 | "source": [ 710 | "train_sents[0].text" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "metadata": {}, 716 | "source": [ 717 | "### 3. 从句子中提取 relation 候选集" 718 | ] 719 | }, 720 | { 721 | "cell_type": "code", 722 | "execution_count": 14, 723 | "metadata": {}, 724 | "outputs": [], 725 | "source": [ 726 | "max_len = 150\n", 727 | "ent_pair_extractor = EntityPairsExtractor(all_rel_types, max_len=max_len)\n", 728 | "train_entity_pairs = ent_pair_extractor(train_sents)\n", 729 | "test_entity_pairs = ent_pair_extractor(test_sents)" 730 | ] 731 | }, 732 | { 733 | "cell_type": "markdown", 734 | "metadata": {}, 735 | "source": [ 736 | "### 4.利用候选 relation 所在的句子训练字符级别的字向量" 737 | ] 738 | }, 739 | { 740 | "cell_type": "code", 741 | "execution_count": 15, 742 | "metadata": {}, 743 | "outputs": [], 744 | "source": [ 745 | "word2idx = {'': 0, '': 1}\n", 746 | "word2idx, idx2word, w2v_embeddings = train_word_embeddings(\n", 747 | " entity_pairs=chain(train_entity_pairs, test_entity_pairs),\n", 748 | " word2idx=word2idx,\n", 749 | " size=100,\n", 750 | " iter=10\n", 751 | ")" 752 | ] 753 | }, 754 | { 755 | "cell_type": "markdown", 756 | "metadata": {}, 757 | "source": [ 758 | "### 5.生成训练集和测试集" 759 | ] 760 | }, 761 | { 762 | "cell_type": "code", 763 | "execution_count": 16, 764 | "metadata": {}, 765 | "outputs": [], 766 | "source": [ 767 | "train_data = Dataset(train_entity_pairs, doc_ent_pair_ids, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)\n", 768 | "test_data = Dataset(test_entity_pairs, word2idx=word2idx, max_len=max_len, cate2idx=ent2idx)" 769 | ] 770 | }, 771 | { 772 | "cell_type": "markdown", 773 | "metadata": {}, 774 | "source": [ 775 | "## 三、 建立深度学习模型" 776 | ] 777 | }, 778 | { 779 | "cell_type": "code", 780 | "execution_count": 17, 781 | "metadata": {}, 782 | "outputs": [ 783 | { 784 | "name": "stderr", 785 | "output_type": "stream", 786 | "text": [ 787 | "Using TensorFlow backend.\n" 788 | ] 789 | } 790 | ], 791 | "source": [ 792 | "from keras import layers\n", 793 | "from keras import backend as K\n", 794 | "from keras.layers import Input, Embedding, Lambda\n", 795 | "from keras.layers import Concatenate, Dense\n", 796 | "from keras.layers import Conv1D, MaxPool1D, Flatten\n", 797 | "from keras.models import Model" 798 | ] 799 | }, 800 | { 801 | "cell_type": "code", 802 | "execution_count": 18, 803 | "metadata": {}, 804 | "outputs": [], 805 | "source": [ 806 | "num_ent_classes = len(ENTITIES) + 1\n", 807 | "ent_emb_size = 2\n", 808 | "emb_size = w2v_embeddings.shape[-1]\n", 809 | "vocab_size = len(word2idx)\n", 810 | "\n", 811 | "\n", 812 | "def build_model():\n", 813 | " inp_sent = Input(shape=(max_len,), dtype='int32')\n", 814 | " inp_ent = Input(shape=(max_len,), dtype='int32')\n", 815 | " inp_f_ent = Input(shape=(max_len,), dtype='float32')\n", 816 | " inp_t_ent = Input(shape=(max_len,), dtype='float32')\n", 817 | " inp_ent_dist = Input(shape=(1,), dtype='float32')\n", 818 | " f_ent = Lambda(lambda x: K.expand_dims(x))(inp_f_ent)\n", 819 | " t_ent = Lambda(lambda x: K.expand_dims(x))(inp_t_ent)\n", 820 | " \n", 821 | " ent_embed = Embedding(num_ent_classes, ent_emb_size)(inp_ent)\n", 822 | " sent_embed = Embedding(vocab_size, emb_size, weights=[w2v_embeddings], trainable=False)(inp_sent)\n", 823 | " \n", 824 | " x = Concatenate()([sent_embed, ent_embed])\n", 825 | " x = Conv1D(64, 1, padding='same', activation='relu')(x)\n", 826 | " \n", 827 | " f_res = layers.multiply([f_ent, x])\n", 828 | " t_res = layers.multiply([t_ent, x])\n", 829 | " \n", 830 | " conv = Conv1D(64, 3, padding='same', activation='relu')\n", 831 | " f_x = conv(x)\n", 832 | " t_x = conv(x)\n", 833 | " f_x = layers.add([f_x, f_res])\n", 834 | " t_x = layers.add([t_x, t_res])\n", 835 | " \n", 836 | " f_res = layers.multiply([f_ent, f_x])\n", 837 | " t_res = layers.multiply([t_ent, t_x])\n", 838 | " conv = Conv1D(64, 3, padding='same', activation='relu')\n", 839 | " f_x = conv(x)\n", 840 | " t_x = conv(x)\n", 841 | " f_x = layers.add([f_x, f_res])\n", 842 | " t_x = layers.add([t_x, t_res])\n", 843 | "\n", 844 | " f_res = layers.multiply([f_ent, f_x])\n", 845 | " t_res = layers.multiply([t_ent, t_x])\n", 846 | " conv = Conv1D(64, 3, padding='same', activation='relu')\n", 847 | " f_x = conv(x)\n", 848 | " t_x = conv(x)\n", 849 | " f_x = layers.add([f_x, f_res])\n", 850 | " t_x = layers.add([t_x, t_res])\n", 851 | " \n", 852 | " conv = Conv1D(64, 3, activation='relu')\n", 853 | " f_x = MaxPool1D(3)(conv(f_x))\n", 854 | " t_x = MaxPool1D(3)(conv(t_x))\n", 855 | " \n", 856 | " conv = Conv1D(64, 3, activation='relu')\n", 857 | " f_x = MaxPool1D(3)(conv(f_x))\n", 858 | " t_x = MaxPool1D(3)(conv(t_x))\n", 859 | " \n", 860 | " f_x = Flatten()(f_x)\n", 861 | " t_x = Flatten()(t_x)\n", 862 | " \n", 863 | " x = Concatenate()([f_x, t_x, inp_ent_dist])\n", 864 | " x = Dense(256, activation='relu')(x)\n", 865 | " x = Dense(1, activation='sigmoid')(x)\n", 866 | "\n", 867 | " model = Model([inp_sent, inp_ent, inp_f_ent, inp_t_ent, inp_ent_dist], x)\n", 868 | " return model" 869 | ] 870 | }, 871 | { 872 | "cell_type": "markdown", 873 | "metadata": {}, 874 | "source": [ 875 | "### 1. 模型训练" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": 19, 881 | "metadata": {}, 882 | "outputs": [], 883 | "source": [ 884 | "tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist, tr_y = train_data[:]" 885 | ] 886 | }, 887 | { 888 | "cell_type": "code", 889 | "execution_count": 20, 890 | "metadata": {}, 891 | "outputs": [], 892 | "source": [ 893 | "K.clear_session()\n", 894 | "model = build_model()" 895 | ] 896 | }, 897 | { 898 | "cell_type": "code", 899 | "execution_count": 21, 900 | "metadata": {}, 901 | "outputs": [], 902 | "source": [ 903 | "model.compile('adam', loss='binary_crossentropy', metrics=['acc'])" 904 | ] 905 | }, 906 | { 907 | "cell_type": "code", 908 | "execution_count": 22, 909 | "metadata": {}, 910 | "outputs": [ 911 | { 912 | "name": "stdout", 913 | "output_type": "stream", 914 | "text": [ 915 | "Epoch 1/2\n", 916 | "553074/553074 [==============================] - 94s 170us/step - loss: 0.2427 - acc: 0.8959\n", 917 | "Epoch 2/2\n", 918 | "553074/553074 [==============================] - 89s 161us/step - loss: 0.1874 - acc: 0.9211\n" 919 | ] 920 | }, 921 | { 922 | "data": { 923 | "text/plain": [ 924 | "" 925 | ] 926 | }, 927 | "execution_count": 22, 928 | "metadata": {}, 929 | "output_type": "execute_result" 930 | } 931 | ], 932 | "source": [ 933 | "model.fit(x=[tr_sent, tr_ent, tr_f_ent, tr_t_ent, tr_ent_dist], \n", 934 | " y=tr_y, batch_size=64, epochs=2)" 935 | ] 936 | }, 937 | { 938 | "cell_type": "markdown", 939 | "metadata": {}, 940 | "source": [ 941 | "### 2.模型预测" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": 23, 947 | "metadata": {}, 948 | "outputs": [ 949 | { 950 | "name": "stdout", 951 | "output_type": "stream", 952 | "text": [ 953 | "63173/63173 [==============================] - 4s 69us/step\n" 954 | ] 955 | } 956 | ], 957 | "source": [ 958 | "te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist, te_y = test_data[:]\n", 959 | "preds = model.predict(x=[te_sent, te_ent, te_f_ent, te_t_ent, te_ent_dist], verbose=1)" 960 | ] 961 | }, 962 | { 963 | "cell_type": "markdown", 964 | "metadata": {}, 965 | "source": [ 966 | "## 四、结果输出" 967 | ] 968 | }, 969 | { 970 | "cell_type": "code", 971 | "execution_count": 24, 972 | "metadata": {}, 973 | "outputs": [], 974 | "source": [ 975 | "def generate_submission(preds, entity_pairs, threshold):\n", 976 | " doc_rels = defaultdict(set)\n", 977 | " for p, ent_pair in zip(preds, entity_pairs):\n", 978 | " if p >= threshold:\n", 979 | " doc_id = ent_pair.doc_id\n", 980 | " f_ent_id = ent_pair.from_ent.ent_id\n", 981 | " t_ent_id = ent_pair.to_ent.ent_id\n", 982 | " category = ent_pair.from_ent.category + '_' + ent_pair.to_ent.category\n", 983 | " category = category.replace('SideEff_Drug', 'SideEff-Drug')\n", 984 | " doc_rels[doc_id].add((f_ent_id, t_ent_id, category))\n", 985 | " submits = dict()\n", 986 | " tot_num_rels = 0\n", 987 | " for doc_id, rels in doc_rels.items():\n", 988 | " output_str = ''\n", 989 | " for i, rel in enumerate(rels):\n", 990 | " tot_num_rels += 1\n", 991 | " line = 'R{}\\t{} Arg1:{} Arg2:{}\\n'.format(i + 1, rel[2], rel[0], rel[1])\n", 992 | " output_str += line\n", 993 | " submits[doc_id] = output_str\n", 994 | " print('Total number of relations: {}. In average {} relations per doc.'.format(tot_num_rels, tot_num_rels / len(submits)))\n", 995 | " return submits\n", 996 | "\n", 997 | "def output_submission(submit_file, submits, test_dir):\n", 998 | " dir_name = os.path.basename(submit_file)\n", 999 | " dir_name, _ = os.path.splitext(dir_name)\n", 1000 | " with zipfile.ZipFile(submit_file, 'w') as zf:\n", 1001 | " for doc_id, rels_str in submits.items():\n", 1002 | " fname = '{}.ann'.format(doc_id)\n", 1003 | " test_file = os.path.join(test_dir, fname)\n", 1004 | " content = open(test_file, encoding='utf-8').read()\n", 1005 | " content += rels_str\n", 1006 | " zf.writestr(os.path.join(dir_name, fname), content)" 1007 | ] 1008 | }, 1009 | { 1010 | "cell_type": "code", 1011 | "execution_count": 25, 1012 | "metadata": {}, 1013 | "outputs": [ 1014 | { 1015 | "name": "stdout", 1016 | "output_type": "stream", 1017 | "text": [ 1018 | "Total number of relations: 9185. In average 170.09259259259258 relations per doc.\n" 1019 | ] 1020 | } 1021 | ], 1022 | "source": [ 1023 | "submits = generate_submission(preds, test_entity_pairs, 0.5)" 1024 | ] 1025 | }, 1026 | { 1027 | "cell_type": "code", 1028 | "execution_count": 26, 1029 | "metadata": {}, 1030 | "outputs": [], 1031 | "source": [ 1032 | "submit_file = 'submit_181205.zip'\n", 1033 | "output_submission(submit_file, submits, test_b_data_dir)" 1034 | ] 1035 | } 1036 | ], 1037 | "metadata": { 1038 | "kernelspec": { 1039 | "display_name": "Python 3", 1040 | "language": "python", 1041 | "name": "python3" 1042 | }, 1043 | "language_info": { 1044 | "codemirror_mode": { 1045 | "name": "ipython", 1046 | "version": 3 1047 | }, 1048 | "file_extension": ".py", 1049 | "mimetype": "text/x-python", 1050 | "name": "python", 1051 | "nbconvert_exporter": "python", 1052 | "pygments_lexer": "ipython3", 1053 | "version": "3.5.2" 1054 | } 1055 | }, 1056 | "nbformat": 4, 1057 | "nbformat_minor": 2 1058 | } 1059 | --------------------------------------------------------------------------------