├── src ├── __init__.py ├── data │ ├── __init__.py │ └── dataset_el.py ├── model │ ├── __init__.py │ ├── entity_detection.py │ ├── entity_linking.py │ └── efficient_el.py ├── utils.py └── beam_search.py ├── requirement.txt ├── LICENSE ├── scripts └── train.py ├── .gitignore ├── README.md └── notebooks ├── README.md └── Test.ipynb /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | pytorch>=1.7 2 | pytorch_lightning>=1.3 3 | transformers>=4.0 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Nicola De Cao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | from pprint import pprint 4 | 5 | from pytorch_lightning import LightningModule, Trainer 6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 7 | from pytorch_lightning.loggers import TensorBoardLogger 8 | from pytorch_lightning.utilities.seed import seed_everything 9 | 10 | from src.model.efficient_el import EfficientEL 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser() 14 | 15 | parser.add_argument("--dirpath", type=str, default="models") 16 | parser.add_argument("--save_top_k", type=int, default=10) 17 | parser.add_argument("--seed", type=int, default=0) 18 | 19 | parser = EfficientEL.add_model_specific_args(parser) 20 | parser = Trainer.add_argparse_args(parser) 21 | 22 | args, _ = parser.parse_known_args() 23 | pprint(args.__dict__) 24 | 25 | seed_everything(seed=args.seed) 26 | 27 | logger = TensorBoardLogger(args.dirpath, name=None) 28 | 29 | callbacks = [ 30 | ModelCheckpoint( 31 | mode="max", 32 | monitor="micro_f1", 33 | dirpath=os.path.join(logger.log_dir, "checkpoints"), 34 | save_top_k=args.save_top_k, 35 | filename="model-epoch={epoch:02d}-micro_f1={micro_f1:.4f}-ed_micro_f1={ed_micro_f1:.4f}", 36 | ), 37 | LearningRateMonitor( 38 | logging_interval="step", 39 | ), 40 | ] 41 | 42 | trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=callbacks) 43 | 44 | model = EfficientEL(**vars(args)) 45 | 46 | trainer.fit(model) 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Highly Parallel Autoregressive Entity Linking
with Discriminative Correction 2 | 3 | ## Overview 4 | 5 | This repository contains the Pytorch implementation of [[1]](#citation)(https://arxiv.org/abs/2109.03792). 6 | 7 | Here the [link](https://mega.nz/folder/l4RhnIxL#_oYvidq2qyDIw1sT-KeMQA) to **pre-processed data** used for this work (i.e., training, validation and test splits of AIDA as well as the KB with the entities) and the **released model**. 8 | 9 | ## Dependencies 10 | 11 | * **python>=3.8** 12 | * **pytorch>=1.7** 13 | * **pytorch_lightning>=1.3** 14 | * **transformers>=4.0** 15 | 16 | ## Structure 17 | * [src](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src): The source code of the model. In [src/data](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src/data) there is an class of a dataset for Entity Linking. In [src/model](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src/model) there are three classes that implement our EL model. One for the Entity Disambiuation part, one for the (autoregresive) Entity Liking part, and one for the entire model (which also contains the training and validation loops). 18 | * [notebooks](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/notebooks): Example code for loading our Entity Linking model, evaluate it on AIDA, and run inference on a test document. 19 | 20 | ## Usage 21 | Please have a look into the [notebooks](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/notebooks) folder to see hot to load our Entity Linking model, evaluate it on AIDA, and run inference on a test document. 22 | 23 | Here a minimal example that demonstrate how to use our model: 24 | ```python 25 | from src.model.efficient_el import EfficientEL 26 | from IPython.display import Markdown 27 | from src.utils import 28 | 29 | # loading the model on GPU and setting the the threshold to the 30 | # optimal value (based on AIDA validation set) 31 | model = EfficientEL.load_from_checkpoint("../models/model.ckpt").eval().cuda() 32 | model.hparams.threshold = -3.2 33 | 34 | # loading the KB with the entities 35 | model.generate_global_trie() 36 | 37 | # document which we want to apply EL on 38 | s = """CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 \ 39 | West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset \ 40 | by an innings and 39 runs in two days to take over at the head of the county championship .""" 41 | 42 | # getting spans from the model and converting the result into Markdown for visualization 43 | Markdown( 44 | get_markdown( 45 | [s], 46 | [[(s[0], s[1], s[2][0][0]) for s in spans] 47 | for spans in model.sample([s])] 48 | )[0] 49 | ) 50 | ``` 51 | Which will generate: 52 | 53 | > CRICKET - [LEICESTERSHIRE](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) TAKE OVER AT TOP AFTER INNINGS VICTORY . [LONDON](https://en.wikipedia.org/wiki/London) 1996-08-30 [West Indian](https://en.wikipedia.org/wiki/West_Indies) all-rounder [Phil Simmons](https://en.wikipedia.org/wiki/Philip_Walton) took four for 38 on Friday as [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) beat [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) by an innings and 39 runs in two days to take over at the head of the county championship . 54 | 55 | 56 | Please cite [[1](#citation)] in your work when using this library in your experiments. 57 | 58 | ### Training 59 | 60 | To train our model you can run the following comand 61 | ```bash 62 | python scripts/train.py --gpus ${NUM_GPUS} --acceleration ddp --batch_size 32 63 | ``` 64 | 65 | ## Feedback 66 | For questions and comments, feel free to contact [Nicola De Cao](mailto:nicola.decao@gmail.com). 67 | 68 | ## License 69 | MIT 70 | 71 | ## Citation 72 | ``` 73 | [1] De Cao Nicola, Aziz Wilker, & Titov Ivan. (2021). 74 | Highly parallel autoregressive entity linking with discriminative correction. 75 | Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, 7662–7669. 76 | https://doi.org/10.18653/v1/2021.emnlp-main.604 77 | ``` 78 | 79 | BibTeX format: 80 | ``` 81 | @inproceedings{de-cao-etal-2021-highly, 82 | title = "Highly Parallel Autoregressive Entity Linking with Discriminative Correction", 83 | author = "De Cao, Nicola and 84 | Aziz, Wilker and 85 | Titov, Ivan", 86 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing", 87 | month = nov, 88 | year = "2021", 89 | address = "Online and Punta Cana, Dominican Republic", 90 | publisher = "Association for Computational Linguistics", 91 | url = "https://aclanthology.org/2021.emnlp-main.604", 92 | doi = "10.18653/v1/2021.emnlp-main.604", 93 | pages = "7662--7669", 94 | } 95 | ``` 96 | 97 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pytorch_lightning.metrics import Metric 3 | 4 | 5 | class MicroF1(Metric): 6 | def __init__(self, dist_sync_on_step=False): 7 | super().__init__(dist_sync_on_step=dist_sync_on_step) 8 | 9 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum") 10 | self.add_state("prec_d", default=torch.tensor(0), dist_reduce_fx="sum") 11 | self.add_state("rec_d", default=torch.tensor(0), dist_reduce_fx="sum") 12 | 13 | def update(self, p, g): 14 | 15 | self.n += len(g.intersection(p)) 16 | self.prec_d += len(p) 17 | self.rec_d += len(g) 18 | 19 | def compute(self): 20 | p = self.n.float() / self.prec_d 21 | r = self.n.float() / self.rec_d 22 | return (2 * p * r / (p + r)) if (p + r) > 0 else (p + r) 23 | 24 | 25 | class MacroF1(Metric): 26 | def __init__(self, dist_sync_on_step=False): 27 | super().__init__(dist_sync_on_step=dist_sync_on_step) 28 | 29 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum") 30 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum") 31 | 32 | def update(self, p, g): 33 | 34 | prec = len(g.intersection(p)) / len(p) 35 | rec = len(g.intersection(p)) / len(g) 36 | 37 | self.n += (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else (prec + rec) 38 | self.d += 1 39 | 40 | def compute(self): 41 | return (self.n / self.d) if self.d > 0 else self.d 42 | 43 | 44 | class MicroPrecision(Metric): 45 | def __init__(self, dist_sync_on_step=False): 46 | super().__init__(dist_sync_on_step=dist_sync_on_step) 47 | 48 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum") 49 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum") 50 | 51 | def update(self, p, g): 52 | self.n += len(g.intersection(p)) 53 | self.d += len(p) 54 | 55 | def compute(self): 56 | return (self.n.float() / self.d) if self.d > 0 else self.d 57 | 58 | 59 | class MacroPrecision(Metric): 60 | def __init__(self, dist_sync_on_step=False): 61 | super().__init__(dist_sync_on_step=dist_sync_on_step) 62 | 63 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum") 64 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum") 65 | 66 | def update(self, p, g): 67 | self.n += len(g.intersection(p)) / len(p) 68 | self.d += 1 69 | 70 | def compute(self): 71 | return (self.n / self.d) if self.d > 0 else self.d 72 | 73 | 74 | class MicroRecall(Metric): 75 | def __init__(self, dist_sync_on_step=False): 76 | super().__init__(dist_sync_on_step=dist_sync_on_step) 77 | 78 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum") 79 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum") 80 | 81 | def update(self, p, g): 82 | self.n += len(g.intersection(p)) 83 | self.d += len(g) 84 | 85 | def compute(self): 86 | return (self.n.float() / self.d) if self.d > 0 else self.d 87 | 88 | 89 | class MacroRecall(Metric): 90 | def __init__(self, dist_sync_on_step=False): 91 | super().__init__(dist_sync_on_step=dist_sync_on_step) 92 | 93 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum") 94 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum") 95 | 96 | def update(self, p, g): 97 | self.n += len(g.intersection(p)) / len(g) 98 | self.d += 1 99 | 100 | def compute(self): 101 | return (self.n / self.d) if self.d > 0 else self.d 102 | 103 | 104 | def get_markdown(sentences, entity_spans): 105 | return_outputs = [] 106 | for sent, entities in zip(sentences, entity_spans): 107 | text = "" 108 | last_end = 0 109 | for begin, end, href in entities: 110 | text += sent[last_end:begin] 111 | text += "[{}](https://en.wikipedia.org/wiki/{})".format( 112 | sent[begin:end], href.replace(" ", "_") 113 | ) 114 | last_end = end 115 | 116 | text += sent[last_end:] 117 | return_outputs.append(text) 118 | 119 | return return_outputs 120 | 121 | 122 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True): 123 | if target.dim() == lprobs.dim() - 1: 124 | target = target.unsqueeze(-1) 125 | nll_loss = -lprobs.gather(dim=-1, index=target) 126 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True) 127 | if ignore_index is not None: 128 | pad_mask = target.eq(ignore_index) 129 | nll_loss.masked_fill_(pad_mask, 0.0) 130 | smooth_loss.masked_fill_(pad_mask, 0.0) 131 | else: 132 | nll_loss = nll_loss.squeeze(-1) 133 | smooth_loss = smooth_loss.squeeze(-1) 134 | if reduce: 135 | nll_loss = nll_loss.sum() 136 | smooth_loss = smooth_loss.sum() 137 | eps_i = epsilon / (lprobs.size(-1) - 1) 138 | loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss 139 | return loss, nll_loss 140 | -------------------------------------------------------------------------------- /src/data/dataset_el.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jsonlines 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class DatasetEL(Dataset): 10 | def __init__( 11 | self, 12 | tokenizer, 13 | data_path, 14 | max_length=32, 15 | max_length_span=15, 16 | test=False, 17 | ): 18 | super().__init__() 19 | self.tokenizer = tokenizer 20 | 21 | with jsonlines.open(data_path) as f: 22 | self.data = list(f) 23 | 24 | self.max_length = max_length 25 | self.max_length_span = max_length_span 26 | self.test = test 27 | 28 | def __len__(self): 29 | return len(self.data) 30 | 31 | def __getitem__(self, item): 32 | return self.data[item] 33 | 34 | def collate_fn(self, batch): 35 | 36 | batch = { 37 | **{ 38 | f"src_{k}": v 39 | for k, v in self.tokenizer( 40 | [b["input"] for b in batch], 41 | return_tensors="pt", 42 | padding=True, 43 | max_length=self.max_length, 44 | truncation=True, 45 | return_offsets_mapping=True, 46 | ).items() 47 | }, 48 | "offsets_start": ( 49 | [ 50 | i 51 | for i, b in enumerate(batch) 52 | for a in b["anchors"] 53 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 54 | ], 55 | [ 56 | a[0] 57 | for i, b in enumerate(batch) 58 | for a in b["anchors"] 59 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 60 | ], 61 | ), 62 | "offsets_end": ( 63 | [ 64 | i 65 | for i, b in enumerate(batch) 66 | for a in b["anchors"] 67 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 68 | ], 69 | [ 70 | a[1] 71 | for i, b in enumerate(batch) 72 | for a in b["anchors"] 73 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 74 | ], 75 | ), 76 | "offsets_inside": ( 77 | [ 78 | i 79 | for i, b in enumerate(batch) 80 | for a in b["anchors"] 81 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 82 | for j in range(a[0] + 1, a[1] + 1) 83 | ], 84 | [ 85 | j 86 | for i, b in enumerate(batch) 87 | for a in b["anchors"] 88 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 89 | for j in range(a[0] + 1, a[1] + 1) 90 | ], 91 | ), 92 | "raw": batch, 93 | } 94 | 95 | if not self.test: 96 | 97 | negatives = [ 98 | np.random.choice([e for e in cands if e != a[2]]) 99 | if len([e for e in cands if e != a[2]]) > 0 100 | else None 101 | for b in batch["raw"] 102 | for a, cands in zip(b["anchors"], b["candidates"]) 103 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 104 | ] 105 | 106 | targets = [ 107 | a[2] 108 | for b in batch["raw"] 109 | for a in b["anchors"] 110 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span 111 | ] 112 | 113 | assert len(targets) == len(negatives) 114 | 115 | batch_upd = { 116 | **( 117 | { 118 | f"trg_{k}": v 119 | for k, v in self.tokenizer( 120 | targets, 121 | return_tensors="pt", 122 | padding=True, 123 | max_length=self.max_length, 124 | truncation=True, 125 | ).items() 126 | } 127 | if not self.test 128 | else {} 129 | ), 130 | **( 131 | { 132 | f"neg_{k}": v 133 | for k, v in self.tokenizer( 134 | [e for e in negatives if e], 135 | return_tensors="pt", 136 | padding=True, 137 | max_length=self.max_length, 138 | truncation=True, 139 | ).items() 140 | } 141 | if not self.test 142 | else {} 143 | ), 144 | "neg_mask": torch.tensor([e is not None for e in negatives]), 145 | } 146 | 147 | batch = {**batch, **batch_upd} 148 | 149 | return batch 150 | -------------------------------------------------------------------------------- /src/model/entity_detection.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import torch 4 | 5 | 6 | class EntityDetectionFactor(torch.nn.Module): 7 | def __init__(self, max_length_span, dropout=0, mentions_filename=None): 8 | super().__init__() 9 | 10 | self.max_length_span = max_length_span 11 | self.classifier_start = torch.nn.Sequential( 12 | torch.nn.LayerNorm(768), 13 | torch.nn.Dropout(dropout), 14 | torch.nn.Linear(768, 128), 15 | torch.nn.ReLU(), 16 | torch.nn.LayerNorm(128), 17 | torch.nn.Dropout(dropout), 18 | torch.nn.Linear(128, 1), 19 | ) 20 | self.classifier_end = torch.nn.Sequential( 21 | torch.nn.LayerNorm(768 * 2), 22 | torch.nn.Dropout(dropout), 23 | torch.nn.Linear(768 * 2, 128), 24 | torch.nn.ReLU(), 25 | torch.nn.LayerNorm(128), 26 | torch.nn.Dropout(dropout), 27 | torch.nn.Linear(128, 1), 28 | ) 29 | 30 | self.mentions = None 31 | if mentions_filename: 32 | with open(mentions_filename) as f: 33 | self.mentions = set(json.load(f)) 34 | 35 | def _forward_start(self, batch, hidden_states): 36 | return self.classifier_start(hidden_states).squeeze(-1) 37 | 38 | def _forward_end(self, batch, hidden_states, offsets_start): 39 | 40 | classifier_end_input = torch.nn.functional.pad( 41 | hidden_states, (0, 0, 0, self.max_length_span - 1) 42 | ) 43 | 44 | classifier_end_input = torch.cat( 45 | ( 46 | hidden_states[offsets_start] 47 | .unsqueeze(-1) 48 | .repeat(1, 1, self.max_length_span), 49 | classifier_end_input.unfold(1, self.max_length_span, 1)[offsets_start], 50 | ), 51 | dim=1, 52 | ).permute(0, 2, 1) 53 | 54 | logits_classifier_end = self.classifier_end(classifier_end_input).squeeze(-1) 55 | 56 | mask = torch.cat( 57 | ( 58 | batch["src_attention_mask"], 59 | torch.zeros( 60 | ( 61 | batch["src_attention_mask"].shape[0], 62 | self.max_length_span - 1, 63 | ), 64 | dtype=torch.float, 65 | device=hidden_states.device, 66 | ), 67 | ), 68 | dim=1, 69 | ) 70 | mask = torch.where( 71 | mask.bool(), 72 | torch.zeros_like(mask), 73 | -torch.full_like(mask, float("inf")), 74 | ).unfold(1, self.max_length_span, 1)[offsets_start] 75 | 76 | return logits_classifier_end + mask 77 | 78 | def forward(self, batch, hidden_states): 79 | 80 | logits_classifier_start = self._forward_start(batch, hidden_states) 81 | offsets_start = batch["offsets_start"] 82 | logits_classifier_end = self._forward_end(batch, hidden_states, offsets_start) 83 | 84 | return logits_classifier_start, logits_classifier_end 85 | 86 | def forward_hard(self, batch, hidden_states, threshold=0): 87 | 88 | logits_classifier_start = self._forward_start(batch, hidden_states) 89 | offsets_start = logits_classifier_start > threshold 90 | logits_classifier_end = self._forward_end(batch, hidden_states, offsets_start) 91 | 92 | start = offsets_start.nonzero() 93 | end = start.clone() 94 | 95 | scores = None 96 | if logits_classifier_end.shape[0] > 0: 97 | end[:, 1] += logits_classifier_end.argmax(-1) 98 | scores = ( 99 | logits_classifier_start[offsets_start] 100 | + logits_classifier_end.max(-1).values 101 | ) 102 | if self.mentions: 103 | mention_mask = torch.tensor( 104 | [ 105 | ( 106 | batch["raw"][i]["input"][ 107 | batch["src_offset_mapping"][i][s][0] 108 | .item() : batch["src_offset_mapping"][i][e][1] 109 | .item() 110 | ] 111 | in self.mentions 112 | ) 113 | for (i, s), (_, e) in zip(start, end) 114 | ], 115 | device=start.device, 116 | ) 117 | 118 | start = start[mention_mask] 119 | end = end[mention_mask] 120 | 121 | return (start, end, scores), ( 122 | logits_classifier_start, 123 | logits_classifier_end, 124 | ) 125 | 126 | def forward_loss(self, batch, hidden_states): 127 | logits_classifier_start, logits_classifier_end = self.forward( 128 | batch, hidden_states 129 | ) 130 | 131 | batch["labels_start"] = torch.zeros_like(batch["src_input_ids"]) 132 | batch["labels_start"][batch["offsets_start"]] = 1 133 | 134 | loss_start = torch.nn.functional.binary_cross_entropy_with_logits( 135 | logits_classifier_start, 136 | batch["labels_start"].float(), 137 | weight=batch["src_attention_mask"], 138 | ) 139 | 140 | batch["labels_end"] = torch.tensor( 141 | [b - a for a, b in zip(batch["offsets_start"][1], batch["offsets_end"][1])], 142 | device=logits_classifier_start.device, 143 | ) 144 | 145 | loss_end = torch.nn.functional.cross_entropy( 146 | logits_classifier_end, 147 | batch["labels_end"], 148 | ) 149 | 150 | return loss_start, loss_end 151 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | ```python 2 | %load_ext autoreload 3 | %autoreload 2 4 | 5 | import sys 6 | sys.path.append("../") 7 | ``` 8 | 9 | 10 | ```python 11 | from argparse import ArgumentParser 12 | from pytorch_lightning import Trainer 13 | from src.model.efficient_el import EfficientEL 14 | from src.data.dataset_el import DatasetEL 15 | from IPython.display import Markdown 16 | from src.utils import get_markdown 17 | ``` 18 | 19 | 20 | ```python 21 | parser = ArgumentParser() 22 | 23 | parser = Trainer.add_argparse_args(parser) 24 | 25 | args, _ = parser.parse_known_args() 26 | args.gpus = 1 27 | args.precision = 16 28 | 29 | trainer = Trainer.from_argparse_args(args) 30 | ``` 31 | 32 | GPU available: True, used: True 33 | TPU available: False, using: 0 TPU cores 34 | Using native 16bit precision. 35 | 36 | 37 | 38 | ```python 39 | model = EfficientEL.load_from_checkpoint("../models/model.ckpt").eval() 40 | ``` 41 | 42 | 43 | ```python 44 | model.hparams.threshold = -3.2 45 | model.hparams.test_with_beam_search = False 46 | model.hparams.test_with_beam_search_no_candidates = False 47 | trainer.test(model, test_dataloaders=model.test_dataloader(), ckpt_path=None) 48 | ``` 49 | 50 | -------------------------------------------------------------------------------- 51 | DATALOADER:0 TEST RESULTS 52 | {'ed_macro_f1': 0.9203808307647705, 53 | 'ed_macro_prec': 0.9131189584732056, 54 | 'ed_macro_rec': 0.9390283226966858, 55 | 'ed_micro_f1': 0.9348137378692627, 56 | 'ed_micro_prec': 0.9219427704811096, 57 | 'ed_micro_rec': 0.9480490684509277, 58 | 'macro_f1': 0.8363054394721985, 59 | 'macro_prec': 0.8289670348167419, 60 | 'macro_rec': 0.8539509773254395, 61 | 'micro_f1': 0.8550071120262146, 62 | 'micro_prec': 0.8432350158691406, 63 | 'micro_rec': 0.8671125769615173} 64 | -------------------------------------------------------------------------------- 65 | 66 | 67 | 68 | ```python 69 | model.generate_global_trie() 70 | ``` 71 | 72 | 73 | ```python 74 | s = """CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 \ 75 | West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset \ 76 | by an innings and 39 runs in two days to take over at the head of the county championship . Their \ 77 | stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed \ 78 | in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire . \ 79 | After bowling Somerset out for 83 on the opening morning at Grace Road , Leicestershire extended their \ 80 | first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three \ 81 | for 83 . Trailing by 213 , Somerset got a solid start to their second innings before Simmons stepped in \ 82 | to bundle them out for 174 . Essex , however , look certain to regain their top spot after Nasser Hussain \ 83 | and Peter Such gave them a firm grip on their match against Yorkshire at Headingley . Hussain , \ 84 | considered surplus to England 's one-day requirements , struck 158 , his first championship century of \ 85 | the season , as Essex reached 372 and took a first innings lead of 82 . By the close Yorkshire had turned \ 86 | that into a 37-run advantage but off-spinner Such had scuttled their hopes , taking four for 24 in 48 balls 87 | \and leaving them hanging on 119 for five and praying for rain . At the Oval , Surrey captain Chris Lewis , \ 88 | another man dumped by England , continued to silence his critics as he followed his four for 45 on Thursday \ 89 | with 80 not out on Friday in the match against Warwickshire . He was well backed by England hopeful Mark \ 90 | Butcher who made 70 as Surrey closed on 429 for seven , a lead of 234 . Derbyshire kept up the hunt for \ 91 | their first championship title since 1936 by reducing Worcestershire to 133 for five in their second \ 92 | innings , still 100 runs away from avoiding an innings defeat . Australian Tom Moody took six for 82 but \ 93 | Chris Adams , 123 , and Tim O'Gorman , 109 , took Derbyshire to 471 and a first innings lead of 233 . \ 94 | After the frustration of seeing the opening day of their match badly affected by the weather , Kent stepped \ 95 | up a gear to dismiss Nottinghamshire for 214 . They were held up by a gritty 84 from Paul Johnson but \ 96 | ex-England fast bowler Martin McCague took four for 55 . By stumps Kent had reached 108 for three .""" 97 | 98 | Markdown(get_markdown([s], [[(s[0], s[1], s[2][0][0]) for s in spans] for spans in model.sample([s])])[0]) 99 | ``` 100 | 101 | 102 | 103 | 104 | CRICKET - [LEICESTERSHIRE](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) TAKE OVER AT TOP AFTER INNINGS VICTORY . [LONDON](https://en.wikipedia.org/wiki/London) 1996-08-30 [West Indian](https://en.wikipedia.org/wiki/West_Indies) all-rounder [Phil Simmons](https://en.wikipedia.org/wiki/Philip_Walton) took four for 38 on Friday as [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) beat [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) by an innings and 39 runs in two days to take over at the head of the county championship . Their stay on top , though , may be short-lived as title rivals [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) and [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) all closed in on victory while [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) made up for lost time in their rain-affected match against [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) . After bowling [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) out for 83 on the opening morning at [Grace Road](https://en.wikipedia.org/wiki/Grace_Road) , [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) extended their first innings by 94 runs before being bowled out for 296 with [England](https://en.wikipedia.org/wiki/England_cricket_team) discard [Andy Caddick](https://en.wikipedia.org/wiki/Andrew_Caddick) taking three for 83 . Trailing by 213 , [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) got a solid start to their second innings before [Simmons](https://en.wikipedia.org/wiki/Singapore) stepped in to bundle them out for 174 . [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , however , look certain to regain their top spot after [Nasser Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) and [Peter Such](https://en.wikipedia.org/wiki/Peter_Thomson_(golfer)) gave them a firm grip on their match against [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) at [Headingley](https://en.wikipedia.org/wiki/Headingley_Stadium) . [Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) , considered surplus to [England](https://en.wikipedia.org/wiki/England_cricket_team) 's one-day requirements , struck 158 , his first championship century of the season , as [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) reached 372 and took a first innings lead of 82 . By the close [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) had turned that into a 37-run advantage but off-spinner [Such](https://en.wikipedia.org/wiki/Mark_Broadhurst) had scuttled their hopes , taking four for 24 in 48 balls 105 | nd leaving them hanging on 119 for five and praying for rain . At the [Oval](https://en.wikipedia.org/wiki/The_Oval) , [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) captain [Chris Lewis](https://en.wikipedia.org/wiki/Chris_Lewis_(cricketer)) , another man dumped by [England](https://en.wikipedia.org/wiki/England_cricket_team) , continued to silence his critics as he followed his four for 45 on Thursday with 80 not out on Friday in the match against [Warwickshire](https://en.wikipedia.org/wiki/Warwickshire_County_Cricket_Club) . He was well backed by [England](https://en.wikipedia.org/wiki/England_cricket_team) hopeful [Mark Butcher](https://en.wikipedia.org/wiki/Mark_Butcher) who made 70 as [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) closed on 429 for seven , a lead of 234 . [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) kept up the hunt for their first championship title since 1936 by reducing [Worcestershire](https://en.wikipedia.org/wiki/Worcestershire_County_Cricket_Club) to 133 for five in their second innings , still 100 runs away from avoiding an innings defeat . [Australian](https://en.wikipedia.org/wiki/Australia) [Tom Moody](https://en.wikipedia.org/wiki/Tommy_Haas) took six for 82 but [Chris Adams](https://en.wikipedia.org/wiki/Chris_Walker_(squash_player)) , 123 , and Tim O'Gorman , 109 , took [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) to 471 and a first innings lead of 233 . After the frustration of seeing the opening day of their match badly affected by the weather , [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) stepped up a gear to dismiss [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) for 214 . They were held up by a gritty 84 from [Paul Johnson](https://en.wikipedia.org/wiki/Paul_Johnson_(squash_player)) but ex-England fast bowler [Martin McCague](https://en.wikipedia.org/wiki/Martin_McCague) took four for 55 . By stumps [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) had reached 108 for three . 106 | 107 | 108 | -------------------------------------------------------------------------------- /src/beam_search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/probabll/bayeseq/blob/master/aevnmt/components/beamsearch.py 3 | """ 4 | import torch 5 | import torch.nn.functional as F 6 | from packaging import version 7 | 8 | 9 | # from onmt 10 | def tile(x, count, dim=0): 11 | """ 12 | Tiles x on dimension dim count times. From OpenNMT. Used for beam search. 13 | :param x: tensor to tile 14 | :param count: number of tiles 15 | :param dim: dimension along which the tensor is tiled 16 | :return: tiled tensor 17 | """ 18 | if isinstance(x, tuple): 19 | return [tile(e, count, dim=dim) for e in x] 20 | 21 | perm = list(range(len(x.size()))) 22 | if dim != 0: 23 | perm[0], perm[dim] = perm[dim], perm[0] 24 | x = x.permute(perm).contiguous() 25 | out_size = list(x.size()) 26 | out_size[0] *= count 27 | batch = x.size(0) 28 | x = ( 29 | x.view(batch, -1) 30 | .transpose(0, 1) 31 | .repeat(count, 1) 32 | .transpose(0, 1) 33 | .contiguous() 34 | .view(*out_size) 35 | ) 36 | if dim != 0: 37 | x = x.permute(perm).contiguous() 38 | return x 39 | 40 | 41 | def beam_search( 42 | decoder, 43 | tgt_vocab_size, 44 | hidden, 45 | bos_idx, 46 | eos_idx, 47 | pad_idx, 48 | beam_width, 49 | alpha=1, 50 | max_len=15, 51 | batch_trie_dict=None, 52 | ): 53 | """ 54 | Beam search with size beam_width. Follows OpenNMT-py implementation. 55 | In each decoding step, find the k most likely partial hypotheses. 56 | 57 | :param decoder: an initialized decoder 58 | """ 59 | 60 | decoder.eval() 61 | with torch.no_grad(): 62 | 63 | # Initialize the hidden state and create the initial input. 64 | batch_size = ( 65 | hidden[0].shape[0] if isinstance(hidden, tuple) else hidden.shape[0] 66 | ) 67 | device = hidden[0].device if isinstance(hidden, tuple) else hidden.device 68 | 69 | prev_y = torch.full( 70 | size=[batch_size], 71 | fill_value=bos_idx, 72 | dtype=torch.long, 73 | device=device, 74 | ) 75 | 76 | # Tile hidden decoder states and encoder outputs beam_width times 77 | hidden = tile(hidden, beam_width, dim=0) # [layers, B*beam_width, H_dec] 78 | 79 | batch_offset = torch.arange(batch_size, dtype=torch.long, device=device) 80 | beam_offset = torch.arange( 81 | 0, 82 | batch_size * beam_width, 83 | step=beam_width, 84 | dtype=torch.long, 85 | device=device, 86 | ) 87 | alive_seq = torch.full( 88 | [batch_size * beam_width, 1], 89 | bos_idx, 90 | dtype=torch.long, 91 | device=device, 92 | ) 93 | 94 | # Give full probability to the first beam on the first step. 95 | topk_log_probs = torch.tensor( 96 | [0.0] + [float("-inf")] * (beam_width - 1), device=device 97 | ).repeat(batch_size) 98 | 99 | # Structure that holds finished hypotheses. 100 | hypotheses = [[] for _ in range(batch_size)] 101 | 102 | results = {} 103 | results["predictions"] = [[] for _ in range(batch_size)] 104 | results["scores"] = [[] for _ in range(batch_size)] 105 | results["gold_score"] = [0] * batch_size 106 | results["contexts"] = [[] for _ in range(batch_size)] 107 | 108 | done = torch.full( 109 | [batch_size, beam_width], 110 | False, 111 | dtype=torch.bool, 112 | device=device, 113 | ) 114 | trie_idx = ( 115 | torch.arange(0, batch_size, device=device) 116 | .unsqueeze(-1) 117 | .repeat(1, beam_width) 118 | .view(-1) 119 | ) 120 | for step in range(max_len): 121 | prev_y = alive_seq[:, -1].view(-1) 122 | 123 | # expand current hypotheses, decode one single step 124 | log_probs, hidden = decoder.step_beam_search(prev_y, hidden) 125 | 126 | if batch_trie_dict is not None: 127 | mask = torch.full_like(log_probs, -float("inf")) 128 | for i, (b_idx, tokens) in enumerate( 129 | zip(trie_idx.tolist(), alive_seq.tolist()) 130 | ): 131 | idx = batch_trie_dict[b_idx].get(tuple(tokens), []) 132 | mask[[i] * len(idx), idx] = 0 133 | 134 | log_probs += mask 135 | 136 | # multiply probs by the beam probability (=add logprobs) 137 | log_probs += topk_log_probs.view(-1).unsqueeze(1) 138 | curr_scores = log_probs 139 | 140 | # compute length penalty 141 | if alpha > -1: 142 | length_penalty = (step + 1) ** alpha 143 | curr_scores /= length_penalty 144 | 145 | # flatten log_probs into a list of possibilities 146 | curr_scores = curr_scores.reshape(-1, beam_width * tgt_vocab_size) 147 | 148 | # pick currently best top beam_width hypotheses (flattened order) 149 | topk_scores, topk_ids = curr_scores.topk(beam_width, dim=-1) 150 | 151 | if alpha > -1: 152 | # recover original log probs 153 | topk_log_probs = topk_scores * length_penalty 154 | 155 | # reconstruct beam origin and true word ids from flattened order 156 | if version.parse(torch.__version__) >= version.parse("1.5.0"): 157 | topk_beam_index = topk_ids.floor_divide(tgt_vocab_size) 158 | else: 159 | topk_beam_index = topk_ids.div(tgt_vocab_size) 160 | topk_ids = topk_ids.fmod(tgt_vocab_size) 161 | 162 | # map beam_index to batch_index in the flat representation 163 | batch_index = topk_beam_index + beam_offset[ 164 | : topk_beam_index.size(0) 165 | ].unsqueeze(1) 166 | select_indices = batch_index.view(-1) 167 | 168 | # append latest prediction 169 | alive_seq = torch.cat( 170 | [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1 171 | ) # batch_size*k x hyp_len 172 | 173 | is_finished = ( 174 | topk_ids.eq(eos_idx) & ~topk_scores.eq(-float("inf")) 175 | ) | topk_scores.eq(-float("inf")).all(-1, keepdim=True) 176 | 177 | if step + 1 == max_len: 178 | is_finished.fill_(1) 179 | 180 | done |= is_finished 181 | 182 | # end condition is whether the top beam is finished 183 | end_condition = done.all(-1) 184 | 185 | # for LSTMs, states are tuples of tensors 186 | hidden = [e.index_select(0, select_indices) for e in hidden] 187 | trie_idx = trie_idx.index_select(0, select_indices) 188 | 189 | # save finished hypotheses 190 | if is_finished.any(): 191 | predictions = alive_seq.view(-1, beam_width, alive_seq.size(-1)) 192 | contexts = hidden[1].view(-1, beam_width, hidden[1].shape[-1]) 193 | 194 | for i in range(is_finished.size(0)): 195 | b = batch_offset[i] 196 | finished_hyp = is_finished[i].nonzero(as_tuple=False).view(-1) 197 | 198 | # store finished hypotheses for this batch 199 | for j in finished_hyp: 200 | 201 | hypotheses[b].append( 202 | ( 203 | topk_scores[i, j], 204 | predictions[i, j], 205 | contexts[i, j].clone(), 206 | ) # ignore start_token 207 | ) 208 | # if the batch reached the end, save the beam_width hypotheses 209 | if end_condition[i]: 210 | best_hyp = sorted( 211 | hypotheses[b], key=lambda x: x[0], reverse=True 212 | ) 213 | for n, (score, pred, cont) in enumerate(best_hyp): 214 | if n >= beam_width: 215 | break 216 | results["scores"][b].append(score) 217 | results["predictions"][b].append(pred) 218 | results["contexts"][b].append(cont) 219 | 220 | if end_condition.any(): 221 | non_finished = end_condition.eq(0).nonzero(as_tuple=False).view(-1) 222 | 223 | # if all sentences are translated, no need to go further 224 | if len(non_finished) == 0: 225 | break 226 | 227 | # remove finished batches for the next step 228 | topk_log_probs = topk_log_probs.index_select(0, non_finished) 229 | batch_index = batch_index.index_select(0, non_finished) 230 | batch_offset = batch_offset.index_select(0, non_finished) 231 | alive_seq = predictions.index_select(0, non_finished).view( 232 | -1, alive_seq.size(-1) 233 | ) 234 | done = done.index_select(0, non_finished) 235 | 236 | # reorder indices, outputs and masks, and trie 237 | select_indices = batch_index.view(-1) 238 | 239 | # for LSTMs, states are tuples of tensors 240 | hidden = [e.index_select(0, select_indices) for e in hidden] 241 | trie_idx = trie_idx.index_select(0, select_indices) 242 | 243 | return ( 244 | torch.nn.utils.rnn.pad_sequence( 245 | [ 246 | torch.nn.utils.rnn.pad_sequence( 247 | e, batch_first=True, padding_value=pad_idx 248 | ).T 249 | for e in results["predictions"] 250 | ], 251 | batch_first=True, 252 | padding_value=pad_idx, 253 | ).permute(0, 2, 1), 254 | torch.stack([torch.stack(e) for e in results["scores"]]), 255 | torch.stack([torch.stack(e) for e in results["contexts"]]), 256 | ) 257 | -------------------------------------------------------------------------------- /src/model/entity_linking.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from src.beam_search import beam_search 4 | from src.utils import label_smoothed_nll_loss 5 | 6 | 7 | class LSTM(torch.nn.Module): 8 | def __init__( 9 | self, bos_token_id, pad_token_id, eos_token_id, embeddings, lm_head, dropout=0 10 | ): 11 | super().__init__() 12 | 13 | self.bos_token_id = bos_token_id 14 | self.pad_token_id = pad_token_id 15 | self.eos_token_id = eos_token_id 16 | self.embeddings = embeddings 17 | self.lm_head = lm_head 18 | self.dropout = dropout 19 | 20 | self.lstm_cell = torch.nn.LSTMCell( 21 | input_size=2 * 768, 22 | hidden_size=768, 23 | ) 24 | 25 | def _roll( 26 | self, 27 | input_ids, 28 | attention_mask, 29 | decoder_hidden, 30 | decoder_context, 31 | decoder_append, 32 | return_lprob=False, 33 | return_dict=False, 34 | ): 35 | 36 | dropout_mask = 1 37 | if self.training: 38 | dropout_mask = (torch.rand_like(decoder_hidden) > self.dropout).float() 39 | 40 | all_hiddens = [] 41 | all_contexts = [] 42 | 43 | emb = self.embeddings(input_ids) 44 | for t in range(emb.shape[1]): 45 | decoder_hidden, decoder_context = self.lstm_cell( 46 | torch.cat((emb[:, t], decoder_append), dim=-1), 47 | (decoder_hidden, decoder_context), 48 | ) 49 | decoder_hidden *= dropout_mask 50 | all_hiddens.append(decoder_hidden) 51 | all_contexts.append(decoder_context) 52 | 53 | all_hiddens = torch.stack(all_hiddens, dim=1) 54 | all_contexts = torch.stack(all_contexts, dim=1) 55 | 56 | all_contexts = all_contexts[ 57 | [e for e in range(attention_mask.shape[0])], attention_mask.sum(-1) - 1 58 | ] 59 | 60 | if return_dict: 61 | outputs = { 62 | "all_hiddens": all_hiddens, 63 | "all_contexts": all_contexts, 64 | } 65 | else: 66 | outputs = (all_hiddens, all_contexts) 67 | if return_lprob: 68 | logits = self.lm_head(all_hiddens) 69 | 70 | if self.training: 71 | logits = logits.log_softmax(-1) 72 | else: 73 | logits.sub_(logits.max(-1, keepdim=True).values) 74 | logits.exp_() 75 | logits.div_(logits.sum(-1, keepdim=True)) 76 | logits.log_() 77 | 78 | if return_dict: 79 | outputs = { 80 | "logits": logits, 81 | } 82 | else: 83 | outputs += (logits,) 84 | 85 | return outputs 86 | 87 | def step_beam_search(self, previous_tokens, hidden): 88 | decoder_hidden, decoder_context, decoder_append = hidden 89 | 90 | emb = self.embeddings(previous_tokens) 91 | decoder_hidden, decoder_context = self.lstm_cell( 92 | torch.cat((emb, decoder_append), dim=-1), 93 | (decoder_hidden, decoder_context), 94 | ) 95 | 96 | logits = self.lm_head(decoder_hidden) 97 | logits.sub_(logits.max(-1, keepdim=True).values) 98 | logits.exp_() 99 | logits.div_(logits.sum(-1, keepdim=True)) 100 | logits.log_() 101 | 102 | return logits, (decoder_hidden, decoder_context, decoder_append) 103 | 104 | def forward(self, batch, decoder_hidden, decoder_context, decoder_append): 105 | 106 | _, all_contexts_positive, lprobs = self._roll( 107 | batch["trg_input_ids"][:, :-1], 108 | batch["trg_attention_mask"][:, 1:], 109 | decoder_hidden, 110 | decoder_context, 111 | decoder_append, 112 | return_lprob=True, 113 | ) 114 | 115 | _, all_contexts_negative = self._roll( 116 | batch["neg_input_ids"][:, :-1], 117 | batch["neg_attention_mask"][:, 1:], 118 | decoder_hidden[batch["neg_mask"]], 119 | decoder_context[batch["neg_mask"]], 120 | decoder_append[batch["neg_mask"]], 121 | return_lprob=False, 122 | ) 123 | 124 | return all_contexts_positive, all_contexts_negative, lprobs 125 | 126 | def forward_all_targets( 127 | self, batch, decoder_hidden, decoder_context, decoder_append 128 | ): 129 | 130 | _, all_contexts, lprobs = self._roll( 131 | batch["cand_input_ids"][:, :-1], 132 | batch["cand_attention_mask"][:, 1:], 133 | decoder_hidden, 134 | decoder_context, 135 | decoder_append, 136 | return_lprob=True, 137 | ) 138 | 139 | scores = ( 140 | lprobs.gather( 141 | dim=-1, index=batch["cand_input_ids"][:, 1:].unsqueeze(-1) 142 | ).squeeze(-1) 143 | * batch["cand_attention_mask"][:, 1:] 144 | ) 145 | 146 | return all_contexts, scores.sum(-1) / (batch["cand_attention_mask"].sum(-1) - 1) 147 | 148 | def forward_beam_search(self, batch, hidden_states): 149 | raise NotImplemented 150 | 151 | 152 | class EntityLinkingLSTM(torch.nn.Module): 153 | def __init__( 154 | self, bos_token_id, pad_token_id, eos_token_id, embeddings, lm_head, dropout=0 155 | ): 156 | super().__init__() 157 | 158 | self.bos_token_id = bos_token_id 159 | self.pad_token_id = pad_token_id 160 | self.eos_token_id = eos_token_id 161 | 162 | self.prj = torch.nn.Sequential( 163 | torch.nn.LayerNorm(768 * 2), 164 | torch.nn.Dropout(dropout), 165 | torch.nn.Linear(768 * 2, 768), 166 | torch.nn.ReLU(), 167 | torch.nn.LayerNorm(768), 168 | torch.nn.Dropout(dropout), 169 | torch.nn.Linear(768, 768 * 3), 170 | ) 171 | 172 | self.lstm = LSTM( 173 | bos_token_id, 174 | pad_token_id, 175 | eos_token_id, 176 | embeddings, 177 | lm_head, 178 | dropout=dropout, 179 | ) 180 | 181 | self.classifier = torch.nn.Sequential( 182 | torch.nn.LayerNorm(768 * 2), 183 | torch.nn.Dropout(dropout), 184 | torch.nn.Linear(768 * 2, 768), 185 | torch.nn.ReLU(), 186 | torch.nn.LayerNorm(768), 187 | torch.nn.Dropout(dropout), 188 | torch.nn.Linear(768, 1), 189 | ) 190 | 191 | def _get_hidden_context_append_vectors(self, batch, hidden_states): 192 | return self.prj( 193 | torch.cat( 194 | ( 195 | hidden_states[batch["offsets_start"]], 196 | hidden_states[batch["offsets_end"]], 197 | ), 198 | dim=-1, 199 | ) 200 | ).split([768, 768, 768], dim=-1) 201 | 202 | def forward(self, batch, hidden_states): 203 | ( 204 | decoder_hidden, 205 | decoder_context, 206 | decoder_append, 207 | ) = self._get_hidden_context_append_vectors(batch, hidden_states) 208 | 209 | ( 210 | all_contexts_positive, 211 | all_contexts_negative, 212 | lprobs_lm, 213 | ) = self.lstm.forward(batch, decoder_hidden, decoder_context, decoder_append) 214 | 215 | logits_classifier = self.classifier( 216 | torch.cat( 217 | ( 218 | decoder_append[batch["neg_mask"]].unsqueeze(1).repeat(1, 2, 1), 219 | torch.stack( 220 | ( 221 | all_contexts_positive[batch["neg_mask"]], 222 | all_contexts_negative, 223 | ), 224 | dim=1, 225 | ), 226 | ), 227 | dim=-1, 228 | ) 229 | ).squeeze(-1) 230 | 231 | return lprobs_lm, logits_classifier 232 | 233 | def forward_loss(self, batch, hidden_states, epsilon=0): 234 | 235 | lprobs_lm, logits_classifier = self.forward(batch, hidden_states) 236 | 237 | loss_generation, _ = label_smoothed_nll_loss( 238 | lprobs_lm, 239 | batch["trg_input_ids"][:, 1:], 240 | epsilon=epsilon, 241 | ignore_index=self.pad_token_id, 242 | ) 243 | loss_generation = loss_generation / batch["trg_attention_mask"][:, 1:].sum() 244 | 245 | loss_classifier = torch.nn.functional.cross_entropy( 246 | logits_classifier, 247 | torch.zeros( 248 | (logits_classifier.shape[0]), 249 | dtype=torch.long, 250 | device=logits_classifier.device, 251 | ), 252 | ) 253 | 254 | return loss_generation, loss_classifier 255 | 256 | def forward_all_targets(self, batch, hidden_states): 257 | ( 258 | decoder_hidden, 259 | decoder_context, 260 | decoder_append, 261 | ) = self._get_hidden_context_append_vectors(batch, hidden_states) 262 | 263 | all_contexts, lm_scores = self.lstm.forward_all_targets( 264 | batch, 265 | decoder_hidden[batch["offsets_candidates"]], 266 | decoder_context[batch["offsets_candidates"]], 267 | decoder_append[batch["offsets_candidates"]], 268 | ) 269 | 270 | classifier_scores = self.classifier( 271 | torch.cat( 272 | ( 273 | decoder_append[batch["offsets_candidates"]], 274 | all_contexts, 275 | ), 276 | dim=-1, 277 | ) 278 | ).squeeze(-1) 279 | 280 | scores = lm_scores + classifier_scores 281 | 282 | classifier_scores = torch.cat( 283 | [ 284 | e.log_softmax(-1) 285 | for e in classifier_scores.split(batch["split_candidates"]) 286 | ] 287 | ) 288 | 289 | tokens = [ 290 | t[s.argsort(descending=True)] 291 | for s, t in zip( 292 | scores.split(batch["split_candidates"]), 293 | batch["cand_input_ids"].split(batch["split_candidates"]), 294 | ) 295 | ] 296 | 297 | tokens = torch.nn.utils.rnn.pad_sequence( 298 | tokens, batch_first=True, padding_value=self.pad_token_id 299 | ) 300 | 301 | scores = torch.nn.utils.rnn.pad_sequence( 302 | [ 303 | e.sort(descending=True).values 304 | for e in scores.split(batch["split_candidates"]) 305 | ], 306 | batch_first=True, 307 | padding_value=-float("inf"), 308 | ) 309 | 310 | return tokens, scores 311 | 312 | def forward_beam_search( 313 | self, batch, hidden_states, batch_trie_dict=None, beams=5, alpha=1, max_len=15 314 | ): 315 | ( 316 | decoder_hidden, 317 | decoder_context, 318 | decoder_append, 319 | ) = self._get_hidden_context_append_vectors(batch, hidden_states) 320 | 321 | tokens, lm_scores, all_contexts = beam_search( 322 | self.lstm, 323 | self.lstm.lm_head.decoder.out_features, 324 | (decoder_hidden, decoder_context, decoder_append), 325 | self.bos_token_id, 326 | self.eos_token_id, 327 | self.pad_token_id, 328 | beam_width=beams, 329 | alpha=alpha, 330 | max_len=max_len, 331 | batch_trie_dict=batch_trie_dict, 332 | ) 333 | 334 | classifier_scores = self.classifier( 335 | torch.cat( 336 | ( 337 | decoder_append.unsqueeze(1).repeat(1, beams, 1), 338 | all_contexts, 339 | ), 340 | dim=-1, 341 | ) 342 | ).squeeze(-1) 343 | 344 | classifier_scores[lm_scores == -float("inf")] = -float("inf") 345 | classifier_scores = classifier_scores.log_softmax(-1) 346 | 347 | scores = (classifier_scores + lm_scores).sort(-1, descending=True) 348 | tokens = tokens[ 349 | torch.arange(scores.indices.shape[0], device=tokens.device) 350 | .unsqueeze(-1) 351 | .repeat(1, beams), 352 | scores.indices, 353 | ] 354 | scores = scores.values 355 | 356 | return tokens, scores 357 | -------------------------------------------------------------------------------- /notebooks/Test.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "7c056d36", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2\n", 12 | "\n", 13 | "import sys\n", 14 | "sys.path.append(\"../\")" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "759b3485", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "from argparse import ArgumentParser\n", 25 | "from pytorch_lightning import Trainer\n", 26 | "from src.model.efficient_el import EfficientEL\n", 27 | "from src.data.dataset_el import DatasetEL\n", 28 | "from IPython.display import Markdown\n", 29 | "from src.utils import get_markdown" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "id": "f7a51eca", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stderr", 40 | "output_type": "stream", 41 | "text": [ 42 | "GPU available: True, used: True\n", 43 | "TPU available: False, using: 0 TPU cores\n", 44 | "Using native 16bit precision.\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "parser = ArgumentParser()\n", 50 | "\n", 51 | "parser = Trainer.add_argparse_args(parser)\n", 52 | "\n", 53 | "args, _ = parser.parse_known_args()\n", 54 | "args.gpus = 1\n", 55 | "args.precision = 16\n", 56 | "\n", 57 | "trainer = Trainer.from_argparse_args(args)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "id": "2e625161", 64 | "metadata": { 65 | "scrolled": true 66 | }, 67 | "outputs": [ 68 | { 69 | "name": "stderr", 70 | "output_type": "stream", 71 | "text": [ 72 | "Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForMaskedLM: ['longformer.encoder.layer.8.attention.self.query.weight', 'longformer.encoder.layer.8.attention.self.query.bias', 'longformer.encoder.layer.8.attention.self.key.weight', 'longformer.encoder.layer.8.attention.self.key.bias', 'longformer.encoder.layer.8.attention.self.value.weight', 'longformer.encoder.layer.8.attention.self.value.bias', 'longformer.encoder.layer.8.attention.self.query_global.weight', 'longformer.encoder.layer.8.attention.self.query_global.bias', 'longformer.encoder.layer.8.attention.self.key_global.weight', 'longformer.encoder.layer.8.attention.self.key_global.bias', 'longformer.encoder.layer.8.attention.self.value_global.weight', 'longformer.encoder.layer.8.attention.self.value_global.bias', 'longformer.encoder.layer.8.attention.output.dense.weight', 'longformer.encoder.layer.8.attention.output.dense.bias', 'longformer.encoder.layer.8.attention.output.LayerNorm.weight', 'longformer.encoder.layer.8.attention.output.LayerNorm.bias', 'longformer.encoder.layer.8.intermediate.dense.weight', 'longformer.encoder.layer.8.intermediate.dense.bias', 'longformer.encoder.layer.8.output.dense.weight', 'longformer.encoder.layer.8.output.dense.bias', 'longformer.encoder.layer.8.output.LayerNorm.weight', 'longformer.encoder.layer.8.output.LayerNorm.bias', 'longformer.encoder.layer.9.attention.self.query.weight', 'longformer.encoder.layer.9.attention.self.query.bias', 'longformer.encoder.layer.9.attention.self.key.weight', 'longformer.encoder.layer.9.attention.self.key.bias', 'longformer.encoder.layer.9.attention.self.value.weight', 'longformer.encoder.layer.9.attention.self.value.bias', 'longformer.encoder.layer.9.attention.self.query_global.weight', 'longformer.encoder.layer.9.attention.self.query_global.bias', 'longformer.encoder.layer.9.attention.self.key_global.weight', 'longformer.encoder.layer.9.attention.self.key_global.bias', 'longformer.encoder.layer.9.attention.self.value_global.weight', 'longformer.encoder.layer.9.attention.self.value_global.bias', 'longformer.encoder.layer.9.attention.output.dense.weight', 'longformer.encoder.layer.9.attention.output.dense.bias', 'longformer.encoder.layer.9.attention.output.LayerNorm.weight', 'longformer.encoder.layer.9.attention.output.LayerNorm.bias', 'longformer.encoder.layer.9.intermediate.dense.weight', 'longformer.encoder.layer.9.intermediate.dense.bias', 'longformer.encoder.layer.9.output.dense.weight', 'longformer.encoder.layer.9.output.dense.bias', 'longformer.encoder.layer.9.output.LayerNorm.weight', 'longformer.encoder.layer.9.output.LayerNorm.bias', 'longformer.encoder.layer.10.attention.self.query.weight', 'longformer.encoder.layer.10.attention.self.query.bias', 'longformer.encoder.layer.10.attention.self.key.weight', 'longformer.encoder.layer.10.attention.self.key.bias', 'longformer.encoder.layer.10.attention.self.value.weight', 'longformer.encoder.layer.10.attention.self.value.bias', 'longformer.encoder.layer.10.attention.self.query_global.weight', 'longformer.encoder.layer.10.attention.self.query_global.bias', 'longformer.encoder.layer.10.attention.self.key_global.weight', 'longformer.encoder.layer.10.attention.self.key_global.bias', 'longformer.encoder.layer.10.attention.self.value_global.weight', 'longformer.encoder.layer.10.attention.self.value_global.bias', 'longformer.encoder.layer.10.attention.output.dense.weight', 'longformer.encoder.layer.10.attention.output.dense.bias', 'longformer.encoder.layer.10.attention.output.LayerNorm.weight', 'longformer.encoder.layer.10.attention.output.LayerNorm.bias', 'longformer.encoder.layer.10.intermediate.dense.weight', 'longformer.encoder.layer.10.intermediate.dense.bias', 'longformer.encoder.layer.10.output.dense.weight', 'longformer.encoder.layer.10.output.dense.bias', 'longformer.encoder.layer.10.output.LayerNorm.weight', 'longformer.encoder.layer.10.output.LayerNorm.bias', 'longformer.encoder.layer.11.attention.self.query.weight', 'longformer.encoder.layer.11.attention.self.query.bias', 'longformer.encoder.layer.11.attention.self.key.weight', 'longformer.encoder.layer.11.attention.self.key.bias', 'longformer.encoder.layer.11.attention.self.value.weight', 'longformer.encoder.layer.11.attention.self.value.bias', 'longformer.encoder.layer.11.attention.self.query_global.weight', 'longformer.encoder.layer.11.attention.self.query_global.bias', 'longformer.encoder.layer.11.attention.self.key_global.weight', 'longformer.encoder.layer.11.attention.self.key_global.bias', 'longformer.encoder.layer.11.attention.self.value_global.weight', 'longformer.encoder.layer.11.attention.self.value_global.bias', 'longformer.encoder.layer.11.attention.output.dense.weight', 'longformer.encoder.layer.11.attention.output.dense.bias', 'longformer.encoder.layer.11.attention.output.LayerNorm.weight', 'longformer.encoder.layer.11.attention.output.LayerNorm.bias', 'longformer.encoder.layer.11.intermediate.dense.weight', 'longformer.encoder.layer.11.intermediate.dense.bias', 'longformer.encoder.layer.11.output.dense.weight', 'longformer.encoder.layer.11.output.dense.bias', 'longformer.encoder.layer.11.output.LayerNorm.weight', 'longformer.encoder.layer.11.output.LayerNorm.bias']\n", 73 | "- This IS expected if you are initializing LongformerForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", 74 | "- This IS NOT expected if you are initializing LongformerForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", 75 | "Some weights of LongformerForMaskedLM were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['lm_head.decoder.bias']\n", 76 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "model = EfficientEL.load_from_checkpoint(\"../models/model.ckpt\").eval()" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 5, 87 | "id": "ea367df0", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stderr", 92 | "output_type": "stream", 93 | "text": [ 94 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n", 95 | "/home/ndecao/.anaconda3/envs/nlp38/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:68: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", 96 | " warnings.warn(*args, **kwargs)\n" 97 | ] 98 | }, 99 | { 100 | "data": { 101 | "application/vnd.jupyter.widget-view+json": { 102 | "model_id": "71b95647f19c4d428c96e0fa3e8603d9", 103 | "version_major": 2, 104 | "version_minor": 0 105 | }, 106 | "text/plain": [ 107 | "Testing: 0it [00:00, ?it/s]" 108 | ] 109 | }, 110 | "metadata": {}, 111 | "output_type": "display_data" 112 | }, 113 | { 114 | "name": "stdout", 115 | "output_type": "stream", 116 | "text": [ 117 | "--------------------------------------------------------------------------------\n", 118 | "DATALOADER:0 TEST RESULTS\n", 119 | "{'ed_macro_f1': 0.9203808307647705,\n", 120 | " 'ed_macro_prec': 0.9131189584732056,\n", 121 | " 'ed_macro_rec': 0.9390283226966858,\n", 122 | " 'ed_micro_f1': 0.9348137378692627,\n", 123 | " 'ed_micro_prec': 0.9219427704811096,\n", 124 | " 'ed_micro_rec': 0.9480490684509277,\n", 125 | " 'macro_f1': 0.8363054394721985,\n", 126 | " 'macro_prec': 0.8289670348167419,\n", 127 | " 'macro_rec': 0.8539509773254395,\n", 128 | " 'micro_f1': 0.8550071120262146,\n", 129 | " 'micro_prec': 0.8432350158691406,\n", 130 | " 'micro_rec': 0.8671125769615173}\n", 131 | "--------------------------------------------------------------------------------\n" 132 | ] 133 | }, 134 | { 135 | "data": { 136 | "text/plain": [ 137 | "[{'micro_f1': 0.8550071120262146,\n", 138 | " 'ed_micro_f1': 0.9348137378692627,\n", 139 | " 'micro_prec': 0.8432350158691406,\n", 140 | " 'macro_rec': 0.8539509773254395,\n", 141 | " 'macro_f1': 0.8363054394721985,\n", 142 | " 'macro_prec': 0.8289670348167419,\n", 143 | " 'micro_rec': 0.8671125769615173,\n", 144 | " 'ed_micro_prec': 0.9219427704811096,\n", 145 | " 'ed_micro_rec': 0.9480490684509277,\n", 146 | " 'ed_macro_f1': 0.9203808307647705,\n", 147 | " 'ed_macro_prec': 0.9131189584732056,\n", 148 | " 'ed_macro_rec': 0.9390283226966858}]" 149 | ] 150 | }, 151 | "execution_count": 5, 152 | "metadata": {}, 153 | "output_type": "execute_result" 154 | } 155 | ], 156 | "source": [ 157 | "model.hparams.threshold = -3.2\n", 158 | "model.hparams.test_with_beam_search = False\n", 159 | "model.hparams.test_with_beam_search_no_candidates = False\n", 160 | "trainer.test(model, test_dataloaders=model.test_dataloader(), ckpt_path=None)" 161 | ] 162 | }, 163 | { 164 | "cell_type": "code", 165 | "execution_count": 6, 166 | "id": "ec5eeeb5", 167 | "metadata": {}, 168 | "outputs": [ 169 | { 170 | "data": { 171 | "application/vnd.jupyter.widget-view+json": { 172 | "model_id": "a0701f715e594adf98ab3f00452b64b2", 173 | "version_major": 2, 174 | "version_minor": 0 175 | }, 176 | "text/plain": [ 177 | "Loading ..: 0%| | 0/470105 [00:00" 234 | ] 235 | }, 236 | "execution_count": 8, 237 | "metadata": {}, 238 | "output_type": "execute_result" 239 | } 240 | ], 241 | "source": [ 242 | "Markdown(get_markdown([s], [[(s[0], s[1], s[2][0][0]) for s in spans] for spans in model.sample([s])])[0])" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": "Python 3", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.8.8" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 5 267 | } 268 | -------------------------------------------------------------------------------- /src/model/efficient_el.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | 6 | import pytorch_lightning as pl 7 | import torch 8 | from pytorch_lightning import LightningModule 9 | from torch.utils.data import DataLoader 10 | from tqdm.auto import tqdm 11 | from transformers import ( 12 | AutoTokenizer, 13 | LongformerForMaskedLM, 14 | get_linear_schedule_with_warmup, 15 | ) 16 | 17 | from src.data.dataset_el import DatasetEL 18 | from src.model.entity_detection import EntityDetectionFactor 19 | from src.model.entity_linking import EntityLinkingLSTM 20 | from src.utils import ( 21 | MacroF1, 22 | MacroPrecision, 23 | MacroRecall, 24 | MicroF1, 25 | MicroPrecision, 26 | MicroRecall, 27 | ) 28 | 29 | 30 | class EfficientEL(LightningModule): 31 | @staticmethod 32 | def add_model_specific_args(parent_parser): 33 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 34 | parser.add_argument( 35 | "--train_data_path", 36 | type=str, 37 | default="../data/aida_train_dataset.jsonl", 38 | ) 39 | parser.add_argument( 40 | "--dev_data_path", 41 | type=str, 42 | default="../data/aida_val_dataset.jsonl", 43 | ) 44 | parser.add_argument( 45 | "--test_data_path", 46 | type=str, 47 | default="../data/aida_test_dataset.jsonl", 48 | ) 49 | parser.add_argument("--batch_size", type=int, default=2) 50 | parser.add_argument("--lr_transformer", type=float, default=1e-4) 51 | parser.add_argument("--lr", type=float, default=1e-3) 52 | parser.add_argument("--max_length_train", type=int, default=1024) 53 | parser.add_argument("--max_length", type=int, default=4096) 54 | parser.add_argument("--weight_decay", type=int, default=0.01) 55 | parser.add_argument("--total_num_updates", type=int, default=10000) 56 | parser.add_argument("--warmup_updates", type=int, default=500) 57 | parser.add_argument("--num_workers", type=int, default=0) 58 | parser.add_argument("--dropout", type=float, default=0.1) 59 | parser.add_argument("--max_length_span", type=int, default=15) 60 | parser.add_argument("--threshold", type=int, default=0) 61 | parser.add_argument("--test_with_beam_search", action="store_true") 62 | parser.add_argument( 63 | "--test_with_beam_search_no_candidates", action="store_true" 64 | ) 65 | parser.add_argument( 66 | "--model_name", type=str, default="allenai/longformer-base-4096" 67 | ) 68 | parser.add_argument( 69 | "--mentions_filename", 70 | type=str, 71 | default="../data/mentions.json", 72 | ) 73 | parser.add_argument( 74 | "--entities_filename", 75 | type=str, 76 | default="../data/entities.json", 77 | ) 78 | parser.add_argument("--epsilon", type=float, default=0.1) 79 | return parser 80 | 81 | def __init__(self, *args, **kwargs): 82 | super().__init__() 83 | self.save_hyperparameters() 84 | 85 | self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.model_name) 86 | 87 | longformer = LongformerForMaskedLM.from_pretrained( 88 | self.hparams.model_name, 89 | num_hidden_layers=8, 90 | attention_window=[128] * 8, 91 | ) 92 | 93 | self.encoder = longformer.longformer 94 | 95 | self.encoder.embeddings.word_embeddings.weight.requires_grad_(False) 96 | 97 | self.entity_detection = EntityDetectionFactor( 98 | self.hparams.max_length_span, 99 | self.hparams.dropout, 100 | mentions_filename=self.hparams.mentions_filename, 101 | ) 102 | 103 | self.entity_linking = EntityLinkingLSTM( 104 | self.tokenizer.bos_token_id, 105 | self.tokenizer.pad_token_id, 106 | self.tokenizer.eos_token_id, 107 | self.encoder.embeddings.word_embeddings, 108 | longformer.lm_head, 109 | self.hparams.dropout, 110 | ) 111 | 112 | self.micro_f1 = MicroF1() 113 | self.micro_prec = MicroPrecision() 114 | self.micro_rec = MicroRecall() 115 | 116 | self.macro_f1 = MacroF1() 117 | self.macro_prec = MacroPrecision() 118 | self.macro_rec = MacroRecall() 119 | 120 | self.ed_micro_f1 = MicroF1() 121 | self.ed_micro_prec = MicroPrecision() 122 | self.ed_micro_rec = MicroRecall() 123 | 124 | self.ed_macro_f1 = MacroF1() 125 | self.ed_macro_prec = MacroPrecision() 126 | self.ed_macro_rec = MacroRecall() 127 | 128 | def train_dataloader(self, shuffle=True): 129 | if not hasattr(self, "train_dataset") or self.hparams.sharded: 130 | self.train_dataset = DatasetEL( 131 | tokenizer=self.tokenizer, 132 | data_path=self.hparams.train_data_path, 133 | max_length=self.hparams.max_length_train, 134 | max_length_span=self.hparams.max_length_span, 135 | ) 136 | return DataLoader( 137 | self.train_dataset, 138 | batch_size=self.hparams.batch_size, 139 | collate_fn=self.train_dataset.collate_fn, 140 | num_workers=self.hparams.num_workers, 141 | shuffle=shuffle, 142 | ) 143 | 144 | def val_dataloader(self): 145 | if not hasattr(self, "val_dataset"): 146 | self.val_dataset = DatasetEL( 147 | tokenizer=self.tokenizer, 148 | data_path=self.hparams.dev_data_path, 149 | max_length=self.hparams.max_length, 150 | max_length_span=self.hparams.max_length_span, 151 | test=True, 152 | ) 153 | return DataLoader( 154 | self.val_dataset, 155 | batch_size=1, 156 | collate_fn=self.val_dataset.collate_fn, 157 | num_workers=self.hparams.num_workers, 158 | ) 159 | 160 | def test_dataloader(self): 161 | if not hasattr(self, "test_dataset"): 162 | self.test_dataset = DatasetEL( 163 | tokenizer=self.tokenizer, 164 | data_path=self.hparams.test_data_path, 165 | max_length=self.hparams.max_length, 166 | max_length_span=self.hparams.max_length_span, 167 | test=True, 168 | ) 169 | 170 | return DataLoader( 171 | self.test_dataset, 172 | batch_size=1, 173 | collate_fn=self.test_dataset.collate_fn, 174 | num_workers=self.hparams.num_workers, 175 | ) 176 | 177 | def forward_all_targets(self, batch, return_dict=False): 178 | 179 | hidden_states = self.encoder( 180 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"] 181 | ).last_hidden_state 182 | 183 | ( 184 | (start, end, scores_ed), 185 | ( 186 | logits_classifier_start, 187 | logits_classifier_end, 188 | ), 189 | ) = self.entity_detection.forward_hard( 190 | batch, hidden_states, threshold=self.hparams.threshold 191 | ) 192 | 193 | if start.shape[0] == 0: 194 | return [] 195 | 196 | batch["offsets_start"] = start.T.tolist() 197 | batch["offsets_end"] = end.T.tolist() 198 | 199 | batch_candidates = [ 200 | { 201 | (s, e): c 202 | for (s, e, _), c in zip( 203 | batch["raw"][i]["anchors"], batch["raw"][i]["candidates"] 204 | ) 205 | }.get(tuple((s, e)), ["NIL"]) 206 | for (i, s), (_, e) in zip( 207 | zip(*batch["offsets_start"]), zip(*batch["offsets_end"]) 208 | ) 209 | ] 210 | 211 | try: 212 | for k, v in self.tokenizer( 213 | [c for candidates in batch_candidates for c in candidates], 214 | return_tensors="pt", 215 | padding=True, 216 | ).items(): 217 | batch[f"cand_{k}"] = v.to(self.device) 218 | 219 | batch["offsets_candidates"] = [ 220 | i 221 | for i, candidates in enumerate(batch_candidates) 222 | for _ in range(len(candidates)) 223 | ] 224 | batch["split_candidates"] = [ 225 | len(candidates) for candidates in batch_candidates 226 | ] 227 | 228 | tokens, scores_el = self.entity_linking.forward_all_targets( 229 | batch, hidden_states 230 | ) 231 | 232 | except: 233 | if not self.training: 234 | print("error on generation") 235 | 236 | try: 237 | spans = self._tokens_scores_to_spans(batch, start, end, tokens, scores_el) 238 | except: 239 | if not self.training: 240 | print("error on _tokens_scores_to_spans") 241 | 242 | spans = [[[0, 0, [("NIL", 0)]]] for i in range(len(batch["src_input_ids"]))] 243 | 244 | if return_dict: 245 | return { 246 | "spans": spans, 247 | "start": start, 248 | "end": end, 249 | "scores_ed": scores_ed, 250 | "scores_el": scores_el, 251 | "logits_classifier_start": logits_classifier_start, 252 | "logits_classifier_end": logits_classifier_end, 253 | } 254 | else: 255 | return spans 256 | 257 | def forward_beam_search(self, batch, candidates=False): 258 | 259 | hidden_states = self.encoder( 260 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"] 261 | ).last_hidden_state 262 | 263 | ( 264 | (start, end, scores_ed), 265 | ( 266 | logits_classifier_start, 267 | logits_classifier_end, 268 | ), 269 | ) = self.entity_detection.forward_hard( 270 | batch, hidden_states, threshold=self.hparams.threshold 271 | ) 272 | 273 | if start.shape[0] == 0: 274 | return [] 275 | 276 | if start.shape[0] == 0: 277 | return [] 278 | 279 | batch["offsets_start"] = start.T.tolist() 280 | batch["offsets_end"] = end.T.tolist() 281 | 282 | batch_trie_dict = None 283 | if candidates: 284 | batch_candidates = [ 285 | { 286 | (s, e): c 287 | for (s, e, _), c in zip( 288 | batch["raw"][i]["anchors"], batch["raw"][i]["candidates"] 289 | ) 290 | }.get(tuple((s, e)), ["NIL"]) 291 | for (i, s), (_, e) in zip( 292 | zip(*batch["offsets_start"]), zip(*batch["offsets_end"]) 293 | ) 294 | ] 295 | 296 | batch_trie_dict = [] 297 | for candidates in batch_candidates: 298 | trie_dict = defaultdict(set) 299 | for c in self.tokenizer(candidates)["input_ids"]: 300 | for i in range(1, len(c)): 301 | trie_dict[tuple(c[:i])].add(c[i]) 302 | 303 | batch_trie_dict.append({k: list(v) for k, v in trie_dict.items()}) 304 | else: 305 | batch_trie_dict = [self.global_trie] * start.shape[0] 306 | 307 | tokens, scores_el = self.entity_linking.forward_beam_search( 308 | batch, 309 | hidden_states, 310 | batch_trie_dict, 311 | ) 312 | 313 | return self._tokens_scores_to_spans(batch, start, end, tokens, scores_el) 314 | 315 | def _tokens_scores_to_spans(self, batch, start, end, tokens, scores_el): 316 | 317 | spans = [ 318 | [ 319 | [ 320 | s, 321 | e, 322 | list( 323 | zip( 324 | self.tokenizer.batch_decode(t, skip_special_tokens=True), 325 | l.tolist(), 326 | ) 327 | ), 328 | ] 329 | for s, e, t, l in zip( 330 | start[start[:, 0] == i][:, 1].tolist(), 331 | end[end[:, 0] == i][:, 1].tolist(), 332 | tokens[start[:, 0] == i], 333 | scores_el[start[:, 0] == i], 334 | ) 335 | ] 336 | for i in range(len(batch["src_input_ids"])) 337 | ] 338 | 339 | for spans_ in spans: 340 | for e in [ 341 | [x, y] 342 | for x in spans_ 343 | for y in spans_ 344 | if x is not y and x[1] >= y[0] and x[0] <= y[0] 345 | ]: 346 | for x in sorted(e, key=lambda x: x[1] - x[0])[:-1]: 347 | spans_.remove(x) 348 | 349 | return spans 350 | 351 | def training_step(self, batch, batch_idx=None): 352 | 353 | hidden_states = self.encoder( 354 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"] 355 | ).last_hidden_state 356 | 357 | loss_start, loss_end = self.entity_detection.forward_loss(batch, hidden_states) 358 | 359 | loss_generation, loss_classifier = self.entity_linking.forward_loss( 360 | batch, hidden_states, epsilon=self.hparams.epsilon 361 | ) 362 | 363 | self.log("loss_s", loss_start, on_step=True, on_epoch=False, prog_bar=True) 364 | self.log("loss_e", loss_end, on_step=True, on_epoch=False, prog_bar=True) 365 | self.log("loss_g", loss_generation, on_step=True, on_epoch=False, prog_bar=True) 366 | self.log("loss_c", loss_classifier, on_step=True, on_epoch=False, prog_bar=True) 367 | 368 | return {"loss": loss_start + loss_end + loss_generation + loss_classifier} 369 | 370 | def _inference_step(self, batch, batch_idx=None): 371 | if self.hparams.test_with_beam_search_no_candidates: 372 | spans = self.forward_beam_search(batch) 373 | elif self.hparams.test_with_beam_search: 374 | spans = self.forward_beam_search(batch, candidates=True) 375 | else: 376 | spans = self.forward_all_targets(batch) 377 | 378 | for p, g in zip(spans, batch["raw"]): 379 | 380 | p_ = set((e[0], e[1], e[2][0][0]) for e in p) 381 | g_ = set((e[0], e[1], e[2]) for e in g["anchors"]) 382 | 383 | self.micro_f1(p_, g_) 384 | self.micro_prec(p_, g_) 385 | self.micro_rec(p_, g_) 386 | 387 | self.macro_f1(p_, g_) 388 | self.macro_prec(p_, g_) 389 | self.macro_rec(p_, g_) 390 | 391 | p_ = set((e[0], e[1]) for e in p) 392 | g_ = set((e[0], e[1]) for e in g["anchors"]) 393 | 394 | self.ed_micro_f1(p_, g_) 395 | self.ed_micro_prec(p_, g_) 396 | self.ed_micro_rec(p_, g_) 397 | 398 | self.ed_macro_f1(p_, g_) 399 | self.ed_macro_prec(p_, g_) 400 | self.ed_macro_rec(p_, g_) 401 | 402 | return { 403 | "micro_f1": self.micro_f1, 404 | "micro_prec": self.micro_prec, 405 | "macro_rec": self.macro_rec, 406 | "macro_f1": self.macro_f1, 407 | "macro_prec": self.macro_prec, 408 | "micro_rec": self.micro_rec, 409 | "ed_micro_f1": self.ed_micro_f1, 410 | "ed_micro_prec": self.ed_micro_prec, 411 | "ed_micro_rec": self.ed_micro_rec, 412 | "ed_macro_f1": self.ed_macro_f1, 413 | "ed_macro_prec": self.ed_macro_prec, 414 | "ed_macro_rec": self.ed_macro_rec, 415 | } 416 | 417 | def validation_step(self, batch, batch_idx=None): 418 | metrics = self._inference_step(batch, batch_idx) 419 | self.log_dict( 420 | {k: v for k, v in metrics.items() if k in ("micro_f1", "ed_micro_f1")}, 421 | prog_bar=True, 422 | ) 423 | 424 | def test_step(self, batch, batch_idx=None): 425 | metrics = self._inference_step(batch, batch_idx) 426 | self.log_dict(metrics) 427 | 428 | def generate_global_trie(self): 429 | 430 | with open(self.hparams.entities_filename) as f: 431 | entities = json.load(f) 432 | 433 | trie_dict = defaultdict(set) 434 | for e in tqdm(entities, desc="Loading .."): 435 | c = self.tokenizer(e)["input_ids"] 436 | for i in range(1, len(c)): 437 | trie_dict[tuple(c[:i])].add(c[i]) 438 | 439 | self.global_trie = {k: list(v) for k, v in trie_dict.items()} 440 | 441 | def sample(self, sentences, anchors=None, candidates=None, all_targets=False): 442 | self.eval() 443 | with torch.no_grad(): 444 | batch = { 445 | f"src_{k}": v.to(self.device) 446 | for k, v in self.tokenizer( 447 | sentences, 448 | return_offsets_mapping=True, 449 | return_tensors="pt", 450 | padding=True, 451 | max_length=self.hparams.max_length, 452 | truncation=True, 453 | ).items() 454 | } 455 | 456 | batch["raw"] = [ 457 | { 458 | "input": sentence, 459 | "anchors": anchors[i] if anchors else None, 460 | "candidates": candidates[i] if candidates else None, 461 | } 462 | for i, sentence in enumerate(sentences) 463 | ] 464 | 465 | if anchors and candidates and all_targets: 466 | spans = self.forward_all_targets(batch) 467 | elif candidates and all_targets: 468 | spans = self.forward_beam_search(batch, candidates=True) 469 | else: 470 | spans = self.forward_beam_search(batch) 471 | 472 | return [ 473 | [(eo[s][0].item(), eo[e][1].item(), l) for s, e, l in es] 474 | for es, eo in zip(spans, batch["src_offset_mapping"]) 475 | ] 476 | 477 | def configure_optimizers(self): 478 | no_decay = ["bias", "LayerNorm.weight"] 479 | optimizer_grouped_parameters = [ 480 | { 481 | "params": [ 482 | p 483 | for n, p in self.encoder.named_parameters() 484 | if "embbedding" not in n and not any(nd in n for nd in no_decay) 485 | ], 486 | "weight_decay": self.hparams.weight_decay, 487 | "lr": self.hparams.lr_transformer, 488 | }, 489 | { 490 | "params": [ 491 | p 492 | for n, p in self.encoder.named_parameters() 493 | if "embbedding" not in n and any(nd in n for nd in no_decay) 494 | ], 495 | "weight_decay": 0, 496 | "lr": self.hparams.lr_transformer, 497 | }, 498 | { 499 | "params": [ 500 | p 501 | for n, p in self.entity_detection.named_parameters() 502 | if not any(nd in n for nd in no_decay) 503 | ], 504 | "weight_decay": self.hparams.weight_decay, 505 | }, 506 | { 507 | "params": [ 508 | p 509 | for n, p in self.entity_detection.named_parameters() 510 | if any(nd in n for nd in no_decay) 511 | ], 512 | "weight_decay": 0, 513 | }, 514 | { 515 | "params": [ 516 | p 517 | for n, p in self.entity_linking.named_parameters() 518 | if "embedding" not in n and not any(nd in n for nd in no_decay) 519 | ], 520 | "weight_decay": self.hparams.weight_decay, 521 | }, 522 | { 523 | "params": [ 524 | p 525 | for n, p in self.entity_linking.named_parameters() 526 | if "embbedding" not in n and any(nd in n for nd in no_decay) 527 | ], 528 | "weight_decay": 0, 529 | }, 530 | ] 531 | 532 | optimizer = torch.optim.AdamW( 533 | optimizer_grouped_parameters, 534 | lr=self.hparams.lr, 535 | weight_decay=self.hparams.weight_decay, 536 | amsgrad=True, 537 | ) 538 | 539 | scheduler = get_linear_schedule_with_warmup( 540 | optimizer, 541 | num_warmup_steps=self.hparams.warmup_updates, 542 | num_training_steps=self.hparams.total_num_updates, 543 | ) 544 | 545 | return ( 546 | [optimizer], 547 | [{"scheduler": scheduler, "interval": "step", "frequency": 1}], 548 | ) 549 | --------------------------------------------------------------------------------