├── wsl ├── common │ ├── __init__.py │ ├── torch_utils.py │ ├── upload.py │ └── log.py ├── inference │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── window │ │ │ └── __init__.py │ │ ├── splitters │ │ │ ├── __init__.py │ │ │ ├── blank_sentence_splitter.py │ │ │ ├── base_sentence_splitter.py │ │ │ ├── window_based_splitter.py │ │ │ └── spacy_sentence_splitter.py │ │ ├── tokenizers │ │ │ ├── base_tokenizer.py │ │ │ ├── __init__.py │ │ │ └── spacy_tokenizer.py │ │ └── objects.py │ └── utils.py ├── reader │ ├── data │ │ ├── __init__.py │ │ ├── wsl_reader_data_utils.py │ │ ├── patches.py │ │ └── wsl_reader_sample.py │ ├── utils │ │ ├── __init__.py │ │ ├── special_symbols.py │ │ ├── metrics.py │ │ ├── strong_matching_eval.py │ │ └── relik_reader_predictor.py │ ├── trainer │ │ ├── __init__.py │ │ └── predict.py │ ├── pytorch_modules │ │ ├── hf │ │ │ ├── __init__.py │ │ │ ├── configuration_wsl.py │ │ │ └── modeling_wsl.py │ │ ├── __init__.py │ │ ├── base.py │ │ └── span.py │ └── __init__.py ├── retriever │ ├── common │ │ ├── __init__.py │ │ └── model_inputs.py │ ├── data │ │ ├── __init__.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ └── datasets.py │ │ ├── utils.py │ │ └── labels.py │ ├── indexers │ │ ├── __init__.py │ │ ├── document.py │ │ └── inmemory.py │ ├── __init__.py │ └── pytorch_modules │ │ ├── __init__.py │ │ └── hf.py ├── __init__.py └── version.py ├── MANIFEST.in ├── constraints.cpu.txt ├── SETUP.cfg ├── assets └── Sapienza_Babelscape.png ├── .flake8 ├── pyproject.toml ├── requirements.txt ├── .pre-commit-config.yaml ├── README.md ├── setup.py ├── .gitignore └── wsl_data_license.txt /wsl/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/reader/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/reader/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/inference/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/reader/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/retriever/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/retriever/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/inference/data/window/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/retriever/data/base/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/retriever/indexers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /wsl/inference/data/splitters/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/hf/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_wsl import WSLReaderConfig 2 | -------------------------------------------------------------------------------- /wsl/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from wsl.retriever.pytorch_modules.model import WSLRetriever 2 | -------------------------------------------------------------------------------- /constraints.cpu.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | torch==2.1.0 3 | -------------------------------------------------------------------------------- /SETUP.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [build] 5 | build-base = /tmp/build 6 | -------------------------------------------------------------------------------- /assets/Sapienza_Babelscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Babelscape/WSL/HEAD/assets/Sapienza_Babelscape.png -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401, E402, C901 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/__init__.py: -------------------------------------------------------------------------------- 1 | WSL_READER_CLASS_MAP = { 2 | "WSLReaderSpanModel": "wsl.reader.pytorch_modules.span.WSLReaderForSpanExtraction", 3 | } 4 | -------------------------------------------------------------------------------- /wsl/reader/__init__.py: -------------------------------------------------------------------------------- 1 | # from wsl.reader.pytorch_modules.base import RelikReaderBase 2 | # from wsl.reader.pytorch_modules.span import RelikReaderForSpanExtraction 3 | # from wsl.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 4 | -------------------------------------------------------------------------------- /wsl/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from wsl.inference.annotator import WSL 4 | 5 | VERSION = {} # type: ignore 6 | with open(Path(__file__).parent / "version.py", "r") as version_file: 7 | exec(version_file.read(), VERSION) 8 | 9 | __version__ = VERSION["VERSION"] 10 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | include = '\.pyi?$' 3 | exclude = ''' 4 | /( 5 | \.git 6 | | \.hg 7 | | \.mypy_cache 8 | | \.tox 9 | | \.venv 10 | | _build 11 | | buck-out 12 | | build 13 | | dist 14 | )/ 15 | ''' 16 | [tool.isort] 17 | profile = 'black' 18 | line_length = 120 19 | known_third_party = ["numpy", "pytest", "wandb", "torch"] 20 | known_local_folder = "wsl" 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #------- Core dependencies ------- 2 | --extra-index-url https://download.pytorch.org/whl/cu12.1 3 | torch==2.3.1 4 | 5 | transformers[sentencepiece]>=4.41,<4.42 6 | rich>=13.0.0,<14.0.0 7 | scikit-learn>=1.3,<1.4 8 | overrides>=7.4,<7.9 9 | art==6.2 10 | pprintpp==0.4.0 11 | colorama==0.4.6 12 | termcolor==2.4.0 13 | spacy>=3.7,<3.8 14 | typer>=0.12,<0.13 15 | hydra-core 16 | 17 | lightning>=2.0,<2.1 18 | -------------------------------------------------------------------------------- /wsl/reader/utils/special_symbols.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | NME_SYMBOL = "--NME--" 4 | 5 | 6 | def get_special_symbols(num_entities: int) -> List[str]: 7 | return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] 8 | 9 | 10 | def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]: 11 | if use_nme: 12 | return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] 13 | else: 14 | return [f"[R-{i}]" for i in range(num_entities)] 15 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/PyCQA/isort.git 3 | rev: "5.12.0" 4 | hooks: 5 | - id: isort 6 | - repo: https://github.com/ambv/black 7 | rev: '22.3.0' 8 | hooks: 9 | - id: black 10 | - repo: https://github.com/pycqa/flake8 11 | rev: '6.1.0' 12 | hooks: 13 | - id: flake8 14 | - repo: https://github.com/PyCQA/autoflake 15 | rev: v2.3.1 16 | hooks: 17 | - id: autoflake 18 | 19 | default_language_version: 20 | python: python3 21 | -------------------------------------------------------------------------------- /wsl/version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _MAJOR = "1" 4 | _MINOR = "0" 5 | # On main and in a nightly release the patch should be one ahead of the last 6 | # released build. 7 | _PATCH = "0" 8 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 9 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 10 | _SUFFIX = os.environ.get("WSL_VERSION_SUFFIX", "") 11 | 12 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 13 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 14 | -------------------------------------------------------------------------------- /wsl/reader/utils/metrics.py: -------------------------------------------------------------------------------- 1 | def safe_divide(num: float, den: float) -> float: 2 | if den == 0: 3 | return 0 4 | else: 5 | return num / den 6 | 7 | 8 | def f1_measure(precision: float, recall: float) -> float: 9 | if precision == 0 or recall == 0: 10 | return 0.0 11 | return safe_divide(2 * precision * recall, (precision + recall)) 12 | 13 | 14 | def compute_metrics(total_correct, total_preds, total_gold): 15 | precision = safe_divide(total_correct, total_preds) 16 | recall = safe_divide(total_correct, total_gold) 17 | f1 = f1_measure(precision, recall) 18 | return precision, recall, f1 19 | -------------------------------------------------------------------------------- /wsl/retriever/pytorch_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from wsl.retriever.indexers.document import Document 6 | 7 | PRECISION_MAP = { 8 | None: torch.float32, 9 | 32: torch.float32, 10 | 16: torch.float16, 11 | torch.float32: torch.float32, 12 | torch.float16: torch.float16, 13 | torch.bfloat16: torch.bfloat16, 14 | "float32": torch.float32, 15 | "float16": torch.float16, 16 | "bfloat16": torch.bfloat16, 17 | "float": torch.float32, 18 | "half": torch.float16, 19 | "32": torch.float32, 20 | "16": torch.float16, 21 | "fp32": torch.float32, 22 | "fp16": torch.float16, 23 | "bf16": torch.bfloat16, 24 | } 25 | 26 | 27 | @dataclass 28 | class RetrievedSample: 29 | """ 30 | Dataclass for the output of the GoldenRetriever model. 31 | """ 32 | 33 | score: float 34 | document: Document 35 | -------------------------------------------------------------------------------- /wsl/inference/data/splitters/blank_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | 4 | class BlankSentenceSplitter: 5 | """ 6 | A `BlankSentenceSplitter` splits strings into sentences. 7 | """ 8 | 9 | def __call__(self, *args, **kwargs): 10 | """ 11 | Calls :meth:`split_sentences`. 12 | """ 13 | return self.split_sentences(*args, **kwargs) 14 | 15 | def split_sentences( 16 | self, text: str, max_len: int = 0, *args, **kwargs 17 | ) -> List[str]: 18 | """ 19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. 20 | """ 21 | return [text] 22 | 23 | def split_sentences_batch( 24 | self, texts: List[str], *args, **kwargs 25 | ) -> List[List[str]]: 26 | """ 27 | Default implementation is to just iterate over the texts and call `split_sentences`. 28 | """ 29 | return [self.split_sentences(text) for text in texts] 30 | -------------------------------------------------------------------------------- /wsl/retriever/common/model_inputs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import UserDict 4 | from typing import Any, Union 5 | 6 | import torch 7 | from lightning.fabric.utilities import move_data_to_device 8 | 9 | from wsl.common.log import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class ModelInputs(UserDict): 15 | """Model input dictionary wrapper.""" 16 | 17 | def __getattr__(self, item: str): 18 | try: 19 | return self.data[item] 20 | except KeyError: 21 | raise AttributeError(f"`ModelInputs` has no attribute `{item}`") 22 | 23 | def __getitem__(self, item: str) -> Any: 24 | return self.data[item] 25 | 26 | def __getstate__(self): 27 | return {"data": self.data} 28 | 29 | def __setstate__(self, state): 30 | if "data" in state: 31 | self.data = state["data"] 32 | 33 | def keys(self): 34 | """A set-like object providing a view on D's keys.""" 35 | return self.data.keys() 36 | 37 | def values(self): 38 | """An object providing a view on D's values.""" 39 | return self.data.values() 40 | 41 | def items(self): 42 | """A set-like object providing a view on D's items.""" 43 | return self.data.items() 44 | 45 | def to(self, device: Union[str, torch.device]) -> ModelInputs: 46 | """ 47 | Send all tensors values to device. 48 | Args: 49 | device (`str` or `torch.device`): The device to put the tensors on. 50 | Returns: 51 | :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` 52 | after modification. 53 | """ 54 | self.data = move_data_to_device(self.data, device) 55 | return self 56 | -------------------------------------------------------------------------------- /wsl/reader/data/wsl_reader_data_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def flatten(lsts: List[list]) -> list: 8 | acc_lst = list() 9 | for lst in lsts: 10 | acc_lst.extend(lst) 11 | return acc_lst 12 | 13 | 14 | def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor: 15 | return torch.nn.utils.rnn.pad_sequence( 16 | tensors, batch_first=True, padding_value=padding_value 17 | ) 18 | 19 | 20 | def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: 21 | x = max([t.shape[0] for t in tensors]) 22 | y = max([t.shape[1] for t in tensors]) 23 | out_matrix = torch.zeros((len(tensors), x, y)) 24 | out_matrix += padding_value 25 | for i, tensor in enumerate(tensors): 26 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor 27 | return out_matrix 28 | 29 | 30 | def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: 31 | x = max([t.shape[0] for t in tensors]) 32 | y = max([t.shape[1] for t in tensors]) 33 | rest = tensors[0].shape[2] 34 | out_matrix = torch.zeros((len(tensors), x, y, rest)) 35 | out_matrix += padding_value 36 | for i, tensor in enumerate(tensors): 37 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor 38 | return out_matrix 39 | 40 | 41 | def chunks(lst: list, chunk_size: int) -> List[list]: 42 | chunks_acc = list() 43 | for i in range(0, len(lst), chunk_size): 44 | chunks_acc.append(lst[i : i + chunk_size]) 45 | return chunks_acc 46 | 47 | 48 | def add_noise_to_value(value: int, noise_param: float): 49 | noise_value = value * noise_param 50 | noise = np.random.uniform(-noise_value, noise_value) 51 | return max(1, value + noise) 52 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/hf/configuration_wsl.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import AutoConfig 4 | from transformers.configuration_utils import PretrainedConfig 5 | 6 | 7 | class WSLReaderConfig(PretrainedConfig): 8 | model_type = "wsl-reader" 9 | 10 | def __init__( 11 | self, 12 | transformer_model: str = "microsoft/deberta-v3-base", 13 | additional_special_symbols: int = 101, 14 | additional_special_symbols_types: Optional[int] = 0, 15 | num_layers: Optional[int] = None, 16 | activation: str = "gelu", 17 | linears_hidden_size: Optional[int] = 512, 18 | use_last_k_layers: int = 1, 19 | entity_type_loss: bool = False, 20 | add_entity_embedding: bool = None, 21 | binary_end_logits: bool = False, 22 | training: bool = False, 23 | default_reader_class: Optional[str] = None, 24 | threshold: Optional[float] = 0.5, 25 | **kwargs 26 | ) -> None: 27 | # TODO: add name_or_path to kwargs 28 | self.transformer_model = transformer_model 29 | self.additional_special_symbols = additional_special_symbols 30 | self.additional_special_symbols_types = additional_special_symbols_types 31 | self.num_layers = num_layers 32 | self.activation = activation 33 | self.linears_hidden_size = linears_hidden_size 34 | self.use_last_k_layers = use_last_k_layers 35 | self.entity_type_loss = entity_type_loss 36 | self.add_entity_embedding = ( 37 | True 38 | if add_entity_embedding is None and entity_type_loss 39 | else add_entity_embedding 40 | ) 41 | self.threshold = threshold 42 | self.binary_end_logits = binary_end_logits 43 | self.training = training 44 | self.default_reader_class = default_reader_class 45 | super().__init__(**kwargs) 46 | -------------------------------------------------------------------------------- /wsl/reader/trainer/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pprint import pprint 3 | from typing import Optional 4 | 5 | from wsl.reader.data.wsl_reader_sample import load_wsl_reader_samples 6 | from wsl.reader.pytorch_modules.span import WSLReaderForSpanExtraction 7 | from wsl.reader.utils.strong_matching_eval import StrongMatching 8 | 9 | 10 | def predict( 11 | model_path: str, 12 | dataset_path: str, 13 | token_batch_size: int, 14 | is_eval: bool, 15 | output_path: Optional[str], 16 | ) -> None: 17 | wsl_reader = WSLReaderForSpanExtraction( 18 | model_path, dataset_kwargs={"use_nme": True}, device="cuda" 19 | ) 20 | samples = list(load_wsl_reader_samples(dataset_path)) 21 | predicted_samples = wsl_reader.read( 22 | samples=samples, token_batch_size=token_batch_size, progress_bar=True 23 | ) 24 | if is_eval: 25 | eval_dict = StrongMatching()(predicted_samples) 26 | pprint(eval_dict) 27 | if output_path is not None: 28 | with open(output_path, "w") as f: 29 | for sample in predicted_samples: 30 | f.write(sample.to_jsons() + "\n") 31 | 32 | 33 | def parse_arg() -> argparse.Namespace: 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--model-path", 37 | required=True, 38 | ) 39 | parser.add_argument("--dataset-path", "-i", required=True) 40 | parser.add_argument("--is-eval", action="store_true") 41 | parser.add_argument( 42 | "--output-path", 43 | "-o", 44 | ) 45 | parser.add_argument("--token-batch-size", default=4096) 46 | return parser.parse_args() 47 | 48 | 49 | def main(): 50 | args = parse_arg() 51 | predict( 52 | args.model_path, 53 | args.dataset_path, 54 | token_batch_size=args.token_batch_size, 55 | is_eval=args.is_eval, 56 | output_path=args.output_path, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /wsl/inference/data/splitters/base_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | 4 | class BaseSentenceSplitter: 5 | """ 6 | A `BaseSentenceSplitter` splits strings into sentences. 7 | """ 8 | 9 | def __call__(self, *args, **kwargs): 10 | """ 11 | Calls :meth:`split_sentences`. 12 | """ 13 | return self.split_sentences(*args, **kwargs) 14 | 15 | def split_sentences( 16 | self, text: str, max_len: int = 0, *args, **kwargs 17 | ) -> List[str]: 18 | """ 19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. 20 | """ 21 | raise NotImplementedError 22 | 23 | def split_sentences_batch( 24 | self, texts: List[str], *args, **kwargs 25 | ) -> List[List[str]]: 26 | """ 27 | Default implementation is to just iterate over the texts and call `split_sentences`. 28 | """ 29 | return [self.split_sentences(text) for text in texts] 30 | 31 | @staticmethod 32 | def check_is_batched( 33 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool 34 | ): 35 | """ 36 | Check if input is batched or a single sample. 37 | 38 | Args: 39 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 40 | Text to check. 41 | is_split_into_words (:obj:`bool`): 42 | If :obj:`True` and the input is a string, the input is split on spaces. 43 | 44 | Returns: 45 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. 46 | """ 47 | return bool( 48 | (not is_split_into_words and isinstance(texts, (list, tuple))) 49 | or ( 50 | is_split_into_words 51 | and isinstance(texts, (list, tuple)) 52 | and texts 53 | and isinstance(texts[0], (list, tuple)) 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /wsl/reader/data/patches.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample 4 | from wsl.reader.utils.special_symbols import NME_SYMBOL 5 | 6 | 7 | def merge_patches_predictions(sample) -> None: 8 | sample._d["predicted_window_labels"] = dict() 9 | predicted_window_labels = sample._d["predicted_window_labels"] 10 | 11 | sample._d["span_title_probabilities"] = dict() 12 | span_title_probabilities = sample._d["span_title_probabilities"] 13 | 14 | span2title = dict() 15 | for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): 16 | # selecting span predictions 17 | for predicted_title, predicted_spans in patch_info[ 18 | "predicted_window_labels" 19 | ].items(): 20 | for pred_span in predicted_spans: 21 | pred_span = tuple(pred_span) 22 | curr_title = span2title.get(pred_span) 23 | if curr_title is None or curr_title == NME_SYMBOL: 24 | span2title[pred_span] = predicted_title 25 | # else: 26 | # print("Merging at patch level") 27 | 28 | # selecting span predictions probability 29 | for predicted_span, titles_probabilities in patch_info[ 30 | "span_title_probabilities" 31 | ].items(): 32 | if predicted_span not in span_title_probabilities: 33 | span_title_probabilities[predicted_span] = titles_probabilities 34 | 35 | for span, title in span2title.items(): 36 | if title not in predicted_window_labels: 37 | predicted_window_labels[title] = list() 38 | predicted_window_labels[title].append(span) 39 | 40 | 41 | def remove_duplicate_samples( 42 | samples: List[WSLReaderSample], 43 | ) -> List[WSLReaderSample]: 44 | seen_sample = set() 45 | samples_store = [] 46 | for sample in samples: 47 | sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" 48 | if sample_id not in seen_sample: 49 | seen_sample.add(sample_id) 50 | samples_store.append(sample) 51 | return samples_store 52 | -------------------------------------------------------------------------------- /wsl/reader/data/wsl_reader_sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | 6 | 7 | class NpEncoder(json.JSONEncoder): 8 | def default(self, obj): 9 | if isinstance(obj, np.integer): 10 | return int(obj) 11 | if isinstance(obj, np.floating): 12 | return float(obj) 13 | if isinstance(obj, np.ndarray): 14 | return obj.tolist() 15 | return super(NpEncoder, self).default(obj) 16 | 17 | 18 | class WSLReaderSample: 19 | def __init__(self, **kwargs): 20 | super().__setattr__("_d", {}) 21 | self._d = kwargs 22 | 23 | def __getattribute__(self, item): 24 | return super(WSLReaderSample, self).__getattribute__(item) 25 | 26 | def __getattr__(self, item): 27 | if item.startswith("__") and item.endswith("__"): 28 | # this is likely some python library-specific variable (such as __deepcopy__ for copy) 29 | # better follow standard behavior here 30 | raise AttributeError(item) 31 | elif item in self._d: 32 | return self._d[item] 33 | else: 34 | return None 35 | 36 | def __setattr__(self, key, value): 37 | if key in self._d: 38 | self._d[key] = value 39 | else: 40 | super().__setattr__(key, value) 41 | 42 | def to_jsons(self) -> str: 43 | if "predicted_window_labels" in self._d: 44 | new_obj = { 45 | k: v 46 | for k, v in self._d.items() 47 | if k != "predicted_window_labels" and k != "span_title_probabilities" 48 | } 49 | new_obj["predicted_window_labels"] = [ 50 | [ss, se, pred_title] 51 | for (ss, se), pred_title in self.predicted_window_labels_chars 52 | ] 53 | else: 54 | return json.dumps(self._d, cls=NpEncoder) 55 | 56 | def to_dict(self) -> dict: 57 | return self._d 58 | 59 | 60 | def load_wsl_reader_samples(path: str) -> Iterable[WSLReaderSample]: 61 | with open(path) as f: 62 | for line in f: 63 | jsonl_line = json.loads(line.strip()) 64 | wsl_reader_sample = WSLReaderSample(**jsonl_line) 65 | yield wsl_reader_sample 66 | -------------------------------------------------------------------------------- /wsl/inference/data/splitters/window_based_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from wsl.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter 4 | 5 | 6 | class WindowSentenceSplitter(BaseSentenceSplitter): 7 | """ 8 | A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size. 9 | """ 10 | 11 | def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None: 12 | super(WindowSentenceSplitter, self).__init__() 13 | self.window_size = window_size 14 | self.window_stride = window_stride 15 | 16 | def __call__( 17 | self, 18 | texts: Union[str, List[str], List[List[str]]], 19 | is_split_into_words: bool = False, 20 | **kwargs, 21 | ) -> Union[List[str], List[List[str]]]: 22 | """ 23 | Tokenize the input into single words using SpaCy models. 24 | 25 | Args: 26 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 27 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 28 | 29 | Returns: 30 | :obj:`List[List[str]]`: The input doc split into sentences. 31 | """ 32 | return self.split_sentences(texts) 33 | 34 | def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]: 35 | """ 36 | Splits a `text` into sentences. 37 | 38 | Args: 39 | text (:obj:`str`): 40 | Text to split. 41 | 42 | Returns: 43 | :obj:`List[str]`: The input text split into sentences. 44 | """ 45 | 46 | if isinstance(text, str): 47 | text = text.split() 48 | sentences = [] 49 | # if window_stride is zero, we don't need overlapping windows 50 | self.window_stride = ( 51 | self.window_stride if self.window_stride != 0 else self.window_size 52 | ) 53 | for i in range(0, len(text), self.window_stride): 54 | # if the last stride is smaller than the window size, then we can 55 | # include more tokens form the previous window. 56 | if i != 0 and i + self.window_size > len(text): 57 | overflowing_tokens = i + self.window_size - len(text) 58 | if overflowing_tokens >= self.window_stride: 59 | break 60 | i -= overflowing_tokens 61 | involved_token_indices = list( 62 | range(i, min(i + self.window_size, len(text))) 63 | ) 64 | window_tokens = [text[j] for j in involved_token_indices] 65 | sentences.append(window_tokens) 66 | return sentences 67 | -------------------------------------------------------------------------------- /wsl/retriever/data/base/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch.utils.data import Dataset, IterableDataset 7 | 8 | from wsl.common.log import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class BaseDataset(Dataset): 14 | def __init__( 15 | self, 16 | name: str, 17 | path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, 18 | data: Any = None, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | self.name = name 23 | if path is None and data is None: 24 | raise ValueError("Either `path` or `data` must be provided") 25 | self.path = path 26 | self.project_folder = Path(__file__).parent.parent.parent 27 | self.data = data 28 | 29 | def __len__(self) -> int: 30 | return len(self.data) 31 | 32 | def __getitem__( 33 | self, index 34 | ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: 35 | return self.data[index] 36 | 37 | def __repr__(self) -> str: 38 | return f"Dataset({self.name=}, {self.path=})" 39 | 40 | def load( 41 | self, 42 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]], 43 | *args, 44 | **kwargs, 45 | ) -> Any: 46 | # load data from single or multiple paths in one single dataset 47 | raise NotImplementedError 48 | 49 | @staticmethod 50 | def collate_fn(batch: Any, *args, **kwargs) -> Any: 51 | raise NotImplementedError 52 | 53 | 54 | class IterableBaseDataset(IterableDataset): 55 | def __init__( 56 | self, 57 | name: str, 58 | path: Optional[Union[str, Path, List[str], List[Path]]] = None, 59 | data: Any = None, 60 | *args, 61 | **kwargs, 62 | ): 63 | super().__init__() 64 | self.name = name 65 | if path is None and data is None: 66 | raise ValueError("Either `path` or `data` must be provided") 67 | self.path = path 68 | self.project_folder = Path(__file__).parent.parent.parent 69 | self.data = data 70 | 71 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: 72 | for sample in self.data: 73 | yield sample 74 | 75 | def __repr__(self) -> str: 76 | return f"Dataset({self.name=}, {self.path=})" 77 | 78 | def load( 79 | self, 80 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]], 81 | *args, 82 | **kwargs, 83 | ) -> Any: 84 | # load data from single or multiple paths in one single dataset 85 | raise NotImplementedError 86 | 87 | @staticmethod 88 | def collate_fn(batch: Any, *args, **kwargs) -> Any: 89 | raise NotImplementedError 90 | -------------------------------------------------------------------------------- /wsl/inference/data/tokenizers/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from wsl.inference.data.objects import Word 4 | 5 | 6 | class BaseTokenizer: 7 | """ 8 | A :obj:`Tokenizer` splits strings of text into single words, optionally adds 9 | pos tags and perform lemmatization. 10 | """ 11 | 12 | def __call__( 13 | self, 14 | texts: Union[str, List[str], List[List[str]]], 15 | is_split_into_words: bool = False, 16 | **kwargs 17 | ) -> List[List[Word]]: 18 | """ 19 | Tokenize the input into single words. 20 | 21 | Args: 22 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 23 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 24 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): 25 | If :obj:`True` and the input is a string, the input is split on spaces. 26 | 27 | Returns: 28 | :obj:`List[List[Word]]`: The input text tokenized in single words. 29 | """ 30 | raise NotImplementedError 31 | 32 | def tokenize(self, text: str) -> List[Word]: 33 | """ 34 | Implements splitting words into tokens. 35 | 36 | Args: 37 | text (:obj:`str`): 38 | Text to tokenize. 39 | 40 | Returns: 41 | :obj:`List[Word]`: The input text tokenized in single words. 42 | 43 | """ 44 | raise NotImplementedError 45 | 46 | def tokenize_batch(self, texts: List[str]) -> List[List[Word]]: 47 | """ 48 | Implements batch splitting words into tokens. 49 | 50 | Args: 51 | texts (:obj:`List[str]`): 52 | Batch of text to tokenize. 53 | 54 | Returns: 55 | :obj:`List[List[Word]]`: The input batch tokenized in single words. 56 | 57 | """ 58 | return [self.tokenize(text) for text in texts] 59 | 60 | @staticmethod 61 | def check_is_batched( 62 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool 63 | ): 64 | """ 65 | Check if input is batched or a single sample. 66 | 67 | Args: 68 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 69 | Text to check. 70 | is_split_into_words (:obj:`bool`): 71 | If :obj:`True` and the input is a string, the input is split on spaces. 72 | 73 | Returns: 74 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. 75 | """ 76 | return bool( 77 | (not is_split_into_words and isinstance(texts, (list, tuple))) 78 | or ( 79 | is_split_into_words 80 | and isinstance(texts, (list, tuple)) 81 | and texts 82 | and isinstance(texts[0], (list, tuple)) 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /wsl/common/torch_utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch 4 | import transformers as tr 5 | 6 | from wsl.common.utils import is_package_available 7 | 8 | # check if ORT is available 9 | if is_package_available("onnxruntime"): 10 | from optimum.onnxruntime import ( 11 | ORTModel, 12 | ORTModelForCustomTasks, 13 | ORTModelForSequenceClassification, 14 | ORTOptimizer, 15 | ) 16 | from optimum.onnxruntime.configuration import AutoOptimizationConfig 17 | 18 | # from wsl.retriever.pytorch_modules import PRECISION_MAP 19 | 20 | 21 | def get_autocast_context( 22 | device: str | torch.device, precision: str 23 | ) -> contextlib.AbstractContextManager: 24 | # fucking autocast only wants pure strings like 'cpu' or 'cuda' 25 | # we need to convert the model device to that 26 | device_type_for_autocast = str(device).split(":")[0] 27 | 28 | from wsl.retriever.pytorch_modules import PRECISION_MAP 29 | 30 | # autocast doesn't work with CPU and stuff different from bfloat16 31 | autocast_manager = ( 32 | contextlib.nullcontext() 33 | if device_type_for_autocast in ["cpu", "mps"] 34 | and PRECISION_MAP[precision] != torch.bfloat16 35 | else ( 36 | torch.autocast( 37 | device_type=device_type_for_autocast, 38 | dtype=PRECISION_MAP[precision], 39 | ) 40 | ) 41 | ) 42 | return autocast_manager 43 | 44 | 45 | # def load_ort_optimized_hf_model( 46 | # hf_model: tr.PreTrainedModel, 47 | # provider: str = "CPUExecutionProvider", 48 | # ort_model_type: callable = "ORTModelForCustomTasks", 49 | # ) -> ORTModel: 50 | # """ 51 | # Load an optimized ONNX Runtime HF model. 52 | # 53 | # Args: 54 | # hf_model (`tr.PreTrainedModel`): 55 | # The HF model to optimize. 56 | # provider (`str`, optional): 57 | # The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider". 58 | # 59 | # Returns: 60 | # `ORTModel`: The optimized HF model. 61 | # """ 62 | # if isinstance(hf_model, ORTModel): 63 | # return hf_model 64 | # temp_dir = tempfile.mkdtemp() 65 | # hf_model.save_pretrained(temp_dir) 66 | # ort_model = ort_model_type.from_pretrained( 67 | # temp_dir, export=True, provider=provider, use_io_binding=True 68 | # ) 69 | # if is_package_available("onnxruntime"): 70 | # optimizer = ORTOptimizer.from_pretrained(ort_model) 71 | # optimization_config = AutoOptimizationConfig.O4() 72 | # optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config) 73 | # ort_model = ort_model_type.from_pretrained( 74 | # temp_dir, 75 | # export=True, 76 | # provider=provider, 77 | # use_io_binding=bool(provider == "CUDAExecutionProvider"), 78 | # ) 79 | # return ort_model 80 | # else: 81 | # raise ValueError("onnxruntime is not installed. Please install Ray with `pip install wsl[serve]`.") 82 | -------------------------------------------------------------------------------- /wsl/retriever/pytorch_modules/hf.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | from transformers import PretrainedConfig 5 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions 6 | from transformers.models.bert.modeling_bert import BertModel 7 | 8 | 9 | class WSLRetrieverConfig(PretrainedConfig): 10 | model_type = "bert" 11 | 12 | def __init__( 13 | self, 14 | vocab_size=30522, 15 | hidden_size=768, 16 | num_hidden_layers=12, 17 | num_attention_heads=12, 18 | intermediate_size=3072, 19 | hidden_act="gelu", 20 | hidden_dropout_prob=0.1, 21 | attention_probs_dropout_prob=0.1, 22 | max_position_embeddings=512, 23 | type_vocab_size=2, 24 | initializer_range=0.02, 25 | layer_norm_eps=1e-12, 26 | pad_token_id=0, 27 | position_embedding_type="absolute", 28 | use_cache=True, 29 | classifier_dropout=None, 30 | **kwargs, 31 | ): 32 | super().__init__(pad_token_id=pad_token_id, **kwargs) 33 | 34 | self.vocab_size = vocab_size 35 | self.hidden_size = hidden_size 36 | self.num_hidden_layers = num_hidden_layers 37 | self.num_attention_heads = num_attention_heads 38 | self.hidden_act = hidden_act 39 | self.intermediate_size = intermediate_size 40 | self.hidden_dropout_prob = hidden_dropout_prob 41 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 42 | self.max_position_embeddings = max_position_embeddings 43 | self.type_vocab_size = type_vocab_size 44 | self.initializer_range = initializer_range 45 | self.layer_norm_eps = layer_norm_eps 46 | self.position_embedding_type = position_embedding_type 47 | self.use_cache = use_cache 48 | self.classifier_dropout = classifier_dropout 49 | 50 | 51 | class WSLRetrieverModel(BertModel): 52 | config_class = WSLRetrieverConfig 53 | 54 | def __init__(self, config, *args, **kwargs): 55 | super().__init__(config) 56 | 57 | def forward( 58 | self, **kwargs 59 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: 60 | attention_mask = kwargs.get("attention_mask", None) 61 | model_outputs = super().forward(**kwargs) 62 | if attention_mask is None: 63 | pooler_output = model_outputs.pooler_output 64 | else: 65 | token_embeddings = model_outputs.last_hidden_state 66 | input_mask_expanded = ( 67 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 68 | ) 69 | pooler_output = torch.sum( 70 | token_embeddings * input_mask_expanded, 1 71 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 72 | 73 | if not kwargs.get("return_dict", True): 74 | return (model_outputs[0], pooler_output) + model_outputs[2:] 75 | 76 | return BaseModelOutputWithPoolingAndCrossAttentions( 77 | last_hidden_state=model_outputs.last_hidden_state, 78 | pooler_output=pooler_output, 79 | past_key_values=model_outputs.past_key_values, 80 | hidden_states=model_outputs.hidden_states, 81 | attentions=model_outputs.attentions, 82 | cross_attentions=model_outputs.cross_attentions, 83 | ) 84 | -------------------------------------------------------------------------------- /wsl/inference/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | SPACY_LANGUAGE_MAPPER = { 2 | "ca": "ca_core_news_sm", 3 | "da": "da_core_news_sm", 4 | "de": "de_core_news_sm", 5 | "el": "el_core_news_sm", 6 | "en": "en_core_web_sm", 7 | "es": "es_core_news_sm", 8 | "fr": "fr_core_news_sm", 9 | "it": "it_core_news_sm", 10 | "ja": "ja_core_news_sm", 11 | "lt": "lt_core_news_sm", 12 | "mk": "mk_core_news_sm", 13 | "nb": "nb_core_news_sm", 14 | "nl": "nl_core_news_sm", 15 | "pl": "pl_core_news_sm", 16 | "pt": "pt_core_news_sm", 17 | "ro": "ro_core_news_sm", 18 | "ru": "ru_core_news_sm", 19 | "xx": "xx_sent_ud_sm", 20 | "zh": "zh_core_web_sm", 21 | "ca_core_news_sm": "ca_core_news_sm", 22 | "ca_core_news_md": "ca_core_news_md", 23 | "ca_core_news_lg": "ca_core_news_lg", 24 | "ca_core_news_trf": "ca_core_news_trf", 25 | "da_core_news_sm": "da_core_news_sm", 26 | "da_core_news_md": "da_core_news_md", 27 | "da_core_news_lg": "da_core_news_lg", 28 | "da_core_news_trf": "da_core_news_trf", 29 | "de_core_news_sm": "de_core_news_sm", 30 | "de_core_news_md": "de_core_news_md", 31 | "de_core_news_lg": "de_core_news_lg", 32 | "de_dep_news_trf": "de_dep_news_trf", 33 | "el_core_news_sm": "el_core_news_sm", 34 | "el_core_news_md": "el_core_news_md", 35 | "el_core_news_lg": "el_core_news_lg", 36 | "en_core_web_sm": "en_core_web_sm", 37 | "en_core_web_md": "en_core_web_md", 38 | "en_core_web_lg": "en_core_web_lg", 39 | "en_core_web_trf": "en_core_web_trf", 40 | "es_core_news_sm": "es_core_news_sm", 41 | "es_core_news_md": "es_core_news_md", 42 | "es_core_news_lg": "es_core_news_lg", 43 | "es_dep_news_trf": "es_dep_news_trf", 44 | "fr_core_news_sm": "fr_core_news_sm", 45 | "fr_core_news_md": "fr_core_news_md", 46 | "fr_core_news_lg": "fr_core_news_lg", 47 | "fr_dep_news_trf": "fr_dep_news_trf", 48 | "it_core_news_sm": "it_core_news_sm", 49 | "it_core_news_md": "it_core_news_md", 50 | "it_core_news_lg": "it_core_news_lg", 51 | "ja_core_news_sm": "ja_core_news_sm", 52 | "ja_core_news_md": "ja_core_news_md", 53 | "ja_core_news_lg": "ja_core_news_lg", 54 | "ja_dep_news_trf": "ja_dep_news_trf", 55 | "lt_core_news_sm": "lt_core_news_sm", 56 | "lt_core_news_md": "lt_core_news_md", 57 | "lt_core_news_lg": "lt_core_news_lg", 58 | "mk_core_news_sm": "mk_core_news_sm", 59 | "mk_core_news_md": "mk_core_news_md", 60 | "mk_core_news_lg": "mk_core_news_lg", 61 | "nb_core_news_sm": "nb_core_news_sm", 62 | "nb_core_news_md": "nb_core_news_md", 63 | "nb_core_news_lg": "nb_core_news_lg", 64 | "nl_core_news_sm": "nl_core_news_sm", 65 | "nl_core_news_md": "nl_core_news_md", 66 | "nl_core_news_lg": "nl_core_news_lg", 67 | "pl_core_news_sm": "pl_core_news_sm", 68 | "pl_core_news_md": "pl_core_news_md", 69 | "pl_core_news_lg": "pl_core_news_lg", 70 | "pt_core_news_sm": "pt_core_news_sm", 71 | "pt_core_news_md": "pt_core_news_md", 72 | "pt_core_news_lg": "pt_core_news_lg", 73 | "ro_core_news_sm": "ro_core_news_sm", 74 | "ro_core_news_md": "ro_core_news_md", 75 | "ro_core_news_lg": "ro_core_news_lg", 76 | "ru_core_news_sm": "ru_core_news_sm", 77 | "ru_core_news_md": "ru_core_news_md", 78 | "ru_core_news_lg": "ru_core_news_lg", 79 | "xx_ent_wiki_sm": "xx_ent_wiki_sm", 80 | "xx_sent_ud_sm": "xx_sent_ud_sm", 81 | "zh_core_web_sm": "zh_core_web_sm", 82 | "zh_core_web_md": "zh_core_web_md", 83 | "zh_core_web_lg": "zh_core_web_lg", 84 | "zh_core_web_trf": "zh_core_web_trf", 85 | } 86 | 87 | from wsl.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer 88 | -------------------------------------------------------------------------------- /wsl/inference/data/objects.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import Dict, List, NamedTuple, Optional 6 | 7 | from wsl.reader.pytorch_modules.hf.modeling_wsl import WSLReaderSample 8 | from wsl.retriever.indexers.document import Document 9 | 10 | 11 | @dataclass 12 | class Word: 13 | """ 14 | A word representation that includes text, index in the sentence, POS tag, lemma, 15 | dependency relation, and similar information. 16 | 17 | # Parameters 18 | text : `str`, optional 19 | The text representation. 20 | index : `int`, optional 21 | The word offset in the sentence. 22 | lemma : `str`, optional 23 | The lemma of this word. 24 | pos : `str`, optional 25 | The coarse-grained part of speech of this word. 26 | dep : `str`, optional 27 | The dependency relation for this word. 28 | 29 | input_id : `int`, optional 30 | Integer representation of the word, used to pass it to a model. 31 | token_type_id : `int`, optional 32 | Token type id used by some transformers. 33 | attention_mask: `int`, optional 34 | Attention mask used by transformers, indicates to the model which tokens should 35 | be attended to, and which should not. 36 | """ 37 | 38 | text: str 39 | i: int 40 | idx: Optional[int] = None 41 | idx_end: Optional[int] = None 42 | # preprocessing fields 43 | lemma: Optional[str] = None 44 | pos: Optional[str] = None 45 | dep: Optional[str] = None 46 | head: Optional[int] = None 47 | 48 | def __str__(self): 49 | return self.text 50 | 51 | def __repr__(self): 52 | return self.__str__() 53 | 54 | 55 | class Span(NamedTuple): 56 | start: int 57 | end: int 58 | label: str 59 | text: str 60 | 61 | 62 | class Candidates(NamedTuple): 63 | candidates: Dict[List[Document]] 64 | 65 | 66 | @dataclass 67 | class WSLOutput: 68 | """ 69 | Represents the output of the Relik model. 70 | 71 | Attributes: 72 | text (str): 73 | The original input text. 74 | tokens (List[str]): 75 | The list of tokens generated from the input text. 76 | spans (List[Span]): 77 | The list of spans generated for the input t 78 | candidates (Candidates): 79 | The candidates for spans and triplets. The candidates are generated by the retriever. 80 | For each type of candidate, the documents are stored in a list of lists. The outer list 81 | represents the windows, and the inner list represents the documents in that window. 82 | If only one window is used, the outer list will have only one element. 83 | windows (Optional[List[WSLReaderSample]]): 84 | The list of windows used for processing the input text. 85 | """ 86 | 87 | text: str 88 | tokens: List[str] 89 | id: str | int 90 | spans: List[Span] 91 | candidates: Candidates = None 92 | windows: Optional[List[WSLReaderSample]] = None 93 | 94 | # convert to dict 95 | def to_dict(self): 96 | self_dict = { 97 | "text": self.text, 98 | "tokens": self.tokens, 99 | "spans": self.spans, 100 | "candidates": { 101 | "span": [ 102 | [[doc.to_dict() for doc in documents] for documents in window] 103 | for window in self.candidates.candidates 104 | ], 105 | "triplet": [ 106 | [[doc.to_dict() for doc in documents] for documents in window] 107 | for window in self.candidates.triplet 108 | ], 109 | }, 110 | } 111 | if self.windows is not None: 112 | self_dict["windows"] = [window.to_dict() for window in self.windows] 113 | return self_dict 114 | 115 | 116 | class AnnotationType(Enum): 117 | CHAR = "char" 118 | WORD = "word" 119 | 120 | 121 | class TaskType(Enum): 122 | SPAN = "span" 123 | BOTH = "both" 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | # Word Sense Linking: Disambiguating Outside the Sandbox 6 | 7 | 8 | [![Conference](http://img.shields.io/badge/ACL-2024-4b44ce.svg)](https://2024.aclweb.org/) 9 | [![Paper](http://img.shields.io/badge/paper-ACL--anthology-B31B1B.svg)](https://aclanthology.org/2024.findings-acl.851/) 10 | [![Hugging Face Collection](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-FCD21D)](https://huggingface.co/collections/Babelscape/word-sense-linking-66ace2182bc45680964cefcb) 11 | 12 | ![i](./assets/Sapienza_Babelscape.png) 13 | 14 |
15 |
16 | 17 |
18 | 19 | With this work we introduce a new task: **Word Sense Linking (WSL)**. WSL enhances Word Sense Disambiguation by carrying out both candidate identification (new!) and candidate disambiguation. Our Word Sense Linking model is designed to identify and disambiguate spans of text to their most suitable senses from a reference inventory. 20 | 21 | ## Installation 22 | 23 | Installation from PyPI 24 | 25 | ```bash 26 | git clone https://github.com/Babelscape/WSL 27 | cd WSL 28 | pip install . 29 | ``` 30 | 31 | 32 | ## Usage 33 | 34 | WSL is composed of two main components: a retriever and a reader. 35 | The retriever is responsible for retrieving relevant senses from a senses inventory (e.g. WordNet), 36 | while the reader is responsible for extracting spans from the input text and link them to the retrieved documents. 37 | WSL can be used with the `from_pretrained` method to load a pre-trained pipeline. 38 | 39 | ```python 40 | from wsl import WSL 41 | from wsl.inference.data.objects import WSLOutput 42 | 43 | wsl_model = WSL.from_pretrained("Babelscape/wsl-base") 44 | WSLOutput = wsl_model("Bus drivers drive busses for a living.") 45 | ``` 46 | 47 | WSLOutput( 48 | text='Bus drivers drive busses for a living.', 49 | tokens=['Bus', 'drivers', 'drive', 'busses', 'for', 'a', 'living', '.'], 50 | id=0, 51 | spans=[ 52 | Span(start=0, end=11, label='bus driver: someone who drives a bus', text='Bus drivers'), 53 | Span(start=12, end=17, label='drive: operate or control a vehicle', text='drive'), 54 | Span(start=18, end=24, label='bus: a vehicle carrying many passengers; used for public transport', text='busses'), 55 | Span(start=31, end=37, label='living: the financial means whereby one lives', text='living') 56 | ], 57 | candidates=Candidates( 58 | candidates=[ 59 | {"text": "bus driver: someone who drives a bus", "id": "bus_driver%1:18:00::", "metadata": {}}, 60 | {"text": "driver: the operator of a motor vehicle", "id": "driver%1:18:00::", "metadata": {}}, 61 | {"text": "driver: someone who drives animals that pull a vehicle", "id": "driver%1:18:02::", "metadata": {}}, 62 | {"text": "bus: a vehicle carrying many passengers; used for public transport", "id": "bus%1:06:00::", "metadata": {}}, 63 | {"text": "living: the financial means whereby one lives", "id": "living%1:26:00::", "metadata": {}} 64 | ] 65 | ), 66 | ) 67 | 68 | 69 | 70 | ## Model Performance 71 | 72 | Here you can find the performances of our model on the [WSL evaluation dataset](https://huggingface.co/datasets/Babelscape/wsl). 73 | 74 | ### Validation (SE07) 75 | 76 | | Models | P | R | F1 | 77 | |--------------|------|--------|--------| 78 | | BEM_SUP | 67.6 | 40.9 | 51.0 | 79 | | BEM_HEU | 70.8 | 51.2 | 59.4 | 80 | | ConSeC_SUP | 76.4 | 46.5 | 57.8 | 81 | | ConSeC_HEU | **76.7** | 55.4 | 64.3 | 82 | | **Our Model**| 73.8 | **74.9** | **74.4** | 83 | 84 | ### Test (ALL_FULL) 85 | 86 | | Models | P | R | F1 | 87 | |--------------|------|--------|--------| 88 | | BEM_SUP | 74.8 | 50.7 | 60.4 | 89 | | BEM_HEU | 76.6 | 61.2 | 68.0 | 90 | | ConSeC_SUP | 78.9 | 53.1 | 63.5 | 91 | | ConSeC_HEU | **80.4** | 64.3 | 71.5 | 92 | | **Our Model**| 75.2 | **76.7** | **75.9** | 93 | 94 | 95 | ## Cite this work 96 | 97 | If you use any part of this work, please consider citing the paper as follows: 98 | 99 | ```bibtex 100 | @inproceedings{bejgu-etal-2024-wsl, 101 | title = "Word Sense Linking: Disambiguating Outside the Sandbox", 102 | author = "Bejgu, Andrei Stefan and Barba, Edoardo and Procopio, Luigi and Fern{\'a}ndez-Castro, Alberte and Navigli, Roberto", 103 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2024", 104 | month = aug, 105 | year = "2024", 106 | address = "Bangkok, Thailand", 107 | publisher = "Association for Computational Linguistics", 108 | } 109 | ``` 110 | 111 | ## License 112 | 113 | The data and software are licensed under cc-by-nc-sa-4.0 you can read it here [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](./wsl_data_license.txt). 114 | 115 | -------------------------------------------------------------------------------- /wsl/common/upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import tempfile 6 | import zipfile 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Optional, Union 10 | 11 | import huggingface_hub 12 | 13 | from wsl.common.log import get_logger 14 | from wsl.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5 15 | 16 | logger = get_logger(__name__, level=logging.DEBUG) 17 | 18 | 19 | def create_info_file(tmpdir: Path): 20 | logger.debug("Computing md5 of model.zip") 21 | md5 = get_md5(tmpdir / "model.zip") 22 | date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT) 23 | 24 | logger.debug("Dumping info.json file") 25 | with (tmpdir / "info.json").open("w") as f: 26 | json.dump(dict(md5=md5, upload_date=date), f, indent=2) 27 | 28 | 29 | def zip_run( 30 | dir_path: Union[str, os.PathLike], 31 | tmpdir: Union[str, os.PathLike], 32 | zip_name: str = "model.zip", 33 | ) -> Path: 34 | logger.debug(f"zipping {dir_path} to {tmpdir}") 35 | # creates a zip version of the provided dir_path 36 | run_dir = Path(dir_path) 37 | zip_path = tmpdir / zip_name 38 | 39 | with zipfile.ZipFile(zip_path, "w") as zip_file: 40 | # fully zip the run directory maintaining its structure 41 | for file in run_dir.rglob("*.*"): 42 | if file.is_dir(): 43 | continue 44 | 45 | zip_file.write(file, arcname=file.relative_to(run_dir)) 46 | 47 | return zip_path 48 | 49 | 50 | def get_logged_in_username(): 51 | token = huggingface_hub.HfFolder.get_token() 52 | if token is None: 53 | raise ValueError( 54 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!" 55 | ) 56 | api = huggingface_hub.HfApi() 57 | user = api.whoami(token=token) 58 | return user["name"] 59 | 60 | 61 | def upload( 62 | model_dir: Union[str, os.PathLike], 63 | model_name: str, 64 | filenames: Optional[list[str]] = None, 65 | organization: Optional[str] = None, 66 | repo_name: Optional[str] = None, 67 | commit: Optional[str] = None, 68 | archive: bool = False, 69 | ): 70 | token = huggingface_hub.HfFolder.get_token() 71 | if token is None: 72 | raise ValueError( 73 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!" 74 | ) 75 | 76 | repo_id = repo_name or model_name 77 | if organization is not None: 78 | repo_id = f"{organization}/{repo_id}" 79 | with tempfile.TemporaryDirectory() as tmpdir: 80 | api = huggingface_hub.HfApi() 81 | repo_url = api.create_repo( 82 | token=token, 83 | repo_id=repo_id, 84 | exist_ok=True, 85 | ) 86 | repo = huggingface_hub.Repository( 87 | str(tmpdir), clone_from=repo_url, use_auth_token=token 88 | ) 89 | 90 | tmp_path = Path(tmpdir) 91 | if archive: 92 | # otherwise we zip the model_dir 93 | logger.debug(f"Zipping {model_dir} to {tmp_path}") 94 | zip_run(model_dir, tmp_path) 95 | create_info_file(tmp_path) 96 | else: 97 | # if the user wants to upload a transformers model, we don't need to zip it 98 | # we just need to copy the files to the tmpdir 99 | logger.debug(f"Copying {model_dir} to {tmpdir}") 100 | # copy only the files that are needed 101 | if filenames is not None: 102 | for filename in filenames: 103 | os.system(f"cp {model_dir}/{filename} {tmpdir}") 104 | else: 105 | os.system(f"cp -r {model_dir}/* {tmpdir}") 106 | 107 | # this method automatically puts large files (>10MB) into git lfs 108 | repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp") 109 | 110 | 111 | def parse_args() -> argparse.Namespace: 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "model_dir", help="The directory of the model you want to upload" 115 | ) 116 | parser.add_argument("model_name", help="The model you want to upload") 117 | parser.add_argument( 118 | "--organization", 119 | help="the name of the organization where you want to upload the model", 120 | ) 121 | parser.add_argument( 122 | "--repo_name", 123 | help="Optional name to use when uploading to the HuggingFace repository", 124 | ) 125 | parser.add_argument( 126 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub" 127 | ) 128 | parser.add_argument( 129 | "--archive", 130 | action="store_true", 131 | help=""" 132 | Whether to compress the model directory before uploading it. 133 | If True, the model directory will be zipped and the zip file will be uploaded. 134 | If False, the model directory will be uploaded as is.""", 135 | ) 136 | return parser.parse_args() 137 | 138 | 139 | def main(): 140 | upload(**vars(parse_args())) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /wsl/common/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import threading 5 | from logging.config import dictConfig 6 | from typing import Any, Dict, Optional 7 | 8 | from art import text2art, tprint 9 | from colorama import Fore, Style, init 10 | from rich import get_console 11 | from termcolor import colored, cprint 12 | 13 | _lock = threading.Lock() 14 | _default_handler: Optional[logging.Handler] = None 15 | 16 | _default_log_level = logging.WARNING 17 | 18 | # fancy logger 19 | _console = get_console() 20 | 21 | 22 | class ColorfulFormatter(logging.Formatter): 23 | """ 24 | Formatter to add coloring to log messages by log type 25 | """ 26 | 27 | COLORS = { 28 | "WARNING": Fore.YELLOW, 29 | "ERROR": Fore.RED, 30 | "CRITICAL": Fore.RED + Style.BRIGHT, 31 | "DEBUG": Fore.CYAN, 32 | # "INFO": Fore.GREEN, 33 | } 34 | 35 | def format(self, record): 36 | record.rank = int(os.getenv("LOCAL_RANK", "0")) 37 | log_message = super().format(record) 38 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET 39 | 40 | 41 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { 42 | "version": 1, 43 | "formatters": { 44 | "simple": { 45 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", 46 | }, 47 | "colorful": { 48 | "()": ColorfulFormatter, 49 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", 50 | }, 51 | }, 52 | "filters": {}, 53 | "handlers": { 54 | "console": { 55 | "class": "logging.StreamHandler", 56 | "formatter": "simple", 57 | "filters": [], 58 | "stream": sys.stdout, 59 | }, 60 | "color_console": { 61 | "class": "logging.StreamHandler", 62 | "formatter": "colorful", 63 | "filters": [], 64 | "stream": sys.stdout, 65 | }, 66 | }, 67 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, 68 | "loggers": { 69 | "wsl": { 70 | "handlers": ["color_console"], 71 | "level": "DEBUG", 72 | "propagate": False, 73 | }, 74 | }, 75 | } 76 | 77 | 78 | def configure_logging(**kwargs): 79 | """Configure with default logging""" 80 | init() # Initialize colorama 81 | # merge DEFAULT_LOGGING_CONFIG with kwargs 82 | logger_config = DEFAULT_LOGGING_CONFIG 83 | if kwargs: 84 | logger_config.update(kwargs) 85 | dictConfig(logger_config) 86 | 87 | 88 | def _get_library_name() -> str: 89 | return __name__.split(".")[0] 90 | 91 | 92 | def _get_library_root_logger() -> logging.Logger: 93 | return logging.getLogger(_get_library_name()) 94 | 95 | 96 | def _configure_library_root_logger() -> None: 97 | global _default_handler 98 | 99 | with _lock: 100 | if _default_handler: 101 | # This library has already configured the library root logger. 102 | return 103 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 104 | _default_handler.flush = sys.stderr.flush 105 | 106 | # Apply our default configuration to the library root logger. 107 | library_root_logger = _get_library_root_logger() 108 | library_root_logger.addHandler(_default_handler) 109 | library_root_logger.setLevel(_default_log_level) 110 | library_root_logger.propagate = False 111 | 112 | 113 | def _reset_library_root_logger() -> None: 114 | global _default_handler 115 | 116 | with _lock: 117 | if not _default_handler: 118 | return 119 | 120 | library_root_logger = _get_library_root_logger() 121 | library_root_logger.removeHandler(_default_handler) 122 | library_root_logger.setLevel(logging.NOTSET) 123 | _default_handler = None 124 | 125 | 126 | def set_log_level(level: int, logger: logging.Logger = None) -> None: 127 | """ 128 | Set the log level. 129 | Args: 130 | level (:obj:`int`): 131 | Logging level. 132 | logger (:obj:`logging.Logger`): 133 | Logger to set the log level. 134 | """ 135 | if not logger: 136 | _configure_library_root_logger() 137 | logger = _get_library_root_logger() 138 | logger.setLevel(level) 139 | 140 | 141 | def get_logger( 142 | name: Optional[str] = None, 143 | level: Optional[int] = None, 144 | formatter: Optional[str] = None, 145 | **kwargs, 146 | ) -> logging.Logger: 147 | """ 148 | Return a logger with the specified name. 149 | """ 150 | 151 | configure_logging(**kwargs) 152 | 153 | if name is None: 154 | name = _get_library_name() 155 | 156 | _configure_library_root_logger() 157 | 158 | if level is not None: 159 | set_log_level(level) 160 | 161 | if formatter is None: 162 | formatter = logging.Formatter( 163 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 164 | ) 165 | _default_handler.setFormatter(formatter) 166 | 167 | return logging.getLogger(name) 168 | 169 | 170 | def get_console_logger(): 171 | return _console 172 | -------------------------------------------------------------------------------- /wsl/reader/utils/strong_matching_eval.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from lightning.pytorch.callbacks import Callback 4 | 5 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample 6 | from wsl.reader.utils.metrics import f1_measure, safe_divide 7 | from wsl.reader.utils.relik_reader_predictor import WSLReaderPredictor 8 | from wsl.reader.utils.special_symbols import NME_SYMBOL 9 | 10 | 11 | class StrongMatching: 12 | def __call__(self, predicted_samples: List[WSLReaderSample]) -> Dict: 13 | # accumulators 14 | correct_predictions = 0 15 | correct_predictions_at_k = 0 16 | total_predictions = 0 17 | total_gold = 0 18 | correct_span_predictions = 0 19 | miss_due_to_candidates = 0 20 | 21 | # prediction index stats 22 | avg_correct_predicted_index = [] 23 | avg_wrong_predicted_index = [] 24 | less_index_predictions = [] 25 | 26 | # collect data from samples 27 | for sample in predicted_samples: 28 | predicted_annotations = sample.predicted_window_labels_chars 29 | predicted_annotations_probabilities = sample.probs_window_labels_chars 30 | gold_annotations = { 31 | (ss, se, entity) 32 | for ss, se, entity in sample.window_labels 33 | if entity != NME_SYMBOL 34 | } 35 | total_predictions += len(predicted_annotations) 36 | total_gold += len(gold_annotations) 37 | 38 | # correct named entity detection 39 | predicted_spans = {(s, e) for s, e, _ in predicted_annotations} 40 | gold_spans = {(s, e) for s, e, _ in gold_annotations} 41 | correct_span_predictions += len(predicted_spans.intersection(gold_spans)) 42 | 43 | # correct entity linking 44 | correct_predictions += len( 45 | predicted_annotations.intersection(gold_annotations) 46 | ) 47 | 48 | for ss, se, ge in gold_annotations.difference(predicted_annotations): 49 | if ge not in sample.span_candidates: 50 | miss_due_to_candidates += 1 51 | if ge in predicted_annotations_probabilities.get((ss, se), set()): 52 | correct_predictions_at_k += 1 53 | 54 | # indices metrics 55 | predicted_spans_index = { 56 | (ss, se): ent for ss, se, ent in predicted_annotations 57 | } 58 | gold_spans_index = {(ss, se): ent for ss, se, ent in gold_annotations} 59 | 60 | for pred_span, pred_ent in predicted_spans_index.items(): 61 | gold_ent = gold_spans_index.get(pred_span) 62 | 63 | if pred_span not in gold_spans_index: 64 | continue 65 | 66 | # missing candidate 67 | if gold_ent not in sample.span_candidates: 68 | continue 69 | 70 | gold_idx = sample.span_candidates.index(gold_ent) 71 | if gold_idx is None: 72 | continue 73 | pred_idx = sample.span_candidates.index(pred_ent) 74 | 75 | if gold_ent != pred_ent: 76 | avg_wrong_predicted_index.append(pred_idx) 77 | 78 | if gold_idx is not None: 79 | if pred_idx > gold_idx: 80 | less_index_predictions.append(0) 81 | else: 82 | less_index_predictions.append(1) 83 | 84 | else: 85 | avg_correct_predicted_index.append(pred_idx) 86 | 87 | # compute NED metrics 88 | span_precision = safe_divide(correct_span_predictions, total_predictions) 89 | span_recall = safe_divide(correct_span_predictions, total_gold) 90 | span_f1 = f1_measure(span_precision, span_recall) 91 | 92 | # compute EL metrics 93 | precision = safe_divide(correct_predictions, total_predictions) 94 | recall = safe_divide(correct_predictions, total_gold) 95 | recall_at_k = safe_divide( 96 | (correct_predictions + correct_predictions_at_k), total_gold 97 | ) 98 | 99 | f1 = f1_measure(precision, recall) 100 | 101 | wrong_for_candidates = safe_divide(miss_due_to_candidates, total_gold) 102 | 103 | out_dict = { 104 | "span_precision": span_precision, 105 | "span_recall": span_recall, 106 | "span_f1": span_f1, 107 | "core_precision": precision, 108 | "core_recall": recall, 109 | "core_recall-at-k": recall_at_k, 110 | "core_f1": round(f1, 4), 111 | "wrong-for-candidates": wrong_for_candidates, 112 | "index_errors_avg-index": safe_divide( 113 | sum(avg_wrong_predicted_index), len(avg_wrong_predicted_index) 114 | ), 115 | "index_correct_avg-index": safe_divide( 116 | sum(avg_correct_predicted_index), len(avg_correct_predicted_index) 117 | ), 118 | "index_avg-index": safe_divide( 119 | sum(avg_correct_predicted_index + avg_wrong_predicted_index), 120 | len(avg_correct_predicted_index + avg_wrong_predicted_index), 121 | ), 122 | "index_percentage-favoured-smaller-idx": safe_divide( 123 | sum(less_index_predictions), len(less_index_predictions) 124 | ), 125 | } 126 | 127 | return {k: round(v, 5) for k, v in out_dict.items()} 128 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/master/setup.py 3 | To create the package for pypi. 4 | 1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the 5 | documentation. 6 | 2. Run Tests for Amazon Sagemaker. The documentation is located in `./tests/sagemaker/README.md`, otherwise @philschmid. 7 | 3. Unpin specific versions from setup.py that use a git install. 8 | 4. Commit these changes with the message: "Release: VERSION" 9 | 5. Add a tag in git to mark the release: "git tag VERSION -m 'Adds tag VERSION for pypi' " 10 | Push the tag to git: git push --tags origin master 11 | 6. Build both the sources and the wheel. Do not change anything in setup.py between 12 | creating the wheel and the source distribution (obviously). 13 | For the wheel, run: "python setup.py bdist_wheel" in the top level directory. 14 | (this will build a wheel for the python version you use to build it). 15 | For the sources, run: "python setup.py sdist" 16 | You should now have a /dist directory with both .whl and .tar.gz source versions. 17 | 7. Check that everything looks correct by uploading the package to the pypi test server: 18 | twine upload dist/* -r pypitest 19 | (pypi suggest using twine as other methods upload files via plaintext.) 20 | You may have to specify the repository url, use the following command then: 21 | twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ 22 | Check that you can install it in a virtualenv by running: 23 | pip install -i https://testpypi.python.org/pypi transformers 24 | 8. Upload the final version to actual pypi: 25 | twine upload dist/* -r pypi 26 | 9. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 27 | 10. Run `make post-release` (or `make post-patch` for a patch release). 28 | """ 29 | from collections import defaultdict 30 | 31 | import setuptools 32 | 33 | 34 | def parse_requirements_file( 35 | path, allowed_extras: set = None, include_all_extra: bool = True 36 | ): 37 | requirements = [] 38 | extras = defaultdict(list) 39 | find_links = [] 40 | with open(path) as requirements_file: 41 | import re 42 | 43 | def fix_url_dependencies(req: str) -> str: 44 | """Pip and setuptools disagree about how URL dependencies should be handled.""" 45 | m = re.match( 46 | r"^(git\+)?(https|ssh)://(git@)?github\.com/([\w-]+)/(?P[\w-]+)\.git", 47 | req, 48 | ) 49 | if m is None: 50 | return req 51 | else: 52 | return f"{m.group('name')} @ {req}" 53 | 54 | for line in requirements_file: 55 | line = line.strip() 56 | if line.startswith("#") or len(line) <= 0: 57 | continue 58 | if ( 59 | line.startswith("-f") 60 | or line.startswith("--find-links") 61 | or line.startswith("--index-url") 62 | or line.startswith("--extra-index-url") 63 | ): 64 | find_links.append(line.split(" ", maxsplit=1)[-1].strip()) 65 | continue 66 | 67 | req, *needed_by = line.split("# needed by:") 68 | req = fix_url_dependencies(req.strip()) 69 | if needed_by: 70 | for extra in needed_by[0].strip().split(","): 71 | extra = extra.strip() 72 | if allowed_extras is not None and extra not in allowed_extras: 73 | raise ValueError(f"invalid extra '{extra}' in {path}") 74 | extras[extra].append(req) 75 | if include_all_extra and req not in extras["all"]: 76 | # if "gpu" in extra: 77 | # extras["all-gpu"].append(req) 78 | # else: 79 | extras["all"].append(req) 80 | 81 | else: 82 | requirements.append(req) 83 | return requirements, extras, find_links 84 | 85 | 86 | allowed_extras = { 87 | "serve", 88 | "ray", 89 | "train", 90 | "retriever", 91 | "reader", 92 | "all", 93 | "faiss", 94 | "faiss-gpu", 95 | "dev", 96 | } 97 | 98 | # Load requirements. 99 | install_requirements, extras, find_links = parse_requirements_file( 100 | "requirements.txt", allowed_extras=allowed_extras 101 | ) 102 | 103 | # version.py defines the VERSION and VERSION_SHORT variables. 104 | # We use exec here, so we don't import allennlp whilst setting up. 105 | VERSION = {} # type: ignore 106 | with open("wsl/version.py", "r") as version_file: 107 | exec(version_file.read(), VERSION) 108 | 109 | with open("README.md", "r") as fh: 110 | long_description = fh.read() 111 | 112 | setuptools.setup( 113 | name="wsl", 114 | version=VERSION["VERSION"], 115 | author="Andrei Stefan Bejgu, Edoardo Barba, Riccardo Orlando, Pere-Lluís Huguet Cabot, Roberto Navigli", 116 | author_email="info@babelscape.com", 117 | description="W", 118 | long_description=long_description, 119 | long_description_content_type="text/markdown", 120 | url="https://github.com/Babelscape/wsl", 121 | keywords="NLP Sapienza sapienzanlp babelscape deep learning transformer pytorch retriever word sense linking word sense disambiguation reader", 122 | packages=setuptools.find_packages(), 123 | include_package_data=True, 124 | license="CC BY-NC-SA 4.0", 125 | classifiers=[ 126 | "Intended Audience :: Science/Research", 127 | "Programming Language :: Python :: 3", 128 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 129 | ], 130 | install_requires=install_requirements, 131 | extras_require=extras, 132 | python_requires=">=3.10", 133 | find_links=find_links, 134 | 135 | ) 136 | -------------------------------------------------------------------------------- /wsl/retriever/data/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | from typing import Dict, List, Optional, Union 5 | 6 | import numpy as np 7 | import transformers as tr 8 | from tqdm import tqdm 9 | 10 | 11 | class HardNegativesManager: 12 | def __init__( 13 | self, 14 | tokenizer: tr.PreTrainedTokenizer, 15 | data: Union[List[Dict], os.PathLike, Dict[int, List]] = None, 16 | max_length: int = 64, 17 | batch_size: int = 1000, 18 | lazy: bool = False, 19 | ) -> None: 20 | self._db: dict = None 21 | self.tokenizer = tokenizer 22 | 23 | if data is None: 24 | self._db = {} 25 | else: 26 | if isinstance(data, Dict): 27 | self._db = data 28 | elif isinstance(data, os.PathLike): 29 | with open(data) as f: 30 | self._db = json.load(f) 31 | else: 32 | raise ValueError( 33 | f"Data type {type(data)} not supported, only Dict and os.PathLike are supported." 34 | ) 35 | # add the tokenizer to the class for future use 36 | self.tokenizer = tokenizer 37 | 38 | # invert the db to have a passage -> sample_idx mapping 39 | self._passage_db = defaultdict(set) 40 | for sample_idx, passages in self._db.items(): 41 | for passage in passages: 42 | self._passage_db[passage].add(sample_idx) 43 | 44 | self._passage_hard_negatives = {} 45 | if not lazy: 46 | # create a dictionary of passage -> hard_negative mapping 47 | batch_size = min(batch_size, len(self._passage_db)) 48 | unique_passages = list(self._passage_db.keys()) 49 | for i in tqdm( 50 | range(0, len(unique_passages), batch_size), 51 | desc="Tokenizing Hard Negatives", 52 | ): 53 | batch = unique_passages[i : i + batch_size] 54 | tokenized_passages = self.tokenizer( 55 | batch, 56 | max_length=max_length, 57 | truncation=True, 58 | ) 59 | for i, passage in enumerate(batch): 60 | self._passage_hard_negatives[passage] = { 61 | k: tokenized_passages[k][i] for k in tokenized_passages.keys() 62 | } 63 | 64 | def __len__(self) -> int: 65 | return len(self._db) 66 | 67 | def __getitem__(self, idx: int) -> Dict: 68 | return self._db[idx] 69 | 70 | def __iter__(self): 71 | for sample in self._db: 72 | yield sample 73 | 74 | def __contains__(self, idx: int) -> bool: 75 | return idx in self._db 76 | 77 | def get(self, idx: int) -> List[str]: 78 | """Get the hard negatives for a given sample index.""" 79 | if idx not in self._db: 80 | raise ValueError(f"Sample index {idx} not in the database.") 81 | 82 | passages = self._db[idx] 83 | 84 | output = [] 85 | for passage in passages: 86 | if passage not in self._passage_hard_negatives: 87 | self._passage_hard_negatives[passage] = self._tokenize(passage) 88 | output.append(self._passage_hard_negatives[passage]) 89 | 90 | return output 91 | 92 | def _tokenize(self, passage: str) -> Dict: 93 | return self.tokenizer(passage, max_length=self.max_length, truncation=True) 94 | 95 | 96 | class NegativeSampler: 97 | def __init__( 98 | self, num_elements: int, probabilities: Optional[Union[List, np.ndarray]] = None 99 | ): 100 | if not isinstance(probabilities, np.ndarray): 101 | probabilities = np.array(probabilities) 102 | 103 | if probabilities is None: 104 | # probabilities should sum to 1 105 | probabilities = np.random.random(num_elements) 106 | probabilities /= np.sum(probabilities) 107 | self.probabilities = probabilities 108 | 109 | def __call__( 110 | self, 111 | sample_size: int, 112 | num_samples: int = 1, 113 | probabilities: np.array = None, 114 | exclude: List[int] = None, 115 | ) -> np.array: 116 | """ 117 | Fast sampling of `sample_size` elements from `num_elements` elements. 118 | The sampling is done by randomly shifting the probabilities and then 119 | finding the smallest of the negative numbers. This is much faster than 120 | sampling from a multinomial distribution. 121 | 122 | Args: 123 | sample_size (`int`): 124 | number of elements to sample 125 | num_samples (`int`, optional): 126 | number of samples to draw. Defaults to 1. 127 | probabilities (`np.array`, optional): 128 | probabilities of each element. Defaults to None. 129 | exclude (`List[int]`, optional): 130 | indices of elements to exclude. Defaults to None. 131 | 132 | Returns: 133 | `np.array`: array of sampled indices 134 | """ 135 | if probabilities is None: 136 | probabilities = self.probabilities 137 | 138 | if exclude is not None: 139 | probabilities[exclude] = 0 140 | # re-normalize? 141 | # probabilities /= np.sum(probabilities) 142 | 143 | # replicate probabilities as many times as `num_samples` 144 | replicated_probabilities = np.tile(probabilities, (num_samples, 1)) 145 | # get random shifting numbers & scale them correctly 146 | random_shifts = np.random.random(replicated_probabilities.shape) 147 | random_shifts /= random_shifts.sum(axis=1)[:, np.newaxis] 148 | # shift by numbers & find largest (by finding the smallest of the negative) 149 | shifted_probabilities = random_shifts - replicated_probabilities 150 | sampled_indices = np.argpartition(shifted_probabilities, sample_size, axis=1)[ 151 | :, :sample_size 152 | ] 153 | return sampled_indices 154 | -------------------------------------------------------------------------------- /wsl/inference/data/splitters/spacy_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Iterable, List, Optional, Union 2 | 3 | import spacy 4 | 5 | from wsl.inference.data.objects import Word 6 | from wsl.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter 7 | from wsl.inference.data.tokenizers.spacy_tokenizer import load_spacy 8 | 9 | SPACY_LANGUAGE_MAPPER = { 10 | "cs": "xx_sent_ud_sm", 11 | "da": "xx_sent_ud_sm", 12 | "de": "xx_sent_ud_sm", 13 | "fa": "xx_sent_ud_sm", 14 | "fi": "xx_sent_ud_sm", 15 | "fr": "xx_sent_ud_sm", 16 | "el": "el_core_news_sm", 17 | "en": "xx_sent_ud_sm", 18 | "es": "xx_sent_ud_sm", 19 | "ga": "xx_sent_ud_sm", 20 | "hr": "xx_sent_ud_sm", 21 | "id": "xx_sent_ud_sm", 22 | "it": "xx_sent_ud_sm", 23 | "ja": "ja_core_news_sm", 24 | "lv": "xx_sent_ud_sm", 25 | "lt": "xx_sent_ud_sm", 26 | "mr": "xx_sent_ud_sm", 27 | "nb": "xx_sent_ud_sm", 28 | "nl": "xx_sent_ud_sm", 29 | "no": "xx_sent_ud_sm", 30 | "pl": "pl_core_news_sm", 31 | "pt": "xx_sent_ud_sm", 32 | "ro": "xx_sent_ud_sm", 33 | "ru": "xx_sent_ud_sm", 34 | "sk": "xx_sent_ud_sm", 35 | "sr": "xx_sent_ud_sm", 36 | "sv": "xx_sent_ud_sm", 37 | "te": "xx_sent_ud_sm", 38 | "vi": "xx_sent_ud_sm", 39 | "zh": "zh_core_web_sm", 40 | } 41 | 42 | 43 | class SpacySentenceSplitter(BaseSentenceSplitter): 44 | """ 45 | A :obj:`SentenceSplitter` that uses spaCy's built-in sentence boundary detection. 46 | 47 | Args: 48 | language (:obj:`str`, optional, defaults to :obj:`en`): 49 | Language of the text to tokenize. 50 | model_type (:obj:`str`, optional, defaults to :obj:`statistical`): 51 | Three different type of sentence splitter: 52 | - ``dependency``: sentence splitter uses a dependency parse to detect sentence boundaries, 53 | slow, but accurate. 54 | - ``statistical``: 55 | - ``rule_based``: It's fast and has a small memory footprint, since it uses punctuation to detect 56 | sentence boundaries. 57 | """ 58 | 59 | def __init__(self, language: str = "en", model_type: str = "statistical") -> None: 60 | # we need spacy's dependency parser if we're not using rule-based sentence boundary detection. 61 | # self.spacy = get_spacy_model(language, parse=not rule_based, ner=False) 62 | dep = bool(model_type == "dependency") 63 | if language in SPACY_LANGUAGE_MAPPER: 64 | self.spacy = load_spacy(SPACY_LANGUAGE_MAPPER[language], parse=dep) 65 | else: 66 | self.spacy = spacy.blank(language) 67 | # force type to rule_based since there is no pre-trained model 68 | model_type = "rule_based" 69 | if model_type == "dependency": 70 | # dependency type must declared at model init 71 | pass 72 | elif model_type == "statistical": 73 | if not self.spacy.has_pipe("senter"): 74 | self.spacy.enable_pipe("senter") 75 | elif model_type == "rule_based": 76 | # we use `sentencizer`, a built-in spacy module for rule-based sentence boundary detection. 77 | # depending on the spacy version, it could be called 'sentencizer' or 'sbd' 78 | if not self.spacy.has_pipe("sentencizer"): 79 | self.spacy.add_pipe("sentencizer") 80 | else: 81 | raise ValueError( 82 | f"type {model_type} not supported. Choose between `dependency`, `statistical` or `rule_based`" 83 | ) 84 | 85 | def __call__( 86 | self, 87 | texts: Union[str, List[str], List[List[str]]], 88 | max_length: Optional[int] = None, 89 | is_split_into_words: bool = False, 90 | **kwargs, 91 | ) -> Union[List[str], List[List[str]]]: 92 | """ 93 | Tokenize the input into single words using SpaCy models. 94 | 95 | Args: 96 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 97 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 98 | max_len (:obj:`int`, optional, defaults to :obj:`0`): 99 | Maximum length of a single text. If the text is longer than `max_len`, it will be split 100 | into multiple sentences. 101 | 102 | Returns: 103 | :obj:`List[List[str]]`: The input doc split into sentences. 104 | """ 105 | # check if input is batched or a single sample 106 | is_batched = self.check_is_batched(texts, is_split_into_words) 107 | 108 | if is_batched: 109 | sents = self.split_sentences_batch(texts) 110 | else: 111 | sents = self.split_sentences(texts, max_length) 112 | return sents 113 | 114 | @staticmethod 115 | def chunked(iterable, n: int) -> Iterable[List[Any]]: 116 | """ 117 | Chunks a list into n sized chunks. 118 | 119 | Args: 120 | iterable (:obj:`List[Any]`): 121 | List to chunk. 122 | n (:obj:`int`): 123 | Size of the chunks. 124 | 125 | Returns: 126 | :obj:`Iterable[List[Any]]`: The input list chunked into n sized chunks. 127 | """ 128 | return [iterable[i : i + n] for i in range(0, len(iterable), n)] 129 | 130 | def split_sentences( 131 | self, text: str | List[Word], max_length: Optional[int] = None, *args, **kwargs 132 | ) -> List[str]: 133 | """ 134 | Splits a `text` into smaller sentences. 135 | 136 | Args: 137 | text (:obj:`str`): 138 | Text to split. 139 | max_length (:obj:`int`, optional, defaults to :obj:`0`): 140 | Maximum length of a single sentence. If the text is longer than `max_len`, it will be split 141 | into multiple sentences. 142 | 143 | Returns: 144 | :obj:`List[str]`: The input text split into sentences. 145 | """ 146 | sentences = [sent for sent in self.spacy(text).sents] 147 | if max_length is not None and max_length > 0: 148 | sentences = [ 149 | chunk 150 | for sentence in sentences 151 | for chunk in self.chunked(sentence, max_length) 152 | ] 153 | return sentences 154 | -------------------------------------------------------------------------------- /wsl/reader/utils/relik_reader_predictor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Iterable, Iterator, List, Optional 3 | 4 | import hydra 5 | import torch 6 | from lightning.pytorch.utilities import move_data_to_device 7 | from torch.utils.data import DataLoader, IterableDataset 8 | from tqdm import tqdm 9 | 10 | from wsl.reader.data.patches import merge_patches_predictions 11 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample, load_wsl_reader_samples 12 | from wsl.reader.pytorch_modules.base import WSLReaderBase 13 | from wsl.reader.utils.special_symbols import NME_SYMBOL 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def convert_tokens_to_char_annotations( 19 | sample: WSLReaderSample, remove_nmes: bool = False 20 | ): 21 | char_annotations = set() 22 | 23 | for ( 24 | predicted_entity, 25 | predicted_spans, 26 | ) in sample.predicted_window_labels.items(): 27 | if predicted_entity == NME_SYMBOL and remove_nmes: 28 | continue 29 | 30 | for span_start, span_end in predicted_spans: 31 | span_start = sample.token2char_start[str(span_start)] 32 | span_end = sample.token2char_end[str(span_end)] 33 | 34 | char_annotations.add((span_start, span_end, predicted_entity)) 35 | 36 | char_probs_annotations = dict() 37 | for ( 38 | span_start, 39 | span_end, 40 | ), candidates_probs in sample.span_title_probabilities.items(): 41 | span_start = sample.token2char_start[str(span_start)] 42 | span_end = sample.token2char_end[str(span_end)] 43 | char_probs_annotations[(span_start, span_end)] = { 44 | title for title, _ in candidates_probs 45 | } 46 | 47 | sample.predicted_window_labels_chars = char_annotations 48 | sample.probs_window_labels_chars = char_probs_annotations 49 | 50 | 51 | class WSLReaderPredictor: 52 | def __init__( 53 | self, 54 | wsl_reader_core: WSLReaderBase, 55 | dataset_conf: Optional[dict] = None, 56 | predict_nmes: bool = False, 57 | dataloader: Optional[DataLoader] = None, 58 | ) -> None: 59 | self.wsl_reader_core = wsl_reader_core 60 | self.dataset_conf = dataset_conf 61 | self.predict_nmes = predict_nmes 62 | self.dataloader: DataLoader | None = dataloader 63 | 64 | if self.dataset_conf is not None and self.dataset is None: 65 | # instantiate dataset 66 | self.dataset = hydra.utils.instantiate( 67 | dataset_conf, 68 | dataset_path=None, 69 | samples=None, 70 | ) 71 | 72 | def predict( 73 | self, 74 | path: Optional[str], 75 | samples: Optional[Iterable[WSLReaderSample]], 76 | dataset_conf: Optional[dict], 77 | token_batch_size: int = 1024, 78 | progress_bar: bool = False, 79 | **kwargs, 80 | ) -> List[WSLReaderSample]: 81 | annotated_samples = list( 82 | self._predict(path, samples, dataset_conf, token_batch_size, progress_bar) 83 | ) 84 | for sample in annotated_samples: 85 | merge_patches_predictions(sample) 86 | convert_tokens_to_char_annotations( 87 | sample, remove_nmes=not self.predict_nmes 88 | ) 89 | return annotated_samples 90 | 91 | def _predict( 92 | self, 93 | path: Optional[str], 94 | samples: Optional[Iterable[WSLReaderSample]], 95 | dataset_conf: dict, 96 | token_batch_size: int = 1024, 97 | progress_bar: bool = False, 98 | **kwargs, 99 | ) -> Iterator[WSLReaderSample]: 100 | assert ( 101 | path is not None or samples is not None 102 | ), "Either predict on a path or on an iterable of samples" 103 | 104 | next_prediction_position = 0 105 | position2predicted_sample = {} 106 | 107 | if self.dataloader is not None: 108 | iterator = self.dataloader 109 | for i, sample in enumerate(self.dataloader.dataset.samples): 110 | sample._mixin_prediction_position = i 111 | else: 112 | samples = load_wsl_reader_samples(path) if samples is None else samples 113 | 114 | # setup infrastructure to re-yield in order 115 | def samples_it(): 116 | for i, sample in enumerate(samples): 117 | assert sample._mixin_prediction_position is None 118 | sample._mixin_prediction_position = i 119 | yield sample 120 | 121 | # instantiate dataset 122 | if getattr(self, "dataset", None) is not None: 123 | dataset = self.dataset 124 | dataset.samples = samples_it() 125 | dataset.tokens_per_batch = token_batch_size 126 | else: 127 | dataset = hydra.utils.instantiate( 128 | dataset_conf, 129 | dataset_path=None, 130 | samples=samples_it(), 131 | tokens_per_batch=token_batch_size, 132 | ) 133 | 134 | # instantiate dataloader 135 | iterator = DataLoader( 136 | dataset, batch_size=None, num_workers=0, shuffle=False 137 | ) 138 | if progress_bar: 139 | iterator = tqdm(iterator, desc="Predicting") 140 | 141 | model_device = next(self.wsl_reader_core.parameters()).device 142 | 143 | with torch.inference_mode(): 144 | for batch in iterator: 145 | # do batch predict 146 | with torch.autocast( 147 | "cpu" if model_device == torch.device("cpu") else "cuda" 148 | ): 149 | batch = move_data_to_device(batch, model_device) 150 | batch_out = self.wsl_reader_core._batch_predict(**batch) 151 | # update prediction position position 152 | for sample in batch_out: 153 | if sample._mixin_prediction_position >= next_prediction_position: 154 | position2predicted_sample[ 155 | sample._mixin_prediction_position 156 | ] = sample 157 | 158 | # yield 159 | while next_prediction_position in position2predicted_sample: 160 | yield position2predicted_sample[next_prediction_position] 161 | del position2predicted_sample[next_prediction_position] 162 | next_prediction_position += 1 163 | 164 | if len(position2predicted_sample) > 0: 165 | logger.warning( 166 | "It seems samples have been discarded in your dataset. " 167 | "This means that you WON'T have a prediction for each input sample. " 168 | "Prediction order will also be partially disrupted" 169 | ) 170 | for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): 171 | yield v 172 | 173 | if progress_bar: 174 | iterator.close() 175 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | wsl/test.py 3 | 4 | data/* 5 | experiments/* 6 | models 7 | retrievers 8 | outputs 9 | wandb 10 | pretrained_configs 11 | lightning_logs 12 | hf_weights 13 | 14 | # Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos 15 | # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos 16 | 17 | ### JetBrains+all ### 18 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 19 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 20 | 21 | # User-specific stuff 22 | .idea/**/workspace.xml 23 | .idea/**/tasks.xml 24 | .idea/**/usage.statistics.xml 25 | .idea/**/dictionaries 26 | .idea/**/shelf 27 | 28 | # Generated files 29 | .idea/**/contentModel.xml 30 | 31 | # Sensitive or high-churn files 32 | .idea/**/dataSources/ 33 | .idea/**/dataSources.ids 34 | .idea/**/dataSources.local.xml 35 | .idea/**/sqlDataSources.xml 36 | .idea/**/dynamic.xml 37 | .idea/**/uiDesigner.xml 38 | .idea/**/dbnavigator.xml 39 | 40 | # Gradle 41 | .idea/**/gradle.xml 42 | .idea/**/libraries 43 | 44 | # Gradle and Maven with auto-import 45 | # When using Gradle or Maven with auto-import, you should exclude module files, 46 | # since they will be recreated, and may cause churn. Uncomment if using 47 | # auto-import. 48 | # .idea/artifacts 49 | # .idea/compiler.xml 50 | # .idea/jarRepositories.xml 51 | # .idea/modules.xml 52 | # .idea/*.iml 53 | # .idea/modules 54 | # *.iml 55 | # *.ipr 56 | 57 | # CMake 58 | cmake-build-*/ 59 | 60 | # Mongo Explorer plugin 61 | .idea/**/mongoSettings.xml 62 | 63 | # File-based project format 64 | *.iws 65 | 66 | # IntelliJ 67 | out/ 68 | 69 | # mpeltonen/sbt-idea plugin 70 | .idea_modules/ 71 | 72 | # JIRA plugin 73 | atlassian-ide-plugin.xml 74 | 75 | # Cursive Clojure plugin 76 | .idea/replstate.xml 77 | 78 | # Crashlytics plugin (for Android Studio and IntelliJ) 79 | com_crashlytics_export_strings.xml 80 | crashlytics.properties 81 | crashlytics-build.properties 82 | fabric.properties 83 | 84 | # Editor-based Rest Client 85 | .idea/httpRequests 86 | 87 | # Android studio 3.1+ serialized cache file 88 | .idea/caches/build_file_checksums.ser 89 | 90 | ### JetBrains+all Patch ### 91 | # Ignores the whole .idea folder and all .iml files 92 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 93 | 94 | .idea/ 95 | 96 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 97 | 98 | *.iml 99 | modules.xml 100 | .idea/misc.xml 101 | *.ipr 102 | 103 | # Sonarlint plugin 104 | .idea/sonarlint 105 | 106 | ### JupyterNotebooks ### 107 | # gitignore template for Jupyter Notebooks 108 | # website: http://jupyter.org/ 109 | 110 | .ipynb_checkpoints 111 | */.ipynb_checkpoints/* 112 | 113 | # IPython 114 | profile_default/ 115 | ipython_config.py 116 | 117 | # Remove previous ipynb_checkpoints 118 | # git rm -r .ipynb_checkpoints/ 119 | 120 | ### Linux ### 121 | *~ 122 | 123 | # temporary files which can be created if a process still has a handle open of a deleted file 124 | .fuse_hidden* 125 | 126 | # KDE directory preferences 127 | .directory 128 | 129 | # Linux trash folder which might appear on any partition or disk 130 | .Trash-* 131 | 132 | # .nfs files are created when an open file is removed but is still being accessed 133 | .nfs* 134 | 135 | ### macOS ### 136 | # General 137 | .DS_Store 138 | .AppleDouble 139 | .LSOverride 140 | 141 | # Icon must end with two \r 142 | Icon 143 | 144 | 145 | # Thumbnails 146 | ._* 147 | 148 | # Files that might appear in the root of a volume 149 | .DocumentRevisions-V100 150 | .fseventsd 151 | .Spotlight-V100 152 | .TemporaryItems 153 | .Trashes 154 | .VolumeIcon.icns 155 | .com.apple.timemachine.donotpresent 156 | 157 | # Directories potentially created on remote AFP share 158 | .AppleDB 159 | .AppleDesktop 160 | Network Trash Folder 161 | Temporary Items 162 | .apdisk 163 | 164 | ### Python ### 165 | # Byte-compiled / optimized / DLL files 166 | __pycache__/ 167 | *.py[cod] 168 | *$py.class 169 | 170 | # C extensions 171 | *.so 172 | 173 | # Distribution / packaging 174 | .Python 175 | build/ 176 | develop-eggs/ 177 | dist/ 178 | downloads/ 179 | eggs/ 180 | .eggs/ 181 | lib/ 182 | lib64/ 183 | parts/ 184 | sdist/ 185 | var/ 186 | wheels/ 187 | pip-wheel-metadata/ 188 | share/python-wheels/ 189 | *.egg-info/ 190 | .installed.cfg 191 | *.egg 192 | MANIFEST 193 | 194 | # PyInstaller 195 | # Usually these files are written by a python script from a template 196 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 197 | *.manifest 198 | *.spec 199 | 200 | # Installer logs 201 | pip-log.txt 202 | pip-delete-this-directory.txt 203 | 204 | # Unit test / coverage reports 205 | htmlcov/ 206 | .tox/ 207 | .nox/ 208 | .coverage 209 | .coverage.* 210 | .cache 211 | nosetests.xml 212 | coverage.xml 213 | *.cover 214 | *.py,cover 215 | .hypothesis/ 216 | .pytest_cache/ 217 | pytestdebug.log 218 | 219 | # Translations 220 | *.mo 221 | *.pot 222 | 223 | # Django stuff: 224 | *.log 225 | local_settings.py 226 | db.sqlite3 227 | db.sqlite3-journal 228 | 229 | # Flask stuff: 230 | instance/ 231 | .webassets-cache 232 | 233 | # Scrapy stuff: 234 | .scrapy 235 | 236 | # Sphinx documentation 237 | docs/_build/ 238 | doc/_build/ 239 | 240 | # PyBuilder 241 | target/ 242 | 243 | # Jupyter Notebook 244 | 245 | # IPython 246 | 247 | # pyenv 248 | .python-version 249 | 250 | # pipenv 251 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 252 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 253 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 254 | # install all needed dependencies. 255 | #Pipfile.lock 256 | 257 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 258 | __pypackages__/ 259 | 260 | # Celery stuff 261 | celerybeat-schedule 262 | celerybeat.pid 263 | 264 | # SageMath parsed files 265 | *.sage.py 266 | 267 | # Environments 268 | .env 269 | .venv 270 | env/ 271 | venv/ 272 | ENV/ 273 | env.bak/ 274 | venv.bak/ 275 | pythonenv* 276 | 277 | # Spyder project settings 278 | .spyderproject 279 | .spyproject 280 | 281 | # Rope project settings 282 | .ropeproject 283 | 284 | # mkdocs documentation 285 | /site 286 | 287 | # mypy 288 | .mypy_cache/ 289 | .dmypy.json 290 | dmypy.json 291 | 292 | # Pyre type checker 293 | .pyre/ 294 | 295 | # pytype static type analyzer 296 | .pytype/ 297 | 298 | # profiling data 299 | .prof 300 | 301 | ### vscode ### 302 | .vscode 303 | .vscode/* 304 | !.vscode/settings.json 305 | !.vscode/tasks.json 306 | !.vscode/launch.json 307 | !.vscode/extensions.json 308 | *.code-workspace 309 | 310 | ### Windows ### 311 | # Windows thumbnail cache files 312 | Thumbs.db 313 | Thumbs.db:encryptable 314 | ehthumbs.db 315 | ehthumbs_vista.db 316 | 317 | # Dump file 318 | *.stackdump 319 | 320 | # Folder config file 321 | [Dd]esktop.ini 322 | 323 | # Recycle Bin used on file shares 324 | $RECYCLE.BIN/ 325 | 326 | # Windows Installer files 327 | *.cab 328 | *.msi 329 | *.msix 330 | *.msm 331 | *.msp 332 | 333 | # Windows shortcuts 334 | *.lnk 335 | 336 | # End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos -------------------------------------------------------------------------------- /wsl/retriever/data/labels.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Dict, List, Set, Union 4 | 5 | 6 | class Labels: 7 | """ 8 | Class that contains the labels for a model. 9 | 10 | Args: 11 | _labels_to_index (:obj:`Dict[str, Dict[str, int]]`): 12 | A dictionary from :obj:`str` to :obj:`int`. 13 | _index_to_labels (:obj:`Dict[str, Dict[int, str]]`): 14 | A dictionary from :obj:`int` to :obj:`str`. 15 | """ 16 | 17 | def __init__( 18 | self, 19 | _labels_to_index: Dict[str, Dict[str, int]] = None, 20 | _index_to_labels: Dict[str, Dict[int, str]] = None, 21 | **kwargs, 22 | ): 23 | self._labels_to_index = _labels_to_index or {"labels": {}} 24 | self._index_to_labels = _index_to_labels or {"labels": {}} 25 | # if _labels_to_index is not empty and _index_to_labels is not provided 26 | # to the constructor, build the inverted label dictionary 27 | if not _index_to_labels and _labels_to_index: 28 | for namespace in self._labels_to_index: 29 | self._index_to_labels[namespace] = { 30 | v: k for k, v in self._labels_to_index[namespace].items() 31 | } 32 | 33 | def get_index_from_label(self, label: str, namespace: str = "labels") -> int: 34 | """ 35 | Returns the index of a literal label. 36 | 37 | Args: 38 | label (:obj:`str`): 39 | The string representation of the label. 40 | namespace (:obj:`str`, optional, defaults to ``labels``): 41 | The namespace where the label belongs, e.g. ``roles`` for a SRL task. 42 | 43 | Returns: 44 | :obj:`int`: The index of the label. 45 | """ 46 | if namespace not in self._labels_to_index: 47 | raise ValueError( 48 | f"Provided namespace `{namespace}` is not in the label dictionary." 49 | ) 50 | 51 | if label not in self._labels_to_index[namespace]: 52 | raise ValueError(f"Provided label {label} is not in the label dictionary.") 53 | 54 | return self._labels_to_index[namespace][label] 55 | 56 | def get_label_from_index(self, index: int, namespace: str = "labels") -> str: 57 | """ 58 | Returns the string representation of the label index. 59 | 60 | Args: 61 | index (:obj:`int`): 62 | The index of the label. 63 | namespace (:obj:`str`, optional, defaults to ``labels``): 64 | The namespace where the label belongs, e.g. ``roles`` for a SRL task. 65 | 66 | Returns: 67 | :obj:`str`: The string representation of the label. 68 | """ 69 | if namespace not in self._index_to_labels: 70 | raise ValueError( 71 | f"Provided namespace `{namespace}` is not in the label dictionary." 72 | ) 73 | 74 | if index not in self._index_to_labels[namespace]: 75 | raise ValueError( 76 | f"Provided label `{index}` is not in the label dictionary." 77 | ) 78 | 79 | return self._index_to_labels[namespace][index] 80 | 81 | def add_labels( 82 | self, 83 | labels: Union[str, List[str], Set[str], Dict[str, int]], 84 | namespace: str = "labels", 85 | ) -> List[int]: 86 | """ 87 | Adds the labels in input in the label dictionary. 88 | 89 | Args: 90 | labels (:obj:`str`, :obj:`List[str]`, :obj:`Set[str]`): 91 | The labels (single label, list of labels or set of labels) to add to the dictionary. 92 | namespace (:obj:`str`, optional, defaults to ``labels``): 93 | Namespace where the labels belongs. 94 | 95 | Returns: 96 | :obj:`List[int]`: The index of the labels just inserted. 97 | """ 98 | if isinstance(labels, dict): 99 | self._labels_to_index[namespace] = labels 100 | self._index_to_labels[namespace] = { 101 | v: k for k, v in self._labels_to_index[namespace].items() 102 | } 103 | # normalize input 104 | if isinstance(labels, (str, list)): 105 | labels = set(labels) 106 | # if new namespace, add to the dictionaries 107 | if namespace not in self._labels_to_index: 108 | self._labels_to_index[namespace] = {} 109 | self._index_to_labels[namespace] = {} 110 | # returns the new indices 111 | return [self._add_label(label, namespace) for label in labels] 112 | 113 | def _add_label(self, label: str, namespace: str = "labels") -> int: 114 | """ 115 | Adds the label in input in the label dictionary. 116 | 117 | Args: 118 | label (:obj:`str`): 119 | The label to add to the dictionary. 120 | namespace (:obj:`str`, optional, defaults to ``labels``): 121 | Namespace where the label belongs. 122 | 123 | Returns: 124 | :obj:`List[int]`: The index of the label just inserted. 125 | """ 126 | if label not in self._labels_to_index[namespace]: 127 | index = len(self._labels_to_index[namespace]) 128 | self._labels_to_index[namespace][label] = index 129 | self._index_to_labels[namespace][index] = label 130 | return index 131 | else: 132 | return self._labels_to_index[namespace][label] 133 | 134 | def get_labels(self, namespace: str = "labels") -> Dict[str, int]: 135 | """ 136 | Returns all the labels that belongs to the input namespace. 137 | 138 | Args: 139 | namespace (:obj:`str`, optional, defaults to ``labels``): 140 | Labels namespace to retrieve. 141 | 142 | Returns: 143 | :obj:`Dict[str, int]`: The label dictionary, from ``str`` to ``int``. 144 | """ 145 | if namespace not in self._labels_to_index: 146 | raise ValueError( 147 | f"Provided namespace `{namespace}` is not in the label dictionary." 148 | ) 149 | return self._labels_to_index[namespace] 150 | 151 | def get_label_size(self, namespace: str = "labels") -> int: 152 | """ 153 | Returns the number of the labels in the namespace dictionary. 154 | 155 | Args: 156 | namespace (:obj:`str`, optional, defaults to ``labels``): 157 | Labels namespace to retrieve. 158 | 159 | Returns: 160 | :obj:`int`: Number of labels. 161 | """ 162 | if namespace not in self._labels_to_index: 163 | raise ValueError( 164 | f"Provided namespace `{namespace}` is not in the label dictionary." 165 | ) 166 | return len(self._labels_to_index[namespace]) 167 | 168 | def get_namespaces(self) -> List[str]: 169 | """ 170 | Returns all the namespaces in the label dictionary. 171 | 172 | Returns: 173 | :obj:`List[str]`: The namespaces in the label dictionary. 174 | """ 175 | return list(self._labels_to_index.keys()) 176 | 177 | @classmethod 178 | def from_file(cls, file_path: Union[str, Path, dict], **kwargs): 179 | with open(file_path, "r") as f: 180 | labels_to_index = json.load(f) 181 | return cls(labels_to_index, **kwargs) 182 | 183 | def save(self, file_path: Union[str, Path, dict], indent: int = 2, **kwargs): 184 | with open(file_path, "w") as f: 185 | json.dump(self._labels_to_index, f, indent=indent) 186 | -------------------------------------------------------------------------------- /wsl/inference/data/tokenizers/spacy_tokenizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Dict, List, Tuple, Union 3 | 4 | import spacy 5 | 6 | # from ipa.common.utils import load_spacy 7 | from spacy.cli.download import download as spacy_download 8 | from spacy.tokens import Doc 9 | 10 | from wsl.common.log import get_logger 11 | from wsl.inference.data.objects import Word 12 | from wsl.inference.data.tokenizers import SPACY_LANGUAGE_MAPPER 13 | from wsl.inference.data.tokenizers.base_tokenizer import BaseTokenizer 14 | 15 | logger = get_logger(level=logging.DEBUG) 16 | 17 | # Spacy and Stanza stuff 18 | 19 | LOADED_SPACY_MODELS: Dict[Tuple[str, bool, bool, bool, bool], spacy.Language] = {} 20 | 21 | 22 | def load_spacy( 23 | language: str, 24 | pos_tags: bool = False, 25 | lemma: bool = False, 26 | parse: bool = False, 27 | split_on_spaces: bool = False, 28 | ) -> spacy.Language: 29 | """ 30 | Download and load spacy model. 31 | 32 | Args: 33 | language (:obj:`str`, defaults to :obj:`en`): 34 | Language of the text to tokenize. 35 | pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): 36 | If :obj:`True`, performs POS tagging with spacy model. 37 | lemma (:obj:`bool`, optional, defaults to :obj:`False`): 38 | If :obj:`True`, performs lemmatization with spacy model. 39 | parse (:obj:`bool`, optional, defaults to :obj:`False`): 40 | If :obj:`True`, performs dependency parsing with spacy model. 41 | split_on_spaces (:obj:`bool`, optional, defaults to :obj:`False`): 42 | If :obj:`True`, will split by spaces without performing tokenization. 43 | 44 | Returns: 45 | :obj:`spacy.Language`: The spacy model loaded. 46 | """ 47 | exclude = ["vectors", "textcat", "ner"] 48 | if not pos_tags: 49 | exclude.append("tagger") 50 | if not lemma: 51 | exclude.append("lemmatizer") 52 | if not parse: 53 | exclude.append("parser") 54 | 55 | # check if the model is already loaded 56 | # if so, there is no need to reload it 57 | spacy_params = (language, pos_tags, lemma, parse, split_on_spaces) 58 | if spacy_params not in LOADED_SPACY_MODELS: 59 | try: 60 | spacy_tagger = spacy.load(language, exclude=exclude) 61 | except OSError: 62 | logger.warning( 63 | "Spacy model '%s' not found. Downloading and installing.", language 64 | ) 65 | spacy_download(language) 66 | spacy_tagger = spacy.load(language, exclude=exclude) 67 | 68 | # if everything is disabled, return only the tokenizer 69 | # for faster tokenization 70 | # TODO: is it really faster? 71 | # TODO: check split_on_spaces behaviour if we don't do this if 72 | if len(exclude) >= 6 and split_on_spaces: 73 | spacy_tagger = spacy_tagger.tokenizer 74 | LOADED_SPACY_MODELS[spacy_params] = spacy_tagger 75 | 76 | return LOADED_SPACY_MODELS[spacy_params] 77 | 78 | 79 | class SpacyTokenizer(BaseTokenizer): 80 | """ 81 | A :obj:`Tokenizer` that uses SpaCy to tokenizer and preprocess the text. It returns :obj:`Word` objects. 82 | 83 | Args: 84 | language (:obj:`str`, optional, defaults to :obj:`en`): 85 | Language of the text to tokenize. 86 | return_pos_tags (:obj:`bool`, optional, defaults to :obj:`False`): 87 | If :obj:`True`, performs POS tagging with spacy model. 88 | return_lemmas (:obj:`bool`, optional, defaults to :obj:`False`): 89 | If :obj:`True`, performs lemmatization with spacy model. 90 | return_deps (:obj:`bool`, optional, defaults to :obj:`False`): 91 | If :obj:`True`, performs dependency parsing with spacy model. 92 | use_gpu (:obj:`bool`, optional, defaults to :obj:`False`): 93 | If :obj:`True`, will load the Stanza model on GPU. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | language: str = "en", 99 | return_pos_tags: bool = False, 100 | return_lemmas: bool = False, 101 | return_deps: bool = False, 102 | use_gpu: bool = False, 103 | ): 104 | super().__init__() 105 | if language not in SPACY_LANGUAGE_MAPPER: 106 | raise ValueError( 107 | f"`{language}` language not supported. The supported " 108 | f"languages are: {list(SPACY_LANGUAGE_MAPPER.keys())}." 109 | ) 110 | if use_gpu: 111 | # load the model on GPU 112 | # if the GPU is not available or not correctly configured, 113 | # it will rise an error 114 | spacy.require_gpu() 115 | self.spacy = load_spacy( 116 | SPACY_LANGUAGE_MAPPER[language], 117 | return_pos_tags, 118 | return_lemmas, 119 | return_deps, 120 | ) 121 | 122 | def __call__( 123 | self, 124 | texts: Union[str, List[str], List[List[str]]], 125 | is_split_into_words: bool = False, 126 | **kwargs, 127 | ) -> Union[List[Word], List[List[Word]]]: 128 | """ 129 | Tokenize the input into single words using SpaCy models. 130 | 131 | Args: 132 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 133 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 134 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): 135 | If :obj:`True` and the input is a string, the input is split on spaces. 136 | 137 | Returns: 138 | :obj:`List[List[Word]]`: The input text tokenized in single words. 139 | 140 | Example:: 141 | 142 | >>> from wsl.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer 143 | 144 | >>> spacy_tokenizer = SpacyTokenizer(language="en", pos_tags=True, lemma=True) 145 | >>> spacy_tokenizer("Mary sold the car to John.") 146 | 147 | """ 148 | # check if input is batched or a single sample 149 | is_batched = self.check_is_batched(texts, is_split_into_words) 150 | 151 | if is_batched: 152 | tokenized = self.tokenize_batch(texts, is_split_into_words) 153 | else: 154 | tokenized = self.tokenize(texts, is_split_into_words) 155 | 156 | return tokenized 157 | 158 | def tokenize(self, text: Union[str, List[str]], is_split_into_words: bool) -> Doc: 159 | if is_split_into_words: 160 | if isinstance(text, str): 161 | text = text.split(" ") 162 | elif isinstance(text, list): 163 | text = text 164 | else: 165 | raise ValueError( 166 | f"text must be either `str` or `list`, found: `{type(text)}`" 167 | ) 168 | spaces = [True] * len(text) 169 | return self.spacy(Doc(self.spacy.vocab, words=text, spaces=spaces)) 170 | return self.spacy(text) 171 | 172 | def tokenize_batch( 173 | self, texts: Union[List[str], List[List[str]]], is_split_into_words: bool 174 | ) -> list[Any] | list[Doc]: 175 | try: 176 | if is_split_into_words: 177 | if isinstance(texts[0], str): 178 | texts = [text.split(" ") for text in texts] 179 | elif isinstance(texts[0], list): 180 | texts = texts 181 | else: 182 | raise ValueError( 183 | f"text must be either `str` or `list`, found: `{type(texts[0])}`" 184 | ) 185 | spaces = [[True] * len(text) for text in texts] 186 | texts = [ 187 | Doc(self.spacy.vocab, words=text, spaces=space) 188 | for text, space in zip(texts, spaces) 189 | ] 190 | return list(self.spacy.pipe(texts)) 191 | except AttributeError: 192 | # a WhitespaceSpacyTokenizer has no `pipe()` method, we use simple for loop 193 | return [self.spacy(tokens) for tokens in texts] 194 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path, PosixPath 4 | from typing import Any, Dict, List 5 | 6 | import torch 7 | import transformers as tr 8 | from torch.utils.data import IterableDataset 9 | from transformers import AutoConfig, AutoModel 10 | 11 | from wsl.common.log import get_logger 12 | 13 | # from wsl.common.torch_utils import load_ort_optimized_hf_model 14 | from wsl.common.utils import get_callable_from_string 15 | from wsl.inference.data.objects import AnnotationType 16 | from wsl.reader.pytorch_modules import WSL_READER_CLASS_MAP 17 | from wsl.reader.pytorch_modules.hf.modeling_wsl import WSLReaderConfig, WSLReaderSample 18 | from wsl.retriever.pytorch_modules import PRECISION_MAP 19 | 20 | logger = get_logger(__name__, level=logging.INFO) 21 | 22 | 23 | class WSLReaderBase(torch.nn.Module): 24 | default_reader_class: str | None = None 25 | default_data_class: str | None = None 26 | 27 | def __init__( 28 | self, 29 | transformer_model: str | tr.PreTrainedModel | None = None, 30 | additional_special_symbols: int = 0, 31 | num_layers: int | None = None, 32 | activation: str = "gelu", 33 | linears_hidden_size: int | None = 512, 34 | use_last_k_layers: int = 1, 35 | training: bool = False, 36 | device: str | torch.device | None = None, 37 | precision: int = 32, 38 | tokenizer: str | tr.PreTrainedTokenizer | None = None, 39 | dataset: IterableDataset | str | None = None, 40 | default_reader_class: tr.PreTrainedModel | str | None = None, 41 | **kwargs, 42 | ) -> None: 43 | super().__init__() 44 | 45 | self.default_reader_class = default_reader_class or self.default_reader_class 46 | 47 | if self.default_reader_class is None: 48 | raise ValueError("You must specify a default reader class.") 49 | 50 | # get the callable for the default reader class 51 | self.default_reader_class: tr.PreTrainedModel = get_callable_from_string( 52 | self.default_reader_class 53 | ) 54 | if isinstance(transformer_model, PosixPath): 55 | transformer_model = str(transformer_model) 56 | if isinstance(transformer_model, str): 57 | self.name_or_path = transformer_model 58 | config = AutoConfig.from_pretrained( 59 | transformer_model, trust_remote_code=True 60 | ) 61 | if "wsl-reader" in config.model_type: 62 | transformer_model = self.default_reader_class.from_pretrained( 63 | transformer_model, 64 | config=config, 65 | ignore_mismatched_sizes=True, 66 | trust_remote_code=True, 67 | **kwargs, 68 | ) 69 | else: 70 | reader_config = WSLReaderConfig( 71 | transformer_model=transformer_model, 72 | additional_special_symbols=additional_special_symbols, 73 | num_layers=num_layers, 74 | activation=activation, 75 | linears_hidden_size=linears_hidden_size, 76 | use_last_k_layers=use_last_k_layers, 77 | training=training, 78 | **kwargs, 79 | ) 80 | transformer_model = self.default_reader_class(reader_config) 81 | self.name_or_path = transformer_model.config.transformer_model 82 | else: 83 | self.name_or_path = transformer_model.config.transformer_model 84 | 85 | self.wsl_reader_model = transformer_model 86 | 87 | self.wsl_reader_model_config = self.wsl_reader_model.config 88 | 89 | # get the tokenizer 90 | self._tokenizer = tokenizer 91 | 92 | # and instantiate the dataset class 93 | self.dataset: IterableDataset | None = dataset 94 | 95 | # move the model to the device 96 | self.to(device or torch.device("cpu")) 97 | 98 | # set the precision 99 | self.precision = precision 100 | self.to(PRECISION_MAP[precision]) 101 | 102 | def forward(self, **kwargs) -> Dict[str, Any]: 103 | return self.wsl_reader_model(**kwargs) 104 | 105 | def _read(self, *args, **kwargs) -> Any: 106 | raise NotImplementedError 107 | 108 | @torch.no_grad() 109 | @torch.inference_mode() 110 | def read( 111 | self, 112 | text: List[str] | List[List[str]] | None = None, 113 | samples: List[WSLReaderSample] | None = None, 114 | input_ids: torch.Tensor | None = None, 115 | attention_mask: torch.Tensor | None = None, 116 | token_type_ids: torch.Tensor | None = None, 117 | prediction_mask: torch.Tensor | None = None, 118 | special_symbols_mask: torch.Tensor | None = None, 119 | candidates: List[List[str]] | None = None, 120 | max_length: int = 1000, 121 | max_batch_size: int = 128, 122 | token_batch_size: int = 2048, 123 | precision: int | str | None = None, 124 | annotation_type: str | AnnotationType = AnnotationType.CHAR, 125 | progress_bar: bool = False, 126 | *args, 127 | **kwargs, 128 | ) -> List[WSLReaderSample] | List[List[WSLReaderSample]]: 129 | """ 130 | Reads the given text. 131 | 132 | Args: 133 | text (:obj:`List[str]` or :obj:`List[List[str]]`, `optional`): 134 | The text to read in tokens. If a list of list of tokens is provided, each 135 | inner list is considered a sentence. 136 | samples (:obj:`List[RelikReaderSample]`, `optional`): 137 | The samples to read. If provided, `text` and `candidates` are ignored. 138 | input_ids (:obj:`torch.Tensor`, `optional`): 139 | The input ids of the text. 140 | attention_mask (:obj:`torch.Tensor`, `optional`): 141 | The attention mask of the text. 142 | token_type_ids (:obj:`torch.Tensor`, `optional`): 143 | The token type ids of the text. 144 | prediction_mask (:obj:`torch.Tensor`, `optional`): 145 | The prediction mask of the text. 146 | special_symbols_mask (:obj:`torch.Tensor`, `optional`): 147 | The special symbols mask of the text. 148 | candidates (:obj:`List[List[str]]`, `optional`): 149 | The candidates of the text. 150 | max_length (:obj:`int`, `optional`, defaults to 1024): 151 | The maximum length of the text. 152 | max_batch_size (:obj:`int`, `optional`, defaults to 128): 153 | The maximum batch size. 154 | token_batch_size (:obj:`int`, `optional`): 155 | The maximum number of tokens per batch. 156 | precision (:obj:`int` or :obj:`str`, `optional`): 157 | The precision to use. If not provided, the default is 32 bit. 158 | annotation_type (`str` or `AnnotationType`, `optional`, defaults to `char`): 159 | The type of annotation to return. If `char`, the spans will be in terms of 160 | character offsets. If `word`, the spans will be in terms of word offsets. 161 | progress_bar (:obj:`bool`, `optional`, defaults to :obj:`False`): 162 | Whether to show a progress bar. 163 | 164 | Returns: 165 | The predicted labels for each sample. 166 | """ 167 | if isinstance(annotation_type, str): 168 | try: 169 | annotation_type = AnnotationType(annotation_type) 170 | except ValueError: 171 | raise ValueError( 172 | f"Annotation type `{annotation_type}` not recognized. " 173 | f"Please choose one of {list(AnnotationType)}." 174 | ) 175 | 176 | if text is None and input_ids is None and samples is None: 177 | raise ValueError( 178 | "Either `text` or `input_ids` or `samples` must be provided." 179 | ) 180 | if (input_ids is None and samples is None) and ( 181 | text is None or candidates is None 182 | ): 183 | raise ValueError( 184 | "`text` and `candidates` must be provided to return the predictions when " 185 | "`input_ids` and `samples` is not provided." 186 | ) 187 | if text is not None and samples is None: 188 | if len(text) != len(candidates): 189 | raise ValueError("`text` and `candidates` must have the same length.") 190 | if isinstance(text[0], str): # change to list of text 191 | text = [text] 192 | candidates = [candidates] 193 | 194 | samples = [ 195 | WSLReaderSample(tokens=t, candidates=c) 196 | for t, c in zip(text, candidates) 197 | ] 198 | 199 | return self._read( 200 | samples, 201 | input_ids, 202 | attention_mask, 203 | token_type_ids, 204 | prediction_mask, 205 | special_symbols_mask, 206 | max_length, 207 | max_batch_size, 208 | token_batch_size, 209 | precision or self.precision, 210 | annotation_type, 211 | progress_bar, 212 | *args, 213 | **kwargs, 214 | ) 215 | 216 | @property 217 | def device(self) -> torch.device: 218 | """ 219 | The device of the model. 220 | """ 221 | return next(self.parameters()).device 222 | 223 | @property 224 | def tokenizer(self) -> tr.PreTrainedTokenizer: 225 | """ 226 | The tokenizer. 227 | """ 228 | if self._tokenizer: 229 | return self._tokenizer 230 | 231 | self._tokenizer = tr.AutoTokenizer.from_pretrained( 232 | self.wsl_reader_model.config.name_or_path 233 | if self.wsl_reader_model.config.name_or_path 234 | else self.wsl_reader_model.config.transformer_model 235 | ) 236 | return self._tokenizer 237 | 238 | @classmethod 239 | def from_pretrained( 240 | cls, 241 | model_name_or_dir: str | os.PathLike, 242 | **kwargs, 243 | ): 244 | transformer_model = AutoModel.from_pretrained( 245 | model_name_or_dir, trust_remote_code=True, **kwargs 246 | ) 247 | if transformer_model.__class__.__name__ not in WSL_READER_CLASS_MAP: 248 | raise ValueError(f"Model type {type(transformer_model)} not recognized.") 249 | 250 | reader_class = WSL_READER_CLASS_MAP[transformer_model.__class__.__name__] 251 | reader_class = get_callable_from_string(reader_class) 252 | return reader_class(transformer_model=transformer_model, **kwargs) 253 | 254 | def save_pretrained( 255 | self, 256 | output_dir: str | os.PathLike, 257 | model_name: str | None = None, 258 | push_to_hub: bool = False, 259 | **kwargs, 260 | ) -> None: 261 | """ 262 | Saves the model to the given path. 263 | 264 | Args: 265 | output_dir (`str` or :obj:`os.PathLike`): 266 | The path to save the model to. 267 | model_name (`str`, `optional`): 268 | The name of the model. If not provided, the model will be saved as 269 | `default_reader_class.__name__`. 270 | push_to_hub (`bool`, `optional`, defaults to `False`): 271 | Whether to push the model to the HuggingFace Hub. 272 | **kwargs: 273 | Additional keyword arguments to pass to the `save_pretrained` method 274 | """ 275 | # create the output directory 276 | output_dir = Path(output_dir) 277 | output_dir.mkdir(parents=True, exist_ok=True) 278 | 279 | model_name = model_name or output_dir.name 280 | 281 | logger.info(f"Saving reader to {output_dir / model_name}") 282 | 283 | # save the model 284 | self.wsl_reader_model.config.register_for_auto_class() 285 | self.wsl_reader_model.register_for_auto_class() 286 | self.wsl_reader_model.save_pretrained( 287 | str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs 288 | ) 289 | 290 | if self.tokenizer: 291 | logger.info("Saving also the tokenizer") 292 | self.tokenizer.save_pretrained( 293 | str(output_dir / model_name), push_to_hub=push_to_hub, **kwargs 294 | ) 295 | -------------------------------------------------------------------------------- /wsl/retriever/indexers/document.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import pickle 4 | import sys 5 | from pathlib import Path 6 | from typing import Dict, List, Union 7 | 8 | from wsl.common.log import get_logger 9 | from wsl.common.utils import JsonSerializable 10 | 11 | csv.field_size_limit(sys.maxsize) 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class Document: 17 | def __init__( 18 | self, 19 | text: str, 20 | id: int = None, 21 | metadata: Dict = None, 22 | **kwargs, 23 | ): 24 | self.text = text 25 | # if id is not provided, we use the hash of the text 26 | self.id = id if id is not None else hash(text) 27 | # if metadata is not provided, we use an empty dictionary 28 | self.metadata = metadata or {} 29 | 30 | def __str__(self): 31 | return f"{self.id}:{self.text}" 32 | 33 | def __repr__(self): 34 | return json.dumps(self.to_dict()) 35 | 36 | def __eq__(self, other): 37 | if isinstance(other, Document): 38 | return self.id == other.id 39 | elif isinstance(other, int): 40 | return self.id == other 41 | elif isinstance(other, str): 42 | return self.text == other 43 | else: 44 | raise ValueError( 45 | f"Document must be compared with a Document, an int or a str, got `{type(other)}`" 46 | ) 47 | 48 | def to_dict(self): 49 | return {"text": self.text, "id": self.id, "metadata": self.metadata} 50 | 51 | def to_json(self): 52 | return json.dumps(self.to_dict()) 53 | 54 | @classmethod 55 | def from_dict(cls, d: Dict): 56 | return cls(**d) 57 | 58 | @classmethod 59 | def from_file(cls, file_path: Union[str, Path], **kwargs): 60 | with open(file_path, "r") as f: 61 | d = json.load(f) 62 | return cls.from_dict(d) 63 | 64 | def save(self, file_path: Union[str, Path], **kwargs): 65 | with open(file_path, "w") as f: 66 | json.dump(self.to_dict(), f, indent=2) 67 | 68 | 69 | class DocumentStore: 70 | """ 71 | A document store is a collection of documents. 72 | 73 | Args: 74 | documents (:obj:`List[Document]`): 75 | The documents to store. 76 | """ 77 | 78 | def __init__(self, documents: List[Document] = None) -> None: 79 | if documents is None: 80 | documents = [] 81 | # if self.ingore_case: 82 | # documents = [doc.lower() for doc in documents] 83 | self._documents = documents 84 | # build an index for the documents 85 | self._documents_index = {doc.id: doc for doc in self._documents} 86 | # build a reverse index for the documents 87 | self._documents_reverse_index = {doc.text: doc for doc in self._documents} 88 | 89 | def __len__(self): 90 | return len(self._documents) 91 | 92 | def __getitem__(self, index): 93 | return self._documents[index] 94 | 95 | def __iter__(self): 96 | return iter(self._documents) 97 | 98 | def __contains__(self, item): 99 | if isinstance(item, int): 100 | return item in self._documents_index 101 | elif isinstance(item, str): 102 | return item in self._documents_reverse_index 103 | elif isinstance(item, Document): 104 | return item.id in self._documents_index 105 | # return item in self._documents_index 106 | 107 | def __str__(self): 108 | return f"DocumentStore with {len(self)} documents" 109 | 110 | def __repr__(self): 111 | return self.__str__() 112 | 113 | def get_document_from_id(self, id: int) -> Document: 114 | """ 115 | Retrieve a document by its ID. 116 | 117 | Args: 118 | id (`int`): 119 | The ID of the document to retrieve. 120 | 121 | Returns: 122 | Optional[Document]: The document with the given ID, or None if it does not exist. 123 | """ 124 | if id not in self._documents_index: 125 | logger.warning(f"Document with id `{id}` does not exist, skipping") 126 | return self._documents_index.get(id, None) 127 | 128 | def get_document_from_text(self, text: str) -> Document: 129 | """ 130 | Retrieve the document by its text. 131 | 132 | Args: 133 | text (`str`): 134 | The text of the document to retrieve. 135 | 136 | Returns: 137 | Optional[Document]: The document with the given text, or None if it does not exist. 138 | """ 139 | if text not in self._documents_reverse_index: 140 | logger.warning(f"Document with text `{text}` does not exist, skipping") 141 | return self._documents_reverse_index.get(text, None) 142 | 143 | def get_document_from_index(self, index: int) -> Document: 144 | """ 145 | Retrieve the document by its index. 146 | 147 | Args: 148 | index (`int`): 149 | The index of the document to retrieve. 150 | 151 | Returns: 152 | Optional[Document]: The document with the given index, or None if it does not exist. 153 | """ 154 | if index >= len(self._documents): 155 | logger.warning(f"Document with index `{index}` does not exist, skipping") 156 | return self._documents[index] 157 | 158 | def add_documents( 159 | self, documents: List[Document] | List[str] | List[Dict] 160 | ) -> List[Document]: 161 | """ 162 | Add a list of documents to the document store. 163 | 164 | Args: 165 | documents (`List[Document]`): 166 | The documents to add. 167 | 168 | Returns: 169 | List[Document]: The documents just added. 170 | """ 171 | return [ 172 | ( 173 | self.add_document(Document.from_dict(doc)) 174 | if isinstance(doc, Dict) 175 | else self.add_document(doc) 176 | ) 177 | for doc in documents 178 | ] 179 | 180 | def add_document( 181 | self, 182 | text: str | Document, 183 | id: int | None = None, 184 | metadata: Dict | None = None, 185 | ) -> Document: 186 | """ 187 | Add a document to the document store. 188 | 189 | Args: 190 | text (`str`): 191 | The text of the document to add. 192 | id (`int`, optional, defaults to None): 193 | The ID of the document to add. 194 | metadata (`Dict`, optional, defaults to None): 195 | The metadata of the document to add. 196 | 197 | Returns: 198 | Document: The document just added. 199 | """ 200 | if isinstance(text, str): 201 | # check if the document already exists 202 | if text in self: 203 | logger.warning(f"Document `{text}` already exists, skipping") 204 | return self._documents_reverse_index[text] 205 | if id is None: 206 | # get the len of the documents and add 1 207 | id = len(self._documents) # + 1 208 | text = Document(text, id, metadata) 209 | 210 | if text in self: 211 | logger.warning(f"Document `{text}` already exists, skipping") 212 | return self._documents_index[text.id] 213 | 214 | self._documents.append(text) 215 | self._documents_index[text.id] = text 216 | self._documents_reverse_index[text.text] = text 217 | return text 218 | # if id in self._documents_index: 219 | # logger.warning(f"Document with id `{id}` already exists, skipping") 220 | # return self._documents_index[id] 221 | # if text_or_document in self._documents_reverse_index: 222 | # logger.warning(f"Document with text `{text_or_document}` already exists, skipping") 223 | # return self._documents_reverse_index[text_or_document] 224 | # self._documents.append(Document(text_or_document, id, metadata)) 225 | # self._documents_index[id] = self._documents[-1] 226 | # self._documents_reverse_index[text_or_document] = self._documents[-1] 227 | # return self._documents_index[id] 228 | 229 | def delete_document(self, document: int | str | Document) -> bool: 230 | """ 231 | Delete a document from the document store. 232 | 233 | Args: 234 | document (`int`, `str` or `Document`): 235 | The document to delete. 236 | 237 | Returns: 238 | bool: True if the document has been deleted, False otherwise. 239 | """ 240 | if isinstance(document, int): 241 | return self.delete_by_id(document) 242 | elif isinstance(document, str): 243 | return self.delete_by_text(document) 244 | elif isinstance(document, Document): 245 | return self.delete_by_document(document) 246 | else: 247 | raise ValueError( 248 | f"Document must be an int, a str or a Document, got `{type(document)}`" 249 | ) 250 | 251 | def delete_by_id(self, id: int) -> bool: 252 | """ 253 | Delete a document by its ID. 254 | 255 | Args: 256 | id (`int`): 257 | The ID of the document to delete. 258 | 259 | Returns: 260 | bool: True if the document has been deleted, False otherwise. 261 | """ 262 | if id not in self._documents_index: 263 | logger.warning(f"Document with id `{id}` does not exist, skipping") 264 | return False 265 | del self._documents_reverse_index[self._documents_index[id]] 266 | del self._documents_index[id] 267 | return True 268 | 269 | def delete_by_text(self, text: str) -> bool: 270 | """ 271 | Delete a document by its text. 272 | 273 | Args: 274 | text (`str`): 275 | The text of the document to delete. 276 | 277 | Returns: 278 | bool: True if the document has been deleted, False otherwise. 279 | """ 280 | if text not in self._documents_reverse_index: 281 | logger.warning(f"Document with text `{text}` does not exist, skipping") 282 | return False 283 | del self._documents_reverse_index[text] 284 | del self._documents_index[self._documents_index[text]] 285 | return True 286 | 287 | def delete_by_document(self, document: Document) -> bool: 288 | """ 289 | Delete a document by its text. 290 | 291 | Args: 292 | document (:obj:`Document`): 293 | The document to delete. 294 | 295 | Returns: 296 | bool: True if the document has been deleted, False otherwise. 297 | """ 298 | if document.id not in self._documents_index: 299 | logger.warning(f"Document {document} does not exist, skipping") 300 | return False 301 | del self._documents[self._documents.index(document)] 302 | del self._documents_index[document.id] 303 | del self._documents_reverse_index[self._documents_index[document.id]] 304 | 305 | def to_dict(self): 306 | return [doc.to_dict() for doc in self._documents] 307 | 308 | @classmethod 309 | def from_dict(cls, d): 310 | return cls([Document.from_dict(doc) for doc in d]) 311 | 312 | @classmethod 313 | def from_file(cls, file_path: Union[str, Path], **kwargs): 314 | with open(file_path, "r") as f: 315 | # load a json lines file 316 | d = [Document.from_dict(json.loads(line)) for line in f] 317 | return cls(d) 318 | 319 | @classmethod 320 | def from_pickle(cls, file_path: Union[str, Path], **kwargs): 321 | with open(file_path, "rb") as handle: 322 | d = pickle.load(handle) 323 | return cls(d) 324 | 325 | @classmethod 326 | def from_tsv( 327 | cls, 328 | file_path: Union[str, Path], 329 | ingore_case: bool = False, 330 | delimiter: str = "\t", 331 | **kwargs, 332 | ): 333 | d = [] 334 | # load a tsv/csv file and take the header into account 335 | # the header must be `id\ttext\t[list of metadata keys]` 336 | with open(file_path, "r", encoding="utf8") as f: 337 | csv_reader = csv.reader(f, delimiter=delimiter, **kwargs) 338 | header = next(csv_reader) 339 | id, text, *metadata_keys = header 340 | for i, row in enumerate(csv_reader): 341 | # check if id can be casted to int 342 | # if not, we add it to the metadata and use `i` as id 343 | try: 344 | s_id = int(row[header.index(id)]) 345 | row_metadata_keys = metadata_keys 346 | except ValueError: 347 | row_metadata_keys = [id] + metadata_keys 348 | s_id = i 349 | 350 | d.append( 351 | Document( 352 | text=( 353 | row[header.index(text)].strip().lower() 354 | if ingore_case 355 | else row[header.index(text)].strip() 356 | ), 357 | id=s_id, # row[header.index(id)], 358 | metadata={ 359 | key: row[header.index(key)] for key in row_metadata_keys 360 | }, 361 | ) 362 | ) 363 | return cls(d) 364 | 365 | def save(self, file_path: Union[str, Path], **kwargs): 366 | with open(file_path, "w") as f: 367 | for doc in self._documents: 368 | # save as json lines 369 | f.write(json.dumps(doc.to_dict()) + "\n") 370 | -------------------------------------------------------------------------------- /wsl/retriever/indexers/inmemory.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import logging 3 | import os 4 | from typing import Callable, List, Optional, Union 5 | 6 | import torch 7 | import transformers as tr 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from wsl.common.log import get_logger 12 | from wsl.common.torch_utils import get_autocast_context 13 | from wsl.retriever.common.model_inputs import ModelInputs 14 | from wsl.retriever.data.base.datasets import BaseDataset 15 | from wsl.retriever.indexers.base import BaseDocumentIndex 16 | from wsl.retriever.indexers.document import Document, DocumentStore 17 | from wsl.retriever.pytorch_modules import PRECISION_MAP, RetrievedSample 18 | 19 | # check if ORT is available 20 | # if is_package_available("onnxruntime"): 21 | 22 | logger = get_logger(__name__, level=logging.INFO) 23 | 24 | 25 | class MatrixMultiplicationModule(torch.nn.Module): 26 | def __init__(self, embeddings): 27 | super().__init__() 28 | self.embeddings = torch.nn.Parameter(embeddings, requires_grad=False) 29 | 30 | def forward(self, query): 31 | return torch.matmul(query, self.embeddings.T) 32 | 33 | 34 | class InMemoryDocumentIndex(BaseDocumentIndex): 35 | def __init__( 36 | self, 37 | documents: Union[str, List[str], List[Document]] = None, 38 | embeddings: Optional[torch.Tensor] = None, 39 | metadata_fields: Optional[List[str]] = None, 40 | separator: Optional[str] = None, 41 | name_or_path: Union[str, os.PathLike, None] = None, 42 | device: str = "cpu", 43 | precision: Union[str, int, torch.dtype] = 32, 44 | *args, 45 | **kwargs, 46 | ) -> None: 47 | """ 48 | An in-memory indexer based on PyTorch. 49 | 50 | Args: 51 | documents (:obj:`Union[List[str]]`): 52 | The documents to be indexed. 53 | embeddings (:obj:`Optional[torch.Tensor]`, `optional`, defaults to :obj:`None`): 54 | The embeddings of the documents. 55 | device (:obj:`str`, `optional`, defaults to "cpu"): 56 | The device to be used for storing the embeddings. 57 | """ 58 | 59 | super().__init__( 60 | documents, embeddings, metadata_fields, separator, name_or_path, device 61 | ) 62 | 63 | if embeddings is not None and documents is not None: 64 | logger.info("Both documents and embeddings are provided.") 65 | if len(documents) != embeddings.shape[0]: 66 | raise ValueError( 67 | "The number of documents and embeddings must be the same. " 68 | f"Got {len(documents)} documents and {embeddings.shape[0]} embeddings." 69 | ) 70 | 71 | # # embeddings of the documents 72 | # self.embeddings = embeddings 73 | # does this do anything? 74 | del embeddings 75 | # convert the embeddings to the desired precision 76 | if precision is not None: 77 | if self.embeddings is not None and self.device == "cpu": 78 | if PRECISION_MAP[precision] == PRECISION_MAP[16]: 79 | logger.info( 80 | f"Precision `{precision}` is not supported on CPU. " 81 | f"Using `{PRECISION_MAP[32]}` instead." 82 | ) 83 | precision = 32 84 | 85 | if ( 86 | self.embeddings is not None 87 | and self.embeddings.dtype != PRECISION_MAP[precision] 88 | ): 89 | logger.info( 90 | f"Index vectors are of type {self.embeddings.dtype}. " 91 | f"Converting to {PRECISION_MAP[precision]}." 92 | ) 93 | self.embeddings = self.embeddings.to(PRECISION_MAP[precision]) 94 | else: 95 | # TODO: a bit redundant, fix this eventually 96 | if ( 97 | # here we trust the device_in_init, since we don't know yet 98 | # the device of the embeddings 99 | ( 100 | self.device_in_init == "cpu" 101 | or self.device_in_init == torch.device("cpu") 102 | ) 103 | and self.embeddings is not None 104 | and self.embeddings.dtype != torch.float32 105 | ): 106 | logger.info( 107 | f"Index vectors are of type {self.embeddings.dtype} but the device is CPU. " 108 | f"Converting to {PRECISION_MAP[32]}." 109 | ) 110 | self.embeddings = self.embeddings.to(PRECISION_MAP[32]) 111 | 112 | # move the embeddings to the desired device 113 | if ( 114 | self.embeddings is not None 115 | and not self.embeddings.device == self.device_in_init 116 | ): 117 | self.embeddings = self.embeddings.to(self.device_in_init) 118 | 119 | # TODO: check interactions with the embeddings 120 | # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings) 121 | # self.mm.eval() 122 | 123 | # precision to be used for the embeddings 124 | self.precision = precision 125 | 126 | @torch.no_grad() 127 | @torch.inference_mode() 128 | def index( 129 | self, 130 | retriever, 131 | documents: Optional[List[Document]] = None, 132 | batch_size: int = 32, 133 | num_workers: int = 4, 134 | max_length: Optional[int] = None, 135 | collate_fn: Optional[Callable] = None, 136 | encoder_precision: Optional[Union[str, int]] = None, 137 | compute_on_cpu: bool = False, 138 | force_reindex: bool = False, 139 | ) -> "InMemoryDocumentIndex": 140 | """ 141 | Index the documents using the encoder. 142 | 143 | Args: 144 | retriever (:obj:`torch.nn.Module`): 145 | The encoder to be used for indexing. 146 | documents (:obj:`List[Document]`, `optional`, defaults to :obj:`None`): 147 | The documents to be indexed. If not provided, the documents provided at the initialization will be used. 148 | batch_size (:obj:`int`, `optional`, defaults to 32): 149 | The batch size to be used for indexing. 150 | num_workers (:obj:`int`, `optional`, defaults to 4): 151 | The number of workers to be used for indexing. 152 | max_length (:obj:`int`, `optional`, defaults to None): 153 | The maximum length of the input to the encoder. 154 | collate_fn (:obj:`Callable`, `optional`, defaults to None): 155 | The collate function to be used for batching. 156 | encoder_precision (:obj:`Union[str, int]`, `optional`, defaults to None): 157 | The precision to be used for the encoder. 158 | compute_on_cpu (:obj:`bool`, `optional`, defaults to False): 159 | Whether to compute the embeddings on CPU. 160 | force_reindex (:obj:`bool`, `optional`, defaults to False): 161 | Whether to force reindexing. 162 | 163 | Returns: 164 | :obj:`InMemoryIndexer`: The indexer object. 165 | """ 166 | 167 | if documents is None and self.documents is None: 168 | raise ValueError("Documents must be provided.") 169 | 170 | if self.embeddings is not None and not force_reindex and documents is None: 171 | logger.info( 172 | "Embeddings are already present and `force_reindex` is `False`. Skipping indexing." 173 | ) 174 | return self 175 | 176 | if force_reindex: 177 | if documents is not None: 178 | self.documents.add_documents(documents) 179 | data = [k for k in self.get_passages()] 180 | 181 | else: 182 | if documents is not None: 183 | data = [k for k in self.get_passages(DocumentStore(documents))] 184 | # add the documents to the actual document store 185 | self.documents.add_documents(documents) 186 | else: 187 | if self.embeddings is None: 188 | data = [k for k in self.get_passages()] 189 | 190 | if collate_fn is None: 191 | tokenizer = retriever.passage_tokenizer 192 | 193 | def collate_fn(x): 194 | return ModelInputs( 195 | tokenizer( 196 | x, 197 | padding=True, 198 | return_tensors="pt", 199 | truncation=True, 200 | max_length=max_length or tokenizer.model_max_length, 201 | ) 202 | ) 203 | 204 | # added prefix for passage retrieve 205 | data = [retriever.passage_prefix + p for p in data] 206 | dataloader = DataLoader( 207 | BaseDataset(name="passage", data=data), 208 | batch_size=batch_size, 209 | shuffle=False, 210 | num_workers=num_workers, 211 | pin_memory=False, 212 | collate_fn=collate_fn, 213 | ) 214 | 215 | encoder = retriever.passage_encoder 216 | 217 | # Create empty lists to store the passage embeddings and passage index 218 | passage_embeddings: List[torch.Tensor] = [] 219 | 220 | encoder_device = "cpu" if compute_on_cpu else encoder.device 221 | 222 | # fucking autocast only wants pure strings like 'cpu' or 'cuda' 223 | # we need to convert the model device to that 224 | device_type_for_autocast = str(encoder_device).split(":")[0] 225 | # autocast doesn't work with CPU and stuff different from bfloat16 226 | autocast_pssg_mngr = ( 227 | contextlib.nullcontext() 228 | if device_type_for_autocast == "cpu" 229 | else ( 230 | torch.autocast( 231 | device_type=device_type_for_autocast, 232 | dtype=PRECISION_MAP[encoder_precision], 233 | ) 234 | ) 235 | ) 236 | with autocast_pssg_mngr: 237 | # Iterate through each batch in the dataloader 238 | for batch in tqdm(dataloader, desc="Indexing"): 239 | # Move the batch to the device 240 | batch: ModelInputs = batch.to(encoder_device) 241 | # Compute the passage embeddings 242 | passage_outs = encoder(**batch).pooler_output 243 | # Append the passage embeddings to the list 244 | if self.device == "cpu": 245 | passage_embeddings.extend([c.detach().cpu() for c in passage_outs]) 246 | else: 247 | passage_embeddings.extend([c for c in passage_outs]) 248 | 249 | # move the passage embeddings to the CPU if not already done 250 | # the move to cpu and then to gpu is needed to avoid OOM when using mixed precision 251 | if not self.device == "cpu": # this if is to avoid unnecessary moves 252 | passage_embeddings = [c.detach().cpu() for c in passage_embeddings] 253 | # stack it 254 | passage_embeddings: torch.Tensor = torch.stack(passage_embeddings, dim=0) 255 | # move the passage embeddings to the gpu if needed 256 | if not self.device == "cpu": 257 | passage_embeddings = passage_embeddings.to(PRECISION_MAP[self.precision]) 258 | passage_embeddings = passage_embeddings.to(self.device) 259 | self.embeddings = passage_embeddings 260 | # update the matrix multiplication module 261 | # self.mm = MatrixMultiplicationModule(embeddings=self.embeddings) 262 | 263 | # free up memory from the unused variable 264 | del passage_embeddings 265 | 266 | return self 267 | 268 | @torch.no_grad() 269 | @torch.inference_mode() 270 | def search(self, query: torch.Tensor, k: int = 1) -> list[list[RetrievedSample]]: 271 | """ 272 | Search the documents using the query. 273 | 274 | Args: 275 | query (:obj:`torch.Tensor`): 276 | The query to be used for searching. 277 | k (:obj:`int`, `optional`, defaults to 1): 278 | The number of documents to be retrieved. 279 | 280 | Returns: 281 | :obj:`List[RetrievedSample]`: The retrieved documents. 282 | """ 283 | 284 | with get_autocast_context(self.device, self.embeddings.dtype): 285 | # move query to the same device as embeddings 286 | query = query.to(self.embeddings.device) 287 | if query.dtype != self.embeddings.dtype: 288 | query = query.to(self.embeddings.dtype) 289 | similarity = torch.matmul(query, self.embeddings.T) 290 | # similarity = self.mm(query) 291 | # Retrieve the indices of the top k passage embeddings 292 | retriever_out: torch.return_types.topk = torch.topk( 293 | similarity, k=min(k, similarity.shape[-1]), dim=1 294 | ) 295 | 296 | # get int values 297 | batch_top_k: List[List[int]] = retriever_out.indices.detach().cpu().tolist() 298 | # get float values 299 | batch_scores: List[List[float]] = retriever_out.values.detach().cpu().tolist() 300 | # Retrieve the passages corresponding to the indices 301 | batch_docs = [ 302 | [self.documents.get_document_from_index(i) for i in indices] 303 | for indices in batch_top_k 304 | ] 305 | # build the output object 306 | batch_retrieved_samples = [ 307 | [ 308 | RetrievedSample(document=doc, score=score) 309 | for doc, score in zip(docs, scores) 310 | ] 311 | for docs, scores in zip(batch_docs, batch_scores) 312 | ] 313 | return batch_retrieved_samples 314 | -------------------------------------------------------------------------------- /wsl/inference/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import hydra 4 | import torch 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | from wsl.common.log import get_logger 8 | from wsl.inference.data.objects import TaskType 9 | from wsl.reader.pytorch_modules.base import WSLReaderBase 10 | from wsl.retriever.indexers.base import BaseDocumentIndex 11 | from wsl.retriever.pytorch_modules import PRECISION_MAP 12 | from wsl.retriever.pytorch_modules.model import WSLRetriever 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def _instantiate_retriever( 18 | retriever: WSLRetriever | DictConfig | Dict, 19 | device: str | None, 20 | precision: int | str | torch.dtype | None, 21 | **kwargs: Any, 22 | ) -> WSLRetriever: 23 | """ 24 | Instantiate a retriever. 25 | 26 | Args: 27 | retriever (`GoldenRetriever`, `DictConfig` or `Dict`): 28 | The retriever to instantiate. 29 | retriever_device (`str`, `optional`): 30 | The device to use for the retriever. 31 | retriever_precision (`int`, `str` or `torch.dtype`, `optional`): 32 | The precision to use for the retriever. 33 | retriever_kwargs (`Dict[str, Any]`, `optional`): 34 | Additional keyword arguments to pass to the retriever. 35 | 36 | Returns: 37 | `GoldenRetriever`: 38 | The instantiated retriever. 39 | """ 40 | if not isinstance(retriever, WSLRetriever): 41 | # convert to DictConfig 42 | retriever = hydra.utils.instantiate( 43 | OmegaConf.create(retriever), 44 | device=device, 45 | precision=precision, 46 | **kwargs, 47 | ) 48 | else: 49 | if device is not None: 50 | logger.info(f"Moving retriever to `{device}`.") 51 | retriever.to(device) 52 | if precision is not None: 53 | logger.info( 54 | f"Setting precision of retriever to `{PRECISION_MAP[precision]}`." 55 | ) 56 | retriever.to(PRECISION_MAP[precision]) 57 | retriever.training = False 58 | retriever.eval() 59 | return retriever 60 | 61 | 62 | def load_retriever( 63 | retriever: WSLRetriever | DictConfig | Dict | str, 64 | device: str | None | torch.device | int = None, 65 | precision: int | str | torch.dtype | None = None, 66 | task: TaskType | str | None = None, 67 | compile: bool = False, 68 | **kwargs, 69 | ) -> Dict[TaskType, WSLRetriever]: 70 | """ 71 | Load and instantiate retrievers for a given task. 72 | 73 | Args: 74 | retriever (GoldenRetriever | DictConfig | Dict): 75 | The retriever object or configuration. 76 | device (str | None): 77 | The device to load the retriever on. 78 | precision (int | str | torch.dtype | None): 79 | The precision of the retriever. 80 | task (TaskType): 81 | The task type for the retriever. 82 | compile (bool, optional): 83 | Whether to compile the retriever. Defaults to False. 84 | **kwargs: 85 | Additional keyword arguments to be passed to the retriever instantiation. 86 | 87 | Returns: 88 | Dict[TaskType, GoldenRetriever]: 89 | A dictionary containing the instantiated retrievers. 90 | 91 | Raises: 92 | ValueError: If the `retriever` argument is not of type `GoldenRetriever`, `DictConfig`, or `Dict`. 93 | ValueError: If the `retriever` argument is a `DictConfig` without the `_target_` key. 94 | ValueError: If the task type is not valid for each retriever in the `DictConfig`. 95 | 96 | """ 97 | 98 | # retriever section 99 | _retriever: Dict[TaskType, WSLRetriever] = { 100 | TaskType.SPAN: None, 101 | } 102 | 103 | # check retriever type, it can be a GoldenRetriever, a DictConfig or a Dict 104 | if not isinstance(retriever, (WSLRetriever, DictConfig, Dict, str)): 105 | raise ValueError( 106 | f"`retriever` must be a `GoldenRetriever`, a `DictConfig`, " 107 | f"a `Dict`, or a `str`, got `{type(retriever)}`." 108 | ) 109 | if isinstance(retriever, str): 110 | logger.warning( 111 | "Using a string to instantiate the retriever. " 112 | f"We will use the same model `{retriever}` for both query and passage encoder. " 113 | "If you want to use different models, please provide a dictionary with keys `question_encoder` and `passage_encoder`." 114 | ) 115 | retriever = {"question_encoder": retriever} 116 | # we need to check weather the DictConfig is a DictConfig for an instance of GoldenRetriever 117 | # or a primitive Dict 118 | if isinstance(retriever, DictConfig): 119 | # then it is probably a primitive Dict 120 | if "_target_" not in retriever: 121 | retriever = OmegaConf.to_container(retriever, resolve=True) 122 | # convert the key to TaskType 123 | try: 124 | retriever = {TaskType(k.lower()): v for k, v in retriever.items()} 125 | except ValueError as e: 126 | raise ValueError( 127 | f"Please choose a valid task type (one of {list(TaskType)}) for each retriever." 128 | ) from e 129 | 130 | if isinstance(retriever, Dict): 131 | # convert the key to TaskType 132 | retriever = {TaskType(k): v for k, v in retriever.items()} 133 | else: 134 | retriever = {task: retriever} 135 | 136 | # instantiate each retriever 137 | if task in [TaskType.SPAN, TaskType.BOTH]: 138 | _retriever[TaskType.SPAN] = _instantiate_retriever( 139 | retriever[TaskType.SPAN], 140 | device, 141 | precision, 142 | **kwargs, 143 | ) 144 | 145 | # clean up None retrievers from the dictionary 146 | _retriever = {task_type: r for task_type, r in _retriever.items() if r is not None} 147 | if compile: 148 | # torch compile 149 | _retriever = { 150 | task_type: torch.compile(r) for task_type, r in _retriever.items() 151 | } 152 | 153 | return _retriever 154 | 155 | 156 | def _instantiate_index( 157 | index: BaseDocumentIndex | DictConfig | Dict, 158 | device: str | None | torch.device | int = None, 159 | precision: int | str | torch.dtype | None = None, 160 | **kwargs: Dict[str, Any], 161 | ) -> BaseDocumentIndex: 162 | """ 163 | Instantiate a document index. 164 | 165 | Args: 166 | index (`BaseDocumentIndex`, `DictConfig` or `Dict`): 167 | The document index to instantiate. 168 | device (`str`, `optional`): 169 | The device to use for the document index. 170 | precision (`int`, `str` or `torch.dtype`, `optional`): 171 | The precision to use for the document index. 172 | kwargs (`Dict[str, Any]`, `optional`): 173 | Additional keyword arguments to pass to the document index. 174 | 175 | Returns: 176 | `BaseDocumentIndex`: 177 | The instantiated document index. 178 | """ 179 | if not isinstance(index, BaseDocumentIndex): 180 | index = OmegaConf.create(index) 181 | use_faiss = kwargs.get("use_faiss", False) 182 | if use_faiss: 183 | index = OmegaConf.merge( 184 | index, 185 | { 186 | "_target_": "wsl.retriever.indexers.faissindex.FaissDocumentIndex.from_pretrained", 187 | }, 188 | ) 189 | if device is not None: 190 | kwargs["device"] = device 191 | if precision is not None: 192 | kwargs["precision"] = precision 193 | 194 | # merge the kwargs 195 | index = OmegaConf.merge(index, OmegaConf.create(kwargs)) 196 | index: BaseDocumentIndex = hydra.utils.instantiate(index) 197 | else: 198 | index = index 199 | if device is not None: 200 | logger.info(f"Moving index to `{device}`.") 201 | index.to(device) 202 | if precision is not None: 203 | logger.info(f"Setting precision of index to `{PRECISION_MAP[precision]}`.") 204 | index.to(PRECISION_MAP[precision]) 205 | return index 206 | 207 | 208 | def load_index( 209 | index: BaseDocumentIndex | DictConfig | Dict | str, 210 | device: str | None, 211 | precision: int | str | torch.dtype | None, 212 | task: TaskType, 213 | **kwargs, 214 | ) -> Dict[TaskType, BaseDocumentIndex]: 215 | """ 216 | Load the document index based on the specified parameters. 217 | 218 | Args: 219 | index (BaseDocumentIndex | DictConfig | Dict): 220 | The document index to load. It can be an instance of `BaseDocumentIndex`, a `DictConfig`, or a `Dict`. 221 | device (str | None): 222 | The device to use for loading the index. If `None`, the default device will be used. 223 | precision (int | str | torch.dtype | None): 224 | The precision of the index. If `None`, the default precision will be used. 225 | task (TaskType): 226 | The type of task for the index. 227 | **kwargs: 228 | Additional keyword arguments to be passed to the index instantiation. 229 | 230 | Returns: 231 | Dict[TaskType, BaseDocumentIndex]: 232 | A dictionary containing the loaded document index for each task type. 233 | 234 | Raises: 235 | ValueError: If the `index` parameter is not of type `BaseDocumentIndex`, `DictConfig`, or `Dict`. 236 | ValueError: If the `index` parameter is a `DictConfig` without a `_target_` key. 237 | ValueError: If the task type specified in the `index` parameter is not valid. 238 | """ 239 | 240 | # index 241 | _index: Dict[TaskType, BaseDocumentIndex] = { 242 | TaskType.SPAN: None, 243 | } 244 | 245 | # check retriever type, it can be a BaseDocumentIndex, a DictConfig or a Dict 246 | if not isinstance(index, (BaseDocumentIndex, DictConfig, Dict, str)): 247 | raise ValueError( 248 | f"`index` must be a `BaseDocumentIndex`, a `DictConfig`, " 249 | f"a `Dict`, or a `str`, got `{type(index)}`." 250 | ) 251 | # we need to check weather the DictConfig is a DictConfig for an instance of BaseDocumentIndex 252 | # or a primitive Dict 253 | if isinstance(index, str): 254 | index = {"name_or_path": index} 255 | if isinstance(index, DictConfig): 256 | # then it is probably a primitive Dict 257 | if "_target_" not in index: 258 | index = OmegaConf.to_container(index, resolve=True) 259 | # convert the key to TaskType 260 | try: 261 | index = {TaskType(k.lower()): v for k, v in index.items()} 262 | except ValueError as e: 263 | raise ValueError( 264 | f"Please choose a valid task type (one of {list(TaskType)}) for each index." 265 | ) from e 266 | 267 | if isinstance(index, Dict): 268 | # convert the key to TaskType 269 | index = {TaskType(k): v for k, v in index.items()} 270 | else: 271 | index = {task: index} 272 | 273 | # instantiate each retriever 274 | if task in [TaskType.SPAN, TaskType.BOTH]: 275 | _index[TaskType.SPAN] = _instantiate_index( 276 | index[TaskType.SPAN], 277 | device, 278 | precision, 279 | **kwargs, 280 | ) 281 | 282 | # clean up None retrievers from the dictionary 283 | _index = {task_type: i for task_type, i in _index.items() if i is not None} 284 | return _index 285 | 286 | 287 | def load_reader( 288 | reader: WSLReaderBase, 289 | device: str | None, 290 | precision: int | str | torch.dtype | None, 291 | compile: bool = False, 292 | **kwargs: Dict[str, Any], 293 | ) -> WSLReaderBase: 294 | """ 295 | Load a reader model for inference. 296 | 297 | Args: 298 | reader (WSLReaderBase): 299 | The reader model to load. 300 | device (str | None): 301 | The device to move the reader model to. 302 | precision (int | str | torch.dtype | None): 303 | The precision to set for the reader model. 304 | compile (bool, optional): 305 | Whether to compile the reader model. Defaults to False. 306 | **kwargs (Dict[str, Any]): 307 | Additional keyword arguments to pass to the reader model. 308 | 309 | Returns: 310 | WSLReaderBase: The loaded reader model. 311 | """ 312 | 313 | if not isinstance(reader, (WSLReaderBase, DictConfig, Dict, str)): 314 | raise ValueError( 315 | f"`reader` must be a `WSLReaderBase`, a `DictConfig`, " 316 | f"a `Dict`, or a `str`, got `{type(reader)}`." 317 | ) 318 | 319 | if isinstance(reader, str): 320 | reader = { 321 | "_target_": "wsl.reader.pytorch_modules.base.WSLReaderBase.from_pretrained", 322 | "model_name_or_dir": reader, 323 | } 324 | 325 | if not isinstance(reader, DictConfig): 326 | # then it is probably a primitive Dict 327 | # if "_target_" not in reader: 328 | # reader = OmegaConf.to_container(reader, resolve=True) 329 | # reader = OmegaConf.to_container(reader, resolve=True) 330 | # if not isinstance(reader, DictConfig): 331 | reader = OmegaConf.create(reader) 332 | 333 | reader = ( 334 | hydra.utils.instantiate( 335 | reader, 336 | device=device, 337 | precision=precision, 338 | **kwargs, 339 | ) 340 | if isinstance(reader, DictConfig) 341 | else reader 342 | ) 343 | reader.training = False 344 | reader.eval() 345 | if device is not None: 346 | logger.info(f"Moving reader to `{device}`.") 347 | reader.to(device) 348 | if precision is not None and reader.precision != PRECISION_MAP[precision]: 349 | logger.info(f"Setting precision of reader to `{PRECISION_MAP[precision]}`.") 350 | reader.to(PRECISION_MAP[precision]) 351 | 352 | if compile: 353 | reader = torch.compile(reader) 354 | return reader 355 | -------------------------------------------------------------------------------- /wsl_data_license.txt: -------------------------------------------------------------------------------- 1 | WSL Dataset Non-Commercial license 2 | 3 | 1. Definitions 4 | 5 | "Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original. 6 | "Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or other works or subject matter which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined above) for the purposes of this License. 7 | "Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership. 8 | "License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, Noncommercial, ShareAlike. 9 | "Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License. 10 | "Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i)in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast. 11 | "Work" means the word-sense disambiguated part of SemCor provided by Babelscape under this license. 12 | "You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation. 13 | "Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images. 14 | "Reproduce" means to make digital or paper copies of the Work by any means including without limitation by textual, sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage in digital form or other electronic medium. 15 | 16 | 2. Fair Dealing Rights 17 | 18 | Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws. 19 | 3. License Grant 20 | 21 | Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below: 22 | 23 | to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections provided the Work is clearly identifiable, linked to https://github.com/Babelscape/WSL and made accessible only to research institutions; 24 | to create, Reproduce, Distribute and Publicly Perform Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, an Adaptation could be marked "This data is a processed version of the WSL dataset downloaded from https://huggingface.co/datasets/Babelscape/wsl, made available by Babelscape with the WSL Non-Commercial License" and made accessible only to research institutions. Alternatively, a link to the WSL website can be provided for download of the official data and code for the creation of the Adaptation can be provided to any user. 25 | 26 | The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved, including but not limited to the rights described in Section 4(e). 27 | 4. Restrictions 28 | 29 | The license granted in Section 3 above is expressly made subject to and limited by the following restrictions: 30 | 31 | You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) https://github.com/Babelscape/WSL/wsl_data_license.txt for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(d), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(d), as requested. 32 | You may Distribute or Publicly Perform an Adaptation only under: (i) the terms of this License; (ii) a later version of this License with the same License Elements as this License. You must include a copy of, or the URI https://github.com/Babelscape/WSL/wsl_data_license.txt, for Applicable License with every copy of each Adaptation You Distribute or Publicly Perform. You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License. You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License. 33 | You may not exercise any of the rights granted to You in Section 3 above if You are not a research institution or if in any manner that is primarily intended for or directed toward commercial advantage or private monetary compensation. 34 | If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work (Word Sense Linking); (iii) the URI (https://huggingface.co/datasets/Babelscape/wsl); and, (iv) consistent with Section 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "This data is a processed version of the WSL dataset downloaded from https://huggingface.co/datasets/Babelscape/wsl, made available by Babelscape with the WSL Non-Commercial License"). The credit required by this Section 4(d) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties. 35 | 36 | For the avoidance of doubt: 37 | Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; 38 | Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License if Your exercise of such rights is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c) and otherwise waives the right to collect royalties through any statutory or compulsory licensing scheme; and, 39 | Voluntary License Schemes. The Licensor reserves the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License that is for a purpose or use which is otherwise than noncommercial as permitted under Section 4(c). 40 | Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise. 41 | 42 | 5. Representations, Warranties and Disclaimer 43 | 44 | UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING AND TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, NON-INFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO THIS EXCLUSION MAY NOT APPLY TO YOU. 45 | 6. Limitation on Liability 46 | 47 | EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 48 | 7. Termination 49 | 50 | This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License. 51 | Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above. 52 | 53 | 8. Miscellaneous 54 | 55 | Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License. 56 | Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License. 57 | If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable. 58 | No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent. 59 | This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You. 60 | The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law. 61 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/span.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | from typing import Any, Dict, Iterator, List 4 | 5 | import torch 6 | import transformers as tr 7 | from lightning_fabric.utilities import move_data_to_device 8 | from torch.utils.data import DataLoader, IterableDataset 9 | from tqdm import tqdm 10 | 11 | from wsl.common.log import get_logger 12 | from wsl.common.torch_utils import get_autocast_context 13 | from wsl.common.utils import get_callable_from_string 14 | from wsl.inference.data.objects import AnnotationType 15 | from wsl.reader.data.wsl_reader_sample import WSLReaderSample 16 | from wsl.reader.pytorch_modules.base import WSLReaderBase 17 | 18 | logger = get_logger(__name__, level=logging.INFO) 19 | 20 | 21 | class WSLReaderForSpanExtraction(WSLReaderBase): 22 | """ 23 | A class for the RelikReader model for span extraction. 24 | 25 | Args: 26 | transformer_model (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): 27 | The transformer model to use. If `None`, the default model is used. 28 | additional_special_symbols (:obj:`int`, `optional`, defaults to 0): 29 | The number of additional special symbols to add to the tokenizer. 30 | num_layers (:obj:`int`, `optional`): 31 | The number of layers to use. If `None`, all layers are used. 32 | activation (:obj:`str`, `optional`, defaults to "gelu"): 33 | The activation function to use. 34 | linears_hidden_size (:obj:`int`, `optional`, defaults to 512): 35 | The hidden size of the linears. 36 | use_last_k_layers (:obj:`int`, `optional`, defaults to 1): 37 | The number of last layers to use. 38 | training (:obj:`bool`, `optional`, defaults to False): 39 | Whether the model is in training mode. 40 | device (:obj:`str` or :obj:`torch.device` or :obj:`None`, `optional`): 41 | The device to use. If `None`, the default device is used. 42 | tokenizer (:obj:`str` or :obj:`transformers.PreTrainedTokenizer` or :obj:`None`, `optional`): 43 | The tokenizer to use. If `None`, the default tokenizer is used. 44 | dataset (:obj:`IterableDataset` or :obj:`str` or :obj:`None`, `optional`): 45 | The dataset to use. If `None`, the default dataset is used. 46 | dataset_kwargs (:obj:`Dict[str, Any]` or :obj:`None`, `optional`): 47 | The keyword arguments to pass to the dataset class. 48 | default_reader_class (:obj:`str` or :obj:`transformers.PreTrainedModel` or :obj:`None`, `optional`): 49 | The default reader class to use. If `None`, the default reader class is used. 50 | **kwargs: 51 | Keyword arguments. 52 | """ 53 | 54 | default_reader_class: str = ( 55 | "wsl.reader.pytorch_modules.hf.modeling_wsl.WSLReaderSpanModel" 56 | ) 57 | default_data_class: str = "wsl.reader.data.wsl_reader_data.WSLDataset" 58 | 59 | def __init__( 60 | self, 61 | transformer_model: str | tr.PreTrainedModel | None = None, 62 | additional_special_symbols: int = 0, 63 | num_layers: int | None = None, 64 | activation: str = "gelu", 65 | linears_hidden_size: int | None = 512, 66 | use_last_k_layers: int = 1, 67 | training: bool = False, 68 | device: str | torch.device | None = None, 69 | tokenizer: str | tr.PreTrainedTokenizer | None = None, 70 | dataset: IterableDataset | str | None = None, 71 | dataset_kwargs: Dict[str, Any] | None = None, 72 | default_reader_class: tr.PreTrainedModel | str | None = None, 73 | **kwargs, 74 | ): 75 | super().__init__( 76 | transformer_model=transformer_model, 77 | additional_special_symbols=additional_special_symbols, 78 | num_layers=num_layers, 79 | activation=activation, 80 | linears_hidden_size=linears_hidden_size, 81 | use_last_k_layers=use_last_k_layers, 82 | training=training, 83 | device=device, 84 | tokenizer=tokenizer, 85 | dataset=dataset, 86 | default_reader_class=default_reader_class, 87 | **kwargs, 88 | ) 89 | # and instantiate the dataset class 90 | self.dataset = dataset 91 | if self.dataset is None: 92 | self.default_data_class = get_callable_from_string(self.default_data_class) 93 | default_data_kwargs = dict( 94 | dataset_path=None, 95 | materialize_samples=False, 96 | transformer_model=self.tokenizer, 97 | special_symbols=self.default_data_class.get_special_symbols( 98 | self.wsl_reader_model.config.additional_special_symbols 99 | ), 100 | for_inference=True, 101 | use_nme=kwargs.get("use_nme", True), 102 | ) 103 | # merge the default data kwargs with the ones passed to the model 104 | default_data_kwargs.update(dataset_kwargs or {}) 105 | self.dataset = self.default_data_class(**default_data_kwargs) 106 | 107 | @torch.no_grad() 108 | @torch.inference_mode() 109 | def _read( 110 | self, 111 | samples: List[WSLReaderSample] | None = None, 112 | input_ids: torch.Tensor | None = None, 113 | attention_mask: torch.Tensor | None = None, 114 | token_type_ids: torch.Tensor | None = None, 115 | prediction_mask: torch.Tensor | None = None, 116 | special_symbols_mask: torch.Tensor | None = None, 117 | max_length: int = 1000, 118 | max_batch_size: int = 128, 119 | token_batch_size: int = 2048, 120 | precision: str = 32, 121 | annotation_type: AnnotationType = AnnotationType.CHAR, 122 | progress_bar: bool = False, 123 | remove_nmes: bool = True, 124 | *args: object, 125 | **kwargs: object, 126 | ) -> List[WSLReaderSample] | List[List[WSLReaderSample]]: 127 | """ 128 | A wrapper around the forward method that returns the predicted labels for each sample. 129 | 130 | Args: 131 | samples (:obj:`List[RelikReaderSample]`, `optional`): 132 | The samples to read. If provided, `text` and `candidates` are ignored. 133 | input_ids (:obj:`torch.Tensor`, `optional`): 134 | The input ids of the text. If `samples` is provided, this is ignored. 135 | attention_mask (:obj:`torch.Tensor`, `optional`): 136 | The attention mask of the text. If `samples` is provided, this is ignored. 137 | token_type_ids (:obj:`torch.Tensor`, `optional`): 138 | The token type ids of the text. If `samples` is provided, this is ignored. 139 | prediction_mask (:obj:`torch.Tensor`, `optional`): 140 | The prediction mask of the text. If `samples` is provided, this is ignored. 141 | special_symbols_mask (:obj:`torch.Tensor`, `optional`): 142 | The special symbols mask of the text. If `samples` is provided, this is ignored. 143 | max_length (:obj:`int`, `optional`, defaults to 1000): 144 | The maximum length of the text. 145 | max_batch_size (:obj:`int`, `optional`, defaults to 128): 146 | The maximum batch size. 147 | token_batch_size (:obj:`int`, `optional`): 148 | The token batch size. 149 | progress_bar (:obj:`bool`, `optional`, defaults to False): 150 | Whether to show a progress bar. 151 | precision (:obj:`str`, `optional`, defaults to 32): 152 | The precision to use for the model. 153 | annotation_type (`AnnotationType`, `optional`, defaults to `AnnotationType.CHAR`): 154 | The type of annotation to return. If `char`, the spans will be in terms of 155 | character offsets. If `word`, the spans will be in terms of word offsets. 156 | *args: 157 | Positional arguments. 158 | **kwargs: 159 | Keyword arguments. 160 | 161 | Returns: 162 | :obj:`List[RelikReaderSample]` or :obj:`List[List[RelikReaderSample]]`: 163 | The predicted labels for each sample. 164 | """ 165 | 166 | precision = precision or self.precision 167 | if samples is not None: 168 | 169 | def _read_iterator(): 170 | def samples_it(): 171 | for i, sample in enumerate(samples): 172 | assert sample._mixin_prediction_position is None 173 | sample._mixin_prediction_position = i 174 | if sample.spans is not None and len(sample.spans) > 0: 175 | sample.window_labels = [ 176 | [s[0], s[1], ""] for s in sample.spans 177 | ] 178 | yield sample 179 | 180 | next_prediction_position = 0 181 | position2predicted_sample = {} 182 | 183 | # instantiate dataset 184 | if self.dataset is None: 185 | raise ValueError( 186 | "You need to pass a dataset to the model in order to predict" 187 | ) 188 | self.dataset.samples = samples_it() 189 | self.dataset.model_max_length = max_length 190 | self.dataset.tokens_per_batch = token_batch_size 191 | self.dataset.max_batch_size = max_batch_size 192 | # self.dataset.batch_size = max_batch_size 193 | 194 | # instantiate dataloader 195 | iterator = DataLoader( 196 | self.dataset, batch_size=None, num_workers=0, shuffle=False 197 | ) 198 | if progress_bar: 199 | iterator = tqdm(iterator, desc="Predicting with RelikReader") 200 | 201 | with get_autocast_context(self.device, precision): 202 | for batch in iterator: 203 | batch = move_data_to_device(batch, self.device) 204 | batch.update(kwargs) 205 | batch_out = self._batch_predict(**batch) 206 | 207 | for sample in batch_out: 208 | if ( 209 | sample.spans is not None 210 | and len(sample.spans) > 0 211 | and sample.window_labels 212 | ): 213 | # remove window labels 214 | sample.window_labels = None 215 | if ( 216 | sample._mixin_prediction_position 217 | >= next_prediction_position 218 | ): 219 | position2predicted_sample[ 220 | sample._mixin_prediction_position 221 | ] = sample 222 | 223 | # yield 224 | while next_prediction_position in position2predicted_sample: 225 | yield position2predicted_sample[next_prediction_position] 226 | del position2predicted_sample[next_prediction_position] 227 | next_prediction_position += 1 228 | 229 | outputs = list(_read_iterator()) 230 | for sample in outputs: 231 | self.dataset.merge_patches_predictions(sample) 232 | if annotation_type == AnnotationType.CHAR: 233 | self.dataset.convert_to_char_annotations(sample, remove_nmes) 234 | elif annotation_type == AnnotationType.WORD: 235 | self.dataset.convert_to_word_annotations(sample, remove_nmes) 236 | else: 237 | raise ValueError( 238 | f"Annotation type {annotation_type} not recognized. " 239 | f"Please choose one of {list(AnnotationType)}." 240 | ) 241 | 242 | else: 243 | outputs = list( 244 | self._batch_predict( 245 | input_ids, 246 | attention_mask, 247 | token_type_ids, 248 | prediction_mask, 249 | special_symbols_mask, 250 | *args, 251 | **kwargs, 252 | ) 253 | ) 254 | return outputs 255 | 256 | def _batch_predict( 257 | self, 258 | input_ids: torch.Tensor, 259 | attention_mask: torch.Tensor, 260 | token_type_ids: torch.Tensor | None = None, 261 | prediction_mask: torch.Tensor | None = None, 262 | special_symbols_mask: torch.Tensor | None = None, 263 | sample: List[WSLReaderSample] | None = None, 264 | top_k: int = 5, # the amount of top-k most probable entities to predict 265 | *args, 266 | **kwargs, 267 | ) -> Iterator[WSLReaderSample]: 268 | """ 269 | A wrapper around the forward method that returns the predicted labels for each sample. 270 | It also adds the predicted labels to the samples. 271 | 272 | Args: 273 | input_ids (:obj:`torch.Tensor`): 274 | The input ids of the text. 275 | attention_mask (:obj:`torch.Tensor`): 276 | The attention mask of the text. 277 | token_type_ids (:obj:`torch.Tensor`, `optional`): 278 | The token type ids of the text. 279 | prediction_mask (:obj:`torch.Tensor`, `optional`): 280 | The prediction mask of the text. 281 | special_symbols_mask (:obj:`torch.Tensor`, `optional`): 282 | The special symbols mask of the text. 283 | sample (:obj:`List[RelikReaderSample]`, `optional`): 284 | The samples to read. If provided, `text` and `candidates` are ignored. 285 | top_k (:obj:`int`, `optional`, defaults to 5): 286 | The amount of top-k most probable entities to predict. 287 | *args: 288 | Positional arguments. 289 | **kwargs: 290 | Keyword arguments. 291 | 292 | Returns: 293 | The predicted labels for each sample. 294 | """ 295 | forward_output = self.forward( 296 | input_ids=input_ids, 297 | attention_mask=attention_mask, 298 | token_type_ids=token_type_ids, 299 | prediction_mask=prediction_mask, 300 | special_symbols_mask=special_symbols_mask, 301 | *args, 302 | **kwargs, 303 | ) 304 | 305 | ned_start_predictions = forward_output["ned_start_predictions"].cpu().numpy() 306 | ned_end_predictions = forward_output["ned_end_predictions"] # .cpu().numpy() 307 | ed_predictions = forward_output["ed_predictions"].cpu().numpy() 308 | ed_probabilities = forward_output["ed_probabilities"].cpu().numpy() 309 | 310 | batch_predictable_candidates = kwargs["predictable_candidates"] 311 | patch_offset = kwargs["patch_offset"] 312 | for ts, ne_sp, ne_ep, edp, edpr, pred_cands, po in zip( 313 | sample, 314 | ned_start_predictions, 315 | ned_end_predictions, 316 | ed_predictions, 317 | ed_probabilities, 318 | batch_predictable_candidates, 319 | patch_offset, 320 | ): 321 | ent_count = 0 322 | ne_ep = ne_ep.cpu().numpy() 323 | ne_start_indices = [ti for ti, c in enumerate(ne_sp[1:]) if c > 0] 324 | # ne_end_indices = [ti for ti, c in enumerate(ne_ep[1:]) if c > 0] 325 | final_class2predicted_spans = collections.defaultdict(list) 326 | spans2predicted_probabilities = dict() 327 | for start_token_index, end_token_index in zip(ne_start_indices, ne_ep): 328 | for end_token_index in [ 329 | ti for ti, c in enumerate(end_token_index[1:]) if c > 0 330 | ]: 331 | # predicted candidate 332 | token_class = edp[ent_count] - 1 333 | predicted_candidate_title = pred_cands[token_class] 334 | final_class2predicted_spans[predicted_candidate_title].append( 335 | [start_token_index, end_token_index] 336 | ) 337 | 338 | # candidates probabilities 339 | classes_probabilities = edpr[ent_count] 340 | classes_probabilities_best_indices = ( 341 | classes_probabilities.argsort()[::-1] 342 | ) 343 | titles_2_probs = [] 344 | top_k = ( 345 | min( 346 | top_k, 347 | len(classes_probabilities_best_indices), 348 | ) 349 | if top_k != -1 350 | else len(classes_probabilities_best_indices) 351 | ) 352 | for i in range(top_k): 353 | titles_2_probs.append( 354 | ( 355 | pred_cands[classes_probabilities_best_indices[i] - 1], 356 | classes_probabilities[ 357 | classes_probabilities_best_indices[i] 358 | ].item(), 359 | ) 360 | ) 361 | spans2predicted_probabilities[ 362 | (start_token_index, end_token_index) 363 | ] = titles_2_probs 364 | ent_count += 1 365 | 366 | if "patches" not in ts._d: 367 | ts._d["patches"] = dict() 368 | 369 | ts._d["patches"][po] = dict() 370 | sample_patch = ts._d["patches"][po] 371 | 372 | sample_patch["predicted_window_labels"] = final_class2predicted_spans 373 | sample_patch["span_title_probabilities"] = spans2predicted_probabilities 374 | 375 | # additional info 376 | sample_patch["predictable_candidates"] = pred_cands 377 | 378 | # try-out for a new format 379 | sample_patch["predicted_spans"] = final_class2predicted_spans 380 | sample_patch[ 381 | "predicted_spans_probabilities" 382 | ] = spans2predicted_probabilities 383 | 384 | yield ts 385 | -------------------------------------------------------------------------------- /wsl/reader/pytorch_modules/hf/modeling_wsl.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import torch 4 | from transformers import AutoModel, PreTrainedModel 5 | from transformers.activations import ClippedGELUActivation, GELUActivation 6 | from transformers.configuration_utils import PretrainedConfig 7 | from transformers.modeling_utils import PoolerEndLogits 8 | 9 | from .configuration_wsl import WSLReaderConfig 10 | 11 | 12 | class WSLReaderSample: 13 | def __init__(self, **kwargs): 14 | super().__setattr__("_d", {}) 15 | self._d = kwargs 16 | 17 | def __getattribute__(self, item): 18 | return super(WSLReaderSample, self).__getattribute__(item) 19 | 20 | def __getattr__(self, item): 21 | if item.startswith("__") and item.endswith("__"): 22 | # this is likely some python library-specific variable (such as __deepcopy__ for copy) 23 | # better follow standard behavior here 24 | raise AttributeError(item) 25 | elif item in self._d: 26 | return self._d[item] 27 | else: 28 | return None 29 | 30 | def __setattr__(self, key, value): 31 | if key in self._d: 32 | self._d[key] = value 33 | else: 34 | super().__setattr__(key, value) 35 | self._d[key] = value 36 | 37 | 38 | activation2functions = { 39 | "relu": torch.nn.ReLU(), 40 | "gelu": GELUActivation(), 41 | "gelu_10": ClippedGELUActivation(-10, 10), 42 | } 43 | 44 | 45 | class PoolerEndLogitsBi(PoolerEndLogits): 46 | def __init__(self, config: PretrainedConfig): 47 | super().__init__(config) 48 | self.dense_1 = torch.nn.Linear(config.hidden_size, 2) 49 | 50 | def forward( 51 | self, 52 | hidden_states: torch.FloatTensor, 53 | start_states: Optional[torch.FloatTensor] = None, 54 | start_positions: Optional[torch.LongTensor] = None, 55 | p_mask: Optional[torch.FloatTensor] = None, 56 | ) -> torch.FloatTensor: 57 | if p_mask is not None: 58 | p_mask = p_mask.unsqueeze(-1) 59 | logits = super().forward( 60 | hidden_states, 61 | start_states, 62 | start_positions, 63 | p_mask, 64 | ) 65 | return logits 66 | 67 | 68 | class WSLReaderSpanModel(PreTrainedModel): 69 | config_class = WSLReaderConfig 70 | 71 | def __init__(self, config: WSLReaderConfig, *args, **kwargs): 72 | super().__init__(config) 73 | # Transformer model declaration 74 | self.config = config 75 | self.transformer_model = ( 76 | AutoModel.from_pretrained(self.config.transformer_model) 77 | if self.config.num_layers is None 78 | else AutoModel.from_pretrained( 79 | self.config.transformer_model, num_hidden_layers=self.config.num_layers 80 | ) 81 | ) 82 | self.transformer_model.resize_token_embeddings( 83 | self.transformer_model.config.vocab_size 84 | + self.config.additional_special_symbols 85 | ) 86 | 87 | self.activation = self.config.activation 88 | self.linears_hidden_size = self.config.linears_hidden_size 89 | self.use_last_k_layers = self.config.use_last_k_layers 90 | 91 | # named entity detection layers 92 | self.ned_start_classifier = self._get_projection_layer( 93 | self.activation, last_hidden=2, layer_norm=False 94 | ) 95 | if self.config.binary_end_logits: 96 | self.ned_end_classifier = PoolerEndLogitsBi(self.transformer_model.config) 97 | else: 98 | self.ned_end_classifier = PoolerEndLogits(self.transformer_model.config) 99 | 100 | # END entity disambiguation layer 101 | self.ed_start_projector = self._get_projection_layer(self.activation) 102 | self.ed_end_projector = self._get_projection_layer(self.activation) 103 | 104 | self.training = self.config.training 105 | 106 | # criterion 107 | self.criterion = torch.nn.CrossEntropyLoss() 108 | 109 | def _get_projection_layer( 110 | self, 111 | activation: str, 112 | last_hidden: Optional[int] = None, 113 | input_hidden=None, 114 | layer_norm: bool = True, 115 | ) -> torch.nn.Sequential: 116 | head_components = [ 117 | torch.nn.Dropout(0.1), 118 | torch.nn.Linear( 119 | ( 120 | self.transformer_model.config.hidden_size * self.use_last_k_layers 121 | if input_hidden is None 122 | else input_hidden 123 | ), 124 | self.linears_hidden_size, 125 | ), 126 | activation2functions[activation], 127 | torch.nn.Dropout(0.1), 128 | torch.nn.Linear( 129 | self.linears_hidden_size, 130 | self.linears_hidden_size if last_hidden is None else last_hidden, 131 | ), 132 | ] 133 | 134 | if layer_norm: 135 | head_components.append( 136 | torch.nn.LayerNorm( 137 | self.linears_hidden_size if last_hidden is None else last_hidden, 138 | self.transformer_model.config.layer_norm_eps, 139 | ) 140 | ) 141 | 142 | return torch.nn.Sequential(*head_components) 143 | 144 | def _mask_logits(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 145 | mask = mask.unsqueeze(-1) 146 | if next(self.parameters()).dtype == torch.float16: 147 | logits = logits * (1 - mask) - 65500 * mask 148 | else: 149 | logits = logits * (1 - mask) - 1e30 * mask 150 | return logits 151 | 152 | def _get_model_features( 153 | self, 154 | input_ids: torch.Tensor, 155 | attention_mask: torch.Tensor, 156 | token_type_ids: Optional[torch.Tensor], 157 | ): 158 | model_input = { 159 | "input_ids": input_ids, 160 | "attention_mask": attention_mask, 161 | "output_hidden_states": self.use_last_k_layers > 1, 162 | } 163 | 164 | if token_type_ids is not None: 165 | model_input["token_type_ids"] = token_type_ids 166 | 167 | model_output = self.transformer_model(**model_input) 168 | 169 | if self.use_last_k_layers > 1: 170 | model_features = torch.cat( 171 | model_output[1][-self.use_last_k_layers :], dim=-1 172 | ) 173 | else: 174 | model_features = model_output[0] 175 | 176 | return model_features 177 | 178 | def compute_ned_end_logits( 179 | self, 180 | start_predictions, 181 | start_labels, 182 | model_features, 183 | prediction_mask, 184 | batch_size, 185 | ) -> Optional[torch.Tensor]: 186 | # todo: maybe when constraining on the spans, 187 | # we should not use a prediction_mask for the end tokens. 188 | # at least we should not during training imo 189 | start_positions = start_labels if self.training else start_predictions 190 | start_positions_indices = ( 191 | torch.arange(start_positions.size(1), device=start_positions.device) 192 | .unsqueeze(0) 193 | .expand(batch_size, -1)[start_positions > 0] 194 | ).to(start_positions.device) 195 | 196 | if len(start_positions_indices) > 0: 197 | expanded_features = model_features.repeat_interleave( 198 | torch.sum(start_positions > 0, dim=-1), dim=0 199 | ) 200 | expanded_prediction_mask = prediction_mask.repeat_interleave( 201 | torch.sum(start_positions > 0, dim=-1), dim=0 202 | ) 203 | end_logits = self.ned_end_classifier( 204 | hidden_states=expanded_features, 205 | start_positions=start_positions_indices, 206 | p_mask=expanded_prediction_mask, 207 | ) 208 | 209 | return end_logits 210 | 211 | return None 212 | 213 | def compute_classification_logits( 214 | self, 215 | model_features_start, 216 | model_features_end, 217 | special_symbols_features, 218 | ) -> torch.Tensor: 219 | model_start_features = self.ed_start_projector(model_features_start) 220 | model_end_features = self.ed_end_projector(model_features_end) 221 | model_start_features_symbols = self.ed_start_projector(special_symbols_features) 222 | model_end_features_symbols = self.ed_end_projector(special_symbols_features) 223 | 224 | model_ed_features = torch.cat( 225 | [model_start_features, model_end_features], dim=-1 226 | ) 227 | special_symbols_representation = torch.cat( 228 | [model_start_features_symbols, model_end_features_symbols], dim=-1 229 | ) 230 | 231 | logits = torch.bmm( 232 | model_ed_features, 233 | torch.permute(special_symbols_representation, (0, 2, 1)), 234 | ) 235 | 236 | logits = self._mask_logits(logits, (model_features_start == -100).all(2).long()) 237 | return logits 238 | 239 | def forward( 240 | self, 241 | input_ids: torch.Tensor, 242 | attention_mask: torch.Tensor, 243 | token_type_ids: Optional[torch.Tensor] = None, 244 | prediction_mask: Optional[torch.Tensor] = None, 245 | special_symbols_mask: Optional[torch.Tensor] = None, 246 | start_labels: Optional[torch.Tensor] = None, 247 | end_labels: Optional[torch.Tensor] = None, 248 | use_predefined_spans: bool = False, 249 | *args, 250 | **kwargs, 251 | ) -> Dict[str, Any]: 252 | batch_size, seq_len = input_ids.shape 253 | 254 | model_features = self._get_model_features( 255 | input_ids, attention_mask, token_type_ids 256 | ) 257 | 258 | ned_start_labels = None 259 | 260 | # named entity detection if required 261 | if use_predefined_spans: # no need to compute spans 262 | ned_start_logits, ned_start_probabilities, ned_start_predictions = ( 263 | None, 264 | None, 265 | ( 266 | torch.clone(start_labels) 267 | if start_labels is not None 268 | else torch.zeros_like(input_ids) 269 | ), 270 | ) 271 | ned_end_logits, ned_end_probabilities, ned_end_predictions = ( 272 | None, 273 | None, 274 | ( 275 | torch.clone(end_labels) 276 | if end_labels is not None 277 | else torch.zeros_like(input_ids) 278 | ), 279 | ) 280 | ned_start_predictions[ned_start_predictions > 0] = 1 281 | ned_end_predictions[end_labels > 0] = 1 282 | ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)] 283 | 284 | else: # compute spans 285 | # start boundary prediction 286 | ned_start_logits = self.ned_start_classifier(model_features) 287 | ned_start_logits = self._mask_logits(ned_start_logits, prediction_mask) 288 | ned_start_probabilities = torch.softmax(ned_start_logits, dim=-1) 289 | ned_start_predictions = ned_start_probabilities.argmax(dim=-1) 290 | 291 | # end boundary prediction 292 | ned_start_labels = ( 293 | torch.zeros_like(start_labels) if start_labels is not None else None 294 | ) 295 | 296 | if ned_start_labels is not None: 297 | ned_start_labels[start_labels == -100] = -100 298 | ned_start_labels[start_labels > 0] = 1 299 | 300 | ned_end_logits = self.compute_ned_end_logits( 301 | ned_start_predictions, 302 | ned_start_labels, 303 | model_features, 304 | prediction_mask, 305 | batch_size, 306 | ) 307 | 308 | if ned_end_logits is not None: 309 | ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1) 310 | if not self.config.binary_end_logits: 311 | ned_end_predictions = torch.argmax( 312 | ned_end_probabilities, dim=-1, keepdim=True 313 | ) 314 | ned_end_predictions = torch.zeros_like( 315 | ned_end_probabilities 316 | ).scatter_(1, ned_end_predictions, 1) 317 | else: 318 | ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1) 319 | else: 320 | ned_end_logits, ned_end_probabilities = None, None 321 | ned_end_predictions = ned_start_predictions.new_zeros( 322 | batch_size, seq_len 323 | ) 324 | 325 | if not self.training: 326 | # if len(ned_end_predictions.shape) < 2: 327 | # print(ned_end_predictions) 328 | end_preds_count = ned_end_predictions.sum(1) 329 | # If there are no end predictions for a start prediction, remove the start prediction 330 | if (end_preds_count == 0).any() and (ned_start_predictions > 0).any(): 331 | ned_start_predictions[ned_start_predictions == 1] = ( 332 | end_preds_count != 0 333 | ).long() 334 | ned_end_predictions = ned_end_predictions[end_preds_count != 0] 335 | 336 | if end_labels is not None: 337 | end_labels = end_labels[~(end_labels == -100).all(2)] 338 | 339 | start_position, end_position = ( 340 | (start_labels, end_labels) 341 | if self.training 342 | else (ned_start_predictions, ned_end_predictions) 343 | ) 344 | start_counts = (start_position > 0).sum(1) 345 | if (start_counts > 0).any(): 346 | ned_end_predictions = ned_end_predictions.split(start_counts.tolist()) 347 | # Entity disambiguation 348 | if (end_position > 0).sum() > 0: 349 | ends_count = (end_position > 0).sum(1) 350 | model_entity_start = torch.repeat_interleave( 351 | model_features[start_position > 0], ends_count, dim=0 352 | ) 353 | model_entity_end = torch.repeat_interleave( 354 | model_features, start_counts, dim=0 355 | )[end_position > 0] 356 | ents_count = torch.nn.utils.rnn.pad_sequence( 357 | torch.split(ends_count, start_counts.tolist()), 358 | batch_first=True, 359 | padding_value=0, 360 | ).sum(1) 361 | 362 | model_entity_start = torch.nn.utils.rnn.pad_sequence( 363 | torch.split(model_entity_start, ents_count.tolist()), 364 | batch_first=True, 365 | padding_value=-100, 366 | ) 367 | 368 | model_entity_end = torch.nn.utils.rnn.pad_sequence( 369 | torch.split(model_entity_end, ents_count.tolist()), 370 | batch_first=True, 371 | padding_value=-100, 372 | ) 373 | 374 | ed_logits = self.compute_classification_logits( 375 | model_entity_start, 376 | model_entity_end, 377 | model_features[special_symbols_mask].view( 378 | batch_size, -1, model_features.shape[-1] 379 | ), 380 | ) 381 | ed_probabilities = torch.softmax(ed_logits, dim=-1) 382 | ed_predictions = torch.argmax(ed_probabilities, dim=-1) 383 | else: 384 | ed_logits, ed_probabilities, ed_predictions = ( 385 | None, 386 | ned_start_predictions.new_zeros(batch_size, seq_len), 387 | ned_start_predictions.new_zeros(batch_size), 388 | ) 389 | # output build 390 | output_dict = dict( 391 | batch_size=batch_size, 392 | ned_start_logits=ned_start_logits, 393 | ned_start_probabilities=ned_start_probabilities, 394 | ned_start_predictions=ned_start_predictions, 395 | ned_end_logits=ned_end_logits, 396 | ned_end_probabilities=ned_end_probabilities, 397 | ned_end_predictions=ned_end_predictions, 398 | ed_logits=ed_logits, 399 | ed_probabilities=ed_probabilities, 400 | ed_predictions=ed_predictions, 401 | ) 402 | 403 | # compute loss if labels 404 | if start_labels is not None and end_labels is not None and self.training: 405 | # named entity detection loss 406 | 407 | # start 408 | if ned_start_logits is not None: 409 | ned_start_loss = self.criterion( 410 | ned_start_logits.view(-1, ned_start_logits.shape[-1]), 411 | ned_start_labels.view(-1), 412 | ) 413 | else: 414 | ned_start_loss = 0 415 | 416 | # end 417 | # use ents_count to assign the labels to the correct positions i.e. using end_labels -> [[0,0,4,0], [0,0,0,2]] -> [4,2] (this is just an element, for batch we need to mask it with ents_count), ie -> [[4,2,-100,-100], [3,1,2,-100], [1,3,2,5]] 418 | 419 | if ned_end_logits is not None: 420 | ed_labels = end_labels.clone() 421 | ed_labels = torch.nn.utils.rnn.pad_sequence( 422 | torch.split(ed_labels[ed_labels > 0], ents_count.tolist()), 423 | batch_first=True, 424 | padding_value=-100, 425 | ) 426 | end_labels[end_labels > 0] = 1 427 | if not self.config.binary_end_logits: 428 | # transform label to position in the sequence 429 | end_labels = end_labels.argmax(dim=-1) 430 | ned_end_loss = self.criterion( 431 | ned_end_logits.view(-1, ned_end_logits.shape[-1]), 432 | end_labels.view(-1), 433 | ) 434 | else: 435 | ned_end_loss = self.criterion( 436 | ned_end_logits.reshape(-1, ned_end_logits.shape[-1]), 437 | end_labels.reshape(-1).long(), 438 | ) 439 | 440 | # entity disambiguation loss 441 | ed_loss = self.criterion( 442 | ed_logits.view(-1, ed_logits.shape[-1]), 443 | ed_labels.view(-1).long(), 444 | ) 445 | 446 | else: 447 | ned_end_loss = 0 448 | ed_loss = 0 449 | 450 | output_dict["ned_start_loss"] = ned_start_loss 451 | output_dict["ned_end_loss"] = ned_end_loss 452 | output_dict["ed_loss"] = ed_loss 453 | 454 | output_dict["loss"] = ned_start_loss + ned_end_loss + ed_loss 455 | 456 | return output_dict 457 | --------------------------------------------------------------------------------