├── src
├── __init__.py
├── data
│ ├── __init__.py
│ └── dataset_el.py
├── model
│ ├── __init__.py
│ ├── entity_detection.py
│ ├── entity_linking.py
│ └── efficient_el.py
├── utils.py
└── beam_search.py
├── requirement.txt
├── LICENSE
├── scripts
└── train.py
├── .gitignore
├── README.md
└── notebooks
├── README.md
└── Test.ipynb
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | pytorch>=1.7
2 | pytorch_lightning>=1.3
3 | transformers>=4.0
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Nicola De Cao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 | from pprint import pprint
4 |
5 | from pytorch_lightning import LightningModule, Trainer
6 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
7 | from pytorch_lightning.loggers import TensorBoardLogger
8 | from pytorch_lightning.utilities.seed import seed_everything
9 |
10 | from src.model.efficient_el import EfficientEL
11 |
12 | if __name__ == "__main__":
13 | parser = ArgumentParser()
14 |
15 | parser.add_argument("--dirpath", type=str, default="models")
16 | parser.add_argument("--save_top_k", type=int, default=10)
17 | parser.add_argument("--seed", type=int, default=0)
18 |
19 | parser = EfficientEL.add_model_specific_args(parser)
20 | parser = Trainer.add_argparse_args(parser)
21 |
22 | args, _ = parser.parse_known_args()
23 | pprint(args.__dict__)
24 |
25 | seed_everything(seed=args.seed)
26 |
27 | logger = TensorBoardLogger(args.dirpath, name=None)
28 |
29 | callbacks = [
30 | ModelCheckpoint(
31 | mode="max",
32 | monitor="micro_f1",
33 | dirpath=os.path.join(logger.log_dir, "checkpoints"),
34 | save_top_k=args.save_top_k,
35 | filename="model-epoch={epoch:02d}-micro_f1={micro_f1:.4f}-ed_micro_f1={ed_micro_f1:.4f}",
36 | ),
37 | LearningRateMonitor(
38 | logging_interval="step",
39 | ),
40 | ]
41 |
42 | trainer = Trainer.from_argparse_args(args, logger=logger, callbacks=callbacks)
43 |
44 | model = EfficientEL(**vars(args))
45 |
46 | trainer.fit(model)
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Highly Parallel Autoregressive Entity Linking
with Discriminative Correction
2 |
3 | ## Overview
4 |
5 | This repository contains the Pytorch implementation of [[1]](#citation)(https://arxiv.org/abs/2109.03792).
6 |
7 | Here the [link](https://mega.nz/folder/l4RhnIxL#_oYvidq2qyDIw1sT-KeMQA) to **pre-processed data** used for this work (i.e., training, validation and test splits of AIDA as well as the KB with the entities) and the **released model**.
8 |
9 | ## Dependencies
10 |
11 | * **python>=3.8**
12 | * **pytorch>=1.7**
13 | * **pytorch_lightning>=1.3**
14 | * **transformers>=4.0**
15 |
16 | ## Structure
17 | * [src](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src): The source code of the model. In [src/data](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src/data) there is an class of a dataset for Entity Linking. In [src/model](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/src/model) there are three classes that implement our EL model. One for the Entity Disambiuation part, one for the (autoregresive) Entity Liking part, and one for the entire model (which also contains the training and validation loops).
18 | * [notebooks](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/notebooks): Example code for loading our Entity Linking model, evaluate it on AIDA, and run inference on a test document.
19 |
20 | ## Usage
21 | Please have a look into the [notebooks](https://github.com/nicola-decao/efficient-autoregressive-EL/tree/master/notebooks) folder to see hot to load our Entity Linking model, evaluate it on AIDA, and run inference on a test document.
22 |
23 | Here a minimal example that demonstrate how to use our model:
24 | ```python
25 | from src.model.efficient_el import EfficientEL
26 | from IPython.display import Markdown
27 | from src.utils import
28 |
29 | # loading the model on GPU and setting the the threshold to the
30 | # optimal value (based on AIDA validation set)
31 | model = EfficientEL.load_from_checkpoint("../models/model.ckpt").eval().cuda()
32 | model.hparams.threshold = -3.2
33 |
34 | # loading the KB with the entities
35 | model.generate_global_trie()
36 |
37 | # document which we want to apply EL on
38 | s = """CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 \
39 | West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset \
40 | by an innings and 39 runs in two days to take over at the head of the county championship ."""
41 |
42 | # getting spans from the model and converting the result into Markdown for visualization
43 | Markdown(
44 | get_markdown(
45 | [s],
46 | [[(s[0], s[1], s[2][0][0]) for s in spans]
47 | for spans in model.sample([s])]
48 | )[0]
49 | )
50 | ```
51 | Which will generate:
52 |
53 | > CRICKET - [LEICESTERSHIRE](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) TAKE OVER AT TOP AFTER INNINGS VICTORY . [LONDON](https://en.wikipedia.org/wiki/London) 1996-08-30 [West Indian](https://en.wikipedia.org/wiki/West_Indies) all-rounder [Phil Simmons](https://en.wikipedia.org/wiki/Philip_Walton) took four for 38 on Friday as [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) beat [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) by an innings and 39 runs in two days to take over at the head of the county championship .
54 |
55 |
56 | Please cite [[1](#citation)] in your work when using this library in your experiments.
57 |
58 | ### Training
59 |
60 | To train our model you can run the following comand
61 | ```bash
62 | python scripts/train.py --gpus ${NUM_GPUS} --acceleration ddp --batch_size 32
63 | ```
64 |
65 | ## Feedback
66 | For questions and comments, feel free to contact [Nicola De Cao](mailto:nicola.decao@gmail.com).
67 |
68 | ## License
69 | MIT
70 |
71 | ## Citation
72 | ```
73 | [1] De Cao Nicola, Aziz Wilker, & Titov Ivan. (2021).
74 | Highly parallel autoregressive entity linking with discriminative correction.
75 | Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, 7662–7669.
76 | https://doi.org/10.18653/v1/2021.emnlp-main.604
77 | ```
78 |
79 | BibTeX format:
80 | ```
81 | @inproceedings{de-cao-etal-2021-highly,
82 | title = "Highly Parallel Autoregressive Entity Linking with Discriminative Correction",
83 | author = "De Cao, Nicola and
84 | Aziz, Wilker and
85 | Titov, Ivan",
86 | booktitle = "Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing",
87 | month = nov,
88 | year = "2021",
89 | address = "Online and Punta Cana, Dominican Republic",
90 | publisher = "Association for Computational Linguistics",
91 | url = "https://aclanthology.org/2021.emnlp-main.604",
92 | doi = "10.18653/v1/2021.emnlp-main.604",
93 | pages = "7662--7669",
94 | }
95 | ```
96 |
97 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_lightning.metrics import Metric
3 |
4 |
5 | class MicroF1(Metric):
6 | def __init__(self, dist_sync_on_step=False):
7 | super().__init__(dist_sync_on_step=dist_sync_on_step)
8 |
9 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum")
10 | self.add_state("prec_d", default=torch.tensor(0), dist_reduce_fx="sum")
11 | self.add_state("rec_d", default=torch.tensor(0), dist_reduce_fx="sum")
12 |
13 | def update(self, p, g):
14 |
15 | self.n += len(g.intersection(p))
16 | self.prec_d += len(p)
17 | self.rec_d += len(g)
18 |
19 | def compute(self):
20 | p = self.n.float() / self.prec_d
21 | r = self.n.float() / self.rec_d
22 | return (2 * p * r / (p + r)) if (p + r) > 0 else (p + r)
23 |
24 |
25 | class MacroF1(Metric):
26 | def __init__(self, dist_sync_on_step=False):
27 | super().__init__(dist_sync_on_step=dist_sync_on_step)
28 |
29 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
30 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum")
31 |
32 | def update(self, p, g):
33 |
34 | prec = len(g.intersection(p)) / len(p)
35 | rec = len(g.intersection(p)) / len(g)
36 |
37 | self.n += (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else (prec + rec)
38 | self.d += 1
39 |
40 | def compute(self):
41 | return (self.n / self.d) if self.d > 0 else self.d
42 |
43 |
44 | class MicroPrecision(Metric):
45 | def __init__(self, dist_sync_on_step=False):
46 | super().__init__(dist_sync_on_step=dist_sync_on_step)
47 |
48 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum")
49 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum")
50 |
51 | def update(self, p, g):
52 | self.n += len(g.intersection(p))
53 | self.d += len(p)
54 |
55 | def compute(self):
56 | return (self.n.float() / self.d) if self.d > 0 else self.d
57 |
58 |
59 | class MacroPrecision(Metric):
60 | def __init__(self, dist_sync_on_step=False):
61 | super().__init__(dist_sync_on_step=dist_sync_on_step)
62 |
63 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
64 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum")
65 |
66 | def update(self, p, g):
67 | self.n += len(g.intersection(p)) / len(p)
68 | self.d += 1
69 |
70 | def compute(self):
71 | return (self.n / self.d) if self.d > 0 else self.d
72 |
73 |
74 | class MicroRecall(Metric):
75 | def __init__(self, dist_sync_on_step=False):
76 | super().__init__(dist_sync_on_step=dist_sync_on_step)
77 |
78 | self.add_state("n", default=torch.tensor(0), dist_reduce_fx="sum")
79 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum")
80 |
81 | def update(self, p, g):
82 | self.n += len(g.intersection(p))
83 | self.d += len(g)
84 |
85 | def compute(self):
86 | return (self.n.float() / self.d) if self.d > 0 else self.d
87 |
88 |
89 | class MacroRecall(Metric):
90 | def __init__(self, dist_sync_on_step=False):
91 | super().__init__(dist_sync_on_step=dist_sync_on_step)
92 |
93 | self.add_state("n", default=torch.tensor(0.0), dist_reduce_fx="sum")
94 | self.add_state("d", default=torch.tensor(0), dist_reduce_fx="sum")
95 |
96 | def update(self, p, g):
97 | self.n += len(g.intersection(p)) / len(g)
98 | self.d += 1
99 |
100 | def compute(self):
101 | return (self.n / self.d) if self.d > 0 else self.d
102 |
103 |
104 | def get_markdown(sentences, entity_spans):
105 | return_outputs = []
106 | for sent, entities in zip(sentences, entity_spans):
107 | text = ""
108 | last_end = 0
109 | for begin, end, href in entities:
110 | text += sent[last_end:begin]
111 | text += "[{}](https://en.wikipedia.org/wiki/{})".format(
112 | sent[begin:end], href.replace(" ", "_")
113 | )
114 | last_end = end
115 |
116 | text += sent[last_end:]
117 | return_outputs.append(text)
118 |
119 | return return_outputs
120 |
121 |
122 | def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
123 | if target.dim() == lprobs.dim() - 1:
124 | target = target.unsqueeze(-1)
125 | nll_loss = -lprobs.gather(dim=-1, index=target)
126 | smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
127 | if ignore_index is not None:
128 | pad_mask = target.eq(ignore_index)
129 | nll_loss.masked_fill_(pad_mask, 0.0)
130 | smooth_loss.masked_fill_(pad_mask, 0.0)
131 | else:
132 | nll_loss = nll_loss.squeeze(-1)
133 | smooth_loss = smooth_loss.squeeze(-1)
134 | if reduce:
135 | nll_loss = nll_loss.sum()
136 | smooth_loss = smooth_loss.sum()
137 | eps_i = epsilon / (lprobs.size(-1) - 1)
138 | loss = (1.0 - epsilon - eps_i) * nll_loss + eps_i * smooth_loss
139 | return loss, nll_loss
140 |
--------------------------------------------------------------------------------
/src/data/dataset_el.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import jsonlines
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class DatasetEL(Dataset):
10 | def __init__(
11 | self,
12 | tokenizer,
13 | data_path,
14 | max_length=32,
15 | max_length_span=15,
16 | test=False,
17 | ):
18 | super().__init__()
19 | self.tokenizer = tokenizer
20 |
21 | with jsonlines.open(data_path) as f:
22 | self.data = list(f)
23 |
24 | self.max_length = max_length
25 | self.max_length_span = max_length_span
26 | self.test = test
27 |
28 | def __len__(self):
29 | return len(self.data)
30 |
31 | def __getitem__(self, item):
32 | return self.data[item]
33 |
34 | def collate_fn(self, batch):
35 |
36 | batch = {
37 | **{
38 | f"src_{k}": v
39 | for k, v in self.tokenizer(
40 | [b["input"] for b in batch],
41 | return_tensors="pt",
42 | padding=True,
43 | max_length=self.max_length,
44 | truncation=True,
45 | return_offsets_mapping=True,
46 | ).items()
47 | },
48 | "offsets_start": (
49 | [
50 | i
51 | for i, b in enumerate(batch)
52 | for a in b["anchors"]
53 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
54 | ],
55 | [
56 | a[0]
57 | for i, b in enumerate(batch)
58 | for a in b["anchors"]
59 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
60 | ],
61 | ),
62 | "offsets_end": (
63 | [
64 | i
65 | for i, b in enumerate(batch)
66 | for a in b["anchors"]
67 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
68 | ],
69 | [
70 | a[1]
71 | for i, b in enumerate(batch)
72 | for a in b["anchors"]
73 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
74 | ],
75 | ),
76 | "offsets_inside": (
77 | [
78 | i
79 | for i, b in enumerate(batch)
80 | for a in b["anchors"]
81 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
82 | for j in range(a[0] + 1, a[1] + 1)
83 | ],
84 | [
85 | j
86 | for i, b in enumerate(batch)
87 | for a in b["anchors"]
88 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
89 | for j in range(a[0] + 1, a[1] + 1)
90 | ],
91 | ),
92 | "raw": batch,
93 | }
94 |
95 | if not self.test:
96 |
97 | negatives = [
98 | np.random.choice([e for e in cands if e != a[2]])
99 | if len([e for e in cands if e != a[2]]) > 0
100 | else None
101 | for b in batch["raw"]
102 | for a, cands in zip(b["anchors"], b["candidates"])
103 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
104 | ]
105 |
106 | targets = [
107 | a[2]
108 | for b in batch["raw"]
109 | for a in b["anchors"]
110 | if a[1] < self.max_length and a[1] - a[0] < self.max_length_span
111 | ]
112 |
113 | assert len(targets) == len(negatives)
114 |
115 | batch_upd = {
116 | **(
117 | {
118 | f"trg_{k}": v
119 | for k, v in self.tokenizer(
120 | targets,
121 | return_tensors="pt",
122 | padding=True,
123 | max_length=self.max_length,
124 | truncation=True,
125 | ).items()
126 | }
127 | if not self.test
128 | else {}
129 | ),
130 | **(
131 | {
132 | f"neg_{k}": v
133 | for k, v in self.tokenizer(
134 | [e for e in negatives if e],
135 | return_tensors="pt",
136 | padding=True,
137 | max_length=self.max_length,
138 | truncation=True,
139 | ).items()
140 | }
141 | if not self.test
142 | else {}
143 | ),
144 | "neg_mask": torch.tensor([e is not None for e in negatives]),
145 | }
146 |
147 | batch = {**batch, **batch_upd}
148 |
149 | return batch
150 |
--------------------------------------------------------------------------------
/src/model/entity_detection.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import torch
4 |
5 |
6 | class EntityDetectionFactor(torch.nn.Module):
7 | def __init__(self, max_length_span, dropout=0, mentions_filename=None):
8 | super().__init__()
9 |
10 | self.max_length_span = max_length_span
11 | self.classifier_start = torch.nn.Sequential(
12 | torch.nn.LayerNorm(768),
13 | torch.nn.Dropout(dropout),
14 | torch.nn.Linear(768, 128),
15 | torch.nn.ReLU(),
16 | torch.nn.LayerNorm(128),
17 | torch.nn.Dropout(dropout),
18 | torch.nn.Linear(128, 1),
19 | )
20 | self.classifier_end = torch.nn.Sequential(
21 | torch.nn.LayerNorm(768 * 2),
22 | torch.nn.Dropout(dropout),
23 | torch.nn.Linear(768 * 2, 128),
24 | torch.nn.ReLU(),
25 | torch.nn.LayerNorm(128),
26 | torch.nn.Dropout(dropout),
27 | torch.nn.Linear(128, 1),
28 | )
29 |
30 | self.mentions = None
31 | if mentions_filename:
32 | with open(mentions_filename) as f:
33 | self.mentions = set(json.load(f))
34 |
35 | def _forward_start(self, batch, hidden_states):
36 | return self.classifier_start(hidden_states).squeeze(-1)
37 |
38 | def _forward_end(self, batch, hidden_states, offsets_start):
39 |
40 | classifier_end_input = torch.nn.functional.pad(
41 | hidden_states, (0, 0, 0, self.max_length_span - 1)
42 | )
43 |
44 | classifier_end_input = torch.cat(
45 | (
46 | hidden_states[offsets_start]
47 | .unsqueeze(-1)
48 | .repeat(1, 1, self.max_length_span),
49 | classifier_end_input.unfold(1, self.max_length_span, 1)[offsets_start],
50 | ),
51 | dim=1,
52 | ).permute(0, 2, 1)
53 |
54 | logits_classifier_end = self.classifier_end(classifier_end_input).squeeze(-1)
55 |
56 | mask = torch.cat(
57 | (
58 | batch["src_attention_mask"],
59 | torch.zeros(
60 | (
61 | batch["src_attention_mask"].shape[0],
62 | self.max_length_span - 1,
63 | ),
64 | dtype=torch.float,
65 | device=hidden_states.device,
66 | ),
67 | ),
68 | dim=1,
69 | )
70 | mask = torch.where(
71 | mask.bool(),
72 | torch.zeros_like(mask),
73 | -torch.full_like(mask, float("inf")),
74 | ).unfold(1, self.max_length_span, 1)[offsets_start]
75 |
76 | return logits_classifier_end + mask
77 |
78 | def forward(self, batch, hidden_states):
79 |
80 | logits_classifier_start = self._forward_start(batch, hidden_states)
81 | offsets_start = batch["offsets_start"]
82 | logits_classifier_end = self._forward_end(batch, hidden_states, offsets_start)
83 |
84 | return logits_classifier_start, logits_classifier_end
85 |
86 | def forward_hard(self, batch, hidden_states, threshold=0):
87 |
88 | logits_classifier_start = self._forward_start(batch, hidden_states)
89 | offsets_start = logits_classifier_start > threshold
90 | logits_classifier_end = self._forward_end(batch, hidden_states, offsets_start)
91 |
92 | start = offsets_start.nonzero()
93 | end = start.clone()
94 |
95 | scores = None
96 | if logits_classifier_end.shape[0] > 0:
97 | end[:, 1] += logits_classifier_end.argmax(-1)
98 | scores = (
99 | logits_classifier_start[offsets_start]
100 | + logits_classifier_end.max(-1).values
101 | )
102 | if self.mentions:
103 | mention_mask = torch.tensor(
104 | [
105 | (
106 | batch["raw"][i]["input"][
107 | batch["src_offset_mapping"][i][s][0]
108 | .item() : batch["src_offset_mapping"][i][e][1]
109 | .item()
110 | ]
111 | in self.mentions
112 | )
113 | for (i, s), (_, e) in zip(start, end)
114 | ],
115 | device=start.device,
116 | )
117 |
118 | start = start[mention_mask]
119 | end = end[mention_mask]
120 |
121 | return (start, end, scores), (
122 | logits_classifier_start,
123 | logits_classifier_end,
124 | )
125 |
126 | def forward_loss(self, batch, hidden_states):
127 | logits_classifier_start, logits_classifier_end = self.forward(
128 | batch, hidden_states
129 | )
130 |
131 | batch["labels_start"] = torch.zeros_like(batch["src_input_ids"])
132 | batch["labels_start"][batch["offsets_start"]] = 1
133 |
134 | loss_start = torch.nn.functional.binary_cross_entropy_with_logits(
135 | logits_classifier_start,
136 | batch["labels_start"].float(),
137 | weight=batch["src_attention_mask"],
138 | )
139 |
140 | batch["labels_end"] = torch.tensor(
141 | [b - a for a, b in zip(batch["offsets_start"][1], batch["offsets_end"][1])],
142 | device=logits_classifier_start.device,
143 | )
144 |
145 | loss_end = torch.nn.functional.cross_entropy(
146 | logits_classifier_end,
147 | batch["labels_end"],
148 | )
149 |
150 | return loss_start, loss_end
151 |
--------------------------------------------------------------------------------
/notebooks/README.md:
--------------------------------------------------------------------------------
1 | ```python
2 | %load_ext autoreload
3 | %autoreload 2
4 |
5 | import sys
6 | sys.path.append("../")
7 | ```
8 |
9 |
10 | ```python
11 | from argparse import ArgumentParser
12 | from pytorch_lightning import Trainer
13 | from src.model.efficient_el import EfficientEL
14 | from src.data.dataset_el import DatasetEL
15 | from IPython.display import Markdown
16 | from src.utils import get_markdown
17 | ```
18 |
19 |
20 | ```python
21 | parser = ArgumentParser()
22 |
23 | parser = Trainer.add_argparse_args(parser)
24 |
25 | args, _ = parser.parse_known_args()
26 | args.gpus = 1
27 | args.precision = 16
28 |
29 | trainer = Trainer.from_argparse_args(args)
30 | ```
31 |
32 | GPU available: True, used: True
33 | TPU available: False, using: 0 TPU cores
34 | Using native 16bit precision.
35 |
36 |
37 |
38 | ```python
39 | model = EfficientEL.load_from_checkpoint("../models/model.ckpt").eval()
40 | ```
41 |
42 |
43 | ```python
44 | model.hparams.threshold = -3.2
45 | model.hparams.test_with_beam_search = False
46 | model.hparams.test_with_beam_search_no_candidates = False
47 | trainer.test(model, test_dataloaders=model.test_dataloader(), ckpt_path=None)
48 | ```
49 |
50 | --------------------------------------------------------------------------------
51 | DATALOADER:0 TEST RESULTS
52 | {'ed_macro_f1': 0.9203808307647705,
53 | 'ed_macro_prec': 0.9131189584732056,
54 | 'ed_macro_rec': 0.9390283226966858,
55 | 'ed_micro_f1': 0.9348137378692627,
56 | 'ed_micro_prec': 0.9219427704811096,
57 | 'ed_micro_rec': 0.9480490684509277,
58 | 'macro_f1': 0.8363054394721985,
59 | 'macro_prec': 0.8289670348167419,
60 | 'macro_rec': 0.8539509773254395,
61 | 'micro_f1': 0.8550071120262146,
62 | 'micro_prec': 0.8432350158691406,
63 | 'micro_rec': 0.8671125769615173}
64 | --------------------------------------------------------------------------------
65 |
66 |
67 |
68 | ```python
69 | model.generate_global_trie()
70 | ```
71 |
72 |
73 | ```python
74 | s = """CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 \
75 | West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset \
76 | by an innings and 39 runs in two days to take over at the head of the county championship . Their \
77 | stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed \
78 | in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire . \
79 | After bowling Somerset out for 83 on the opening morning at Grace Road , Leicestershire extended their \
80 | first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three \
81 | for 83 . Trailing by 213 , Somerset got a solid start to their second innings before Simmons stepped in \
82 | to bundle them out for 174 . Essex , however , look certain to regain their top spot after Nasser Hussain \
83 | and Peter Such gave them a firm grip on their match against Yorkshire at Headingley . Hussain , \
84 | considered surplus to England 's one-day requirements , struck 158 , his first championship century of \
85 | the season , as Essex reached 372 and took a first innings lead of 82 . By the close Yorkshire had turned \
86 | that into a 37-run advantage but off-spinner Such had scuttled their hopes , taking four for 24 in 48 balls
87 | \and leaving them hanging on 119 for five and praying for rain . At the Oval , Surrey captain Chris Lewis , \
88 | another man dumped by England , continued to silence his critics as he followed his four for 45 on Thursday \
89 | with 80 not out on Friday in the match against Warwickshire . He was well backed by England hopeful Mark \
90 | Butcher who made 70 as Surrey closed on 429 for seven , a lead of 234 . Derbyshire kept up the hunt for \
91 | their first championship title since 1936 by reducing Worcestershire to 133 for five in their second \
92 | innings , still 100 runs away from avoiding an innings defeat . Australian Tom Moody took six for 82 but \
93 | Chris Adams , 123 , and Tim O'Gorman , 109 , took Derbyshire to 471 and a first innings lead of 233 . \
94 | After the frustration of seeing the opening day of their match badly affected by the weather , Kent stepped \
95 | up a gear to dismiss Nottinghamshire for 214 . They were held up by a gritty 84 from Paul Johnson but \
96 | ex-England fast bowler Martin McCague took four for 55 . By stumps Kent had reached 108 for three ."""
97 |
98 | Markdown(get_markdown([s], [[(s[0], s[1], s[2][0][0]) for s in spans] for spans in model.sample([s])])[0])
99 | ```
100 |
101 |
102 |
103 |
104 | CRICKET - [LEICESTERSHIRE](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) TAKE OVER AT TOP AFTER INNINGS VICTORY . [LONDON](https://en.wikipedia.org/wiki/London) 1996-08-30 [West Indian](https://en.wikipedia.org/wiki/West_Indies) all-rounder [Phil Simmons](https://en.wikipedia.org/wiki/Philip_Walton) took four for 38 on Friday as [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) beat [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) by an innings and 39 runs in two days to take over at the head of the county championship . Their stay on top , though , may be short-lived as title rivals [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) and [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) all closed in on victory while [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) made up for lost time in their rain-affected match against [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) . After bowling [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) out for 83 on the opening morning at [Grace Road](https://en.wikipedia.org/wiki/Grace_Road) , [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) extended their first innings by 94 runs before being bowled out for 296 with [England](https://en.wikipedia.org/wiki/England_cricket_team) discard [Andy Caddick](https://en.wikipedia.org/wiki/Andrew_Caddick) taking three for 83 . Trailing by 213 , [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) got a solid start to their second innings before [Simmons](https://en.wikipedia.org/wiki/Singapore) stepped in to bundle them out for 174 . [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , however , look certain to regain their top spot after [Nasser Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) and [Peter Such](https://en.wikipedia.org/wiki/Peter_Thomson_(golfer)) gave them a firm grip on their match against [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) at [Headingley](https://en.wikipedia.org/wiki/Headingley_Stadium) . [Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) , considered surplus to [England](https://en.wikipedia.org/wiki/England_cricket_team) 's one-day requirements , struck 158 , his first championship century of the season , as [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) reached 372 and took a first innings lead of 82 . By the close [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) had turned that into a 37-run advantage but off-spinner [Such](https://en.wikipedia.org/wiki/Mark_Broadhurst) had scuttled their hopes , taking four for 24 in 48 balls
105 | nd leaving them hanging on 119 for five and praying for rain . At the [Oval](https://en.wikipedia.org/wiki/The_Oval) , [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) captain [Chris Lewis](https://en.wikipedia.org/wiki/Chris_Lewis_(cricketer)) , another man dumped by [England](https://en.wikipedia.org/wiki/England_cricket_team) , continued to silence his critics as he followed his four for 45 on Thursday with 80 not out on Friday in the match against [Warwickshire](https://en.wikipedia.org/wiki/Warwickshire_County_Cricket_Club) . He was well backed by [England](https://en.wikipedia.org/wiki/England_cricket_team) hopeful [Mark Butcher](https://en.wikipedia.org/wiki/Mark_Butcher) who made 70 as [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) closed on 429 for seven , a lead of 234 . [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) kept up the hunt for their first championship title since 1936 by reducing [Worcestershire](https://en.wikipedia.org/wiki/Worcestershire_County_Cricket_Club) to 133 for five in their second innings , still 100 runs away from avoiding an innings defeat . [Australian](https://en.wikipedia.org/wiki/Australia) [Tom Moody](https://en.wikipedia.org/wiki/Tommy_Haas) took six for 82 but [Chris Adams](https://en.wikipedia.org/wiki/Chris_Walker_(squash_player)) , 123 , and Tim O'Gorman , 109 , took [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) to 471 and a first innings lead of 233 . After the frustration of seeing the opening day of their match badly affected by the weather , [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) stepped up a gear to dismiss [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) for 214 . They were held up by a gritty 84 from [Paul Johnson](https://en.wikipedia.org/wiki/Paul_Johnson_(squash_player)) but ex-England fast bowler [Martin McCague](https://en.wikipedia.org/wiki/Martin_McCague) took four for 55 . By stumps [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) had reached 108 for three .
106 |
107 |
108 |
--------------------------------------------------------------------------------
/src/beam_search.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from: https://github.com/probabll/bayeseq/blob/master/aevnmt/components/beamsearch.py
3 | """
4 | import torch
5 | import torch.nn.functional as F
6 | from packaging import version
7 |
8 |
9 | # from onmt
10 | def tile(x, count, dim=0):
11 | """
12 | Tiles x on dimension dim count times. From OpenNMT. Used for beam search.
13 | :param x: tensor to tile
14 | :param count: number of tiles
15 | :param dim: dimension along which the tensor is tiled
16 | :return: tiled tensor
17 | """
18 | if isinstance(x, tuple):
19 | return [tile(e, count, dim=dim) for e in x]
20 |
21 | perm = list(range(len(x.size())))
22 | if dim != 0:
23 | perm[0], perm[dim] = perm[dim], perm[0]
24 | x = x.permute(perm).contiguous()
25 | out_size = list(x.size())
26 | out_size[0] *= count
27 | batch = x.size(0)
28 | x = (
29 | x.view(batch, -1)
30 | .transpose(0, 1)
31 | .repeat(count, 1)
32 | .transpose(0, 1)
33 | .contiguous()
34 | .view(*out_size)
35 | )
36 | if dim != 0:
37 | x = x.permute(perm).contiguous()
38 | return x
39 |
40 |
41 | def beam_search(
42 | decoder,
43 | tgt_vocab_size,
44 | hidden,
45 | bos_idx,
46 | eos_idx,
47 | pad_idx,
48 | beam_width,
49 | alpha=1,
50 | max_len=15,
51 | batch_trie_dict=None,
52 | ):
53 | """
54 | Beam search with size beam_width. Follows OpenNMT-py implementation.
55 | In each decoding step, find the k most likely partial hypotheses.
56 |
57 | :param decoder: an initialized decoder
58 | """
59 |
60 | decoder.eval()
61 | with torch.no_grad():
62 |
63 | # Initialize the hidden state and create the initial input.
64 | batch_size = (
65 | hidden[0].shape[0] if isinstance(hidden, tuple) else hidden.shape[0]
66 | )
67 | device = hidden[0].device if isinstance(hidden, tuple) else hidden.device
68 |
69 | prev_y = torch.full(
70 | size=[batch_size],
71 | fill_value=bos_idx,
72 | dtype=torch.long,
73 | device=device,
74 | )
75 |
76 | # Tile hidden decoder states and encoder outputs beam_width times
77 | hidden = tile(hidden, beam_width, dim=0) # [layers, B*beam_width, H_dec]
78 |
79 | batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
80 | beam_offset = torch.arange(
81 | 0,
82 | batch_size * beam_width,
83 | step=beam_width,
84 | dtype=torch.long,
85 | device=device,
86 | )
87 | alive_seq = torch.full(
88 | [batch_size * beam_width, 1],
89 | bos_idx,
90 | dtype=torch.long,
91 | device=device,
92 | )
93 |
94 | # Give full probability to the first beam on the first step.
95 | topk_log_probs = torch.tensor(
96 | [0.0] + [float("-inf")] * (beam_width - 1), device=device
97 | ).repeat(batch_size)
98 |
99 | # Structure that holds finished hypotheses.
100 | hypotheses = [[] for _ in range(batch_size)]
101 |
102 | results = {}
103 | results["predictions"] = [[] for _ in range(batch_size)]
104 | results["scores"] = [[] for _ in range(batch_size)]
105 | results["gold_score"] = [0] * batch_size
106 | results["contexts"] = [[] for _ in range(batch_size)]
107 |
108 | done = torch.full(
109 | [batch_size, beam_width],
110 | False,
111 | dtype=torch.bool,
112 | device=device,
113 | )
114 | trie_idx = (
115 | torch.arange(0, batch_size, device=device)
116 | .unsqueeze(-1)
117 | .repeat(1, beam_width)
118 | .view(-1)
119 | )
120 | for step in range(max_len):
121 | prev_y = alive_seq[:, -1].view(-1)
122 |
123 | # expand current hypotheses, decode one single step
124 | log_probs, hidden = decoder.step_beam_search(prev_y, hidden)
125 |
126 | if batch_trie_dict is not None:
127 | mask = torch.full_like(log_probs, -float("inf"))
128 | for i, (b_idx, tokens) in enumerate(
129 | zip(trie_idx.tolist(), alive_seq.tolist())
130 | ):
131 | idx = batch_trie_dict[b_idx].get(tuple(tokens), [])
132 | mask[[i] * len(idx), idx] = 0
133 |
134 | log_probs += mask
135 |
136 | # multiply probs by the beam probability (=add logprobs)
137 | log_probs += topk_log_probs.view(-1).unsqueeze(1)
138 | curr_scores = log_probs
139 |
140 | # compute length penalty
141 | if alpha > -1:
142 | length_penalty = (step + 1) ** alpha
143 | curr_scores /= length_penalty
144 |
145 | # flatten log_probs into a list of possibilities
146 | curr_scores = curr_scores.reshape(-1, beam_width * tgt_vocab_size)
147 |
148 | # pick currently best top beam_width hypotheses (flattened order)
149 | topk_scores, topk_ids = curr_scores.topk(beam_width, dim=-1)
150 |
151 | if alpha > -1:
152 | # recover original log probs
153 | topk_log_probs = topk_scores * length_penalty
154 |
155 | # reconstruct beam origin and true word ids from flattened order
156 | if version.parse(torch.__version__) >= version.parse("1.5.0"):
157 | topk_beam_index = topk_ids.floor_divide(tgt_vocab_size)
158 | else:
159 | topk_beam_index = topk_ids.div(tgt_vocab_size)
160 | topk_ids = topk_ids.fmod(tgt_vocab_size)
161 |
162 | # map beam_index to batch_index in the flat representation
163 | batch_index = topk_beam_index + beam_offset[
164 | : topk_beam_index.size(0)
165 | ].unsqueeze(1)
166 | select_indices = batch_index.view(-1)
167 |
168 | # append latest prediction
169 | alive_seq = torch.cat(
170 | [alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1
171 | ) # batch_size*k x hyp_len
172 |
173 | is_finished = (
174 | topk_ids.eq(eos_idx) & ~topk_scores.eq(-float("inf"))
175 | ) | topk_scores.eq(-float("inf")).all(-1, keepdim=True)
176 |
177 | if step + 1 == max_len:
178 | is_finished.fill_(1)
179 |
180 | done |= is_finished
181 |
182 | # end condition is whether the top beam is finished
183 | end_condition = done.all(-1)
184 |
185 | # for LSTMs, states are tuples of tensors
186 | hidden = [e.index_select(0, select_indices) for e in hidden]
187 | trie_idx = trie_idx.index_select(0, select_indices)
188 |
189 | # save finished hypotheses
190 | if is_finished.any():
191 | predictions = alive_seq.view(-1, beam_width, alive_seq.size(-1))
192 | contexts = hidden[1].view(-1, beam_width, hidden[1].shape[-1])
193 |
194 | for i in range(is_finished.size(0)):
195 | b = batch_offset[i]
196 | finished_hyp = is_finished[i].nonzero(as_tuple=False).view(-1)
197 |
198 | # store finished hypotheses for this batch
199 | for j in finished_hyp:
200 |
201 | hypotheses[b].append(
202 | (
203 | topk_scores[i, j],
204 | predictions[i, j],
205 | contexts[i, j].clone(),
206 | ) # ignore start_token
207 | )
208 | # if the batch reached the end, save the beam_width hypotheses
209 | if end_condition[i]:
210 | best_hyp = sorted(
211 | hypotheses[b], key=lambda x: x[0], reverse=True
212 | )
213 | for n, (score, pred, cont) in enumerate(best_hyp):
214 | if n >= beam_width:
215 | break
216 | results["scores"][b].append(score)
217 | results["predictions"][b].append(pred)
218 | results["contexts"][b].append(cont)
219 |
220 | if end_condition.any():
221 | non_finished = end_condition.eq(0).nonzero(as_tuple=False).view(-1)
222 |
223 | # if all sentences are translated, no need to go further
224 | if len(non_finished) == 0:
225 | break
226 |
227 | # remove finished batches for the next step
228 | topk_log_probs = topk_log_probs.index_select(0, non_finished)
229 | batch_index = batch_index.index_select(0, non_finished)
230 | batch_offset = batch_offset.index_select(0, non_finished)
231 | alive_seq = predictions.index_select(0, non_finished).view(
232 | -1, alive_seq.size(-1)
233 | )
234 | done = done.index_select(0, non_finished)
235 |
236 | # reorder indices, outputs and masks, and trie
237 | select_indices = batch_index.view(-1)
238 |
239 | # for LSTMs, states are tuples of tensors
240 | hidden = [e.index_select(0, select_indices) for e in hidden]
241 | trie_idx = trie_idx.index_select(0, select_indices)
242 |
243 | return (
244 | torch.nn.utils.rnn.pad_sequence(
245 | [
246 | torch.nn.utils.rnn.pad_sequence(
247 | e, batch_first=True, padding_value=pad_idx
248 | ).T
249 | for e in results["predictions"]
250 | ],
251 | batch_first=True,
252 | padding_value=pad_idx,
253 | ).permute(0, 2, 1),
254 | torch.stack([torch.stack(e) for e in results["scores"]]),
255 | torch.stack([torch.stack(e) for e in results["contexts"]]),
256 | )
257 |
--------------------------------------------------------------------------------
/src/model/entity_linking.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from src.beam_search import beam_search
4 | from src.utils import label_smoothed_nll_loss
5 |
6 |
7 | class LSTM(torch.nn.Module):
8 | def __init__(
9 | self, bos_token_id, pad_token_id, eos_token_id, embeddings, lm_head, dropout=0
10 | ):
11 | super().__init__()
12 |
13 | self.bos_token_id = bos_token_id
14 | self.pad_token_id = pad_token_id
15 | self.eos_token_id = eos_token_id
16 | self.embeddings = embeddings
17 | self.lm_head = lm_head
18 | self.dropout = dropout
19 |
20 | self.lstm_cell = torch.nn.LSTMCell(
21 | input_size=2 * 768,
22 | hidden_size=768,
23 | )
24 |
25 | def _roll(
26 | self,
27 | input_ids,
28 | attention_mask,
29 | decoder_hidden,
30 | decoder_context,
31 | decoder_append,
32 | return_lprob=False,
33 | return_dict=False,
34 | ):
35 |
36 | dropout_mask = 1
37 | if self.training:
38 | dropout_mask = (torch.rand_like(decoder_hidden) > self.dropout).float()
39 |
40 | all_hiddens = []
41 | all_contexts = []
42 |
43 | emb = self.embeddings(input_ids)
44 | for t in range(emb.shape[1]):
45 | decoder_hidden, decoder_context = self.lstm_cell(
46 | torch.cat((emb[:, t], decoder_append), dim=-1),
47 | (decoder_hidden, decoder_context),
48 | )
49 | decoder_hidden *= dropout_mask
50 | all_hiddens.append(decoder_hidden)
51 | all_contexts.append(decoder_context)
52 |
53 | all_hiddens = torch.stack(all_hiddens, dim=1)
54 | all_contexts = torch.stack(all_contexts, dim=1)
55 |
56 | all_contexts = all_contexts[
57 | [e for e in range(attention_mask.shape[0])], attention_mask.sum(-1) - 1
58 | ]
59 |
60 | if return_dict:
61 | outputs = {
62 | "all_hiddens": all_hiddens,
63 | "all_contexts": all_contexts,
64 | }
65 | else:
66 | outputs = (all_hiddens, all_contexts)
67 | if return_lprob:
68 | logits = self.lm_head(all_hiddens)
69 |
70 | if self.training:
71 | logits = logits.log_softmax(-1)
72 | else:
73 | logits.sub_(logits.max(-1, keepdim=True).values)
74 | logits.exp_()
75 | logits.div_(logits.sum(-1, keepdim=True))
76 | logits.log_()
77 |
78 | if return_dict:
79 | outputs = {
80 | "logits": logits,
81 | }
82 | else:
83 | outputs += (logits,)
84 |
85 | return outputs
86 |
87 | def step_beam_search(self, previous_tokens, hidden):
88 | decoder_hidden, decoder_context, decoder_append = hidden
89 |
90 | emb = self.embeddings(previous_tokens)
91 | decoder_hidden, decoder_context = self.lstm_cell(
92 | torch.cat((emb, decoder_append), dim=-1),
93 | (decoder_hidden, decoder_context),
94 | )
95 |
96 | logits = self.lm_head(decoder_hidden)
97 | logits.sub_(logits.max(-1, keepdim=True).values)
98 | logits.exp_()
99 | logits.div_(logits.sum(-1, keepdim=True))
100 | logits.log_()
101 |
102 | return logits, (decoder_hidden, decoder_context, decoder_append)
103 |
104 | def forward(self, batch, decoder_hidden, decoder_context, decoder_append):
105 |
106 | _, all_contexts_positive, lprobs = self._roll(
107 | batch["trg_input_ids"][:, :-1],
108 | batch["trg_attention_mask"][:, 1:],
109 | decoder_hidden,
110 | decoder_context,
111 | decoder_append,
112 | return_lprob=True,
113 | )
114 |
115 | _, all_contexts_negative = self._roll(
116 | batch["neg_input_ids"][:, :-1],
117 | batch["neg_attention_mask"][:, 1:],
118 | decoder_hidden[batch["neg_mask"]],
119 | decoder_context[batch["neg_mask"]],
120 | decoder_append[batch["neg_mask"]],
121 | return_lprob=False,
122 | )
123 |
124 | return all_contexts_positive, all_contexts_negative, lprobs
125 |
126 | def forward_all_targets(
127 | self, batch, decoder_hidden, decoder_context, decoder_append
128 | ):
129 |
130 | _, all_contexts, lprobs = self._roll(
131 | batch["cand_input_ids"][:, :-1],
132 | batch["cand_attention_mask"][:, 1:],
133 | decoder_hidden,
134 | decoder_context,
135 | decoder_append,
136 | return_lprob=True,
137 | )
138 |
139 | scores = (
140 | lprobs.gather(
141 | dim=-1, index=batch["cand_input_ids"][:, 1:].unsqueeze(-1)
142 | ).squeeze(-1)
143 | * batch["cand_attention_mask"][:, 1:]
144 | )
145 |
146 | return all_contexts, scores.sum(-1) / (batch["cand_attention_mask"].sum(-1) - 1)
147 |
148 | def forward_beam_search(self, batch, hidden_states):
149 | raise NotImplemented
150 |
151 |
152 | class EntityLinkingLSTM(torch.nn.Module):
153 | def __init__(
154 | self, bos_token_id, pad_token_id, eos_token_id, embeddings, lm_head, dropout=0
155 | ):
156 | super().__init__()
157 |
158 | self.bos_token_id = bos_token_id
159 | self.pad_token_id = pad_token_id
160 | self.eos_token_id = eos_token_id
161 |
162 | self.prj = torch.nn.Sequential(
163 | torch.nn.LayerNorm(768 * 2),
164 | torch.nn.Dropout(dropout),
165 | torch.nn.Linear(768 * 2, 768),
166 | torch.nn.ReLU(),
167 | torch.nn.LayerNorm(768),
168 | torch.nn.Dropout(dropout),
169 | torch.nn.Linear(768, 768 * 3),
170 | )
171 |
172 | self.lstm = LSTM(
173 | bos_token_id,
174 | pad_token_id,
175 | eos_token_id,
176 | embeddings,
177 | lm_head,
178 | dropout=dropout,
179 | )
180 |
181 | self.classifier = torch.nn.Sequential(
182 | torch.nn.LayerNorm(768 * 2),
183 | torch.nn.Dropout(dropout),
184 | torch.nn.Linear(768 * 2, 768),
185 | torch.nn.ReLU(),
186 | torch.nn.LayerNorm(768),
187 | torch.nn.Dropout(dropout),
188 | torch.nn.Linear(768, 1),
189 | )
190 |
191 | def _get_hidden_context_append_vectors(self, batch, hidden_states):
192 | return self.prj(
193 | torch.cat(
194 | (
195 | hidden_states[batch["offsets_start"]],
196 | hidden_states[batch["offsets_end"]],
197 | ),
198 | dim=-1,
199 | )
200 | ).split([768, 768, 768], dim=-1)
201 |
202 | def forward(self, batch, hidden_states):
203 | (
204 | decoder_hidden,
205 | decoder_context,
206 | decoder_append,
207 | ) = self._get_hidden_context_append_vectors(batch, hidden_states)
208 |
209 | (
210 | all_contexts_positive,
211 | all_contexts_negative,
212 | lprobs_lm,
213 | ) = self.lstm.forward(batch, decoder_hidden, decoder_context, decoder_append)
214 |
215 | logits_classifier = self.classifier(
216 | torch.cat(
217 | (
218 | decoder_append[batch["neg_mask"]].unsqueeze(1).repeat(1, 2, 1),
219 | torch.stack(
220 | (
221 | all_contexts_positive[batch["neg_mask"]],
222 | all_contexts_negative,
223 | ),
224 | dim=1,
225 | ),
226 | ),
227 | dim=-1,
228 | )
229 | ).squeeze(-1)
230 |
231 | return lprobs_lm, logits_classifier
232 |
233 | def forward_loss(self, batch, hidden_states, epsilon=0):
234 |
235 | lprobs_lm, logits_classifier = self.forward(batch, hidden_states)
236 |
237 | loss_generation, _ = label_smoothed_nll_loss(
238 | lprobs_lm,
239 | batch["trg_input_ids"][:, 1:],
240 | epsilon=epsilon,
241 | ignore_index=self.pad_token_id,
242 | )
243 | loss_generation = loss_generation / batch["trg_attention_mask"][:, 1:].sum()
244 |
245 | loss_classifier = torch.nn.functional.cross_entropy(
246 | logits_classifier,
247 | torch.zeros(
248 | (logits_classifier.shape[0]),
249 | dtype=torch.long,
250 | device=logits_classifier.device,
251 | ),
252 | )
253 |
254 | return loss_generation, loss_classifier
255 |
256 | def forward_all_targets(self, batch, hidden_states):
257 | (
258 | decoder_hidden,
259 | decoder_context,
260 | decoder_append,
261 | ) = self._get_hidden_context_append_vectors(batch, hidden_states)
262 |
263 | all_contexts, lm_scores = self.lstm.forward_all_targets(
264 | batch,
265 | decoder_hidden[batch["offsets_candidates"]],
266 | decoder_context[batch["offsets_candidates"]],
267 | decoder_append[batch["offsets_candidates"]],
268 | )
269 |
270 | classifier_scores = self.classifier(
271 | torch.cat(
272 | (
273 | decoder_append[batch["offsets_candidates"]],
274 | all_contexts,
275 | ),
276 | dim=-1,
277 | )
278 | ).squeeze(-1)
279 |
280 | scores = lm_scores + classifier_scores
281 |
282 | classifier_scores = torch.cat(
283 | [
284 | e.log_softmax(-1)
285 | for e in classifier_scores.split(batch["split_candidates"])
286 | ]
287 | )
288 |
289 | tokens = [
290 | t[s.argsort(descending=True)]
291 | for s, t in zip(
292 | scores.split(batch["split_candidates"]),
293 | batch["cand_input_ids"].split(batch["split_candidates"]),
294 | )
295 | ]
296 |
297 | tokens = torch.nn.utils.rnn.pad_sequence(
298 | tokens, batch_first=True, padding_value=self.pad_token_id
299 | )
300 |
301 | scores = torch.nn.utils.rnn.pad_sequence(
302 | [
303 | e.sort(descending=True).values
304 | for e in scores.split(batch["split_candidates"])
305 | ],
306 | batch_first=True,
307 | padding_value=-float("inf"),
308 | )
309 |
310 | return tokens, scores
311 |
312 | def forward_beam_search(
313 | self, batch, hidden_states, batch_trie_dict=None, beams=5, alpha=1, max_len=15
314 | ):
315 | (
316 | decoder_hidden,
317 | decoder_context,
318 | decoder_append,
319 | ) = self._get_hidden_context_append_vectors(batch, hidden_states)
320 |
321 | tokens, lm_scores, all_contexts = beam_search(
322 | self.lstm,
323 | self.lstm.lm_head.decoder.out_features,
324 | (decoder_hidden, decoder_context, decoder_append),
325 | self.bos_token_id,
326 | self.eos_token_id,
327 | self.pad_token_id,
328 | beam_width=beams,
329 | alpha=alpha,
330 | max_len=max_len,
331 | batch_trie_dict=batch_trie_dict,
332 | )
333 |
334 | classifier_scores = self.classifier(
335 | torch.cat(
336 | (
337 | decoder_append.unsqueeze(1).repeat(1, beams, 1),
338 | all_contexts,
339 | ),
340 | dim=-1,
341 | )
342 | ).squeeze(-1)
343 |
344 | classifier_scores[lm_scores == -float("inf")] = -float("inf")
345 | classifier_scores = classifier_scores.log_softmax(-1)
346 |
347 | scores = (classifier_scores + lm_scores).sort(-1, descending=True)
348 | tokens = tokens[
349 | torch.arange(scores.indices.shape[0], device=tokens.device)
350 | .unsqueeze(-1)
351 | .repeat(1, beams),
352 | scores.indices,
353 | ]
354 | scores = scores.values
355 |
356 | return tokens, scores
357 |
--------------------------------------------------------------------------------
/notebooks/Test.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "7c056d36",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "%load_ext autoreload\n",
11 | "%autoreload 2\n",
12 | "\n",
13 | "import sys\n",
14 | "sys.path.append(\"../\")"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "id": "759b3485",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "from argparse import ArgumentParser\n",
25 | "from pytorch_lightning import Trainer\n",
26 | "from src.model.efficient_el import EfficientEL\n",
27 | "from src.data.dataset_el import DatasetEL\n",
28 | "from IPython.display import Markdown\n",
29 | "from src.utils import get_markdown"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 3,
35 | "id": "f7a51eca",
36 | "metadata": {},
37 | "outputs": [
38 | {
39 | "name": "stderr",
40 | "output_type": "stream",
41 | "text": [
42 | "GPU available: True, used: True\n",
43 | "TPU available: False, using: 0 TPU cores\n",
44 | "Using native 16bit precision.\n"
45 | ]
46 | }
47 | ],
48 | "source": [
49 | "parser = ArgumentParser()\n",
50 | "\n",
51 | "parser = Trainer.add_argparse_args(parser)\n",
52 | "\n",
53 | "args, _ = parser.parse_known_args()\n",
54 | "args.gpus = 1\n",
55 | "args.precision = 16\n",
56 | "\n",
57 | "trainer = Trainer.from_argparse_args(args)"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "id": "2e625161",
64 | "metadata": {
65 | "scrolled": true
66 | },
67 | "outputs": [
68 | {
69 | "name": "stderr",
70 | "output_type": "stream",
71 | "text": [
72 | "Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForMaskedLM: ['longformer.encoder.layer.8.attention.self.query.weight', 'longformer.encoder.layer.8.attention.self.query.bias', 'longformer.encoder.layer.8.attention.self.key.weight', 'longformer.encoder.layer.8.attention.self.key.bias', 'longformer.encoder.layer.8.attention.self.value.weight', 'longformer.encoder.layer.8.attention.self.value.bias', 'longformer.encoder.layer.8.attention.self.query_global.weight', 'longformer.encoder.layer.8.attention.self.query_global.bias', 'longformer.encoder.layer.8.attention.self.key_global.weight', 'longformer.encoder.layer.8.attention.self.key_global.bias', 'longformer.encoder.layer.8.attention.self.value_global.weight', 'longformer.encoder.layer.8.attention.self.value_global.bias', 'longformer.encoder.layer.8.attention.output.dense.weight', 'longformer.encoder.layer.8.attention.output.dense.bias', 'longformer.encoder.layer.8.attention.output.LayerNorm.weight', 'longformer.encoder.layer.8.attention.output.LayerNorm.bias', 'longformer.encoder.layer.8.intermediate.dense.weight', 'longformer.encoder.layer.8.intermediate.dense.bias', 'longformer.encoder.layer.8.output.dense.weight', 'longformer.encoder.layer.8.output.dense.bias', 'longformer.encoder.layer.8.output.LayerNorm.weight', 'longformer.encoder.layer.8.output.LayerNorm.bias', 'longformer.encoder.layer.9.attention.self.query.weight', 'longformer.encoder.layer.9.attention.self.query.bias', 'longformer.encoder.layer.9.attention.self.key.weight', 'longformer.encoder.layer.9.attention.self.key.bias', 'longformer.encoder.layer.9.attention.self.value.weight', 'longformer.encoder.layer.9.attention.self.value.bias', 'longformer.encoder.layer.9.attention.self.query_global.weight', 'longformer.encoder.layer.9.attention.self.query_global.bias', 'longformer.encoder.layer.9.attention.self.key_global.weight', 'longformer.encoder.layer.9.attention.self.key_global.bias', 'longformer.encoder.layer.9.attention.self.value_global.weight', 'longformer.encoder.layer.9.attention.self.value_global.bias', 'longformer.encoder.layer.9.attention.output.dense.weight', 'longformer.encoder.layer.9.attention.output.dense.bias', 'longformer.encoder.layer.9.attention.output.LayerNorm.weight', 'longformer.encoder.layer.9.attention.output.LayerNorm.bias', 'longformer.encoder.layer.9.intermediate.dense.weight', 'longformer.encoder.layer.9.intermediate.dense.bias', 'longformer.encoder.layer.9.output.dense.weight', 'longformer.encoder.layer.9.output.dense.bias', 'longformer.encoder.layer.9.output.LayerNorm.weight', 'longformer.encoder.layer.9.output.LayerNorm.bias', 'longformer.encoder.layer.10.attention.self.query.weight', 'longformer.encoder.layer.10.attention.self.query.bias', 'longformer.encoder.layer.10.attention.self.key.weight', 'longformer.encoder.layer.10.attention.self.key.bias', 'longformer.encoder.layer.10.attention.self.value.weight', 'longformer.encoder.layer.10.attention.self.value.bias', 'longformer.encoder.layer.10.attention.self.query_global.weight', 'longformer.encoder.layer.10.attention.self.query_global.bias', 'longformer.encoder.layer.10.attention.self.key_global.weight', 'longformer.encoder.layer.10.attention.self.key_global.bias', 'longformer.encoder.layer.10.attention.self.value_global.weight', 'longformer.encoder.layer.10.attention.self.value_global.bias', 'longformer.encoder.layer.10.attention.output.dense.weight', 'longformer.encoder.layer.10.attention.output.dense.bias', 'longformer.encoder.layer.10.attention.output.LayerNorm.weight', 'longformer.encoder.layer.10.attention.output.LayerNorm.bias', 'longformer.encoder.layer.10.intermediate.dense.weight', 'longformer.encoder.layer.10.intermediate.dense.bias', 'longformer.encoder.layer.10.output.dense.weight', 'longformer.encoder.layer.10.output.dense.bias', 'longformer.encoder.layer.10.output.LayerNorm.weight', 'longformer.encoder.layer.10.output.LayerNorm.bias', 'longformer.encoder.layer.11.attention.self.query.weight', 'longformer.encoder.layer.11.attention.self.query.bias', 'longformer.encoder.layer.11.attention.self.key.weight', 'longformer.encoder.layer.11.attention.self.key.bias', 'longformer.encoder.layer.11.attention.self.value.weight', 'longformer.encoder.layer.11.attention.self.value.bias', 'longformer.encoder.layer.11.attention.self.query_global.weight', 'longformer.encoder.layer.11.attention.self.query_global.bias', 'longformer.encoder.layer.11.attention.self.key_global.weight', 'longformer.encoder.layer.11.attention.self.key_global.bias', 'longformer.encoder.layer.11.attention.self.value_global.weight', 'longformer.encoder.layer.11.attention.self.value_global.bias', 'longformer.encoder.layer.11.attention.output.dense.weight', 'longformer.encoder.layer.11.attention.output.dense.bias', 'longformer.encoder.layer.11.attention.output.LayerNorm.weight', 'longformer.encoder.layer.11.attention.output.LayerNorm.bias', 'longformer.encoder.layer.11.intermediate.dense.weight', 'longformer.encoder.layer.11.intermediate.dense.bias', 'longformer.encoder.layer.11.output.dense.weight', 'longformer.encoder.layer.11.output.dense.bias', 'longformer.encoder.layer.11.output.LayerNorm.weight', 'longformer.encoder.layer.11.output.LayerNorm.bias']\n",
73 | "- This IS expected if you are initializing LongformerForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
74 | "- This IS NOT expected if you are initializing LongformerForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
75 | "Some weights of LongformerForMaskedLM were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['lm_head.decoder.bias']\n",
76 | "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "model = EfficientEL.load_from_checkpoint(\"../models/model.ckpt\").eval()"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 5,
87 | "id": "ea367df0",
88 | "metadata": {},
89 | "outputs": [
90 | {
91 | "name": "stderr",
92 | "output_type": "stream",
93 | "text": [
94 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]\n",
95 | "/home/ndecao/.anaconda3/envs/nlp38/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:68: UserWarning: The dataloader, test dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n",
96 | " warnings.warn(*args, **kwargs)\n"
97 | ]
98 | },
99 | {
100 | "data": {
101 | "application/vnd.jupyter.widget-view+json": {
102 | "model_id": "71b95647f19c4d428c96e0fa3e8603d9",
103 | "version_major": 2,
104 | "version_minor": 0
105 | },
106 | "text/plain": [
107 | "Testing: 0it [00:00, ?it/s]"
108 | ]
109 | },
110 | "metadata": {},
111 | "output_type": "display_data"
112 | },
113 | {
114 | "name": "stdout",
115 | "output_type": "stream",
116 | "text": [
117 | "--------------------------------------------------------------------------------\n",
118 | "DATALOADER:0 TEST RESULTS\n",
119 | "{'ed_macro_f1': 0.9203808307647705,\n",
120 | " 'ed_macro_prec': 0.9131189584732056,\n",
121 | " 'ed_macro_rec': 0.9390283226966858,\n",
122 | " 'ed_micro_f1': 0.9348137378692627,\n",
123 | " 'ed_micro_prec': 0.9219427704811096,\n",
124 | " 'ed_micro_rec': 0.9480490684509277,\n",
125 | " 'macro_f1': 0.8363054394721985,\n",
126 | " 'macro_prec': 0.8289670348167419,\n",
127 | " 'macro_rec': 0.8539509773254395,\n",
128 | " 'micro_f1': 0.8550071120262146,\n",
129 | " 'micro_prec': 0.8432350158691406,\n",
130 | " 'micro_rec': 0.8671125769615173}\n",
131 | "--------------------------------------------------------------------------------\n"
132 | ]
133 | },
134 | {
135 | "data": {
136 | "text/plain": [
137 | "[{'micro_f1': 0.8550071120262146,\n",
138 | " 'ed_micro_f1': 0.9348137378692627,\n",
139 | " 'micro_prec': 0.8432350158691406,\n",
140 | " 'macro_rec': 0.8539509773254395,\n",
141 | " 'macro_f1': 0.8363054394721985,\n",
142 | " 'macro_prec': 0.8289670348167419,\n",
143 | " 'micro_rec': 0.8671125769615173,\n",
144 | " 'ed_micro_prec': 0.9219427704811096,\n",
145 | " 'ed_micro_rec': 0.9480490684509277,\n",
146 | " 'ed_macro_f1': 0.9203808307647705,\n",
147 | " 'ed_macro_prec': 0.9131189584732056,\n",
148 | " 'ed_macro_rec': 0.9390283226966858}]"
149 | ]
150 | },
151 | "execution_count": 5,
152 | "metadata": {},
153 | "output_type": "execute_result"
154 | }
155 | ],
156 | "source": [
157 | "model.hparams.threshold = -3.2\n",
158 | "model.hparams.test_with_beam_search = False\n",
159 | "model.hparams.test_with_beam_search_no_candidates = False\n",
160 | "trainer.test(model, test_dataloaders=model.test_dataloader(), ckpt_path=None)"
161 | ]
162 | },
163 | {
164 | "cell_type": "code",
165 | "execution_count": 6,
166 | "id": "ec5eeeb5",
167 | "metadata": {},
168 | "outputs": [
169 | {
170 | "data": {
171 | "application/vnd.jupyter.widget-view+json": {
172 | "model_id": "a0701f715e594adf98ab3f00452b64b2",
173 | "version_major": 2,
174 | "version_minor": 0
175 | },
176 | "text/plain": [
177 | "Loading ..: 0%| | 0/470105 [00:00, ?it/s]"
178 | ]
179 | },
180 | "metadata": {},
181 | "output_type": "display_data"
182 | }
183 | ],
184 | "source": [
185 | "model.generate_global_trie()"
186 | ]
187 | },
188 | {
189 | "cell_type": "code",
190 | "execution_count": 7,
191 | "id": "e259e626",
192 | "metadata": {},
193 | "outputs": [],
194 | "source": [
195 | "s = \"\"\"CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY . LONDON 1996-08-30 \\\n",
196 | "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset \\\n",
197 | "by an innings and 39 runs in two days to take over at the head of the county championship . Their \\\n",
198 | "stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed \\\n",
199 | "in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire . \\\n",
200 | "After bowling Somerset out for 83 on the opening morning at Grace Road , Leicestershire extended their \\\n",
201 | "first innings by 94 runs before being bowled out for 296 with England discard Andy Caddick taking three \\\n",
202 | "for 83 . Trailing by 213 , Somerset got a solid start to their second innings before Simmons stepped in \\\n",
203 | "to bundle them out for 174 . Essex , however , look certain to regain their top spot after Nasser Hussain \\\n",
204 | "and Peter Such gave them a firm grip on their match against Yorkshire at Headingley . Hussain , \\\n",
205 | "considered surplus to England 's one-day requirements , struck 158 , his first championship century of \\\n",
206 | "the season , as Essex reached 372 and took a first innings lead of 82 . By the close Yorkshire had turned \\\n",
207 | "that into a 37-run advantage but off-spinner Such had scuttled their hopes , taking four for 24 in 48 balls \n",
208 | "\\and leaving them hanging on 119 for five and praying for rain . At the Oval , Surrey captain Chris Lewis , \\\n",
209 | "another man dumped by England , continued to silence his critics as he followed his four for 45 on Thursday \\\n",
210 | "with 80 not out on Friday in the match against Warwickshire . He was well backed by England hopeful Mark \\\n",
211 | "Butcher who made 70 as Surrey closed on 429 for seven , a lead of 234 . Derbyshire kept up the hunt for \\\n",
212 | "their first championship title since 1936 by reducing Worcestershire to 133 for five in their second \\\n",
213 | "innings , still 100 runs away from avoiding an innings defeat . Australian Tom Moody took six for 82 but \\\n",
214 | "Chris Adams , 123 , and Tim O'Gorman , 109 , took Derbyshire to 471 and a first innings lead of 233 . \\\n",
215 | "After the frustration of seeing the opening day of their match badly affected by the weather , Kent stepped \\\n",
216 | "up a gear to dismiss Nottinghamshire for 214 . They were held up by a gritty 84 from Paul Johnson but \\\n",
217 | "ex-England fast bowler Martin McCague took four for 55 . By stumps Kent had reached 108 for three .\"\"\""
218 | ]
219 | },
220 | {
221 | "cell_type": "code",
222 | "execution_count": 8,
223 | "id": "908a75ea",
224 | "metadata": {},
225 | "outputs": [
226 | {
227 | "data": {
228 | "text/markdown": [
229 | "CRICKET - [LEICESTERSHIRE](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) TAKE OVER AT TOP AFTER INNINGS VICTORY . [LONDON](https://en.wikipedia.org/wiki/London) 1996-08-30 [West Indian](https://en.wikipedia.org/wiki/West_Indies) all-rounder [Phil Simmons](https://en.wikipedia.org/wiki/Philip_Walton) took four for 38 on Friday as [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) beat [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) by an innings and 39 runs in two days to take over at the head of the county championship . Their stay on top , though , may be short-lived as title rivals [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) and [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) all closed in on victory while [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) made up for lost time in their rain-affected match against [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) . After bowling [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) out for 83 on the opening morning at [Grace Road](https://en.wikipedia.org/wiki/Grace_Road) , [Leicestershire](https://en.wikipedia.org/wiki/Leicestershire_County_Cricket_Club) extended their first innings by 94 runs before being bowled out for 296 with [England](https://en.wikipedia.org/wiki/England_cricket_team) discard [Andy Caddick](https://en.wikipedia.org/wiki/Andrew_Caddick) taking three for 83 . Trailing by 213 , [Somerset](https://en.wikipedia.org/wiki/Somerset_County_Cricket_Club) got a solid start to their second innings before [Simmons](https://en.wikipedia.org/wiki/Singapore) stepped in to bundle them out for 174 . [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) , however , look certain to regain their top spot after [Nasser Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) and [Peter Such](https://en.wikipedia.org/wiki/Peter_Thomson_(golfer)) gave them a firm grip on their match against [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) at [Headingley](https://en.wikipedia.org/wiki/Headingley_Stadium) . [Hussain](https://en.wikipedia.org/wiki/Nasser_Hussain) , considered surplus to [England](https://en.wikipedia.org/wiki/England_cricket_team) 's one-day requirements , struck 158 , his first championship century of the season , as [Essex](https://en.wikipedia.org/wiki/Essex_County_Cricket_Club) reached 372 and took a first innings lead of 82 . By the close [Yorkshire](https://en.wikipedia.org/wiki/Yorkshire_County_Cricket_Club) had turned that into a 37-run advantage but off-spinner [Such](https://en.wikipedia.org/wiki/Mark_Broadhurst) had scuttled their hopes , taking four for 24 in 48 balls \n",
230 | "\u0007nd leaving them hanging on 119 for five and praying for rain . At the [Oval](https://en.wikipedia.org/wiki/The_Oval) , [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) captain [Chris Lewis](https://en.wikipedia.org/wiki/Chris_Lewis_(cricketer)) , another man dumped by [England](https://en.wikipedia.org/wiki/England_cricket_team) , continued to silence his critics as he followed his four for 45 on Thursday with 80 not out on Friday in the match against [Warwickshire](https://en.wikipedia.org/wiki/Warwickshire_County_Cricket_Club) . He was well backed by [England](https://en.wikipedia.org/wiki/England_cricket_team) hopeful [Mark Butcher](https://en.wikipedia.org/wiki/Mark_Butcher) who made 70 as [Surrey](https://en.wikipedia.org/wiki/Surrey_County_Cricket_Club) closed on 429 for seven , a lead of 234 . [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) kept up the hunt for their first championship title since 1936 by reducing [Worcestershire](https://en.wikipedia.org/wiki/Worcestershire_County_Cricket_Club) to 133 for five in their second innings , still 100 runs away from avoiding an innings defeat . [Australian](https://en.wikipedia.org/wiki/Australia) [Tom Moody](https://en.wikipedia.org/wiki/Tommy_Haas) took six for 82 but [Chris Adams](https://en.wikipedia.org/wiki/Chris_Walker_(squash_player)) , 123 , and Tim O'Gorman , 109 , took [Derbyshire](https://en.wikipedia.org/wiki/Derbyshire_County_Cricket_Club) to 471 and a first innings lead of 233 . After the frustration of seeing the opening day of their match badly affected by the weather , [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) stepped up a gear to dismiss [Nottinghamshire](https://en.wikipedia.org/wiki/Nottinghamshire_County_Cricket_Club) for 214 . They were held up by a gritty 84 from [Paul Johnson](https://en.wikipedia.org/wiki/Paul_Johnson_(squash_player)) but ex-England fast bowler [Martin McCague](https://en.wikipedia.org/wiki/Martin_McCague) took four for 55 . By stumps [Kent](https://en.wikipedia.org/wiki/Kent_County_Cricket_Club) had reached 108 for three ."
231 | ],
232 | "text/plain": [
233 | ""
234 | ]
235 | },
236 | "execution_count": 8,
237 | "metadata": {},
238 | "output_type": "execute_result"
239 | }
240 | ],
241 | "source": [
242 | "Markdown(get_markdown([s], [[(s[0], s[1], s[2][0][0]) for s in spans] for spans in model.sample([s])])[0])"
243 | ]
244 | }
245 | ],
246 | "metadata": {
247 | "kernelspec": {
248 | "display_name": "Python 3",
249 | "language": "python",
250 | "name": "python3"
251 | },
252 | "language_info": {
253 | "codemirror_mode": {
254 | "name": "ipython",
255 | "version": 3
256 | },
257 | "file_extension": ".py",
258 | "mimetype": "text/x-python",
259 | "name": "python",
260 | "nbconvert_exporter": "python",
261 | "pygments_lexer": "ipython3",
262 | "version": "3.8.8"
263 | }
264 | },
265 | "nbformat": 4,
266 | "nbformat_minor": 5
267 | }
268 |
--------------------------------------------------------------------------------
/src/model/efficient_el.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | from argparse import ArgumentParser
4 | from collections import defaultdict
5 |
6 | import pytorch_lightning as pl
7 | import torch
8 | from pytorch_lightning import LightningModule
9 | from torch.utils.data import DataLoader
10 | from tqdm.auto import tqdm
11 | from transformers import (
12 | AutoTokenizer,
13 | LongformerForMaskedLM,
14 | get_linear_schedule_with_warmup,
15 | )
16 |
17 | from src.data.dataset_el import DatasetEL
18 | from src.model.entity_detection import EntityDetectionFactor
19 | from src.model.entity_linking import EntityLinkingLSTM
20 | from src.utils import (
21 | MacroF1,
22 | MacroPrecision,
23 | MacroRecall,
24 | MicroF1,
25 | MicroPrecision,
26 | MicroRecall,
27 | )
28 |
29 |
30 | class EfficientEL(LightningModule):
31 | @staticmethod
32 | def add_model_specific_args(parent_parser):
33 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
34 | parser.add_argument(
35 | "--train_data_path",
36 | type=str,
37 | default="../data/aida_train_dataset.jsonl",
38 | )
39 | parser.add_argument(
40 | "--dev_data_path",
41 | type=str,
42 | default="../data/aida_val_dataset.jsonl",
43 | )
44 | parser.add_argument(
45 | "--test_data_path",
46 | type=str,
47 | default="../data/aida_test_dataset.jsonl",
48 | )
49 | parser.add_argument("--batch_size", type=int, default=2)
50 | parser.add_argument("--lr_transformer", type=float, default=1e-4)
51 | parser.add_argument("--lr", type=float, default=1e-3)
52 | parser.add_argument("--max_length_train", type=int, default=1024)
53 | parser.add_argument("--max_length", type=int, default=4096)
54 | parser.add_argument("--weight_decay", type=int, default=0.01)
55 | parser.add_argument("--total_num_updates", type=int, default=10000)
56 | parser.add_argument("--warmup_updates", type=int, default=500)
57 | parser.add_argument("--num_workers", type=int, default=0)
58 | parser.add_argument("--dropout", type=float, default=0.1)
59 | parser.add_argument("--max_length_span", type=int, default=15)
60 | parser.add_argument("--threshold", type=int, default=0)
61 | parser.add_argument("--test_with_beam_search", action="store_true")
62 | parser.add_argument(
63 | "--test_with_beam_search_no_candidates", action="store_true"
64 | )
65 | parser.add_argument(
66 | "--model_name", type=str, default="allenai/longformer-base-4096"
67 | )
68 | parser.add_argument(
69 | "--mentions_filename",
70 | type=str,
71 | default="../data/mentions.json",
72 | )
73 | parser.add_argument(
74 | "--entities_filename",
75 | type=str,
76 | default="../data/entities.json",
77 | )
78 | parser.add_argument("--epsilon", type=float, default=0.1)
79 | return parser
80 |
81 | def __init__(self, *args, **kwargs):
82 | super().__init__()
83 | self.save_hyperparameters()
84 |
85 | self.tokenizer = AutoTokenizer.from_pretrained(self.hparams.model_name)
86 |
87 | longformer = LongformerForMaskedLM.from_pretrained(
88 | self.hparams.model_name,
89 | num_hidden_layers=8,
90 | attention_window=[128] * 8,
91 | )
92 |
93 | self.encoder = longformer.longformer
94 |
95 | self.encoder.embeddings.word_embeddings.weight.requires_grad_(False)
96 |
97 | self.entity_detection = EntityDetectionFactor(
98 | self.hparams.max_length_span,
99 | self.hparams.dropout,
100 | mentions_filename=self.hparams.mentions_filename,
101 | )
102 |
103 | self.entity_linking = EntityLinkingLSTM(
104 | self.tokenizer.bos_token_id,
105 | self.tokenizer.pad_token_id,
106 | self.tokenizer.eos_token_id,
107 | self.encoder.embeddings.word_embeddings,
108 | longformer.lm_head,
109 | self.hparams.dropout,
110 | )
111 |
112 | self.micro_f1 = MicroF1()
113 | self.micro_prec = MicroPrecision()
114 | self.micro_rec = MicroRecall()
115 |
116 | self.macro_f1 = MacroF1()
117 | self.macro_prec = MacroPrecision()
118 | self.macro_rec = MacroRecall()
119 |
120 | self.ed_micro_f1 = MicroF1()
121 | self.ed_micro_prec = MicroPrecision()
122 | self.ed_micro_rec = MicroRecall()
123 |
124 | self.ed_macro_f1 = MacroF1()
125 | self.ed_macro_prec = MacroPrecision()
126 | self.ed_macro_rec = MacroRecall()
127 |
128 | def train_dataloader(self, shuffle=True):
129 | if not hasattr(self, "train_dataset") or self.hparams.sharded:
130 | self.train_dataset = DatasetEL(
131 | tokenizer=self.tokenizer,
132 | data_path=self.hparams.train_data_path,
133 | max_length=self.hparams.max_length_train,
134 | max_length_span=self.hparams.max_length_span,
135 | )
136 | return DataLoader(
137 | self.train_dataset,
138 | batch_size=self.hparams.batch_size,
139 | collate_fn=self.train_dataset.collate_fn,
140 | num_workers=self.hparams.num_workers,
141 | shuffle=shuffle,
142 | )
143 |
144 | def val_dataloader(self):
145 | if not hasattr(self, "val_dataset"):
146 | self.val_dataset = DatasetEL(
147 | tokenizer=self.tokenizer,
148 | data_path=self.hparams.dev_data_path,
149 | max_length=self.hparams.max_length,
150 | max_length_span=self.hparams.max_length_span,
151 | test=True,
152 | )
153 | return DataLoader(
154 | self.val_dataset,
155 | batch_size=1,
156 | collate_fn=self.val_dataset.collate_fn,
157 | num_workers=self.hparams.num_workers,
158 | )
159 |
160 | def test_dataloader(self):
161 | if not hasattr(self, "test_dataset"):
162 | self.test_dataset = DatasetEL(
163 | tokenizer=self.tokenizer,
164 | data_path=self.hparams.test_data_path,
165 | max_length=self.hparams.max_length,
166 | max_length_span=self.hparams.max_length_span,
167 | test=True,
168 | )
169 |
170 | return DataLoader(
171 | self.test_dataset,
172 | batch_size=1,
173 | collate_fn=self.test_dataset.collate_fn,
174 | num_workers=self.hparams.num_workers,
175 | )
176 |
177 | def forward_all_targets(self, batch, return_dict=False):
178 |
179 | hidden_states = self.encoder(
180 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"]
181 | ).last_hidden_state
182 |
183 | (
184 | (start, end, scores_ed),
185 | (
186 | logits_classifier_start,
187 | logits_classifier_end,
188 | ),
189 | ) = self.entity_detection.forward_hard(
190 | batch, hidden_states, threshold=self.hparams.threshold
191 | )
192 |
193 | if start.shape[0] == 0:
194 | return []
195 |
196 | batch["offsets_start"] = start.T.tolist()
197 | batch["offsets_end"] = end.T.tolist()
198 |
199 | batch_candidates = [
200 | {
201 | (s, e): c
202 | for (s, e, _), c in zip(
203 | batch["raw"][i]["anchors"], batch["raw"][i]["candidates"]
204 | )
205 | }.get(tuple((s, e)), ["NIL"])
206 | for (i, s), (_, e) in zip(
207 | zip(*batch["offsets_start"]), zip(*batch["offsets_end"])
208 | )
209 | ]
210 |
211 | try:
212 | for k, v in self.tokenizer(
213 | [c for candidates in batch_candidates for c in candidates],
214 | return_tensors="pt",
215 | padding=True,
216 | ).items():
217 | batch[f"cand_{k}"] = v.to(self.device)
218 |
219 | batch["offsets_candidates"] = [
220 | i
221 | for i, candidates in enumerate(batch_candidates)
222 | for _ in range(len(candidates))
223 | ]
224 | batch["split_candidates"] = [
225 | len(candidates) for candidates in batch_candidates
226 | ]
227 |
228 | tokens, scores_el = self.entity_linking.forward_all_targets(
229 | batch, hidden_states
230 | )
231 |
232 | except:
233 | if not self.training:
234 | print("error on generation")
235 |
236 | try:
237 | spans = self._tokens_scores_to_spans(batch, start, end, tokens, scores_el)
238 | except:
239 | if not self.training:
240 | print("error on _tokens_scores_to_spans")
241 |
242 | spans = [[[0, 0, [("NIL", 0)]]] for i in range(len(batch["src_input_ids"]))]
243 |
244 | if return_dict:
245 | return {
246 | "spans": spans,
247 | "start": start,
248 | "end": end,
249 | "scores_ed": scores_ed,
250 | "scores_el": scores_el,
251 | "logits_classifier_start": logits_classifier_start,
252 | "logits_classifier_end": logits_classifier_end,
253 | }
254 | else:
255 | return spans
256 |
257 | def forward_beam_search(self, batch, candidates=False):
258 |
259 | hidden_states = self.encoder(
260 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"]
261 | ).last_hidden_state
262 |
263 | (
264 | (start, end, scores_ed),
265 | (
266 | logits_classifier_start,
267 | logits_classifier_end,
268 | ),
269 | ) = self.entity_detection.forward_hard(
270 | batch, hidden_states, threshold=self.hparams.threshold
271 | )
272 |
273 | if start.shape[0] == 0:
274 | return []
275 |
276 | if start.shape[0] == 0:
277 | return []
278 |
279 | batch["offsets_start"] = start.T.tolist()
280 | batch["offsets_end"] = end.T.tolist()
281 |
282 | batch_trie_dict = None
283 | if candidates:
284 | batch_candidates = [
285 | {
286 | (s, e): c
287 | for (s, e, _), c in zip(
288 | batch["raw"][i]["anchors"], batch["raw"][i]["candidates"]
289 | )
290 | }.get(tuple((s, e)), ["NIL"])
291 | for (i, s), (_, e) in zip(
292 | zip(*batch["offsets_start"]), zip(*batch["offsets_end"])
293 | )
294 | ]
295 |
296 | batch_trie_dict = []
297 | for candidates in batch_candidates:
298 | trie_dict = defaultdict(set)
299 | for c in self.tokenizer(candidates)["input_ids"]:
300 | for i in range(1, len(c)):
301 | trie_dict[tuple(c[:i])].add(c[i])
302 |
303 | batch_trie_dict.append({k: list(v) for k, v in trie_dict.items()})
304 | else:
305 | batch_trie_dict = [self.global_trie] * start.shape[0]
306 |
307 | tokens, scores_el = self.entity_linking.forward_beam_search(
308 | batch,
309 | hidden_states,
310 | batch_trie_dict,
311 | )
312 |
313 | return self._tokens_scores_to_spans(batch, start, end, tokens, scores_el)
314 |
315 | def _tokens_scores_to_spans(self, batch, start, end, tokens, scores_el):
316 |
317 | spans = [
318 | [
319 | [
320 | s,
321 | e,
322 | list(
323 | zip(
324 | self.tokenizer.batch_decode(t, skip_special_tokens=True),
325 | l.tolist(),
326 | )
327 | ),
328 | ]
329 | for s, e, t, l in zip(
330 | start[start[:, 0] == i][:, 1].tolist(),
331 | end[end[:, 0] == i][:, 1].tolist(),
332 | tokens[start[:, 0] == i],
333 | scores_el[start[:, 0] == i],
334 | )
335 | ]
336 | for i in range(len(batch["src_input_ids"]))
337 | ]
338 |
339 | for spans_ in spans:
340 | for e in [
341 | [x, y]
342 | for x in spans_
343 | for y in spans_
344 | if x is not y and x[1] >= y[0] and x[0] <= y[0]
345 | ]:
346 | for x in sorted(e, key=lambda x: x[1] - x[0])[:-1]:
347 | spans_.remove(x)
348 |
349 | return spans
350 |
351 | def training_step(self, batch, batch_idx=None):
352 |
353 | hidden_states = self.encoder(
354 | input_ids=batch["src_input_ids"], attention_mask=batch["src_attention_mask"]
355 | ).last_hidden_state
356 |
357 | loss_start, loss_end = self.entity_detection.forward_loss(batch, hidden_states)
358 |
359 | loss_generation, loss_classifier = self.entity_linking.forward_loss(
360 | batch, hidden_states, epsilon=self.hparams.epsilon
361 | )
362 |
363 | self.log("loss_s", loss_start, on_step=True, on_epoch=False, prog_bar=True)
364 | self.log("loss_e", loss_end, on_step=True, on_epoch=False, prog_bar=True)
365 | self.log("loss_g", loss_generation, on_step=True, on_epoch=False, prog_bar=True)
366 | self.log("loss_c", loss_classifier, on_step=True, on_epoch=False, prog_bar=True)
367 |
368 | return {"loss": loss_start + loss_end + loss_generation + loss_classifier}
369 |
370 | def _inference_step(self, batch, batch_idx=None):
371 | if self.hparams.test_with_beam_search_no_candidates:
372 | spans = self.forward_beam_search(batch)
373 | elif self.hparams.test_with_beam_search:
374 | spans = self.forward_beam_search(batch, candidates=True)
375 | else:
376 | spans = self.forward_all_targets(batch)
377 |
378 | for p, g in zip(spans, batch["raw"]):
379 |
380 | p_ = set((e[0], e[1], e[2][0][0]) for e in p)
381 | g_ = set((e[0], e[1], e[2]) for e in g["anchors"])
382 |
383 | self.micro_f1(p_, g_)
384 | self.micro_prec(p_, g_)
385 | self.micro_rec(p_, g_)
386 |
387 | self.macro_f1(p_, g_)
388 | self.macro_prec(p_, g_)
389 | self.macro_rec(p_, g_)
390 |
391 | p_ = set((e[0], e[1]) for e in p)
392 | g_ = set((e[0], e[1]) for e in g["anchors"])
393 |
394 | self.ed_micro_f1(p_, g_)
395 | self.ed_micro_prec(p_, g_)
396 | self.ed_micro_rec(p_, g_)
397 |
398 | self.ed_macro_f1(p_, g_)
399 | self.ed_macro_prec(p_, g_)
400 | self.ed_macro_rec(p_, g_)
401 |
402 | return {
403 | "micro_f1": self.micro_f1,
404 | "micro_prec": self.micro_prec,
405 | "macro_rec": self.macro_rec,
406 | "macro_f1": self.macro_f1,
407 | "macro_prec": self.macro_prec,
408 | "micro_rec": self.micro_rec,
409 | "ed_micro_f1": self.ed_micro_f1,
410 | "ed_micro_prec": self.ed_micro_prec,
411 | "ed_micro_rec": self.ed_micro_rec,
412 | "ed_macro_f1": self.ed_macro_f1,
413 | "ed_macro_prec": self.ed_macro_prec,
414 | "ed_macro_rec": self.ed_macro_rec,
415 | }
416 |
417 | def validation_step(self, batch, batch_idx=None):
418 | metrics = self._inference_step(batch, batch_idx)
419 | self.log_dict(
420 | {k: v for k, v in metrics.items() if k in ("micro_f1", "ed_micro_f1")},
421 | prog_bar=True,
422 | )
423 |
424 | def test_step(self, batch, batch_idx=None):
425 | metrics = self._inference_step(batch, batch_idx)
426 | self.log_dict(metrics)
427 |
428 | def generate_global_trie(self):
429 |
430 | with open(self.hparams.entities_filename) as f:
431 | entities = json.load(f)
432 |
433 | trie_dict = defaultdict(set)
434 | for e in tqdm(entities, desc="Loading .."):
435 | c = self.tokenizer(e)["input_ids"]
436 | for i in range(1, len(c)):
437 | trie_dict[tuple(c[:i])].add(c[i])
438 |
439 | self.global_trie = {k: list(v) for k, v in trie_dict.items()}
440 |
441 | def sample(self, sentences, anchors=None, candidates=None, all_targets=False):
442 | self.eval()
443 | with torch.no_grad():
444 | batch = {
445 | f"src_{k}": v.to(self.device)
446 | for k, v in self.tokenizer(
447 | sentences,
448 | return_offsets_mapping=True,
449 | return_tensors="pt",
450 | padding=True,
451 | max_length=self.hparams.max_length,
452 | truncation=True,
453 | ).items()
454 | }
455 |
456 | batch["raw"] = [
457 | {
458 | "input": sentence,
459 | "anchors": anchors[i] if anchors else None,
460 | "candidates": candidates[i] if candidates else None,
461 | }
462 | for i, sentence in enumerate(sentences)
463 | ]
464 |
465 | if anchors and candidates and all_targets:
466 | spans = self.forward_all_targets(batch)
467 | elif candidates and all_targets:
468 | spans = self.forward_beam_search(batch, candidates=True)
469 | else:
470 | spans = self.forward_beam_search(batch)
471 |
472 | return [
473 | [(eo[s][0].item(), eo[e][1].item(), l) for s, e, l in es]
474 | for es, eo in zip(spans, batch["src_offset_mapping"])
475 | ]
476 |
477 | def configure_optimizers(self):
478 | no_decay = ["bias", "LayerNorm.weight"]
479 | optimizer_grouped_parameters = [
480 | {
481 | "params": [
482 | p
483 | for n, p in self.encoder.named_parameters()
484 | if "embbedding" not in n and not any(nd in n for nd in no_decay)
485 | ],
486 | "weight_decay": self.hparams.weight_decay,
487 | "lr": self.hparams.lr_transformer,
488 | },
489 | {
490 | "params": [
491 | p
492 | for n, p in self.encoder.named_parameters()
493 | if "embbedding" not in n and any(nd in n for nd in no_decay)
494 | ],
495 | "weight_decay": 0,
496 | "lr": self.hparams.lr_transformer,
497 | },
498 | {
499 | "params": [
500 | p
501 | for n, p in self.entity_detection.named_parameters()
502 | if not any(nd in n for nd in no_decay)
503 | ],
504 | "weight_decay": self.hparams.weight_decay,
505 | },
506 | {
507 | "params": [
508 | p
509 | for n, p in self.entity_detection.named_parameters()
510 | if any(nd in n for nd in no_decay)
511 | ],
512 | "weight_decay": 0,
513 | },
514 | {
515 | "params": [
516 | p
517 | for n, p in self.entity_linking.named_parameters()
518 | if "embedding" not in n and not any(nd in n for nd in no_decay)
519 | ],
520 | "weight_decay": self.hparams.weight_decay,
521 | },
522 | {
523 | "params": [
524 | p
525 | for n, p in self.entity_linking.named_parameters()
526 | if "embbedding" not in n and any(nd in n for nd in no_decay)
527 | ],
528 | "weight_decay": 0,
529 | },
530 | ]
531 |
532 | optimizer = torch.optim.AdamW(
533 | optimizer_grouped_parameters,
534 | lr=self.hparams.lr,
535 | weight_decay=self.hparams.weight_decay,
536 | amsgrad=True,
537 | )
538 |
539 | scheduler = get_linear_schedule_with_warmup(
540 | optimizer,
541 | num_warmup_steps=self.hparams.warmup_updates,
542 | num_training_steps=self.hparams.total_num_updates,
543 | )
544 |
545 | return (
546 | [optimizer],
547 | [{"scheduler": scheduler, "interval": "step", "frequency": 1}],
548 | )
549 |
--------------------------------------------------------------------------------