├── .gitignore ├── README.md ├── aux ├── build.sbt ├── project │ ├── build.properties │ └── plugins.sbt ├── scripts │ ├── aida │ │ ├── dataset-split.py │ │ ├── get-aida-ontology.py │ │ ├── merge-tsvs.py │ │ └── resolve-token-index.py │ └── bbn │ │ ├── bbn-gen-dev.py │ │ ├── bbn-jsonlines-to-tsv.py │ │ └── bbn-ontology.py └── src │ └── main │ └── scala │ └── hiertype │ ├── GetHierarchy.scala │ └── PreprocessShimaokaData.scala ├── hiertype.tape ├── hiertype ├── __init__.py ├── commands │ ├── __init__.py │ ├── aggregate_metric.py │ ├── cache_repr.py │ ├── run.py │ └── train.py ├── contextualizers │ ├── __init__.py │ ├── contextualizer.py │ ├── contextualizer_test.py │ ├── elmo_contextualizer.py │ ├── get_contextualizer.py │ └── hugging_face_contextualizer.py ├── data │ ├── __init__.py │ ├── alphabet.py │ ├── bdb_storage.py │ ├── cached_mention_reader.py │ ├── hierarchy.py │ └── str_ndarray_bdb_storage.py ├── decoders │ ├── __init__.py │ ├── beam_decoder.py │ └── hierarchical_decoder.py ├── fields │ ├── __init__.py │ ├── int_field.py │ ├── real_field.py │ └── tensor_field.py ├── metrics │ ├── __init__.py │ ├── hierarchical_metric.py │ └── set_metric.py ├── models │ ├── __init__.py │ └── hierarchical_typer.py ├── modules │ ├── __init__.py │ ├── compl_ex.py │ ├── indexed_hinge_loss.py │ ├── mention_feature_extractor.py │ ├── relation_constraint_loss.py │ └── type_scorer.py ├── training │ ├── __init__.py │ └── my_trainer.py └── util │ ├── __init__.py │ ├── compact.py │ └── sample.py ├── requirements.txt └── tapes ├── aida-data.tape ├── bbn-data.tape ├── cache.tape ├── env.tape ├── params.tape ├── sge.tape ├── shimaoka-data.tape ├── test.tape └── train.tape /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Generated files 14 | .idea/**/contentModel.xml 15 | 16 | # Sensitive or high-churn files 17 | .idea/**/dataSources/ 18 | .idea/**/dataSources.ids 19 | .idea/**/dataSources.local.xml 20 | .idea/**/sqlDataSources.xml 21 | .idea/**/dynamic.xml 22 | .idea/**/uiDesigner.xml 23 | .idea/**/dbnavigator.xml 24 | 25 | # Gradle 26 | .idea/**/gradle.xml 27 | .idea/**/libraries 28 | 29 | # Gradle and Maven with auto-import 30 | # When using Gradle or Maven with auto-import, you should exclude module files, 31 | # since they will be recreated, and may cause churn. Uncomment if using 32 | # auto-import. 33 | # .idea/modules.xml 34 | # .idea/*.iml 35 | # .idea/modules 36 | # *.iml 37 | # *.ipr 38 | 39 | # CMake 40 | cmake-build-*/ 41 | 42 | # Mongo Explorer plugin 43 | .idea/**/mongoSettings.xml 44 | 45 | # File-based project format 46 | *.iws 47 | 48 | # IntelliJ 49 | out/ 50 | 51 | # mpeltonen/sbt-idea plugin 52 | .idea_modules/ 53 | 54 | # JIRA plugin 55 | atlassian-ide-plugin.xml 56 | 57 | # Cursive Clojure plugin 58 | .idea/replstate.xml 59 | 60 | # Crashlytics plugin (for Android Studio and IntelliJ) 61 | com_crashlytics_export_strings.xml 62 | crashlytics.properties 63 | crashlytics-build.properties 64 | fabric.properties 65 | 66 | # Editor-based Rest Client 67 | .idea/httpRequests 68 | 69 | # Android studio 3.1+ serialized cache file 70 | .idea/caches/build_file_checksums.ser 71 | 72 | ### Python template 73 | # Byte-compiled / optimized / DLL files 74 | __pycache__/ 75 | *.py[cod] 76 | *$py.class 77 | 78 | # C extensions 79 | *.so 80 | 81 | # Distribution / packaging 82 | .Python 83 | build/ 84 | develop-eggs/ 85 | dist/ 86 | downloads/ 87 | eggs/ 88 | .eggs/ 89 | lib/ 90 | lib64/ 91 | parts/ 92 | sdist/ 93 | var/ 94 | wheels/ 95 | pip-wheel-metadata/ 96 | share/python-wheels/ 97 | *.egg-info/ 98 | .installed.cfg 99 | *.egg 100 | MANIFEST 101 | 102 | # PyInstaller 103 | # Usually these files are written by a python script from a template 104 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 105 | *.manifest 106 | *.spec 107 | 108 | # Installer logs 109 | pip-log.txt 110 | pip-delete-this-directory.txt 111 | 112 | # Unit test / coverage reports 113 | htmlcov/ 114 | .tox/ 115 | .nox/ 116 | .coverage 117 | .coverage.* 118 | .cache 119 | nosetests.xml 120 | coverage.xml 121 | *.cover 122 | .hypothesis/ 123 | .pytest_cache/ 124 | 125 | # Translations 126 | *.mo 127 | *.pot 128 | 129 | # Django stuff: 130 | *.log 131 | local_settings.py 132 | db.sqlite3 133 | 134 | # Flask stuff: 135 | instance/ 136 | .webassets-cache 137 | 138 | # Scrapy stuff: 139 | .scrapy 140 | 141 | # Sphinx documentation 142 | docs/_build/ 143 | 144 | # PyBuilder 145 | target/ 146 | 147 | # Jupyter Notebook 148 | .ipynb_checkpoints 149 | 150 | # IPython 151 | profile_default/ 152 | ipython_config.py 153 | 154 | # pyenv 155 | .python-version 156 | 157 | # pipenv 158 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 159 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 160 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 161 | # install all needed dependencies. 162 | #Pipfile.lock 163 | 164 | # celery beat schedule file 165 | celerybeat-schedule 166 | 167 | # SageMath parsed files 168 | *.sage.py 169 | 170 | # Environments 171 | .env 172 | .venv 173 | env/ 174 | venv/ 175 | ENV/ 176 | env.bak/ 177 | venv.bak/ 178 | 179 | # Spyder project settings 180 | .spyderproject 181 | .spyproject 182 | 183 | # Rope project settings 184 | .ropeproject 185 | 186 | # mkdocs documentation 187 | /site 188 | 189 | # mypy 190 | .mypy_cache/ 191 | .dmypy.json 192 | dmypy.json 193 | 194 | # Pyre type checker 195 | .pyre/ 196 | 197 | # Script output 198 | aux/target 199 | aux/project/target 200 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | This repository contains code for the following paper: 2 | - Tongfei Chen, Yunmo Chen, Benjamin Van Durme (2020): [Hierarchical Entity Typing via Multi-level Learning to Rank](https://www.aclweb.org/anthology/2020.acl-main.749/). In _Proceedings of ACL_. 3 | 4 | ```bibtex 5 | @inproceedings{ChenCD20, 6 | author = {Tongfei Chen and 7 | Yunmo Chen and 8 | Benjamin {Van Durme}}, 9 | title = {Hierarchical Entity Typing via Multi-level Learning to Rank}, 10 | booktitle = {Proceedings of the 58th Annual Meeting of the Association for Computational 11 | Linguistics, {ACL} 2020, Online, July 5-10, 2020}, 12 | pages = {8465--8475}, 13 | year = {2020}, 14 | url = {https://www.aclweb.org/anthology/2020.acl-main.749/} 15 | } 16 | ``` 17 | 18 | ### Setup 19 | 20 | This repository uses [Ducttape](https://github.com/jhclark/ducttape) to manage intermediate results 21 | of the experiment pipeline. 22 | 23 | To run a portion of the pipeline, first clone this repository to your location, then in `tapes/env.tape`, 24 | modify various paths to point to various datasets or packages. 25 | 26 | Then use the following command: 27 | 28 | ```bash 29 | ducttape hiertype.tape -p 30 | ``` 31 | where `` is any of the plans defined in the Ducttape scripts. 32 | 33 | One can easily execute different tasks by modifying the plans in the tape files. 34 | -------------------------------------------------------------------------------- /aux/build.sbt: -------------------------------------------------------------------------------- 1 | name := "hiertype-aux" 2 | 3 | version := "0.1" 4 | 5 | scalaVersion := "2.12.9" 6 | 7 | libraryDependencies ++= Seq( 8 | "me.tongfei" %% "poly-io" % "0.3.2", 9 | "me.tongfei" % "progressbar" % "0.7.4" 10 | ) 11 | -------------------------------------------------------------------------------- /aux/project/build.properties: -------------------------------------------------------------------------------- 1 | sbt.version = 1.2.8 -------------------------------------------------------------------------------- /aux/project/plugins.sbt: -------------------------------------------------------------------------------- 1 | addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") 2 | -------------------------------------------------------------------------------- /aux/scripts/aida/dataset-split.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import random 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--all", type=str, default="") 7 | parser.add_argument("--train", type=str, default="") 8 | parser.add_argument("--dev", type=str, default="") 9 | parser.add_argument("--test", type=str, default="") 10 | parser.add_argument("--seed", type=int, default=42) 11 | ARGS = parser.parse_args() 12 | 13 | random.seed = ARGS.seed 14 | k = 0.8 15 | k1 = 0.9 16 | data = itertools.groupby(open(ARGS.all), key=lambda l: l.split(':')[0]) 17 | data = list(map(lambda t: (t[0], list(map(lambda s: s.strip(), t[1]))), data)) 18 | random.shuffle(data) 19 | 20 | n = len(data) 21 | 22 | out_train = open(ARGS.train, mode='w') 23 | out_dev = open(ARGS.dev, mode='w') 24 | out_test = open(ARGS.test, mode='w') 25 | 26 | for _, lines in data[0:int(k * n)]: 27 | for l in lines: 28 | print(l, file=out_train) 29 | for _, lines in data[int(k * n):int(k1 * n)]: 30 | for l in lines: 31 | print(l, file=out_dev) 32 | for _, lines in data[int(k1 * n):n]: 33 | for l in lines: 34 | print(l, file=out_test) 35 | 36 | out_train.close() 37 | out_dev.close() 38 | out_test.close() 39 | -------------------------------------------------------------------------------- /aux/scripts/aida/get-aida-ontology.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import argparse 3 | 4 | parser = argparse.ArgumentParser(description="Extract AIDA ontology from Excel spreadsheet") 5 | parser.add_argument("--sheet", type=str, default="", help="{events / entities}") 6 | parser.add_argument("--path", type=str, default="", help="Path to the LDC resource directory") 7 | ARGS = parser.parse_args() 8 | 9 | sheet = pd.read_excel(f"{ARGS.path}/docs/LDC_AIDAAnnotationOntologyWithMapping_V8.xlsx", sheet_name=ARGS.sheet) 10 | 11 | ldc_ontology = ( 12 | ( 13 | row['AnnotIndexID'], 14 | row['Output Value for Type'], 15 | row['Output Value for Subtype'], 16 | row['Output Value for Sub-subtype'] if 'Output Value for Sub-subtype' in row else row['Output Value for Sub-Subtype'], 17 | row['Definition'] 18 | ) 19 | for i, row in sheet.iterrows() if row['AnnotIndexID'].startswith('LDC_') 20 | ) 21 | 22 | for cid, ct1, ct2, ct3, cdef in ldc_ontology: 23 | if ct2 == "unspecified": 24 | print(f"/{ct1}") 25 | elif ct3 == "unspecified": 26 | print(f"/{ct1}/{ct2}") 27 | else: 28 | print(f"/{ct1}/{ct2}/{ct3}") 29 | -------------------------------------------------------------------------------- /aux/scripts/aida/merge-tsvs.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import sys 3 | 4 | files = sys.argv[1:] 5 | 6 | rows = pd.concat( 7 | pd.read_csv(f, delimiter='\t') 8 | for f in files 9 | ) 10 | 11 | 12 | def get_type(t1: str, t2: str, t3: str) -> str: 13 | assert t1 != "unspecified" 14 | if t2 == "unspecified": 15 | return f"/{t1}" 16 | elif t3 == "unspecified": 17 | return f"/{t1}/{t2}" 18 | else: 19 | return f"/{t1}/{t2}/{t3}" 20 | 21 | 22 | for _, r in rows.iterrows(): 23 | if r['text_string'] != "EMPTY_NA": 24 | t1 = r['type'] 25 | t2 = r['subtype'] 26 | t3 = r['subsubtype'] 27 | 28 | print(f"{r['child_uid']}:{r['textoffset_startchar']}:{r['textoffset_endchar']}\t{r['text_string']}\t{get_type(t1, t2, t3)}") 29 | -------------------------------------------------------------------------------- /aux/scripts/aida/resolve-token-index.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import concrete 3 | import sys 4 | import concrete.util.file_io as cio 5 | import numpy as np 6 | import argparse 7 | import csv 8 | import itertools 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--tsv", type=str, default="", help="") 12 | parser.add_argument("--concrete_dir", type=str, default="", help="") 13 | parser.add_argument("--lang", type=str, default="eng") 14 | ARGS = parser.parse_args() 15 | 16 | tsv = csv.reader(open(ARGS.tsv), delimiter='\t') 17 | 18 | for comm_id, rows in itertools.groupby(tsv, key=lambda r: r[0].split(':')[0]): 19 | 20 | try: 21 | comm = cio.read_communication_from_file(f"{ARGS.concrete_dir}/{comm_id}.comm") 22 | 23 | # remove non-English documents 24 | lang_dist = comm.lidList[0].languageToProbabilityMap 25 | lang = max(lang_dist.items(), key=lambda t: t[1])[0] 26 | if ARGS.lang != "all" and lang != ARGS.lang: # if ARGS.lang == "all", retain all language samples 27 | continue 28 | 29 | sentences = [ 30 | sentence 31 | for section in comm.sectionList 32 | for sentence in section.sentenceList 33 | ] 34 | 35 | sentence_indices: np.ndarray = np.array([sentence.textSpan.start for sentence in sentences]) 36 | token_indices: List[np.ndarray] = [ 37 | np.array([token.textSpan.start for token in sentence.tokenization.tokenList.tokenList]) 38 | for sentence in sentences 39 | ] 40 | 41 | for row in rows: 42 | _, lidx, ridx = row[0].split(":") 43 | text = row[1] 44 | label = row[2] 45 | lidx = int(lidx) 46 | ridx = int(ridx) 47 | sentence_index = np.digitize([lidx], sentence_indices).item() - 1 48 | left_token_index = np.digitize([lidx], token_indices[sentence_index]).item() - 1 49 | right_token_index = np.digitize([ridx], token_indices[sentence_index]).item() 50 | 51 | sentence = ' '.join(t.text for t in sentences[sentence_index].tokenization.tokenList.tokenList) 52 | 53 | print(f"{text}\t{' '.join(map(lambda t: t.text, sentences[sentence_index].tokenization.tokenList.tokenList[left_token_index:right_token_index]))}", file=sys.stderr) 54 | print(f"{sentence}\t{left_token_index}:{right_token_index}\t{label}") 55 | 56 | except FileNotFoundError: 57 | pass # some communications are not found in the LTF files 58 | -------------------------------------------------------------------------------- /aux/scripts/bbn/bbn-gen-dev.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import random 3 | 4 | 5 | random.seed(0xCAFEBABE) 6 | 7 | dev_size = 2000 8 | 9 | lines = [l.strip() for l in open(sys.argv[1])] 10 | indices = set(random.sample(range(len(lines)), dev_size)) 11 | 12 | train = open(sys.argv[2], mode='w') 13 | dev = open(sys.argv[3], mode='w') 14 | 15 | for i, l in enumerate(lines): 16 | if i in indices: 17 | print(l, file=dev) 18 | else: 19 | print(l, file=train) 20 | 21 | train.close() 22 | dev.close() 23 | -------------------------------------------------------------------------------- /aux/scripts/bbn/bbn-jsonlines-to-tsv.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | 5 | def normalize(t: str) -> str: 6 | if t == "``": 7 | return "\"" 8 | elif t == "''": 9 | return "\"" 10 | else: 11 | return t 12 | 13 | 14 | for line in sys.stdin: 15 | j = json.loads(line) 16 | sentence = ' '.join(normalize(t) for t in j["tokens"]) 17 | 18 | for m in j["mentions"]: 19 | left = m["start"] 20 | right = m["end"] 21 | types = [t.lower() for t in m["labels"]] 22 | 23 | print(f"{sentence}\t{left}:{right}\t{' '.join(types)}") 24 | -------------------------------------------------------------------------------- /aux/scripts/bbn/bbn-ontology.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | all_types = set() 4 | 5 | for line in sys.stdin: 6 | _, _, types = line.strip().split('\t') 7 | for t in types.split(' '): 8 | all_types.add(t) 9 | 10 | for t in sorted(list(all_types)): 11 | print(t) 12 | -------------------------------------------------------------------------------- /aux/src/main/scala/hiertype/GetHierarchy.scala: -------------------------------------------------------------------------------- 1 | package hiertype 2 | 3 | import scala.collection._ 4 | import scala.collection.JavaConverters._ 5 | 6 | import me.tongfei.progressbar._ 7 | import poly.io.Local._ 8 | 9 | object GetHierarchy extends App { 10 | 11 | val types = mutable.HashSet[String]() 12 | 13 | val path = args(0) 14 | for (line <- ProgressBar.wrap(File(path).lines.view.asJava, "Processing").asScala) { 15 | val Array(_, _, strTypes) = line.split("\t", 3) 16 | val ts = strTypes.split(" ") 17 | ts foreach types.add 18 | } 19 | 20 | types.toArray.sorted foreach println 21 | 22 | } 23 | -------------------------------------------------------------------------------- /aux/src/main/scala/hiertype/PreprocessShimaokaData.scala: -------------------------------------------------------------------------------- 1 | package hiertype 2 | 3 | import scala.collection.JavaConverters._ 4 | 5 | import poly.io.Local._ 6 | import me.tongfei.progressbar._ 7 | 8 | object PreprocessShimaokaData extends App { 9 | 10 | def normalize(s: String) = s match { 11 | case "''" => "\"" 12 | case "``" => "\"" 13 | case "-LRB-" => "(" 14 | case "-RRB-" => ")" 15 | case "-LSB-" => "[" 16 | case "-RSB-" => "]" 17 | case "-LCB-" => "{" 18 | case "-RCB-" => "}" 19 | case _ => s 20 | } 21 | 22 | val path = args(0) 23 | for (line <- ProgressBar.wrap(File(path).lines.view.asJava, "Preprocessing").asScala) { 24 | val Array(strL, strR, strSentence, strTypes, _*) = line.split("\t", 5) 25 | val l = strL.toInt 26 | val r = strR.toInt 27 | val s = strSentence.split(" ").map(normalize) 28 | val types = strTypes.split(" ").foldLeft(Set[String]()) { (ts, t) => 29 | if (ts.exists(t.startsWith)) ts.filterNot(t.startsWith) + t 30 | else if (ts.exists(_ startsWith t)) ts 31 | else ts + t 32 | } 33 | 34 | println(s"${s.mkString(" ").trim}\t${l}:${r}\t${types.mkString(" ")}") 35 | } 36 | 37 | } 38 | -------------------------------------------------------------------------------- /hiertype.tape: -------------------------------------------------------------------------------- 1 | global { 2 | ducttape_output="out" 3 | ducttape_experimental_packages=true 4 | ducttape_experimental_submitters=true 5 | ducttape_experimental_imports=true 6 | ducttape_experimental_multiproc=true 7 | } 8 | 9 | import "tapes/sge.tape" 10 | import "tapes/env.tape" 11 | 12 | import "tapes/params.tape" 13 | 14 | import "tapes/shimaoka-data.tape" 15 | import "tapes/aida-data.tape" 16 | import "tapes/bbn-data.tape" 17 | 18 | import "tapes/cache.tape" 19 | import "tapes/train.tape" 20 | import "tapes/test.tape" 21 | -------------------------------------------------------------------------------- /hiertype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctongfei/hierarchical-typing/fc85a8dbbe03c89f51778c09239e732564d2c33b/hiertype/__init__.py -------------------------------------------------------------------------------- /hiertype/commands/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ctongfei/hierarchical-typing/fc85a8dbbe03c89f51778c09239e732564d2c33b/hiertype/commands/__init__.py -------------------------------------------------------------------------------- /hiertype/commands/aggregate_metric.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import os 3 | import sys 4 | import re 5 | import json 6 | from colors import blue 7 | 8 | 9 | metric = sys.argv[1] 10 | glob = sys.argv[2:] 11 | 12 | r = re.compile(r"metrics_epoch_(.*)\.json") 13 | 14 | for path in glob: 15 | try: 16 | files = { 17 | int(r.findall(f)[0]): f 18 | for f in os.listdir(f"{path}/out/") if r.match(f) 19 | } 20 | 21 | max_epoch = max(files.keys()) 22 | j = json.load(open(f"{path}/out/metrics_epoch_{max_epoch}.json")) 23 | 24 | print(f"{blue(path)}: {j[f'best_validation_{metric}']}") 25 | 26 | except: 27 | pass 28 | -------------------------------------------------------------------------------- /hiertype/commands/cache_repr.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | from bisect import bisect_left, bisect_right 4 | from tqdm import tqdm 5 | import numpy as np 6 | from fire import Fire 7 | import sys 8 | from colors import blue 9 | 10 | from hiertype.contextualizers import Contextualizer, get_contextualizer 11 | from hiertype.data import StringNdArrayBerkeleyDBStorage 12 | 13 | T = TypeVar('T', covariant=True) 14 | 15 | 16 | def get_spans(lines: Iterator[str]) -> Iterator[Tuple[List[str], int, int]]: 17 | for line in lines: 18 | sentence, span, *_ = line.split('\t') 19 | tokens = sentence.split(' ') 20 | l_str, r_str = span.split(':') 21 | l = int(l_str) 22 | r = int(r_str) 23 | yield (tokens, l, r) 24 | 25 | 26 | def batched(xs: Iterator[T], batch_size: int = 128) -> Iterator[List[T]]: 27 | buf = [] 28 | for x in xs: 29 | buf.append(x) 30 | if len(buf) == batch_size: 31 | yield buf 32 | buf = [] 33 | if len(buf) > 0: 34 | yield buf 35 | 36 | 37 | def select_embeddings( 38 | encoded: torch.Tensor, # R[Batch, Layer, Word, Emb] 39 | mappings: List[List[int]], 40 | layers: List[int], 41 | unit: str = "subword" 42 | ) -> Iterator[torch.Tensor]: # R[Layer, Word, Emb] 43 | 44 | bsz = encoded.size(0) 45 | 46 | for i in range(bsz): 47 | x = encoded[i, layers, :, :] # R [Layer, Word, Emb] 48 | mapping = mappings[i] 49 | word_l = min(mapping) 50 | word_r = max(mapping) + 1 51 | subword_l = bisect_left(mapping, word_l) 52 | subword_r = bisect_right(mapping, word_r - 1) 53 | if subword_r == subword_l: 54 | subword_r += 1 55 | 56 | if unit == "subword": 57 | yield x[:, subword_l:subword_r, :] 58 | elif unit == "word": 59 | yield torch.cat( 60 | [ 61 | x[:, bisect_left(mapping, j):bisect_right(mapping, j), :].mean(dim=1, keepdim=True) 62 | for j in range(word_l, word_r) 63 | ], 64 | dim=1 65 | ) 66 | else: 67 | raise AssertionError("`unit` must be either `word` or `subword`") 68 | 69 | 70 | def main(*, 71 | input: str, 72 | output: str, 73 | model: str = "elmo-original", 74 | unit: str = "subword", 75 | batch_size: int = 64, 76 | layers: List[int], 77 | gpuid: int = 0 78 | ): 79 | 80 | for k, v in reversed(list(locals().items())): # seems that `locals()` stores the args in reverse order 81 | print(f"{blue('--' + k)} \"{v}\"", file=sys.stderr) 82 | 83 | if gpuid >= 0: 84 | torch.cuda.set_device(gpuid) 85 | 86 | contextualizer: Contextualizer = get_contextualizer( 87 | model, 88 | device="cpu" if gpuid < 0 else f"cuda:{gpuid}", 89 | tokenizer_only=False 90 | ) 91 | dump = StringNdArrayBerkeleyDBStorage.open(output, mode='w') 92 | 93 | lines: Iterator[str] = tqdm(open(input, mode='r')) 94 | spans: Iterator[Tuple[List[str], int, int]] = get_spans(lines) 95 | 96 | i = 0 97 | for batch in batched(spans, batch_size=batch_size): 98 | sentences, ls, rs = zip(*batch) 99 | 100 | tokenized_sentences, mappings = zip(*[ 101 | contextualizer.tokenize_with_mapping(sentence) 102 | for sentence in sentences 103 | ]) 104 | encoded = contextualizer.encode(tokenized_sentences, frozen=True) 105 | 106 | for emb in select_embeddings(encoded, mappings, layers, unit): 107 | x: np.ndarray = emb.detach().cpu().numpy() 108 | dump[str(i)] = x.astype(np.float32) 109 | i += 1 110 | 111 | dump.close() 112 | print("Job complete.", file=sys.stderr) 113 | 114 | 115 | if __name__ == "__main__": 116 | Fire(main) 117 | -------------------------------------------------------------------------------- /hiertype/commands/run.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import argparse 3 | import logging 4 | import json 5 | import sys 6 | import torch 7 | import tqdm 8 | from allennlp.data.iterators import BasicIterator 9 | from hiertype.data import Hierarchy, CachedMentionReader 10 | from hiertype.models import HierarchicalTyper 11 | from hiertype.modules import MentionFeatureExtractor, TypeScorer 12 | from hiertype.decoders import BeamDecoder 13 | from fire import Fire 14 | 15 | 16 | def main(*, 17 | model: str, 18 | model_file: str = "best.th", 19 | test: str, 20 | out: str, 21 | max_branching_factors: List[int], 22 | delta: List[float], 23 | strategies: List[str], 24 | other_delta: float = 0.0, 25 | seed: int = 0xDEADBEEF, 26 | batch_size: int = 256, 27 | gpuid: int = 0 28 | ): 29 | TEST_ARGS = argparse.Namespace(**locals().copy()) 30 | ARGS = argparse.Namespace(**json.load(open(f"{TEST_ARGS.model}/args.json", mode='r'))) 31 | 32 | for key, val in ARGS.__dict__.items(): 33 | print(f"ARG {key}: {val}", file=sys.stderr) 34 | for key, val in TEST_ARGS.__dict__.items(): 35 | print(f"TEST_ARG {key}: {val}", file=sys.stderr) 36 | 37 | torch.cuda.set_device(gpuid) 38 | torch.manual_seed(seed) 39 | 40 | if TEST_ARGS.max_branching_factors is None: 41 | TEST_ARGS.max_branching_factors = ARGS.max_branching_factors 42 | if TEST_ARGS.delta is None: 43 | TEST_ARGS.delta = ARGS.delta 44 | if TEST_ARGS.strategies is None: 45 | TEST_ARGS.strategies = ARGS.strategies 46 | 47 | hierarchy: Hierarchy = Hierarchy.from_tree_file(filename=ARGS.ontology, with_other=ARGS.with_other) 48 | 49 | model = HierarchicalTyper( 50 | hierarchy=hierarchy, 51 | input_dim=ARGS.input_dim, 52 | type_dim=ARGS.type_dim, 53 | bottleneck_dim=ARGS.bottleneck_dim, 54 | mention_pooling=ARGS.mention_pooling, 55 | with_context=True, 56 | dropout_rate=ARGS.dropout_rate, 57 | emb_dropout_rate=ARGS.emb_dropout_rate, 58 | margins_per_level=ARGS.margins, 59 | num_negative_samples=ARGS.num_negative_samples, 60 | threshold_ratio=ARGS.threshold_ratio, 61 | lift_other=ARGS.lift_other, 62 | relation_constraint_coef=ARGS.relation_constraint_coef, 63 | compute_metric_when_training=True, 64 | decoder=BeamDecoder( 65 | hierarchy=hierarchy, 66 | strategies=TEST_ARGS.strategies, 67 | max_branching_factors=TEST_ARGS.max_branching_factors, 68 | delta=TEST_ARGS.delta, 69 | top_other_delta=TEST_ARGS.other_delta 70 | ) 71 | ) 72 | 73 | model_state = torch.load(f"{ARGS.out}/{TEST_ARGS.model_file}", map_location=lambda storage, loc: storage) 74 | model.load_state_dict(model_state) 75 | model.cuda() 76 | model.eval() 77 | 78 | model.metric.set_serialization_dir(TEST_ARGS.out) 79 | print("Model loaded.", file=sys.stderr) 80 | 81 | test_reader = CachedMentionReader(hierarchy=hierarchy, model=ARGS.contextualizer) 82 | iterator = BasicIterator(batch_size=TEST_ARGS.batch_size) 83 | 84 | with torch.no_grad(): 85 | for batch in tqdm.tqdm(iterator(instances=test_reader.read(TEST_ARGS.test), num_epochs=1, shuffle=False)): 86 | for k, v in batch.items(): 87 | if hasattr(v, 'cuda'): 88 | batch[k] = v.cuda() 89 | model(**batch) 90 | 91 | for m, v in model.metric.get_metric(reset=False).items(): 92 | print(f"METRIC {m}: {v}") 93 | 94 | 95 | Fire(main) 96 | -------------------------------------------------------------------------------- /hiertype/commands/train.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import argparse 3 | import logging 4 | import torch 5 | import json 6 | import sys 7 | import numpy as np 8 | import random 9 | from colors import blue 10 | from fire import Fire 11 | 12 | from torch.optim.optimizer import Optimizer 13 | from torch.optim.adamw import AdamW 14 | from allennlp.data.iterators import BasicIterator 15 | 16 | from hiertype.training import MyTrainer 17 | from hiertype.data import Hierarchy, CachedMentionReader 18 | from hiertype.decoders import BeamDecoder 19 | from hiertype.models import HierarchicalTyper 20 | 21 | 22 | def main(*, 23 | ontology: str, 24 | train: str, 25 | dev: str, 26 | out: str, 27 | 28 | contextualizer: str = "elmo-original", 29 | input_dim: int = 3072, 30 | type_dim: int = 1024, 31 | bottleneck_dim: int = 0, 32 | with_other: bool = False, 33 | lift_other: bool = False, 34 | mention_pooling: str = "max", 35 | emb_dropout_rate: float = 0.3, 36 | dropout_rate: float = 0.3, 37 | 38 | margins: List[float] = [], 39 | threshold_ratio: float = 0.1, 40 | relation_constraint_coef: float = 0.1, 41 | num_negative_samples: int = 0, 42 | 43 | max_branching_factors: List[int] = [], 44 | delta: List[float] = [], 45 | strategies: List[str] = [], 46 | 47 | seed: int = 0xDEADBEEF, 48 | batch_size: int = 256, 49 | dev_batch_size: int = 256, 50 | num_epochs: int = 5, 51 | dev_metric: str = "+O_MiF", 52 | patience: int = 4, 53 | lr: float = 1e-5, 54 | regularizer: float = 0.1, 55 | gpuid: int = 0 56 | ): 57 | 58 | args = locals().copy() 59 | 60 | with open(f"{out}/args.json", mode='w') as args_out: 61 | for k, v in reversed(list(args.items())): # seems that `locals()` stores the args in reverse order 62 | print(f"{blue('--' + k)} \"{v}\"", file=sys.stderr) 63 | print(json.dumps(args, indent=2), file=args_out) 64 | 65 | torch.cuda.set_device(gpuid) 66 | 67 | # Ensure deterministic behavior 68 | torch.manual_seed(seed) 69 | np.random.seed(seed) 70 | random.seed(seed) 71 | torch.backends.cudnn.deterministic = True 72 | torch.backends.cudnn.benchmark = False 73 | 74 | logging.basicConfig(level=logging.INFO) 75 | 76 | hierarchy: Hierarchy = Hierarchy.from_tree_file(ontology, with_other=with_other) 77 | print(hierarchy, file=sys.stderr) 78 | 79 | reader = CachedMentionReader(hierarchy, model=contextualizer) 80 | 81 | model = HierarchicalTyper( 82 | hierarchy=hierarchy, 83 | input_dim=input_dim, 84 | type_dim=type_dim, 85 | bottleneck_dim=bottleneck_dim, 86 | mention_pooling=mention_pooling, 87 | with_context=True, 88 | dropout_rate=dropout_rate, 89 | emb_dropout_rate=emb_dropout_rate, 90 | margins_per_level=margins, 91 | num_negative_samples=num_negative_samples, 92 | threshold_ratio=threshold_ratio, 93 | relation_constraint_coef=relation_constraint_coef, 94 | lift_other=lift_other, 95 | compute_metric_when_training=True, 96 | decoder=BeamDecoder( 97 | hierarchy=hierarchy, 98 | strategies=strategies, 99 | max_branching_factors=max_branching_factors, 100 | delta=delta 101 | ) 102 | ) 103 | model.cuda() 104 | 105 | optimizer: Optimizer = AdamW( 106 | params=model.parameters(), 107 | lr=lr, 108 | weight_decay=regularizer 109 | ) 110 | 111 | trainer = MyTrainer( 112 | model=model, 113 | optimizer=optimizer, 114 | iterator=BasicIterator(batch_size=batch_size), 115 | validation_iterator=BasicIterator(batch_size=dev_batch_size), 116 | train_dataset=reader.read(train), 117 | validation_dataset=reader.read(dev), 118 | validation_metric=dev_metric, 119 | patience=patience, 120 | num_epochs=num_epochs, 121 | grad_norm=1.0, 122 | serialization_dir=out, 123 | num_serialized_models_to_keep=1, 124 | cuda_device=gpuid 125 | ) 126 | 127 | model.set_trainer(trainer) 128 | model.metric.set_serialization_dir(trainer._serialization_dir) 129 | # hook into the trainer to set the metric serialization path 130 | trainer.train() 131 | 132 | 133 | if __name__ == "__main__": 134 | Fire(main) 135 | -------------------------------------------------------------------------------- /hiertype/contextualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.contextualizers.contextualizer import Contextualizer 2 | from hiertype.contextualizers.hugging_face_contextualizer import HuggingFaceContextualizer, \ 3 | BERTContextualizer, XLMRobertaContextualizer 4 | from hiertype.contextualizers.elmo_contextualizer import ELMoContextualizer 5 | 6 | from hiertype.contextualizers.get_contextualizer import get_contextualizer 7 | -------------------------------------------------------------------------------- /hiertype/contextualizers/contextualizer.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import abstractmethod 3 | import torch 4 | 5 | 6 | class Contextualizer: 7 | """ 8 | Wraps around any contextualizer with optional subword tokenization. 9 | This abstracts over the following 3 cases: 10 | - GloVe et al. (no contextualization); 11 | - ELMo et al. (character-based); 12 | - BERT et al. (subword-based). 13 | 14 | These are encapsulated in the following 3 steps, as represented 15 | by the 3 abstract methods. 16 | This abstraction does not support methods requiring a language input, 17 | i.e. XLM (but it can process XLM-Roberta). 18 | """ 19 | 20 | @abstractmethod 21 | def tokenize_with_mapping(self, 22 | sentence: List[str] 23 | ) -> Tuple[Union[List[int], List[str]], List[int]]: 24 | """ 25 | Given a sentence tokenized into words, 26 | tokenizes it into subword units, 27 | optionally index (ELMo does not do this) these subword units into IDs, 28 | and returns a mapping of the tokenized symbols and the original token indices. 29 | :param sentence: List of original tokens 30 | :param lang: Language ID 31 | :return: (subword token indices, index mapping) 32 | """ 33 | raise NotImplementedError 34 | 35 | @abstractmethod 36 | def encode(self, 37 | sentences: List[Union[List[int], List[str]]], 38 | frozen: bool = True 39 | ) -> torch.Tensor: # R[Batch, Layer, Word, Emb] 40 | """ 41 | Encodes these sentences, with their optional language IDs 42 | :param sentences: 43 | :param frozen: Whether the encoder is frozen 44 | :return: 45 | """ 46 | raise NotImplementedError 47 | -------------------------------------------------------------------------------- /hiertype/contextualizers/contextualizer_test.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from hiertype.contextualizers import get_contextualizer 3 | 4 | s = "He found a leprechaun in his walnut shell .".split(' ') 5 | 6 | contextualizer = get_contextualizer("xlm-roberta-base", device='cuda:0') 7 | 8 | t, m = contextualizer.tokenize_with_mapping(s) 9 | encoded = contextualizer.encode([t], frozen=True) 10 | pass 11 | -------------------------------------------------------------------------------- /hiertype/contextualizers/elmo_contextualizer.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | 5 | from hiertype.contextualizers import Contextualizer 6 | from allennlp.commands.elmo import ElmoEmbedder 7 | from allennlp.modules.elmo import _ElmoBiLm 8 | 9 | 10 | class ELMoContextualizer(Contextualizer): 11 | 12 | def __init__(self, elmo: ElmoEmbedder, device: str): 13 | self.elmo = elmo 14 | self.device = device 15 | 16 | @classmethod 17 | def from_model(cls, 18 | elmo_weights_path: str, 19 | elmo_options_path: str, 20 | device: str, 21 | tokenizer_only: str = False 22 | ): 23 | if device == "cpu": 24 | device_id = -1 25 | else: 26 | device_id = int(device[5:]) # "cuda:1" 27 | 28 | if tokenizer_only: 29 | return cls(elmo=None, device=device) 30 | 31 | elmo = ElmoEmbedder( 32 | options_file=elmo_options_path, 33 | weight_file=elmo_weights_path, 34 | cuda_device=device_id 35 | ) 36 | return cls(elmo=elmo, device=device) 37 | 38 | def tokenize_with_mapping(self, sentence: List[str]) -> Tuple[List[str], List[int]]: 39 | # Doesn't do anything -- retain original tokenization 40 | n = len(sentence) 41 | return sentence, [i for i in range(n)] 42 | 43 | def encode(self, 44 | sentences: List[List[str]], 45 | frozen: bool = True 46 | ) -> torch.Tensor: 47 | 48 | with torch.no_grad() if frozen else torch.enable_grad(): 49 | embs, _ = self.elmo.batch_to_embeddings(sentences) # R[Batch, Layer, Word, Emb] 50 | return embs 51 | -------------------------------------------------------------------------------- /hiertype/contextualizers/get_contextualizer.py: -------------------------------------------------------------------------------- 1 | from hiertype.contextualizers.contextualizer import Contextualizer 2 | from hiertype.contextualizers.hugging_face_contextualizer import BERTContextualizer, XLMRobertaContextualizer 3 | from hiertype.contextualizers.elmo_contextualizer import ELMoContextualizer 4 | 5 | 6 | def get_contextualizer( 7 | model_name: str, 8 | device: str, 9 | tokenizer_only: bool = False 10 | ) -> Contextualizer: 11 | """ 12 | Returns a contextualizer by pre-trained model name. 13 | 14 | :param model_name: Model identifier. 15 | :param device: 16 | :param tokenizer_only: if True, only loads tokenizer (not actual model) 17 | """ 18 | 19 | if model_name.startswith("bert"): 20 | return BERTContextualizer.from_model(model_name, device, tokenizer_only=tokenizer_only) 21 | 22 | elif model_name.startswith("xlm-roberta"): 23 | return XLMRobertaContextualizer.from_model(model_name, device, tokenizer_only=tokenizer_only) 24 | 25 | elif model_name.startswith("elmo"): 26 | elmo_path_prefix = "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo" 27 | elmo_path_id = { 28 | "elmo-small": "2x1024_128_2048cnn_1xhighway", 29 | "elmo-medium": "2x2048_256_2048cnn_1xhighway", 30 | "elmo-original": "2x4096_512_2048cnn_2xhighway", 31 | "elmo-original-5.5B": "2x4096_512_2048cnn_2xhighway_5.5B" 32 | }[model_name] 33 | elmo_path = f"{elmo_path_prefix}/{elmo_path_id}/elmo_{elmo_path_id}" 34 | elmo_weights_path = f"{elmo_path}_weights.hdf5" 35 | elmo_options_path = f"{elmo_path}_options.json" 36 | return ELMoContextualizer.from_model( 37 | elmo_weights_path, elmo_options_path, device, tokenizer_only=tokenizer_only 38 | ) 39 | 40 | elif model_name.startswith("glove"): 41 | glove_path = { 42 | "glove-6B-300d": "http://nlp.stanford.edu/data/glove.6B.zip", 43 | "glove-42B-300d": "http://nlp.stanford.edu/data/glove.42B.300d.zip", 44 | "glove-840B-300d": "http://nlp.stanford.edu/data/glove.840B.300d.zip" 45 | } 46 | # TODO: read GloVe files 47 | raise NotImplementedError 48 | -------------------------------------------------------------------------------- /hiertype/contextualizers/hugging_face_contextualizer.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import abstractmethod 3 | import torch 4 | import transformers 5 | 6 | from hiertype.contextualizers import Contextualizer 7 | 8 | 9 | class HuggingFaceContextualizer(Contextualizer): 10 | """ 11 | Wraps around any contextualizer in the HuggingFace Transformers package. 12 | """ 13 | 14 | def __init__(self, 15 | hf_tokenizer: transformers.PreTrainedTokenizer, 16 | hf_model: transformers.PreTrainedModel, 17 | device: str 18 | ): 19 | self.hf_tokenizer = hf_tokenizer 20 | self.hf_model = hf_model 21 | self.device = device 22 | 23 | @abstractmethod 24 | def postprocess_mapping(self, 25 | mapping: List[int] 26 | ) -> List[int]: 27 | raise NotImplementedError 28 | 29 | @abstractmethod 30 | def select_output(self, output: Any) -> torch.Tensor: 31 | raise NotImplementedError 32 | 33 | def tokenize_with_mapping(self, 34 | sentence: List[str] 35 | ) -> Tuple[Union[List[int]], List[int]]: 36 | tokens = [] 37 | mapping = [] 38 | for i, t in enumerate(sentence): 39 | for wp in self.hf_tokenizer.tokenize(t): 40 | tokens.append(wp) 41 | mapping.append(i) 42 | 43 | token_indices = self.hf_tokenizer.convert_tokens_to_ids(tokens) 44 | token_indices_with_special_symbols = self.hf_tokenizer.build_inputs_with_special_tokens(token_indices) 45 | mapping_with_special_symbols = self.postprocess_mapping(mapping) 46 | 47 | return token_indices_with_special_symbols, mapping_with_special_symbols 48 | 49 | def encode(self, 50 | indexed_sentences: List[List[int]], 51 | frozen: bool = True, 52 | ) -> torch.Tensor: # R[Batch, Layer, Length, Emb] 53 | 54 | bsz = len(indexed_sentences) 55 | lengths = [len(s) for s in indexed_sentences] 56 | indices_tensor = torch.zeros(bsz, max(lengths), dtype=torch.int64) 57 | input_mask = torch.zeros(bsz, max(lengths), dtype=torch.int64) 58 | 59 | for i in range(bsz): 60 | for j in range(lengths[i]): 61 | indices_tensor[i, j] = indexed_sentences[i][j] 62 | input_mask[i, j] = 1 63 | 64 | indices_tensor = indices_tensor.to(device=self.device) 65 | input_mask = input_mask.to(device=self.device) 66 | 67 | if frozen: 68 | self.hf_model.eval() 69 | else: 70 | self.hf_model.train() 71 | 72 | with torch.no_grad() if frozen else torch.enable_grad(): 73 | model_output = self.hf_model( 74 | input_ids=indices_tensor, 75 | attention_mask=input_mask 76 | ) 77 | embs = self.select_output(model_output) 78 | return embs 79 | 80 | 81 | class BERTContextualizer(HuggingFaceContextualizer): 82 | 83 | def __init__(self, 84 | hf_tokenizer: transformers.BertTokenizer, 85 | hf_model: transformers.BertModel, 86 | device: str 87 | ): 88 | super(BERTContextualizer, self).__init__(hf_tokenizer, hf_model, device) 89 | 90 | @classmethod 91 | def from_model(cls, model_name: str, device: str, tokenizer_only: bool = False): 92 | hf_tokenizer = transformers.BertTokenizer.from_pretrained(model_name) 93 | hf_model = None if tokenizer_only \ 94 | else transformers.BertModel.from_pretrained(model_name, output_hidden_states=True) 95 | if not tokenizer_only and device != "cpu": 96 | hf_model.cuda(device=device) 97 | return cls( 98 | hf_tokenizer=hf_tokenizer, 99 | hf_model=hf_model, 100 | device=device 101 | ) 102 | 103 | def postprocess_mapping(self, mapping: List[int]) -> List[int]: 104 | # account for [CLS] and [SEP] 105 | return [-1] + mapping + [max(mapping) + 1] 106 | 107 | def select_output(self, output: Any) -> torch.Tensor: 108 | encoded = output[2] # List_Layer[R[Batch, Word, Emb]] 109 | stacked = torch.stack(encoded, dim=1) # R[Batch, Layer, Word, Emb] 110 | return stacked 111 | 112 | 113 | class XLMRobertaContextualizer(HuggingFaceContextualizer): 114 | 115 | def __init__(self, 116 | hf_tokenizer: transformers.XLMRobertaTokenizer, 117 | hf_model: transformers.XLMRobertaModel, 118 | device: str 119 | ): 120 | super(XLMRobertaContextualizer, self).__init__(hf_tokenizer, hf_model, device) 121 | 122 | @classmethod 123 | def from_model(cls, model_name: str, device: str, tokenizer_only: bool = False): 124 | hf_tokenizer = transformers.XLMRobertaTokenizer.from_pretrained(model_name) 125 | hf_model = None if tokenizer_only \ 126 | else transformers.XLMRobertaModel.from_pretrained(model_name, output_hidden_states=True) 127 | if not tokenizer_only and device != "cpu": 128 | hf_model.cuda(device=device) 129 | return cls( 130 | hf_tokenizer=hf_tokenizer, 131 | hf_model=hf_model, 132 | device=device 133 | ) 134 | 135 | def postprocess_mapping(self, mapping: List[int]) -> List[int]: 136 | # account for and 137 | return [-1] + mapping + [max(mapping) + 1] 138 | 139 | def select_output(self, output: Any) -> torch.Tensor: 140 | encoded = output[2] # List_Layer[R[Batch, Word, Emb]] 141 | stacked = torch.stack(encoded, dim=1) # R[Batch, Layer, Word, Emb] 142 | return stacked 143 | -------------------------------------------------------------------------------- /hiertype/data/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.data.alphabet import Alphabet 2 | from hiertype.data.hierarchy import Hierarchy 3 | 4 | from hiertype.data.bdb_storage import BerkeleyDBStorage 5 | from hiertype.data.str_ndarray_bdb_storage import StringNdArrayBerkeleyDBStorage 6 | 7 | from hiertype.data.cached_mention_reader import CachedMentionReader 8 | -------------------------------------------------------------------------------- /hiertype/data/alphabet.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | 4 | class Alphabet: 5 | """ 6 | Maintains a bijection between symbols (tokens / labels) with unique indices. 7 | """ 8 | 9 | def __init__(self, 10 | sym_to_idx: Dict[str, int], 11 | idx_to_sym: List[str] 12 | ): 13 | self.sym_to_idx: Dict[str, int] = sym_to_idx 14 | self.idx_to_sym: List[str] = idx_to_sym 15 | 16 | def size(self) -> int: 17 | return len(self.idx_to_sym) 18 | 19 | def index(self, sym: str) -> int: 20 | if sym in self.sym_to_idx: 21 | return self.sym_to_idx[sym] 22 | else: 23 | idx = self.size() 24 | self.idx_to_sym.append(sym) 25 | self.sym_to_idx[sym] = idx 26 | return idx 27 | 28 | @classmethod 29 | def with_special_symbols(cls, special_symbols: List[str]) -> 'Alphabet': 30 | 31 | sym_to_idx: Dict[str, int] = {} 32 | idx_to_sym: List[str] = [] 33 | 34 | for i, sym in enumerate(special_symbols): 35 | sym_to_idx[sym] = i 36 | idx_to_sym.append(sym) 37 | 38 | return cls(sym_to_idx, idx_to_sym) 39 | -------------------------------------------------------------------------------- /hiertype/data/bdb_storage.py: -------------------------------------------------------------------------------- 1 | from bsddb3 import db 2 | import numpy as np 3 | from typing import * 4 | from abc import abstractmethod 5 | 6 | K = TypeVar('K') 7 | V = TypeVar('V') 8 | 9 | 10 | class BerkeleyDBStorage(Generic[K, V], MutableMapping[K, V]): 11 | """ 12 | A high-performance key-value storage on disk, powered by BerkeleyDB. 13 | """ 14 | 15 | def __init__(self, kvs: db.DB): 16 | self.kvs = kvs 17 | 18 | @abstractmethod 19 | def encode_key(self, k: K) -> bytes: 20 | pass 21 | 22 | @abstractmethod 23 | def encode_value(self, v: V) -> bytes: 24 | pass 25 | 26 | @abstractmethod 27 | def decode_key(self, k: bytes) -> K: 28 | pass 29 | 30 | @abstractmethod 31 | def decode_value(self, v: bytes) -> V: 32 | pass 33 | 34 | def __getitem__(self, k: K) -> V: 35 | return self.decode_value(self.kvs.get(self.encode_key(k))) 36 | 37 | def __setitem__(self, k: K, v: V) -> None: 38 | self.kvs.put(self.encode_key(k), self.encode_value(v)) 39 | 40 | def __delitem__(self, k: K) -> None: 41 | self.kvs.delete(self.encode_key(k)) 42 | 43 | def __len__(self) -> int: 44 | return self.kvs.stat()['ndata'] 45 | 46 | def items(self) -> Iterator[Tuple[K, V]]: 47 | cursor = self.kvs.cursor() 48 | entry = cursor.first() 49 | while entry: 50 | raw_key, raw_value = entry 51 | yield self.decode_key(raw_key), self.decode_value(raw_value) 52 | entry = cursor.next() 53 | 54 | def __iter__(self) -> Iterator[V]: 55 | for _, v in self.items(): 56 | yield v 57 | 58 | def close(self) -> None: 59 | self.kvs.close() 60 | 61 | def __exit__(self, exc_type, exc_val, exc_tb) -> None: 62 | self.close() 63 | 64 | def __enter__(self): 65 | return self 66 | -------------------------------------------------------------------------------- /hiertype/data/cached_mention_reader.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from functools import reduce 3 | import torch 4 | import numpy as np 5 | from bisect import bisect_left, bisect_right 6 | 7 | from allennlp.data import Instance, DatasetReader 8 | from allennlp.data.fields import MetadataField 9 | from hiertype.contextualizers import Contextualizer, get_contextualizer 10 | from hiertype.data import Hierarchy, StringNdArrayBerkeleyDBStorage 11 | from hiertype.fields import TensorField, IntField 12 | 13 | 14 | class CachedMentionReader(DatasetReader): 15 | 16 | def __init__(self, 17 | hierarchy: Hierarchy, 18 | delimiter: str = "/", 19 | model: str = "elmo-original" 20 | ): 21 | super(CachedMentionReader, self).__init__(lazy=True) 22 | self.hierarchy = hierarchy 23 | self.delimiter = delimiter 24 | self.contextualizer: Contextualizer = get_contextualizer(model, device="cpu", tokenizer_only=True) 25 | 26 | def _read(self, file_path: str) -> Iterable[Instance]: 27 | 28 | dump_path, file_path = file_path.split(':') 29 | with StringNdArrayBerkeleyDBStorage.open(dump_path, mode='r') as dump, \ 30 | open(file_path, mode='r') as lines: 31 | 32 | for i, line in enumerate(lines): 33 | 34 | sentence_str, span_str, type_list_str = line.strip().split('\t') 35 | token_left_str, token_right_str = span_str.split(':') 36 | token_left, token_right = int(token_left_str), int(token_right_str) 37 | token_list = sentence_str.split(' ') 38 | span_text = token_list[token_left:token_right] 39 | subtoken_list, mapping = self.contextualizer.tokenize_with_mapping(token_list) 40 | 41 | subtoken_left = bisect_left(mapping, token_left) 42 | subtoken_right = bisect_right(mapping, token_right - 1) 43 | if subtoken_right == subtoken_left: 44 | subtoken_right += 1 45 | 46 | sentence_repr_np: np.ndarray = dump[str(i)] 47 | sentence_repr = torch.from_numpy(sentence_repr_np).permute(1, 0, 2) # R[Length, Layer, Emb] 48 | sentence_len = sentence_repr.size(0) 49 | sentence_repr = sentence_repr.reshape(sentence_len, -1) # R[Length, Emb] 50 | 51 | assert subtoken_right <= sentence_repr.size(0) 52 | 53 | span_repr = sentence_repr[subtoken_left:subtoken_right, :] # R[SpanLength, Emb] 54 | 55 | x_fields = { 56 | "id": MetadataField(metadata=i), 57 | "sentence_text": MetadataField(metadata=token_list), 58 | "span_text": MetadataField(metadata=span_text), 59 | "span_left": IntField(value=subtoken_left), 60 | "span_right": IntField(value=subtoken_right), 61 | "sentence": TensorField(tensor=sentence_repr, pad_dim=0), 62 | "sentence_length": IntField(value=sentence_len), 63 | "span": TensorField(tensor=span_repr, pad_dim=0), 64 | "span_length": IntField(value=subtoken_right - subtoken_left) 65 | } 66 | 67 | all_types: List[int] = sorted(list(reduce( 68 | lambda u, v: u.union(v), 69 | (set(self.hierarchy.index_of_nodes_on_path(t)) for t in type_list_str.split(' ')) 70 | ))) 71 | 72 | y_fields = { 73 | "labels": MetadataField(metadata=all_types) 74 | } 75 | 76 | yield Instance({**x_fields, **y_fields}) 77 | 78 | def text_to_instance(self, *inputs) -> Instance: 79 | raise NotImplementedError 80 | 81 | def debug(self, file_path: str): 82 | for instance in self._read(file_path): 83 | print(f"{' '.join(instance['span_text'].metadata)}\t{' '.join(self.hierarchy.type_str(t) for t in instance['labels'].metadata)}") 84 | 85 | -------------------------------------------------------------------------------- /hiertype/data/hierarchy.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import bisect 4 | from colors import green 5 | from functools import reduce 6 | from hiertype.data.alphabet import Alphabet 7 | 8 | 9 | class Hierarchy: 10 | """ 11 | Stores a hierarchical type ontology for hierarchical classification. 12 | """ 13 | OTHER = "" 14 | 15 | def __init__(self, 16 | alphabet: Alphabet, 17 | num_level: int, 18 | level: List[range], 19 | parent: List[int], 20 | children: List[range], 21 | dummies: Set[int], 22 | with_other: bool, 23 | delimiter: str = "/", 24 | ): 25 | self.num_level = num_level 26 | self._alphabet = alphabet 27 | self._level = level 28 | self._parent = parent 29 | self._children = children 30 | self._dummies = dummies 31 | self.delimiter = delimiter 32 | self.with_other = with_other 33 | 34 | def type_str(self, i: int) -> str: 35 | """ 36 | Returns the name of a type indexed by the specified id. 37 | :param i: ID 38 | :return: Type name 39 | """ 40 | return self._alphabet.idx_to_sym[i] 41 | 42 | def index(self, t: str) -> int: 43 | """ 44 | Returns the index of the specific type. If t is OOV, return 0. 45 | :param t: Type name 46 | :return: Index to that type 47 | """ 48 | return self._alphabet.sym_to_idx.get(t) or 0 49 | 50 | def index_of_nodes_on_path(self, t: str) -> List[int]: 51 | """ 52 | Returns the path of types from Any (the supertype of all types) to the given type. 53 | :param t: Type name 54 | :return: List of node indices, starting from 0 (Any) 55 | """ 56 | arcs = t.split(self.delimiter) # ["", a, b] 57 | prefixes = [self.delimiter.join(arcs[:i + 1]) for i in range(len(arcs))] 58 | nodes = [] 59 | for prefix in prefixes: 60 | node = self._alphabet.sym_to_idx.get(prefix) 61 | if node is not None: 62 | nodes.append(node) 63 | 64 | if self.with_other: 65 | last_node = nodes[-1] 66 | if len(self.children(last_node)) != 0: # not leaf 67 | other_node = self.other_child(last_node) 68 | if other_node is not None: 69 | nodes.append(other_node) # add 70 | return nodes 71 | 72 | def other_child(self, x: int) -> Optional[int]: 73 | """ 74 | Given type t indexed by x, return the index of "t/". 75 | :param x: Index of parent node 76 | :return: Index of OTHER child node (None if no child exists) 77 | """ 78 | if self.with_other: 79 | path = f"{self.type_str(x)}{self.delimiter}{Hierarchy.OTHER}" 80 | return self._alphabet.sym_to_idx.get(path) # could be None if no other found 81 | else: 82 | return None 83 | 84 | def parent(self, x: int) -> int: 85 | """Returns the parent node of this node.""" 86 | return self._parent[x] 87 | 88 | def path_to_root(self, x: int) -> List[int]: 89 | """Returns the path from the current node to the root.""" 90 | path = [x] 91 | p = x 92 | while p != 0: 93 | p = self._parent[p] 94 | path.append(p) 95 | return path 96 | 97 | def children(self, x: int) -> range: 98 | """Returns the span of children nodes of this node.""" 99 | c = self._children[x] 100 | return c if c is not None else range(0, 0) # else branch: empty range 101 | 102 | def is_dummy(self, x: int) -> bool: 103 | return x in self._dummies 104 | 105 | def sibling(self, x: int) -> Set[int]: 106 | if x == 0: # root 107 | return {0} 108 | else: 109 | return set(self._children[self._parent[x]]) 110 | 111 | def level(self, x: int) -> int: 112 | for i in range(self.num_level): 113 | if self._level[i].start <= x < self._level[i].stop: 114 | return i 115 | 116 | def level_range(self, i: int) -> range: 117 | """Returns all nodes at level i.""" 118 | return self._level[i] 119 | 120 | def size(self) -> int: 121 | return self._alphabet.size() 122 | 123 | def __str__(self): 124 | """ 125 | Prints the hierarchy in a terminal-friendly way. 126 | """ 127 | import io 128 | buf = io.StringIO() 129 | stack = [(0, " ")] 130 | 131 | def next_prefix(prefix: str, last: bool) -> str: 132 | if prefix.endswith("├─"): 133 | return f"{prefix[:-3]} │ {' └─' if last else ' ├─'}" 134 | else: 135 | return f"{prefix[:-3]} {' └─' if last else ' ├─'}" 136 | 137 | while len(stack) > 0: 138 | c, prefix = stack.pop() 139 | l = self.level(c) 140 | p = self.parent(c) 141 | print(f"{prefix} {green(str(c))} {self.type_str(c)}", file=buf) 142 | cc = self.children(c) 143 | if len(cc) != 0: 144 | cc = list(reversed(cc)) 145 | stack.append((cc[0], next_prefix(prefix, True))) 146 | for d in cc[1:]: 147 | stack.append((d, next_prefix(prefix, False))) 148 | 149 | return buf.getvalue() 150 | 151 | @classmethod 152 | def from_tree_file(cls, filename: str, with_other: bool = False, delimiter: str = '/'): 153 | 154 | def parent_type(t: str): 155 | p = delimiter.join(t.split(delimiter)[:-1]) 156 | if all(c == delimiter for c in p): # "//", "/" => "" 157 | p = "" 158 | return p 159 | 160 | types_per_level: List[Set[str]] = [] 161 | with open(filename, mode='r') as file: 162 | for line in file: 163 | arcs = line.strip().split(delimiter) # ["", "a", "b"] 164 | for i in range(1, len(arcs) + 1): 165 | prefix = delimiter.join(arcs[:i]) 166 | if prefix.endswith(delimiter): 167 | continue # skip types like "/" or "//" that is the apparent parent of "//person" 168 | if len(types_per_level) < i: 169 | types_per_level.append(set()) 170 | types_per_level[i - 1].add(prefix) 171 | 172 | if with_other and i < len(arcs): # non-leaf 173 | while len(types_per_level) <= i: 174 | types_per_level.append(set()) 175 | types_per_level[i].add(f"{prefix}{delimiter}{Hierarchy.OTHER}") 176 | 177 | num_level = len(types_per_level) 178 | alphabet: Alphabet = Alphabet.with_special_symbols([]) 179 | num_all_types = sum(len(types) for types in types_per_level) 180 | parent: List[int] = [0] * num_all_types 181 | children: List[range] = [None] * num_all_types 182 | 183 | level_start = 0 184 | level = [] 185 | for l in range(num_level): 186 | types = sorted(types_per_level[l]) 187 | level_size = len(types) 188 | level.append(range(level_start, level_start + level_size)) 189 | level_start += level_size 190 | 191 | for k in range(level_size): 192 | i = alphabet.size() 193 | t = types[k] 194 | alphabet.index(t) 195 | if i != 0: # not root 196 | p = alphabet.sym_to_idx[parent_type(t)] 197 | parent[i] = p 198 | if children[p] is None: 199 | children[p] = range(i, i + 1) 200 | else: 201 | r = children[p] 202 | children[p] = range(min(i, r.start), max(i + 1, r.stop)) 203 | 204 | dummies: Set[int] = {0} 205 | for i, t in enumerate(alphabet.idx_to_sym): 206 | if t.endswith(Hierarchy.OTHER): 207 | dummies.add(i) 208 | 209 | return cls(alphabet, num_level, level, parent, children, dummies, with_other, delimiter) 210 | -------------------------------------------------------------------------------- /hiertype/data/str_ndarray_bdb_storage.py: -------------------------------------------------------------------------------- 1 | from hiertype.data.bdb_storage import BerkeleyDBStorage 2 | import numpy as np 3 | from bsddb3 import db 4 | import msgpack 5 | import msgpack_numpy 6 | 7 | 8 | class StringNdArrayBerkeleyDBStorage(BerkeleyDBStorage[str, np.ndarray]): 9 | 10 | def __init__(self, kvs): 11 | super(StringNdArrayBerkeleyDBStorage, self).__init__(kvs) 12 | 13 | def encode_key(self, k: str) -> bytes: 14 | return k.encode() 15 | 16 | def encode_value(self, v: np.ndarray) -> bytes: 17 | return msgpack.packb(v, default=msgpack_numpy.encode) 18 | 19 | def decode_key(self, k: bytes) -> str: 20 | return k.decode() 21 | 22 | def decode_value(self, v: bytes) -> np.ndarray: 23 | return msgpack.unpackb(v, object_hook=msgpack_numpy.decode) 24 | 25 | @classmethod 26 | def open(cls, file: str, db_kind=db.DB_BTREE, mode: str = 'r'): 27 | kvs = db.DB() 28 | db_mode = { 29 | 'r': db.DB_DIRTY_READ, 30 | 'w': db.DB_CREATE 31 | }[mode] 32 | kvs.open(file, None, db_kind, db_mode) 33 | return cls(kvs) 34 | -------------------------------------------------------------------------------- /hiertype/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.decoders.hierarchical_decoder import HierarchyDecoder 2 | from hiertype.decoders.beam_decoder import BeamDecoder 3 | -------------------------------------------------------------------------------- /hiertype/decoders/beam_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | 4 | from hiertype.data import Hierarchy 5 | from hiertype.decoders import HierarchyDecoder 6 | 7 | 8 | class BeamDecoder(HierarchyDecoder): 9 | 10 | def __init__(self, 11 | hierarchy: Hierarchy, 12 | strategies: List[str], # top / other 13 | max_branching_factors: List[int], 14 | delta: List[float], 15 | top_other_delta: float = 0.0 16 | ): 17 | super(BeamDecoder, self).__init__(hierarchy) 18 | self.max_branching_factors = max_branching_factors 19 | self.delta = delta 20 | self.strategies = strategies 21 | self.top_other_delta = top_other_delta 22 | 23 | def decode(self, 24 | predictions: torch.Tensor # R[Batch, Type] 25 | ) -> List[Set[int]]: 26 | 27 | def step_weak(y: torch.Tensor, l: int) -> Set[int]: # decode without adhering to the hierarchy 28 | threshold = y[0].item() + self.delta[l] # 0: root threshold 29 | children = [(s, y[s].item()) for s in self.hierarchy.level_range(l)] 30 | children.sort(key=lambda p: p[1], reverse=True) 31 | 32 | accepted_children: List[Tuple[int, float]] = [] 33 | for i in range(min(self.max_branching_factors[l], len(children))): 34 | s, ys = children[i] 35 | if ys < threshold: 36 | break 37 | accepted_children.append(children[i]) 38 | return {s for s, _ in accepted_children} 39 | 40 | def step(y: torch.Tensor, # R[Type] 41 | l: int, 42 | t: int) -> Tuple[Set[int], bool]: 43 | 44 | threshold = y[t].item() + self.delta[l] 45 | if self.hierarchy.type_str(t) == '/other': 46 | threshold += self.top_other_delta 47 | 48 | children: List[Tuple[int, float]] = [(s, y[s].item()) for s in self.hierarchy.children(t)] 49 | children.sort(key=lambda p: p[1], reverse=True) 50 | 51 | accepted_children: List[Tuple[int, float]] = [] 52 | 53 | for i in range(min(self.max_branching_factors[l], len(children))): 54 | s, ys = children[i] 55 | if ys < threshold: 56 | break 57 | if not self.hierarchy.is_dummy(s): 58 | accepted_children.append(children[i]) 59 | 60 | if len(accepted_children) == 0 and len(children) != 0: # has children, but none accepted 61 | if self.strategies[l] == "top": 62 | return ({children[0][0]}, True) if len(children) > 0 else (set(), True) # enforces the top type 63 | elif self.strategies[l] == "other": 64 | return {self.hierarchy.index("/other")}, True # for OntoNotes 65 | else: 66 | return set(), True 67 | 68 | else: 69 | return {s for s, _ in accepted_children}, True 70 | 71 | def decode_instance(y: torch.Tensor) -> Set[int]: 72 | beam: List[Tuple[int, int, bool]] = [(0, 0, True)] # root (level, type, explore?) 73 | i = 0 74 | while i < len(beam): 75 | l, t, flag = beam[i] 76 | if l >= self.hierarchy.num_level - 1 or (not flag): 77 | break 78 | if self.strategies[l] == "weak": 79 | children = step_weak(y, l + 1) 80 | for s in children: 81 | beam.append((l + 1, s, True)) 82 | while i < len(beam) and beam[i][0] == l: 83 | i += 1 # skip all child of this level 84 | else: # strict decoding on the hierarchy 85 | children, flag = step(y, l, t) 86 | for s in children: 87 | beam.append((l + 1, s, flag)) 88 | i += 1 89 | 90 | return set(t for l, t, _ in beam if l != 0) 91 | 92 | return [ 93 | decode_instance(predictions[i, :]) 94 | for i in range(predictions.size(0)) 95 | ] 96 | -------------------------------------------------------------------------------- /hiertype/decoders/hierarchical_decoder.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from abc import ABC, abstractmethod 3 | import torch 4 | 5 | from hiertype.data import Hierarchy 6 | 7 | 8 | class HierarchyDecoder(ABC): 9 | 10 | def __init__(self, hierarchy: Hierarchy): 11 | self.hierarchy = hierarchy 12 | 13 | @abstractmethod 14 | def decode(self, 15 | y: torch.Tensor # R[Batch, Type] 16 | ) -> List[Set[int]]: 17 | pass 18 | -------------------------------------------------------------------------------- /hiertype/fields/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.fields.int_field import IntField 2 | from hiertype.fields.real_field import RealField 3 | from hiertype.fields.tensor_field import TensorField 4 | -------------------------------------------------------------------------------- /hiertype/fields/int_field.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import math 4 | 5 | from allennlp.data.fields import Field 6 | 7 | 8 | class IntField(Field[torch.Tensor]): 9 | """ 10 | An `IntField` contains a real-valued number. 11 | This field will be converted into a batched long tensor. 12 | This is different than an `allennlp.data.fields.LabelField`, where the semantics is categorical. 13 | """ 14 | 15 | def __init__(self, value: int): 16 | self.value = value 17 | 18 | def get_padding_lengths(self) -> Dict[str, int]: 19 | return {} 20 | 21 | def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: 22 | return torch.tensor(self.value, dtype=torch.int64) 23 | 24 | def empty_field(self) -> 'Field': 25 | return IntField(0) 26 | 27 | def __str__(self) -> str: 28 | return f"IntField with value: {self.value}" -------------------------------------------------------------------------------- /hiertype/fields/real_field.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import math 4 | 5 | from allennlp.data.fields import Field 6 | 7 | 8 | class RealField(Field[torch.Tensor]): 9 | """ 10 | A `RealField` contains a real-valued number. 11 | This field will be converted into a batched float tensor. 12 | """ 13 | 14 | def __init__(self, value: float): 15 | self.value = value 16 | 17 | def get_padding_lengths(self) -> Dict[str, int]: 18 | return {} 19 | 20 | def as_tensor(self, padding_lengths: Dict[str, int]) -> torch.Tensor: 21 | return torch.tensor(self.value, dtype=torch.float32) 22 | 23 | def empty_field(self) -> 'Field': 24 | return RealField(math.nan) 25 | 26 | def __str__(self) -> str: 27 | return f"RealField with value: {self.value}" 28 | -------------------------------------------------------------------------------- /hiertype/fields/tensor_field.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | from allennlp.data import DataArray 4 | 5 | from allennlp.data.fields import Field 6 | 7 | 8 | class TensorField(Field[torch.Tensor]): 9 | """ 10 | A field that contains tensors (from other input sources, e.g. HDF5). 11 | A tensor could have 1 dimension that is the sequence length: 12 | When being batched, that dimension will automatically be padded. 13 | """ 14 | 15 | def __init__(self, tensor: torch.Tensor, pad_dim: int, pad_element=0): 16 | self.tensor = tensor 17 | self.pad_dim = pad_dim 18 | self.pad_element = pad_element 19 | 20 | def get_padding_lengths(self) -> Dict[str, int]: 21 | return {"length": self.tensor.size(self.pad_dim)} 22 | 23 | def as_tensor(self, padding_lengths: Dict[str, int]) -> DataArray: 24 | pad_shape = list(self.tensor.size()) 25 | pad_shape[self.pad_dim] = padding_lengths["length"] - self.tensor.size(self.pad_dim) 26 | 27 | pad = torch.full(pad_shape, self.pad_element).type_as(self.tensor) 28 | return torch.cat([self.tensor, pad], dim=self.pad_dim) 29 | 30 | def empty_field(self) -> 'Field': 31 | raise NotImplementedError 32 | 33 | def __str__(self) -> str: 34 | return f"TensorField with shape: {self.tensor.size()}" 35 | -------------------------------------------------------------------------------- /hiertype/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.metrics.set_metric import SetMetric 2 | from hiertype.metrics.hierarchical_metric import HierarchicalMetric 3 | -------------------------------------------------------------------------------- /hiertype/metrics/hierarchical_metric.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | from functools import reduce 4 | import collections 5 | from allennlp.training.metrics import Metric 6 | 7 | from hiertype.data import Hierarchy 8 | from hiertype.metrics import SetMetric 9 | 10 | 11 | class HierarchicalMetric(Metric): 12 | 13 | def __init__(self, hierarchy: Hierarchy): 14 | self.hierarchy = hierarchy 15 | self.num_level = hierarchy.num_level 16 | self.level_sets = [ 17 | set(self.hierarchy.level_range(l)) 18 | for l in range(self.num_level) 19 | ] 20 | 21 | self.metrics_by_level = [SetMetric() for l in range(self.num_level)] 22 | self.overall_metric = SetMetric() 23 | 24 | self.serialization_dir: Optional[str] = None 25 | self.output: Optional[IO] = None 26 | 27 | def set_serialization_dir(self, serialization_dir: str): 28 | self.serialization_dir = serialization_dir 29 | 30 | def init_output(self, partition: str, epoch: int): 31 | if self.output is None: 32 | self.output = open(f"{self.serialization_dir}/{partition}-{epoch}.res", mode='w') 33 | 34 | def __call__(self, pred: List[Set[int]], gold: List[Set[int]]): 35 | 36 | # Remove root node 0 and OTHER 37 | pred: List[Set[int]] = [ 38 | {t for t in ts if not self.hierarchy.is_dummy(t)} 39 | for ts in pred 40 | ] 41 | gold: List[Set[int]] = [ 42 | {t for t in ts if not self.hierarchy.is_dummy(t)} 43 | for ts in gold 44 | ] 45 | 46 | batch_size = len(pred) 47 | for b in range(batch_size): 48 | 49 | # self.n += 1 50 | # metrics by level 51 | instance_pred_by_level = [{0}] 52 | instance_gold_by_level = [{0}] 53 | for l in range(1, self.num_level): 54 | instance_pred_l = pred[b] & self.level_sets[l] 55 | instance_pred_l |= instance_pred_by_level[l - 1] - \ 56 | {self.hierarchy.parent(t) for t in instance_pred_l} 57 | # add those in the upper layer with no children 58 | instance_gold_l = gold[b] & self.level_sets[l] 59 | instance_gold_l |= instance_gold_by_level[l - 1] - \ 60 | {self.hierarchy.parent(t) for t in instance_gold_l} 61 | 62 | pred_str = ' '.join(sorted(self.hierarchy.type_str(t) for t in instance_pred_l)) 63 | gold_str = ' '.join(sorted(self.hierarchy.type_str(t) for t in instance_gold_l)) 64 | print(f"[{l}]\t{pred_str}\t|\t{gold_str}", file=self.output) 65 | 66 | self.metrics_by_level[l]([instance_pred_l], [instance_gold_l]) 67 | 68 | instance_pred_by_level.append(instance_pred_l) 69 | instance_gold_by_level.append(instance_gold_l) 70 | 71 | pred_str = ' '.join(sorted(self.hierarchy.type_str(t) for t in pred[b])) 72 | gold_str = ' '.join(sorted(self.hierarchy.type_str(t) for t in gold[b])) 73 | print(f"[+]\t{pred_str}\t|\t{gold_str}", file=self.output) 74 | 75 | self.overall_metric([pred[b]], [gold[b]]) 76 | 77 | def get_metric(self, reset: bool) -> Dict[str, float]: 78 | 79 | level_metrics = { 80 | f"L{l}_{k}": v 81 | for l in range(1, self.num_level) 82 | for k, v in self.metrics_by_level[l].get_metric(reset).items() 83 | } 84 | 85 | overall_metrics = { 86 | f"O_{k}": v for k, v in self.overall_metric.get_metric(reset).items() 87 | } 88 | if reset: 89 | self.reset() 90 | 91 | return collections.OrderedDict(sorted({**level_metrics, **overall_metrics}.items())) 92 | 93 | def reset(self) -> None: 94 | """ 95 | Flushes the output file that stores prediction results. 96 | """ 97 | if self.output is not None: 98 | self.output.close() 99 | self.output = None 100 | -------------------------------------------------------------------------------- /hiertype/metrics/set_metric.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | from allennlp.training.metrics import Metric 5 | 6 | 7 | class SetMetric(Metric): 8 | 9 | def __init__(self): 10 | self.n = 0 11 | self.strict = 0 12 | self.pred = 0 13 | self.gold = 0 14 | self.both = 0 15 | self.sum_p = 0.0 16 | self.sum_r = 0.0 17 | self.sum_f = 0.0 18 | 19 | def __call__(self, pred: List[Set[int]], gold: List[Set[int]]): 20 | 21 | batch_size = len(pred) 22 | for b in range(batch_size): 23 | 24 | self.n += 1 25 | 26 | pred_count = len(pred[b]) 27 | gold_count = len(gold[b]) 28 | both_count = len(pred[b] & gold[b]) 29 | 30 | self.strict += 1 if pred[b] == gold[b] else 0 31 | self.pred += pred_count 32 | self.gold += gold_count 33 | self.both += both_count 34 | 35 | p = _safe_div(both_count, pred_count) 36 | r = _safe_div(both_count, gold_count) 37 | f = _f1(p, r) 38 | 39 | self.sum_p += p 40 | self.sum_r += r 41 | self.sum_f += f 42 | 43 | def get_metric(self, reset: bool) -> Dict[str, float]: 44 | mi_p = _safe_div(self.both, self.pred) 45 | mi_r = _safe_div(self.both, self.gold) 46 | mi_f = _f1(mi_p, mi_r) 47 | m = { 48 | "Acc": _safe_div(self.strict, self.n), 49 | "MaP": _safe_div(self.sum_p, self.n), 50 | "MaR": _safe_div(self.sum_r, self.n), 51 | "MaF": _safe_div(self.sum_f, self.n), 52 | "MiP": mi_p, 53 | "MiR": mi_r, 54 | "MiF": mi_f 55 | } 56 | 57 | if reset: 58 | self.reset() 59 | 60 | return m 61 | 62 | def reset(self) -> None: 63 | self.n = 0 64 | self.strict = 0 65 | self.pred = 0 66 | self.gold = 0 67 | self.both = 0 68 | self.sum_p = 0 69 | self.sum_r = 0 70 | self.sum_f = 0 71 | 72 | 73 | def _safe_div(nominator: Union[int, float], denominator: Union[int, float]) -> float: 74 | return 0.0 if nominator == 0 or denominator == 0 else float(nominator) / float(denominator) 75 | 76 | 77 | def _f1(p: float, r: float) -> float: 78 | return 0.0 if p == 0.0 or r == 0.0 else 2 * p * r / (p + r) 79 | -------------------------------------------------------------------------------- /hiertype/models/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.models.hierarchical_typer import HierarchicalTyper 2 | -------------------------------------------------------------------------------- /hiertype/models/hierarchical_typer.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import numpy as np 4 | import json 5 | from allennlp.models import Model 6 | 7 | from hiertype.data import Hierarchy 8 | from hiertype.modules import MentionFeatureExtractor, TypeScorer, IndexedHingeLoss, ComplEx, RelationConstraintLoss 9 | from hiertype.decoders import HierarchyDecoder, BeamDecoder 10 | from hiertype.metrics import HierarchicalMetric 11 | from hiertype.training import MyTrainer 12 | from hiertype.util import compact2, compact3, sample_multiple_from, sample_in_range_except 13 | 14 | 15 | class HierarchicalTyper(Model): 16 | 17 | def __init__(self, 18 | hierarchy: Hierarchy, 19 | input_dim: int, 20 | type_dim: int, 21 | bottleneck_dim: int, 22 | mention_pooling: str, 23 | with_context: bool, 24 | dropout_rate: float, 25 | emb_dropout_rate: float, 26 | margins_per_level: List[float], 27 | num_negative_samples: int, 28 | threshold_ratio: float, 29 | relation_constraint_coef: float, 30 | lift_other: bool, 31 | compute_metric_when_training: bool, 32 | decoder: HierarchyDecoder 33 | ): 34 | super(HierarchicalTyper, self).__init__(vocab=None) 35 | 36 | self.hierarchy = hierarchy 37 | self.threshold_ratio = threshold_ratio 38 | self.relation_constraint_coef = relation_constraint_coef 39 | self.num_negative_samples = num_negative_samples 40 | self.lift_other = lift_other 41 | 42 | self.mention_feature_extractor = MentionFeatureExtractor( 43 | hierarchy=hierarchy, 44 | dim=input_dim, 45 | dropout_rate=emb_dropout_rate, 46 | mention_pooling=mention_pooling, 47 | with_context=with_context 48 | ) 49 | self.type_scorer = TypeScorer( 50 | type_embeddings=torch.nn.Embedding(hierarchy.size(), type_dim), 51 | input_dim=(input_dim * 2) if with_context else input_dim, 52 | type_dim=type_dim, 53 | bottleneck_dim=bottleneck_dim, 54 | dropout_rate=dropout_rate 55 | ) 56 | self.loss = IndexedHingeLoss(torch.tensor([0.0] + margins_per_level, dtype=torch.float32)) 57 | self.rel_loss = RelationConstraintLoss( 58 | self.type_scorer.type_embeddings, 59 | ComplEx(self.type_scorer.type_embeddings.embedding_dim) 60 | ) 61 | 62 | self.decoder = decoder 63 | self.compute_metric_when_training = compute_metric_when_training 64 | self.metric = HierarchicalMetric(hierarchy) 65 | 66 | self.trainer: MyTrainer = None 67 | self.current_epoch = 0 68 | 69 | def set_trainer(self, trainer: MyTrainer): 70 | self.trainer = trainer 71 | 72 | def scores(self, 73 | sentence: torch.Tensor, 74 | sentence_length: torch.Tensor, 75 | span: torch.Tensor, 76 | span_length: torch.Tensor, 77 | span_left: torch.Tensor, # Z[Batch] 78 | span_right: torch.Tensor, # Z[Batch] 79 | **kwargs 80 | ) -> torch.Tensor: # R[Batch, Type] 81 | 82 | mention_features = self.mention_feature_extractor( 83 | sentence, sentence_length, span, span_length, span_left, span_right 84 | ) # R[Batch, Emb] 85 | scores = self.type_scorer(mention_features) # R[Batch, Type] 86 | return scores 87 | 88 | def forward(self, 89 | id: List[int], 90 | span_text: List[List[str]], 91 | sentence_text: List[List[str]], 92 | sentence: torch.Tensor, 93 | sentence_length: torch.Tensor, 94 | span: torch.Tensor, 95 | span_length: torch.Tensor, 96 | span_left: torch.Tensor, # Z[Batch] 97 | span_right: torch.Tensor, # Z[Batch] 98 | labels: List[List[int]] 99 | ) -> Dict[str, torch.Tensor]: 100 | 101 | device = sentence.device 102 | 103 | scores = self.scores(sentence, sentence_length, span, span_length, span_left, span_right) 104 | 105 | pos_type_ids = torch.from_numpy(self.get_pos_indices(labels, lift_other=self.lift_other)).to(device=device) 106 | thr_type_ids = torch.from_numpy(self.get_parent_indices(labels)).to(device=device) 107 | neg_type_ids = torch.from_numpy(self.get_neg_sibling_indices(labels)).to(device=device) 108 | neg_parent_type_ids = torch.from_numpy(self.get_parent_sibling_indices(labels)).to(device=device) 109 | levels = torch.from_numpy(self.levels(labels)).to(device=device) 110 | 111 | loss_above = self.loss(scores, pos_type_ids, thr_type_ids.unsqueeze(dim=2), levels, self.threshold_ratio) 112 | loss_below = self.loss(scores, thr_type_ids, neg_type_ids, levels, 1.0 - self.threshold_ratio) 113 | loss_both = self.loss(scores, pos_type_ids, neg_type_ids, levels, 1.0) 114 | 115 | rel_sibling_loss = self.rel_loss(pos_type_ids, thr_type_ids, neg_type_ids) # siblings are not parent 116 | rel_parent_loss = self.rel_loss(pos_type_ids, thr_type_ids, neg_parent_type_ids) # siblings of parent are not parent 117 | all_rel_loss = rel_sibling_loss + rel_parent_loss 118 | 119 | return_dict = { 120 | "scores": scores, 121 | "loss_above": loss_above, 122 | "loss_below": loss_below, 123 | "loss_both": loss_both, 124 | "loss_rel": all_rel_loss, 125 | "loss": loss_above + loss_below + loss_both + all_rel_loss * self.relation_constraint_coef 126 | } 127 | 128 | if self.training and not self.compute_metric_when_training: 129 | return return_dict 130 | 131 | predicted_types = self.decoder.decode(scores) 132 | self.current_epoch = 0 if not self.trainer else self.trainer.current_epoch 133 | if self.training: 134 | self.metric.init_output("train", self.current_epoch) 135 | else: 136 | self.metric.init_output("dev", self.current_epoch) 137 | self.metric(predicted_types, labels) 138 | 139 | return return_dict 140 | 141 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 142 | return self.metric.get_metric(reset) 143 | 144 | def levels(self, labels: List[List[int]]) -> np.ndarray: 145 | _, l = compact2( 146 | [ 147 | [self.hierarchy.level(x) for x in xs if x != 0] 148 | for xs in labels 149 | ], 150 | pad=0 151 | ) 152 | return l 153 | 154 | def get_pos_indices(self, labels: List[List[int]], lift_other: bool = False) -> np.ndarray: 155 | def lift(x): 156 | if self.hierarchy.is_dummy(x): 157 | return self.hierarchy.parent(x) 158 | else: 159 | return x 160 | 161 | _, indices = compact2( 162 | [ 163 | [lift(x) if lift_other else x for x in xs if x != 0] # remove root, which has id 0 164 | for xs in labels 165 | ], 166 | pad=-1 167 | ) 168 | return indices # Z_Type[Batch, PosCand] 169 | 170 | def get_parent_indices(self, labels: List[List[int]]) -> np.ndarray: 171 | _, indices = compact2( 172 | [ 173 | [self.hierarchy.parent(x) for x in xs if x != 0] # remove root, which has id 0 174 | for xs in labels 175 | ], 176 | pad=-1 177 | ) 178 | return indices 179 | 180 | def get_neg_sibling_indices(self, labels: List[List[int]]) -> np.ndarray: 181 | 182 | label_sets: List[Set[int]] = [set(xs) for xs in labels] 183 | 184 | # use all negative siblings of this positive root if num_negative_samples <= 0 185 | def choose_except(x: int, excluded: Set[int]) -> List[int]: 186 | if self.num_negative_samples <= 0: 187 | return list(self.hierarchy.sibling(x).difference(excluded)) 188 | else: 189 | r = self.hierarchy.level_range(self.hierarchy.level(x)) 190 | return sample_in_range_except(r.start, r.stop, self.num_negative_samples, excluded) 191 | 192 | l = [ 193 | [ 194 | choose_except(x, label_sets[i]) 195 | for x in xs if x != 0 196 | ] 197 | for i, xs in enumerate(labels) 198 | ] 199 | _, _, indices = compact3(l, pad=-1) 200 | return indices 201 | 202 | def get_parent_sibling_indices(self, labels: List[List[int]]) -> np.ndarray: 203 | 204 | label_sets: List[Set[int]] = [set(xs) for xs in labels] 205 | 206 | # use all negative siblings of this positive root if num_negative_samples <= 0 207 | def choose_except(x: int, excluded: Set[int]) -> List[int]: 208 | if self.num_negative_samples <= 0: 209 | return list(self.hierarchy.sibling(self.hierarchy.parent(x)).difference(excluded)) 210 | else: 211 | r = self.hierarchy.level_range(self.hierarchy.level(x) - 1) 212 | return sample_in_range_except(r.start, r.stop, self.num_negative_samples, excluded) 213 | 214 | l = [ 215 | [ 216 | choose_except(x, label_sets[i]) 217 | for x in xs if x != 0 218 | ] 219 | for i, xs in enumerate(labels) 220 | ] 221 | 222 | _, _, indices = compact3(l, pad=-1) 223 | return indices 224 | 225 | @classmethod 226 | def from_args(cls, args_path: str): 227 | args = json.load(open(args_path)) 228 | hierarchy = Hierarchy.from_tree_file(args["ontology"], with_other=args["with_other"]) 229 | return cls( 230 | hierarchy=hierarchy, 231 | input_dim=args["input_dim"], 232 | type_dim=args["type_dim"], 233 | bottleneck_dim=args["bottleneck_dim"], 234 | mention_pooling=args["mention_pooling"], 235 | with_context=True, 236 | dropout_rate=args["dropout_rate"], 237 | emb_dropout_rate=args["emb_dropout_rate"], 238 | margins_per_level=args["margins"], 239 | num_negative_samples=args["num_negative_samples"], 240 | threshold_ratio=args["threshold_ratio"], 241 | relation_constraint_coef=args["relation_constraint_coef"], 242 | lift_other=args["lift_other"], 243 | compute_metric_when_training=True, 244 | decoder=BeamDecoder( 245 | hierarchy=hierarchy, 246 | strategies=args["strategies"], 247 | max_branching_factors=args["max_branching_factors"], 248 | delta=args["delta"] 249 | ) 250 | ) 251 | 252 | @classmethod 253 | def from_model_path(cls, model_path: str, device: str = None): 254 | model = cls.from_args(f"{model_path}/args.json") 255 | model.load_state_dict(torch.load(f"{model_path}/best.th")) 256 | if device is not None: 257 | model.cuda(device) 258 | return model 259 | -------------------------------------------------------------------------------- /hiertype/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.modules.mention_feature_extractor import MentionFeatureExtractor 2 | from hiertype.modules.type_scorer import TypeScorer 3 | from hiertype.modules.indexed_hinge_loss import IndexedHingeLoss 4 | 5 | from hiertype.modules.compl_ex import ComplEx 6 | from hiertype.modules.relation_constraint_loss import RelationConstraintLoss 7 | -------------------------------------------------------------------------------- /hiertype/modules/compl_ex.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | 4 | 5 | class ComplEx(torch.nn.Module): 6 | 7 | def __init__(self, dim: int): 8 | super(ComplEx, self).__init__() 9 | self.dim = dim 10 | self.rel_emb = torch.nn.Parameter(torch.zeros(dim, dtype=torch.float32)) 11 | torch.nn.init.normal_(self.rel_emb, mean=0.0, std=0.01) 12 | 13 | @staticmethod 14 | def real_as_complex(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 15 | re, im = x.chunk(chunks=2, dim=-1) 16 | return re, im 17 | 18 | def forward(self, 19 | s: torch.Tensor, # R[Batch, Emb] 20 | t: torch.Tensor # R[Batch, Emb] 21 | ) -> torch.Tensor: # R[Batch] 22 | s_re, s_im = ComplEx.real_as_complex(s) 23 | t_re, t_im = ComplEx.real_as_complex(t) 24 | r_re, r_im = ComplEx.real_as_complex(self.rel_emb) 25 | 26 | y = s_re * t_re * r_re + s_im * t_im * r_re + s_re * t_im * r_im - s_im * t_re * r_im 27 | return torch.sum(y, dim=-1) 28 | -------------------------------------------------------------------------------- /hiertype/modules/indexed_hinge_loss.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class IndexedHingeLoss(torch.nn.Module): 8 | 9 | def __init__(self, 10 | margins: torch.Tensor, # R[Level] 11 | ): 12 | super(IndexedHingeLoss, self).__init__() 13 | self.margins = torch.nn.Parameter(margins) 14 | self.margins.requires_grad = False # freeze! 15 | 16 | def forward(self, 17 | scores: torch.Tensor, # R[Batch, Type] 18 | pos_type_ids: torch.Tensor, # Z_Type[Batch, PosCand] 19 | neg_type_ids: torch.Tensor, # Z_Type[Batch, PosCand, NegCand] 20 | levels: torch.Tensor, # Z_Level[Batch, PosCand] 21 | margin_ratio: float 22 | ): 23 | 24 | batch_size = scores.size(0) 25 | max_pos_type_size = pos_type_ids.size(1) 26 | max_neg_type_size = neg_type_ids.size(2) 27 | device = scores.device 28 | batch_range = torch.arange(0, batch_size, dtype=torch.int64, device=device) # Z_Batch[Batch] 29 | zero = torch.tensor(0, dtype=torch.int64, device=device) 30 | 31 | neg_mask = (neg_type_ids != -1) # B[Batch, PosCand, NegCand] 32 | neg_mask_r = neg_mask.float() # R[Batch, PosCand, NegCand] 33 | pos_mask = pos_type_ids != -1 34 | 35 | pos_type_ids = torch.where(pos_mask, pos_type_ids, zero) 36 | neg_type_ids = torch.where(neg_mask, neg_type_ids, zero) 37 | 38 | pos_scores = scores[ 39 | batch_range.unsqueeze(dim=1).expand(batch_size, max_pos_type_size), 40 | pos_type_ids 41 | ] # R[Batch, PosCand] 42 | 43 | neg_scores = scores[ 44 | batch_range.unsqueeze(dim=1).unsqueeze(dim=2).expand(batch_size, max_pos_type_size, max_neg_type_size), 45 | neg_type_ids 46 | ] # R[Batch, PosCand, NegCand] 47 | 48 | level_margins = self.margins[levels] * margin_ratio # R[Batch, PosCand] 49 | 50 | diff = F.relu( 51 | level_margins.unsqueeze(dim=2).expand_as(neg_scores) 52 | - pos_scores.unsqueeze(dim=2).expand_as(neg_scores) 53 | + neg_scores 54 | ) 55 | 56 | return (diff * neg_mask_r).sum() / neg_mask_r.sum() 57 | -------------------------------------------------------------------------------- /hiertype/modules/mention_feature_extractor.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from hiertype.data import Hierarchy 6 | from allennlp.nn.util import get_mask_from_sequence_lengths 7 | 8 | 9 | class MentionFeatureExtractor(torch.nn.Module): 10 | 11 | def __init__(self, 12 | hierarchy: Hierarchy, 13 | dim: int, 14 | dropout_rate: float, 15 | mention_pooling: str = "max", # max / mean / attention 16 | with_context: bool = False 17 | ): 18 | super(MentionFeatureExtractor, self).__init__() 19 | self.hierarchy = hierarchy 20 | self.mention_pooling = mention_pooling 21 | self.dim = dim 22 | self.dropout = torch.nn.Dropout(dropout_rate) 23 | self.projection = torch.nn.Linear(self.dim, self.dim) 24 | torch.nn.init.eye_(self.projection.weight) 25 | torch.nn.init.zeros_(self.projection.weight) 26 | 27 | if self.mention_pooling == "attention": 28 | self.query = torch.nn.Parameter(torch.zeros(dim, dtype=torch.float32)) 29 | torch.nn.init.normal_(self.query, mean=0.0, std=0.02) 30 | self.with_context = with_context 31 | if self.with_context: 32 | self.mention_query_transform = torch.nn.Linear(self.dim, self.dim) 33 | torch.nn.init.normal_(self.mention_query_transform.weight, mean=0.0, std=0.02) 34 | torch.nn.init.zeros_(self.mention_query_transform.bias) 35 | 36 | def forward(self, 37 | sentence: torch.Tensor, # R[Batch, Word, Emb] 38 | sentence_lengths: torch.Tensor, # Z_Word[Batch] 39 | span: torch.Tensor, # R[Batch, Word, Emb] 40 | span_lengths: torch.Tensor, # Z_Word[Batch] 41 | span_left: torch.Tensor, # Z_Word[Batch] 42 | span_right: torch.Tensor # Z_Word[Batch] 43 | ) -> torch.Tensor: # R[Batch, Feature] 44 | 45 | batch_size = sentence.size(0) 46 | sentence_max_len = sentence.size(1) 47 | emb_size = sentence.size(2) 48 | span_max_len = span.size(1) 49 | device = sentence.device 50 | neg_inf = torch.tensor(-10000, dtype=torch.float32, device=device) 51 | zero = torch.tensor(0, dtype=torch.float32, device=device) 52 | 53 | span = self.projection(self.dropout(span)) 54 | sentence = self.projection(self.dropout(sentence)) 55 | 56 | span_mask = get_mask_from_sequence_lengths(span_lengths, span_lengths.max().item()).byte() # Z[Batch, Word] 57 | 58 | def attention_pool(): 59 | span_attn_scores = torch.einsum('e,bwe->bw', self.query, span) 60 | masked_span_attn_scores = torch.where(span_mask, span_attn_scores, neg_inf) 61 | normalized_span_attn_scores = F.softmax(masked_span_attn_scores, dim=1) 62 | span_pooled = torch.einsum('bwe,bw->be', span, normalized_span_attn_scores) 63 | return span_pooled 64 | 65 | span_pooled = { 66 | "max": lambda: torch.max(torch.where(span_mask.unsqueeze(dim=2).expand_as(span), span, neg_inf), dim=1)[0], 67 | "mean": lambda: torch.sum( 68 | torch.where(span_mask.unsqueeze(dim=2).expand_as(span), span, zero), dim=1 69 | ) / span_lengths.unsqueeze(dim=1).expand(batch_size, emb_size), 70 | "attention": lambda: attention_pool() 71 | }[self.mention_pooling]() # R[Batch, Emb] 72 | 73 | features = span_pooled 74 | 75 | if self.with_context: 76 | sentence_mask = get_mask_from_sequence_lengths(sentence_lengths, sentence_max_len).bool() # B[B, L] 77 | 78 | length_range = torch.arange(0, sentence_max_len, device=device) \ 79 | .unsqueeze(dim=0).expand(batch_size, sentence_max_len) 80 | span_mask = (length_range >= (span_left.unsqueeze(dim=1).expand_as(length_range))) \ 81 | & (length_range < (span_right.unsqueeze(dim=1).expand_as(length_range))) # B[Batch, Length] 82 | 83 | span_queries = self.mention_query_transform(span_pooled) 84 | attn_scores = torch.einsum('be,bwe->bw', span_queries, sentence) # R[Batch, Word] 85 | masked_attn_scores = torch.where(sentence_mask, attn_scores, neg_inf) # R[Batch, Word] & ~span_mask 86 | normalized_attn_scores = F.softmax(masked_attn_scores, dim=1) 87 | context_pooled = torch.einsum('bwe,bw->be', sentence, normalized_attn_scores) # R[Batch, Emb] 88 | 89 | features = torch.cat([span_pooled, context_pooled], dim=1) # R[Batch, Emb*2] 90 | 91 | return features # R[Batch, Emb] 92 | -------------------------------------------------------------------------------- /hiertype/modules/relation_constraint_loss.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from hiertype.modules import ComplEx 6 | 7 | 8 | class RelationConstraintLoss(torch.nn.Module): 9 | 10 | def __init__(self, 11 | type_embeddings: torch.nn.Embedding, 12 | scorer: torch.nn.Module, 13 | ): 14 | super(RelationConstraintLoss, self).__init__() 15 | self.scorer = scorer 16 | self.type_embeddings = type_embeddings 17 | 18 | def forward(self, 19 | subtype_ids: torch.Tensor, # Z_Type[Batch, PosCand] 20 | pos_supertype_ids: torch.Tensor, # Z_Type[Batch, PosCand] 21 | neg_supertype_ids: torch.Tensor, # Z_Type[Batch, PosCand, NegCand] 22 | negative_samples_coef: float = 1.0 23 | ): 24 | 25 | batch_size = subtype_ids.size(0) 26 | zero = torch.tensor(0, dtype=torch.int64, device=subtype_ids.device) 27 | 28 | subtype_mask = subtype_ids != -1 # B[Batch, PosCand] 29 | pos_mask = pos_supertype_ids != -1 # B[Batch, PosCand] 30 | pos_mask_r = pos_mask.float() # R[Batch, PosCand] 31 | neg_mask = neg_supertype_ids != -1 # B[Batch, PosCand, NegCand] 32 | neg_mask_r = neg_mask.float() # R[Batch, PosCand, NegCand] 33 | 34 | subtype_ids = torch.where(subtype_mask, subtype_ids, zero) 35 | pos_supertype_ids = torch.where(pos_mask, pos_supertype_ids, zero) 36 | neg_supertype_ids = torch.where(neg_mask, neg_supertype_ids, zero) 37 | 38 | subtype_embs = self.type_embeddings(subtype_ids) # R[Batch, Emb] 39 | pos_supertype_embs = self.type_embeddings(pos_supertype_ids) # R[Batch, PosCand, Emb] 40 | neg_supertype_embs = self.type_embeddings(neg_supertype_ids) # R[Batch, PosCand, NegCand, Emb] 41 | 42 | pos_scores = self.scorer( 43 | subtype_embs, 44 | pos_supertype_embs 45 | ) # R[Batch, PosCand] 46 | 47 | neg_scores = self.scorer( 48 | subtype_embs.unsqueeze(dim=2).expand_as(neg_supertype_embs), 49 | neg_supertype_embs 50 | ) # R[Batch, PosCand, NegCand] 51 | 52 | pos_diff = F.relu(-pos_scores + 1.0) 53 | pos_loss = (pos_diff * pos_mask_r).sum() / pos_mask_r.sum() 54 | 55 | neg_diff = F.relu(neg_scores + 1.0) 56 | neg_loss = (neg_diff * neg_mask_r).sum() / neg_mask_r.sum() * negative_samples_coef 57 | 58 | return pos_loss + neg_loss 59 | -------------------------------------------------------------------------------- /hiertype/modules/type_scorer.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import torch 3 | 4 | 5 | class TypeScorer(torch.nn.Module): 6 | 7 | def __init__(self, 8 | type_embeddings: torch.nn.Embedding, 9 | input_dim: int, 10 | type_dim: int, 11 | bottleneck_dim: int, 12 | dropout_rate: float 13 | ): 14 | super(TypeScorer, self).__init__() 15 | self.type_embeddings = type_embeddings 16 | self.ffnn = torch.nn.Sequential( 17 | torch.nn.Dropout(dropout_rate), 18 | torch.nn.Linear(input_dim, input_dim // 2), 19 | torch.nn.Tanh(), 20 | torch.nn.Linear(input_dim // 2, input_dim // 2), 21 | torch.nn.Tanh(), 22 | torch.nn.Linear(input_dim // 2, type_dim), 23 | torch.nn.Tanh() 24 | ) 25 | self.linear = torch.nn.Linear( 26 | in_features=type_embeddings.embedding_dim, 27 | out_features=type_embeddings.num_embeddings, 28 | bias=True 29 | ) 30 | self.linear.weight = type_embeddings.weight # Put the embeddings into the last layer 31 | self.bottleneck_dim = bottleneck_dim 32 | 33 | if self.bottleneck_dim > 0: 34 | self.bottleneck_weight = torch.nn.Parameter(torch.tensor(0.1)) 35 | self.bottleneck = torch.nn.Sequential( 36 | torch.nn.Linear(type_embeddings.embedding_dim, bottleneck_dim), 37 | torch.nn.Linear(bottleneck_dim, type_embeddings.num_embeddings) 38 | ) 39 | 40 | def forward(self, 41 | features: torch.Tensor 42 | ) -> torch.Tensor: 43 | 44 | mapped_mentions = self.ffnn(features) # R[Batch, Emb] 45 | scores = self.linear(mapped_mentions) # R[Batch, Type] 46 | 47 | if self.bottleneck_dim > 0: 48 | bottleneck_scores = self.bottleneck(mapped_mentions) # R[Batch, Type] 49 | scores = scores + self.bottleneck_weight * bottleneck_scores 50 | 51 | return scores 52 | 53 | -------------------------------------------------------------------------------- /hiertype/training/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.training.my_trainer import MyTrainer 2 | -------------------------------------------------------------------------------- /hiertype/training/my_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from allennlp.training import Trainer 3 | 4 | from typing import Any, Dict 5 | import traceback 6 | import torch 7 | import time 8 | import datetime 9 | import logging 10 | import os 11 | from allennlp.training import Trainer 12 | from allennlp.training import util as training_util 13 | from allennlp.common.util import dump_metrics 14 | from allennlp.common.checks import ConfigurationError 15 | 16 | 17 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 18 | 19 | 20 | class MyTrainer(Trainer): 21 | """ 22 | A modified trainer on top of the original AllenNLP trainer. 23 | Modifications include: 24 | - added a `current_epoch` field to record the current training epoch 25 | - Runs dev before train to 26 | - Capture dev bugs early 27 | - Get dev performance on random model 28 | """ 29 | 30 | def __init__(self, **kwargs): 31 | super(MyTrainer, self).__init__(**kwargs) 32 | self.current_epoch = 0 33 | 34 | def _train_epoch(self, epoch: int) -> Dict[str, float]: 35 | self.current_epoch = epoch # record the current epoch 36 | return super()._train_epoch(epoch) 37 | 38 | def train(self) -> Dict[str, Any]: 39 | """ 40 | Trains the supplied model with the supplied parameters. 41 | """ 42 | try: 43 | epoch_counter = self._restore_checkpoint() 44 | except RuntimeError: 45 | traceback.print_exc() 46 | raise ConfigurationError("Could not recover training from the checkpoint. Did you mean to output to " 47 | "a different serialization directory or delete the existing serialization " 48 | "directory?") 49 | 50 | training_util.enable_gradient_clipping(self.model, self._grad_clipping) 51 | 52 | logger.info("Beginning training.") 53 | 54 | train_metrics: Dict[str, float] = {} 55 | val_metrics: Dict[str, float] = {} 56 | this_epoch_val_metric: float = None 57 | metrics: Dict[str, Any] = {} 58 | epochs_trained = 0 59 | training_start_time = time.time() 60 | 61 | metrics['best_epoch'] = self._metric_tracker.best_epoch 62 | for key, value in self._metric_tracker.best_epoch_metrics.items(): 63 | metrics["best_validation_" + key] = value 64 | 65 | # Now run dev once 66 | if self._validation_data is not None: 67 | with torch.no_grad(): 68 | val_loss, num_batches = self._validation_loss() 69 | val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) 70 | self._tensorboard.log_metrics( 71 | train_metrics, 72 | val_metrics=val_metrics, 73 | log_to_console=True, 74 | epoch=0 75 | ) 76 | if self._serialization_dir: 77 | dump_metrics(os.path.join(self._serialization_dir, f'metrics_start.json'), metrics) 78 | # Finish initial dev run 79 | 80 | for epoch in range(epoch_counter, self._num_epochs): 81 | epoch_start_time = time.time() 82 | train_metrics = self._train_epoch(epoch) 83 | 84 | # get peak of memory usage 85 | if 'cpu_memory_MB' in train_metrics: 86 | metrics['peak_cpu_memory_MB'] = max(metrics.get('peak_cpu_memory_MB', 0), 87 | train_metrics['cpu_memory_MB']) 88 | for key, value in train_metrics.items(): 89 | if key.startswith('gpu_'): 90 | metrics["peak_"+key] = max(metrics.get("peak_"+key, 0), value) 91 | 92 | if self._validation_data is not None: 93 | with torch.no_grad(): 94 | # We have a validation set, so compute all the metrics on it. 95 | val_loss, num_batches = self._validation_loss() 96 | val_metrics = training_util.get_metrics(self.model, val_loss, num_batches, reset=True) 97 | 98 | # Check validation metric for early stopping 99 | this_epoch_val_metric = val_metrics[self._validation_metric] 100 | self._metric_tracker.add_metric(this_epoch_val_metric) 101 | 102 | if self._metric_tracker.should_stop_early(): 103 | logger.info("Ran out of patience. Stopping training.") 104 | break 105 | 106 | self._tensorboard.log_metrics(train_metrics, 107 | val_metrics=val_metrics, 108 | log_to_console=True, 109 | epoch=epoch + 1) # +1 because tensorboard doesn't like 0 110 | 111 | # Create overall metrics dict 112 | training_elapsed_time = time.time() - training_start_time 113 | metrics["training_duration"] = str(datetime.timedelta(seconds=training_elapsed_time)) 114 | metrics["training_start_epoch"] = epoch_counter 115 | metrics["training_epochs"] = epochs_trained 116 | metrics["epoch"] = epoch 117 | 118 | for key, value in train_metrics.items(): 119 | metrics["training_" + key] = value 120 | for key, value in val_metrics.items(): 121 | metrics["validation_" + key] = value 122 | 123 | if self._metric_tracker.is_best_so_far(): 124 | # Update all the best_ metrics. 125 | # (Otherwise they just stay the same as they were.) 126 | metrics['best_epoch'] = epoch 127 | for key, value in val_metrics.items(): 128 | metrics["best_validation_" + key] = value 129 | 130 | self._metric_tracker.best_epoch_metrics = val_metrics 131 | 132 | if self._serialization_dir: 133 | dump_metrics(os.path.join(self._serialization_dir, f'metrics_epoch_{epoch}.json'), metrics) 134 | 135 | # The Scheduler API is agnostic to whether your schedule requires a validation metric - 136 | # if it doesn't, the validation metric passed here is ignored. 137 | if self._learning_rate_scheduler: 138 | self._learning_rate_scheduler.step(this_epoch_val_metric, epoch) 139 | if self._momentum_scheduler: 140 | self._momentum_scheduler.step(this_epoch_val_metric, epoch) 141 | 142 | self._save_checkpoint(epoch) 143 | 144 | epoch_elapsed_time = time.time() - epoch_start_time 145 | logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) 146 | 147 | if epoch < self._num_epochs - 1: 148 | training_elapsed_time = time.time() - training_start_time 149 | estimated_time_remaining = training_elapsed_time * \ 150 | ((self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) 151 | formatted_time = str(datetime.timedelta(seconds=int(estimated_time_remaining))) 152 | logger.info("Estimated training time remaining: %s", formatted_time) 153 | 154 | epochs_trained += 1 155 | 156 | # Load the best model state before returning 157 | best_model_state = self._checkpointer.best_model_state() 158 | if best_model_state: 159 | self.model.load_state_dict(best_model_state) 160 | 161 | return metrics 162 | -------------------------------------------------------------------------------- /hiertype/util/__init__.py: -------------------------------------------------------------------------------- 1 | from hiertype.util.sample import sample_except, sample_multiple_from, sample_in_range_except 2 | from hiertype.util.compact import compact2, compact3 3 | -------------------------------------------------------------------------------- /hiertype/util/compact.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import numpy as np 3 | 4 | 5 | def compact2( 6 | xss: List[List[int]], 7 | pad: int 8 | ) -> Tuple[np.ndarray, np.ndarray]: 9 | """ 10 | Pads and constructs a compact representation for a List[List[int]]. 11 | """ 12 | len1 = [len(xs) for xs in xss] 13 | max_len1 = max(len1) 14 | 15 | c = np.stack([ 16 | np.pad( 17 | np.array(xs, dtype=np.int64), 18 | (0, max_len1 - len(xs)), 19 | mode='constant', 20 | constant_values=pad 21 | ) 22 | for xs in xss 23 | ]) 24 | 25 | l1 = np.array(len1) 26 | return l1, c 27 | 28 | 29 | def compact3( 30 | xsss: List[List[List[int]]], 31 | pad: int 32 | ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 33 | """ 34 | Pads and constructs a compact representation for a List[List[List[int]]]. 35 | """ 36 | for xss in xsss: 37 | if len(xss) == 0: 38 | xss.append([pad]) 39 | 40 | len1 = [len(xss) for xss in xsss] 41 | max_len1 = max(len1) 42 | len2 = [[len(xs) for xs in xss] for xss in xsss] 43 | max_len2 = max(max(l, default=0) for l in len2) 44 | 45 | c = np.stack([ 46 | np.pad( 47 | np.stack([ 48 | np.pad( 49 | np.array(xs, dtype=np.int64), 50 | (0, max_len2 - len(xs)), 51 | mode='constant', 52 | constant_values=pad 53 | ) 54 | for xs in xss 55 | ]), 56 | [(0, max_len1 - len(xss)), (0, 0)], 57 | mode='constant', 58 | constant_values=pad 59 | ) 60 | for xss in xsss 61 | ]) 62 | 63 | l1, l2 = compact2(len2, 0) 64 | 65 | return l1, l2, c 66 | -------------------------------------------------------------------------------- /hiertype/util/sample.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import numpy as np 3 | 4 | 5 | def sample_except(n: int, k: int, excluded: Set[int]) -> np.ndarray: 6 | """Sample k numbers in {0 ... n - 1} except excluded.""" 7 | 8 | samples = np.random.choice(n - len(excluded), k) 9 | for x in sorted(excluded): 10 | samples[samples >= x] += 1 11 | 12 | return samples 13 | 14 | 15 | def sample_in_range_except(lo: int, hi: int, k: int, excluded: Set[int]) -> List[int]: 16 | exc = {x for x in excluded if lo <= x < hi} 17 | try: 18 | return list(sample_except(hi - lo, k, {x - lo for x in exc}) + lo) 19 | except: 20 | return [] 21 | 22 | 23 | def sample_multiple_from(xs: List[int], n: int) -> List[int]: 24 | indices = np.random.choice(len(xs), n) 25 | return [xs[i] for i in indices] 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | h5py 3 | tqdm~=4.43.0 4 | numpy~=1.18.4 5 | allennlp>=0.8.4 6 | pandas~=1.0.4 7 | msgpack~=1.0.0 8 | msgpack-numpy 9 | bsddb3~=6.2.7 10 | spacy>=2.0.0 11 | ansicolors 12 | transformers~=2.5.1 13 | fire~=0.2.1 14 | -------------------------------------------------------------------------------- /tapes/aida-data.tape: -------------------------------------------------------------------------------- 1 | global { 2 | aidaPartition=(AidaPartition: practice eval) 3 | aidaLang=(AidaLang: eng rus ukr all) 4 | aidaTask=(AidaTask: entity event relation) 5 | aidaSheetName=(AidaTask: entity=entities event=events relation=relations) 6 | aidaLdcDataIdentifier=(AidaTask: entity=arg event=evt relation=rel) 7 | } 8 | 9 | task aidaLtfToConcrete : concreteAida 10 | < ltf=(AidaPartition: practice=$aidaPracticeSourcePath eval=$aidaEvalSourcePath) 11 | > concrete 12 | { 13 | for z in $(ls $ltf); do 14 | unzip $ltf/$z -d . 15 | done 16 | ls -1 ltf/* > ltf-filelist 17 | mkdir -p $concrete 18 | PYTHONPATH=$concreteAida python $concreteAida/aida2concrete.py \ 19 | --filelist ltf-filelist --output-dir $concrete 20 | } 21 | 22 | plan AidaLtfToConcrete { 23 | reach aidaLtfToConcrete via (AidaPartition: practice eval) 24 | } 25 | 26 | 27 | task aidaOntology : hiertype 28 | < data=$aidaPracticeAnnotationsPath 29 | > ontology="ontology.txt" 30 | :: aidaSheetName=@ 31 | { 32 | python $hiertype/aux/scripts/aida/get-aida-ontology.py \ 33 | --path $data --sheet $aidaSheetName > $ontology 34 | } 35 | 36 | plan AidaOntology { 37 | reach aidaOntology via (AidaTask: entity) 38 | } 39 | 40 | 41 | task aidaMergeLdcTsvs : hiertype 42 | < aidaPracticeAnnotationsPath=@ 43 | < aidaEvalAnnotationsPath=@ 44 | > out 45 | :: aidaLdcDataIdentifier=@ 46 | { 47 | mkdir -p $out 48 | python $hiertype/aux/scripts/aida/merge-tsvs.py \ 49 | $aidaPracticeAnnotationsPath/data/R103/R103_${aidaLdcDataIdentifier}_mentions.tab \ 50 | $aidaPracticeAnnotationsPath/data/R107/R107_${aidaLdcDataIdentifier}_mentions.tab > $out/train.tsv 51 | 52 | python $hiertype/aux/scripts/aida/merge-tsvs.py \ 53 | $aidaPracticeAnnotationsPath/data/R105/R105_${aidaLdcDataIdentifier}_mentions.tab > $out/dev.tsv 54 | 55 | python $hiertype/aux/scripts/aida/merge-tsvs.py \ 56 | $aidaEvalAnnotationsPath/data/E101/E101_${aidaLdcDataIdentifier}_mentions.tab \ 57 | $aidaEvalAnnotationsPath/data/E102/E102_${aidaLdcDataIdentifier}_mentions.tab \ 58 | $aidaEvalAnnotationsPath/data/E103/E103_${aidaLdcDataIdentifier}_mentions.tab > $out/test.tsv 59 | 60 | } 61 | 62 | plan AidaMergeLdcTsvs { 63 | reach aidaMergeLdcTsvs via (AidaLang: eng) * (AidaTask: entity) 64 | } 65 | 66 | 67 | task aidaData : hiertype 68 | < tsv=$out@aidaMergeLdcTsvs 69 | < practiceConcrete=$concrete@aidaLtfToConcrete[AidaPartition:practice] 70 | < evalConcrete=$concrete@aidaLtfToConcrete[AidaPartition:eval] 71 | > out 72 | :: aidaLang=@ 73 | { 74 | mkdir -p $out 75 | # python $hiertype/aux/scripts/aida/resolve-token-index.py \ 76 | # --tsv $tsv/train.tsv --lang $aidaLang --concrete_dir $practiceConcrete | grep -Pv "[ЁёА-я]" > $out/train.tsv 77 | # python $hiertype/aux/scripts/aida/resolve-token-index.py \ 78 | # --tsv $tsv/dev.tsv --lang $aidaLang --concrete_dir $practiceConcrete | grep -Pv "[ЁёА-я]" > $out/dev.tsv 79 | # python $hiertype/aux/scripts/aida/resolve-token-index.py \ 80 | # --tsv $tsv/test.tsv --lang $aidaLang --concrete_dir $evalConcrete | grep -Pv "[ЁёА-я]" > $out/test.tsv 81 | python $hiertype/aux/scripts/aida/resolve-token-index.py \ 82 | --tsv $tsv/train.tsv --lang $aidaLang --concrete_dir $practiceConcrete > $out/train.tsv 83 | python $hiertype/aux/scripts/aida/resolve-token-index.py \ 84 | --tsv $tsv/dev.tsv --lang $aidaLang --concrete_dir $practiceConcrete > $out/dev.tsv 85 | python $hiertype/aux/scripts/aida/resolve-token-index.py \ 86 | --tsv $tsv/test.tsv --lang $aidaLang --concrete_dir $evalConcrete > $out/test.tsv 87 | } 88 | 89 | plan PreprocessAidaData { 90 | reach aidaData via (AidaLang: eng all) * (AidaTask: entity) 91 | reach aidaOntology via (AidaTask: entity) 92 | } 93 | -------------------------------------------------------------------------------- /tapes/bbn-data.tape: -------------------------------------------------------------------------------- 1 | global { 2 | bbnDataUrl="https://drive.google.com/file/d/0B2ke42d0kYFfdVk2ZkJ6TGRzR2M/view" 3 | } 4 | 5 | task bbnData : hiertype 6 | < bbnPath=@ 7 | > out 8 | { 9 | mkdir -p $out 10 | cat $bbnPath/train.json | python $hiertype/aux/scripts/bbn/bbn-jsonlines-to-tsv.py > $out/train-dev.tsv 11 | cat $bbnPath/test.json | python $hiertype/aux/scripts/bbn/bbn-jsonlines-to-tsv.py > $out/test.tsv 12 | python $hiertype/aux/scripts/bbn/bbn-gen-dev.py $out/train-dev.tsv $out/train.tsv $out/dev.tsv 13 | } 14 | 15 | plan BBNData { 16 | reach bbnData 17 | } 18 | 19 | task bbnOntology : hiertype 20 | < data=$out@bbnData 21 | > ontology="ontology.txt" 22 | { 23 | cat $data/train.tsv $data/dev.tsv $data/test.tsv | python $hiertype/aux/scripts/bbn/bbn-ontology.py > $ontology 24 | } 25 | 26 | plan BBNOntology { 27 | reach bbnOntology 28 | } 29 | 30 | -------------------------------------------------------------------------------- /tapes/cache.tape: -------------------------------------------------------------------------------- 1 | task cache : hiertype 2 | < data=@ 3 | > dump 4 | :: contextualizer=@ :: unit=@ :: layers=@ 5 | :: grid=@ :: .submitter=$grid .action_flags=@ .resource_flags=@ 6 | :: gpuId=@ 7 | { 8 | mkdir -p $dump 9 | for partition in train dev test; do 10 | PYTHONPATH=$hiertype CUDA_VISIBLE_DEVICES=$gpuId \ 11 | python $hiertype/hiertype/commands/cache_repr.py \ 12 | --input $data/${partition}.tsv --output $dump/${partition}.db \ 13 | --model $contextualizer \ 14 | --unit $unit \ 15 | --layers $layers 16 | done 17 | } 18 | -------------------------------------------------------------------------------- /tapes/env.tape: -------------------------------------------------------------------------------- 1 | package hiertype :: .versioner=disk .path="/home/tongfei/proj/hierarchical-typing-release/" .rev=NULL { 2 | cd aux 3 | sbt clean compile assembly 4 | } 5 | 6 | package concreteAida :: .versioner=disk .path="/home/tongfei/proj/concrete-aida" .rev=NULL {} 7 | 8 | global { 9 | 10 | grid="shell" 11 | gpuId="0" 12 | 13 | .action_flags=(Device: 14 | GPU="" 15 | CPU="" 16 | ) 17 | 18 | .resource_flags=(Device: 19 | GPU="" 20 | CPU="" 21 | ) 22 | 23 | } 24 | 25 | 26 | global { 27 | 28 | aidaLtfPath="/media/tongfei/dump/data/aida2019/Phase_1_Eval_Practice" 29 | aidaLdcData="/media/tongfei/dump/data/aida2019/LDC2019E07_AIDA_Phase_1_Evaluation_Practice_Topic_Annotations_V8.0" 30 | 31 | aidaPracticeSourcePath="/media/tongfei/dump/data/aida/aida-practice-source" 32 | aidaPracticeAnnotationsPath="/media/tongfei/dump/data/aida/aida-practice-annotations" 33 | aidaEvalSourcePath="/media/tongfei/dump/data/aida/aida-eval-source" 34 | aidaEvalAnnotationsPath="/media/tongfei/dump/data/aida/aida-eval-annotations" 35 | 36 | ultrafineCleanedPath="/media/tongfei/dump/data/cleaned-ultrafine/cleaned_data" 37 | 38 | bbnPath="/media/tongfei/dump/data/bbn" 39 | 40 | } -------------------------------------------------------------------------------- /tapes/params.tape: -------------------------------------------------------------------------------- 1 | global { 2 | 3 | dataset=(Dataset: Aida Figer OntoNotes BBN) 4 | partition=(Partition: train dev test) 5 | 6 | unit=(Unit: subword word span sentence) 7 | 8 | data=(Dataset: 9 | Aida=$out@aidaData 10 | Figer=$out@shimaokaData[ShimaokaDataset:Figer] 11 | OntoNotes=$out@shimaokaData[ShimaokaDataset:OntoNotes] 12 | BBN=$out@bbnData 13 | ) 14 | 15 | ontology=(Dataset: 16 | Aida=$ontology@aidaOntology 17 | Figer=$ontology@shimaokaOntology[ShimaokaDataset:Figer] 18 | OntoNotes=$ontology@shimaokaOntology[ShimaokaDataset:OntoNotes] 19 | BBN=$ontology@bbnOntology 20 | ) 21 | 22 | numLevels=(Dataset: Aida=3 Figer=2 OntoNotes=3 BBN=2) 23 | multiLabel=(Dataset: Aida=False Figer=True OntoNotes=True BBN=True) 24 | forceOther=(Dataset: Aida=False Figer=False OntoNotes=True BBN=False) 25 | 26 | thresholdRatio=(ThresholdRatio: 0.05 0.1 0.15 0.2 0.25) 27 | metric="+O_MiF" 28 | 29 | contextualizer=(Contextualizer: elmo-original bert-base-cased xlm-roberta-base) 30 | 31 | layers=(Contextualizer: 32 | elmo-original="[0,1,2]" 33 | bert-base-cased="[0,10,11,12]" 34 | xlm-roberta-base="[0,10,11,12]" 35 | ) 36 | 37 | inputDim=3072 38 | typeDim=1024 39 | 40 | embDropoutRate=(EmbDropout: 0.0 0.1 0.2 0.3 0.4 0.5 0.6) 41 | dropoutRate=(Dropout: 0.0 0.1 0.2 0.3) 42 | 43 | margins=(MarginMode: 44 | graded=(Dataset: 45 | Aida="[3,2,1]" Figer="[2,1]" OntoNotes="[3,2,1]" BBN="[2,1]" 46 | ) 47 | flat=(Dataset: 48 | Aida="[1,1,1]" Figer="[1,1]" OntoNotes="[1,1,1]" BBN="[1,1]" 49 | ) 50 | ) 51 | 52 | relationConstraintCoef=(RelConsCoef: 0.0 0.01 0.05 0.1 0.2 0.3) 53 | 54 | bottleneckDim=(BottleneckDim: 0 64 128 256) 55 | 56 | maxBranchingFactors=(Dataset: 57 | Aida="[1,1,1]" 58 | Figer="[2,1]" 59 | OntoNotes="[1,1,1]" 60 | BBN="[1,1]" 61 | ) 62 | 63 | batchSize=256 64 | devBatchSize=256 65 | 66 | numNegativeSamples=0 67 | 68 | numEpochs=(Dataset: Aida=20 Figer=5 OntoNotes=5 BBN=5) 69 | 70 | strategies=(Dataset: 71 | Aida="[top,none,none]" 72 | Figer="[top,none]" 73 | OntoNotes="[other,none,none]" 74 | BBN="[other,none]" 75 | ) 76 | 77 | regularizer=(Regularizer: 0.0001 0.0003 0.001 0.003 0.01 0.03 0.1 0.3) 78 | 79 | delta=(Dataset: 80 | Aida="[0,0,0]" Figer="[3,1]" OntoNotes="[0,1,2.5]" BBN="[0,0]" 81 | ) 82 | otherDelta=(Dataset: 83 | Aida=0 Figer=0 OntoNotes=2 BBN=0 84 | ) 85 | 86 | withOther=(WithOther: False True) 87 | liftOther=(LiftOther: False True) 88 | 89 | } 90 | -------------------------------------------------------------------------------- /tapes/sge.tape: -------------------------------------------------------------------------------- 1 | # Credit to Shuoyang Ding 2 | # https://github.com/shuoyangd/tape4nmt/blob/master/tapes/submitters.tape 3 | # 4 | # COMMANDS: the bash commands from some task 5 | # TASK, REALIZATION, CONFIGURATION: variables passed by ducttape 6 | submitter sge :: action_flags 7 | :: COMMANDS 8 | :: TASK REALIZATION TASK_VARIABLES CONFIGURATION { 9 | action run { 10 | wrapper="ducttape_sge_job.sh" 11 | echo "#!/usr/bin/env bash" >> $wrapper 12 | echo "" >> $wrapper 13 | echo "#$ $resource_flags" >> $wrapper 14 | echo "#$ $action_flags" >> $wrapper 15 | echo "#$ -j y" >> $wrapper 16 | echo "#$ -o $PWD/job.out" >> $wrapper 17 | echo "#$ -e $PWD/job.err" >> $wrapper 18 | echo "#$ -N $TASK[$REALIZATION]$CONFIGURATION" >> $wrapper 19 | echo "" >> $wrapper 20 | 21 | # Bash flags aren't necessarily passed into the scheduler 22 | # so we must re-initialize them 23 | 24 | echo "set -euo pipefail" >> $wrapper 25 | echo "" >> $wrapper 26 | echo "$TASK_VARIABLES" | perl -pe 's/=/="/; s/$/"/' >> $wrapper 27 | 28 | # Setup the virtual environment 29 | cat >> $wrapper <> $wrapper 51 | 52 | echo >> $wrapper 53 | echo "echo \"HOSTNAME: \$(hostname)\"" >> $wrapper 54 | echo "echo" >> $wrapper 55 | echo "echo CUDA in ENV:" >> $wrapper 56 | echo "env | grep CUDA" >> $wrapper 57 | echo "env | grep SGE" >> $wrapper 58 | echo >> $wrapper 59 | 60 | echo "$COMMANDS" >> $wrapper 61 | echo "echo \$? > $PWD/exitcode" >> $wrapper # saves the exit code of the inner process 62 | 63 | # Use SGE's -sync option to prevent qsub from immediately returning 64 | qsub -V -S /bin/bash $wrapper | grep -Eo "Your job [0-9]+" | grep -Eo "[0-9]+" > $PWD/job_id 65 | job_id=`cat $PWD/job_id` 66 | 67 | # async job killer 68 | exitfn () { 69 | trap SIGINT 70 | echo "wait until I kill the job $job_id" 71 | qdel $job_id 72 | exit 73 | } 74 | 75 | trap "exitfn" INT 76 | 77 | # don't use -sync y, instead, wait on exitcode 78 | while [ ! -z "`qstat -u $USER | grep $job_id`" ] 79 | do 80 | sleep 15 81 | done 82 | 83 | trap SIGINT 84 | 85 | # restore the exit code saved from the inner process 86 | EXITCODE=$(cat $PWD/exitcode) 87 | [ $EXITCODE = "0" ] 88 | } 89 | } 90 | -------------------------------------------------------------------------------- /tapes/shimaoka-data.tape: -------------------------------------------------------------------------------- 1 | global { 2 | shimaokaDataUrl="http://www.cl.ecei.tohoku.ac.jp/~shimaoka/corpus.zip" 3 | shimaokaDataset=(ShimaokaDataset: Figer OntoNotes) 4 | } 5 | 6 | task shimaokaDataDownload : hiertype 7 | :: shimaokaDataUrl=@ 8 | > corpus 9 | { 10 | wget $shimaokaDataUrl 11 | unzip corpus.zip 12 | rm corpus.zip 13 | } 14 | 15 | plan ShimaokaDataDownload { 16 | reach shimaokaDataDownload 17 | } 18 | 19 | task shimaokaData : hiertype 20 | < data=$corpus@shimaokaDataDownload 21 | > out 22 | :: dir=(ShimaokaDataset: Figer=Wiki OntoNotes=OntoNotes) 23 | { 24 | mkdir -p $out 25 | java -cp $hiertype/aux/target/*/*.jar hiertype.PreprocessShimaokaData $data/$dir/train.txt \ 26 | | sed 's/geograpy/geography/g' | sed 's|livingthing|living_thing|g' > $out/train.tsv 27 | java -cp $hiertype/aux/target/*/*.jar hiertype.PreprocessShimaokaData $data/$dir/dev.txt \ 28 | | sed 's/geograpy/geography/g' | sed 's|livingthing|living_thing|g' > $out/dev.tsv 29 | java -cp $hiertype/aux/target/*/*.jar hiertype.PreprocessShimaokaData $data/$dir/test.txt \ 30 | | sed 's/geograpy/geography/g' | sed 's|livingthing|living_thing|g' > $out/test.tsv 31 | } 32 | 33 | task shimaokaOntology : hiertype 34 | < data=$out@shimaokaData 35 | > ontology="ontology.txt" 36 | { 37 | java -cp $hiertype/aux/target/*/*.jar hiertype.GetHierarchy $data/train.tsv > $ontology 38 | } 39 | 40 | 41 | plan PreprocessShimaokaData { 42 | reach shimaokaData via (ShimaokaDataset: *) 43 | reach shimaokaOntology via (ShimaokaDataset: *) 44 | } 45 | 46 | 47 | task ontoNotesOtherRemovedData : hiertype 48 | < data=$out@shimaokaData[ShimaokaDataset:OntoNotes] 49 | > out 50 | { 51 | mkdir -p $out 52 | cat $data/train.tsv | python $hiertype/aux/scripts/ontonotes/remove_other.py > $out/train.tsv 53 | cat $data/dev.tsv | python $hiertype/aux/scripts/ontonotes/remove_other.py > $out/dev.tsv 54 | cat $data/test.tsv | python $hiertype/aux/scripts/ontonotes/remove_other.py > $out/test.tsv 55 | } 56 | 57 | task ontoNotesOtherRemovedOntology 58 | < original=$ontology@shimaokaOntology[ShimaokaDataset:OntoNotes] 59 | > ontology="ontology.txt" 60 | { 61 | cat $original | sed -e 's|/other||g' > $ontology 62 | } 63 | -------------------------------------------------------------------------------- /tapes/test.tape: -------------------------------------------------------------------------------- 1 | 2 | task test : hiertype 3 | < data=@ 4 | < dump=$dump@cache 5 | < ontology=@ 6 | < model=$out@train 7 | > out 8 | > metrics="metrics.txt" 9 | :: maxBranchingFactors=@ :: delta=@ :: strategies=@ :: otherDelta=@ 10 | :: grid=@ 11 | :: .submitter=$grid .action_flags=@ .resource_flags=@ 12 | :: gpuId=@ 13 | { 14 | mkdir -p $out 15 | PYTHONPATH=$hiertype CUDA_VISIBLE_DEVICES=$gpuId \ 16 | python $hiertype/hiertype/commands/run.py \ 17 | --test $dump/test.db:$data/test.tsv \ 18 | --model_dir $model \ 19 | --out $out \ 20 | --gpuid 0 \ 21 | --strategies $strategies \ 22 | --max_branching_factors $maxBranchingFactors \ 23 | --delta $delta \ 24 | --other_delta $otherDelta \ 25 | > $metrics 26 | } 27 | -------------------------------------------------------------------------------- /tapes/train.tape: -------------------------------------------------------------------------------- 1 | task train : hiertype 2 | < data=@ 3 | < dump=$dump@cache 4 | < ontology=@ 5 | > out 6 | :: contextualizer=@ :: inputDim=@ :: typeDim=@ :: bottleneckDim=@ :: embDropoutRate=@ :: dropoutRate=@ 7 | :: margins=@ :: thresholdRatio=@ :: relationConstraintCoef=@ :: numNegativeSamples=@ 8 | :: withOther=@ :: liftOther=@ 9 | :: maxBranchingFactors=@ :: strategies=@ :: delta=@ :: otherDelta=@ 10 | :: regularizer=@ :: batchSize=@ :: devBatchSize=@ :: numEpochs=@ 11 | :: grid=@ 12 | :: .submitter=$grid .action_flags=@ .resource_flags=@ 13 | :: gpuId=@ 14 | { 15 | mkdir -p $out 16 | PYTHONPATH=$hiertype CUDA_VISIBLE_DEVICES=$gpuId \ 17 | python $hiertype/hiertype/commands/train.py \ 18 | --train $dump/train.db:$data/train.tsv \ 19 | --dev $dump/dev.db:$data/dev.tsv \ 20 | --ontology $ontology \ 21 | --out $out \ 22 | --contextualizer $contextualizer --input_dim $inputDim --type_dim $typeDim --bottleneck_dim $bottleneckDim \ 23 | --emb_dropout_rate $embDropoutRate --dropout_rate $dropoutRate --margins $margins --threshold_ratio $thresholdRatio \ 24 | --with_other $withOther --lift_other $liftOther \ 25 | --relation_constraint_coef $relationConstraintCoef \ 26 | --num_negative_samples $numNegativeSamples \ 27 | --max_branching_factors $maxBranchingFactors --delta $delta --other_delta $otherDelta --strategies $strategies \ 28 | --regularizer $regularizer \ 29 | --dev_metric +O_MiF \ 30 | --batch_size $batchSize --dev_batch_size $devBatchSize --num_epochs $numEpochs 31 | } 32 | 33 | plan TrainBBN { 34 | reach train via (Dataset: BBN) * (Contextualizer: elmo-original) * (EmbDropout: 0.6) * (ThresholdRatio: 0.2) 35 | * (RelConsCoef: 0.1) * (Regularizer: 0.003) * (MarginMode: graded) * (WithOther: True) * (LiftOther: True) 36 | } 37 | 38 | plan TrainAida { 39 | reach train via (Dataset: Aida) * (Contextualizer: elmo-original) * (EmbDropout: 0.6) * (ThresholdRatio: 0.1) 40 | * (RelConsCoef: 0.3) * (Regularizer: 0.1) * (MarginMode: graded) * (WithOther: True) * (LiftOther: False) 41 | } 42 | 43 | plan TrainFiger { 44 | reach train via (Dataset: Figer) * (Contextualizer: elmo-original) * (EmbDropout: 0.5) * (Dropout: 0.1 0.2 0.3) * 45 | (ThresholdRatio: 0.2) * (RelConsCoef: 0.1) * (Regularizer: 0.0001) * (WithOther: True) * (LiftOther: False) 46 | } 47 | 48 | plan TrainOntoNotes { 49 | reach train via (Dataset: OntoNotes) * (Contextualizer: elmo-original) * (EmbDropout: 0.6) * (Dropout: 0.1) * 50 | (ThresholdRatio: 0.15) * (RelConsCoef: 0.1) * (Regularizer: 0.001) * (MarginMode: graded) * (WithOther: False) 51 | } 52 | --------------------------------------------------------------------------------