├── wsl
├── common
│ ├── __init__.py
│ ├── torch_utils.py
│ ├── upload.py
│ └── log.py
├── inference
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── window
│ │ │ └── __init__.py
│ │ ├── splitters
│ │ │ ├── __init__.py
│ │ │ ├── blank_sentence_splitter.py
│ │ │ ├── base_sentence_splitter.py
│ │ │ ├── window_based_splitter.py
│ │ │ └── spacy_sentence_splitter.py
│ │ ├── tokenizers
│ │ │ ├── base_tokenizer.py
│ │ │ ├── __init__.py
│ │ │ └── spacy_tokenizer.py
│ │ └── objects.py
│ └── utils.py
├── reader
│ ├── data
│ │ ├── __init__.py
│ │ ├── wsl_reader_data_utils.py
│ │ ├── patches.py
│ │ └── wsl_reader_sample.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── special_symbols.py
│ │ ├── metrics.py
│ │ ├── strong_matching_eval.py
│ │ └── relik_reader_predictor.py
│ ├── trainer
│ │ ├── __init__.py
│ │ └── predict.py
│ ├── pytorch_modules
│ │ ├── hf
│ │ │ ├── __init__.py
│ │ │ ├── configuration_wsl.py
│ │ │ └── modeling_wsl.py
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── span.py
│ └── __init__.py
├── retriever
│ ├── common
│ │ ├── __init__.py
│ │ └── model_inputs.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── base
│ │ │ ├── __init__.py
│ │ │ └── datasets.py
│ │ ├── utils.py
│ │ └── labels.py
│ ├── indexers
│ │ ├── __init__.py
│ │ ├── document.py
│ │ └── inmemory.py
│ ├── __init__.py
│ └── pytorch_modules
│ │ ├── __init__.py
│ │ └── hf.py
├── __init__.py
└── version.py
├── MANIFEST.in
├── constraints.cpu.txt
├── SETUP.cfg
├── assets
└── Sapienza_Babelscape.png
├── .flake8
├── pyproject.toml
├── requirements.txt
├── .pre-commit-config.yaml
├── README.md
├── setup.py
├── .gitignore
└── wsl_data_license.txt
/wsl/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/reader/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/reader/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/inference/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/reader/trainer/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/retriever/common/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/retriever/data/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/inference/data/window/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/retriever/data/base/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/retriever/indexers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 |
--------------------------------------------------------------------------------
/wsl/inference/data/splitters/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/hf/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_wsl import WSLReaderConfig
2 |
--------------------------------------------------------------------------------
/wsl/retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from wsl.retriever.pytorch_modules.model import WSLRetriever
2 |
--------------------------------------------------------------------------------
/constraints.cpu.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cpu
2 | torch==2.1.0
3 |
--------------------------------------------------------------------------------
/SETUP.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
4 | [build]
5 | build-base = /tmp/build
6 |
--------------------------------------------------------------------------------
/assets/Sapienza_Babelscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Babelscape/WSL/HEAD/assets/Sapienza_Babelscape.png
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203, E266, E501, W503, F403, F401, E402, C901
3 | max-line-length = 88
4 | max-complexity = 18
5 | select = B,C,E,F,W,T4,B9
6 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/__init__.py:
--------------------------------------------------------------------------------
1 | WSL_READER_CLASS_MAP = {
2 | "WSLReaderSpanModel": "wsl.reader.pytorch_modules.span.WSLReaderForSpanExtraction",
3 | }
4 |
--------------------------------------------------------------------------------
/wsl/reader/__init__.py:
--------------------------------------------------------------------------------
1 | # from wsl.reader.pytorch_modules.base import RelikReaderBase
2 | # from wsl.reader.pytorch_modules.span import RelikReaderForSpanExtraction
3 | # from wsl.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
4 |
--------------------------------------------------------------------------------
/wsl/__init__.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | from wsl.inference.annotator import WSL
4 |
5 | VERSION = {} # type: ignore
6 | with open(Path(__file__).parent / "version.py", "r") as version_file:
7 | exec(version_file.read(), VERSION)
8 |
9 | __version__ = VERSION["VERSION"]
10 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | include = '\.pyi?$'
3 | exclude = '''
4 | /(
5 | \.git
6 | | \.hg
7 | | \.mypy_cache
8 | | \.tox
9 | | \.venv
10 | | _build
11 | | buck-out
12 | | build
13 | | dist
14 | )/
15 | '''
16 | [tool.isort]
17 | profile = 'black'
18 | line_length = 120
19 | known_third_party = ["numpy", "pytest", "wandb", "torch"]
20 | known_local_folder = "wsl"
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #------- Core dependencies -------
2 | --extra-index-url https://download.pytorch.org/whl/cu12.1
3 | torch==2.3.1
4 |
5 | transformers[sentencepiece]>=4.41,<4.42
6 | rich>=13.0.0,<14.0.0
7 | scikit-learn>=1.3,<1.4
8 | overrides>=7.4,<7.9
9 | art==6.2
10 | pprintpp==0.4.0
11 | colorama==0.4.6
12 | termcolor==2.4.0
13 | spacy>=3.7,<3.8
14 | typer>=0.12,<0.13
15 | hydra-core
16 |
17 | lightning>=2.0,<2.1
18 |
--------------------------------------------------------------------------------
/wsl/reader/utils/special_symbols.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | NME_SYMBOL = "--NME--"
4 |
5 |
6 | def get_special_symbols(num_entities: int) -> List[str]:
7 | return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)]
8 |
9 |
10 | def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]:
11 | if use_nme:
12 | return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)]
13 | else:
14 | return [f"[R-{i}]" for i in range(num_entities)]
15 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/PyCQA/isort.git
3 | rev: "5.12.0"
4 | hooks:
5 | - id: isort
6 | - repo: https://github.com/ambv/black
7 | rev: '22.3.0'
8 | hooks:
9 | - id: black
10 | - repo: https://github.com/pycqa/flake8
11 | rev: '6.1.0'
12 | hooks:
13 | - id: flake8
14 | - repo: https://github.com/PyCQA/autoflake
15 | rev: v2.3.1
16 | hooks:
17 | - id: autoflake
18 |
19 | default_language_version:
20 | python: python3
21 |
--------------------------------------------------------------------------------
/wsl/version.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | _MAJOR = "1"
4 | _MINOR = "0"
5 | # On main and in a nightly release the patch should be one ahead of the last
6 | # released build.
7 | _PATCH = "0"
8 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See
9 | # https://semver.org/#is-v123-a-semantic-version for the semantics.
10 | _SUFFIX = os.environ.get("WSL_VERSION_SUFFIX", "")
11 |
12 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
13 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
14 |
--------------------------------------------------------------------------------
/wsl/reader/utils/metrics.py:
--------------------------------------------------------------------------------
1 | def safe_divide(num: float, den: float) -> float:
2 | if den == 0:
3 | return 0
4 | else:
5 | return num / den
6 |
7 |
8 | def f1_measure(precision: float, recall: float) -> float:
9 | if precision == 0 or recall == 0:
10 | return 0.0
11 | return safe_divide(2 * precision * recall, (precision + recall))
12 |
13 |
14 | def compute_metrics(total_correct, total_preds, total_gold):
15 | precision = safe_divide(total_correct, total_preds)
16 | recall = safe_divide(total_correct, total_gold)
17 | f1 = f1_measure(precision, recall)
18 | return precision, recall, f1
19 |
--------------------------------------------------------------------------------
/wsl/retriever/pytorch_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from wsl.retriever.indexers.document import Document
6 |
7 | PRECISION_MAP = {
8 | None: torch.float32,
9 | 32: torch.float32,
10 | 16: torch.float16,
11 | torch.float32: torch.float32,
12 | torch.float16: torch.float16,
13 | torch.bfloat16: torch.bfloat16,
14 | "float32": torch.float32,
15 | "float16": torch.float16,
16 | "bfloat16": torch.bfloat16,
17 | "float": torch.float32,
18 | "half": torch.float16,
19 | "32": torch.float32,
20 | "16": torch.float16,
21 | "fp32": torch.float32,
22 | "fp16": torch.float16,
23 | "bf16": torch.bfloat16,
24 | }
25 |
26 |
27 | @dataclass
28 | class RetrievedSample:
29 | """
30 | Dataclass for the output of the GoldenRetriever model.
31 | """
32 |
33 | score: float
34 | document: Document
35 |
--------------------------------------------------------------------------------
/wsl/inference/data/splitters/blank_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 |
4 | class BlankSentenceSplitter:
5 | """
6 | A `BlankSentenceSplitter` splits strings into sentences.
7 | """
8 |
9 | def __call__(self, *args, **kwargs):
10 | """
11 | Calls :meth:`split_sentences`.
12 | """
13 | return self.split_sentences(*args, **kwargs)
14 |
15 | def split_sentences(
16 | self, text: str, max_len: int = 0, *args, **kwargs
17 | ) -> List[str]:
18 | """
19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20 | """
21 | return [text]
22 |
23 | def split_sentences_batch(
24 | self, texts: List[str], *args, **kwargs
25 | ) -> List[List[str]]:
26 | """
27 | Default implementation is to just iterate over the texts and call `split_sentences`.
28 | """
29 | return [self.split_sentences(text) for text in texts]
30 |
--------------------------------------------------------------------------------
/wsl/retriever/common/model_inputs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import UserDict
4 | from typing import Any, Union
5 |
6 | import torch
7 | from lightning.fabric.utilities import move_data_to_device
8 |
9 | from wsl.common.log import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class ModelInputs(UserDict):
15 | """Model input dictionary wrapper."""
16 |
17 | def __getattr__(self, item: str):
18 | try:
19 | return self.data[item]
20 | except KeyError:
21 | raise AttributeError(f"`ModelInputs` has no attribute `{item}`")
22 |
23 | def __getitem__(self, item: str) -> Any:
24 | return self.data[item]
25 |
26 | def __getstate__(self):
27 | return {"data": self.data}
28 |
29 | def __setstate__(self, state):
30 | if "data" in state:
31 | self.data = state["data"]
32 |
33 | def keys(self):
34 | """A set-like object providing a view on D's keys."""
35 | return self.data.keys()
36 |
37 | def values(self):
38 | """An object providing a view on D's values."""
39 | return self.data.values()
40 |
41 | def items(self):
42 | """A set-like object providing a view on D's items."""
43 | return self.data.items()
44 |
45 | def to(self, device: Union[str, torch.device]) -> ModelInputs:
46 | """
47 | Send all tensors values to device.
48 | Args:
49 | device (`str` or `torch.device`): The device to put the tensors on.
50 | Returns:
51 | :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs`
52 | after modification.
53 | """
54 | self.data = move_data_to_device(self.data, device)
55 | return self
56 |
--------------------------------------------------------------------------------
/wsl/reader/data/wsl_reader_data_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def flatten(lsts: List[list]) -> list:
8 | acc_lst = list()
9 | for lst in lsts:
10 | acc_lst.extend(lst)
11 | return acc_lst
12 |
13 |
14 | def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
15 | return torch.nn.utils.rnn.pad_sequence(
16 | tensors, batch_first=True, padding_value=padding_value
17 | )
18 |
19 |
20 | def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
21 | x = max([t.shape[0] for t in tensors])
22 | y = max([t.shape[1] for t in tensors])
23 | out_matrix = torch.zeros((len(tensors), x, y))
24 | out_matrix += padding_value
25 | for i, tensor in enumerate(tensors):
26 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
27 | return out_matrix
28 |
29 |
30 | def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
31 | x = max([t.shape[0] for t in tensors])
32 | y = max([t.shape[1] for t in tensors])
33 | rest = tensors[0].shape[2]
34 | out_matrix = torch.zeros((len(tensors), x, y, rest))
35 | out_matrix += padding_value
36 | for i, tensor in enumerate(tensors):
37 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
38 | return out_matrix
39 |
40 |
41 | def chunks(lst: list, chunk_size: int) -> List[list]:
42 | chunks_acc = list()
43 | for i in range(0, len(lst), chunk_size):
44 | chunks_acc.append(lst[i : i + chunk_size])
45 | return chunks_acc
46 |
47 |
48 | def add_noise_to_value(value: int, noise_param: float):
49 | noise_value = value * noise_param
50 | noise = np.random.uniform(-noise_value, noise_value)
51 | return max(1, value + noise)
52 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/hf/configuration_wsl.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from transformers import AutoConfig
4 | from transformers.configuration_utils import PretrainedConfig
5 |
6 |
7 | class WSLReaderConfig(PretrainedConfig):
8 | model_type = "wsl-reader"
9 |
10 | def __init__(
11 | self,
12 | transformer_model: str = "microsoft/deberta-v3-base",
13 | additional_special_symbols: int = 101,
14 | additional_special_symbols_types: Optional[int] = 0,
15 | num_layers: Optional[int] = None,
16 | activation: str = "gelu",
17 | linears_hidden_size: Optional[int] = 512,
18 | use_last_k_layers: int = 1,
19 | entity_type_loss: bool = False,
20 | add_entity_embedding: bool = None,
21 | binary_end_logits: bool = False,
22 | training: bool = False,
23 | default_reader_class: Optional[str] = None,
24 | threshold: Optional[float] = 0.5,
25 | **kwargs
26 | ) -> None:
27 | # TODO: add name_or_path to kwargs
28 | self.transformer_model = transformer_model
29 | self.additional_special_symbols = additional_special_symbols
30 | self.additional_special_symbols_types = additional_special_symbols_types
31 | self.num_layers = num_layers
32 | self.activation = activation
33 | self.linears_hidden_size = linears_hidden_size
34 | self.use_last_k_layers = use_last_k_layers
35 | self.entity_type_loss = entity_type_loss
36 | self.add_entity_embedding = (
37 | True
38 | if add_entity_embedding is None and entity_type_loss
39 | else add_entity_embedding
40 | )
41 | self.threshold = threshold
42 | self.binary_end_logits = binary_end_logits
43 | self.training = training
44 | self.default_reader_class = default_reader_class
45 | super().__init__(**kwargs)
46 |
--------------------------------------------------------------------------------
/wsl/reader/trainer/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pprint import pprint
3 | from typing import Optional
4 |
5 | from wsl.reader.data.wsl_reader_sample import load_wsl_reader_samples
6 | from wsl.reader.pytorch_modules.span import WSLReaderForSpanExtraction
7 | from wsl.reader.utils.strong_matching_eval import StrongMatching
8 |
9 |
10 | def predict(
11 | model_path: str,
12 | dataset_path: str,
13 | token_batch_size: int,
14 | is_eval: bool,
15 | output_path: Optional[str],
16 | ) -> None:
17 | wsl_reader = WSLReaderForSpanExtraction(
18 | model_path, dataset_kwargs={"use_nme": True}, device="cuda"
19 | )
20 | samples = list(load_wsl_reader_samples(dataset_path))
21 | predicted_samples = wsl_reader.read(
22 | samples=samples, token_batch_size=token_batch_size, progress_bar=True
23 | )
24 | if is_eval:
25 | eval_dict = StrongMatching()(predicted_samples)
26 | pprint(eval_dict)
27 | if output_path is not None:
28 | with open(output_path, "w") as f:
29 | for sample in predicted_samples:
30 | f.write(sample.to_jsons() + "\n")
31 |
32 |
33 | def parse_arg() -> argparse.Namespace:
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--model-path",
37 | required=True,
38 | )
39 | parser.add_argument("--dataset-path", "-i", required=True)
40 | parser.add_argument("--is-eval", action="store_true")
41 | parser.add_argument(
42 | "--output-path",
43 | "-o",
44 | )
45 | parser.add_argument("--token-batch-size", default=4096)
46 | return parser.parse_args()
47 |
48 |
49 | def main():
50 | args = parse_arg()
51 | predict(
52 | args.model_path,
53 | args.dataset_path,
54 | token_batch_size=args.token_batch_size,
55 | is_eval=args.is_eval,
56 | output_path=args.output_path,
57 | )
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
--------------------------------------------------------------------------------
/wsl/inference/data/splitters/base_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 |
4 | class BaseSentenceSplitter:
5 | """
6 | A `BaseSentenceSplitter` splits strings into sentences.
7 | """
8 |
9 | def __call__(self, *args, **kwargs):
10 | """
11 | Calls :meth:`split_sentences`.
12 | """
13 | return self.split_sentences(*args, **kwargs)
14 |
15 | def split_sentences(
16 | self, text: str, max_len: int = 0, *args, **kwargs
17 | ) -> List[str]:
18 | """
19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20 | """
21 | raise NotImplementedError
22 |
23 | def split_sentences_batch(
24 | self, texts: List[str], *args, **kwargs
25 | ) -> List[List[str]]:
26 | """
27 | Default implementation is to just iterate over the texts and call `split_sentences`.
28 | """
29 | return [self.split_sentences(text) for text in texts]
30 |
31 | @staticmethod
32 | def check_is_batched(
33 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
34 | ):
35 | """
36 | Check if input is batched or a single sample.
37 |
38 | Args:
39 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
40 | Text to check.
41 | is_split_into_words (:obj:`bool`):
42 | If :obj:`True` and the input is a string, the input is split on spaces.
43 |
44 | Returns:
45 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
46 | """
47 | return bool(
48 | (not is_split_into_words and isinstance(texts, (list, tuple)))
49 | or (
50 | is_split_into_words
51 | and isinstance(texts, (list, tuple))
52 | and texts
53 | and isinstance(texts[0], (list, tuple))
54 | )
55 | )
56 |
--------------------------------------------------------------------------------
/wsl/reader/data/patches.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample
4 | from wsl.reader.utils.special_symbols import NME_SYMBOL
5 |
6 |
7 | def merge_patches_predictions(sample) -> None:
8 | sample._d["predicted_window_labels"] = dict()
9 | predicted_window_labels = sample._d["predicted_window_labels"]
10 |
11 | sample._d["span_title_probabilities"] = dict()
12 | span_title_probabilities = sample._d["span_title_probabilities"]
13 |
14 | span2title = dict()
15 | for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
16 | # selecting span predictions
17 | for predicted_title, predicted_spans in patch_info[
18 | "predicted_window_labels"
19 | ].items():
20 | for pred_span in predicted_spans:
21 | pred_span = tuple(pred_span)
22 | curr_title = span2title.get(pred_span)
23 | if curr_title is None or curr_title == NME_SYMBOL:
24 | span2title[pred_span] = predicted_title
25 | # else:
26 | # print("Merging at patch level")
27 |
28 | # selecting span predictions probability
29 | for predicted_span, titles_probabilities in patch_info[
30 | "span_title_probabilities"
31 | ].items():
32 | if predicted_span not in span_title_probabilities:
33 | span_title_probabilities[predicted_span] = titles_probabilities
34 |
35 | for span, title in span2title.items():
36 | if title not in predicted_window_labels:
37 | predicted_window_labels[title] = list()
38 | predicted_window_labels[title].append(span)
39 |
40 |
41 | def remove_duplicate_samples(
42 | samples: List[WSLReaderSample],
43 | ) -> List[WSLReaderSample]:
44 | seen_sample = set()
45 | samples_store = []
46 | for sample in samples:
47 | sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
48 | if sample_id not in seen_sample:
49 | seen_sample.add(sample_id)
50 | samples_store.append(sample)
51 | return samples_store
52 |
--------------------------------------------------------------------------------
/wsl/reader/data/wsl_reader_sample.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Iterable
3 |
4 | import numpy as np
5 |
6 |
7 | class NpEncoder(json.JSONEncoder):
8 | def default(self, obj):
9 | if isinstance(obj, np.integer):
10 | return int(obj)
11 | if isinstance(obj, np.floating):
12 | return float(obj)
13 | if isinstance(obj, np.ndarray):
14 | return obj.tolist()
15 | return super(NpEncoder, self).default(obj)
16 |
17 |
18 | class WSLReaderSample:
19 | def __init__(self, **kwargs):
20 | super().__setattr__("_d", {})
21 | self._d = kwargs
22 |
23 | def __getattribute__(self, item):
24 | return super(WSLReaderSample, self).__getattribute__(item)
25 |
26 | def __getattr__(self, item):
27 | if item.startswith("__") and item.endswith("__"):
28 | # this is likely some python library-specific variable (such as __deepcopy__ for copy)
29 | # better follow standard behavior here
30 | raise AttributeError(item)
31 | elif item in self._d:
32 | return self._d[item]
33 | else:
34 | return None
35 |
36 | def __setattr__(self, key, value):
37 | if key in self._d:
38 | self._d[key] = value
39 | else:
40 | super().__setattr__(key, value)
41 |
42 | def to_jsons(self) -> str:
43 | if "predicted_window_labels" in self._d:
44 | new_obj = {
45 | k: v
46 | for k, v in self._d.items()
47 | if k != "predicted_window_labels" and k != "span_title_probabilities"
48 | }
49 | new_obj["predicted_window_labels"] = [
50 | [ss, se, pred_title]
51 | for (ss, se), pred_title in self.predicted_window_labels_chars
52 | ]
53 | else:
54 | return json.dumps(self._d, cls=NpEncoder)
55 |
56 | def to_dict(self) -> dict:
57 | return self._d
58 |
59 |
60 | def load_wsl_reader_samples(path: str) -> Iterable[WSLReaderSample]:
61 | with open(path) as f:
62 | for line in f:
63 | jsonl_line = json.loads(line.strip())
64 | wsl_reader_sample = WSLReaderSample(**jsonl_line)
65 | yield wsl_reader_sample
66 |
--------------------------------------------------------------------------------
/wsl/inference/data/splitters/window_based_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from wsl.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
4 |
5 |
6 | class WindowSentenceSplitter(BaseSentenceSplitter):
7 | """
8 | A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size.
9 | """
10 |
11 | def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None:
12 | super(WindowSentenceSplitter, self).__init__()
13 | self.window_size = window_size
14 | self.window_stride = window_stride
15 |
16 | def __call__(
17 | self,
18 | texts: Union[str, List[str], List[List[str]]],
19 | is_split_into_words: bool = False,
20 | **kwargs,
21 | ) -> Union[List[str], List[List[str]]]:
22 | """
23 | Tokenize the input into single words using SpaCy models.
24 |
25 | Args:
26 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
27 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
28 |
29 | Returns:
30 | :obj:`List[List[str]]`: The input doc split into sentences.
31 | """
32 | return self.split_sentences(texts)
33 |
34 | def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]:
35 | """
36 | Splits a `text` into sentences.
37 |
38 | Args:
39 | text (:obj:`str`):
40 | Text to split.
41 |
42 | Returns:
43 | :obj:`List[str]`: The input text split into sentences.
44 | """
45 |
46 | if isinstance(text, str):
47 | text = text.split()
48 | sentences = []
49 | # if window_stride is zero, we don't need overlapping windows
50 | self.window_stride = (
51 | self.window_stride if self.window_stride != 0 else self.window_size
52 | )
53 | for i in range(0, len(text), self.window_stride):
54 | # if the last stride is smaller than the window size, then we can
55 | # include more tokens form the previous window.
56 | if i != 0 and i + self.window_size > len(text):
57 | overflowing_tokens = i + self.window_size - len(text)
58 | if overflowing_tokens >= self.window_stride:
59 | break
60 | i -= overflowing_tokens
61 | involved_token_indices = list(
62 | range(i, min(i + self.window_size, len(text)))
63 | )
64 | window_tokens = [text[j] for j in involved_token_indices]
65 | sentences.append(window_tokens)
66 | return sentences
67 |
--------------------------------------------------------------------------------
/wsl/retriever/data/base/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
4 |
5 | import torch
6 | from torch.utils.data import Dataset, IterableDataset
7 |
8 | from wsl.common.log import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | class BaseDataset(Dataset):
14 | def __init__(
15 | self,
16 | name: str,
17 | path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None,
18 | data: Any = None,
19 | **kwargs,
20 | ):
21 | super().__init__()
22 | self.name = name
23 | if path is None and data is None:
24 | raise ValueError("Either `path` or `data` must be provided")
25 | self.path = path
26 | self.project_folder = Path(__file__).parent.parent.parent
27 | self.data = data
28 |
29 | def __len__(self) -> int:
30 | return len(self.data)
31 |
32 | def __getitem__(
33 | self, index
34 | ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
35 | return self.data[index]
36 |
37 | def __repr__(self) -> str:
38 | return f"Dataset({self.name=}, {self.path=})"
39 |
40 | def load(
41 | self,
42 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
43 | *args,
44 | **kwargs,
45 | ) -> Any:
46 | # load data from single or multiple paths in one single dataset
47 | raise NotImplementedError
48 |
49 | @staticmethod
50 | def collate_fn(batch: Any, *args, **kwargs) -> Any:
51 | raise NotImplementedError
52 |
53 |
54 | class IterableBaseDataset(IterableDataset):
55 | def __init__(
56 | self,
57 | name: str,
58 | path: Optional[Union[str, Path, List[str], List[Path]]] = None,
59 | data: Any = None,
60 | *args,
61 | **kwargs,
62 | ):
63 | super().__init__()
64 | self.name = name
65 | if path is None and data is None:
66 | raise ValueError("Either `path` or `data` must be provided")
67 | self.path = path
68 | self.project_folder = Path(__file__).parent.parent.parent
69 | self.data = data
70 |
71 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
72 | for sample in self.data:
73 | yield sample
74 |
75 | def __repr__(self) -> str:
76 | return f"Dataset({self.name=}, {self.path=})"
77 |
78 | def load(
79 | self,
80 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
81 | *args,
82 | **kwargs,
83 | ) -> Any:
84 | # load data from single or multiple paths in one single dataset
85 | raise NotImplementedError
86 |
87 | @staticmethod
88 | def collate_fn(batch: Any, *args, **kwargs) -> Any:
89 | raise NotImplementedError
90 |
--------------------------------------------------------------------------------
/wsl/inference/data/tokenizers/base_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from wsl.inference.data.objects import Word
4 |
5 |
6 | class BaseTokenizer:
7 | """
8 | A :obj:`Tokenizer` splits strings of text into single words, optionally adds
9 | pos tags and perform lemmatization.
10 | """
11 |
12 | def __call__(
13 | self,
14 | texts: Union[str, List[str], List[List[str]]],
15 | is_split_into_words: bool = False,
16 | **kwargs
17 | ) -> List[List[Word]]:
18 | """
19 | Tokenize the input into single words.
20 |
21 | Args:
22 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
23 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
24 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
25 | If :obj:`True` and the input is a string, the input is split on spaces.
26 |
27 | Returns:
28 | :obj:`List[List[Word]]`: The input text tokenized in single words.
29 | """
30 | raise NotImplementedError
31 |
32 | def tokenize(self, text: str) -> List[Word]:
33 | """
34 | Implements splitting words into tokens.
35 |
36 | Args:
37 | text (:obj:`str`):
38 | Text to tokenize.
39 |
40 | Returns:
41 | :obj:`List[Word]`: The input text tokenized in single words.
42 |
43 | """
44 | raise NotImplementedError
45 |
46 | def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
47 | """
48 | Implements batch splitting words into tokens.
49 |
50 | Args:
51 | texts (:obj:`List[str]`):
52 | Batch of text to tokenize.
53 |
54 | Returns:
55 | :obj:`List[List[Word]]`: The input batch tokenized in single words.
56 |
57 | """
58 | return [self.tokenize(text) for text in texts]
59 |
60 | @staticmethod
61 | def check_is_batched(
62 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
63 | ):
64 | """
65 | Check if input is batched or a single sample.
66 |
67 | Args:
68 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
69 | Text to check.
70 | is_split_into_words (:obj:`bool`):
71 | If :obj:`True` and the input is a string, the input is split on spaces.
72 |
73 | Returns:
74 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
75 | """
76 | return bool(
77 | (not is_split_into_words and isinstance(texts, (list, tuple)))
78 | or (
79 | is_split_into_words
80 | and isinstance(texts, (list, tuple))
81 | and texts
82 | and isinstance(texts[0], (list, tuple))
83 | )
84 | )
85 |
--------------------------------------------------------------------------------
/wsl/common/torch_utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import torch
4 | import transformers as tr
5 |
6 | from wsl.common.utils import is_package_available
7 |
8 | # check if ORT is available
9 | if is_package_available("onnxruntime"):
10 | from optimum.onnxruntime import (
11 | ORTModel,
12 | ORTModelForCustomTasks,
13 | ORTModelForSequenceClassification,
14 | ORTOptimizer,
15 | )
16 | from optimum.onnxruntime.configuration import AutoOptimizationConfig
17 |
18 | # from wsl.retriever.pytorch_modules import PRECISION_MAP
19 |
20 |
21 | def get_autocast_context(
22 | device: str | torch.device, precision: str
23 | ) -> contextlib.AbstractContextManager:
24 | # fucking autocast only wants pure strings like 'cpu' or 'cuda'
25 | # we need to convert the model device to that
26 | device_type_for_autocast = str(device).split(":")[0]
27 |
28 | from wsl.retriever.pytorch_modules import PRECISION_MAP
29 |
30 | # autocast doesn't work with CPU and stuff different from bfloat16
31 | autocast_manager = (
32 | contextlib.nullcontext()
33 | if device_type_for_autocast in ["cpu", "mps"]
34 | and PRECISION_MAP[precision] != torch.bfloat16
35 | else (
36 | torch.autocast(
37 | device_type=device_type_for_autocast,
38 | dtype=PRECISION_MAP[precision],
39 | )
40 | )
41 | )
42 | return autocast_manager
43 |
44 |
45 | # def load_ort_optimized_hf_model(
46 | # hf_model: tr.PreTrainedModel,
47 | # provider: str = "CPUExecutionProvider",
48 | # ort_model_type: callable = "ORTModelForCustomTasks",
49 | # ) -> ORTModel:
50 | # """
51 | # Load an optimized ONNX Runtime HF model.
52 | #
53 | # Args:
54 | # hf_model (`tr.PreTrainedModel`):
55 | # The HF model to optimize.
56 | # provider (`str`, optional):
57 | # The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider".
58 | #
59 | # Returns:
60 | # `ORTModel`: The optimized HF model.
61 | # """
62 | # if isinstance(hf_model, ORTModel):
63 | # return hf_model
64 | # temp_dir = tempfile.mkdtemp()
65 | # hf_model.save_pretrained(temp_dir)
66 | # ort_model = ort_model_type.from_pretrained(
67 | # temp_dir, export=True, provider=provider, use_io_binding=True
68 | # )
69 | # if is_package_available("onnxruntime"):
70 | # optimizer = ORTOptimizer.from_pretrained(ort_model)
71 | # optimization_config = AutoOptimizationConfig.O4()
72 | # optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config)
73 | # ort_model = ort_model_type.from_pretrained(
74 | # temp_dir,
75 | # export=True,
76 | # provider=provider,
77 | # use_io_binding=bool(provider == "CUDAExecutionProvider"),
78 | # )
79 | # return ort_model
80 | # else:
81 | # raise ValueError("onnxruntime is not installed. Please install Ray with `pip install wsl[serve]`.")
82 |
--------------------------------------------------------------------------------
/wsl/retriever/pytorch_modules/hf.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import torch
4 | from transformers import PretrainedConfig
5 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
6 | from transformers.models.bert.modeling_bert import BertModel
7 |
8 |
9 | class WSLRetrieverConfig(PretrainedConfig):
10 | model_type = "bert"
11 |
12 | def __init__(
13 | self,
14 | vocab_size=30522,
15 | hidden_size=768,
16 | num_hidden_layers=12,
17 | num_attention_heads=12,
18 | intermediate_size=3072,
19 | hidden_act="gelu",
20 | hidden_dropout_prob=0.1,
21 | attention_probs_dropout_prob=0.1,
22 | max_position_embeddings=512,
23 | type_vocab_size=2,
24 | initializer_range=0.02,
25 | layer_norm_eps=1e-12,
26 | pad_token_id=0,
27 | position_embedding_type="absolute",
28 | use_cache=True,
29 | classifier_dropout=None,
30 | **kwargs,
31 | ):
32 | super().__init__(pad_token_id=pad_token_id, **kwargs)
33 |
34 | self.vocab_size = vocab_size
35 | self.hidden_size = hidden_size
36 | self.num_hidden_layers = num_hidden_layers
37 | self.num_attention_heads = num_attention_heads
38 | self.hidden_act = hidden_act
39 | self.intermediate_size = intermediate_size
40 | self.hidden_dropout_prob = hidden_dropout_prob
41 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
42 | self.max_position_embeddings = max_position_embeddings
43 | self.type_vocab_size = type_vocab_size
44 | self.initializer_range = initializer_range
45 | self.layer_norm_eps = layer_norm_eps
46 | self.position_embedding_type = position_embedding_type
47 | self.use_cache = use_cache
48 | self.classifier_dropout = classifier_dropout
49 |
50 |
51 | class WSLRetrieverModel(BertModel):
52 | config_class = WSLRetrieverConfig
53 |
54 | def __init__(self, config, *args, **kwargs):
55 | super().__init__(config)
56 |
57 | def forward(
58 | self, **kwargs
59 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
60 | attention_mask = kwargs.get("attention_mask", None)
61 | model_outputs = super().forward(**kwargs)
62 | if attention_mask is None:
63 | pooler_output = model_outputs.pooler_output
64 | else:
65 | token_embeddings = model_outputs.last_hidden_state
66 | input_mask_expanded = (
67 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
68 | )
69 | pooler_output = torch.sum(
70 | token_embeddings * input_mask_expanded, 1
71 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
72 |
73 | if not kwargs.get("return_dict", True):
74 | return (model_outputs[0], pooler_output) + model_outputs[2:]
75 |
76 | return BaseModelOutputWithPoolingAndCrossAttentions(
77 | last_hidden_state=model_outputs.last_hidden_state,
78 | pooler_output=pooler_output,
79 | past_key_values=model_outputs.past_key_values,
80 | hidden_states=model_outputs.hidden_states,
81 | attentions=model_outputs.attentions,
82 | cross_attentions=model_outputs.cross_attentions,
83 | )
84 |
--------------------------------------------------------------------------------
/wsl/inference/data/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | SPACY_LANGUAGE_MAPPER = {
2 | "ca": "ca_core_news_sm",
3 | "da": "da_core_news_sm",
4 | "de": "de_core_news_sm",
5 | "el": "el_core_news_sm",
6 | "en": "en_core_web_sm",
7 | "es": "es_core_news_sm",
8 | "fr": "fr_core_news_sm",
9 | "it": "it_core_news_sm",
10 | "ja": "ja_core_news_sm",
11 | "lt": "lt_core_news_sm",
12 | "mk": "mk_core_news_sm",
13 | "nb": "nb_core_news_sm",
14 | "nl": "nl_core_news_sm",
15 | "pl": "pl_core_news_sm",
16 | "pt": "pt_core_news_sm",
17 | "ro": "ro_core_news_sm",
18 | "ru": "ru_core_news_sm",
19 | "xx": "xx_sent_ud_sm",
20 | "zh": "zh_core_web_sm",
21 | "ca_core_news_sm": "ca_core_news_sm",
22 | "ca_core_news_md": "ca_core_news_md",
23 | "ca_core_news_lg": "ca_core_news_lg",
24 | "ca_core_news_trf": "ca_core_news_trf",
25 | "da_core_news_sm": "da_core_news_sm",
26 | "da_core_news_md": "da_core_news_md",
27 | "da_core_news_lg": "da_core_news_lg",
28 | "da_core_news_trf": "da_core_news_trf",
29 | "de_core_news_sm": "de_core_news_sm",
30 | "de_core_news_md": "de_core_news_md",
31 | "de_core_news_lg": "de_core_news_lg",
32 | "de_dep_news_trf": "de_dep_news_trf",
33 | "el_core_news_sm": "el_core_news_sm",
34 | "el_core_news_md": "el_core_news_md",
35 | "el_core_news_lg": "el_core_news_lg",
36 | "en_core_web_sm": "en_core_web_sm",
37 | "en_core_web_md": "en_core_web_md",
38 | "en_core_web_lg": "en_core_web_lg",
39 | "en_core_web_trf": "en_core_web_trf",
40 | "es_core_news_sm": "es_core_news_sm",
41 | "es_core_news_md": "es_core_news_md",
42 | "es_core_news_lg": "es_core_news_lg",
43 | "es_dep_news_trf": "es_dep_news_trf",
44 | "fr_core_news_sm": "fr_core_news_sm",
45 | "fr_core_news_md": "fr_core_news_md",
46 | "fr_core_news_lg": "fr_core_news_lg",
47 | "fr_dep_news_trf": "fr_dep_news_trf",
48 | "it_core_news_sm": "it_core_news_sm",
49 | "it_core_news_md": "it_core_news_md",
50 | "it_core_news_lg": "it_core_news_lg",
51 | "ja_core_news_sm": "ja_core_news_sm",
52 | "ja_core_news_md": "ja_core_news_md",
53 | "ja_core_news_lg": "ja_core_news_lg",
54 | "ja_dep_news_trf": "ja_dep_news_trf",
55 | "lt_core_news_sm": "lt_core_news_sm",
56 | "lt_core_news_md": "lt_core_news_md",
57 | "lt_core_news_lg": "lt_core_news_lg",
58 | "mk_core_news_sm": "mk_core_news_sm",
59 | "mk_core_news_md": "mk_core_news_md",
60 | "mk_core_news_lg": "mk_core_news_lg",
61 | "nb_core_news_sm": "nb_core_news_sm",
62 | "nb_core_news_md": "nb_core_news_md",
63 | "nb_core_news_lg": "nb_core_news_lg",
64 | "nl_core_news_sm": "nl_core_news_sm",
65 | "nl_core_news_md": "nl_core_news_md",
66 | "nl_core_news_lg": "nl_core_news_lg",
67 | "pl_core_news_sm": "pl_core_news_sm",
68 | "pl_core_news_md": "pl_core_news_md",
69 | "pl_core_news_lg": "pl_core_news_lg",
70 | "pt_core_news_sm": "pt_core_news_sm",
71 | "pt_core_news_md": "pt_core_news_md",
72 | "pt_core_news_lg": "pt_core_news_lg",
73 | "ro_core_news_sm": "ro_core_news_sm",
74 | "ro_core_news_md": "ro_core_news_md",
75 | "ro_core_news_lg": "ro_core_news_lg",
76 | "ru_core_news_sm": "ru_core_news_sm",
77 | "ru_core_news_md": "ru_core_news_md",
78 | "ru_core_news_lg": "ru_core_news_lg",
79 | "xx_ent_wiki_sm": "xx_ent_wiki_sm",
80 | "xx_sent_ud_sm": "xx_sent_ud_sm",
81 | "zh_core_web_sm": "zh_core_web_sm",
82 | "zh_core_web_md": "zh_core_web_md",
83 | "zh_core_web_lg": "zh_core_web_lg",
84 | "zh_core_web_trf": "zh_core_web_trf",
85 | }
86 |
87 | from wsl.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
88 |
--------------------------------------------------------------------------------
/wsl/inference/data/objects.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from enum import Enum
5 | from typing import Dict, List, NamedTuple, Optional
6 |
7 | from wsl.reader.pytorch_modules.hf.modeling_wsl import WSLReaderSample
8 | from wsl.retriever.indexers.document import Document
9 |
10 |
11 | @dataclass
12 | class Word:
13 | """
14 | A word representation that includes text, index in the sentence, POS tag, lemma,
15 | dependency relation, and similar information.
16 |
17 | # Parameters
18 | text : `str`, optional
19 | The text representation.
20 | index : `int`, optional
21 | The word offset in the sentence.
22 | lemma : `str`, optional
23 | The lemma of this word.
24 | pos : `str`, optional
25 | The coarse-grained part of speech of this word.
26 | dep : `str`, optional
27 | The dependency relation for this word.
28 |
29 | input_id : `int`, optional
30 | Integer representation of the word, used to pass it to a model.
31 | token_type_id : `int`, optional
32 | Token type id used by some transformers.
33 | attention_mask: `int`, optional
34 | Attention mask used by transformers, indicates to the model which tokens should
35 | be attended to, and which should not.
36 | """
37 |
38 | text: str
39 | i: int
40 | idx: Optional[int] = None
41 | idx_end: Optional[int] = None
42 | # preprocessing fields
43 | lemma: Optional[str] = None
44 | pos: Optional[str] = None
45 | dep: Optional[str] = None
46 | head: Optional[int] = None
47 |
48 | def __str__(self):
49 | return self.text
50 |
51 | def __repr__(self):
52 | return self.__str__()
53 |
54 |
55 | class Span(NamedTuple):
56 | start: int
57 | end: int
58 | label: str
59 | text: str
60 |
61 |
62 | class Candidates(NamedTuple):
63 | candidates: Dict[List[Document]]
64 |
65 |
66 | @dataclass
67 | class WSLOutput:
68 | """
69 | Represents the output of the Relik model.
70 |
71 | Attributes:
72 | text (str):
73 | The original input text.
74 | tokens (List[str]):
75 | The list of tokens generated from the input text.
76 | spans (List[Span]):
77 | The list of spans generated for the input t
78 | candidates (Candidates):
79 | The candidates for spans and triplets. The candidates are generated by the retriever.
80 | For each type of candidate, the documents are stored in a list of lists. The outer list
81 | represents the windows, and the inner list represents the documents in that window.
82 | If only one window is used, the outer list will have only one element.
83 | windows (Optional[List[WSLReaderSample]]):
84 | The list of windows used for processing the input text.
85 | """
86 |
87 | text: str
88 | tokens: List[str]
89 | id: str | int
90 | spans: List[Span]
91 | candidates: Candidates = None
92 | windows: Optional[List[WSLReaderSample]] = None
93 |
94 | # convert to dict
95 | def to_dict(self):
96 | self_dict = {
97 | "text": self.text,
98 | "tokens": self.tokens,
99 | "spans": self.spans,
100 | "candidates": {
101 | "span": [
102 | [[doc.to_dict() for doc in documents] for documents in window]
103 | for window in self.candidates.candidates
104 | ],
105 | "triplet": [
106 | [[doc.to_dict() for doc in documents] for documents in window]
107 | for window in self.candidates.triplet
108 | ],
109 | },
110 | }
111 | if self.windows is not None:
112 | self_dict["windows"] = [window.to_dict() for window in self.windows]
113 | return self_dict
114 |
115 |
116 | class AnnotationType(Enum):
117 | CHAR = "char"
118 | WORD = "word"
119 |
120 |
121 | class TaskType(Enum):
122 | SPAN = "span"
123 | BOTH = "both"
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Word Sense Linking: Disambiguating Outside the Sandbox
6 |
7 |
8 | [](https://2024.aclweb.org/)
9 | [](https://aclanthology.org/2024.findings-acl.851/)
10 | [](https://huggingface.co/collections/Babelscape/word-sense-linking-66ace2182bc45680964cefcb)
11 |
12 | 
13 |
14 |
15 |
16 |
17 |
18 |
19 | With this work we introduce a new task: **Word Sense Linking (WSL)**. WSL enhances Word Sense Disambiguation by carrying out both candidate identification (new!) and candidate disambiguation. Our Word Sense Linking model is designed to identify and disambiguate spans of text to their most suitable senses from a reference inventory.
20 |
21 | ## Installation
22 |
23 | Installation from PyPI
24 |
25 | ```bash
26 | git clone https://github.com/Babelscape/WSL
27 | cd WSL
28 | pip install .
29 | ```
30 |
31 |
32 | ## Usage
33 |
34 | WSL is composed of two main components: a retriever and a reader.
35 | The retriever is responsible for retrieving relevant senses from a senses inventory (e.g. WordNet),
36 | while the reader is responsible for extracting spans from the input text and link them to the retrieved documents.
37 | WSL can be used with the `from_pretrained` method to load a pre-trained pipeline.
38 |
39 | ```python
40 | from wsl import WSL
41 | from wsl.inference.data.objects import WSLOutput
42 |
43 | wsl_model = WSL.from_pretrained("Babelscape/wsl-base")
44 | WSLOutput = wsl_model("Bus drivers drive busses for a living.")
45 | ```
46 |
47 | WSLOutput(
48 | text='Bus drivers drive busses for a living.',
49 | tokens=['Bus', 'drivers', 'drive', 'busses', 'for', 'a', 'living', '.'],
50 | id=0,
51 | spans=[
52 | Span(start=0, end=11, label='bus driver: someone who drives a bus', text='Bus drivers'),
53 | Span(start=12, end=17, label='drive: operate or control a vehicle', text='drive'),
54 | Span(start=18, end=24, label='bus: a vehicle carrying many passengers; used for public transport', text='busses'),
55 | Span(start=31, end=37, label='living: the financial means whereby one lives', text='living')
56 | ],
57 | candidates=Candidates(
58 | candidates=[
59 | {"text": "bus driver: someone who drives a bus", "id": "bus_driver%1:18:00::", "metadata": {}},
60 | {"text": "driver: the operator of a motor vehicle", "id": "driver%1:18:00::", "metadata": {}},
61 | {"text": "driver: someone who drives animals that pull a vehicle", "id": "driver%1:18:02::", "metadata": {}},
62 | {"text": "bus: a vehicle carrying many passengers; used for public transport", "id": "bus%1:06:00::", "metadata": {}},
63 | {"text": "living: the financial means whereby one lives", "id": "living%1:26:00::", "metadata": {}}
64 | ]
65 | ),
66 | )
67 |
68 |
69 |
70 | ## Model Performance
71 |
72 | Here you can find the performances of our model on the [WSL evaluation dataset](https://huggingface.co/datasets/Babelscape/wsl).
73 |
74 | ### Validation (SE07)
75 |
76 | | Models | P | R | F1 |
77 | |--------------|------|--------|--------|
78 | | BEM_SUP | 67.6 | 40.9 | 51.0 |
79 | | BEM_HEU | 70.8 | 51.2 | 59.4 |
80 | | ConSeC_SUP | 76.4 | 46.5 | 57.8 |
81 | | ConSeC_HEU | **76.7** | 55.4 | 64.3 |
82 | | **Our Model**| 73.8 | **74.9** | **74.4** |
83 |
84 | ### Test (ALL_FULL)
85 |
86 | | Models | P | R | F1 |
87 | |--------------|------|--------|--------|
88 | | BEM_SUP | 74.8 | 50.7 | 60.4 |
89 | | BEM_HEU | 76.6 | 61.2 | 68.0 |
90 | | ConSeC_SUP | 78.9 | 53.1 | 63.5 |
91 | | ConSeC_HEU | **80.4** | 64.3 | 71.5 |
92 | | **Our Model**| 75.2 | **76.7** | **75.9** |
93 |
94 |
95 | ## Cite this work
96 |
97 | If you use any part of this work, please consider citing the paper as follows:
98 |
99 | ```bibtex
100 | @inproceedings{bejgu-etal-2024-wsl,
101 | title = "Word Sense Linking: Disambiguating Outside the Sandbox",
102 | author = "Bejgu, Andrei Stefan and Barba, Edoardo and Procopio, Luigi and Fern{\'a}ndez-Castro, Alberte and Navigli, Roberto",
103 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2024",
104 | month = aug,
105 | year = "2024",
106 | address = "Bangkok, Thailand",
107 | publisher = "Association for Computational Linguistics",
108 | }
109 | ```
110 |
111 | ## License
112 |
113 | The data and software are licensed under cc-by-nc-sa-4.0 you can read it here [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](./wsl_data_license.txt).
114 |
115 |
--------------------------------------------------------------------------------
/wsl/common/upload.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import tempfile
6 | import zipfile
7 | from datetime import datetime
8 | from pathlib import Path
9 | from typing import Optional, Union
10 |
11 | import huggingface_hub
12 |
13 | from wsl.common.log import get_logger
14 | from wsl.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
15 |
16 | logger = get_logger(__name__, level=logging.DEBUG)
17 |
18 |
19 | def create_info_file(tmpdir: Path):
20 | logger.debug("Computing md5 of model.zip")
21 | md5 = get_md5(tmpdir / "model.zip")
22 | date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
23 |
24 | logger.debug("Dumping info.json file")
25 | with (tmpdir / "info.json").open("w") as f:
26 | json.dump(dict(md5=md5, upload_date=date), f, indent=2)
27 |
28 |
29 | def zip_run(
30 | dir_path: Union[str, os.PathLike],
31 | tmpdir: Union[str, os.PathLike],
32 | zip_name: str = "model.zip",
33 | ) -> Path:
34 | logger.debug(f"zipping {dir_path} to {tmpdir}")
35 | # creates a zip version of the provided dir_path
36 | run_dir = Path(dir_path)
37 | zip_path = tmpdir / zip_name
38 |
39 | with zipfile.ZipFile(zip_path, "w") as zip_file:
40 | # fully zip the run directory maintaining its structure
41 | for file in run_dir.rglob("*.*"):
42 | if file.is_dir():
43 | continue
44 |
45 | zip_file.write(file, arcname=file.relative_to(run_dir))
46 |
47 | return zip_path
48 |
49 |
50 | def get_logged_in_username():
51 | token = huggingface_hub.HfFolder.get_token()
52 | if token is None:
53 | raise ValueError(
54 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
55 | )
56 | api = huggingface_hub.HfApi()
57 | user = api.whoami(token=token)
58 | return user["name"]
59 |
60 |
61 | def upload(
62 | model_dir: Union[str, os.PathLike],
63 | model_name: str,
64 | filenames: Optional[list[str]] = None,
65 | organization: Optional[str] = None,
66 | repo_name: Optional[str] = None,
67 | commit: Optional[str] = None,
68 | archive: bool = False,
69 | ):
70 | token = huggingface_hub.HfFolder.get_token()
71 | if token is None:
72 | raise ValueError(
73 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
74 | )
75 |
76 | repo_id = repo_name or model_name
77 | if organization is not None:
78 | repo_id = f"{organization}/{repo_id}"
79 | with tempfile.TemporaryDirectory() as tmpdir:
80 | api = huggingface_hub.HfApi()
81 | repo_url = api.create_repo(
82 | token=token,
83 | repo_id=repo_id,
84 | exist_ok=True,
85 | )
86 | repo = huggingface_hub.Repository(
87 | str(tmpdir), clone_from=repo_url, use_auth_token=token
88 | )
89 |
90 | tmp_path = Path(tmpdir)
91 | if archive:
92 | # otherwise we zip the model_dir
93 | logger.debug(f"Zipping {model_dir} to {tmp_path}")
94 | zip_run(model_dir, tmp_path)
95 | create_info_file(tmp_path)
96 | else:
97 | # if the user wants to upload a transformers model, we don't need to zip it
98 | # we just need to copy the files to the tmpdir
99 | logger.debug(f"Copying {model_dir} to {tmpdir}")
100 | # copy only the files that are needed
101 | if filenames is not None:
102 | for filename in filenames:
103 | os.system(f"cp {model_dir}/{filename} {tmpdir}")
104 | else:
105 | os.system(f"cp -r {model_dir}/* {tmpdir}")
106 |
107 | # this method automatically puts large files (>10MB) into git lfs
108 | repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
109 |
110 |
111 | def parse_args() -> argparse.Namespace:
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument(
114 | "model_dir", help="The directory of the model you want to upload"
115 | )
116 | parser.add_argument("model_name", help="The model you want to upload")
117 | parser.add_argument(
118 | "--organization",
119 | help="the name of the organization where you want to upload the model",
120 | )
121 | parser.add_argument(
122 | "--repo_name",
123 | help="Optional name to use when uploading to the HuggingFace repository",
124 | )
125 | parser.add_argument(
126 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
127 | )
128 | parser.add_argument(
129 | "--archive",
130 | action="store_true",
131 | help="""
132 | Whether to compress the model directory before uploading it.
133 | If True, the model directory will be zipped and the zip file will be uploaded.
134 | If False, the model directory will be uploaded as is.""",
135 | )
136 | return parser.parse_args()
137 |
138 |
139 | def main():
140 | upload(**vars(parse_args()))
141 |
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------
/wsl/common/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import threading
5 | from logging.config import dictConfig
6 | from typing import Any, Dict, Optional
7 |
8 | from art import text2art, tprint
9 | from colorama import Fore, Style, init
10 | from rich import get_console
11 | from termcolor import colored, cprint
12 |
13 | _lock = threading.Lock()
14 | _default_handler: Optional[logging.Handler] = None
15 |
16 | _default_log_level = logging.WARNING
17 |
18 | # fancy logger
19 | _console = get_console()
20 |
21 |
22 | class ColorfulFormatter(logging.Formatter):
23 | """
24 | Formatter to add coloring to log messages by log type
25 | """
26 |
27 | COLORS = {
28 | "WARNING": Fore.YELLOW,
29 | "ERROR": Fore.RED,
30 | "CRITICAL": Fore.RED + Style.BRIGHT,
31 | "DEBUG": Fore.CYAN,
32 | # "INFO": Fore.GREEN,
33 | }
34 |
35 | def format(self, record):
36 | record.rank = int(os.getenv("LOCAL_RANK", "0"))
37 | log_message = super().format(record)
38 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
39 |
40 |
41 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
42 | "version": 1,
43 | "formatters": {
44 | "simple": {
45 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
46 | },
47 | "colorful": {
48 | "()": ColorfulFormatter,
49 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
50 | },
51 | },
52 | "filters": {},
53 | "handlers": {
54 | "console": {
55 | "class": "logging.StreamHandler",
56 | "formatter": "simple",
57 | "filters": [],
58 | "stream": sys.stdout,
59 | },
60 | "color_console": {
61 | "class": "logging.StreamHandler",
62 | "formatter": "colorful",
63 | "filters": [],
64 | "stream": sys.stdout,
65 | },
66 | },
67 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
68 | "loggers": {
69 | "wsl": {
70 | "handlers": ["color_console"],
71 | "level": "DEBUG",
72 | "propagate": False,
73 | },
74 | },
75 | }
76 |
77 |
78 | def configure_logging(**kwargs):
79 | """Configure with default logging"""
80 | init() # Initialize colorama
81 | # merge DEFAULT_LOGGING_CONFIG with kwargs
82 | logger_config = DEFAULT_LOGGING_CONFIG
83 | if kwargs:
84 | logger_config.update(kwargs)
85 | dictConfig(logger_config)
86 |
87 |
88 | def _get_library_name() -> str:
89 | return __name__.split(".")[0]
90 |
91 |
92 | def _get_library_root_logger() -> logging.Logger:
93 | return logging.getLogger(_get_library_name())
94 |
95 |
96 | def _configure_library_root_logger() -> None:
97 | global _default_handler
98 |
99 | with _lock:
100 | if _default_handler:
101 | # This library has already configured the library root logger.
102 | return
103 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
104 | _default_handler.flush = sys.stderr.flush
105 |
106 | # Apply our default configuration to the library root logger.
107 | library_root_logger = _get_library_root_logger()
108 | library_root_logger.addHandler(_default_handler)
109 | library_root_logger.setLevel(_default_log_level)
110 | library_root_logger.propagate = False
111 |
112 |
113 | def _reset_library_root_logger() -> None:
114 | global _default_handler
115 |
116 | with _lock:
117 | if not _default_handler:
118 | return
119 |
120 | library_root_logger = _get_library_root_logger()
121 | library_root_logger.removeHandler(_default_handler)
122 | library_root_logger.setLevel(logging.NOTSET)
123 | _default_handler = None
124 |
125 |
126 | def set_log_level(level: int, logger: logging.Logger = None) -> None:
127 | """
128 | Set the log level.
129 | Args:
130 | level (:obj:`int`):
131 | Logging level.
132 | logger (:obj:`logging.Logger`):
133 | Logger to set the log level.
134 | """
135 | if not logger:
136 | _configure_library_root_logger()
137 | logger = _get_library_root_logger()
138 | logger.setLevel(level)
139 |
140 |
141 | def get_logger(
142 | name: Optional[str] = None,
143 | level: Optional[int] = None,
144 | formatter: Optional[str] = None,
145 | **kwargs,
146 | ) -> logging.Logger:
147 | """
148 | Return a logger with the specified name.
149 | """
150 |
151 | configure_logging(**kwargs)
152 |
153 | if name is None:
154 | name = _get_library_name()
155 |
156 | _configure_library_root_logger()
157 |
158 | if level is not None:
159 | set_log_level(level)
160 |
161 | if formatter is None:
162 | formatter = logging.Formatter(
163 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
164 | )
165 | _default_handler.setFormatter(formatter)
166 |
167 | return logging.getLogger(name)
168 |
169 |
170 | def get_console_logger():
171 | return _console
172 |
--------------------------------------------------------------------------------
/wsl/reader/utils/strong_matching_eval.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 | from lightning.pytorch.callbacks import Callback
4 |
5 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample
6 | from wsl.reader.utils.metrics import f1_measure, safe_divide
7 | from wsl.reader.utils.relik_reader_predictor import WSLReaderPredictor
8 | from wsl.reader.utils.special_symbols import NME_SYMBOL
9 |
10 |
11 | class StrongMatching:
12 | def __call__(self, predicted_samples: List[WSLReaderSample]) -> Dict:
13 | # accumulators
14 | correct_predictions = 0
15 | correct_predictions_at_k = 0
16 | total_predictions = 0
17 | total_gold = 0
18 | correct_span_predictions = 0
19 | miss_due_to_candidates = 0
20 |
21 | # prediction index stats
22 | avg_correct_predicted_index = []
23 | avg_wrong_predicted_index = []
24 | less_index_predictions = []
25 |
26 | # collect data from samples
27 | for sample in predicted_samples:
28 | predicted_annotations = sample.predicted_window_labels_chars
29 | predicted_annotations_probabilities = sample.probs_window_labels_chars
30 | gold_annotations = {
31 | (ss, se, entity)
32 | for ss, se, entity in sample.window_labels
33 | if entity != NME_SYMBOL
34 | }
35 | total_predictions += len(predicted_annotations)
36 | total_gold += len(gold_annotations)
37 |
38 | # correct named entity detection
39 | predicted_spans = {(s, e) for s, e, _ in predicted_annotations}
40 | gold_spans = {(s, e) for s, e, _ in gold_annotations}
41 | correct_span_predictions += len(predicted_spans.intersection(gold_spans))
42 |
43 | # correct entity linking
44 | correct_predictions += len(
45 | predicted_annotations.intersection(gold_annotations)
46 | )
47 |
48 | for ss, se, ge in gold_annotations.difference(predicted_annotations):
49 | if ge not in sample.span_candidates:
50 | miss_due_to_candidates += 1
51 | if ge in predicted_annotations_probabilities.get((ss, se), set()):
52 | correct_predictions_at_k += 1
53 |
54 | # indices metrics
55 | predicted_spans_index = {
56 | (ss, se): ent for ss, se, ent in predicted_annotations
57 | }
58 | gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations}
59 |
60 | for pred_span, pred_ent in predicted_spans_index.items():
61 | gold_ent = gold_spans_index.get(pred_span)
62 |
63 | if pred_span not in gold_spans_index:
64 | continue
65 |
66 | # missing candidate
67 | if gold_ent not in sample.span_candidates:
68 | continue
69 |
70 | gold_idx = sample.span_candidates.index(gold_ent)
71 | if gold_idx is None:
72 | continue
73 | pred_idx = sample.span_candidates.index(pred_ent)
74 |
75 | if gold_ent != pred_ent:
76 | avg_wrong_predicted_index.append(pred_idx)
77 |
78 | if gold_idx is not None:
79 | if pred_idx > gold_idx:
80 | less_index_predictions.append(0)
81 | else:
82 | less_index_predictions.append(1)
83 |
84 | else:
85 | avg_correct_predicted_index.append(pred_idx)
86 |
87 | # compute NED metrics
88 | span_precision = safe_divide(correct_span_predictions, total_predictions)
89 | span_recall = safe_divide(correct_span_predictions, total_gold)
90 | span_f1 = f1_measure(span_precision, span_recall)
91 |
92 | # compute EL metrics
93 | precision = safe_divide(correct_predictions, total_predictions)
94 | recall = safe_divide(correct_predictions, total_gold)
95 | recall_at_k = safe_divide(
96 | (correct_predictions + correct_predictions_at_k), total_gold
97 | )
98 |
99 | f1 = f1_measure(precision, recall)
100 |
101 | wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold)
102 |
103 | out_dict = {
104 | "span_precision": span_precision,
105 | "span_recall": span_recall,
106 | "span_f1": span_f1,
107 | "core_precision": precision,
108 | "core_recall": recall,
109 | "core_recall-at-k": recall_at_k,
110 | "core_f1": round(f1, 4),
111 | "wrong-for-candidates": wrong_for_candidates,
112 | "index_errors_avg-index": safe_divide(
113 | sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index)
114 | ),
115 | "index_correct_avg-index": safe_divide(
116 | sum(avg_correct_predicted_index), len(avg_correct_predicted_index)
117 | ),
118 | "index_avg-index": safe_divide(
119 | sum(avg_correct_predicted_index + avg_wrong_predicted_index),
120 | len(avg_correct_predicted_index + avg_wrong_predicted_index),
121 | ),
122 | "index_percentage-favoured-smaller-idx": safe_divide(
123 | sum(less_index_predictions), len(less_index_predictions)
124 | ),
125 | }
126 |
127 | return {k: round(v, 5) for k, v in out_dict.items()}
128 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py
3 | To create the package for pypi.
4 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the
5 | documentation.
6 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid.
7 | 3. Unpin specific versions from setup.py that use a git install.
8 | 4. Commit these changes with the message: "Release: VERSION"
9 | 5. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' "
10 | Push the tag to git: git push --tags origin master
11 | 6. Build both the sources and the wheel. Do not change anything in setup.py between
12 | creating the wheel and the source distribution (obviously).
13 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory.
14 | (this will build a wheel for the python version you use to build it).
15 | For the sources, run: "python setup.py sdist"
16 | You should now have a /dist directory with both .whl and .tar.gz source versions.
17 | 7. Check that everything looks correct by uploading the package to the pypi test server:
18 | twine upload dist/* -r pypitest
19 | (pypi suggest using twine as other methods upload files via plaintext.)
20 | You may have to specify the repository url, use the following command then:
21 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/
22 | Check that you can install it in a virtualenv by running:
23 | pip install -i https://testpypi.python.org/pypi transformers
24 | 8. Upload the final version to actual pypi:
25 | twine upload dist/* -r pypi
26 | 9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
27 | 10. Run `make post-release` (or `make post-patch` for a patch release).
28 | """
29 | from collections import defaultdict
30 |
31 | import setuptools
32 |
33 |
34 | def parse_requirements_file(
35 | path, allowed_extras: set = None, include_all_extra: bool = True
36 | ):
37 | requirements = []
38 | extras = defaultdict(list)
39 | find_links = []
40 | with open(path) as requirements_file:
41 | import re
42 |
43 | def fix_url_dependencies(req: str) -> str:
44 | """Pip and setuptools disagree about how URL dependencies should be handled."""
45 | m = re.match(
46 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git",
47 | req,
48 | )
49 | if m is None:
50 | return req
51 | else:
52 | return f"{m.group('name')} @ {req}"
53 |
54 | for line in requirements_file:
55 | line = line.strip()
56 | if line.startswith("#") or len(line) <= 0:
57 | continue
58 | if (
59 | line.startswith("-f")
60 | or line.startswith("--find-links")
61 | or line.startswith("--index-url")
62 | or line.startswith("--extra-index-url")
63 | ):
64 | find_links.append(line.split(" ", maxsplit=1)[-1].strip())
65 | continue
66 |
67 | req, *needed_by = line.split("# needed by:")
68 | req = fix_url_dependencies(req.strip())
69 | if needed_by:
70 | for extra in needed_by[0].strip().split(","):
71 | extra = extra.strip()
72 | if allowed_extras is not None and extra not in allowed_extras:
73 | raise ValueError(f"invalid extra '{extra}' in {path}")
74 | extras[extra].append(req)
75 | if include_all_extra and req not in extras["all"]:
76 | # if "gpu" in extra:
77 | # extras["all-gpu"].append(req)
78 | # else:
79 | extras["all"].append(req)
80 |
81 | else:
82 | requirements.append(req)
83 | return requirements, extras, find_links
84 |
85 |
86 | allowed_extras = {
87 | "serve",
88 | "ray",
89 | "train",
90 | "retriever",
91 | "reader",
92 | "all",
93 | "faiss",
94 | "faiss-gpu",
95 | "dev",
96 | }
97 |
98 | # Load requirements.
99 | install_requirements, extras, find_links = parse_requirements_file(
100 | "requirements.txt", allowed_extras=allowed_extras
101 | )
102 |
103 | # version.py defines the VERSION and VERSION_SHORT variables.
104 | # We use exec here, so we don't import allennlp whilst setting up.
105 | VERSION = {} # type: ignore
106 | with open("wsl/version.py", "r") as version_file:
107 | exec(version_file.read(), VERSION)
108 |
109 | with open("README.md", "r") as fh:
110 | long_description = fh.read()
111 |
112 | setuptools.setup(
113 | name="wsl",
114 | version=VERSION["VERSION"],
115 | author="Andrei Stefan Bejgu, Edoardo Barba, Riccardo Orlando, Pere-Lluís Huguet Cabot, Roberto Navigli",
116 | author_email="info@babelscape.com",
117 | description="W",
118 | long_description=long_description,
119 | long_description_content_type="text/markdown",
120 | url="https://github.com/Babelscape/wsl",
121 | keywords="NLP Sapienza sapienzanlp babelscape deep learning transformer pytorch retriever word sense linking word sense disambiguation reader",
122 | packages=setuptools.find_packages(),
123 | include_package_data=True,
124 | license="CC BY-NC-SA 4.0",
125 | classifiers=[
126 | "Intended Audience :: Science/Research",
127 | "Programming Language :: Python :: 3",
128 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
129 | ],
130 | install_requires=install_requirements,
131 | extras_require=extras,
132 | python_requires=">=3.10",
133 | find_links=find_links,
134 |
135 | )
136 |
--------------------------------------------------------------------------------
/wsl/retriever/data/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 | from typing import Dict, List, Optional, Union
5 |
6 | import numpy as np
7 | import transformers as tr
8 | from tqdm import tqdm
9 |
10 |
11 | class HardNegativesManager:
12 | def __init__(
13 | self,
14 | tokenizer: tr.PreTrainedTokenizer,
15 | data: Union[List[Dict], os.PathLike, Dict[int, List]] = None,
16 | max_length: int = 64,
17 | batch_size: int = 1000,
18 | lazy: bool = False,
19 | ) -> None:
20 | self._db: dict = None
21 | self.tokenizer = tokenizer
22 |
23 | if data is None:
24 | self._db = {}
25 | else:
26 | if isinstance(data, Dict):
27 | self._db = data
28 | elif isinstance(data, os.PathLike):
29 | with open(data) as f:
30 | self._db = json.load(f)
31 | else:
32 | raise ValueError(
33 | f"Data type {type(data)} not supported, only Dict and os.PathLike are supported."
34 | )
35 | # add the tokenizer to the class for future use
36 | self.tokenizer = tokenizer
37 |
38 | # invert the db to have a passage -> sample_idx mapping
39 | self._passage_db = defaultdict(set)
40 | for sample_idx, passages in self._db.items():
41 | for passage in passages:
42 | self._passage_db[passage].add(sample_idx)
43 |
44 | self._passage_hard_negatives = {}
45 | if not lazy:
46 | # create a dictionary of passage -> hard_negative mapping
47 | batch_size = min(batch_size, len(self._passage_db))
48 | unique_passages = list(self._passage_db.keys())
49 | for i in tqdm(
50 | range(0, len(unique_passages), batch_size),
51 | desc="Tokenizing Hard Negatives",
52 | ):
53 | batch = unique_passages[i : i + batch_size]
54 | tokenized_passages = self.tokenizer(
55 | batch,
56 | max_length=max_length,
57 | truncation=True,
58 | )
59 | for i, passage in enumerate(batch):
60 | self._passage_hard_negatives[passage] = {
61 | k: tokenized_passages[k][i] for k in tokenized_passages.keys()
62 | }
63 |
64 | def __len__(self) -> int:
65 | return len(self._db)
66 |
67 | def __getitem__(self, idx: int) -> Dict:
68 | return self._db[idx]
69 |
70 | def __iter__(self):
71 | for sample in self._db:
72 | yield sample
73 |
74 | def __contains__(self, idx: int) -> bool:
75 | return idx in self._db
76 |
77 | def get(self, idx: int) -> List[str]:
78 | """Get the hard negatives for a given sample index."""
79 | if idx not in self._db:
80 | raise ValueError(f"Sample index {idx} not in the database.")
81 |
82 | passages = self._db[idx]
83 |
84 | output = []
85 | for passage in passages:
86 | if passage not in self._passage_hard_negatives:
87 | self._passage_hard_negatives[passage] = self._tokenize(passage)
88 | output.append(self._passage_hard_negatives[passage])
89 |
90 | return output
91 |
92 | def _tokenize(self, passage: str) -> Dict:
93 | return self.tokenizer(passage, max_length=self.max_length, truncation=True)
94 |
95 |
96 | class NegativeSampler:
97 | def __init__(
98 | self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None
99 | ):
100 | if not isinstance(probabilities, np.ndarray):
101 | probabilities = np.array(probabilities)
102 |
103 | if probabilities is None:
104 | # probabilities should sum to 1
105 | probabilities = np.random.random(num_elements)
106 | probabilities /= np.sum(probabilities)
107 | self.probabilities = probabilities
108 |
109 | def __call__(
110 | self,
111 | sample_size: int,
112 | num_samples: int = 1,
113 | probabilities: np.array = None,
114 | exclude: List[int] = None,
115 | ) -> np.array:
116 | """
117 | Fast sampling of `sample_size` elements from `num_elements` elements.
118 | The sampling is done by randomly shifting the probabilities and then
119 | finding the smallest of the negative numbers. This is much faster than
120 | sampling from a multinomial distribution.
121 |
122 | Args:
123 | sample_size (`int`):
124 | number of elements to sample
125 | num_samples (`int`, optional):
126 | number of samples to draw. Defaults to 1.
127 | probabilities (`np.array`, optional):
128 | probabilities of each element. Defaults to None.
129 | exclude (`List[int]`, optional):
130 | indices of elements to exclude. Defaults to None.
131 |
132 | Returns:
133 | `np.array`: array of sampled indices
134 | """
135 | if probabilities is None:
136 | probabilities = self.probabilities
137 |
138 | if exclude is not None:
139 | probabilities[exclude] = 0
140 | # re-normalize?
141 | # probabilities /= np.sum(probabilities)
142 |
143 | # replicate probabilities as many times as `num_samples`
144 | replicated_probabilities = np.tile(probabilities, (num_samples, 1))
145 | # get random shifting numbers & scale them correctly
146 | random_shifts = np.random.random(replicated_probabilities.shape)
147 | random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis]
148 | # shift by numbers & find largest (by finding the smallest of the negative)
149 | shifted_probabilities = random_shifts - replicated_probabilities
150 | sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[
151 | :, :sample_size
152 | ]
153 | return sampled_indices
154 |
--------------------------------------------------------------------------------
/wsl/inference/data/splitters/spacy_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Iterable, List, Optional, Union
2 |
3 | import spacy
4 |
5 | from wsl.inference.data.objects import Word
6 | from wsl.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
7 | from wsl.inference.data.tokenizers.spacy_tokenizer import load_spacy
8 |
9 | SPACY_LANGUAGE_MAPPER = {
10 | "cs": "xx_sent_ud_sm",
11 | "da": "xx_sent_ud_sm",
12 | "de": "xx_sent_ud_sm",
13 | "fa": "xx_sent_ud_sm",
14 | "fi": "xx_sent_ud_sm",
15 | "fr": "xx_sent_ud_sm",
16 | "el": "el_core_news_sm",
17 | "en": "xx_sent_ud_sm",
18 | "es": "xx_sent_ud_sm",
19 | "ga": "xx_sent_ud_sm",
20 | "hr": "xx_sent_ud_sm",
21 | "id": "xx_sent_ud_sm",
22 | "it": "xx_sent_ud_sm",
23 | "ja": "ja_core_news_sm",
24 | "lv": "xx_sent_ud_sm",
25 | "lt": "xx_sent_ud_sm",
26 | "mr": "xx_sent_ud_sm",
27 | "nb": "xx_sent_ud_sm",
28 | "nl": "xx_sent_ud_sm",
29 | "no": "xx_sent_ud_sm",
30 | "pl": "pl_core_news_sm",
31 | "pt": "xx_sent_ud_sm",
32 | "ro": "xx_sent_ud_sm",
33 | "ru": "xx_sent_ud_sm",
34 | "sk": "xx_sent_ud_sm",
35 | "sr": "xx_sent_ud_sm",
36 | "sv": "xx_sent_ud_sm",
37 | "te": "xx_sent_ud_sm",
38 | "vi": "xx_sent_ud_sm",
39 | "zh": "zh_core_web_sm",
40 | }
41 |
42 |
43 | class SpacySentenceSplitter(BaseSentenceSplitter):
44 | """
45 | A :obj:`SentenceSplitter` that uses spaCy's built-in sentence boundary detection.
46 |
47 | Args:
48 | language (:obj:`str`, optional, defaults to :obj:`en`):
49 | Language of the text to tokenize.
50 | model_type (:obj:`str`, optional, defaults to :obj:`statistical`):
51 | Three different type of sentence splitter:
52 | - ``dependency``: sentence splitter uses a dependency parse to detect sentence boundaries,
53 | slow, but accurate.
54 | - ``statistical``:
55 | - ``rule_based``: It's fast and has a small memory footprint, since it uses punctuation to detect
56 | sentence boundaries.
57 | """
58 |
59 | def __init__(self, language: str = "en", model_type: str = "statistical") -> None:
60 | # we need spacy's dependency parser if we're not using rule-based sentence boundary detection.
61 | # self.spacy = get_spacy_model(language, parse=not rule_based, ner=False)
62 | dep = bool(model_type == "dependency")
63 | if language in SPACY_LANGUAGE_MAPPER:
64 | self.spacy = load_spacy(SPACY_LANGUAGE_MAPPER[language], parse=dep)
65 | else:
66 | self.spacy = spacy.blank(language)
67 | # force type to rule_based since there is no pre-trained model
68 | model_type = "rule_based"
69 | if model_type == "dependency":
70 | # dependency type must declared at model init
71 | pass
72 | elif model_type == "statistical":
73 | if not self.spacy.has_pipe("senter"):
74 | self.spacy.enable_pipe("senter")
75 | elif model_type == "rule_based":
76 | # we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection.
77 | # depending on the spacy version, it could be called 'sentencizer' or 'sbd'
78 | if not self.spacy.has_pipe("sentencizer"):
79 | self.spacy.add_pipe("sentencizer")
80 | else:
81 | raise ValueError(
82 | f"type {model_type} not supported. Choose between `dependency`, `statistical` or `rule_based`"
83 | )
84 |
85 | def __call__(
86 | self,
87 | texts: Union[str, List[str], List[List[str]]],
88 | max_length: Optional[int] = None,
89 | is_split_into_words: bool = False,
90 | **kwargs,
91 | ) -> Union[List[str], List[List[str]]]:
92 | """
93 | Tokenize the input into single words using SpaCy models.
94 |
95 | Args:
96 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
97 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
98 | max_len (:obj:`int`, optional, defaults to :obj:`0`):
99 | Maximum length of a single text. If the text is longer than `max_len`, it will be split
100 | into multiple sentences.
101 |
102 | Returns:
103 | :obj:`List[List[str]]`: The input doc split into sentences.
104 | """
105 | # check if input is batched or a single sample
106 | is_batched = self.check_is_batched(texts, is_split_into_words)
107 |
108 | if is_batched:
109 | sents = self.split_sentences_batch(texts)
110 | else:
111 | sents = self.split_sentences(texts, max_length)
112 | return sents
113 |
114 | @staticmethod
115 | def chunked(iterable, n: int) -> Iterable[List[Any]]:
116 | """
117 | Chunks a list into n sized chunks.
118 |
119 | Args:
120 | iterable (:obj:`List[Any]`):
121 | List to chunk.
122 | n (:obj:`int`):
123 | Size of the chunks.
124 |
125 | Returns:
126 | :obj:`Iterable[List[Any]]`: The input list chunked into n sized chunks.
127 | """
128 | return [iterable[i : i + n] for i in range(0, len(iterable), n)]
129 |
130 | def split_sentences(
131 | self, text: str | List[Word], max_length: Optional[int] = None, *args, **kwargs
132 | ) -> List[str]:
133 | """
134 | Splits a `text` into smaller sentences.
135 |
136 | Args:
137 | text (:obj:`str`):
138 | Text to split.
139 | max_length (:obj:`int`, optional, defaults to :obj:`0`):
140 | Maximum length of a single sentence. If the text is longer than `max_len`, it will be split
141 | into multiple sentences.
142 |
143 | Returns:
144 | :obj:`List[str]`: The input text split into sentences.
145 | """
146 | sentences = [sent for sent in self.spacy(text).sents]
147 | if max_length is not None and max_length > 0:
148 | sentences = [
149 | chunk
150 | for sentence in sentences
151 | for chunk in self.chunked(sentence, max_length)
152 | ]
153 | return sentences
154 |
--------------------------------------------------------------------------------
/wsl/reader/utils/relik_reader_predictor.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Iterable, Iterator, List, Optional
3 |
4 | import hydra
5 | import torch
6 | from lightning.pytorch.utilities import move_data_to_device
7 | from torch.utils.data import DataLoader, IterableDataset
8 | from tqdm import tqdm
9 |
10 | from wsl.reader.data.patches import merge_patches_predictions
11 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample, load_wsl_reader_samples
12 | from wsl.reader.pytorch_modules.base import WSLReaderBase
13 | from wsl.reader.utils.special_symbols import NME_SYMBOL
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def convert_tokens_to_char_annotations(
19 | sample: WSLReaderSample, remove_nmes: bool = False
20 | ):
21 | char_annotations = set()
22 |
23 | for (
24 | predicted_entity,
25 | predicted_spans,
26 | ) in sample.predicted_window_labels.items():
27 | if predicted_entity == NME_SYMBOL and remove_nmes:
28 | continue
29 |
30 | for span_start, span_end in predicted_spans:
31 | span_start = sample.token2char_start[str(span_start)]
32 | span_end = sample.token2char_end[str(span_end)]
33 |
34 | char_annotations.add((span_start, span_end, predicted_entity))
35 |
36 | char_probs_annotations = dict()
37 | for (
38 | span_start,
39 | span_end,
40 | ), candidates_probs in sample.span_title_probabilities.items():
41 | span_start = sample.token2char_start[str(span_start)]
42 | span_end = sample.token2char_end[str(span_end)]
43 | char_probs_annotations[(span_start, span_end)] = {
44 | title for title, _ in candidates_probs
45 | }
46 |
47 | sample.predicted_window_labels_chars = char_annotations
48 | sample.probs_window_labels_chars = char_probs_annotations
49 |
50 |
51 | class WSLReaderPredictor:
52 | def __init__(
53 | self,
54 | wsl_reader_core: WSLReaderBase,
55 | dataset_conf: Optional[dict] = None,
56 | predict_nmes: bool = False,
57 | dataloader: Optional[DataLoader] = None,
58 | ) -> None:
59 | self.wsl_reader_core = wsl_reader_core
60 | self.dataset_conf = dataset_conf
61 | self.predict_nmes = predict_nmes
62 | self.dataloader: DataLoader | None = dataloader
63 |
64 | if self.dataset_conf is not None and self.dataset is None:
65 | # instantiate dataset
66 | self.dataset = hydra.utils.instantiate(
67 | dataset_conf,
68 | dataset_path=None,
69 | samples=None,
70 | )
71 |
72 | def predict(
73 | self,
74 | path: Optional[str],
75 | samples: Optional[Iterable[WSLReaderSample]],
76 | dataset_conf: Optional[dict],
77 | token_batch_size: int = 1024,
78 | progress_bar: bool = False,
79 | **kwargs,
80 | ) -> List[WSLReaderSample]:
81 | annotated_samples = list(
82 | self._predict(path, samples, dataset_conf, token_batch_size, progress_bar)
83 | )
84 | for sample in annotated_samples:
85 | merge_patches_predictions(sample)
86 | convert_tokens_to_char_annotations(
87 | sample, remove_nmes=not self.predict_nmes
88 | )
89 | return annotated_samples
90 |
91 | def _predict(
92 | self,
93 | path: Optional[str],
94 | samples: Optional[Iterable[WSLReaderSample]],
95 | dataset_conf: dict,
96 | token_batch_size: int = 1024,
97 | progress_bar: bool = False,
98 | **kwargs,
99 | ) -> Iterator[WSLReaderSample]:
100 | assert (
101 | path is not None or samples is not None
102 | ), "Either predict on a path or on an iterable of samples"
103 |
104 | next_prediction_position = 0
105 | position2predicted_sample = {}
106 |
107 | if self.dataloader is not None:
108 | iterator = self.dataloader
109 | for i, sample in enumerate(self.dataloader.dataset.samples):
110 | sample._mixin_prediction_position = i
111 | else:
112 | samples = load_wsl_reader_samples(path) if samples is None else samples
113 |
114 | # setup infrastructure to re-yield in order
115 | def samples_it():
116 | for i, sample in enumerate(samples):
117 | assert sample._mixin_prediction_position is None
118 | sample._mixin_prediction_position = i
119 | yield sample
120 |
121 | # instantiate dataset
122 | if getattr(self, "dataset", None) is not None:
123 | dataset = self.dataset
124 | dataset.samples = samples_it()
125 | dataset.tokens_per_batch = token_batch_size
126 | else:
127 | dataset = hydra.utils.instantiate(
128 | dataset_conf,
129 | dataset_path=None,
130 | samples=samples_it(),
131 | tokens_per_batch=token_batch_size,
132 | )
133 |
134 | # instantiate dataloader
135 | iterator = DataLoader(
136 | dataset, batch_size=None, num_workers=0, shuffle=False
137 | )
138 | if progress_bar:
139 | iterator = tqdm(iterator, desc="Predicting")
140 |
141 | model_device = next(self.wsl_reader_core.parameters()).device
142 |
143 | with torch.inference_mode():
144 | for batch in iterator:
145 | # do batch predict
146 | with torch.autocast(
147 | "cpu" if model_device == torch.device("cpu") else "cuda"
148 | ):
149 | batch = move_data_to_device(batch, model_device)
150 | batch_out = self.wsl_reader_core._batch_predict(**batch)
151 | # update prediction position position
152 | for sample in batch_out:
153 | if sample._mixin_prediction_position >= next_prediction_position:
154 | position2predicted_sample[
155 | sample._mixin_prediction_position
156 | ] = sample
157 |
158 | # yield
159 | while next_prediction_position in position2predicted_sample:
160 | yield position2predicted_sample[next_prediction_position]
161 | del position2predicted_sample[next_prediction_position]
162 | next_prediction_position += 1
163 |
164 | if len(position2predicted_sample) > 0:
165 | logger.warning(
166 | "It seems samples have been discarded in your dataset. "
167 | "This means that you WON'T have a prediction for each input sample. "
168 | "Prediction order will also be partially disrupted"
169 | )
170 | for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]):
171 | yield v
172 |
173 | if progress_bar:
174 | iterator.close()
175 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom
2 | wsl/test.py
3 |
4 | data/*
5 | experiments/*
6 | models
7 | retrievers
8 | outputs
9 | wandb
10 | pretrained_configs
11 | lightning_logs
12 | hf_weights
13 |
14 | # Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
15 | # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
16 |
17 | ### JetBrains+all ###
18 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
19 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
20 |
21 | # User-specific stuff
22 | .idea/**/workspace.xml
23 | .idea/**/tasks.xml
24 | .idea/**/usage.statistics.xml
25 | .idea/**/dictionaries
26 | .idea/**/shelf
27 |
28 | # Generated files
29 | .idea/**/contentModel.xml
30 |
31 | # Sensitive or high-churn files
32 | .idea/**/dataSources/
33 | .idea/**/dataSources.ids
34 | .idea/**/dataSources.local.xml
35 | .idea/**/sqlDataSources.xml
36 | .idea/**/dynamic.xml
37 | .idea/**/uiDesigner.xml
38 | .idea/**/dbnavigator.xml
39 |
40 | # Gradle
41 | .idea/**/gradle.xml
42 | .idea/**/libraries
43 |
44 | # Gradle and Maven with auto-import
45 | # When using Gradle or Maven with auto-import, you should exclude module files,
46 | # since they will be recreated, and may cause churn. Uncomment if using
47 | # auto-import.
48 | # .idea/artifacts
49 | # .idea/compiler.xml
50 | # .idea/jarRepositories.xml
51 | # .idea/modules.xml
52 | # .idea/*.iml
53 | # .idea/modules
54 | # *.iml
55 | # *.ipr
56 |
57 | # CMake
58 | cmake-build-*/
59 |
60 | # Mongo Explorer plugin
61 | .idea/**/mongoSettings.xml
62 |
63 | # File-based project format
64 | *.iws
65 |
66 | # IntelliJ
67 | out/
68 |
69 | # mpeltonen/sbt-idea plugin
70 | .idea_modules/
71 |
72 | # JIRA plugin
73 | atlassian-ide-plugin.xml
74 |
75 | # Cursive Clojure plugin
76 | .idea/replstate.xml
77 |
78 | # Crashlytics plugin (for Android Studio and IntelliJ)
79 | com_crashlytics_export_strings.xml
80 | crashlytics.properties
81 | crashlytics-build.properties
82 | fabric.properties
83 |
84 | # Editor-based Rest Client
85 | .idea/httpRequests
86 |
87 | # Android studio 3.1+ serialized cache file
88 | .idea/caches/build_file_checksums.ser
89 |
90 | ### JetBrains+all Patch ###
91 | # Ignores the whole .idea folder and all .iml files
92 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360
93 |
94 | .idea/
95 |
96 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023
97 |
98 | *.iml
99 | modules.xml
100 | .idea/misc.xml
101 | *.ipr
102 |
103 | # Sonarlint plugin
104 | .idea/sonarlint
105 |
106 | ### JupyterNotebooks ###
107 | # gitignore template for Jupyter Notebooks
108 | # website: http://jupyter.org/
109 |
110 | .ipynb_checkpoints
111 | */.ipynb_checkpoints/*
112 |
113 | # IPython
114 | profile_default/
115 | ipython_config.py
116 |
117 | # Remove previous ipynb_checkpoints
118 | # git rm -r .ipynb_checkpoints/
119 |
120 | ### Linux ###
121 | *~
122 |
123 | # temporary files which can be created if a process still has a handle open of a deleted file
124 | .fuse_hidden*
125 |
126 | # KDE directory preferences
127 | .directory
128 |
129 | # Linux trash folder which might appear on any partition or disk
130 | .Trash-*
131 |
132 | # .nfs files are created when an open file is removed but is still being accessed
133 | .nfs*
134 |
135 | ### macOS ###
136 | # General
137 | .DS_Store
138 | .AppleDouble
139 | .LSOverride
140 |
141 | # Icon must end with two \r
142 | Icon
143 |
144 |
145 | # Thumbnails
146 | ._*
147 |
148 | # Files that might appear in the root of a volume
149 | .DocumentRevisions-V100
150 | .fseventsd
151 | .Spotlight-V100
152 | .TemporaryItems
153 | .Trashes
154 | .VolumeIcon.icns
155 | .com.apple.timemachine.donotpresent
156 |
157 | # Directories potentially created on remote AFP share
158 | .AppleDB
159 | .AppleDesktop
160 | Network Trash Folder
161 | Temporary Items
162 | .apdisk
163 |
164 | ### Python ###
165 | # Byte-compiled / optimized / DLL files
166 | __pycache__/
167 | *.py[cod]
168 | *$py.class
169 |
170 | # C extensions
171 | *.so
172 |
173 | # Distribution / packaging
174 | .Python
175 | build/
176 | develop-eggs/
177 | dist/
178 | downloads/
179 | eggs/
180 | .eggs/
181 | lib/
182 | lib64/
183 | parts/
184 | sdist/
185 | var/
186 | wheels/
187 | pip-wheel-metadata/
188 | share/python-wheels/
189 | *.egg-info/
190 | .installed.cfg
191 | *.egg
192 | MANIFEST
193 |
194 | # PyInstaller
195 | # Usually these files are written by a python script from a template
196 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
197 | *.manifest
198 | *.spec
199 |
200 | # Installer logs
201 | pip-log.txt
202 | pip-delete-this-directory.txt
203 |
204 | # Unit test / coverage reports
205 | htmlcov/
206 | .tox/
207 | .nox/
208 | .coverage
209 | .coverage.*
210 | .cache
211 | nosetests.xml
212 | coverage.xml
213 | *.cover
214 | *.py,cover
215 | .hypothesis/
216 | .pytest_cache/
217 | pytestdebug.log
218 |
219 | # Translations
220 | *.mo
221 | *.pot
222 |
223 | # Django stuff:
224 | *.log
225 | local_settings.py
226 | db.sqlite3
227 | db.sqlite3-journal
228 |
229 | # Flask stuff:
230 | instance/
231 | .webassets-cache
232 |
233 | # Scrapy stuff:
234 | .scrapy
235 |
236 | # Sphinx documentation
237 | docs/_build/
238 | doc/_build/
239 |
240 | # PyBuilder
241 | target/
242 |
243 | # Jupyter Notebook
244 |
245 | # IPython
246 |
247 | # pyenv
248 | .python-version
249 |
250 | # pipenv
251 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
252 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
253 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
254 | # install all needed dependencies.
255 | #Pipfile.lock
256 |
257 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
258 | __pypackages__/
259 |
260 | # Celery stuff
261 | celerybeat-schedule
262 | celerybeat.pid
263 |
264 | # SageMath parsed files
265 | *.sage.py
266 |
267 | # Environments
268 | .env
269 | .venv
270 | env/
271 | venv/
272 | ENV/
273 | env.bak/
274 | venv.bak/
275 | pythonenv*
276 |
277 | # Spyder project settings
278 | .spyderproject
279 | .spyproject
280 |
281 | # Rope project settings
282 | .ropeproject
283 |
284 | # mkdocs documentation
285 | /site
286 |
287 | # mypy
288 | .mypy_cache/
289 | .dmypy.json
290 | dmypy.json
291 |
292 | # Pyre type checker
293 | .pyre/
294 |
295 | # pytype static type analyzer
296 | .pytype/
297 |
298 | # profiling data
299 | .prof
300 |
301 | ### vscode ###
302 | .vscode
303 | .vscode/*
304 | !.vscode/settings.json
305 | !.vscode/tasks.json
306 | !.vscode/launch.json
307 | !.vscode/extensions.json
308 | *.code-workspace
309 |
310 | ### Windows ###
311 | # Windows thumbnail cache files
312 | Thumbs.db
313 | Thumbs.db:encryptable
314 | ehthumbs.db
315 | ehthumbs_vista.db
316 |
317 | # Dump file
318 | *.stackdump
319 |
320 | # Folder config file
321 | [Dd]esktop.ini
322 |
323 | # Recycle Bin used on file shares
324 | $RECYCLE.BIN/
325 |
326 | # Windows Installer files
327 | *.cab
328 | *.msi
329 | *.msix
330 | *.msm
331 | *.msp
332 |
333 | # Windows shortcuts
334 | *.lnk
335 |
336 | # End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos
--------------------------------------------------------------------------------
/wsl/retriever/data/labels.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import Dict, List, Set, Union
4 |
5 |
6 | class Labels:
7 | """
8 | Class that contains the labels for a model.
9 |
10 | Args:
11 | _labels_to_index (:obj:`Dict[str, Dict[str, int]]`):
12 | A dictionary from :obj:`str` to :obj:`int`.
13 | _index_to_labels (:obj:`Dict[str, Dict[int, str]]`):
14 | A dictionary from :obj:`int` to :obj:`str`.
15 | """
16 |
17 | def __init__(
18 | self,
19 | _labels_to_index: Dict[str, Dict[str, int]] = None,
20 | _index_to_labels: Dict[str, Dict[int, str]] = None,
21 | **kwargs,
22 | ):
23 | self._labels_to_index = _labels_to_index or {"labels": {}}
24 | self._index_to_labels = _index_to_labels or {"labels": {}}
25 | # if _labels_to_index is not empty and _index_to_labels is not provided
26 | # to the constructor, build the inverted label dictionary
27 | if not _index_to_labels and _labels_to_index:
28 | for namespace in self._labels_to_index:
29 | self._index_to_labels[namespace] = {
30 | v: k for k, v in self._labels_to_index[namespace].items()
31 | }
32 |
33 | def get_index_from_label(self, label: str, namespace: str = "labels") -> int:
34 | """
35 | Returns the index of a literal label.
36 |
37 | Args:
38 | label (:obj:`str`):
39 | The string representation of the label.
40 | namespace (:obj:`str`, optional, defaults to ``labels``):
41 | The namespace where the label belongs, e.g. ``roles`` for a SRL task.
42 |
43 | Returns:
44 | :obj:`int`: The index of the label.
45 | """
46 | if namespace not in self._labels_to_index:
47 | raise ValueError(
48 | f"Provided namespace `{namespace}` is not in the label dictionary."
49 | )
50 |
51 | if label not in self._labels_to_index[namespace]:
52 | raise ValueError(f"Provided label {label} is not in the label dictionary.")
53 |
54 | return self._labels_to_index[namespace][label]
55 |
56 | def get_label_from_index(self, index: int, namespace: str = "labels") -> str:
57 | """
58 | Returns the string representation of the label index.
59 |
60 | Args:
61 | index (:obj:`int`):
62 | The index of the label.
63 | namespace (:obj:`str`, optional, defaults to ``labels``):
64 | The namespace where the label belongs, e.g. ``roles`` for a SRL task.
65 |
66 | Returns:
67 | :obj:`str`: The string representation of the label.
68 | """
69 | if namespace not in self._index_to_labels:
70 | raise ValueError(
71 | f"Provided namespace `{namespace}` is not in the label dictionary."
72 | )
73 |
74 | if index not in self._index_to_labels[namespace]:
75 | raise ValueError(
76 | f"Provided label `{index}` is not in the label dictionary."
77 | )
78 |
79 | return self._index_to_labels[namespace][index]
80 |
81 | def add_labels(
82 | self,
83 | labels: Union[str, List[str], Set[str], Dict[str, int]],
84 | namespace: str = "labels",
85 | ) -> List[int]:
86 | """
87 | Adds the labels in input in the label dictionary.
88 |
89 | Args:
90 | labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`):
91 | The labels (single label, list of labels or set of labels) to add to the dictionary.
92 | namespace (:obj:`str`, optional, defaults to ``labels``):
93 | Namespace where the labels belongs.
94 |
95 | Returns:
96 | :obj:`List[int]`: The index of the labels just inserted.
97 | """
98 | if isinstance(labels, dict):
99 | self._labels_to_index[namespace] = labels
100 | self._index_to_labels[namespace] = {
101 | v: k for k, v in self._labels_to_index[namespace].items()
102 | }
103 | # normalize input
104 | if isinstance(labels, (str, list)):
105 | labels = set(labels)
106 | # if new namespace, add to the dictionaries
107 | if namespace not in self._labels_to_index:
108 | self._labels_to_index[namespace] = {}
109 | self._index_to_labels[namespace] = {}
110 | # returns the new indices
111 | return [self._add_label(label, namespace) for label in labels]
112 |
113 | def _add_label(self, label: str, namespace: str = "labels") -> int:
114 | """
115 | Adds the label in input in the label dictionary.
116 |
117 | Args:
118 | label (:obj:`str`):
119 | The label to add to the dictionary.
120 | namespace (:obj:`str`, optional, defaults to ``labels``):
121 | Namespace where the label belongs.
122 |
123 | Returns:
124 | :obj:`List[int]`: The index of the label just inserted.
125 | """
126 | if label not in self._labels_to_index[namespace]:
127 | index = len(self._labels_to_index[namespace])
128 | self._labels_to_index[namespace][label] = index
129 | self._index_to_labels[namespace][index] = label
130 | return index
131 | else:
132 | return self._labels_to_index[namespace][label]
133 |
134 | def get_labels(self, namespace: str = "labels") -> Dict[str, int]:
135 | """
136 | Returns all the labels that belongs to the input namespace.
137 |
138 | Args:
139 | namespace (:obj:`str`, optional, defaults to ``labels``):
140 | Labels namespace to retrieve.
141 |
142 | Returns:
143 | :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``.
144 | """
145 | if namespace not in self._labels_to_index:
146 | raise ValueError(
147 | f"Provided namespace `{namespace}` is not in the label dictionary."
148 | )
149 | return self._labels_to_index[namespace]
150 |
151 | def get_label_size(self, namespace: str = "labels") -> int:
152 | """
153 | Returns the number of the labels in the namespace dictionary.
154 |
155 | Args:
156 | namespace (:obj:`str`, optional, defaults to ``labels``):
157 | Labels namespace to retrieve.
158 |
159 | Returns:
160 | :obj:`int`: Number of labels.
161 | """
162 | if namespace not in self._labels_to_index:
163 | raise ValueError(
164 | f"Provided namespace `{namespace}` is not in the label dictionary."
165 | )
166 | return len(self._labels_to_index[namespace])
167 |
168 | def get_namespaces(self) -> List[str]:
169 | """
170 | Returns all the namespaces in the label dictionary.
171 |
172 | Returns:
173 | :obj:`List[str]`: The namespaces in the label dictionary.
174 | """
175 | return list(self._labels_to_index.keys())
176 |
177 | @classmethod
178 | def from_file(cls, file_path: Union[str, Path, dict], **kwargs):
179 | with open(file_path, "r") as f:
180 | labels_to_index = json.load(f)
181 | return cls(labels_to_index, **kwargs)
182 |
183 | def save(self, file_path: Union[str, Path, dict], indent: int = 2, **kwargs):
184 | with open(file_path, "w") as f:
185 | json.dump(self._labels_to_index, f, indent=indent)
186 |
--------------------------------------------------------------------------------
/wsl/inference/data/tokenizers/spacy_tokenizer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Dict, List, Tuple, Union
3 |
4 | import spacy
5 |
6 | # from ipa.common.utils import load_spacy
7 | from spacy.cli.download import download as spacy_download
8 | from spacy.tokens import Doc
9 |
10 | from wsl.common.log import get_logger
11 | from wsl.inference.data.objects import Word
12 | from wsl.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER
13 | from wsl.inference.data.tokenizers.base_tokenizer import BaseTokenizer
14 |
15 | logger = get_logger(level=logging.DEBUG)
16 |
17 | # Spacy and Stanza stuff
18 |
19 | LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {}
20 |
21 |
22 | def load_spacy(
23 | language: str,
24 | pos_tags: bool = False,
25 | lemma: bool = False,
26 | parse: bool = False,
27 | split_on_spaces: bool = False,
28 | ) -> spacy.Language:
29 | """
30 | Download and load spacy model.
31 |
32 | Args:
33 | language (:obj:`str`, defaults to :obj:`en`):
34 | Language of the text to tokenize.
35 | pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
36 | If :obj:`True`, performs POS tagging with spacy model.
37 | lemma (:obj:`bool`, optional, defaults to :obj:`False`):
38 | If :obj:`True`, performs lemmatization with spacy model.
39 | parse (:obj:`bool`, optional, defaults to :obj:`False`):
40 | If :obj:`True`, performs dependency parsing with spacy model.
41 | split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`):
42 | If :obj:`True`, will split by spaces without performing tokenization.
43 |
44 | Returns:
45 | :obj:`spacy.Language`: The spacy model loaded.
46 | """
47 | exclude = ["vectors", "textcat", "ner"]
48 | if not pos_tags:
49 | exclude.append("tagger")
50 | if not lemma:
51 | exclude.append("lemmatizer")
52 | if not parse:
53 | exclude.append("parser")
54 |
55 | # check if the model is already loaded
56 | # if so, there is no need to reload it
57 | spacy_params = (language, pos_tags, lemma, parse, split_on_spaces)
58 | if spacy_params not in LOADED_SPACY_MODELS:
59 | try:
60 | spacy_tagger = spacy.load(language, exclude=exclude)
61 | except OSError:
62 | logger.warning(
63 | "Spacy model '%s' not found. Downloading and installing.", language
64 | )
65 | spacy_download(language)
66 | spacy_tagger = spacy.load(language, exclude=exclude)
67 |
68 | # if everything is disabled, return only the tokenizer
69 | # for faster tokenization
70 | # TODO: is it really faster?
71 | # TODO: check split_on_spaces behaviour if we don't do this if
72 | if len(exclude) >= 6 and split_on_spaces:
73 | spacy_tagger = spacy_tagger.tokenizer
74 | LOADED_SPACY_MODELS[spacy_params] = spacy_tagger
75 |
76 | return LOADED_SPACY_MODELS[spacy_params]
77 |
78 |
79 | class SpacyTokenizer(BaseTokenizer):
80 | """
81 | A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects.
82 |
83 | Args:
84 | language (:obj:`str`, optional, defaults to :obj:`en`):
85 | Language of the text to tokenize.
86 | return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`):
87 | If :obj:`True`, performs POS tagging with spacy model.
88 | return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`):
89 | If :obj:`True`, performs lemmatization with spacy model.
90 | return_deps (:obj:`bool`, optional, defaults to :obj:`False`):
91 | If :obj:`True`, performs dependency parsing with spacy model.
92 | use_gpu (:obj:`bool`, optional, defaults to :obj:`False`):
93 | If :obj:`True`, will load the Stanza model on GPU.
94 | """
95 |
96 | def __init__(
97 | self,
98 | language: str = "en",
99 | return_pos_tags: bool = False,
100 | return_lemmas: bool = False,
101 | return_deps: bool = False,
102 | use_gpu: bool = False,
103 | ):
104 | super().__init__()
105 | if language not in SPACY_LANGUAGE_MAPPER:
106 | raise ValueError(
107 | f"`{language}` language not supported. The supported "
108 | f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}."
109 | )
110 | if use_gpu:
111 | # load the model on GPU
112 | # if the GPU is not available or not correctly configured,
113 | # it will rise an error
114 | spacy.require_gpu()
115 | self.spacy = load_spacy(
116 | SPACY_LANGUAGE_MAPPER[language],
117 | return_pos_tags,
118 | return_lemmas,
119 | return_deps,
120 | )
121 |
122 | def __call__(
123 | self,
124 | texts: Union[str, List[str], List[List[str]]],
125 | is_split_into_words: bool = False,
126 | **kwargs,
127 | ) -> Union[List[Word], List[List[Word]]]:
128 | """
129 | Tokenize the input into single words using SpaCy models.
130 |
131 | Args:
132 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
133 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
134 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
135 | If :obj:`True` and the input is a string, the input is split on spaces.
136 |
137 | Returns:
138 | :obj:`List[List[Word]]`: The input text tokenized in single words.
139 |
140 | Example::
141 |
142 | >>> from wsl.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
143 |
144 | >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True)
145 | >>> spacy_tokenizer("Mary sold the car to John.")
146 |
147 | """
148 | # check if input is batched or a single sample
149 | is_batched = self.check_is_batched(texts, is_split_into_words)
150 |
151 | if is_batched:
152 | tokenized = self.tokenize_batch(texts, is_split_into_words)
153 | else:
154 | tokenized = self.tokenize(texts, is_split_into_words)
155 |
156 | return tokenized
157 |
158 | def tokenize(self, text: Union[str, List[str]], is_split_into_words: bool) -> Doc:
159 | if is_split_into_words:
160 | if isinstance(text, str):
161 | text = text.split(" ")
162 | elif isinstance(text, list):
163 | text = text
164 | else:
165 | raise ValueError(
166 | f"text must be either `str` or `list`, found: `{type(text)}`"
167 | )
168 | spaces = [True] * len(text)
169 | return self.spacy(Doc(self.spacy.vocab, words=text, spaces=spaces))
170 | return self.spacy(text)
171 |
172 | def tokenize_batch(
173 | self, texts: Union[List[str], List[List[str]]], is_split_into_words: bool
174 | ) -> list[Any] | list[Doc]:
175 | try:
176 | if is_split_into_words:
177 | if isinstance(texts[0], str):
178 | texts = [text.split(" ") for text in texts]
179 | elif isinstance(texts[0], list):
180 | texts = texts
181 | else:
182 | raise ValueError(
183 | f"text must be either `str` or `list`, found: `{type(texts[0])}`"
184 | )
185 | spaces = [[True] * len(text) for text in texts]
186 | texts = [
187 | Doc(self.spacy.vocab, words=text, spaces=space)
188 | for text, space in zip(texts, spaces)
189 | ]
190 | return list(self.spacy.pipe(texts))
191 | except AttributeError:
192 | # a WhitespaceSpacyTokenizer has no `pipe()` method, we use simple for loop
193 | return [self.spacy(tokens) for tokens in texts]
194 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path, PosixPath
4 | from typing import Any, Dict, List
5 |
6 | import torch
7 | import transformers as tr
8 | from torch.utils.data import IterableDataset
9 | from transformers import AutoConfig, AutoModel
10 |
11 | from wsl.common.log import get_logger
12 |
13 | # from wsl.common.torch_utils import load_ort_optimized_hf_model
14 | from wsl.common.utils import get_callable_from_string
15 | from wsl.inference.data.objects import AnnotationType
16 | from wsl.reader.pytorch_modules import WSL_READER_CLASS_MAP
17 | from wsl.reader.pytorch_modules.hf.modeling_wsl import WSLReaderConfig, WSLReaderSample
18 | from wsl.retriever.pytorch_modules import PRECISION_MAP
19 |
20 | logger = get_logger(__name__, level=logging.INFO)
21 |
22 |
23 | class WSLReaderBase(torch.nn.Module):
24 | default_reader_class: str | None = None
25 | default_data_class: str | None = None
26 |
27 | def __init__(
28 | self,
29 | transformer_model: str | tr.PreTrainedModel | None = None,
30 | additional_special_symbols: int = 0,
31 | num_layers: int | None = None,
32 | activation: str = "gelu",
33 | linears_hidden_size: int | None = 512,
34 | use_last_k_layers: int = 1,
35 | training: bool = False,
36 | device: str | torch.device | None = None,
37 | precision: int = 32,
38 | tokenizer: str | tr.PreTrainedTokenizer | None = None,
39 | dataset: IterableDataset | str | None = None,
40 | default_reader_class: tr.PreTrainedModel | str | None = None,
41 | **kwargs,
42 | ) -> None:
43 | super().__init__()
44 |
45 | self.default_reader_class = default_reader_class or self.default_reader_class
46 |
47 | if self.default_reader_class is None:
48 | raise ValueError("You must specify a default reader class.")
49 |
50 | # get the callable for the default reader class
51 | self.default_reader_class: tr.PreTrainedModel = get_callable_from_string(
52 | self.default_reader_class
53 | )
54 | if isinstance(transformer_model, PosixPath):
55 | transformer_model = str(transformer_model)
56 | if isinstance(transformer_model, str):
57 | self.name_or_path = transformer_model
58 | config = AutoConfig.from_pretrained(
59 | transformer_model, trust_remote_code=True
60 | )
61 | if "wsl-reader" in config.model_type:
62 | transformer_model = self.default_reader_class.from_pretrained(
63 | transformer_model,
64 | config=config,
65 | ignore_mismatched_sizes=True,
66 | trust_remote_code=True,
67 | **kwargs,
68 | )
69 | else:
70 | reader_config = WSLReaderConfig(
71 | transformer_model=transformer_model,
72 | additional_special_symbols=additional_special_symbols,
73 | num_layers=num_layers,
74 | activation=activation,
75 | linears_hidden_size=linears_hidden_size,
76 | use_last_k_layers=use_last_k_layers,
77 | training=training,
78 | **kwargs,
79 | )
80 | transformer_model = self.default_reader_class(reader_config)
81 | self.name_or_path = transformer_model.config.transformer_model
82 | else:
83 | self.name_or_path = transformer_model.config.transformer_model
84 |
85 | self.wsl_reader_model = transformer_model
86 |
87 | self.wsl_reader_model_config = self.wsl_reader_model.config
88 |
89 | # get the tokenizer
90 | self._tokenizer = tokenizer
91 |
92 | # and instantiate the dataset class
93 | self.dataset: IterableDataset | None = dataset
94 |
95 | # move the model to the device
96 | self.to(device or torch.device("cpu"))
97 |
98 | # set the precision
99 | self.precision = precision
100 | self.to(PRECISION_MAP[precision])
101 |
102 | def forward(self, **kwargs) -> Dict[str, Any]:
103 | return self.wsl_reader_model(**kwargs)
104 |
105 | def _read(self, *args, **kwargs) -> Any:
106 | raise NotImplementedError
107 |
108 | @torch.no_grad()
109 | @torch.inference_mode()
110 | def read(
111 | self,
112 | text: List[str] | List[List[str]] | None = None,
113 | samples: List[WSLReaderSample] | None = None,
114 | input_ids: torch.Tensor | None = None,
115 | attention_mask: torch.Tensor | None = None,
116 | token_type_ids: torch.Tensor | None = None,
117 | prediction_mask: torch.Tensor | None = None,
118 | special_symbols_mask: torch.Tensor | None = None,
119 | candidates: List[List[str]] | None = None,
120 | max_length: int = 1000,
121 | max_batch_size: int = 128,
122 | token_batch_size: int = 2048,
123 | precision: int | str | None = None,
124 | annotation_type: str | AnnotationType = AnnotationType.CHAR,
125 | progress_bar: bool = False,
126 | *args,
127 | **kwargs,
128 | ) -> List[WSLReaderSample] | List[List[WSLReaderSample]]:
129 | """
130 | Reads the given text.
131 |
132 | Args:
133 | text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`):
134 | The text to read in tokens. If a list of list of tokens is provided, each
135 | inner list is considered a sentence.
136 | samples (:obj:`List[RelikReaderSample]`, `optional`):
137 | The samples to read. If provided, `text` and `candidates` are ignored.
138 | input_ids (:obj:`torch.Tensor`, `optional`):
139 | The input ids of the text.
140 | attention_mask (:obj:`torch.Tensor`, `optional`):
141 | The attention mask of the text.
142 | token_type_ids (:obj:`torch.Tensor`, `optional`):
143 | The token type ids of the text.
144 | prediction_mask (:obj:`torch.Tensor`, `optional`):
145 | The prediction mask of the text.
146 | special_symbols_mask (:obj:`torch.Tensor`, `optional`):
147 | The special symbols mask of the text.
148 | candidates (:obj:`List[List[str]]`, `optional`):
149 | The candidates of the text.
150 | max_length (:obj:`int`, `optional`, defaults to 1024):
151 | The maximum length of the text.
152 | max_batch_size (:obj:`int`, `optional`, defaults to 128):
153 | The maximum batch size.
154 | token_batch_size (:obj:`int`, `optional`):
155 | The maximum number of tokens per batch.
156 | precision (:obj:`int` or :obj:`str`, `optional`):
157 | The precision to use. If not provided, the default is 32 bit.
158 | annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`):
159 | The type of annotation to return. If `char`, the spans will be in terms of
160 | character offsets. If `word`, the spans will be in terms of word offsets.
161 | progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`):
162 | Whether to show a progress bar.
163 |
164 | Returns:
165 | The predicted labels for each sample.
166 | """
167 | if isinstance(annotation_type, str):
168 | try:
169 | annotation_type = AnnotationType(annotation_type)
170 | except ValueError:
171 | raise ValueError(
172 | f"Annotation type `{annotation_type}` not recognized. "
173 | f"Please choose one of {list(AnnotationType)}."
174 | )
175 |
176 | if text is None and input_ids is None and samples is None:
177 | raise ValueError(
178 | "Either `text` or `input_ids` or `samples` must be provided."
179 | )
180 | if (input_ids is None and samples is None) and (
181 | text is None or candidates is None
182 | ):
183 | raise ValueError(
184 | "`text` and `candidates` must be provided to return the predictions when "
185 | "`input_ids` and `samples` is not provided."
186 | )
187 | if text is not None and samples is None:
188 | if len(text) != len(candidates):
189 | raise ValueError("`text` and `candidates` must have the same length.")
190 | if isinstance(text[0], str): # change to list of text
191 | text = [text]
192 | candidates = [candidates]
193 |
194 | samples = [
195 | WSLReaderSample(tokens=t, candidates=c)
196 | for t, c in zip(text, candidates)
197 | ]
198 |
199 | return self._read(
200 | samples,
201 | input_ids,
202 | attention_mask,
203 | token_type_ids,
204 | prediction_mask,
205 | special_symbols_mask,
206 | max_length,
207 | max_batch_size,
208 | token_batch_size,
209 | precision or self.precision,
210 | annotation_type,
211 | progress_bar,
212 | *args,
213 | **kwargs,
214 | )
215 |
216 | @property
217 | def device(self) -> torch.device:
218 | """
219 | The device of the model.
220 | """
221 | return next(self.parameters()).device
222 |
223 | @property
224 | def tokenizer(self) -> tr.PreTrainedTokenizer:
225 | """
226 | The tokenizer.
227 | """
228 | if self._tokenizer:
229 | return self._tokenizer
230 |
231 | self._tokenizer = tr.AutoTokenizer.from_pretrained(
232 | self.wsl_reader_model.config.name_or_path
233 | if self.wsl_reader_model.config.name_or_path
234 | else self.wsl_reader_model.config.transformer_model
235 | )
236 | return self._tokenizer
237 |
238 | @classmethod
239 | def from_pretrained(
240 | cls,
241 | model_name_or_dir: str | os.PathLike,
242 | **kwargs,
243 | ):
244 | transformer_model = AutoModel.from_pretrained(
245 | model_name_or_dir, trust_remote_code=True, **kwargs
246 | )
247 | if transformer_model.__class__.__name__ not in WSL_READER_CLASS_MAP:
248 | raise ValueError(f"Model type {type(transformer_model)} not recognized.")
249 |
250 | reader_class = WSL_READER_CLASS_MAP[transformer_model.__class__.__name__]
251 | reader_class = get_callable_from_string(reader_class)
252 | return reader_class(transformer_model=transformer_model, **kwargs)
253 |
254 | def save_pretrained(
255 | self,
256 | output_dir: str | os.PathLike,
257 | model_name: str | None = None,
258 | push_to_hub: bool = False,
259 | **kwargs,
260 | ) -> None:
261 | """
262 | Saves the model to the given path.
263 |
264 | Args:
265 | output_dir (`str` or :obj:`os.PathLike`):
266 | The path to save the model to.
267 | model_name (`str`, `optional`):
268 | The name of the model. If not provided, the model will be saved as
269 | `default_reader_class.__name__`.
270 | push_to_hub (`bool`, `optional`, defaults to `False`):
271 | Whether to push the model to the HuggingFace Hub.
272 | **kwargs:
273 | Additional keyword arguments to pass to the `save_pretrained` method
274 | """
275 | # create the output directory
276 | output_dir = Path(output_dir)
277 | output_dir.mkdir(parents=True, exist_ok=True)
278 |
279 | model_name = model_name or output_dir.name
280 |
281 | logger.info(f"Saving reader to {output_dir / model_name}")
282 |
283 | # save the model
284 | self.wsl_reader_model.config.register_for_auto_class()
285 | self.wsl_reader_model.register_for_auto_class()
286 | self.wsl_reader_model.save_pretrained(
287 | str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs
288 | )
289 |
290 | if self.tokenizer:
291 | logger.info("Saving also the tokenizer")
292 | self.tokenizer.save_pretrained(
293 | str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs
294 | )
295 |
--------------------------------------------------------------------------------
/wsl/retriever/indexers/document.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import pickle
4 | import sys
5 | from pathlib import Path
6 | from typing import Dict, List, Union
7 |
8 | from wsl.common.log import get_logger
9 | from wsl.common.utils import JsonSerializable
10 |
11 | csv.field_size_limit(sys.maxsize)
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | class Document:
17 | def __init__(
18 | self,
19 | text: str,
20 | id: int = None,
21 | metadata: Dict = None,
22 | **kwargs,
23 | ):
24 | self.text = text
25 | # if id is not provided, we use the hash of the text
26 | self.id = id if id is not None else hash(text)
27 | # if metadata is not provided, we use an empty dictionary
28 | self.metadata = metadata or {}
29 |
30 | def __str__(self):
31 | return f"{self.id}:{self.text}"
32 |
33 | def __repr__(self):
34 | return json.dumps(self.to_dict())
35 |
36 | def __eq__(self, other):
37 | if isinstance(other, Document):
38 | return self.id == other.id
39 | elif isinstance(other, int):
40 | return self.id == other
41 | elif isinstance(other, str):
42 | return self.text == other
43 | else:
44 | raise ValueError(
45 | f"Document must be compared with a Document, an int or a str, got `{type(other)}`"
46 | )
47 |
48 | def to_dict(self):
49 | return {"text": self.text, "id": self.id, "metadata": self.metadata}
50 |
51 | def to_json(self):
52 | return json.dumps(self.to_dict())
53 |
54 | @classmethod
55 | def from_dict(cls, d: Dict):
56 | return cls(**d)
57 |
58 | @classmethod
59 | def from_file(cls, file_path: Union[str, Path], **kwargs):
60 | with open(file_path, "r") as f:
61 | d = json.load(f)
62 | return cls.from_dict(d)
63 |
64 | def save(self, file_path: Union[str, Path], **kwargs):
65 | with open(file_path, "w") as f:
66 | json.dump(self.to_dict(), f, indent=2)
67 |
68 |
69 | class DocumentStore:
70 | """
71 | A document store is a collection of documents.
72 |
73 | Args:
74 | documents (:obj:`List[Document]`):
75 | The documents to store.
76 | """
77 |
78 | def __init__(self, documents: List[Document] = None) -> None:
79 | if documents is None:
80 | documents = []
81 | # if self.ingore_case:
82 | # documents = [doc.lower() for doc in documents]
83 | self._documents = documents
84 | # build an index for the documents
85 | self._documents_index = {doc.id: doc for doc in self._documents}
86 | # build a reverse index for the documents
87 | self._documents_reverse_index = {doc.text: doc for doc in self._documents}
88 |
89 | def __len__(self):
90 | return len(self._documents)
91 |
92 | def __getitem__(self, index):
93 | return self._documents[index]
94 |
95 | def __iter__(self):
96 | return iter(self._documents)
97 |
98 | def __contains__(self, item):
99 | if isinstance(item, int):
100 | return item in self._documents_index
101 | elif isinstance(item, str):
102 | return item in self._documents_reverse_index
103 | elif isinstance(item, Document):
104 | return item.id in self._documents_index
105 | # return item in self._documents_index
106 |
107 | def __str__(self):
108 | return f"DocumentStore with {len(self)} documents"
109 |
110 | def __repr__(self):
111 | return self.__str__()
112 |
113 | def get_document_from_id(self, id: int) -> Document:
114 | """
115 | Retrieve a document by its ID.
116 |
117 | Args:
118 | id (`int`):
119 | The ID of the document to retrieve.
120 |
121 | Returns:
122 | Optional[Document]: The document with the given ID, or None if it does not exist.
123 | """
124 | if id not in self._documents_index:
125 | logger.warning(f"Document with id `{id}` does not exist, skipping")
126 | return self._documents_index.get(id, None)
127 |
128 | def get_document_from_text(self, text: str) -> Document:
129 | """
130 | Retrieve the document by its text.
131 |
132 | Args:
133 | text (`str`):
134 | The text of the document to retrieve.
135 |
136 | Returns:
137 | Optional[Document]: The document with the given text, or None if it does not exist.
138 | """
139 | if text not in self._documents_reverse_index:
140 | logger.warning(f"Document with text `{text}` does not exist, skipping")
141 | return self._documents_reverse_index.get(text, None)
142 |
143 | def get_document_from_index(self, index: int) -> Document:
144 | """
145 | Retrieve the document by its index.
146 |
147 | Args:
148 | index (`int`):
149 | The index of the document to retrieve.
150 |
151 | Returns:
152 | Optional[Document]: The document with the given index, or None if it does not exist.
153 | """
154 | if index >= len(self._documents):
155 | logger.warning(f"Document with index `{index}` does not exist, skipping")
156 | return self._documents[index]
157 |
158 | def add_documents(
159 | self, documents: List[Document] | List[str] | List[Dict]
160 | ) -> List[Document]:
161 | """
162 | Add a list of documents to the document store.
163 |
164 | Args:
165 | documents (`List[Document]`):
166 | The documents to add.
167 |
168 | Returns:
169 | List[Document]: The documents just added.
170 | """
171 | return [
172 | (
173 | self.add_document(Document.from_dict(doc))
174 | if isinstance(doc, Dict)
175 | else self.add_document(doc)
176 | )
177 | for doc in documents
178 | ]
179 |
180 | def add_document(
181 | self,
182 | text: str | Document,
183 | id: int | None = None,
184 | metadata: Dict | None = None,
185 | ) -> Document:
186 | """
187 | Add a document to the document store.
188 |
189 | Args:
190 | text (`str`):
191 | The text of the document to add.
192 | id (`int`, optional, defaults to None):
193 | The ID of the document to add.
194 | metadata (`Dict`, optional, defaults to None):
195 | The metadata of the document to add.
196 |
197 | Returns:
198 | Document: The document just added.
199 | """
200 | if isinstance(text, str):
201 | # check if the document already exists
202 | if text in self:
203 | logger.warning(f"Document `{text}` already exists, skipping")
204 | return self._documents_reverse_index[text]
205 | if id is None:
206 | # get the len of the documents and add 1
207 | id = len(self._documents) # + 1
208 | text = Document(text, id, metadata)
209 |
210 | if text in self:
211 | logger.warning(f"Document `{text}` already exists, skipping")
212 | return self._documents_index[text.id]
213 |
214 | self._documents.append(text)
215 | self._documents_index[text.id] = text
216 | self._documents_reverse_index[text.text] = text
217 | return text
218 | # if id in self._documents_index:
219 | # logger.warning(f"Document with id `{id}` already exists, skipping")
220 | # return self._documents_index[id]
221 | # if text_or_document in self._documents_reverse_index:
222 | # logger.warning(f"Document with text `{text_or_document}` already exists, skipping")
223 | # return self._documents_reverse_index[text_or_document]
224 | # self._documents.append(Document(text_or_document, id, metadata))
225 | # self._documents_index[id] = self._documents[-1]
226 | # self._documents_reverse_index[text_or_document] = self._documents[-1]
227 | # return self._documents_index[id]
228 |
229 | def delete_document(self, document: int | str | Document) -> bool:
230 | """
231 | Delete a document from the document store.
232 |
233 | Args:
234 | document (`int`, `str` or `Document`):
235 | The document to delete.
236 |
237 | Returns:
238 | bool: True if the document has been deleted, False otherwise.
239 | """
240 | if isinstance(document, int):
241 | return self.delete_by_id(document)
242 | elif isinstance(document, str):
243 | return self.delete_by_text(document)
244 | elif isinstance(document, Document):
245 | return self.delete_by_document(document)
246 | else:
247 | raise ValueError(
248 | f"Document must be an int, a str or a Document, got `{type(document)}`"
249 | )
250 |
251 | def delete_by_id(self, id: int) -> bool:
252 | """
253 | Delete a document by its ID.
254 |
255 | Args:
256 | id (`int`):
257 | The ID of the document to delete.
258 |
259 | Returns:
260 | bool: True if the document has been deleted, False otherwise.
261 | """
262 | if id not in self._documents_index:
263 | logger.warning(f"Document with id `{id}` does not exist, skipping")
264 | return False
265 | del self._documents_reverse_index[self._documents_index[id]]
266 | del self._documents_index[id]
267 | return True
268 |
269 | def delete_by_text(self, text: str) -> bool:
270 | """
271 | Delete a document by its text.
272 |
273 | Args:
274 | text (`str`):
275 | The text of the document to delete.
276 |
277 | Returns:
278 | bool: True if the document has been deleted, False otherwise.
279 | """
280 | if text not in self._documents_reverse_index:
281 | logger.warning(f"Document with text `{text}` does not exist, skipping")
282 | return False
283 | del self._documents_reverse_index[text]
284 | del self._documents_index[self._documents_index[text]]
285 | return True
286 |
287 | def delete_by_document(self, document: Document) -> bool:
288 | """
289 | Delete a document by its text.
290 |
291 | Args:
292 | document (:obj:`Document`):
293 | The document to delete.
294 |
295 | Returns:
296 | bool: True if the document has been deleted, False otherwise.
297 | """
298 | if document.id not in self._documents_index:
299 | logger.warning(f"Document {document} does not exist, skipping")
300 | return False
301 | del self._documents[self._documents.index(document)]
302 | del self._documents_index[document.id]
303 | del self._documents_reverse_index[self._documents_index[document.id]]
304 |
305 | def to_dict(self):
306 | return [doc.to_dict() for doc in self._documents]
307 |
308 | @classmethod
309 | def from_dict(cls, d):
310 | return cls([Document.from_dict(doc) for doc in d])
311 |
312 | @classmethod
313 | def from_file(cls, file_path: Union[str, Path], **kwargs):
314 | with open(file_path, "r") as f:
315 | # load a json lines file
316 | d = [Document.from_dict(json.loads(line)) for line in f]
317 | return cls(d)
318 |
319 | @classmethod
320 | def from_pickle(cls, file_path: Union[str, Path], **kwargs):
321 | with open(file_path, "rb") as handle:
322 | d = pickle.load(handle)
323 | return cls(d)
324 |
325 | @classmethod
326 | def from_tsv(
327 | cls,
328 | file_path: Union[str, Path],
329 | ingore_case: bool = False,
330 | delimiter: str = "\t",
331 | **kwargs,
332 | ):
333 | d = []
334 | # load a tsv/csv file and take the header into account
335 | # the header must be `id\ttext\t[list of metadata keys]`
336 | with open(file_path, "r", encoding="utf8") as f:
337 | csv_reader = csv.reader(f, delimiter=delimiter, **kwargs)
338 | header = next(csv_reader)
339 | id, text, *metadata_keys = header
340 | for i, row in enumerate(csv_reader):
341 | # check if id can be casted to int
342 | # if not, we add it to the metadata and use `i` as id
343 | try:
344 | s_id = int(row[header.index(id)])
345 | row_metadata_keys = metadata_keys
346 | except ValueError:
347 | row_metadata_keys = [id] + metadata_keys
348 | s_id = i
349 |
350 | d.append(
351 | Document(
352 | text=(
353 | row[header.index(text)].strip().lower()
354 | if ingore_case
355 | else row[header.index(text)].strip()
356 | ),
357 | id=s_id, # row[header.index(id)],
358 | metadata={
359 | key: row[header.index(key)] for key in row_metadata_keys
360 | },
361 | )
362 | )
363 | return cls(d)
364 |
365 | def save(self, file_path: Union[str, Path], **kwargs):
366 | with open(file_path, "w") as f:
367 | for doc in self._documents:
368 | # save as json lines
369 | f.write(json.dumps(doc.to_dict()) + "\n")
370 |
--------------------------------------------------------------------------------
/wsl/retriever/indexers/inmemory.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import logging
3 | import os
4 | from typing import Callable, List, Optional, Union
5 |
6 | import torch
7 | import transformers as tr
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | from wsl.common.log import get_logger
12 | from wsl.common.torch_utils import get_autocast_context
13 | from wsl.retriever.common.model_inputs import ModelInputs
14 | from wsl.retriever.data.base.datasets import BaseDataset
15 | from wsl.retriever.indexers.base import BaseDocumentIndex
16 | from wsl.retriever.indexers.document import Document, DocumentStore
17 | from wsl.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample
18 |
19 | # check if ORT is available
20 | # if is_package_available("onnxruntime"):
21 |
22 | logger = get_logger(__name__, level=logging.INFO)
23 |
24 |
25 | class MatrixMultiplicationModule(torch.nn.Module):
26 | def __init__(self, embeddings):
27 | super().__init__()
28 | self.embeddings = torch.nn.Parameter(embeddings, requires_grad=False)
29 |
30 | def forward(self, query):
31 | return torch.matmul(query, self.embeddings.T)
32 |
33 |
34 | class InMemoryDocumentIndex(BaseDocumentIndex):
35 | def __init__(
36 | self,
37 | documents: Union[str, List[str], List[Document]] = None,
38 | embeddings: Optional[torch.Tensor] = None,
39 | metadata_fields: Optional[List[str]] = None,
40 | separator: Optional[str] = None,
41 | name_or_path: Union[str, os.PathLike, None] = None,
42 | device: str = "cpu",
43 | precision: Union[str, int, torch.dtype] = 32,
44 | *args,
45 | **kwargs,
46 | ) -> None:
47 | """
48 | An in-memory indexer based on PyTorch.
49 |
50 | Args:
51 | documents (:obj:`Union[List[str]]`):
52 | The documents to be indexed.
53 | embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`):
54 | The embeddings of the documents.
55 | device (:obj:`str`, `optional`, defaults to "cpu"):
56 | The device to be used for storing the embeddings.
57 | """
58 |
59 | super().__init__(
60 | documents, embeddings, metadata_fields, separator, name_or_path, device
61 | )
62 |
63 | if embeddings is not None and documents is not None:
64 | logger.info("Both documents and embeddings are provided.")
65 | if len(documents) != embeddings.shape[0]:
66 | raise ValueError(
67 | "The number of documents and embeddings must be the same. "
68 | f"Got {len(documents)} documents and {embeddings.shape[0]} embeddings."
69 | )
70 |
71 | # # embeddings of the documents
72 | # self.embeddings = embeddings
73 | # does this do anything?
74 | del embeddings
75 | # convert the embeddings to the desired precision
76 | if precision is not None:
77 | if self.embeddings is not None and self.device == "cpu":
78 | if PRECISION_MAP[precision] == PRECISION_MAP[16]:
79 | logger.info(
80 | f"Precision `{precision}` is not supported on CPU. "
81 | f"Using `{PRECISION_MAP[32]}` instead."
82 | )
83 | precision = 32
84 |
85 | if (
86 | self.embeddings is not None
87 | and self.embeddings.dtype != PRECISION_MAP[precision]
88 | ):
89 | logger.info(
90 | f"Index vectors are of type {self.embeddings.dtype}. "
91 | f"Converting to {PRECISION_MAP[precision]}."
92 | )
93 | self.embeddings = self.embeddings.to(PRECISION_MAP[precision])
94 | else:
95 | # TODO: a bit redundant, fix this eventually
96 | if (
97 | # here we trust the device_in_init, since we don't know yet
98 | # the device of the embeddings
99 | (
100 | self.device_in_init == "cpu"
101 | or self.device_in_init == torch.device("cpu")
102 | )
103 | and self.embeddings is not None
104 | and self.embeddings.dtype != torch.float32
105 | ):
106 | logger.info(
107 | f"Index vectors are of type {self.embeddings.dtype} but the device is CPU. "
108 | f"Converting to {PRECISION_MAP[32]}."
109 | )
110 | self.embeddings = self.embeddings.to(PRECISION_MAP[32])
111 |
112 | # move the embeddings to the desired device
113 | if (
114 | self.embeddings is not None
115 | and not self.embeddings.device == self.device_in_init
116 | ):
117 | self.embeddings = self.embeddings.to(self.device_in_init)
118 |
119 | # TODO: check interactions with the embeddings
120 | # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings)
121 | # self.mm.eval()
122 |
123 | # precision to be used for the embeddings
124 | self.precision = precision
125 |
126 | @torch.no_grad()
127 | @torch.inference_mode()
128 | def index(
129 | self,
130 | retriever,
131 | documents: Optional[List[Document]] = None,
132 | batch_size: int = 32,
133 | num_workers: int = 4,
134 | max_length: Optional[int] = None,
135 | collate_fn: Optional[Callable] = None,
136 | encoder_precision: Optional[Union[str, int]] = None,
137 | compute_on_cpu: bool = False,
138 | force_reindex: bool = False,
139 | ) -> "InMemoryDocumentIndex":
140 | """
141 | Index the documents using the encoder.
142 |
143 | Args:
144 | retriever (:obj:`torch.nn.Module`):
145 | The encoder to be used for indexing.
146 | documents (:obj:`List[Document]`, `optional`, defaults to :obj:`None`):
147 | The documents to be indexed. If not provided, the documents provided at the initialization will be used.
148 | batch_size (:obj:`int`, `optional`, defaults to 32):
149 | The batch size to be used for indexing.
150 | num_workers (:obj:`int`, `optional`, defaults to 4):
151 | The number of workers to be used for indexing.
152 | max_length (:obj:`int`, `optional`, defaults to None):
153 | The maximum length of the input to the encoder.
154 | collate_fn (:obj:`Callable`, `optional`, defaults to None):
155 | The collate function to be used for batching.
156 | encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None):
157 | The precision to be used for the encoder.
158 | compute_on_cpu (:obj:`bool`, `optional`, defaults to False):
159 | Whether to compute the embeddings on CPU.
160 | force_reindex (:obj:`bool`, `optional`, defaults to False):
161 | Whether to force reindexing.
162 |
163 | Returns:
164 | :obj:`InMemoryIndexer`: The indexer object.
165 | """
166 |
167 | if documents is None and self.documents is None:
168 | raise ValueError("Documents must be provided.")
169 |
170 | if self.embeddings is not None and not force_reindex and documents is None:
171 | logger.info(
172 | "Embeddings are already present and `force_reindex` is `False`. Skipping indexing."
173 | )
174 | return self
175 |
176 | if force_reindex:
177 | if documents is not None:
178 | self.documents.add_documents(documents)
179 | data = [k for k in self.get_passages()]
180 |
181 | else:
182 | if documents is not None:
183 | data = [k for k in self.get_passages(DocumentStore(documents))]
184 | # add the documents to the actual document store
185 | self.documents.add_documents(documents)
186 | else:
187 | if self.embeddings is None:
188 | data = [k for k in self.get_passages()]
189 |
190 | if collate_fn is None:
191 | tokenizer = retriever.passage_tokenizer
192 |
193 | def collate_fn(x):
194 | return ModelInputs(
195 | tokenizer(
196 | x,
197 | padding=True,
198 | return_tensors="pt",
199 | truncation=True,
200 | max_length=max_length or tokenizer.model_max_length,
201 | )
202 | )
203 |
204 | # added prefix for passage retrieve
205 | data = [retriever.passage_prefix + p for p in data]
206 | dataloader = DataLoader(
207 | BaseDataset(name="passage", data=data),
208 | batch_size=batch_size,
209 | shuffle=False,
210 | num_workers=num_workers,
211 | pin_memory=False,
212 | collate_fn=collate_fn,
213 | )
214 |
215 | encoder = retriever.passage_encoder
216 |
217 | # Create empty lists to store the passage embeddings and passage index
218 | passage_embeddings: List[torch.Tensor] = []
219 |
220 | encoder_device = "cpu" if compute_on_cpu else encoder.device
221 |
222 | # fucking autocast only wants pure strings like 'cpu' or 'cuda'
223 | # we need to convert the model device to that
224 | device_type_for_autocast = str(encoder_device).split(":")[0]
225 | # autocast doesn't work with CPU and stuff different from bfloat16
226 | autocast_pssg_mngr = (
227 | contextlib.nullcontext()
228 | if device_type_for_autocast == "cpu"
229 | else (
230 | torch.autocast(
231 | device_type=device_type_for_autocast,
232 | dtype=PRECISION_MAP[encoder_precision],
233 | )
234 | )
235 | )
236 | with autocast_pssg_mngr:
237 | # Iterate through each batch in the dataloader
238 | for batch in tqdm(dataloader, desc="Indexing"):
239 | # Move the batch to the device
240 | batch: ModelInputs = batch.to(encoder_device)
241 | # Compute the passage embeddings
242 | passage_outs = encoder(**batch).pooler_output
243 | # Append the passage embeddings to the list
244 | if self.device == "cpu":
245 | passage_embeddings.extend([c.detach().cpu() for c in passage_outs])
246 | else:
247 | passage_embeddings.extend([c for c in passage_outs])
248 |
249 | # move the passage embeddings to the CPU if not already done
250 | # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision
251 | if not self.device == "cpu": # this if is to avoid unnecessary moves
252 | passage_embeddings = [c.detach().cpu() for c in passage_embeddings]
253 | # stack it
254 | passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0)
255 | # move the passage embeddings to the gpu if needed
256 | if not self.device == "cpu":
257 | passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision])
258 | passage_embeddings = passage_embeddings.to(self.device)
259 | self.embeddings = passage_embeddings
260 | # update the matrix multiplication module
261 | # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings)
262 |
263 | # free up memory from the unused variable
264 | del passage_embeddings
265 |
266 | return self
267 |
268 | @torch.no_grad()
269 | @torch.inference_mode()
270 | def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]:
271 | """
272 | Search the documents using the query.
273 |
274 | Args:
275 | query (:obj:`torch.Tensor`):
276 | The query to be used for searching.
277 | k (:obj:`int`, `optional`, defaults to 1):
278 | The number of documents to be retrieved.
279 |
280 | Returns:
281 | :obj:`List[RetrievedSample]`: The retrieved documents.
282 | """
283 |
284 | with get_autocast_context(self.device, self.embeddings.dtype):
285 | # move query to the same device as embeddings
286 | query = query.to(self.embeddings.device)
287 | if query.dtype != self.embeddings.dtype:
288 | query = query.to(self.embeddings.dtype)
289 | similarity = torch.matmul(query, self.embeddings.T)
290 | # similarity = self.mm(query)
291 | # Retrieve the indices of the top k passage embeddings
292 | retriever_out: torch.return_types.topk = torch.topk(
293 | similarity, k=min(k, similarity.shape[-1]), dim=1
294 | )
295 |
296 | # get int values
297 | batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist()
298 | # get float values
299 | batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist()
300 | # Retrieve the passages corresponding to the indices
301 | batch_docs = [
302 | [self.documents.get_document_from_index(i) for i in indices]
303 | for indices in batch_top_k
304 | ]
305 | # build the output object
306 | batch_retrieved_samples = [
307 | [
308 | RetrievedSample(document=doc, score=score)
309 | for doc, score in zip(docs, scores)
310 | ]
311 | for docs, scores in zip(batch_docs, batch_scores)
312 | ]
313 | return batch_retrieved_samples
314 |
--------------------------------------------------------------------------------
/wsl/inference/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import hydra
4 | import torch
5 | from omegaconf import DictConfig, OmegaConf
6 |
7 | from wsl.common.log import get_logger
8 | from wsl.inference.data.objects import TaskType
9 | from wsl.reader.pytorch_modules.base import WSLReaderBase
10 | from wsl.retriever.indexers.base import BaseDocumentIndex
11 | from wsl.retriever.pytorch_modules import PRECISION_MAP
12 | from wsl.retriever.pytorch_modules.model import WSLRetriever
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | def _instantiate_retriever(
18 | retriever: WSLRetriever | DictConfig | Dict,
19 | device: str | None,
20 | precision: int | str | torch.dtype | None,
21 | **kwargs: Any,
22 | ) -> WSLRetriever:
23 | """
24 | Instantiate a retriever.
25 |
26 | Args:
27 | retriever (`GoldenRetriever`, `DictConfig` or `Dict`):
28 | The retriever to instantiate.
29 | retriever_device (`str`, `optional`):
30 | The device to use for the retriever.
31 | retriever_precision (`int`, `str` or `torch.dtype`, `optional`):
32 | The precision to use for the retriever.
33 | retriever_kwargs (`Dict[str, Any]`, `optional`):
34 | Additional keyword arguments to pass to the retriever.
35 |
36 | Returns:
37 | `GoldenRetriever`:
38 | The instantiated retriever.
39 | """
40 | if not isinstance(retriever, WSLRetriever):
41 | # convert to DictConfig
42 | retriever = hydra.utils.instantiate(
43 | OmegaConf.create(retriever),
44 | device=device,
45 | precision=precision,
46 | **kwargs,
47 | )
48 | else:
49 | if device is not None:
50 | logger.info(f"Moving retriever to `{device}`.")
51 | retriever.to(device)
52 | if precision is not None:
53 | logger.info(
54 | f"Setting precision of retriever to `{PRECISION_MAP[precision]}`."
55 | )
56 | retriever.to(PRECISION_MAP[precision])
57 | retriever.training = False
58 | retriever.eval()
59 | return retriever
60 |
61 |
62 | def load_retriever(
63 | retriever: WSLRetriever | DictConfig | Dict | str,
64 | device: str | None | torch.device | int = None,
65 | precision: int | str | torch.dtype | None = None,
66 | task: TaskType | str | None = None,
67 | compile: bool = False,
68 | **kwargs,
69 | ) -> Dict[TaskType, WSLRetriever]:
70 | """
71 | Load and instantiate retrievers for a given task.
72 |
73 | Args:
74 | retriever (GoldenRetriever | DictConfig | Dict):
75 | The retriever object or configuration.
76 | device (str | None):
77 | The device to load the retriever on.
78 | precision (int | str | torch.dtype | None):
79 | The precision of the retriever.
80 | task (TaskType):
81 | The task type for the retriever.
82 | compile (bool, optional):
83 | Whether to compile the retriever. Defaults to False.
84 | **kwargs:
85 | Additional keyword arguments to be passed to the retriever instantiation.
86 |
87 | Returns:
88 | Dict[TaskType, GoldenRetriever]:
89 | A dictionary containing the instantiated retrievers.
90 |
91 | Raises:
92 | ValueError: If the `retriever` argument is not of type `GoldenRetriever`, `DictConfig`, or `Dict`.
93 | ValueError: If the `retriever` argument is a `DictConfig` without the `_target_` key.
94 | ValueError: If the task type is not valid for each retriever in the `DictConfig`.
95 |
96 | """
97 |
98 | # retriever section
99 | _retriever: Dict[TaskType, WSLRetriever] = {
100 | TaskType.SPAN: None,
101 | }
102 |
103 | # check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict
104 | if not isinstance(retriever, (WSLRetriever, DictConfig, Dict, str)):
105 | raise ValueError(
106 | f"`retriever` must be a `GoldenRetriever`, a `DictConfig`, "
107 | f"a `Dict`, or a `str`, got `{type(retriever)}`."
108 | )
109 | if isinstance(retriever, str):
110 | logger.warning(
111 | "Using a string to instantiate the retriever. "
112 | f"We will use the same model `{retriever}` for both query and passage encoder. "
113 | "If you want to use different models, please provide a dictionary with keys `question_encoder` and `passage_encoder`."
114 | )
115 | retriever = {"question_encoder": retriever}
116 | # we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever
117 | # or a primitive Dict
118 | if isinstance(retriever, DictConfig):
119 | # then it is probably a primitive Dict
120 | if "_target_" not in retriever:
121 | retriever = OmegaConf.to_container(retriever, resolve=True)
122 | # convert the key to TaskType
123 | try:
124 | retriever = {TaskType(k.lower()): v for k, v in retriever.items()}
125 | except ValueError as e:
126 | raise ValueError(
127 | f"Please choose a valid task type (one of {list(TaskType)}) for each retriever."
128 | ) from e
129 |
130 | if isinstance(retriever, Dict):
131 | # convert the key to TaskType
132 | retriever = {TaskType(k): v for k, v in retriever.items()}
133 | else:
134 | retriever = {task: retriever}
135 |
136 | # instantiate each retriever
137 | if task in [TaskType.SPAN, TaskType.BOTH]:
138 | _retriever[TaskType.SPAN] = _instantiate_retriever(
139 | retriever[TaskType.SPAN],
140 | device,
141 | precision,
142 | **kwargs,
143 | )
144 |
145 | # clean up None retrievers from the dictionary
146 | _retriever = {task_type: r for task_type, r in _retriever.items() if r is not None}
147 | if compile:
148 | # torch compile
149 | _retriever = {
150 | task_type: torch.compile(r) for task_type, r in _retriever.items()
151 | }
152 |
153 | return _retriever
154 |
155 |
156 | def _instantiate_index(
157 | index: BaseDocumentIndex | DictConfig | Dict,
158 | device: str | None | torch.device | int = None,
159 | precision: int | str | torch.dtype | None = None,
160 | **kwargs: Dict[str, Any],
161 | ) -> BaseDocumentIndex:
162 | """
163 | Instantiate a document index.
164 |
165 | Args:
166 | index (`BaseDocumentIndex`, `DictConfig` or `Dict`):
167 | The document index to instantiate.
168 | device (`str`, `optional`):
169 | The device to use for the document index.
170 | precision (`int`, `str` or `torch.dtype`, `optional`):
171 | The precision to use for the document index.
172 | kwargs (`Dict[str, Any]`, `optional`):
173 | Additional keyword arguments to pass to the document index.
174 |
175 | Returns:
176 | `BaseDocumentIndex`:
177 | The instantiated document index.
178 | """
179 | if not isinstance(index, BaseDocumentIndex):
180 | index = OmegaConf.create(index)
181 | use_faiss = kwargs.get("use_faiss", False)
182 | if use_faiss:
183 | index = OmegaConf.merge(
184 | index,
185 | {
186 | "_target_": "wsl.retriever.indexers.faissindex.FaissDocumentIndex.from_pretrained",
187 | },
188 | )
189 | if device is not None:
190 | kwargs["device"] = device
191 | if precision is not None:
192 | kwargs["precision"] = precision
193 |
194 | # merge the kwargs
195 | index = OmegaConf.merge(index, OmegaConf.create(kwargs))
196 | index: BaseDocumentIndex = hydra.utils.instantiate(index)
197 | else:
198 | index = index
199 | if device is not None:
200 | logger.info(f"Moving index to `{device}`.")
201 | index.to(device)
202 | if precision is not None:
203 | logger.info(f"Setting precision of index to `{PRECISION_MAP[precision]}`.")
204 | index.to(PRECISION_MAP[precision])
205 | return index
206 |
207 |
208 | def load_index(
209 | index: BaseDocumentIndex | DictConfig | Dict | str,
210 | device: str | None,
211 | precision: int | str | torch.dtype | None,
212 | task: TaskType,
213 | **kwargs,
214 | ) -> Dict[TaskType, BaseDocumentIndex]:
215 | """
216 | Load the document index based on the specified parameters.
217 |
218 | Args:
219 | index (BaseDocumentIndex | DictConfig | Dict):
220 | The document index to load. It can be an instance of `BaseDocumentIndex`, a `DictConfig`, or a `Dict`.
221 | device (str | None):
222 | The device to use for loading the index. If `None`, the default device will be used.
223 | precision (int | str | torch.dtype | None):
224 | The precision of the index. If `None`, the default precision will be used.
225 | task (TaskType):
226 | The type of task for the index.
227 | **kwargs:
228 | Additional keyword arguments to be passed to the index instantiation.
229 |
230 | Returns:
231 | Dict[TaskType, BaseDocumentIndex]:
232 | A dictionary containing the loaded document index for each task type.
233 |
234 | Raises:
235 | ValueError: If the `index` parameter is not of type `BaseDocumentIndex`, `DictConfig`, or `Dict`.
236 | ValueError: If the `index` parameter is a `DictConfig` without a `_target_` key.
237 | ValueError: If the task type specified in the `index` parameter is not valid.
238 | """
239 |
240 | # index
241 | _index: Dict[TaskType, BaseDocumentIndex] = {
242 | TaskType.SPAN: None,
243 | }
244 |
245 | # check retriever type, it can be a BaseDocumentIndex, a DictConfig or a Dict
246 | if not isinstance(index, (BaseDocumentIndex, DictConfig, Dict, str)):
247 | raise ValueError(
248 | f"`index` must be a `BaseDocumentIndex`, a `DictConfig`, "
249 | f"a `Dict`, or a `str`, got `{type(index)}`."
250 | )
251 | # we need to check weather the DictConfig is a DictConfig for an instance of BaseDocumentIndex
252 | # or a primitive Dict
253 | if isinstance(index, str):
254 | index = {"name_or_path": index}
255 | if isinstance(index, DictConfig):
256 | # then it is probably a primitive Dict
257 | if "_target_" not in index:
258 | index = OmegaConf.to_container(index, resolve=True)
259 | # convert the key to TaskType
260 | try:
261 | index = {TaskType(k.lower()): v for k, v in index.items()}
262 | except ValueError as e:
263 | raise ValueError(
264 | f"Please choose a valid task type (one of {list(TaskType)}) for each index."
265 | ) from e
266 |
267 | if isinstance(index, Dict):
268 | # convert the key to TaskType
269 | index = {TaskType(k): v for k, v in index.items()}
270 | else:
271 | index = {task: index}
272 |
273 | # instantiate each retriever
274 | if task in [TaskType.SPAN, TaskType.BOTH]:
275 | _index[TaskType.SPAN] = _instantiate_index(
276 | index[TaskType.SPAN],
277 | device,
278 | precision,
279 | **kwargs,
280 | )
281 |
282 | # clean up None retrievers from the dictionary
283 | _index = {task_type: i for task_type, i in _index.items() if i is not None}
284 | return _index
285 |
286 |
287 | def load_reader(
288 | reader: WSLReaderBase,
289 | device: str | None,
290 | precision: int | str | torch.dtype | None,
291 | compile: bool = False,
292 | **kwargs: Dict[str, Any],
293 | ) -> WSLReaderBase:
294 | """
295 | Load a reader model for inference.
296 |
297 | Args:
298 | reader (WSLReaderBase):
299 | The reader model to load.
300 | device (str | None):
301 | The device to move the reader model to.
302 | precision (int | str | torch.dtype | None):
303 | The precision to set for the reader model.
304 | compile (bool, optional):
305 | Whether to compile the reader model. Defaults to False.
306 | **kwargs (Dict[str, Any]):
307 | Additional keyword arguments to pass to the reader model.
308 |
309 | Returns:
310 | WSLReaderBase: The loaded reader model.
311 | """
312 |
313 | if not isinstance(reader, (WSLReaderBase, DictConfig, Dict, str)):
314 | raise ValueError(
315 | f"`reader` must be a `WSLReaderBase`, a `DictConfig`, "
316 | f"a `Dict`, or a `str`, got `{type(reader)}`."
317 | )
318 |
319 | if isinstance(reader, str):
320 | reader = {
321 | "_target_": "wsl.reader.pytorch_modules.base.WSLReaderBase.from_pretrained",
322 | "model_name_or_dir": reader,
323 | }
324 |
325 | if not isinstance(reader, DictConfig):
326 | # then it is probably a primitive Dict
327 | # if "_target_" not in reader:
328 | # reader = OmegaConf.to_container(reader, resolve=True)
329 | # reader = OmegaConf.to_container(reader, resolve=True)
330 | # if not isinstance(reader, DictConfig):
331 | reader = OmegaConf.create(reader)
332 |
333 | reader = (
334 | hydra.utils.instantiate(
335 | reader,
336 | device=device,
337 | precision=precision,
338 | **kwargs,
339 | )
340 | if isinstance(reader, DictConfig)
341 | else reader
342 | )
343 | reader.training = False
344 | reader.eval()
345 | if device is not None:
346 | logger.info(f"Moving reader to `{device}`.")
347 | reader.to(device)
348 | if precision is not None and reader.precision != PRECISION_MAP[precision]:
349 | logger.info(f"Setting precision of reader to `{PRECISION_MAP[precision]}`.")
350 | reader.to(PRECISION_MAP[precision])
351 |
352 | if compile:
353 | reader = torch.compile(reader)
354 | return reader
355 |
--------------------------------------------------------------------------------
/wsl_data_license.txt:
--------------------------------------------------------------------------------
1 | WSL Dataset Non-Commercial license
2 |
3 | 1. Definitions
4 |
5 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original.
6 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or other works or subject matter which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined above) for the purposes of this License.
7 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
8 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, Noncommercial, ShareAlike.
9 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
10 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i)in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
11 | "Work" means the word-sense disambiguated part of SemCor provided by Babelscape under this license.
12 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
13 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
14 | "Reproduce" means to make digital or paper copies of the Work by any means including without limitation by textual, sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage in digital form or other electronic medium.
15 |
16 | 2. Fair Dealing Rights
17 |
18 | Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
19 | 3. License Grant
20 |
21 | Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
22 |
23 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections provided the Work is clearly identifiable, linked to https://github.com/Babelscape/WSL and made accessible only to research institutions;
24 | to create, Reproduce, Distribute and Publicly Perform Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, an Adaptation could be marked "This data is a processed version of the WSL dataset downloaded from https://huggingface.co/datasets/Babelscape/wsl, made available by Babelscape with the WSL Non-Commercial License" and made accessible only to research institutions. Alternatively, a link to the WSL website can be provided for download of the official data and code for the creation of the Adaptation can be provided to any user.
25 |
26 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights described in Section 4(e).
27 | 4. Restrictions
28 |
29 | The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
30 |
31 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) https://github.com/Babelscape/WSL/wsl_data_license.txt for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(d), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(d), as requested.
32 | You may Distribute or Publicly Perform an Adaptation only under: (i) the terms of this License; (ii) a later version of this License with the same License Elements as this License. You must include a copy of, or the URI https://github.com/Babelscape/WSL/wsl_data_license.txt, for Applicable License with every copy of each Adaptation You Distribute or Publicly Perform. You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License. You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
33 | You may not exercise any of the rights granted to You in Section 3 above if You are not a research institution or if in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation.
34 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work (Word Sense Linking); (iii) the URI (https://huggingface.co/datasets/Babelscape/wsl); and, (iv) consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "This data is a processed version of the WSL dataset downloaded from https://huggingface.co/datasets/Babelscape/wsl, made available by Babelscape with the WSL Non-Commercial License"). The credit required by this Section 4(d) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
35 |
36 | For the avoidance of doubt:
37 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
38 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License if Your exercise of such rights is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c) and otherwise waives the right to collect royalties through any statutory or compulsory licensing scheme; and,
39 | Voluntary License Schemes. The Licensor reserves the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License that is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c).
40 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
41 |
42 | 5. Representations, Warranties and Disclaimer
43 |
44 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING AND TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO THIS EXCLUSION MAY NOT APPLY TO YOU.
45 | 6. Limitation on Liability
46 |
47 | EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
48 | 7. Termination
49 |
50 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
51 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
52 |
53 | 8. Miscellaneous
54 |
55 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
56 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
57 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
58 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
59 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
60 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
61 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/span.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import logging
3 | from typing import Any, Dict, Iterator, List
4 |
5 | import torch
6 | import transformers as tr
7 | from lightning_fabric.utilities import move_data_to_device
8 | from torch.utils.data import DataLoader, IterableDataset
9 | from tqdm import tqdm
10 |
11 | from wsl.common.log import get_logger
12 | from wsl.common.torch_utils import get_autocast_context
13 | from wsl.common.utils import get_callable_from_string
14 | from wsl.inference.data.objects import AnnotationType
15 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample
16 | from wsl.reader.pytorch_modules.base import WSLReaderBase
17 |
18 | logger = get_logger(__name__, level=logging.INFO)
19 |
20 |
21 | class WSLReaderForSpanExtraction(WSLReaderBase):
22 | """
23 | A class for the RelikReader model for span extraction.
24 |
25 | Args:
26 | transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
27 | The transformer model to use. If `None`, the default model is used.
28 | additional_special_symbols (:obj:`int`, `optional`, defaults to 0):
29 | The number of additional special symbols to add to the tokenizer.
30 | num_layers (:obj:`int`, `optional`):
31 | The number of layers to use. If `None`, all layers are used.
32 | activation (:obj:`str`, `optional`, defaults to "gelu"):
33 | The activation function to use.
34 | linears_hidden_size (:obj:`int`, `optional`, defaults to 512):
35 | The hidden size of the linears.
36 | use_last_k_layers (:obj:`int`, `optional`, defaults to 1):
37 | The number of last layers to use.
38 | training (:obj:`bool`, `optional`, defaults to False):
39 | Whether the model is in training mode.
40 | device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`):
41 | The device to use. If `None`, the default device is used.
42 | tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`):
43 | The tokenizer to use. If `None`, the default tokenizer is used.
44 | dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`):
45 | The dataset to use. If `None`, the default dataset is used.
46 | dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`):
47 | The keyword arguments to pass to the dataset class.
48 | default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`):
49 | The default reader class to use. If `None`, the default reader class is used.
50 | **kwargs:
51 | Keyword arguments.
52 | """
53 |
54 | default_reader_class: str = (
55 | "wsl.reader.pytorch_modules.hf.modeling_wsl.WSLReaderSpanModel"
56 | )
57 | default_data_class: str = "wsl.reader.data.wsl_reader_data.WSLDataset"
58 |
59 | def __init__(
60 | self,
61 | transformer_model: str | tr.PreTrainedModel | None = None,
62 | additional_special_symbols: int = 0,
63 | num_layers: int | None = None,
64 | activation: str = "gelu",
65 | linears_hidden_size: int | None = 512,
66 | use_last_k_layers: int = 1,
67 | training: bool = False,
68 | device: str | torch.device | None = None,
69 | tokenizer: str | tr.PreTrainedTokenizer | None = None,
70 | dataset: IterableDataset | str | None = None,
71 | dataset_kwargs: Dict[str, Any] | None = None,
72 | default_reader_class: tr.PreTrainedModel | str | None = None,
73 | **kwargs,
74 | ):
75 | super().__init__(
76 | transformer_model=transformer_model,
77 | additional_special_symbols=additional_special_symbols,
78 | num_layers=num_layers,
79 | activation=activation,
80 | linears_hidden_size=linears_hidden_size,
81 | use_last_k_layers=use_last_k_layers,
82 | training=training,
83 | device=device,
84 | tokenizer=tokenizer,
85 | dataset=dataset,
86 | default_reader_class=default_reader_class,
87 | **kwargs,
88 | )
89 | # and instantiate the dataset class
90 | self.dataset = dataset
91 | if self.dataset is None:
92 | self.default_data_class = get_callable_from_string(self.default_data_class)
93 | default_data_kwargs = dict(
94 | dataset_path=None,
95 | materialize_samples=False,
96 | transformer_model=self.tokenizer,
97 | special_symbols=self.default_data_class.get_special_symbols(
98 | self.wsl_reader_model.config.additional_special_symbols
99 | ),
100 | for_inference=True,
101 | use_nme=kwargs.get("use_nme", True),
102 | )
103 | # merge the default data kwargs with the ones passed to the model
104 | default_data_kwargs.update(dataset_kwargs or {})
105 | self.dataset = self.default_data_class(**default_data_kwargs)
106 |
107 | @torch.no_grad()
108 | @torch.inference_mode()
109 | def _read(
110 | self,
111 | samples: List[WSLReaderSample] | None = None,
112 | input_ids: torch.Tensor | None = None,
113 | attention_mask: torch.Tensor | None = None,
114 | token_type_ids: torch.Tensor | None = None,
115 | prediction_mask: torch.Tensor | None = None,
116 | special_symbols_mask: torch.Tensor | None = None,
117 | max_length: int = 1000,
118 | max_batch_size: int = 128,
119 | token_batch_size: int = 2048,
120 | precision: str = 32,
121 | annotation_type: AnnotationType = AnnotationType.CHAR,
122 | progress_bar: bool = False,
123 | remove_nmes: bool = True,
124 | *args: object,
125 | **kwargs: object,
126 | ) -> List[WSLReaderSample] | List[List[WSLReaderSample]]:
127 | """
128 | A wrapper around the forward method that returns the predicted labels for each sample.
129 |
130 | Args:
131 | samples (:obj:`List[RelikReaderSample]`, `optional`):
132 | The samples to read. If provided, `text` and `candidates` are ignored.
133 | input_ids (:obj:`torch.Tensor`, `optional`):
134 | The input ids of the text. If `samples` is provided, this is ignored.
135 | attention_mask (:obj:`torch.Tensor`, `optional`):
136 | The attention mask of the text. If `samples` is provided, this is ignored.
137 | token_type_ids (:obj:`torch.Tensor`, `optional`):
138 | The token type ids of the text. If `samples` is provided, this is ignored.
139 | prediction_mask (:obj:`torch.Tensor`, `optional`):
140 | The prediction mask of the text. If `samples` is provided, this is ignored.
141 | special_symbols_mask (:obj:`torch.Tensor`, `optional`):
142 | The special symbols mask of the text. If `samples` is provided, this is ignored.
143 | max_length (:obj:`int`, `optional`, defaults to 1000):
144 | The maximum length of the text.
145 | max_batch_size (:obj:`int`, `optional`, defaults to 128):
146 | The maximum batch size.
147 | token_batch_size (:obj:`int`, `optional`):
148 | The token batch size.
149 | progress_bar (:obj:`bool`, `optional`, defaults to False):
150 | Whether to show a progress bar.
151 | precision (:obj:`str`, `optional`, defaults to 32):
152 | The precision to use for the model.
153 | annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`):
154 | The type of annotation to return. If `char`, the spans will be in terms of
155 | character offsets. If `word`, the spans will be in terms of word offsets.
156 | *args:
157 | Positional arguments.
158 | **kwargs:
159 | Keyword arguments.
160 |
161 | Returns:
162 | :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`:
163 | The predicted labels for each sample.
164 | """
165 |
166 | precision = precision or self.precision
167 | if samples is not None:
168 |
169 | def _read_iterator():
170 | def samples_it():
171 | for i, sample in enumerate(samples):
172 | assert sample._mixin_prediction_position is None
173 | sample._mixin_prediction_position = i
174 | if sample.spans is not None and len(sample.spans) > 0:
175 | sample.window_labels = [
176 | [s[0], s[1], ""] for s in sample.spans
177 | ]
178 | yield sample
179 |
180 | next_prediction_position = 0
181 | position2predicted_sample = {}
182 |
183 | # instantiate dataset
184 | if self.dataset is None:
185 | raise ValueError(
186 | "You need to pass a dataset to the model in order to predict"
187 | )
188 | self.dataset.samples = samples_it()
189 | self.dataset.model_max_length = max_length
190 | self.dataset.tokens_per_batch = token_batch_size
191 | self.dataset.max_batch_size = max_batch_size
192 | # self.dataset.batch_size = max_batch_size
193 |
194 | # instantiate dataloader
195 | iterator = DataLoader(
196 | self.dataset, batch_size=None, num_workers=0, shuffle=False
197 | )
198 | if progress_bar:
199 | iterator = tqdm(iterator, desc="Predicting with RelikReader")
200 |
201 | with get_autocast_context(self.device, precision):
202 | for batch in iterator:
203 | batch = move_data_to_device(batch, self.device)
204 | batch.update(kwargs)
205 | batch_out = self._batch_predict(**batch)
206 |
207 | for sample in batch_out:
208 | if (
209 | sample.spans is not None
210 | and len(sample.spans) > 0
211 | and sample.window_labels
212 | ):
213 | # remove window labels
214 | sample.window_labels = None
215 | if (
216 | sample._mixin_prediction_position
217 | >= next_prediction_position
218 | ):
219 | position2predicted_sample[
220 | sample._mixin_prediction_position
221 | ] = sample
222 |
223 | # yield
224 | while next_prediction_position in position2predicted_sample:
225 | yield position2predicted_sample[next_prediction_position]
226 | del position2predicted_sample[next_prediction_position]
227 | next_prediction_position += 1
228 |
229 | outputs = list(_read_iterator())
230 | for sample in outputs:
231 | self.dataset.merge_patches_predictions(sample)
232 | if annotation_type == AnnotationType.CHAR:
233 | self.dataset.convert_to_char_annotations(sample, remove_nmes)
234 | elif annotation_type == AnnotationType.WORD:
235 | self.dataset.convert_to_word_annotations(sample, remove_nmes)
236 | else:
237 | raise ValueError(
238 | f"Annotation type {annotation_type} not recognized. "
239 | f"Please choose one of {list(AnnotationType)}."
240 | )
241 |
242 | else:
243 | outputs = list(
244 | self._batch_predict(
245 | input_ids,
246 | attention_mask,
247 | token_type_ids,
248 | prediction_mask,
249 | special_symbols_mask,
250 | *args,
251 | **kwargs,
252 | )
253 | )
254 | return outputs
255 |
256 | def _batch_predict(
257 | self,
258 | input_ids: torch.Tensor,
259 | attention_mask: torch.Tensor,
260 | token_type_ids: torch.Tensor | None = None,
261 | prediction_mask: torch.Tensor | None = None,
262 | special_symbols_mask: torch.Tensor | None = None,
263 | sample: List[WSLReaderSample] | None = None,
264 | top_k: int = 5, # the amount of top-k most probable entities to predict
265 | *args,
266 | **kwargs,
267 | ) -> Iterator[WSLReaderSample]:
268 | """
269 | A wrapper around the forward method that returns the predicted labels for each sample.
270 | It also adds the predicted labels to the samples.
271 |
272 | Args:
273 | input_ids (:obj:`torch.Tensor`):
274 | The input ids of the text.
275 | attention_mask (:obj:`torch.Tensor`):
276 | The attention mask of the text.
277 | token_type_ids (:obj:`torch.Tensor`, `optional`):
278 | The token type ids of the text.
279 | prediction_mask (:obj:`torch.Tensor`, `optional`):
280 | The prediction mask of the text.
281 | special_symbols_mask (:obj:`torch.Tensor`, `optional`):
282 | The special symbols mask of the text.
283 | sample (:obj:`List[RelikReaderSample]`, `optional`):
284 | The samples to read. If provided, `text` and `candidates` are ignored.
285 | top_k (:obj:`int`, `optional`, defaults to 5):
286 | The amount of top-k most probable entities to predict.
287 | *args:
288 | Positional arguments.
289 | **kwargs:
290 | Keyword arguments.
291 |
292 | Returns:
293 | The predicted labels for each sample.
294 | """
295 | forward_output = self.forward(
296 | input_ids=input_ids,
297 | attention_mask=attention_mask,
298 | token_type_ids=token_type_ids,
299 | prediction_mask=prediction_mask,
300 | special_symbols_mask=special_symbols_mask,
301 | *args,
302 | **kwargs,
303 | )
304 |
305 | ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy()
306 | ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy()
307 | ed_predictions = forward_output["ed_predictions"].cpu().numpy()
308 | ed_probabilities = forward_output["ed_probabilities"].cpu().numpy()
309 |
310 | batch_predictable_candidates = kwargs["predictable_candidates"]
311 | patch_offset = kwargs["patch_offset"]
312 | for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip(
313 | sample,
314 | ned_start_predictions,
315 | ned_end_predictions,
316 | ed_predictions,
317 | ed_probabilities,
318 | batch_predictable_candidates,
319 | patch_offset,
320 | ):
321 | ent_count = 0
322 | ne_ep = ne_ep.cpu().numpy()
323 | ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0]
324 | # ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0]
325 | final_class2predicted_spans = collections.defaultdict(list)
326 | spans2predicted_probabilities = dict()
327 | for start_token_index, end_token_index in zip(ne_start_indices, ne_ep):
328 | for end_token_index in [
329 | ti for ti, c in enumerate(end_token_index[1:]) if c > 0
330 | ]:
331 | # predicted candidate
332 | token_class = edp[ent_count] - 1
333 | predicted_candidate_title = pred_cands[token_class]
334 | final_class2predicted_spans[predicted_candidate_title].append(
335 | [start_token_index, end_token_index]
336 | )
337 |
338 | # candidates probabilities
339 | classes_probabilities = edpr[ent_count]
340 | classes_probabilities_best_indices = (
341 | classes_probabilities.argsort()[::-1]
342 | )
343 | titles_2_probs = []
344 | top_k = (
345 | min(
346 | top_k,
347 | len(classes_probabilities_best_indices),
348 | )
349 | if top_k != -1
350 | else len(classes_probabilities_best_indices)
351 | )
352 | for i in range(top_k):
353 | titles_2_probs.append(
354 | (
355 | pred_cands[classes_probabilities_best_indices[i] - 1],
356 | classes_probabilities[
357 | classes_probabilities_best_indices[i]
358 | ].item(),
359 | )
360 | )
361 | spans2predicted_probabilities[
362 | (start_token_index, end_token_index)
363 | ] = titles_2_probs
364 | ent_count += 1
365 |
366 | if "patches" not in ts._d:
367 | ts._d["patches"] = dict()
368 |
369 | ts._d["patches"][po] = dict()
370 | sample_patch = ts._d["patches"][po]
371 |
372 | sample_patch["predicted_window_labels"] = final_class2predicted_spans
373 | sample_patch["span_title_probabilities"] = spans2predicted_probabilities
374 |
375 | # additional info
376 | sample_patch["predictable_candidates"] = pred_cands
377 |
378 | # try-out for a new format
379 | sample_patch["predicted_spans"] = final_class2predicted_spans
380 | sample_patch[
381 | "predicted_spans_probabilities"
382 | ] = spans2predicted_probabilities
383 |
384 | yield ts
385 |
--------------------------------------------------------------------------------
/wsl/reader/pytorch_modules/hf/modeling_wsl.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 |
3 | import torch
4 | from transformers import AutoModel, PreTrainedModel
5 | from transformers.activations import ClippedGELUActivation, GELUActivation
6 | from transformers.configuration_utils import PretrainedConfig
7 | from transformers.modeling_utils import PoolerEndLogits
8 |
9 | from .configuration_wsl import WSLReaderConfig
10 |
11 |
12 | class WSLReaderSample:
13 | def __init__(self, **kwargs):
14 | super().__setattr__("_d", {})
15 | self._d = kwargs
16 |
17 | def __getattribute__(self, item):
18 | return super(WSLReaderSample, self).__getattribute__(item)
19 |
20 | def __getattr__(self, item):
21 | if item.startswith("__") and item.endswith("__"):
22 | # this is likely some python library-specific variable (such as __deepcopy__ for copy)
23 | # better follow standard behavior here
24 | raise AttributeError(item)
25 | elif item in self._d:
26 | return self._d[item]
27 | else:
28 | return None
29 |
30 | def __setattr__(self, key, value):
31 | if key in self._d:
32 | self._d[key] = value
33 | else:
34 | super().__setattr__(key, value)
35 | self._d[key] = value
36 |
37 |
38 | activation2functions = {
39 | "relu": torch.nn.ReLU(),
40 | "gelu": GELUActivation(),
41 | "gelu_10": ClippedGELUActivation(-10, 10),
42 | }
43 |
44 |
45 | class PoolerEndLogitsBi(PoolerEndLogits):
46 | def __init__(self, config: PretrainedConfig):
47 | super().__init__(config)
48 | self.dense_1 = torch.nn.Linear(config.hidden_size, 2)
49 |
50 | def forward(
51 | self,
52 | hidden_states: torch.FloatTensor,
53 | start_states: Optional[torch.FloatTensor] = None,
54 | start_positions: Optional[torch.LongTensor] = None,
55 | p_mask: Optional[torch.FloatTensor] = None,
56 | ) -> torch.FloatTensor:
57 | if p_mask is not None:
58 | p_mask = p_mask.unsqueeze(-1)
59 | logits = super().forward(
60 | hidden_states,
61 | start_states,
62 | start_positions,
63 | p_mask,
64 | )
65 | return logits
66 |
67 |
68 | class WSLReaderSpanModel(PreTrainedModel):
69 | config_class = WSLReaderConfig
70 |
71 | def __init__(self, config: WSLReaderConfig, *args, **kwargs):
72 | super().__init__(config)
73 | # Transformer model declaration
74 | self.config = config
75 | self.transformer_model = (
76 | AutoModel.from_pretrained(self.config.transformer_model)
77 | if self.config.num_layers is None
78 | else AutoModel.from_pretrained(
79 | self.config.transformer_model, num_hidden_layers=self.config.num_layers
80 | )
81 | )
82 | self.transformer_model.resize_token_embeddings(
83 | self.transformer_model.config.vocab_size
84 | + self.config.additional_special_symbols
85 | )
86 |
87 | self.activation = self.config.activation
88 | self.linears_hidden_size = self.config.linears_hidden_size
89 | self.use_last_k_layers = self.config.use_last_k_layers
90 |
91 | # named entity detection layers
92 | self.ned_start_classifier = self._get_projection_layer(
93 | self.activation, last_hidden=2, layer_norm=False
94 | )
95 | if self.config.binary_end_logits:
96 | self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config)
97 | else:
98 | self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config)
99 |
100 | # END entity disambiguation layer
101 | self.ed_start_projector = self._get_projection_layer(self.activation)
102 | self.ed_end_projector = self._get_projection_layer(self.activation)
103 |
104 | self.training = self.config.training
105 |
106 | # criterion
107 | self.criterion = torch.nn.CrossEntropyLoss()
108 |
109 | def _get_projection_layer(
110 | self,
111 | activation: str,
112 | last_hidden: Optional[int] = None,
113 | input_hidden=None,
114 | layer_norm: bool = True,
115 | ) -> torch.nn.Sequential:
116 | head_components = [
117 | torch.nn.Dropout(0.1),
118 | torch.nn.Linear(
119 | (
120 | self.transformer_model.config.hidden_size * self.use_last_k_layers
121 | if input_hidden is None
122 | else input_hidden
123 | ),
124 | self.linears_hidden_size,
125 | ),
126 | activation2functions[activation],
127 | torch.nn.Dropout(0.1),
128 | torch.nn.Linear(
129 | self.linears_hidden_size,
130 | self.linears_hidden_size if last_hidden is None else last_hidden,
131 | ),
132 | ]
133 |
134 | if layer_norm:
135 | head_components.append(
136 | torch.nn.LayerNorm(
137 | self.linears_hidden_size if last_hidden is None else last_hidden,
138 | self.transformer_model.config.layer_norm_eps,
139 | )
140 | )
141 |
142 | return torch.nn.Sequential(*head_components)
143 |
144 | def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
145 | mask = mask.unsqueeze(-1)
146 | if next(self.parameters()).dtype == torch.float16:
147 | logits = logits * (1 - mask) - 65500 * mask
148 | else:
149 | logits = logits * (1 - mask) - 1e30 * mask
150 | return logits
151 |
152 | def _get_model_features(
153 | self,
154 | input_ids: torch.Tensor,
155 | attention_mask: torch.Tensor,
156 | token_type_ids: Optional[torch.Tensor],
157 | ):
158 | model_input = {
159 | "input_ids": input_ids,
160 | "attention_mask": attention_mask,
161 | "output_hidden_states": self.use_last_k_layers > 1,
162 | }
163 |
164 | if token_type_ids is not None:
165 | model_input["token_type_ids"] = token_type_ids
166 |
167 | model_output = self.transformer_model(**model_input)
168 |
169 | if self.use_last_k_layers > 1:
170 | model_features = torch.cat(
171 | model_output[1][-self.use_last_k_layers :], dim=-1
172 | )
173 | else:
174 | model_features = model_output[0]
175 |
176 | return model_features
177 |
178 | def compute_ned_end_logits(
179 | self,
180 | start_predictions,
181 | start_labels,
182 | model_features,
183 | prediction_mask,
184 | batch_size,
185 | ) -> Optional[torch.Tensor]:
186 | # todo: maybe when constraining on the spans,
187 | # we should not use a prediction_mask for the end tokens.
188 | # at least we should not during training imo
189 | start_positions = start_labels if self.training else start_predictions
190 | start_positions_indices = (
191 | torch.arange(start_positions.size(1), device=start_positions.device)
192 | .unsqueeze(0)
193 | .expand(batch_size, -1)[start_positions > 0]
194 | ).to(start_positions.device)
195 |
196 | if len(start_positions_indices) > 0:
197 | expanded_features = model_features.repeat_interleave(
198 | torch.sum(start_positions > 0, dim=-1), dim=0
199 | )
200 | expanded_prediction_mask = prediction_mask.repeat_interleave(
201 | torch.sum(start_positions > 0, dim=-1), dim=0
202 | )
203 | end_logits = self.ned_end_classifier(
204 | hidden_states=expanded_features,
205 | start_positions=start_positions_indices,
206 | p_mask=expanded_prediction_mask,
207 | )
208 |
209 | return end_logits
210 |
211 | return None
212 |
213 | def compute_classification_logits(
214 | self,
215 | model_features_start,
216 | model_features_end,
217 | special_symbols_features,
218 | ) -> torch.Tensor:
219 | model_start_features = self.ed_start_projector(model_features_start)
220 | model_end_features = self.ed_end_projector(model_features_end)
221 | model_start_features_symbols = self.ed_start_projector(special_symbols_features)
222 | model_end_features_symbols = self.ed_end_projector(special_symbols_features)
223 |
224 | model_ed_features = torch.cat(
225 | [model_start_features, model_end_features], dim=-1
226 | )
227 | special_symbols_representation = torch.cat(
228 | [model_start_features_symbols, model_end_features_symbols], dim=-1
229 | )
230 |
231 | logits = torch.bmm(
232 | model_ed_features,
233 | torch.permute(special_symbols_representation, (0, 2, 1)),
234 | )
235 |
236 | logits = self._mask_logits(logits, (model_features_start == -100).all(2).long())
237 | return logits
238 |
239 | def forward(
240 | self,
241 | input_ids: torch.Tensor,
242 | attention_mask: torch.Tensor,
243 | token_type_ids: Optional[torch.Tensor] = None,
244 | prediction_mask: Optional[torch.Tensor] = None,
245 | special_symbols_mask: Optional[torch.Tensor] = None,
246 | start_labels: Optional[torch.Tensor] = None,
247 | end_labels: Optional[torch.Tensor] = None,
248 | use_predefined_spans: bool = False,
249 | *args,
250 | **kwargs,
251 | ) -> Dict[str, Any]:
252 | batch_size, seq_len = input_ids.shape
253 |
254 | model_features = self._get_model_features(
255 | input_ids, attention_mask, token_type_ids
256 | )
257 |
258 | ned_start_labels = None
259 |
260 | # named entity detection if required
261 | if use_predefined_spans: # no need to compute spans
262 | ned_start_logits, ned_start_probabilities, ned_start_predictions = (
263 | None,
264 | None,
265 | (
266 | torch.clone(start_labels)
267 | if start_labels is not None
268 | else torch.zeros_like(input_ids)
269 | ),
270 | )
271 | ned_end_logits, ned_end_probabilities, ned_end_predictions = (
272 | None,
273 | None,
274 | (
275 | torch.clone(end_labels)
276 | if end_labels is not None
277 | else torch.zeros_like(input_ids)
278 | ),
279 | )
280 | ned_start_predictions[ned_start_predictions > 0] = 1
281 | ned_end_predictions[end_labels > 0] = 1
282 | ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
283 |
284 | else: # compute spans
285 | # start boundary prediction
286 | ned_start_logits = self.ned_start_classifier(model_features)
287 | ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask)
288 | ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1)
289 | ned_start_predictions = ned_start_probabilities.argmax(dim=-1)
290 |
291 | # end boundary prediction
292 | ned_start_labels = (
293 | torch.zeros_like(start_labels) if start_labels is not None else None
294 | )
295 |
296 | if ned_start_labels is not None:
297 | ned_start_labels[start_labels == -100] = -100
298 | ned_start_labels[start_labels > 0] = 1
299 |
300 | ned_end_logits = self.compute_ned_end_logits(
301 | ned_start_predictions,
302 | ned_start_labels,
303 | model_features,
304 | prediction_mask,
305 | batch_size,
306 | )
307 |
308 | if ned_end_logits is not None:
309 | ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
310 | if not self.config.binary_end_logits:
311 | ned_end_predictions = torch.argmax(
312 | ned_end_probabilities, dim=-1, keepdim=True
313 | )
314 | ned_end_predictions = torch.zeros_like(
315 | ned_end_probabilities
316 | ).scatter_(1, ned_end_predictions, 1)
317 | else:
318 | ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
319 | else:
320 | ned_end_logits, ned_end_probabilities = None, None
321 | ned_end_predictions = ned_start_predictions.new_zeros(
322 | batch_size, seq_len
323 | )
324 |
325 | if not self.training:
326 | # if len(ned_end_predictions.shape) < 2:
327 | # print(ned_end_predictions)
328 | end_preds_count = ned_end_predictions.sum(1)
329 | # If there are no end predictions for a start prediction, remove the start prediction
330 | if (end_preds_count == 0).any() and (ned_start_predictions > 0).any():
331 | ned_start_predictions[ned_start_predictions == 1] = (
332 | end_preds_count != 0
333 | ).long()
334 | ned_end_predictions = ned_end_predictions[end_preds_count != 0]
335 |
336 | if end_labels is not None:
337 | end_labels = end_labels[~(end_labels == -100).all(2)]
338 |
339 | start_position, end_position = (
340 | (start_labels, end_labels)
341 | if self.training
342 | else (ned_start_predictions, ned_end_predictions)
343 | )
344 | start_counts = (start_position > 0).sum(1)
345 | if (start_counts > 0).any():
346 | ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
347 | # Entity disambiguation
348 | if (end_position > 0).sum() > 0:
349 | ends_count = (end_position > 0).sum(1)
350 | model_entity_start = torch.repeat_interleave(
351 | model_features[start_position > 0], ends_count, dim=0
352 | )
353 | model_entity_end = torch.repeat_interleave(
354 | model_features, start_counts, dim=0
355 | )[end_position > 0]
356 | ents_count = torch.nn.utils.rnn.pad_sequence(
357 | torch.split(ends_count, start_counts.tolist()),
358 | batch_first=True,
359 | padding_value=0,
360 | ).sum(1)
361 |
362 | model_entity_start = torch.nn.utils.rnn.pad_sequence(
363 | torch.split(model_entity_start, ents_count.tolist()),
364 | batch_first=True,
365 | padding_value=-100,
366 | )
367 |
368 | model_entity_end = torch.nn.utils.rnn.pad_sequence(
369 | torch.split(model_entity_end, ents_count.tolist()),
370 | batch_first=True,
371 | padding_value=-100,
372 | )
373 |
374 | ed_logits = self.compute_classification_logits(
375 | model_entity_start,
376 | model_entity_end,
377 | model_features[special_symbols_mask].view(
378 | batch_size, -1, model_features.shape[-1]
379 | ),
380 | )
381 | ed_probabilities = torch.softmax(ed_logits, dim=-1)
382 | ed_predictions = torch.argmax(ed_probabilities, dim=-1)
383 | else:
384 | ed_logits, ed_probabilities, ed_predictions = (
385 | None,
386 | ned_start_predictions.new_zeros(batch_size, seq_len),
387 | ned_start_predictions.new_zeros(batch_size),
388 | )
389 | # output build
390 | output_dict = dict(
391 | batch_size=batch_size,
392 | ned_start_logits=ned_start_logits,
393 | ned_start_probabilities=ned_start_probabilities,
394 | ned_start_predictions=ned_start_predictions,
395 | ned_end_logits=ned_end_logits,
396 | ned_end_probabilities=ned_end_probabilities,
397 | ned_end_predictions=ned_end_predictions,
398 | ed_logits=ed_logits,
399 | ed_probabilities=ed_probabilities,
400 | ed_predictions=ed_predictions,
401 | )
402 |
403 | # compute loss if labels
404 | if start_labels is not None and end_labels is not None and self.training:
405 | # named entity detection loss
406 |
407 | # start
408 | if ned_start_logits is not None:
409 | ned_start_loss = self.criterion(
410 | ned_start_logits.view(-1, ned_start_logits.shape[-1]),
411 | ned_start_labels.view(-1),
412 | )
413 | else:
414 | ned_start_loss = 0
415 |
416 | # end
417 | # use ents_count to assign the labels to the correct positions i.e. using end_labels -> [[0,0,4,0], [0,0,0,2]] -> [4,2] (this is just an element, for batch we need to mask it with ents_count), ie -> [[4,2,-100,-100], [3,1,2,-100], [1,3,2,5]]
418 |
419 | if ned_end_logits is not None:
420 | ed_labels = end_labels.clone()
421 | ed_labels = torch.nn.utils.rnn.pad_sequence(
422 | torch.split(ed_labels[ed_labels > 0], ents_count.tolist()),
423 | batch_first=True,
424 | padding_value=-100,
425 | )
426 | end_labels[end_labels > 0] = 1
427 | if not self.config.binary_end_logits:
428 | # transform label to position in the sequence
429 | end_labels = end_labels.argmax(dim=-1)
430 | ned_end_loss = self.criterion(
431 | ned_end_logits.view(-1, ned_end_logits.shape[-1]),
432 | end_labels.view(-1),
433 | )
434 | else:
435 | ned_end_loss = self.criterion(
436 | ned_end_logits.reshape(-1, ned_end_logits.shape[-1]),
437 | end_labels.reshape(-1).long(),
438 | )
439 |
440 | # entity disambiguation loss
441 | ed_loss = self.criterion(
442 | ed_logits.view(-1, ed_logits.shape[-1]),
443 | ed_labels.view(-1).long(),
444 | )
445 |
446 | else:
447 | ned_end_loss = 0
448 | ed_loss = 0
449 |
450 | output_dict["ned_start_loss"] = ned_start_loss
451 | output_dict["ned_end_loss"] = ned_end_loss
452 | output_dict["ed_loss"] = ed_loss
453 |
454 | output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss
455 |
456 | return output_dict
457 |
--------------------------------------------------------------------------------