├── .dockerignore ├── .flake8 ├── .github └── workflows │ ├── black.yaml │ └── python-publish-pypi.yml ├── .gitignore ├── .pre-commit-config.yaml ├── MANIFEST.in ├── README.md ├── SETUP.cfg ├── Sapienza_Babelscape.png ├── constraints.cpu.txt ├── dockerfiles ├── fastapi │ ├── Dockerfile.cpu │ └── Dockerfile.cuda └── ray │ ├── Dockerfile.cpu │ └── Dockerfile.cuda ├── examples ├── cie.py └── langchain.py ├── pyproject.toml ├── relik.png ├── relik ├── __init__.py ├── cli │ ├── __init__.py │ ├── cli.py │ ├── data.py │ ├── reader.py │ ├── retriever.py │ └── utils.py ├── common │ ├── __init__.py │ ├── log.py │ ├── torch_utils.py │ ├── upload.py │ └── utils.py ├── inference │ ├── __init__.py │ ├── annotator.py │ ├── data │ │ ├── __init__.py │ │ ├── objects.py │ │ ├── splitters │ │ │ ├── __init__.py │ │ │ ├── base_sentence_splitter.py │ │ │ ├── blank_sentence_splitter.py │ │ │ ├── spacy_sentence_splitter.py │ │ │ └── window_based_splitter.py │ │ ├── tokenizers │ │ │ ├── __init__.py │ │ │ ├── base_tokenizer.py │ │ │ └── spacy_tokenizer.py │ │ └── window │ │ │ ├── __init__.py │ │ │ └── manager.py │ ├── serve │ │ ├── __init__.py │ │ ├── backend │ │ │ ├── __init__.py │ │ │ ├── fastapi_be.py │ │ │ ├── ray.py │ │ │ └── utils.py │ │ └── frontend │ │ │ ├── __init__.py │ │ │ ├── gradio_fe.py │ │ │ ├── relik_front.py │ │ │ ├── relik_re_front.py │ │ │ ├── style.css │ │ │ └── utils.py │ └── utils.py ├── reader │ ├── __init__.py │ ├── conf │ │ ├── base.yaml │ │ ├── base_nyt.yaml │ │ ├── cie.yaml │ │ ├── config.yaml │ │ ├── data │ │ │ ├── base.yaml │ │ │ ├── cie.yaml │ │ │ ├── large.yaml │ │ │ └── nyt.yaml │ │ ├── large.yaml │ │ ├── large_nyt.yaml │ │ ├── model │ │ │ ├── base.yaml │ │ │ ├── cie.yaml │ │ │ ├── large.yaml │ │ │ ├── nyt.yaml │ │ │ ├── nyt_base.yaml │ │ │ ├── nyt_large.yaml │ │ │ ├── nyt_small.yaml │ │ │ └── small.yaml │ │ ├── small.yaml │ │ ├── small_nyt.yaml │ │ └── training │ │ │ ├── base.yaml │ │ │ ├── cie.yaml │ │ │ ├── large.yaml │ │ │ └── nyt.yaml │ ├── data │ │ ├── __init__.py │ │ ├── patches.py │ │ ├── relik_reader_data.py │ │ ├── relik_reader_data_utils.py │ │ ├── relik_reader_re_data.py │ │ └── relik_reader_sample.py │ ├── lightning_modules │ │ ├── __init__.py │ │ ├── relik_reader_pl_module.py │ │ └── relik_reader_re_pl_module.py │ ├── pytorch_modules │ │ ├── __init__.py │ │ ├── base.py │ │ ├── hf │ │ │ ├── __init__.py │ │ │ ├── configuration_relik.py │ │ │ └── modeling_relik.py │ │ ├── optim │ │ │ ├── __init__.py │ │ │ ├── adamw_with_warmup.py │ │ │ └── layer_wise_lr_decay.py │ │ ├── span.py │ │ └── triplet.py │ ├── trainer │ │ ├── __init__.py │ │ ├── predict.py │ │ ├── predict_cie.py │ │ ├── predict_re.py │ │ ├── train.py │ │ ├── train_cie.py │ │ └── train_re.py │ └── utils │ │ ├── __init__.py │ │ ├── gerbil.py │ │ ├── metrics.py │ │ ├── relation_matching_eval.py │ │ ├── relik_reader_predictor.py │ │ ├── save_load_utilities.py │ │ ├── shuffle_train_callback.py │ │ ├── special_symbols.py │ │ └── strong_matching_eval.py ├── retriever │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base.py │ │ ├── evaluation_callbacks.py │ │ ├── prediction_callbacks.py │ │ ├── training_callbacks.py │ │ └── utils_callbacks.py │ ├── common │ │ ├── __init__.py │ │ ├── model_inputs.py │ │ └── sampler.py │ ├── conf │ │ ├── data │ │ │ ├── aida_dataset.yaml │ │ │ └── blink_dataset.yaml │ │ ├── finetune_iterable_in_batch.yaml │ │ ├── index │ │ │ └── inmemory.yaml │ │ ├── logging │ │ │ └── wandb_logging.yaml │ │ ├── loss │ │ │ ├── nce_loss.yaml │ │ │ └── nll_loss.yaml │ │ ├── model │ │ │ └── golden_retriever.yaml │ │ ├── optimizer │ │ │ ├── adamw.yaml │ │ │ ├── radam.yaml │ │ │ └── radamw.yaml │ │ ├── pretrain_iterable_in_batch.yaml │ │ └── scheduler │ │ │ ├── linear_scheduler.yaml │ │ │ ├── linear_scheduler_with_warmup.yaml │ │ │ └── none.yaml │ ├── data │ │ ├── __init__.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ └── datasets.py │ │ ├── datasets.py │ │ ├── labels.py │ │ └── utils.py │ ├── indexers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── document.py │ │ ├── faissindex.py │ │ └── inmemory.py │ ├── lightning_modules │ │ ├── __init__.py │ │ ├── pl_data_modules.py │ │ └── pl_modules.py │ ├── pytorch_modules │ │ ├── __init__.py │ │ ├── hf.py │ │ ├── loss.py │ │ ├── model.py │ │ ├── optim.py │ │ └── scheduler.py │ └── trainer │ │ ├── __init__.py │ │ └── train.py └── version.py ├── requirements.txt ├── scripts ├── build_all.sh ├── build_docker_with_weights.sh ├── data │ ├── blink │ │ └── preprocess_genre_blink.py │ ├── create_windows.py │ ├── nyt │ │ └── preprocess_nyt.py │ └── retriever │ │ ├── add_candidates.py │ │ ├── convert_to_dpr.py │ │ └── create_index.py ├── docker │ ├── gunicorn_conf.py │ ├── pre-start.sh │ ├── start-gunic.sh │ └── start.sh ├── old-scripts │ ├── data │ │ ├── add_candidates.py │ │ ├── add_re_candidates.py │ │ ├── create_windows.py │ │ ├── create_windows_re.py │ │ ├── debug.py │ │ ├── reader │ │ │ └── add_candidates_from_retriever.py │ │ ├── retriever │ │ │ ├── aida_to_dpr.py │ │ │ ├── blink │ │ │ │ ├── create_random_sample_coverage.py │ │ │ │ ├── create_windows.py │ │ │ │ └── sample_from_data_coverate.py │ │ │ ├── create_index.py │ │ │ ├── explore_blink.py │ │ │ ├── save_retriever_from_checkpoint.py │ │ │ └── triplets_to_dpr.py │ │ └── split_aida.py │ ├── evaluate │ │ ├── evaluate_re.py │ │ └── evaluate_re_bio.py │ ├── predict │ │ └── predict_aida.py │ └── retriever │ │ ├── test_aida.py │ │ ├── train_aida.py │ │ └── train_blink.py └── setup.sh └── setup.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .git 2 | data 3 | benchmark 4 | resources 5 | outputs 6 | retrievers 7 | experiments 8 | build 9 | dist 10 | models 11 | outputs 12 | pretrained_configs 13 | .idea 14 | .vscode 15 | *.egg-info -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E266, E501, W503, F403, F401, E402, C901 3 | max-line-length = 88 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | -------------------------------------------------------------------------------- /.github/workflows/black.yaml: -------------------------------------------------------------------------------- 1 | name: Check Code Quality 2 | 3 | on: pull_request 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable 11 | with: 12 | options: --check . 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v2 15 | with: 16 | python-version: "3.10" 17 | - name: Run flake8 18 | uses: julianwachholz/flake8-action@v2 19 | with: 20 | checkName: "Python Lint" 21 | path: ./relik 22 | plugins: "pep8-naming==0.13.3 flake8-comprehensions==3.14.0" 23 | config: .flake8 24 | env: 25 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} -------------------------------------------------------------------------------- /.github/workflows/python-publish-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package to PyPi 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | publish: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: "3.x" 17 | 18 | - name: Install dependencies 19 | run: | 20 | python -m pip install --upgrade pip 21 | pip install build 22 | 23 | - name: Extract version 24 | run: echo "RELIK_VERSION=`python setup.py --version`" >> $GITHUB_ENV 25 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: '23.10.1' 4 | hooks: 5 | - id: black 6 | - repo: https://github.com/pycqa/flake8 7 | rev: '6.1.0' 8 | hooks: 9 | - id: flake8 10 | 11 | default_language_version: 12 | python: python3 13 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include requirements.txt 2 | -------------------------------------------------------------------------------- /SETUP.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | 4 | [build] 5 | build-base = /tmp/build 6 | -------------------------------------------------------------------------------- /Sapienza_Babelscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/Sapienza_Babelscape.png -------------------------------------------------------------------------------- /constraints.cpu.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cpu 2 | torch==2.3.1 3 | -------------------------------------------------------------------------------- /dockerfiles/fastapi/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM python:3.11.9-slim-bullseye 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN adduser --disabled-password --gecos '' relik-user 6 | USER relik-user 7 | ENV PATH=${PATH}:/home/relik-user/.local/bin 8 | 9 | # Set the working directory 10 | COPY --chown=relik-user:relik-user . /home/relik-user/relik 11 | WORKDIR /home/relik-user/relik 12 | 13 | # mount huggingface cache dir 14 | RUN mkdir -p /home/relik-user/.cache/huggingface 15 | # ENV HF_HOME=/home/relik-user/.cache/huggingface 16 | # mount huggingface 17 | 18 | RUN pip install --upgrade --no-cache-dir .[serve] -c constraints.cpu.txt \ 19 | && chmod +x scripts/docker/start-gunic.sh 20 | 21 | EXPOSE 8000 8001 22 | 23 | ENTRYPOINT ["scripts/docker/start-gunic.sh"] 24 | -------------------------------------------------------------------------------- /dockerfiles/fastapi/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.0.0-base-ubuntu22.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN adduser --disabled-password --gecos '' relik-user \ 6 | && apt-get update \ 7 | && apt-get install -y --no-install-recommends \ 8 | curl wget python3.11 python3-distutils python3-pip \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | USER relik-user 12 | ENV PATH=${PATH}:/home/relik-user/.local/bin 13 | 14 | # Set the working directory 15 | COPY --chown=relik-user:relik-user . /home/relik-user/relik 16 | WORKDIR /home/relik-user/relik 17 | 18 | RUN mkdir -p /home/relik-user/.cache/huggingface 19 | 20 | 21 | RUN pip install --upgrade --no-cache-dir .[serve] \ 22 | && chmod +x scripts/docker/start-gunic.sh 23 | 24 | 25 | EXPOSE 8000 8001 26 | 27 | ENTRYPOINT ["scripts/docker/start-gunic.sh"] 28 | -------------------------------------------------------------------------------- /dockerfiles/ray/Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM python:3.10.13-slim-bullseye 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN adduser --disabled-password --gecos '' relik-user 6 | USER relik-user 7 | ENV PATH=${PATH}:/home/relik-user/.local/bin 8 | 9 | # Set the working directory 10 | COPY --chown=relik-user:relik-user . /home/relik-user/relik 11 | WORKDIR /home/relik-user/relik 12 | 13 | RUN pip install --upgrade --no-cache-dir .[serve,ray] -c constraints.cpu.txt \ 14 | && chmod +x scripts/docker/start.sh 15 | 16 | EXPOSE 8000 17 | 18 | ENTRYPOINT ["scripts/docker/start.sh"] 19 | 20 | # FROM mambaorg/micromamba:bullseye-slim 21 | 22 | # ARG DEBIAN_FRONTEND=noninteractive 23 | # ARG MAMBA_DOCKERFILE_ACTIVATE=1 24 | 25 | # # Set the working directory 26 | # COPY --chown=mambauser:mambauser . /home/mambauser/relik 27 | # WORKDIR /home/mambauser/relik 28 | 29 | # RUN micromamba install -y -n base python==3.10 pytorch==2.1.0 cpuonly -c pytorch -c conda-forge \ 30 | # && which pip \ 31 | # && pip install --upgrade --no-cache-dir .[serve,ray] \ 32 | # && micromamba clean --all --yes \ 33 | # && chmod +x scripts/docker/start.sh 34 | 35 | # EXPOSE 8000 36 | 37 | # ENTRYPOINT ["scripts/docker/start.sh"] 38 | -------------------------------------------------------------------------------- /dockerfiles/ray/Dockerfile.cuda: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.0.0-base-ubuntu22.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN adduser --disabled-password --gecos '' relik-user \ 6 | && apt-get update \ 7 | && apt-get install -y --no-install-recommends \ 8 | curl wget python3.10 python3-distutils python3-pip \ 9 | && rm -rf /var/lib/apt/lists/* 10 | 11 | USER relik-user 12 | ENV PATH=${PATH}:/home/relik-user/.local/bin 13 | # Set the working directory 14 | COPY --chown=relik-user:relik-user . /home/relik-user/relik 15 | WORKDIR /home/relik-user/relik 16 | 17 | # ENVS 18 | # ENV PATH="/root/conda/bin:${PATH}" 19 | 20 | RUN pip install --upgrade --no-cache-dir .[serve,ray] \ 21 | && chmod +x scripts/docker/start.sh 22 | 23 | EXPOSE 8000 24 | 25 | ENTRYPOINT ["scripts/docker/start.sh"] 26 | -------------------------------------------------------------------------------- /examples/cie.py: -------------------------------------------------------------------------------- 1 | from relik import Relik 2 | 3 | relik = Relik.from_pretrained("relik-ie/relik-cie-small", device="cuda") 4 | 5 | text = """When Noah Lyles put his spike into Stade de France’s purple track for his first stride Sunday night of the Paris Olympics 100-meter final, he was already behind. In an event in which margin for error is slimmest, his reaction time to the starting gun was the slowest. 6 | 7 | Halfway through, Lyles, 27, of the U.S., was still in seventh place in an eight-man field, trying to chase down Jamaica’s Kishane Thompson, who owned not only this season’s fastest time but also the fastest time in the semifinal round contested earlier Sunday. 8 | 9 | By the final steps Lyles had caught up so much to Thompson, American Fred Kerley and South Africa’s Akani Simbine that he did something he rarely practices — dipping his shoulder at the finish. 10 | 11 | Even then, Lyles was unconvinced he had won the gold medal he had so boldly predicted, and so badly wanted, for three years. The scoreboard offered no indication of who had won gold, silver or bronze as it processed a photo finish, a sold-out, raucous stadium sharing in the uncertainty. 12 | 13 | “I think you got that one, big dog,” Lyles told Thompson. 14 | 15 | “I’m not even sure,” Thompson replied. “It was that close.”""" 16 | 17 | output = relik(text) 18 | 19 | # Entities 20 | print(output.spans) 21 | # Relations 22 | print(output.triplets) 23 | -------------------------------------------------------------------------------- /examples/langchain.py: -------------------------------------------------------------------------------- 1 | # pip install langchain-core langchain-experimental 2 | 3 | from langchain_experimental.graph_transformers import RelikGraphTransformer 4 | from langchain_core.documents import Document 5 | 6 | relik = RelikGraphTransformer("relik-ie/relik-relation-extraction-small-wikipedia") 7 | 8 | text = """When Noah Lyles put his spike into Stade de France’s purple track for his first stride Sunday night of the Paris Olympics 100-meter final, he was already behind. In an event in which margin for error is slimmest, his reaction time to the starting gun was the slowest. 9 | 10 | Halfway through, Lyles, 27, of the U.S., was still in seventh place in an eight-man field, trying to chase down Jamaica’s Kishane Thompson, who owned not only this season’s fastest time but also the fastest time in the semifinal round contested earlier Sunday. 11 | 12 | By the final steps Lyles had caught up so much to Thompson, American Fred Kerley and South Africa’s Akani Simbine that he did something he rarely practices — dipping his shoulder at the finish. 13 | 14 | Even then, Lyles was unconvinced he had won the gold medal he had so boldly predicted, and so badly wanted, for three years. The scoreboard offered no indication of who had won gold, silver or bronze as it processed a photo finish, a sold-out, raucous stadium sharing in the uncertainty. 15 | 16 | “I think you got that one, big dog,” Lyles told Thompson. 17 | 18 | “I’m not even sure,” Thompson replied. “It was that close.”""" 19 | 20 | documents = [Document(page_content=text)] 21 | output = relik.convert_to_graph_documents(documents) 22 | # triplets 23 | print(output) 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | include = '\.pyi?$' 3 | exclude = ''' 4 | /( 5 | \.git 6 | | \.hg 7 | | \.mypy_cache 8 | | \.tox 9 | | \.venv 10 | | _build 11 | | buck-out 12 | | build 13 | | dist 14 | )/ 15 | ''' 16 | -------------------------------------------------------------------------------- /relik.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik.png -------------------------------------------------------------------------------- /relik/__init__.py: -------------------------------------------------------------------------------- 1 | from relik.inference.annotator import Relik 2 | from pathlib import Path 3 | 4 | VERSION = {} # type: ignore 5 | with open(Path(__file__).parent / "version.py", "r") as version_file: 6 | exec(version_file.read(), VERSION) 7 | 8 | __version__ = VERSION["VERSION"] 9 | -------------------------------------------------------------------------------- /relik/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/cli/__init__.py -------------------------------------------------------------------------------- /relik/cli/reader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import hydra 4 | import typer 5 | from omegaconf import OmegaConf 6 | 7 | from relik.cli.utils import resolve_config 8 | from relik.common.log import get_logger, print_relik_text_art 9 | from relik.reader.trainer.train import train as reader_train 10 | from relik.reader.trainer.train_cie import train as reader_train_cie 11 | 12 | logger = get_logger(__name__) 13 | 14 | app = typer.Typer(no_args_is_help=True, pretty_exceptions_show_locals=False) 15 | 16 | 17 | @app.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) 18 | def train(): 19 | """ 20 | Trains the reader model. 21 | 22 | This function prints the Relik text art, resolves the configuration file path, 23 | and then calls the `_reader_train` function to train the reader model. 24 | 25 | Args: 26 | None 27 | 28 | Returns: 29 | None 30 | """ 31 | print_relik_text_art() 32 | config_dir, config_name, overrides = resolve_config("reader") 33 | 34 | @hydra.main( 35 | config_path=str(config_dir), 36 | config_name=str(config_name), 37 | version_base="1.3", 38 | ) 39 | def _reader_train(conf): 40 | reader_train(conf) 41 | 42 | # clean sys.argv for hydra 43 | sys.argv = sys.argv[:1] 44 | # add the overrides to sys.argv 45 | sys.argv.extend(overrides) 46 | 47 | _reader_train() 48 | 49 | 50 | @app.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) 51 | def train_cie(): 52 | """ 53 | Trains the reader model. 54 | 55 | This function prints the Relik text art, resolves the configuration file path, 56 | and then calls the `_reader_train` function to train the reader model. 57 | 58 | Args: 59 | None 60 | 61 | Returns: 62 | None 63 | """ 64 | print_relik_text_art() 65 | config_dir, config_name, overrides = resolve_config("reader") 66 | 67 | @hydra.main( 68 | config_path=str(config_dir), 69 | config_name=str(config_name), 70 | version_base="1.3", 71 | ) 72 | def _reader_train_cie(conf): 73 | reader_train_cie(conf) 74 | 75 | # clean sys.argv for hydra 76 | sys.argv = sys.argv[:1] 77 | # add the overrides to sys.argv 78 | sys.argv.extend(overrides) 79 | 80 | _reader_train_cie() 81 | -------------------------------------------------------------------------------- /relik/cli/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | import typer 5 | from hydra import compose, initialize, initialize_config_dir 6 | from omegaconf import OmegaConf 7 | 8 | from relik.common.log import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def resolve_config(type: str | None = None) -> OmegaConf: 14 | """ 15 | Resolve the config file and return the OmegaConf object. 16 | 17 | Args: 18 | config_path (`str`): 19 | The path to the config file. 20 | 21 | Returns: 22 | `OmegaConf`: 23 | The OmegaConf object. 24 | """ 25 | # first arg is the entry point 26 | # second arg is the command 27 | # third arg is the subcommand 28 | # fourth arg is the config path/name 29 | # fifth arg is the overrides 30 | _, _, _, config_path, *overrides = sys.argv 31 | config_path = Path(config_path) 32 | # TODO: do checks 33 | # if not config_path.exists(): 34 | # raise ValueError(f"File {config_path} does not exist!") 35 | # get path and name 36 | config_dir, config_name = config_path.parent, config_path.stem 37 | # logger.debug(f"config_path: {config_path}") 38 | # logger.debug(f"config_name: {config_name}") 39 | # check if config_path is absolute or relative 40 | # if config_path.is_absolute(): 41 | # context = initialize_config_dir(config_dir=str(config_path), version_base="1.3") 42 | # else: 43 | if not config_dir.is_absolute(): 44 | base_path = Path(__file__).parent.parent 45 | if type == "reader": 46 | config_dir = base_path / "reader" / "conf" 47 | elif type == "retriever": 48 | config_dir = base_path / "retriever" / "conf" 49 | else: 50 | raise ValueError( 51 | "Please provide the type (`reader` or `retriever`) or provide an absolute path." 52 | ) 53 | logger.debug(f"config_dir: {config_dir}") 54 | # logger.debug(f"config_name: {config_name}") 55 | 56 | # print(OmegaConf.load(config_dir / f"{config_name}.yaml")) 57 | 58 | # with initialize_config_dir(config_dir=str(config_dir), version_base="1.3"): 59 | # cfg = compose(config_name=config_name, overrides=overrides) 60 | 61 | return config_dir, config_name, overrides 62 | 63 | 64 | def int_or_str_typer(value: str) -> int | None: 65 | """ 66 | Converts a string value to an integer or None. 67 | 68 | Args: 69 | value (str): The string value to be converted. 70 | 71 | Returns: 72 | int | None: The converted integer value or None if the input is "None". 73 | """ 74 | if value == "None": 75 | return None 76 | return int(value) 77 | -------------------------------------------------------------------------------- /relik/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/common/__init__.py -------------------------------------------------------------------------------- /relik/common/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | import threading 5 | from logging.config import dictConfig 6 | from typing import Any, Dict, Optional 7 | 8 | from art import text2art, tprint 9 | from colorama import Fore, Style, init 10 | from rich import get_console 11 | from termcolor import colored, cprint 12 | 13 | 14 | _lock = threading.Lock() 15 | _default_handler: Optional[logging.Handler] = None 16 | 17 | _default_log_level = logging.WARNING 18 | 19 | # fancy logger 20 | _console = get_console() 21 | 22 | 23 | class ColorfulFormatter(logging.Formatter): 24 | """ 25 | Formatter to add coloring to log messages by log type 26 | """ 27 | 28 | COLORS = { 29 | "WARNING": Fore.YELLOW, 30 | "ERROR": Fore.RED, 31 | "CRITICAL": Fore.RED + Style.BRIGHT, 32 | "DEBUG": Fore.CYAN, 33 | # "INFO": Fore.GREEN, 34 | } 35 | 36 | def format(self, record): 37 | record.rank = int(os.getenv("LOCAL_RANK", "0")) 38 | log_message = super().format(record) 39 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET 40 | 41 | 42 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = { 43 | "version": 1, 44 | "formatters": { 45 | "simple": { 46 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s", 47 | }, 48 | "colorful": { 49 | "()": ColorfulFormatter, 50 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s", 51 | }, 52 | }, 53 | "filters": {}, 54 | "handlers": { 55 | "console": { 56 | "class": "logging.StreamHandler", 57 | "formatter": "simple", 58 | "filters": [], 59 | "stream": sys.stdout, 60 | }, 61 | "color_console": { 62 | "class": "logging.StreamHandler", 63 | "formatter": "colorful", 64 | "filters": [], 65 | "stream": sys.stdout, 66 | }, 67 | }, 68 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")}, 69 | "loggers": { 70 | "relik": { 71 | "handlers": ["color_console"], 72 | "level": "DEBUG", 73 | "propagate": False, 74 | }, 75 | }, 76 | } 77 | 78 | 79 | def configure_logging(**kwargs): 80 | """Configure with default logging""" 81 | init() # Initialize colorama 82 | # merge DEFAULT_LOGGING_CONFIG with kwargs 83 | logger_config = DEFAULT_LOGGING_CONFIG 84 | if kwargs: 85 | logger_config.update(kwargs) 86 | dictConfig(logger_config) 87 | 88 | 89 | def _get_library_name() -> str: 90 | return __name__.split(".")[0] 91 | 92 | 93 | def _get_library_root_logger() -> logging.Logger: 94 | return logging.getLogger(_get_library_name()) 95 | 96 | 97 | def _configure_library_root_logger() -> None: 98 | global _default_handler 99 | 100 | with _lock: 101 | if _default_handler: 102 | # This library has already configured the library root logger. 103 | return 104 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 105 | _default_handler.flush = sys.stderr.flush 106 | 107 | # Apply our default configuration to the library root logger. 108 | library_root_logger = _get_library_root_logger() 109 | library_root_logger.addHandler(_default_handler) 110 | library_root_logger.setLevel(_default_log_level) 111 | library_root_logger.propagate = False 112 | 113 | 114 | def _reset_library_root_logger() -> None: 115 | global _default_handler 116 | 117 | with _lock: 118 | if not _default_handler: 119 | return 120 | 121 | library_root_logger = _get_library_root_logger() 122 | library_root_logger.removeHandler(_default_handler) 123 | library_root_logger.setLevel(logging.NOTSET) 124 | _default_handler = None 125 | 126 | 127 | def set_log_level(level: int, logger: logging.Logger = None) -> None: 128 | """ 129 | Set the log level. 130 | Args: 131 | level (:obj:`int`): 132 | Logging level. 133 | logger (:obj:`logging.Logger`): 134 | Logger to set the log level. 135 | """ 136 | if not logger: 137 | _configure_library_root_logger() 138 | logger = _get_library_root_logger() 139 | logger.setLevel(level) 140 | 141 | 142 | def get_logger( 143 | name: Optional[str] = None, 144 | level: Optional[int] = None, 145 | formatter: Optional[str] = None, 146 | **kwargs, 147 | ) -> logging.Logger: 148 | """ 149 | Return a logger with the specified name. 150 | """ 151 | 152 | configure_logging(**kwargs) 153 | 154 | if name is None: 155 | name = _get_library_name() 156 | 157 | _configure_library_root_logger() 158 | 159 | if level is not None: 160 | set_log_level(level) 161 | 162 | if formatter is None: 163 | formatter = logging.Formatter( 164 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 165 | ) 166 | _default_handler.setFormatter(formatter) 167 | 168 | return logging.getLogger(name) 169 | 170 | 171 | def get_console_logger(): 172 | return _console 173 | 174 | 175 | def print_relik_text_art( 176 | text: str = "relik", font: str = "larry3d", color: str = "magenta", **kwargs 177 | ): 178 | # tprint(text, font=font, **kwargs) 179 | art = text2art(text, font=font, **kwargs) # .rstrip() 180 | # art += "\n\n Retrieve, Read, and Link" 181 | # art += "\nA fast and lightweight Information Extraction framework" 182 | cprint(art, color, attrs=["bold"]) 183 | -------------------------------------------------------------------------------- /relik/common/torch_utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import tempfile 3 | 4 | import torch 5 | import transformers as tr 6 | 7 | 8 | def get_autocast_context( 9 | device: str | torch.device, precision: str 10 | ) -> contextlib.AbstractContextManager: 11 | # fucking autocast only wants pure strings like 'cpu' or 'cuda' 12 | # we need to convert the model device to that 13 | device_type_for_autocast = str(device).split(":")[0] 14 | 15 | from relik.retriever.pytorch_modules import PRECISION_MAP 16 | 17 | # autocast doesn't work with CPU and stuff different from bfloat16 18 | autocast_manager = ( 19 | contextlib.nullcontext() 20 | if device_type_for_autocast in ["cpu", "mps"] 21 | and PRECISION_MAP[precision] != torch.bfloat16 22 | else ( 23 | torch.autocast( 24 | device_type=device_type_for_autocast, 25 | dtype=PRECISION_MAP[precision], 26 | ) 27 | ) 28 | ) 29 | return autocast_manager 30 | -------------------------------------------------------------------------------- /relik/common/upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import tempfile 6 | import zipfile 7 | from datetime import datetime 8 | from pathlib import Path 9 | from typing import Optional, Union 10 | 11 | import huggingface_hub 12 | 13 | from relik.common.log import get_logger 14 | from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5 15 | 16 | logger = get_logger(__name__, level=logging.DEBUG) 17 | 18 | 19 | def create_info_file(tmpdir: Path): 20 | logger.debug("Computing md5 of model.zip") 21 | md5 = get_md5(tmpdir / "model.zip") 22 | date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT) 23 | 24 | logger.debug("Dumping info.json file") 25 | with (tmpdir / "info.json").open("w") as f: 26 | json.dump(dict(md5=md5, upload_date=date), f, indent=2) 27 | 28 | 29 | def zip_run( 30 | dir_path: Union[str, os.PathLike], 31 | tmpdir: Union[str, os.PathLike], 32 | zip_name: str = "model.zip", 33 | ) -> Path: 34 | logger.debug(f"zipping {dir_path} to {tmpdir}") 35 | # creates a zip version of the provided dir_path 36 | run_dir = Path(dir_path) 37 | zip_path = tmpdir / zip_name 38 | 39 | with zipfile.ZipFile(zip_path, "w") as zip_file: 40 | # fully zip the run directory maintaining its structure 41 | for file in run_dir.rglob("*.*"): 42 | if file.is_dir(): 43 | continue 44 | 45 | zip_file.write(file, arcname=file.relative_to(run_dir)) 46 | 47 | return zip_path 48 | 49 | 50 | def get_logged_in_username(): 51 | token = huggingface_hub.HfFolder.get_token() 52 | if token is None: 53 | raise ValueError( 54 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!" 55 | ) 56 | api = huggingface_hub.HfApi() 57 | user = api.whoami(token=token) 58 | return user["name"] 59 | 60 | 61 | def upload( 62 | model_dir: Union[str, os.PathLike], 63 | model_name: str, 64 | filenames: Optional[list[str]] = None, 65 | organization: Optional[str] = None, 66 | repo_name: Optional[str] = None, 67 | commit: Optional[str] = None, 68 | archive: bool = False, 69 | ): 70 | token = huggingface_hub.HfFolder.get_token() 71 | if token is None: 72 | raise ValueError( 73 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!" 74 | ) 75 | 76 | repo_id = repo_name or model_name 77 | if organization is not None: 78 | repo_id = f"{organization}/{repo_id}" 79 | with tempfile.TemporaryDirectory() as tmpdir: 80 | api = huggingface_hub.HfApi() 81 | repo_url = api.create_repo( 82 | token=token, 83 | repo_id=repo_id, 84 | exist_ok=True, 85 | ) 86 | repo = huggingface_hub.Repository( 87 | str(tmpdir), clone_from=repo_url, use_auth_token=token 88 | ) 89 | 90 | tmp_path = Path(tmpdir) 91 | if archive: 92 | # otherwise we zip the model_dir 93 | logger.debug(f"Zipping {model_dir} to {tmp_path}") 94 | zip_run(model_dir, tmp_path) 95 | create_info_file(tmp_path) 96 | else: 97 | # if the user wants to upload a transformers model, we don't need to zip it 98 | # we just need to copy the files to the tmpdir 99 | logger.debug(f"Copying {model_dir} to {tmpdir}") 100 | # copy only the files that are needed 101 | if filenames is not None: 102 | for filename in filenames: 103 | os.system(f"cp {model_dir}/{filename} {tmpdir}") 104 | else: 105 | os.system(f"cp -r {model_dir}/* {tmpdir}") 106 | 107 | # this method automatically puts large files (>10MB) into git lfs 108 | repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp") 109 | 110 | 111 | def parse_args() -> argparse.Namespace: 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument( 114 | "model_dir", help="The directory of the model you want to upload" 115 | ) 116 | parser.add_argument("model_name", help="The model you want to upload") 117 | parser.add_argument( 118 | "--organization", 119 | help="the name of the organization where you want to upload the model", 120 | ) 121 | parser.add_argument( 122 | "--repo_name", 123 | help="Optional name to use when uploading to the HuggingFace repository", 124 | ) 125 | parser.add_argument( 126 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub" 127 | ) 128 | parser.add_argument( 129 | "--archive", 130 | action="store_true", 131 | help=""" 132 | Whether to compress the model directory before uploading it. 133 | If True, the model directory will be zipped and the zip file will be uploaded. 134 | If False, the model directory will be uploaded as is.""", 135 | ) 136 | return parser.parse_args() 137 | 138 | 139 | def main(): 140 | upload(**vars(parse_args())) 141 | 142 | 143 | if __name__ == "__main__": 144 | main() 145 | -------------------------------------------------------------------------------- /relik/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/__init__.py -------------------------------------------------------------------------------- /relik/inference/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/__init__.py -------------------------------------------------------------------------------- /relik/inference/data/objects.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dataclasses import dataclass 4 | from enum import Enum 5 | from typing import Dict, List, NamedTuple, Optional 6 | 7 | from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample 8 | from relik.retriever.indexers.document import Document 9 | 10 | 11 | @dataclass 12 | class Word: 13 | """ 14 | A word representation that includes text, index in the sentence, POS tag, lemma, 15 | dependency relation, and similar information. 16 | 17 | # Parameters 18 | text : `str`, optional 19 | The text representation. 20 | index : `int`, optional 21 | The word offset in the sentence. 22 | lemma : `str`, optional 23 | The lemma of this word. 24 | pos : `str`, optional 25 | The coarse-grained part of speech of this word. 26 | dep : `str`, optional 27 | The dependency relation for this word. 28 | 29 | input_id : `int`, optional 30 | Integer representation of the word, used to pass it to a model. 31 | token_type_id : `int`, optional 32 | Token type id used by some transformers. 33 | attention_mask: `int`, optional 34 | Attention mask used by transformers, indicates to the model which tokens should 35 | be attended to, and which should not. 36 | """ 37 | 38 | text: str 39 | i: int 40 | idx: Optional[int] = None 41 | idx_end: Optional[int] = None 42 | # preprocessing fields 43 | lemma: Optional[str] = None 44 | pos: Optional[str] = None 45 | dep: Optional[str] = None 46 | head: Optional[int] = None 47 | 48 | def __str__(self): 49 | return self.text 50 | 51 | def __repr__(self): 52 | return self.__str__() 53 | 54 | 55 | class Span(NamedTuple): 56 | start: int 57 | end: int 58 | label: str 59 | text: str 60 | 61 | 62 | class Triplets(NamedTuple): 63 | subject: Span 64 | label: str 65 | object: Span 66 | confidence: float 67 | 68 | 69 | class Candidates(NamedTuple): 70 | span: Dict[List[Document]] 71 | triplet: Dict[List[Document]] 72 | 73 | 74 | @dataclass 75 | class RelikOutput: 76 | """ 77 | Represents the output of the Relik model. 78 | 79 | Attributes: 80 | text (str): 81 | The original input text. 82 | tokens (List[str]): 83 | The list of tokens generated from the input text. 84 | spans (List[Span]): 85 | The list of spans generated for the input text. 86 | triples (List[Triples]): 87 | The list of triples generated for the input text. 88 | candidates (Candidates): 89 | The candidates for spans and triplets. The candidates are generated by the retriever. 90 | For each type of candidate, the documents are stored in a list of lists. The outer list 91 | represents the windows, and the inner list represents the documents in that window. 92 | If only one window is used, the outer list will have only one element. 93 | windows (Optional[List[RelikReaderSample]]): 94 | The list of windows used for processing the input text. 95 | """ 96 | 97 | text: str 98 | tokens: List[str] 99 | id: str | int 100 | spans: List[Span] 101 | triplets: List[Triplets] 102 | candidates: Candidates = None 103 | windows: Optional[List[RelikReaderSample]] = None 104 | 105 | # convert to dict 106 | def to_dict(self): 107 | self_dict = { 108 | "text": self.text, 109 | "tokens": [tok.text for tok in self.tokens], 110 | "spans": self.spans, 111 | "triplets": self.triplets, 112 | "candidates": { 113 | "span": [ 114 | [[doc.to_dict() for doc in documents] for documents in window] 115 | for window in self.candidates.span 116 | ], 117 | "triplet": [ 118 | [[doc.to_dict() for doc in documents] for documents in window] 119 | for window in self.candidates.triplet 120 | ], 121 | }, 122 | } 123 | if self.windows is not None: 124 | self_dict["windows"] = [window.to_dict() for window in self.windows] 125 | return self_dict 126 | 127 | 128 | class AnnotationType(Enum): 129 | CHAR = "char" 130 | WORD = "word" 131 | 132 | 133 | class TaskType(Enum): 134 | SPAN = "span" 135 | TRIPLET = "triplet" 136 | BOTH = "both" 137 | -------------------------------------------------------------------------------- /relik/inference/data/splitters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/splitters/__init__.py -------------------------------------------------------------------------------- /relik/inference/data/splitters/base_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | 4 | class BaseSentenceSplitter: 5 | """ 6 | A `BaseSentenceSplitter` splits strings into sentences. 7 | """ 8 | 9 | def __call__(self, *args, **kwargs): 10 | """ 11 | Calls :meth:`split_sentences`. 12 | """ 13 | return self.split_sentences(*args, **kwargs) 14 | 15 | def split_sentences( 16 | self, text: str, max_len: int = 0, *args, **kwargs 17 | ) -> List[str]: 18 | """ 19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. 20 | """ 21 | raise NotImplementedError 22 | 23 | def split_sentences_batch( 24 | self, texts: List[str], *args, **kwargs 25 | ) -> List[List[str]]: 26 | """ 27 | Default implementation is to just iterate over the texts and call `split_sentences`. 28 | """ 29 | return [self.split_sentences(text) for text in texts] 30 | 31 | @staticmethod 32 | def check_is_batched( 33 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool 34 | ): 35 | """ 36 | Check if input is batched or a single sample. 37 | 38 | Args: 39 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 40 | Text to check. 41 | is_split_into_words (:obj:`bool`): 42 | If :obj:`True` and the input is a string, the input is split on spaces. 43 | 44 | Returns: 45 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. 46 | """ 47 | return bool( 48 | (not is_split_into_words and isinstance(texts, (list, tuple))) 49 | or ( 50 | is_split_into_words 51 | and isinstance(texts, (list, tuple)) 52 | and texts 53 | and isinstance(texts[0], (list, tuple)) 54 | ) 55 | ) 56 | -------------------------------------------------------------------------------- /relik/inference/data/splitters/blank_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | 4 | class BlankSentenceSplitter: 5 | """ 6 | A `BlankSentenceSplitter` splits strings into sentences. 7 | """ 8 | 9 | def __call__(self, *args, **kwargs): 10 | """ 11 | Calls :meth:`split_sentences`. 12 | """ 13 | return self.split_sentences(*args, **kwargs) 14 | 15 | def split_sentences( 16 | self, text: str, max_len: int = 0, *args, **kwargs 17 | ) -> List[str]: 18 | """ 19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence. 20 | """ 21 | return [text] 22 | 23 | def split_sentences_batch( 24 | self, texts: List[str], *args, **kwargs 25 | ) -> List[List[str]]: 26 | """ 27 | Default implementation is to just iterate over the texts and call `split_sentences`. 28 | """ 29 | return [self.split_sentences(text) for text in texts] 30 | -------------------------------------------------------------------------------- /relik/inference/data/splitters/window_based_splitter.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter 4 | 5 | 6 | class WindowSentenceSplitter(BaseSentenceSplitter): 7 | """ 8 | A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size. 9 | """ 10 | 11 | def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None: 12 | super(WindowSentenceSplitter, self).__init__() 13 | self.window_size = window_size 14 | self.window_stride = window_stride 15 | 16 | def __call__( 17 | self, 18 | texts: Union[str, List[str], List[List[str]]], 19 | is_split_into_words: bool = False, 20 | **kwargs, 21 | ) -> Union[List[str], List[List[str]]]: 22 | """ 23 | Tokenize the input into single words using SpaCy models. 24 | 25 | Args: 26 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 27 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 28 | 29 | Returns: 30 | :obj:`List[List[str]]`: The input doc split into sentences. 31 | """ 32 | return self.split_sentences(texts) 33 | 34 | def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]: 35 | """ 36 | Splits a `text` into sentences. 37 | 38 | Args: 39 | text (:obj:`str`): 40 | Text to split. 41 | 42 | Returns: 43 | :obj:`List[str]`: The input text split into sentences. 44 | """ 45 | 46 | if isinstance(text, str): 47 | text = text.split() 48 | sentences = [] 49 | # if window_stride is zero, we don't need overlapping windows 50 | self.window_stride = ( 51 | self.window_stride if self.window_stride != 0 else self.window_size 52 | ) 53 | for i in range(0, len(text), self.window_stride): 54 | # if the last stride is smaller than the window size, then we can 55 | # include more tokens form the previous window. 56 | if i != 0 and i + self.window_size > len(text): 57 | overflowing_tokens = i + self.window_size - len(text) 58 | if overflowing_tokens >= self.window_stride: 59 | break 60 | i -= overflowing_tokens 61 | involved_token_indices = list( 62 | range(i, min(i + self.window_size, len(text))) 63 | ) 64 | window_tokens = [text[j] for j in involved_token_indices] 65 | sentences.append(window_tokens) 66 | return sentences 67 | -------------------------------------------------------------------------------- /relik/inference/data/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | SPACY_LANGUAGE_MAPPER = { 2 | "ca": "ca_core_news_sm", 3 | "da": "da_core_news_sm", 4 | "de": "de_core_news_sm", 5 | "el": "el_core_news_sm", 6 | "en": "en_core_web_sm", 7 | "es": "es_core_news_sm", 8 | "fr": "fr_core_news_sm", 9 | "it": "it_core_news_sm", 10 | "ja": "ja_core_news_sm", 11 | "lt": "lt_core_news_sm", 12 | "mk": "mk_core_news_sm", 13 | "nb": "nb_core_news_sm", 14 | "nl": "nl_core_news_sm", 15 | "pl": "pl_core_news_sm", 16 | "pt": "pt_core_news_sm", 17 | "ro": "ro_core_news_sm", 18 | "ru": "ru_core_news_sm", 19 | "xx": "xx_sent_ud_sm", 20 | "zh": "zh_core_web_sm", 21 | "ca_core_news_sm": "ca_core_news_sm", 22 | "ca_core_news_md": "ca_core_news_md", 23 | "ca_core_news_lg": "ca_core_news_lg", 24 | "ca_core_news_trf": "ca_core_news_trf", 25 | "da_core_news_sm": "da_core_news_sm", 26 | "da_core_news_md": "da_core_news_md", 27 | "da_core_news_lg": "da_core_news_lg", 28 | "da_core_news_trf": "da_core_news_trf", 29 | "de_core_news_sm": "de_core_news_sm", 30 | "de_core_news_md": "de_core_news_md", 31 | "de_core_news_lg": "de_core_news_lg", 32 | "de_dep_news_trf": "de_dep_news_trf", 33 | "el_core_news_sm": "el_core_news_sm", 34 | "el_core_news_md": "el_core_news_md", 35 | "el_core_news_lg": "el_core_news_lg", 36 | "en_core_web_sm": "en_core_web_sm", 37 | "en_core_web_md": "en_core_web_md", 38 | "en_core_web_lg": "en_core_web_lg", 39 | "en_core_web_trf": "en_core_web_trf", 40 | "es_core_news_sm": "es_core_news_sm", 41 | "es_core_news_md": "es_core_news_md", 42 | "es_core_news_lg": "es_core_news_lg", 43 | "es_dep_news_trf": "es_dep_news_trf", 44 | "fr_core_news_sm": "fr_core_news_sm", 45 | "fr_core_news_md": "fr_core_news_md", 46 | "fr_core_news_lg": "fr_core_news_lg", 47 | "fr_dep_news_trf": "fr_dep_news_trf", 48 | "it_core_news_sm": "it_core_news_sm", 49 | "it_core_news_md": "it_core_news_md", 50 | "it_core_news_lg": "it_core_news_lg", 51 | "ja_core_news_sm": "ja_core_news_sm", 52 | "ja_core_news_md": "ja_core_news_md", 53 | "ja_core_news_lg": "ja_core_news_lg", 54 | "ja_dep_news_trf": "ja_dep_news_trf", 55 | "lt_core_news_sm": "lt_core_news_sm", 56 | "lt_core_news_md": "lt_core_news_md", 57 | "lt_core_news_lg": "lt_core_news_lg", 58 | "mk_core_news_sm": "mk_core_news_sm", 59 | "mk_core_news_md": "mk_core_news_md", 60 | "mk_core_news_lg": "mk_core_news_lg", 61 | "nb_core_news_sm": "nb_core_news_sm", 62 | "nb_core_news_md": "nb_core_news_md", 63 | "nb_core_news_lg": "nb_core_news_lg", 64 | "nl_core_news_sm": "nl_core_news_sm", 65 | "nl_core_news_md": "nl_core_news_md", 66 | "nl_core_news_lg": "nl_core_news_lg", 67 | "pl_core_news_sm": "pl_core_news_sm", 68 | "pl_core_news_md": "pl_core_news_md", 69 | "pl_core_news_lg": "pl_core_news_lg", 70 | "pt_core_news_sm": "pt_core_news_sm", 71 | "pt_core_news_md": "pt_core_news_md", 72 | "pt_core_news_lg": "pt_core_news_lg", 73 | "ro_core_news_sm": "ro_core_news_sm", 74 | "ro_core_news_md": "ro_core_news_md", 75 | "ro_core_news_lg": "ro_core_news_lg", 76 | "ru_core_news_sm": "ru_core_news_sm", 77 | "ru_core_news_md": "ru_core_news_md", 78 | "ru_core_news_lg": "ru_core_news_lg", 79 | "xx_ent_wiki_sm": "xx_ent_wiki_sm", 80 | "xx_sent_ud_sm": "xx_sent_ud_sm", 81 | "zh_core_web_sm": "zh_core_web_sm", 82 | "zh_core_web_md": "zh_core_web_md", 83 | "zh_core_web_lg": "zh_core_web_lg", 84 | "zh_core_web_trf": "zh_core_web_trf", 85 | } 86 | 87 | from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer 88 | -------------------------------------------------------------------------------- /relik/inference/data/tokenizers/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | from relik.inference.data.objects import Word 4 | 5 | 6 | class BaseTokenizer: 7 | """ 8 | A :obj:`Tokenizer` splits strings of text into single words, optionally adds 9 | pos tags and perform lemmatization. 10 | """ 11 | 12 | def __call__( 13 | self, 14 | texts: Union[str, List[str], List[List[str]]], 15 | is_split_into_words: bool = False, 16 | **kwargs 17 | ) -> List[List[Word]]: 18 | """ 19 | Tokenize the input into single words. 20 | 21 | Args: 22 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 23 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings. 24 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`): 25 | If :obj:`True` and the input is a string, the input is split on spaces. 26 | 27 | Returns: 28 | :obj:`List[List[Word]]`: The input text tokenized in single words. 29 | """ 30 | raise NotImplementedError 31 | 32 | def tokenize(self, text: str) -> List[Word]: 33 | """ 34 | Implements splitting words into tokens. 35 | 36 | Args: 37 | text (:obj:`str`): 38 | Text to tokenize. 39 | 40 | Returns: 41 | :obj:`List[Word]`: The input text tokenized in single words. 42 | 43 | """ 44 | raise NotImplementedError 45 | 46 | def tokenize_batch(self, texts: List[str]) -> List[List[Word]]: 47 | """ 48 | Implements batch splitting words into tokens. 49 | 50 | Args: 51 | texts (:obj:`List[str]`): 52 | Batch of text to tokenize. 53 | 54 | Returns: 55 | :obj:`List[List[Word]]`: The input batch tokenized in single words. 56 | 57 | """ 58 | return [self.tokenize(text) for text in texts] 59 | 60 | @staticmethod 61 | def check_is_batched( 62 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool 63 | ): 64 | """ 65 | Check if input is batched or a single sample. 66 | 67 | Args: 68 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`): 69 | Text to check. 70 | is_split_into_words (:obj:`bool`): 71 | If :obj:`True` and the input is a string, the input is split on spaces. 72 | 73 | Returns: 74 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise. 75 | """ 76 | return bool( 77 | (not is_split_into_words and isinstance(texts, (list, tuple))) 78 | or ( 79 | is_split_into_words 80 | and isinstance(texts, (list, tuple)) 81 | and texts 82 | and isinstance(texts[0], (list, tuple)) 83 | ) 84 | ) 85 | -------------------------------------------------------------------------------- /relik/inference/data/window/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/window/__init__.py -------------------------------------------------------------------------------- /relik/inference/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/__init__.py -------------------------------------------------------------------------------- /relik/inference/serve/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/backend/__init__.py -------------------------------------------------------------------------------- /relik/inference/serve/backend/utils.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import os 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class ServerParameterManager: 8 | relik_pretrained: str = os.environ.get("RELIK_PRETRAINED", None) 9 | device: str = os.environ.get("DEVICE", "cpu") 10 | retriever_device: str | None = os.environ.get("RETRIEVER_DEVICE", None) 11 | document_index_device: str | None = os.environ.get("INDEX_DEVICE", None) 12 | reader_device: str | None = os.environ.get("READER_DEVICE", None) 13 | precision: int | str | None = os.environ.get("PRECISION", "fp32") 14 | retriever_precision: int | str | None = os.environ.get("RETRIEVER_PRECISION", None) 15 | document_index_precision: int | str | None = os.environ.get("INDEX_PRECISION", None) 16 | reader_precision: int | str | None = os.environ.get("READER_PRECISION", None) 17 | annotation_type: str = os.environ.get("ANNOTATION_TYPE", "char") 18 | question_encoder: str = os.environ.get("QUESTION_ENCODER", None) 19 | passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None) 20 | document_index: str = os.environ.get("DOCUMENT_INDEX", None) 21 | reader_encoder: str = os.environ.get("READER_ENCODER", None) 22 | top_k: int = int(os.environ.get("TOP_K", 100)) 23 | use_faiss: bool = os.environ.get("USE_FAISS", False) 24 | retriever_batch_size: int = int(os.environ.get("RETRIEVER_BATCH_SIZE", 32)) 25 | reader_batch_size: int = int(os.environ.get("READER_BATCH_SIZE", 32)) 26 | window_size: int = int(os.environ.get("WINDOW_SIZE", 32)) 27 | window_stride: int = int(os.environ.get("WINDOW_SIZE", 16)) 28 | split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False) 29 | # relik_config_override: dict = ast.literal_eval( 30 | # os.environ.get("RELIK_CONFIG_OVERRIDE", None) 31 | # ) 32 | 33 | 34 | class RayParameterManager: 35 | def __init__(self) -> None: 36 | self.num_gpus = int(os.environ.get("NUM_GPUS", 1)) 37 | self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1)) 38 | self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1)) 39 | -------------------------------------------------------------------------------- /relik/inference/serve/frontend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/frontend/__init__.py -------------------------------------------------------------------------------- /relik/inference/serve/frontend/style.css: -------------------------------------------------------------------------------- 1 | /* Sidebar */ 2 | .eczjsme11 { 3 | background-color: #802433; 4 | } 5 | 6 | .st-emotion-cache-10oheav h2 { 7 | color: white; 8 | } 9 | 10 | .st-emotion-cache-10oheav li { 11 | color: white; 12 | } 13 | 14 | /* Main */ 15 | a:link { 16 | text-decoration: none; 17 | color: white; 18 | } 19 | 20 | a:visited { 21 | text-decoration: none; 22 | color: white; 23 | } 24 | 25 | a:hover { 26 | text-decoration: none; 27 | color: rgba(255, 255, 255, 0.871); 28 | } 29 | 30 | a:active { 31 | text-decoration: none; 32 | color: white; 33 | } -------------------------------------------------------------------------------- /relik/inference/serve/frontend/utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import random 3 | from typing import Dict, List, Optional, Union 4 | 5 | import spacy 6 | import streamlit as st 7 | from spacy import displacy 8 | 9 | 10 | def get_html(html: str): 11 | """Convert HTML so it can be rendered.""" 12 | WRAPPER = """
{}
""" 13 | # Newlines seem to mess with the rendering 14 | html = html.replace("\n", " ") 15 | return WRAPPER.format(html) 16 | 17 | 18 | def get_svg(svg: str, style: str = "", wrap: bool = True): 19 | """Convert an SVG to a base64-encoded image.""" 20 | b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8") 21 | html = f'' 22 | return get_html(html) if wrap else html 23 | 24 | 25 | def visualize_parser( 26 | doc: Union[spacy.tokens.Doc, List[Dict[str, str]]], 27 | *, 28 | title: Optional[str] = None, 29 | key: Optional[str] = None, 30 | manual: bool = False, 31 | displacy_options: Optional[Dict] = None, 32 | ) -> None: 33 | """Visualizer for dependency parses. 34 | 35 | doc (Doc, List): The document to visualize. 36 | key (str): Key used for the streamlit component for selecting labels. 37 | title (str): The title displayed at the top of the parser visualization. 38 | manual (bool): Flag signifying whether the doc argument is a Doc object or a List of Dicts containing parse information. 39 | displacy_options (Dict): Dictionary of options to be passed to the displacy render method for generating the HTML to be rendered. 40 | See: https://spacy.io/api/top-level#options-dep 41 | """ 42 | if displacy_options is None: 43 | displacy_options = dict() 44 | if title: 45 | st.header(title) 46 | docs = [doc] 47 | # add selected options to options provided by user 48 | # `options` from `displacy_options` are overwritten by user provided 49 | # options from the checkboxes 50 | for sent in docs: 51 | html = displacy.render( 52 | sent, options=displacy_options, style="dep", manual=manual 53 | ) 54 | # Double newlines seem to mess with the rendering 55 | html = html.replace("\n\n", "\n") 56 | st.write(get_svg(html), unsafe_allow_html=True) 57 | 58 | 59 | def get_random_color(ents): 60 | colors = {} 61 | random_colors = generate_pastel_colors(len(ents)) 62 | for ent in ents: 63 | colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1)) 64 | return colors 65 | 66 | 67 | def floatrange(start, stop, steps): 68 | if int(steps) == 1: 69 | return [stop] 70 | return [ 71 | start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps) 72 | ] 73 | 74 | 75 | def hsl_to_rgb(h, s, l): 76 | def hue_2_rgb(v1, v2, v_h): 77 | while v_h < 0.0: 78 | v_h += 1.0 79 | while v_h > 1.0: 80 | v_h -= 1.0 81 | if 6 * v_h < 1.0: 82 | return v1 + (v2 - v1) * 6.0 * v_h 83 | if 2 * v_h < 1.0: 84 | return v2 85 | if 3 * v_h < 2.0: 86 | return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0 87 | return v1 88 | 89 | # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1." 90 | # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1." 91 | 92 | r, b, g = (l * 255,) * 3 93 | if s != 0.0: 94 | if l < 0.5: 95 | var_2 = l * (1.0 + s) 96 | else: 97 | var_2 = (l + s) - (s * l) 98 | var_1 = 2.0 * l - var_2 99 | r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0)) 100 | g = 255 * hue_2_rgb(var_1, var_2, h) 101 | b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0)) 102 | 103 | return int(round(r)), int(round(g)), int(round(b)) 104 | 105 | 106 | def generate_pastel_colors(n): 107 | """Return different pastel colours. 108 | 109 | Input: 110 | n (integer) : The number of colors to return 111 | 112 | Output: 113 | A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc']) 114 | 115 | Example: 116 | >>> print generate_pastel_colors(5) 117 | ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0'] 118 | """ 119 | if n == 0: 120 | return [] 121 | 122 | # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space) 123 | start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue 124 | saturation = 1.0 125 | lightness = 0.9 126 | # We take points around the chromatic circle (hue): 127 | # (Note: we generate n+1 colors, then drop the last one ([:-1]) because 128 | # it equals the first one (hue 0 = hue 1)) 129 | return [ 130 | "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness) 131 | for hue in floatrange(start_hue, start_hue + 1, n + 1) 132 | ][:-1] 133 | -------------------------------------------------------------------------------- /relik/reader/__init__.py: -------------------------------------------------------------------------------- 1 | # from relik.reader.pytorch_modules.base import RelikReaderBase 2 | # from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction 3 | # from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 4 | -------------------------------------------------------------------------------- /relik/reader/conf/base.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | job: 4 | chdir: True 5 | run: 6 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 7 | 8 | model_name: relik-reader-deberta-base-04062024-lrd08-1x4096-seed42 # -start-end-mask-0.001 # used to name the model in wandb and output dir 9 | project_name: relik-reader # used to name the project in wandb 10 | offline: false # if true, wandb will not be used 11 | 12 | defaults: 13 | - _self_ 14 | - training: base 15 | - model: base 16 | - data: base 17 | -------------------------------------------------------------------------------- /relik/reader/conf/base_nyt.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-base 7 | project_name: relik-reader-nyt # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: nyt 13 | - model: nyt_base 14 | - data: nyt 15 | -------------------------------------------------------------------------------- /relik/reader/conf/cie.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-large # no-proj-special-token # -start-end-mask-0.001 # used to name the model in wandb and output dir 7 | project_name: relik-reader-cie # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: cie 13 | - model: cie 14 | - data: cie 15 | -------------------------------------------------------------------------------- /relik/reader/conf/config.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-twin-no-pere # -start-end-mask-0.001 # used to name the model in wandb and output dir 7 | project_name: relik-reader # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: base 13 | - model: base 14 | - data: base 15 | -------------------------------------------------------------------------------- /relik/reader/conf/data/base.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/train_windowed_candidates.jsonl" 2 | val_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl" 3 | 4 | train_dataset: 5 | _target_: "relik.reader.data.relik_reader_data.RelikDataset" 6 | transformer_model: "${model.model.transformer_model}" 7 | materialize_samples: False 8 | shuffle_candidates: 0.5 9 | random_drop_gold_candidates: 0.05 10 | noise_param: 0.0 11 | for_inference: False 12 | tokens_per_batch: 4096 13 | special_symbols: null 14 | 15 | val_dataset: 16 | _target_: "relik.reader.data.relik_reader_data.RelikDataset" 17 | transformer_model: "${model.model.transformer_model}" 18 | materialize_samples: False 19 | shuffle_candidates: False 20 | for_inference: True 21 | special_symbols: null 22 | -------------------------------------------------------------------------------- /relik/reader/conf/data/cie.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: 2 | val_dataset_path: 3 | test_dataset_path: 4 | 5 | train_dataset: 6 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" 7 | transformer_model: "${model.model.transformer_model}" 8 | materialize_samples: False 9 | shuffle_candidates: False 10 | flip_candidates: 1.0 11 | noise_param: 0.0 12 | for_inference: False 13 | tokens_per_batch: 4096 14 | max_length: 1024 15 | max_triplets: 25 16 | max_spans: 75 17 | min_length: -1 18 | special_symbols: null 19 | special_symbols_re: null 20 | section_size: null 21 | use_nme: True 22 | sorting_fields: 23 | - "predictable_candidates" 24 | val_dataset: 25 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" 26 | transformer_model: "${model.model.transformer_model}" 27 | materialize_samples: False 28 | shuffle_candidates: False 29 | flip_candidates: False 30 | for_inference: True 31 | use_nme: True 32 | max_triplets: 25 33 | max_spans: 75 34 | min_length: -1 35 | special_symbols: null 36 | special_symbols_re: null 37 | -------------------------------------------------------------------------------- /relik/reader/conf/data/large.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/train_windowed_candidates.jsonl" 2 | val_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl" 3 | 4 | train_dataset: 5 | _target_: "relik.reader.data.relik_reader_data.RelikDataset" 6 | transformer_model: "${model.model.transformer_model}" 7 | materialize_samples: False 8 | shuffle_candidates: 0.5 9 | random_drop_gold_candidates: 0.05 10 | noise_param: 0.0 11 | for_inference: False 12 | tokens_per_batch: 2048 13 | special_symbols: null 14 | 15 | val_dataset: 16 | _target_: "relik.reader.data.relik_reader_data.RelikDataset" 17 | transformer_model: "${model.model.transformer_model}" 18 | materialize_samples: False 19 | shuffle_candidates: False 20 | for_inference: True 21 | special_symbols: null 22 | -------------------------------------------------------------------------------- /relik/reader/conf/data/nyt.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: "data/reader/nyt/train.relik.candidates.jsonl" 2 | val_dataset_path: "data/reader/nyt/valid.relik.candidates.jsonl" 3 | test_dataset_path: "data/reader/nyt/test.relik.candidates.jsonl" 4 | 5 | train_dataset: 6 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" 7 | transformer_model: "${model.model.transformer_model}" 8 | materialize_samples: False 9 | shuffle_candidates: False 10 | flip_candidates: 1.0 11 | noise_param: 0.0 12 | for_inference: False 13 | tokens_per_batch: 2048 14 | max_length: 1024 15 | max_triplets: 24 16 | max_spans: 17 | min_length: -1 18 | special_symbols: null 19 | special_symbols_re: null 20 | section_size: null 21 | use_nme: False 22 | sorting_fields: 23 | - "predictable_candidates" 24 | val_dataset: 25 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset" 26 | transformer_model: "${model.model.transformer_model}" 27 | materialize_samples: False 28 | shuffle_candidates: False 29 | flip_candidates: False 30 | for_inference: True 31 | use_nme: False 32 | max_triplets: 24 33 | max_spans: 34 | min_length: -1 35 | special_symbols: null 36 | special_symbols_re: null 37 | -------------------------------------------------------------------------------- /relik/reader/conf/large.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-large-29052024-lrd09-8x2048-onelayer # no-proj-special-token # -start-end-mask-0.001 # used to name the model in wandb and output dir 7 | project_name: relik-reader # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: large 13 | - model: large 14 | - data: large 15 | -------------------------------------------------------------------------------- /relik/reader/conf/large_nyt.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-large 7 | project_name: relik-reader-nyt # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: nyt 13 | - model: nyt_large 14 | - data: nyt 15 | -------------------------------------------------------------------------------- /relik/reader/conf/model/base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-base" 3 | 4 | optimizer: 5 | lr: 0.0001 6 | warmup_steps: 5000 7 | total_steps: ${training.trainer.max_steps} 8 | total_reset: 1 9 | weight_decay: 0.0 10 | lr_decay: 0.8 11 | no_decay_params: 12 | - "bias" 13 | - LayerNorm.weight 14 | 15 | entities_per_forward: 100 16 | -------------------------------------------------------------------------------- /relik/reader/conf/model/cie.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-large" 3 | 4 | optimizer: 5 | lr: 1.0e-05 6 | warmup_steps: 5000 7 | total_steps: ${training.trainer.max_steps} 8 | total_reset: 1 9 | weight_decay: 0.01 10 | lr_decay: 0.9 11 | no_decay_params: 12 | - "bias" 13 | - LayerNorm.weight 14 | 15 | entities_per_forward: 75 16 | relations_per_forward: 25 17 | -------------------------------------------------------------------------------- /relik/reader/conf/model/large.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-large" 3 | 4 | optimizer: 5 | lr: 1.0e-05 6 | warmup_steps: 5000 7 | total_steps: ${training.trainer.max_steps} 8 | total_reset: 1 9 | weight_decay: 0.01 10 | lr_decay: 0.9 11 | no_decay_params: 12 | - "bias" 13 | - LayerNorm.weight 14 | 15 | entities_per_forward: 100 16 | -------------------------------------------------------------------------------- /relik/reader/conf/model/nyt.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-small" 3 | 4 | optimizer: 5 | lr: 0.00002 6 | warmup_steps: 10000 7 | total_steps: ${training.trainer.max_steps} 8 | weight_decay: 0.01 9 | no_decay_params: 10 | - "bias" 11 | - LayerNorm.weight 12 | 13 | relations_per_forward: 24 14 | entities_per_forward: 15 | -------------------------------------------------------------------------------- /relik/reader/conf/model/nyt_base.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-base" 3 | 4 | optimizer: 5 | lr: 0.00002 6 | warmup_steps: 75000 7 | total_steps: ${training.trainer.max_steps} 8 | weight_decay: 0.01 9 | no_decay_params: 10 | - "bias" 11 | - LayerNorm.weight 12 | 13 | relations_per_forward: 24 14 | entities_per_forward: 15 | -------------------------------------------------------------------------------- /relik/reader/conf/model/nyt_large.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-large" 3 | 4 | optimizer: 5 | lr: 0.00002 6 | warmup_steps: 75000 7 | total_steps: ${training.trainer.max_steps} 8 | weight_decay: 0.01 9 | no_decay_params: 10 | - "bias" 11 | - LayerNorm.weight 12 | 13 | relations_per_forward: 24 14 | entities_per_forward: 15 | -------------------------------------------------------------------------------- /relik/reader/conf/model/nyt_small.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-small" 3 | 4 | optimizer: 5 | lr: 0.00002 6 | warmup_steps: 75000 7 | total_steps: ${training.trainer.max_steps} 8 | weight_decay: 0.01 9 | no_decay_params: 10 | - "bias" 11 | - LayerNorm.weight 12 | 13 | relations_per_forward: 24 14 | entities_per_forward: 15 | -------------------------------------------------------------------------------- /relik/reader/conf/model/small.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | transformer_model: "microsoft/deberta-v3-small" 3 | 4 | optimizer: 5 | lr: 0.0001 6 | warmup_steps: 5000 7 | total_steps: ${training.trainer.max_steps} 8 | total_reset: 1 9 | weight_decay: 0.0 10 | lr_decay: 0.8 11 | no_decay_params: 12 | - "bias" 13 | - LayerNorm.weight 14 | 15 | entities_per_forward: 100 16 | -------------------------------------------------------------------------------- /relik/reader/conf/small.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-small-retriever-relik-entity-linking-aida-wikipedia # -start-end-mask-0.001 # used to name the model in wandb and output dir 7 | project_name: relik-reader # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: base 13 | - model: small 14 | - data: base 15 | -------------------------------------------------------------------------------- /relik/reader/conf/small_nyt.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: relik-reader-deberta-small 7 | project_name: relik-reader-nyt # used to name the project in wandb 8 | offline: false # if true, wandb will not be used 9 | 10 | defaults: 11 | - _self_ 12 | - training: nyt 13 | - model: nyt_small 14 | - data: nyt 15 | -------------------------------------------------------------------------------- /relik/reader/conf/training/base.yaml: -------------------------------------------------------------------------------- 1 | seed: 94 2 | 3 | trainer: 4 | _target_: lightning.Trainer 5 | devices: 6 | - 0 7 | precision: "16-mixed" 8 | max_steps: 50000 9 | val_check_interval: 1.0 10 | num_sanity_val_steps: 0 11 | limit_val_batches: 1 12 | gradient_clip_val: 1.0 13 | accumulate_grad_batches: 1 14 | -------------------------------------------------------------------------------- /relik/reader/conf/training/cie.yaml: -------------------------------------------------------------------------------- 1 | seed: 15 2 | 3 | trainer: 4 | _target_: lightning.Trainer 5 | devices: 6 | - 0 7 | precision: "16-mixed" 8 | max_steps: 100000 9 | val_check_interval: 1.0 10 | num_sanity_val_steps: 0 11 | limit_val_batches: 1 12 | gradient_clip_val: 1.0 13 | accumulate_grad_batches: 2 14 | -------------------------------------------------------------------------------- /relik/reader/conf/training/large.yaml: -------------------------------------------------------------------------------- 1 | seed: 94 2 | 3 | trainer: 4 | _target_: lightning.Trainer 5 | devices: 6 | - 0 7 | precision: "16-mixed" 8 | max_steps: 50000 9 | val_check_interval: 1.0 10 | num_sanity_val_steps: 0 11 | limit_val_batches: 1 12 | gradient_clip_val: 1.0 13 | accumulate_grad_batches: 4 14 | -------------------------------------------------------------------------------- /relik/reader/conf/training/nyt.yaml: -------------------------------------------------------------------------------- 1 | seed: 15 2 | 3 | trainer: 4 | _target_: lightning.Trainer 5 | devices: 6 | - 0 7 | precision: "16-mixed" 8 | max_steps: 100000 9 | val_check_interval: 1.0 10 | num_sanity_val_steps: 0 11 | limit_val_batches: 1 12 | gradient_clip_val: 1.0 13 | accumulate_grad_batches: 2 14 | 15 | ckpt_path: 16 | -------------------------------------------------------------------------------- /relik/reader/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/data/__init__.py -------------------------------------------------------------------------------- /relik/reader/data/patches.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from relik.reader.data.relik_reader_sample import RelikReaderSample 4 | from relik.reader.utils.special_symbols import NME_SYMBOL 5 | 6 | 7 | def merge_patches_predictions(sample) -> None: 8 | sample._d["predicted_window_labels"] = dict() 9 | predicted_window_labels = sample._d["predicted_window_labels"] 10 | 11 | sample._d["span_title_probabilities"] = dict() 12 | span_title_probabilities = sample._d["span_title_probabilities"] 13 | 14 | span2title = dict() 15 | for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]): 16 | # selecting span predictions 17 | for predicted_title, predicted_spans in patch_info[ 18 | "predicted_window_labels" 19 | ].items(): 20 | for pred_span in predicted_spans: 21 | pred_span = tuple(pred_span) 22 | curr_title = span2title.get(pred_span) 23 | if curr_title is None or curr_title == NME_SYMBOL: 24 | span2title[pred_span] = predicted_title 25 | # else: 26 | # print("Merging at patch level") 27 | 28 | # selecting span predictions probability 29 | for predicted_span, titles_probabilities in patch_info[ 30 | "span_title_probabilities" 31 | ].items(): 32 | if predicted_span not in span_title_probabilities: 33 | span_title_probabilities[predicted_span] = titles_probabilities 34 | 35 | for span, title in span2title.items(): 36 | if title not in predicted_window_labels: 37 | predicted_window_labels[title] = list() 38 | predicted_window_labels[title].append(span) 39 | 40 | 41 | def remove_duplicate_samples( 42 | samples: List[RelikReaderSample], 43 | ) -> List[RelikReaderSample]: 44 | seen_sample = set() 45 | samples_store = [] 46 | for sample in samples: 47 | sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}" 48 | if sample_id not in seen_sample: 49 | seen_sample.add(sample_id) 50 | samples_store.append(sample) 51 | return samples_store 52 | -------------------------------------------------------------------------------- /relik/reader/data/relik_reader_data_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def flatten(lsts: List[list]) -> list: 8 | acc_lst = list() 9 | for lst in lsts: 10 | acc_lst.extend(lst) 11 | return acc_lst 12 | 13 | 14 | def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor: 15 | return torch.nn.utils.rnn.pad_sequence( 16 | tensors, batch_first=True, padding_value=padding_value 17 | ) 18 | 19 | 20 | def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: 21 | x = max([t.shape[0] for t in tensors]) 22 | y = max([t.shape[1] for t in tensors]) 23 | out_matrix = torch.zeros((len(tensors), x, y)) 24 | out_matrix += padding_value 25 | for i, tensor in enumerate(tensors): 26 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor 27 | return out_matrix 28 | 29 | 30 | def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor: 31 | x = max([t.shape[0] for t in tensors]) 32 | y = max([t.shape[1] for t in tensors]) 33 | rest = tensors[0].shape[2] 34 | out_matrix = torch.zeros((len(tensors), x, y, rest)) 35 | out_matrix += padding_value 36 | for i, tensor in enumerate(tensors): 37 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor 38 | return out_matrix 39 | 40 | 41 | def chunks(lst: list, chunk_size: int) -> List[list]: 42 | chunks_acc = list() 43 | for i in range(0, len(lst), chunk_size): 44 | chunks_acc.append(lst[i : i + chunk_size]) 45 | return chunks_acc 46 | 47 | 48 | def add_noise_to_value(value: int, noise_param: float): 49 | noise_value = value * noise_param 50 | noise = np.random.uniform(-noise_value, noise_value) 51 | return max(1, value + noise) 52 | -------------------------------------------------------------------------------- /relik/reader/data/relik_reader_sample.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | from typing import Iterable 4 | 5 | 6 | class NpEncoder(json.JSONEncoder): 7 | def default(self, obj): 8 | if isinstance(obj, np.integer): 9 | return int(obj) 10 | if isinstance(obj, np.floating): 11 | return float(obj) 12 | if isinstance(obj, np.ndarray): 13 | return obj.tolist() 14 | return super(NpEncoder, self).default(obj) 15 | 16 | 17 | class RelikReaderSample: 18 | def __init__(self, **kwargs): 19 | super().__setattr__("_d", {}) 20 | self._d = kwargs 21 | 22 | def __getattribute__(self, item): 23 | return super(RelikReaderSample, self).__getattribute__(item) 24 | 25 | def __getattr__(self, item): 26 | if item.startswith("__") and item.endswith("__"): 27 | # this is likely some python library-specific variable (such as __deepcopy__ for copy) 28 | # better follow standard behavior here 29 | raise AttributeError(item) 30 | elif item in self._d: 31 | return self._d[item] 32 | else: 33 | return None 34 | 35 | def __setattr__(self, key, value): 36 | if key in self._d: 37 | self._d[key] = value 38 | else: 39 | super().__setattr__(key, value) 40 | 41 | def to_jsons(self) -> str: 42 | if "predicted_window_labels" in self._d: 43 | new_obj = { 44 | k: v 45 | for k, v in self._d.items() 46 | if k != "predicted_window_labels" and k != "span_title_probabilities" 47 | } 48 | new_obj["predicted_window_labels"] = [ 49 | [ss, se, pred_title] 50 | for (ss, se), pred_title in self.predicted_window_labels_chars 51 | ] 52 | else: 53 | return json.dumps(self._d, cls=NpEncoder) 54 | 55 | def to_dict(self) -> dict: 56 | return self._d 57 | 58 | 59 | def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]: 60 | with open(path) as f: 61 | for line in f: 62 | jsonl_line = json.loads(line.strip()) 63 | relik_reader_sample = RelikReaderSample(**jsonl_line) 64 | yield relik_reader_sample 65 | -------------------------------------------------------------------------------- /relik/reader/lightning_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/lightning_modules/__init__.py -------------------------------------------------------------------------------- /relik/reader/lightning_modules/relik_reader_pl_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import lightning 4 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler 5 | 6 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction 7 | 8 | 9 | class RelikReaderPLModule(lightning.LightningModule): 10 | def __init__( 11 | self, 12 | cfg: dict, 13 | transformer_model: str, 14 | additional_special_symbols: int, 15 | num_layers: Optional[int] = None, 16 | activation: str = "gelu", 17 | linears_hidden_size: Optional[int] = 512, 18 | use_last_k_layers: int = 1, 19 | training: bool = False, 20 | *args: Any, 21 | **kwargs: Any 22 | ): 23 | super().__init__(*args, **kwargs) 24 | self.save_hyperparameters() 25 | self.relik_reader_core_model = RelikReaderForSpanExtraction( 26 | transformer_model, 27 | additional_special_symbols, 28 | num_layers, 29 | activation, 30 | linears_hidden_size, 31 | use_last_k_layers, 32 | training=training, 33 | ) 34 | self.optimizer_factory = None 35 | 36 | def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: 37 | relik_output = self.relik_reader_core_model(**batch) 38 | self.log("train-loss", relik_output["loss"]) 39 | return relik_output["loss"] 40 | 41 | def validation_step( 42 | self, batch: dict, *args: Any, **kwargs: Any 43 | ) -> Optional[STEP_OUTPUT]: 44 | return 45 | 46 | def set_optimizer_factory(self, optimizer_factory) -> None: 47 | self.optimizer_factory = optimizer_factory 48 | 49 | def configure_optimizers(self) -> OptimizerLRScheduler: 50 | return self.optimizer_factory(self.relik_reader_core_model) 51 | -------------------------------------------------------------------------------- /relik/reader/lightning_modules/relik_reader_re_pl_module.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import lightning 4 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler 5 | 6 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 7 | 8 | 9 | class RelikReaderREPLModule(lightning.LightningModule): 10 | def __init__( 11 | self, 12 | cfg: dict, 13 | transformer_model: str, 14 | additional_special_symbols: int, 15 | additional_special_symbols_types: Optional[int] = 0, 16 | entity_type_loss: bool = None, 17 | add_entity_embedding: bool = None, 18 | num_layers: Optional[int] = None, 19 | activation: str = "gelu", 20 | linears_hidden_size: Optional[int] = 512, 21 | use_last_k_layers: int = 1, 22 | training: bool = False, 23 | default_reader_class: str = "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel", 24 | *args: Any, 25 | **kwargs: Any 26 | ): 27 | super().__init__(*args, **kwargs) 28 | self.save_hyperparameters() 29 | 30 | self.relik_reader_re_model = RelikReaderForTripletExtraction( 31 | transformer_model, 32 | additional_special_symbols, 33 | additional_special_symbols_types, 34 | entity_type_loss, 35 | add_entity_embedding, 36 | num_layers, 37 | activation, 38 | linears_hidden_size, 39 | use_last_k_layers, 40 | training=training, 41 | default_reader_class=default_reader_class, 42 | **kwargs, 43 | ) 44 | self.optimizer_factory = None 45 | 46 | def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT: 47 | relik_output = self.relik_reader_re_model(**batch) 48 | self.log("train-loss", relik_output["loss"]) 49 | self.log("train-start_loss", relik_output["ned_start_loss"]) 50 | self.log("train-end_loss", relik_output["ned_end_loss"]) 51 | self.log("train-relation_loss", relik_output["re_loss"]) 52 | if "ned_type_loss" in relik_output: 53 | self.log("train-ned_type_loss", relik_output["ned_type_loss"]) 54 | return relik_output["loss"] 55 | 56 | def validation_step( 57 | self, batch: dict, *args: Any, **kwargs: Any 58 | ) -> Optional[STEP_OUTPUT]: 59 | return 60 | 61 | def set_optimizer_factory(self, optimizer_factory) -> None: 62 | self.optimizer_factory = optimizer_factory 63 | 64 | def configure_optimizers(self) -> OptimizerLRScheduler: 65 | return self.optimizer_factory(self.relik_reader_re_model) 66 | -------------------------------------------------------------------------------- /relik/reader/pytorch_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSpanModel 2 | # from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction 3 | # from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 4 | 5 | 6 | RELIK_READER_CLASS_MAP = { 7 | "RelikReaderSpanModel": "relik.reader.pytorch_modules.span.RelikReaderForSpanExtraction", 8 | "RelikReaderREModel": "relik.reader.pytorch_modules.triplet.RelikReaderForTripletExtraction", 9 | } 10 | -------------------------------------------------------------------------------- /relik/reader/pytorch_modules/hf/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_relik import RelikReaderConfig 2 | from .modeling_relik import RelikReaderREModel 3 | -------------------------------------------------------------------------------- /relik/reader/pytorch_modules/hf/configuration_relik.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from transformers import AutoConfig 4 | from transformers.configuration_utils import PretrainedConfig 5 | 6 | 7 | class RelikReaderConfig(PretrainedConfig): 8 | model_type = "relik-reader" 9 | 10 | def __init__( 11 | self, 12 | transformer_model: str = "microsoft/deberta-v3-base", 13 | additional_special_symbols: int = 101, 14 | additional_special_symbols_types: Optional[int] = 0, 15 | num_layers: Optional[int] = None, 16 | activation: str = "gelu", 17 | linears_hidden_size: Optional[int] = 512, 18 | use_last_k_layers: int = 1, 19 | entity_type_loss: bool = False, 20 | add_entity_embedding: bool = None, 21 | binary_end_logits: bool = False, 22 | training: bool = False, 23 | default_reader_class: Optional[str] = None, 24 | threshold: Optional[float] = 0.5, 25 | **kwargs 26 | ) -> None: 27 | # TODO: add name_or_path to kwargs 28 | self.transformer_model = transformer_model 29 | self.additional_special_symbols = additional_special_symbols 30 | self.additional_special_symbols_types = additional_special_symbols_types 31 | self.num_layers = num_layers 32 | self.activation = activation 33 | self.linears_hidden_size = linears_hidden_size 34 | self.use_last_k_layers = use_last_k_layers 35 | self.entity_type_loss = entity_type_loss 36 | self.add_entity_embedding = ( 37 | True 38 | if add_entity_embedding is None and entity_type_loss 39 | else add_entity_embedding 40 | ) 41 | self.threshold = threshold 42 | self.binary_end_logits = binary_end_logits 43 | self.training = training 44 | self.default_reader_class = default_reader_class 45 | super().__init__(**kwargs) 46 | -------------------------------------------------------------------------------- /relik/reader/pytorch_modules/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from relik.reader.pytorch_modules.optim.adamw_with_warmup import ( 2 | AdamWWithWarmupOptimizer, 3 | ) 4 | from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import ( 5 | LayerWiseLRDecayOptimizer, 6 | ) 7 | -------------------------------------------------------------------------------- /relik/reader/pytorch_modules/optim/adamw_with_warmup.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import transformers 5 | from torch.optim import AdamW 6 | 7 | 8 | class AdamWWithWarmupOptimizer: 9 | def __init__( 10 | self, 11 | lr: float | List[float], 12 | warmup_steps: int, 13 | total_steps: int, 14 | weight_decay: float, 15 | no_decay_params: List[str], 16 | other_lr_params: List[str] = None, 17 | ): 18 | self.lr = lr[0] if isinstance(lr, list) else lr 19 | self.lr2 = lr[1] if isinstance(lr, list) else lr 20 | self.warmup_steps = warmup_steps 21 | self.total_steps = total_steps 22 | self.weight_decay = weight_decay 23 | self.no_decay_params = no_decay_params 24 | self.other_lr_params = other_lr_params or [] # Ensure it's a list 25 | 26 | def group_params(self, module: torch.nn.Module) -> list: 27 | decay_params = set() 28 | no_decay_params = set() 29 | other_lr_params = set() 30 | # Populate parameter sets 31 | for n, p in module.named_parameters(): 32 | if any(nd in n for nd in self.no_decay_params): 33 | no_decay_params.add(p) 34 | elif any(olr in n for olr in self.other_lr_params): 35 | other_lr_params.add(p) 36 | else: 37 | decay_params.add(p) 38 | # Group parameters 39 | optimizer_grouped_parameters = [] 40 | if decay_params: 41 | optimizer_grouped_parameters.append( 42 | {"params": list(decay_params), "weight_decay": self.weight_decay} 43 | ) 44 | if no_decay_params: 45 | optimizer_grouped_parameters.append( 46 | {"params": list(no_decay_params), "weight_decay": 0.0} 47 | ) 48 | if other_lr_params: 49 | optimizer_grouped_parameters.append( 50 | { 51 | "params": list(other_lr_params), 52 | "lr": self.lr2, 53 | "weight_decay": self.weight_decay, 54 | } 55 | ) 56 | 57 | return optimizer_grouped_parameters 58 | 59 | def __call__(self, module: torch.nn.Module): 60 | optimizer_grouped_parameters = self.group_params(module) 61 | optimizer = AdamW( 62 | optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay 63 | ) 64 | scheduler = transformers.get_linear_schedule_with_warmup( 65 | optimizer, self.warmup_steps, self.total_steps 66 | ) 67 | return { 68 | "optimizer": optimizer, 69 | "lr_scheduler": { 70 | "scheduler": scheduler, 71 | "interval": "step", 72 | "frequency": 1, 73 | }, 74 | } 75 | -------------------------------------------------------------------------------- /relik/reader/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/trainer/__init__.py -------------------------------------------------------------------------------- /relik/reader/trainer/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pprint import pprint 3 | from typing import Optional 4 | 5 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction 6 | from relik.reader.utils.strong_matching_eval import StrongMatching 7 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples 8 | 9 | 10 | def predict( 11 | model_path: str, 12 | dataset_path: str, 13 | token_batch_size: int, 14 | is_eval: bool, 15 | output_path: Optional[str], 16 | ) -> None: 17 | relik_reader = RelikReaderForSpanExtraction( 18 | model_path, dataset_kwargs={"use_nme": True}, device="cuda" 19 | ) 20 | samples = list(load_relik_reader_samples(dataset_path)) 21 | predicted_samples = relik_reader.read( 22 | samples=samples, token_batch_size=token_batch_size, progress_bar=True 23 | ) 24 | if is_eval: 25 | eval_dict = StrongMatching()(predicted_samples) 26 | pprint(eval_dict) 27 | if output_path is not None: 28 | with open(output_path, "w") as f: 29 | for sample in predicted_samples: 30 | f.write(sample.to_jsons() + "\n") 31 | 32 | 33 | def parse_arg() -> argparse.Namespace: 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--model-path", 37 | required=True, 38 | ) 39 | parser.add_argument("--dataset-path", "-i", required=True) 40 | parser.add_argument("--is-eval", action="store_true") 41 | parser.add_argument( 42 | "--output-path", 43 | "-o", 44 | ) 45 | parser.add_argument("--token-batch-size", default=4096) 46 | return parser.parse_args() 47 | 48 | 49 | def main(): 50 | args = parse_arg() 51 | predict( 52 | args.model_path, 53 | args.dataset_path, 54 | token_batch_size=args.token_batch_size, 55 | is_eval=args.is_eval, 56 | output_path=args.output_path, 57 | ) 58 | 59 | 60 | if __name__ == "__main__": 61 | main() 62 | -------------------------------------------------------------------------------- /relik/reader/trainer/predict_cie.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples 3 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 4 | from relik.reader.utils.relation_matching_eval import StrongMatching 5 | from relik.inference.data.objects import AnnotationType 6 | 7 | import torch 8 | 9 | import numpy as np 10 | from sklearn.metrics import precision_recall_curve 11 | 12 | 13 | def find_optimal_threshold(scores, labels): 14 | # Calculate precision-recall pairs for various threshold values 15 | precision, recall, thresholds = precision_recall_curve(labels, scores) 16 | # Add the end point for thresholds, which is the maximum score + 1 to ensure completeness 17 | thresholds = np.append(thresholds, thresholds[-1] + 1) 18 | # Calculate F1 scores from precision and recall for each threshold 19 | f1_scores = 2 * (precision * recall) / (precision + recall) 20 | # Handle the case where precision + recall equals zero (to avoid division by zero) 21 | f1_scores = np.nan_to_num(f1_scores) 22 | # Find the index of the maximum F1 score 23 | max_index = np.argmax(f1_scores) 24 | # Find the threshold and F1 score corresponding to the maximum F1 score 25 | optimal_threshold = thresholds[max_index] 26 | best_f1 = f1_scores[max_index] 27 | 28 | return optimal_threshold, best_f1 29 | 30 | 31 | def eval( 32 | model_path, 33 | data_path, 34 | is_eval, 35 | output_path=None, 36 | compute_threshold=False, 37 | save_threshold=False, 38 | ): 39 | device = "cuda" if torch.cuda.is_available() else "cpu" 40 | device = "mps" if torch.backends.mps.is_available() else device 41 | print(f"Device: {device}") 42 | reader = RelikReaderForTripletExtraction(model_path, training=False, device=device) 43 | samples = list(load_relik_reader_samples(data_path)) 44 | optimal_threshold = None 45 | if compute_threshold: 46 | predicted_samples = reader.read( 47 | samples=samples, 48 | progress_bar=True, 49 | annotation_type=AnnotationType.WORD, 50 | return_threshold_utils=True, 51 | ) 52 | re_probabilities, re_labels = [], [] 53 | for sample in predicted_samples: 54 | re_probabilities.extend(sample.re_probabilities.flatten()) 55 | re_labels.extend(sample.re_labels.flatten()) 56 | optimal_threshold, best_f1 = find_optimal_threshold(re_probabilities, re_labels) 57 | print(f"Optimal threshold: {optimal_threshold}") 58 | print(f"Best F1: {best_f1}") 59 | # set the threshold to the optimal threshold 60 | samples = list(load_relik_reader_samples(data_path)) 61 | if save_threshold: 62 | reader.relik_reader_model.config.threshold = optimal_threshold 63 | reader.relik_reader_model.save_pretrained(model_path) 64 | 65 | predicted_samples = reader.read( 66 | samples=samples, 67 | progress_bar=True, 68 | annotation_type=AnnotationType.WORD, 69 | relation_threshold=optimal_threshold, 70 | ) 71 | 72 | if is_eval: 73 | strong_matching_metric = StrongMatching() 74 | predicted_samples = list(predicted_samples) 75 | for k, v in strong_matching_metric(predicted_samples).items(): 76 | print(f"test_{k}", v) 77 | if output_path is not None: 78 | with open(output_path, "w") as f: 79 | for sample in predicted_samples: 80 | f.write(sample.to_jsons() + "\n") 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument( 86 | "--model_path", 87 | type=str, 88 | required=True, 89 | ) 90 | parser.add_argument( 91 | "--data_path", 92 | type=str, 93 | required=True, 94 | ) 95 | parser.add_argument("--is-eval", action="store_true") 96 | parser.add_argument("--output_path", type=str, default=None) 97 | parser.add_argument("--compute-threshold", action="store_true") 98 | parser.add_argument("--save-threshold", action="store_true") 99 | args = parser.parse_args() 100 | eval( 101 | args.model_path, 102 | args.data_path, 103 | args.is_eval, 104 | args.output_path, 105 | args.compute_threshold, 106 | args.save_threshold, 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /relik/reader/trainer/predict_re.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples 3 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 4 | from relik.reader.utils.relation_matching_eval import StrongMatching 5 | from relik.inference.data.objects import AnnotationType 6 | 7 | import torch 8 | 9 | import numpy as np 10 | from sklearn.metrics import precision_recall_curve 11 | 12 | 13 | def find_optimal_threshold(scores, labels): 14 | # Calculate precision-recall pairs for various threshold values 15 | precision, recall, thresholds = precision_recall_curve(labels, scores) 16 | # Add the end point for thresholds, which is the maximum score + 1 to ensure completeness 17 | thresholds = np.append(thresholds, thresholds[-1] + 1) 18 | # Calculate F1 scores from precision and recall for each threshold 19 | f1_scores = 2 * (precision * recall) / (precision + recall) 20 | # Handle the case where precision + recall equals zero (to avoid division by zero) 21 | f1_scores = np.nan_to_num(f1_scores) 22 | # Find the index of the maximum F1 score 23 | max_index = np.argmax(f1_scores) 24 | # Find the threshold and F1 score corresponding to the maximum F1 score 25 | optimal_threshold = thresholds[max_index] 26 | best_f1 = f1_scores[max_index] 27 | 28 | return optimal_threshold, best_f1 29 | 30 | 31 | def eval( 32 | model_path, 33 | data_path, 34 | is_eval, 35 | output_path=None, 36 | compute_threshold=False, 37 | save_threshold=False, 38 | ): 39 | device = "cuda" if torch.cuda.is_available() else "cpu" 40 | device = "mps" if torch.backends.mps.is_available() else device 41 | print(f"Device: {device}") 42 | reader = RelikReaderForTripletExtraction(model_path, training=False, device=device) 43 | samples = list(load_relik_reader_samples(data_path)) 44 | optimal_threshold = None 45 | if compute_threshold: 46 | predicted_samples = reader.read( 47 | samples=samples, 48 | progress_bar=True, 49 | annotation_type=AnnotationType.WORD, 50 | return_threshold_utils=True, 51 | ) 52 | re_probabilities, re_labels = [], [] 53 | for sample in predicted_samples: 54 | re_probabilities.extend(sample.re_probabilities.flatten()) 55 | re_labels.extend(sample.re_labels.flatten()) 56 | optimal_threshold, best_f1 = find_optimal_threshold(re_probabilities, re_labels) 57 | print(f"Optimal threshold: {optimal_threshold}") 58 | print(f"Best F1: {best_f1}") 59 | # set the threshold to the optimal threshold 60 | samples = list(load_relik_reader_samples(data_path)) 61 | if save_threshold: 62 | reader.relik_reader_model.config.threshold = optimal_threshold 63 | reader.relik_reader_model.save_pretrained(model_path) 64 | 65 | predicted_samples = reader.read( 66 | samples=samples, 67 | progress_bar=True, 68 | annotation_type=AnnotationType.WORD, 69 | relation_threshold=optimal_threshold, 70 | ) 71 | 72 | if is_eval: 73 | strong_matching_metric = StrongMatching() 74 | predicted_samples = list(predicted_samples) 75 | for k, v in strong_matching_metric(predicted_samples).items(): 76 | print(f"test_{k}", v) 77 | if output_path is not None: 78 | with open(output_path, "w") as f: 79 | for sample in predicted_samples: 80 | f.write(sample.to_jsons() + "\n") 81 | 82 | 83 | def main(): 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument( 86 | "--model_path", 87 | type=str, 88 | required=True, 89 | ) 90 | parser.add_argument( 91 | "--data_path", 92 | type=str, 93 | required=True, 94 | ) 95 | parser.add_argument("--is-eval", action="store_true") 96 | parser.add_argument("--output_path", type=str, default=None) 97 | parser.add_argument("--compute-threshold", action="store_true") 98 | parser.add_argument("--save-threshold", action="store_true") 99 | args = parser.parse_args() 100 | eval( 101 | args.model_path, 102 | args.data_path, 103 | args.is_eval, 104 | args.output_path, 105 | args.compute_threshold, 106 | args.save_threshold, 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /relik/reader/trainer/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from pprint import pprint 4 | import hydra 5 | import lightning 6 | from hydra.utils import to_absolute_path, get_original_cwd 7 | from lightning import Trainer 8 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 9 | from lightning.pytorch.loggers.wandb import WandbLogger 10 | from omegaconf import DictConfig, OmegaConf, open_dict 11 | import omegaconf 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | from relik.reader.data.relik_reader_data import RelikDataset 16 | from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule 17 | from relik.reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer 18 | from relik.reader.utils.special_symbols import get_special_symbols 19 | from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback 20 | from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback 21 | 22 | 23 | def train(cfg: DictConfig) -> None: 24 | lightning.seed_everything(cfg.training.seed) 25 | # check if deterministic algorithms are available 26 | if "deterministic" in cfg and cfg.deterministic: 27 | torch.use_deterministic_algorithms(True, warn_only=True) 28 | 29 | # log the configuration 30 | pprint(OmegaConf.to_container(cfg, resolve=True)) 31 | 32 | special_symbols = get_special_symbols(cfg.model.entities_per_forward) 33 | 34 | # model declaration 35 | model = RelikReaderPLModule( 36 | cfg=OmegaConf.to_container(cfg), 37 | transformer_model=cfg.model.model.transformer_model, 38 | additional_special_symbols=len(special_symbols), 39 | training=True, 40 | ) 41 | 42 | # optimizer declaration 43 | opt_conf = cfg.model.optimizer 44 | electra_optimizer_factory = LayerWiseLRDecayOptimizer( 45 | lr=opt_conf.lr, 46 | warmup_steps=opt_conf.warmup_steps, 47 | total_steps=opt_conf.total_steps, 48 | total_reset=opt_conf.total_reset, 49 | no_decay_params=opt_conf.no_decay_params, 50 | weight_decay=opt_conf.weight_decay, 51 | lr_decay=opt_conf.lr_decay, 52 | ) 53 | 54 | model.set_optimizer_factory(electra_optimizer_factory) 55 | 56 | # datasets declaration 57 | train_dataset: RelikDataset = hydra.utils.instantiate( 58 | cfg.data.train_dataset, 59 | dataset_path=to_absolute_path(cfg.data.train_dataset_path), 60 | special_symbols=special_symbols, 61 | ) 62 | 63 | # update of validation dataset config with special_symbols since they 64 | # are required even from the EvaluationCallback dataset_config 65 | with open_dict(cfg): 66 | cfg.data.val_dataset.special_symbols = special_symbols 67 | 68 | val_dataset: RelikDataset = hydra.utils.instantiate( 69 | cfg.data.val_dataset, 70 | dataset_path=to_absolute_path(cfg.data.val_dataset_path), 71 | ) 72 | 73 | # callbacks declaration 74 | callbacks = [ 75 | ELStrongMatchingCallback( 76 | to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset 77 | ), 78 | ModelCheckpoint( 79 | "model", 80 | filename="{epoch}-{val_core_f1:.2f}", 81 | monitor="val_core_f1", 82 | mode="max", 83 | ), 84 | LearningRateMonitor(), 85 | ] 86 | if ( 87 | cfg.data.train_dataset.section_size == None 88 | ): # If section_size is None, we shuffle the dataset. This increases a lot the speed for bigger datasets but be careful, as it will shuffle the file itself at the end of each epoch 89 | callbacks.append(ShuffleTrainCallback()) 90 | 91 | wandb_logger = WandbLogger( 92 | cfg.model_name, project=cfg.project_name, offline=cfg.offline 93 | ) 94 | 95 | # trainer declaration 96 | trainer: Trainer = hydra.utils.instantiate( 97 | cfg.training.trainer, 98 | callbacks=callbacks, 99 | logger=wandb_logger, 100 | ) 101 | 102 | # model.relik_reader_core_model._tokenizer = train_dataset.tokenizer 103 | 104 | # Trainer fit 105 | trainer.fit( 106 | model=model, 107 | train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), 108 | val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), 109 | ) 110 | 111 | # if cfg.training.save_model_path: 112 | experiment_path = Path(wandb_logger.experiment.dir) 113 | model = RelikReaderPLModule.load_from_checkpoint( 114 | trainer.checkpoint_callback.best_model_path 115 | ) 116 | model.relik_reader_core_model._tokenizer = train_dataset.tokenizer 117 | model.relik_reader_core_model.save_pretrained(experiment_path) 118 | 119 | 120 | @hydra.main(config_path="../conf", config_name="config", version_base="1.3") 121 | def main(conf: omegaconf.DictConfig): 122 | train(conf) 123 | 124 | 125 | if __name__ == "__main__": 126 | main() 127 | -------------------------------------------------------------------------------- /relik/reader/trainer/train_re.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from pathlib import Path 3 | import lightning 4 | from hydra.utils import to_absolute_path 5 | from lightning import Trainer 6 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint 7 | from lightning.pytorch.loggers.wandb import WandbLogger 8 | from omegaconf import DictConfig, OmegaConf, open_dict 9 | from torch.utils.data import DataLoader 10 | 11 | from relik.reader.data.relik_reader_re_data import RelikREDataset 12 | from relik.reader.lightning_modules.relik_reader_re_pl_module import ( 13 | RelikReaderREPLModule, 14 | ) 15 | from relik.reader.pytorch_modules.optim import ( 16 | AdamWWithWarmupOptimizer, 17 | LayerWiseLRDecayOptimizer, 18 | ) 19 | from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback 20 | from relik.reader.utils.special_symbols import get_special_symbols_re 21 | from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback 22 | 23 | 24 | @hydra.main(config_path="../conf", config_name="config") 25 | def train(cfg: DictConfig) -> None: 26 | lightning.seed_everything(cfg.training.seed) 27 | 28 | special_symbols = get_special_symbols_re(cfg.model.relations_per_forward) 29 | 30 | # datasets declaration 31 | train_dataset: RelikREDataset = hydra.utils.instantiate( 32 | cfg.data.train_dataset, 33 | dataset_path=to_absolute_path(cfg.data.train_dataset_path), 34 | special_symbols_re=special_symbols, 35 | ) 36 | 37 | # update of validation dataset config with special_symbols since they 38 | # are required even from the EvaluationCallback dataset_config 39 | with open_dict(cfg): 40 | cfg.data.val_dataset.special_symbols_re = special_symbols 41 | 42 | val_dataset: RelikREDataset = hydra.utils.instantiate( 43 | cfg.data.val_dataset, 44 | dataset_path=to_absolute_path(cfg.data.val_dataset_path), 45 | ) 46 | 47 | if val_dataset.materialize_samples: 48 | list(val_dataset.dataset_iterator_func()) 49 | # model declaration 50 | model = RelikReaderREPLModule( 51 | cfg=OmegaConf.to_container(cfg), 52 | transformer_model=cfg.model.model.transformer_model, 53 | additional_special_symbols=len(special_symbols), 54 | training=True, 55 | ) 56 | model.relik_reader_re_model._tokenizer = train_dataset.tokenizer 57 | # optimizer declaration 58 | opt_conf = cfg.model.optimizer 59 | 60 | if "total_reset" not in opt_conf: 61 | optimizer_factory = AdamWWithWarmupOptimizer( 62 | lr=opt_conf.lr, 63 | warmup_steps=opt_conf.warmup_steps, 64 | total_steps=opt_conf.total_steps, 65 | no_decay_params=opt_conf.no_decay_params, 66 | weight_decay=opt_conf.weight_decay, 67 | ) 68 | else: 69 | optimizer_factory = LayerWiseLRDecayOptimizer( 70 | lr=opt_conf.lr, 71 | warmup_steps=opt_conf.warmup_steps, 72 | total_steps=opt_conf.total_steps, 73 | total_reset=opt_conf.total_reset, 74 | no_decay_params=opt_conf.no_decay_params, 75 | weight_decay=opt_conf.weight_decay, 76 | lr_decay=opt_conf.lr_decay, 77 | ) 78 | 79 | model.set_optimizer_factory(optimizer_factory) 80 | # callbacks declaration 81 | callbacks = [ 82 | REStrongMatchingCallback( 83 | to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset 84 | ), 85 | ModelCheckpoint( 86 | "model", 87 | filename="{epoch}-{val_f1:.2f}", 88 | monitor="val_f1", 89 | mode="max", 90 | ), 91 | LearningRateMonitor(), 92 | ] 93 | 94 | if ( 95 | cfg.data.train_dataset.section_size == None 96 | ): # If section_size is None, we shuffle the dataset. This increases a lot the speed for bigger datasets but be careful, as it will shuffle the file itself at the end of each epoch 97 | callbacks.append( 98 | ShuffleTrainCallback( 99 | data_path=to_absolute_path(cfg.data.train_dataset_path) 100 | ) 101 | ) 102 | 103 | wandb_logger = WandbLogger( 104 | cfg.model_name, project=cfg.project_name, offline=cfg.offline 105 | ) 106 | 107 | # trainer declaration 108 | trainer: Trainer = hydra.utils.instantiate( 109 | cfg.training.trainer, 110 | callbacks=callbacks, 111 | logger=wandb_logger, 112 | ) 113 | 114 | # Trainer fit 115 | trainer.fit( 116 | model=model, 117 | train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0), 118 | val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0), 119 | ckpt_path=( 120 | cfg.training.ckpt_path 121 | if "ckpt_path" in cfg.training and cfg.training.ckpt_path 122 | else None 123 | ), 124 | ) 125 | 126 | # Load best checkpoint 127 | # if cfg.training.save_model_path: 128 | model = RelikReaderREPLModule.load_from_checkpoint( 129 | trainer.checkpoint_callback.best_model_path 130 | ) 131 | experiment_path = Path(wandb_logger.experiment.dir) 132 | model.relik_reader_re_model._tokenizer = train_dataset.tokenizer 133 | model.relik_reader_re_model.save_pretrained(experiment_path / "hf_model") 134 | 135 | 136 | def main(): 137 | train() 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /relik/reader/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/utils/__init__.py -------------------------------------------------------------------------------- /relik/reader/utils/metrics.py: -------------------------------------------------------------------------------- 1 | def safe_divide(num: float, den: float) -> float: 2 | if den == 0: 3 | return 0 4 | else: 5 | return num / den 6 | 7 | 8 | def f1_measure(precision: float, recall: float) -> float: 9 | if precision == 0 or recall == 0: 10 | return 0.0 11 | return safe_divide(2 * precision * recall, (precision + recall)) 12 | 13 | 14 | def compute_metrics(total_correct, total_preds, total_gold): 15 | precision = safe_divide(total_correct, total_preds) 16 | recall = safe_divide(total_correct, total_gold) 17 | f1 = f1_measure(precision, recall) 18 | return precision, recall, f1 19 | -------------------------------------------------------------------------------- /relik/reader/utils/save_load_utilities.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import Tuple 4 | 5 | import omegaconf 6 | import torch 7 | 8 | from relik.common.utils import from_cache 9 | from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule 10 | 11 | CKPT_FILE_NAME = "model.ckpt" 12 | CONFIG_FILE_NAME = "cfg.yaml" 13 | 14 | 15 | def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None: 16 | if not os.path.exists(output_dir): 17 | os.makedirs(output_dir) 18 | else: 19 | print(f"{output_dir} already exists, aborting operation") 20 | exit(1) 21 | 22 | relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint( 23 | pl_module_ckpt_path 24 | ) 25 | torch.save( 26 | relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}" 27 | ) 28 | with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f: 29 | omegaconf.OmegaConf.save( 30 | omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f 31 | ) 32 | 33 | 34 | def load_model_and_conf( 35 | model_dir_path: str, 36 | ) -> Tuple[torch.nn.Module, omegaconf.DictConfig]: 37 | # TODO: quick workaround to load the model from HF hub 38 | model_dir = from_cache( 39 | model_dir_path, 40 | filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME], 41 | cache_dir=None, 42 | force_download=False, 43 | ) 44 | 45 | ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}" 46 | model = torch.load(ckpt_path, map_location=torch.device("cpu")) 47 | 48 | model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}" 49 | model_conf = omegaconf.OmegaConf.load(model_cfg_path) 50 | return model, model_conf 51 | 52 | 53 | def parse_arg() -> argparse.Namespace: 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | "--ckpt", 57 | help="Path to the pytorch lightning ckpt you want to convert.", 58 | required=True, 59 | ) 60 | parser.add_argument( 61 | "--output-dir", 62 | "-o", 63 | help="The output dir to store the bare models and the config.", 64 | required=True, 65 | ) 66 | return parser.parse_args() 67 | 68 | 69 | def main(): 70 | args = parse_arg() 71 | convert_pl_module(args.ckpt, args.output_dir) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /relik/reader/utils/shuffle_train_callback.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.callbacks import Callback 2 | 3 | from relik.common.log import get_logger 4 | 5 | import os 6 | 7 | import random 8 | 9 | logger = get_logger() 10 | 11 | 12 | class ShuffleTrainCallback(Callback): 13 | def __init__(self, shuffle_every: int = 1, data_path: str = None): 14 | self.shuffle_every = shuffle_every 15 | self.data_path = data_path 16 | 17 | def on_train_epoch_end(self, trainer, pl_module): 18 | if (trainer.current_epoch + 1) % self.shuffle_every == 0: 19 | logger.info("Shuffling train dataset") 20 | # os.system(f"shuf {self.data_path} > {self.data_path}.shuf") 21 | # os.system(f"mv {self.data_path}.shuf {self.data_path}") 22 | lines = open(self.data_path).readlines() 23 | random.shuffle(lines) 24 | open(self.data_path, "w").writelines(lines) 25 | -------------------------------------------------------------------------------- /relik/reader/utils/special_symbols.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | NME_SYMBOL = "--NME--" 4 | 5 | 6 | def get_special_symbols(num_entities: int) -> List[str]: 7 | return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)] 8 | 9 | 10 | def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]: 11 | if use_nme: 12 | return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)] 13 | else: 14 | return [f"[R-{i}]" for i in range(num_entities)] 15 | -------------------------------------------------------------------------------- /relik/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from relik.retriever.pytorch_modules.model import GoldenRetriever 2 | -------------------------------------------------------------------------------- /relik/retriever/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/callbacks/__init__.py -------------------------------------------------------------------------------- /relik/retriever/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/common/__init__.py -------------------------------------------------------------------------------- /relik/retriever/common/model_inputs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from collections import UserDict 4 | from typing import Any, Union 5 | 6 | import torch 7 | from lightning.fabric.utilities import move_data_to_device 8 | 9 | from relik.common.log import get_logger 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | class ModelInputs(UserDict): 15 | """Model input dictionary wrapper.""" 16 | 17 | def __getattr__(self, item: str): 18 | try: 19 | return self.data[item] 20 | except KeyError: 21 | raise AttributeError(f"`ModelInputs` has no attribute `{item}`") 22 | 23 | def __getitem__(self, item: str) -> Any: 24 | return self.data[item] 25 | 26 | def __getstate__(self): 27 | return {"data": self.data} 28 | 29 | def __setstate__(self, state): 30 | if "data" in state: 31 | self.data = state["data"] 32 | 33 | def keys(self): 34 | """A set-like object providing a view on D's keys.""" 35 | return self.data.keys() 36 | 37 | def values(self): 38 | """An object providing a view on D's values.""" 39 | return self.data.values() 40 | 41 | def items(self): 42 | """A set-like object providing a view on D's items.""" 43 | return self.data.items() 44 | 45 | def to(self, device: Union[str, torch.device]) -> ModelInputs: 46 | """ 47 | Send all tensors values to device. 48 | Args: 49 | device (`str` or `torch.device`): The device to put the tensors on. 50 | Returns: 51 | :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs` 52 | after modification. 53 | """ 54 | self.data = move_data_to_device(self.data, device) 55 | return self 56 | -------------------------------------------------------------------------------- /relik/retriever/common/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler 4 | 5 | 6 | def identity(x): 7 | return x 8 | 9 | 10 | class SortedSampler(Sampler): 11 | """ 12 | Samples elements sequentially, always in the same order. 13 | 14 | Args: 15 | data (`obj`: `Iterable`): 16 | Iterable data. 17 | sort_key (`obj`: `Callable`): 18 | Specifies a function of one argument that is used to 19 | extract a numerical comparison key from each list element. 20 | 21 | Example: 22 | >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) 23 | [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] 24 | 25 | """ 26 | 27 | def __init__(self, data, sort_key=identity): 28 | super().__init__(data) 29 | self.data = data 30 | self.sort_key = sort_key 31 | zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] 32 | zip_ = sorted(zip_, key=lambda r: r[1]) 33 | self.sorted_indexes = [item[0] for item in zip_] 34 | 35 | def __iter__(self): 36 | return iter(self.sorted_indexes) 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | 42 | class BucketBatchSampler(BatchSampler): 43 | """ 44 | `BucketBatchSampler` toggles between `sampler` batches and sorted batches. 45 | Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between 46 | random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice 47 | versa. 48 | Background: 49 | ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like 50 | ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar 51 | size length to reduce the padding required for each batch while maintaining some noise 52 | through bucketing. 53 | **AllenNLP Implementation:** 54 | https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py 55 | **torchtext Implementation:** 56 | https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 57 | 58 | Args: 59 | sampler (`obj`: `torch.data.utils.sampler.Sampler): 60 | batch_size (`int`): 61 | Size of mini-batch. 62 | drop_last (`bool`, optional, defaults to `False`): 63 | If `True` the sampler will drop the last batch if its size would be less than `batch_size`. 64 | sort_key (`obj`: `Callable`, optional, defaults to `identity`): 65 | Callable to specify a comparison key for sorting. 66 | bucket_size_multiplier (`int`, optional, defaults to `100`): 67 | Buckets are of size `batch_size * bucket_size_multiplier`. 68 | Example: 69 | >>> from torchnlp.random import set_seed 70 | >>> set_seed(123) 71 | >>> 72 | >>> from torch.utils.data.sampler import SequentialSampler 73 | >>> sampler = SequentialSampler(list(range(10))) 74 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) 75 | [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] 76 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) 77 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 78 | 79 | """ 80 | 81 | def __init__( 82 | self, 83 | sampler, 84 | batch_size, 85 | drop_last: bool = False, 86 | sort_key=identity, 87 | bucket_size_multiplier=100, 88 | ): 89 | super().__init__(sampler, batch_size, drop_last) 90 | self.sort_key = sort_key 91 | _bucket_size = batch_size * bucket_size_multiplier 92 | if hasattr(sampler, "__len__"): 93 | _bucket_size = min(_bucket_size, len(sampler)) 94 | self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) 95 | 96 | def __iter__(self): 97 | for bucket in self.bucket_sampler: 98 | sorted_sampler = SortedSampler(bucket, self.sort_key) 99 | for batch in SubsetRandomSampler( 100 | list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) 101 | ): 102 | yield [bucket[i] for i in batch] 103 | 104 | def __len__(self): 105 | if self.drop_last: 106 | return len(self.sampler) // self.batch_size 107 | else: 108 | return math.ceil(len(self.sampler) / self.batch_size) 109 | -------------------------------------------------------------------------------- /relik/retriever/conf/data/aida_dataset.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: null 2 | val_dataset_path: null 3 | test_dataset_path: null 4 | 5 | shared_params: 6 | documents_path: null 7 | max_passage_length: 64 8 | passage_batch_size: 64 9 | question_batch_size: 64 10 | use_topics: True 11 | 12 | datamodule: 13 | _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule 14 | datasets: 15 | train: 16 | _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset 17 | name: "train" 18 | path: ${data.train_dataset_path} 19 | tokenizer: ${model.language_model} 20 | max_passage_length: ${data.shared_params.max_passage_length} 21 | question_batch_size: ${data.shared_params.question_batch_size} 22 | passage_batch_size: ${data.shared_params.passage_batch_size} 23 | subsample_strategy: null 24 | subsample_portion: 0.1 25 | shuffle: True 26 | metadata_fields: ['definition'] 27 | metadata_separator: ' ' 28 | use_topics: ${data.shared_params.use_topics} 29 | 30 | val: 31 | - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset 32 | name: "val" 33 | path: ${data.val_dataset_path} 34 | tokenizer: ${model.language_model} 35 | max_passage_length: ${data.shared_params.max_passage_length} 36 | question_batch_size: ${data.shared_params.question_batch_size} 37 | passage_batch_size: ${data.shared_params.passage_batch_size} 38 | metadata_fields: ['definition'] 39 | metadata_separator: ' ' 40 | use_topics: ${data.shared_params.use_topics} 41 | 42 | test: 43 | - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset 44 | name: "test" 45 | path: ${data.test_dataset_path} 46 | tokenizer: ${model.language_model} 47 | max_passage_length: ${data.shared_params.max_passage_length} 48 | question_batch_size: ${data.shared_params.question_batch_size} 49 | passage_batch_size: ${data.shared_params.passage_batch_size} 50 | metadata_fields: ['definition'] 51 | metadata_separator: ' ' 52 | use_topics: ${data.shared_params.use_topics} 53 | 54 | num_workers: 55 | train: 4 56 | val: 4 57 | test: 4 58 | -------------------------------------------------------------------------------- /relik/retriever/conf/data/blink_dataset.yaml: -------------------------------------------------------------------------------- 1 | train_dataset_path: null 2 | val_dataset_path: null 3 | test_dataset_path: null 4 | 5 | shared_params: 6 | documents_path: null 7 | max_passage_length: 64 8 | passage_batch_size: 64 9 | question_batch_size: 64 10 | 11 | datamodule: 12 | _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule 13 | datasets: 14 | train: 15 | _target_: relik.retriever.data.datasets.InBatchNegativesDataset 16 | name: "train" 17 | path: ${data.train_dataset_path} 18 | tokenizer: ${model.language_model} 19 | max_passage_length: ${data.shared_params.max_passage_length} 20 | question_batch_size: ${data.shared_params.question_batch_size} 21 | passage_batch_size: ${data.shared_params.passage_batch_size} 22 | subsample_strategy: random 23 | subsample_portion: 0.1 24 | metadata_fields: ['definition'] 25 | metadata_separator: ' ' 26 | shuffle: True 27 | 28 | val: 29 | - _target_: relik.retriever.data.datasets.InBatchNegativesDataset 30 | name: "val" 31 | path: ${data.val_dataset_path} 32 | tokenizer: ${model.language_model} 33 | max_passage_length: ${data.shared_params.max_passage_length} 34 | question_batch_size: ${data.shared_params.question_batch_size} 35 | metadata_fields: ['definition'] 36 | metadata_separator: ' ' 37 | passage_batch_size: ${data.shared_params.passage_batch_size} 38 | 39 | test: 40 | - _target_: relik.retriever.data.datasets.InBatchNegativesDataset 41 | name: "test" 42 | path: ${data.test_dataset_path} 43 | tokenizer: ${model.language_model} 44 | max_passage_length: ${data.shared_params.max_passage_length} 45 | question_batch_size: ${data.shared_params.question_batch_size} 46 | passage_batch_size: ${data.shared_params.passage_batch_size} 47 | metadata_fields: ['definition'] 48 | metadata_separator: ' ' 49 | 50 | num_workers: 51 | train: 0 52 | val: 0 53 | test: 0 54 | -------------------------------------------------------------------------------- /relik/retriever/conf/finetune_iterable_in_batch.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: ${model.language_model} # used to name the model in wandb 7 | project_name: relik-retriever-aida # used to name the project in wandb 8 | 9 | defaults: 10 | - _self_ 11 | - model: golden_retriever 12 | - index: inmemory 13 | - loss: nce_loss 14 | - optimizer: radamw 15 | - scheduler: linear_scheduler 16 | - data: aida_dataset 17 | - logging: wandb_logging 18 | - override hydra/job_logging: colorlog 19 | - override hydra/hydra_logging: colorlog 20 | 21 | train: 22 | # reproducibility 23 | seed: 42 24 | set_determinism_the_old_way: False 25 | # torch parameters 26 | float32_matmul_precision: "high" 27 | # if true, only test the model 28 | only_test: False 29 | # if provided, initialize the model with the weights from the checkpoint 30 | pretrain_ckpt_path: null 31 | # if provided, start training from the checkpoint 32 | checkpoint_path: null 33 | 34 | # task specific parameter 35 | top_k: 100 36 | 37 | # pl_trainer 38 | pl_trainer: 39 | _target_: lightning.Trainer 40 | accelerator: gpu 41 | devices: 1 42 | num_nodes: 1 43 | strategy: auto 44 | accumulate_grad_batches: 1 45 | gradient_clip_val: 1.0 46 | val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps 47 | check_val_every_n_epoch: 1 48 | max_epochs: 0 49 | max_steps: 25_000 50 | deterministic: True 51 | fast_dev_run: False 52 | precision: 16 53 | reload_dataloaders_every_n_epochs: 1 54 | 55 | early_stopping_callback: 56 | # null 57 | _target_: lightning.pytorch.callbacks.EarlyStopping 58 | monitor: validate_recall@${train.top_k} 59 | mode: max 60 | patience: 3 61 | 62 | model_checkpoint_callback: 63 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 64 | monitor: validate_recall@${train.top_k} 65 | mode: max 66 | verbose: True 67 | save_top_k: 1 68 | save_last: False 69 | filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" 70 | auto_insert_metric_name: False 71 | 72 | callbacks: 73 | prediction_callback: 74 | _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback 75 | k: ${train.top_k} 76 | batch_size: 64 77 | precision: 16 78 | index_precision: 16 79 | other_callbacks: 80 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback 81 | k: ${train.top_k} 82 | verbose: True 83 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback 84 | k: 50 85 | verbose: True 86 | prog_bar: False 87 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback 88 | k: ${train.top_k} 89 | verbose: True 90 | - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback 91 | 92 | hard_negatives_callback: 93 | _target_: relik.retriever.callbacks.training_callbacks.NegativeAugmentationCallback 94 | k: ${train.top_k} 95 | batch_size: 64 96 | precision: 16 97 | index_precision: 16 98 | stages: [validate] #[validate, sanity_check] 99 | metrics_to_monitor: 100 | validate_recall@${train.top_k} 101 | # - sanity_check_recall@${train.top_k} 102 | threshold: 0.0 103 | max_negatives: 20 104 | add_with_probability: 1.0 105 | refresh_every_n_epochs: 1 106 | other_callbacks: 107 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback 108 | k: ${train.top_k} 109 | verbose: True 110 | prefix: "train" 111 | 112 | utils_callbacks: 113 | - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback 114 | - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback 115 | # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback 116 | # question_encoder: ${model.pl_module.model.question_encoder} 117 | # passage_encoder: ${model.pl_module.model.passage_encoder} 118 | -------------------------------------------------------------------------------- /relik/retriever/conf/index/inmemory.yaml: -------------------------------------------------------------------------------- 1 | _target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex 2 | documents: 3 | _target_: relik.retriever.indexers.document.DocumentStore.from_file 4 | file_path: ${data.shared_params.documents_path} 5 | device: "cuda" 6 | precision: "16" 7 | -------------------------------------------------------------------------------- /relik/retriever/conf/logging/wandb_logging.yaml: -------------------------------------------------------------------------------- 1 | # don't forget loggers.login() for the first usage. 2 | 3 | log: True # set to False to avoid the logging 4 | 5 | wandb_arg: 6 | _target_: lightning.pytorch.loggers.WandbLogger 7 | name: ${model_name} 8 | project: ${project_name} 9 | save_dir: ./ 10 | log_model: True 11 | mode: "online" 12 | entity: null 13 | 14 | watch: 15 | log: "all" 16 | log_freq: 100 17 | -------------------------------------------------------------------------------- /relik/retriever/conf/loss/nce_loss.yaml: -------------------------------------------------------------------------------- 1 | _target_: relik.retriever.pytorch_modules.loss.MultiLabelNCELoss 2 | -------------------------------------------------------------------------------- /relik/retriever/conf/loss/nll_loss.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.nn.NLLLoss 2 | -------------------------------------------------------------------------------- /relik/retriever/conf/model/golden_retriever.yaml: -------------------------------------------------------------------------------- 1 | language_model: "intfloat/e5-small-v2" 2 | 3 | pl_module: 4 | _target_: relik.retriever.lightning_modules.pl_modules.GoldenRetrieverPLModule 5 | model: 6 | _target_: relik.retriever.pytorch_modules.model.GoldenRetriever 7 | question_encoder: ${model.language_model} 8 | passage_encoder: ${model.language_model} 9 | document_index: ${index} 10 | loss_type: ${loss} 11 | optimizer: ${optimizer} 12 | lr_scheduler: ${scheduler} 13 | -------------------------------------------------------------------------------- /relik/retriever/conf/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 1e-5 3 | weight_decay: 0.01 4 | fused: False 5 | -------------------------------------------------------------------------------- /relik/retriever/conf/optimizer/radam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.RAdam 2 | lr: 1e-5 3 | weight_decay: 0 4 | -------------------------------------------------------------------------------- /relik/retriever/conf/optimizer/radamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: relik.retriever.pytorch_modules.optim.RAdamW 2 | lr: 1e-5 3 | weight_decay: 0.01 4 | -------------------------------------------------------------------------------- /relik/retriever/conf/pretrain_iterable_in_batch.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: ${model.language_model} # used to name the model in wandb 7 | project_name: relik-retriever # used to name the project in wandb 8 | 9 | defaults: 10 | - _self_ 11 | - model: golden_retriever 12 | - index: inmemory 13 | - loss: nce_loss 14 | - optimizer: radamw 15 | - scheduler: linear_scheduler 16 | - data: blink_dataset 17 | - logging: wandb_logging 18 | - override hydra/job_logging: colorlog 19 | - override hydra/hydra_logging: colorlog 20 | 21 | train: 22 | # reproducibility 23 | seed: 42 24 | set_determinism_the_old_way: False 25 | # torch parameters 26 | float32_matmul_precision: "medium" 27 | # if true, only test the model 28 | only_test: False 29 | # if provided, initialize the model with the weights from the checkpoint 30 | pretrain_ckpt_path: null 31 | # if provided, start training from the checkpoint 32 | checkpoint_path: null 33 | 34 | # task specific parameter 35 | top_k: 100 36 | 37 | # pl_trainer 38 | pl_trainer: 39 | _target_: lightning.Trainer 40 | accelerator: gpu 41 | devices: 1 42 | num_nodes: 1 43 | strategy: auto 44 | accumulate_grad_batches: 1 45 | gradient_clip_val: 1.0 46 | val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps 47 | check_val_every_n_epoch: 1 48 | max_epochs: 0 49 | max_steps: 220_000 50 | deterministic: True 51 | fast_dev_run: False 52 | precision: 16 53 | reload_dataloaders_every_n_epochs: 1 54 | 55 | early_stopping_callback: 56 | null 57 | # _target_: lightning.pytorch.callbacks.EarlyStopping 58 | # monitor: validate_recall@${train.top_k} 59 | # mode: max 60 | # patience: 15 61 | 62 | model_checkpoint_callback: 63 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 64 | monitor: validate_recall@${train.top_k} 65 | mode: max 66 | verbose: True 67 | save_top_k: 1 68 | save_last: True 69 | filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}" 70 | auto_insert_metric_name: False 71 | 72 | callbacks: 73 | prediction_callback: 74 | _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback 75 | k: ${train.top_k} 76 | batch_size: 128 77 | precision: 16 78 | index_precision: 16 79 | other_callbacks: 80 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback 81 | k: ${train.top_k} 82 | verbose: True 83 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback 84 | k: 50 85 | verbose: True 86 | prog_bar: False 87 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback 88 | k: ${train.top_k} 89 | verbose: True 90 | - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback 91 | 92 | hard_negatives_callback: 93 | k: ${train.top_k} 94 | batch_size: 128 95 | precision: 16 96 | index_precision: 16 97 | stages: [validate] #[validate, sanity_check] 98 | metrics_to_monitor: 99 | validate_recall@${train.top_k} 100 | # - sanity_check_recall@${train.top_k} 101 | threshold: 0.0 102 | max_negatives: 15 103 | add_with_probability: 0.2 104 | refresh_every_n_epochs: 1 105 | other_callbacks: 106 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback 107 | k: ${train.top_k} 108 | verbose: True 109 | prefix: "train" 110 | 111 | utils_callbacks: 112 | - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback 113 | - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback 114 | -------------------------------------------------------------------------------- /relik/retriever/conf/scheduler/linear_scheduler.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.get_linear_schedule_with_warmup 2 | num_warmup_steps: 0 3 | num_training_steps: ${train.pl_trainer.max_steps} 4 | -------------------------------------------------------------------------------- /relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml: -------------------------------------------------------------------------------- 1 | _target_: transformers.get_linear_schedule_with_warmup 2 | num_warmup_steps: 5_000 3 | num_training_steps: ${train.pl_trainer.max_steps} 4 | -------------------------------------------------------------------------------- /relik/retriever/conf/scheduler/none.yaml: -------------------------------------------------------------------------------- 1 | null -------------------------------------------------------------------------------- /relik/retriever/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/data/__init__.py -------------------------------------------------------------------------------- /relik/retriever/data/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/data/base/__init__.py -------------------------------------------------------------------------------- /relik/retriever/data/base/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch.utils.data import Dataset, IterableDataset 7 | 8 | from relik.common.log import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | class BaseDataset(Dataset): 14 | def __init__( 15 | self, 16 | name: str, 17 | path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, 18 | data: Any = None, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | self.name = name 23 | if path is None and data is None: 24 | raise ValueError("Either `path` or `data` must be provided") 25 | self.path = path 26 | self.project_folder = Path(__file__).parent.parent.parent 27 | self.data = data 28 | 29 | def __len__(self) -> int: 30 | return len(self.data) 31 | 32 | def __getitem__( 33 | self, index 34 | ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: 35 | return self.data[index] 36 | 37 | def __repr__(self) -> str: 38 | return f"Dataset({self.name=}, {self.path=})" 39 | 40 | def load( 41 | self, 42 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]], 43 | *args, 44 | **kwargs, 45 | ) -> Any: 46 | # load data from single or multiple paths in one single dataset 47 | raise NotImplementedError 48 | 49 | @staticmethod 50 | def collate_fn(batch: Any, *args, **kwargs) -> Any: 51 | raise NotImplementedError 52 | 53 | 54 | class IterableBaseDataset(IterableDataset): 55 | def __init__( 56 | self, 57 | name: str, 58 | path: Optional[Union[str, Path, List[str], List[Path]]] = None, 59 | data: Any = None, 60 | *args, 61 | **kwargs, 62 | ): 63 | super().__init__() 64 | self.name = name 65 | if path is None and data is None: 66 | raise ValueError("Either `path` or `data` must be provided") 67 | self.path = path 68 | self.project_folder = Path(__file__).parent.parent.parent 69 | self.data = data 70 | 71 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: 72 | for sample in self.data: 73 | yield sample 74 | 75 | def __repr__(self) -> str: 76 | return f"Dataset({self.name=}, {self.path=})" 77 | 78 | def load( 79 | self, 80 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]], 81 | *args, 82 | **kwargs, 83 | ) -> Any: 84 | # load data from single or multiple paths in one single dataset 85 | raise NotImplementedError 86 | 87 | @staticmethod 88 | def collate_fn(batch: Any, *args, **kwargs) -> Any: 89 | raise NotImplementedError 90 | -------------------------------------------------------------------------------- /relik/retriever/indexers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/indexers/__init__.py -------------------------------------------------------------------------------- /relik/retriever/lightning_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/lightning_modules/__init__.py -------------------------------------------------------------------------------- /relik/retriever/lightning_modules/pl_data_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Sequence, Union 2 | 3 | import hydra 4 | import lightning as pl 5 | import torch 6 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS 7 | from omegaconf import DictConfig 8 | from torch.utils.data import DataLoader 9 | 10 | from relik.common.log import get_logger 11 | from relik.retriever.data.datasets import GoldenRetrieverDataset 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | class GoldenRetrieverPLDataModule(pl.LightningDataModule): 17 | def __init__( 18 | self, 19 | train_dataset: Optional[GoldenRetrieverDataset] = None, 20 | val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, 21 | test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None, 22 | num_workers: Optional[Union[DictConfig, int]] = None, 23 | datasets: Optional[DictConfig] = None, 24 | *args, 25 | **kwargs, 26 | ): 27 | super().__init__() 28 | self.datasets = datasets 29 | if num_workers is None: 30 | num_workers = 0 31 | if isinstance(num_workers, int): 32 | num_workers = DictConfig( 33 | {"train": num_workers, "val": num_workers, "test": num_workers} 34 | ) 35 | self.num_workers = num_workers 36 | # data 37 | self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset 38 | self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets 39 | self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets 40 | 41 | def prepare_data(self, *args, **kwargs): 42 | """ 43 | Method for preparing the data before the training. This method is called only once. 44 | It is used to download the data, tokenize the data, etc. 45 | """ 46 | pass 47 | 48 | def setup(self, stage: Optional[str] = None): 49 | if stage == "fit" or stage is None: 50 | # usually there is only one dataset for train 51 | # if you need more train loader, you can follow 52 | # the same logic as val and test datasets 53 | if self.train_dataset is None: 54 | self.train_dataset = hydra.utils.instantiate(self.datasets.train) 55 | self.val_datasets = [ 56 | hydra.utils.instantiate(dataset_cfg) 57 | for dataset_cfg in self.datasets.val 58 | ] 59 | if stage == "test": 60 | if self.test_datasets is None: 61 | self.test_datasets = [ 62 | hydra.utils.instantiate(dataset_cfg) 63 | for dataset_cfg in self.datasets.test 64 | ] 65 | 66 | def train_dataloader(self, *args, **kwargs) -> DataLoader: 67 | torch_dataset = self.train_dataset.to_torch_dataset() 68 | return DataLoader( 69 | # self.train_dataset.to_torch_dataset(), 70 | torch_dataset, 71 | shuffle=False, 72 | batch_size=None, 73 | num_workers=self.num_workers.train, 74 | pin_memory=True, 75 | collate_fn=lambda x: x, 76 | ) 77 | 78 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 79 | dataloaders = [] 80 | for dataset in self.val_datasets: 81 | torch_dataset = dataset.to_torch_dataset() 82 | dataloaders.append( 83 | DataLoader( 84 | torch_dataset, 85 | shuffle=False, 86 | batch_size=None, 87 | num_workers=self.num_workers.val, 88 | pin_memory=True, 89 | collate_fn=lambda x: x, 90 | ) 91 | ) 92 | return dataloaders 93 | 94 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: 95 | dataloaders = [] 96 | for dataset in self.test_datasets: 97 | torch_dataset = dataset.to_torch_dataset() 98 | dataloaders.append( 99 | DataLoader( 100 | torch_dataset, 101 | shuffle=False, 102 | batch_size=None, 103 | num_workers=self.num_workers.test, 104 | pin_memory=True, 105 | collate_fn=lambda x: x, 106 | ) 107 | ) 108 | return dataloaders 109 | 110 | def predict_dataloader(self) -> EVAL_DATALOADERS: 111 | raise NotImplementedError 112 | 113 | def transfer_batch_to_device( 114 | self, batch: Any, device: torch.device, dataloader_idx: int 115 | ) -> Any: 116 | return super().transfer_batch_to_device(batch, device, dataloader_idx) 117 | 118 | def __repr__(self) -> str: 119 | return ( 120 | f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, " 121 | ) 122 | -------------------------------------------------------------------------------- /relik/retriever/lightning_modules/pl_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union 2 | 3 | import hydra 4 | import lightning as pl 5 | import torch 6 | from omegaconf import DictConfig 7 | 8 | from relik.retriever.common.model_inputs import ModelInputs 9 | 10 | 11 | class GoldenRetrieverPLModule(pl.LightningModule): 12 | def __init__( 13 | self, 14 | model: Union[torch.nn.Module, DictConfig], 15 | optimizer: Union[torch.optim.Optimizer, DictConfig], 16 | lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None, 17 | *args, 18 | **kwargs, 19 | ) -> None: 20 | super().__init__() 21 | self.save_hyperparameters(ignore=["model"]) 22 | if isinstance(model, DictConfig): 23 | self.model = hydra.utils.instantiate(model) 24 | else: 25 | self.model = model 26 | 27 | self.optimizer_config = optimizer 28 | self.lr_scheduler_config = lr_scheduler 29 | 30 | def forward(self, **kwargs) -> dict: 31 | """ 32 | Method for the forward pass. 33 | 'training_step', 'validation_step' and 'test_step' should call 34 | this method in order to compute the output predictions and the loss. 35 | 36 | Returns: 37 | output_dict: forward output containing the predictions (output logits ecc...) and the loss if any. 38 | 39 | """ 40 | return self.model(**kwargs) 41 | 42 | def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor: 43 | forward_output = self.forward(**batch, return_loss=True) 44 | self.log( 45 | "loss", 46 | forward_output["loss"], 47 | batch_size=batch["questions"]["input_ids"].size(0), 48 | prog_bar=True, 49 | ) 50 | return forward_output["loss"] 51 | 52 | def validation_step(self, batch: ModelInputs, batch_idx: int) -> None: 53 | forward_output = self.forward(**batch, return_loss=True) 54 | self.log( 55 | "val_loss", 56 | forward_output["loss"], 57 | batch_size=batch["questions"]["input_ids"].size(0), 58 | ) 59 | 60 | def test_step(self, batch: ModelInputs, batch_idx: int) -> Any: 61 | forward_output = self.forward(**batch, return_loss=True) 62 | self.log( 63 | "test_loss", 64 | forward_output["loss"], 65 | batch_size=batch["questions"]["input_ids"].size(0), 66 | ) 67 | 68 | def configure_optimizers(self): 69 | if isinstance(self.optimizer_config, DictConfig): 70 | param_optimizer = list(self.named_parameters()) 71 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] 72 | optimizer_grouped_parameters = [ 73 | { 74 | "params": [ 75 | p for n, p in param_optimizer if "layer_norm_layer" in n 76 | ], 77 | "weight_decay": self.hparams.optimizer.weight_decay, 78 | "lr": 1e-4, 79 | }, 80 | { 81 | "params": [ 82 | p 83 | for n, p in param_optimizer 84 | if all(nd not in n for nd in no_decay) 85 | and "layer_norm_layer" not in n 86 | ], 87 | "weight_decay": self.hparams.optimizer.weight_decay, 88 | }, 89 | { 90 | "params": [ 91 | p 92 | for n, p in param_optimizer 93 | if "layer_norm_layer" not in n 94 | and any(nd in n for nd in no_decay) 95 | ], 96 | "weight_decay": 0.0, 97 | }, 98 | ] 99 | optimizer = hydra.utils.instantiate( 100 | self.optimizer_config, 101 | # params=self.parameters(), 102 | params=optimizer_grouped_parameters, 103 | _convert_="partial", 104 | ) 105 | else: 106 | optimizer = self.optimizer_config 107 | 108 | if self.lr_scheduler_config is None: 109 | return optimizer 110 | 111 | if isinstance(self.lr_scheduler_config, DictConfig): 112 | lr_scheduler = hydra.utils.instantiate( 113 | self.lr_scheduler_config, optimizer=optimizer 114 | ) 115 | else: 116 | lr_scheduler = self.lr_scheduler_config 117 | 118 | lr_scheduler_config = { 119 | "scheduler": lr_scheduler, 120 | "interval": "step", 121 | "frequency": 1, 122 | } 123 | return [optimizer], [lr_scheduler_config] 124 | -------------------------------------------------------------------------------- /relik/retriever/pytorch_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | 5 | from relik.retriever.indexers.document import Document 6 | 7 | PRECISION_MAP = { 8 | None: torch.float32, 9 | 32: torch.float32, 10 | 16: torch.float16, 11 | torch.float32: torch.float32, 12 | torch.float16: torch.float16, 13 | torch.bfloat16: torch.bfloat16, 14 | "float32": torch.float32, 15 | "float16": torch.float16, 16 | "bfloat16": torch.bfloat16, 17 | "float": torch.float32, 18 | "half": torch.float16, 19 | "32": torch.float32, 20 | "16": torch.float16, 21 | "fp32": torch.float32, 22 | "fp16": torch.float16, 23 | "bf16": torch.bfloat16, 24 | } 25 | 26 | 27 | @dataclass 28 | class RetrievedSample: 29 | """ 30 | Dataclass for the output of the GoldenRetriever model. 31 | """ 32 | 33 | score: float 34 | document: Document 35 | -------------------------------------------------------------------------------- /relik/retriever/pytorch_modules/hf.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | from transformers import PretrainedConfig 5 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions 6 | from transformers.models.bert.modeling_bert import BertModel 7 | 8 | 9 | class GoldenRetrieverConfig(PretrainedConfig): 10 | model_type = "bert" 11 | 12 | def __init__( 13 | self, 14 | vocab_size=30522, 15 | hidden_size=768, 16 | num_hidden_layers=12, 17 | num_attention_heads=12, 18 | intermediate_size=3072, 19 | hidden_act="gelu", 20 | hidden_dropout_prob=0.1, 21 | attention_probs_dropout_prob=0.1, 22 | max_position_embeddings=512, 23 | type_vocab_size=2, 24 | initializer_range=0.02, 25 | layer_norm_eps=1e-12, 26 | pad_token_id=0, 27 | position_embedding_type="absolute", 28 | use_cache=True, 29 | classifier_dropout=None, 30 | **kwargs, 31 | ): 32 | super().__init__(pad_token_id=pad_token_id, **kwargs) 33 | 34 | self.vocab_size = vocab_size 35 | self.hidden_size = hidden_size 36 | self.num_hidden_layers = num_hidden_layers 37 | self.num_attention_heads = num_attention_heads 38 | self.hidden_act = hidden_act 39 | self.intermediate_size = intermediate_size 40 | self.hidden_dropout_prob = hidden_dropout_prob 41 | self.attention_probs_dropout_prob = attention_probs_dropout_prob 42 | self.max_position_embeddings = max_position_embeddings 43 | self.type_vocab_size = type_vocab_size 44 | self.initializer_range = initializer_range 45 | self.layer_norm_eps = layer_norm_eps 46 | self.position_embedding_type = position_embedding_type 47 | self.use_cache = use_cache 48 | self.classifier_dropout = classifier_dropout 49 | 50 | 51 | class GoldenRetrieverModel(BertModel): 52 | config_class = GoldenRetrieverConfig 53 | 54 | def __init__(self, config, *args, **kwargs): 55 | super().__init__(config) 56 | self.layer_norm_layer = torch.nn.LayerNorm( 57 | config.hidden_size, eps=config.layer_norm_eps 58 | ) 59 | 60 | def forward( 61 | self, **kwargs 62 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: 63 | attention_mask = kwargs.get("attention_mask", None) 64 | model_outputs = super().forward(**kwargs) 65 | if attention_mask is None: 66 | pooler_output = model_outputs.pooler_output 67 | else: 68 | token_embeddings = model_outputs.last_hidden_state 69 | input_mask_expanded = ( 70 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() 71 | ) 72 | pooler_output = torch.sum( 73 | token_embeddings * input_mask_expanded, 1 74 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) 75 | 76 | pooler_output = self.layer_norm_layer(pooler_output) 77 | 78 | if not kwargs.get("return_dict", True): 79 | return (model_outputs[0], pooler_output) + model_outputs[2:] 80 | 81 | return BaseModelOutputWithPoolingAndCrossAttentions( 82 | last_hidden_state=model_outputs.last_hidden_state, 83 | pooler_output=pooler_output, 84 | past_key_values=model_outputs.past_key_values, 85 | hidden_states=model_outputs.hidden_states, 86 | attentions=model_outputs.attentions, 87 | cross_attentions=model_outputs.cross_attentions, 88 | ) 89 | -------------------------------------------------------------------------------- /relik/retriever/pytorch_modules/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch.nn.modules.loss import _WeightedLoss 5 | 6 | 7 | class MultiLabelNCELoss(_WeightedLoss): 8 | __constants__ = ["reduction"] 9 | 10 | def __init__( 11 | self, 12 | weight: Optional[torch.Tensor] = None, 13 | size_average=None, 14 | reduction: Optional[str] = "mean", 15 | ) -> None: 16 | super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction) 17 | 18 | def forward( 19 | self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100 20 | ) -> torch.Tensor: 21 | gold_scores = input.masked_fill(~(target.bool()), 0) 22 | gold_scores_sum = gold_scores.sum(-1) # B x C 23 | neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L 24 | neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1 25 | norm_term = ( 26 | torch.logaddexp(input, neg_log_sum_exp) 27 | .masked_fill(~(target.bool()), 0) 28 | .sum(-1) 29 | ) 30 | gold_log_probs = gold_scores_sum - norm_term 31 | loss = -gold_log_probs.sum() 32 | if self.reduction == "mean": 33 | loss /= input.size(0) 34 | return loss 35 | -------------------------------------------------------------------------------- /relik/retriever/pytorch_modules/optim.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.optim import Optimizer 5 | 6 | 7 | class RAdamW(Optimizer): 8 | r"""Implements RAdamW algorithm. 9 | 10 | RAdam from `On the Variance of the Adaptive Learning Rate and Beyond 11 | `_ 12 | 13 | * `Adam: A Method for Stochastic Optimization 14 | `_ 15 | * `Decoupled Weight Decay Regularization 16 | `_ 17 | * `On the Convergence of Adam and Beyond 18 | `_ 19 | * `On the Variance of the Adaptive Learning Rate and Beyond 20 | `_ 21 | 22 | Arguments: 23 | params (iterable): iterable of parameters to optimize or dicts defining 24 | parameter groups 25 | lr (float, optional): learning rate (default: 1e-3) 26 | betas (Tuple[float, float], optional): coefficients used for computing 27 | running averages of gradient and its square (default: (0.9, 0.999)) 28 | eps (float, optional): term added to the denominator to improve 29 | numerical stability (default: 1e-8) 30 | weight_decay (float, optional): weight decay coefficient (default: 1e-2) 31 | """ 32 | 33 | def __init__( 34 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2 35 | ): 36 | if not 0.0 <= lr: 37 | raise ValueError("Invalid learning rate: {}".format(lr)) 38 | if not 0.0 <= eps: 39 | raise ValueError("Invalid epsilon value: {}".format(eps)) 40 | if not 0.0 <= betas[0] < 1.0: 41 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 42 | if not 0.0 <= betas[1] < 1.0: 43 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 44 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 45 | super(RAdamW, self).__init__(params, defaults) 46 | 47 | def step(self, closure=None): 48 | """Performs a single optimization step. 49 | 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group["params"]: 60 | if p.grad is None: 61 | continue 62 | 63 | # Perform optimization step 64 | grad = p.grad.data 65 | if grad.is_sparse: 66 | raise RuntimeError( 67 | "Adam does not support sparse gradients, please consider SparseAdam instead" 68 | ) 69 | 70 | state = self.state[p] 71 | 72 | # State initialization 73 | if len(state) == 0: 74 | state["step"] = 0 75 | # Exponential moving average of gradient values 76 | state["exp_avg"] = torch.zeros_like(p.data) 77 | # Exponential moving average of squared gradient values 78 | state["exp_avg_sq"] = torch.zeros_like(p.data) 79 | 80 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 81 | beta1, beta2 = group["betas"] 82 | eps = group["eps"] 83 | lr = group["lr"] 84 | if "rho_inf" not in group: 85 | group["rho_inf"] = 2 / (1 - beta2) - 1 86 | rho_inf = group["rho_inf"] 87 | 88 | state["step"] += 1 89 | t = state["step"] 90 | 91 | # Decay the first and second moment running average coefficient 92 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 93 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 94 | rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t)) 95 | 96 | # Perform stepweight decay 97 | p.data.mul_(1 - lr * group["weight_decay"]) 98 | 99 | if rho_t >= 5: 100 | var = exp_avg_sq.sqrt().add_(eps) 101 | r = math.sqrt( 102 | (1 - beta2**t) 103 | * ((rho_t - 4) * (rho_t - 2) * rho_inf) 104 | / ((rho_inf - 4) * (rho_inf - 2) * rho_t) 105 | ) 106 | 107 | p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t)) 108 | else: 109 | p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t)) 110 | 111 | return loss 112 | -------------------------------------------------------------------------------- /relik/retriever/pytorch_modules/scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import LRScheduler 3 | 4 | 5 | class LinearSchedulerWithWarmup(LRScheduler): 6 | def __init__( 7 | self, 8 | optimizer: torch.optim.Optimizer, 9 | num_warmup_steps: int, 10 | num_training_steps: int, 11 | last_epoch: int = -1, 12 | verbose: bool = False, 13 | **kwargs, 14 | ): 15 | self.num_warmup_steps = num_warmup_steps 16 | self.num_training_steps = num_training_steps 17 | super().__init__(optimizer, last_epoch, verbose) 18 | 19 | def get_lr(self): 20 | def scheduler_fn(current_step): 21 | if current_step < self.num_warmup_steps: 22 | return current_step / max(1, self.num_warmup_steps) 23 | return max( 24 | 0.0, 25 | float(self.num_training_steps - current_step) 26 | / float(max(1, self.num_training_steps - self.num_warmup_steps)), 27 | ) 28 | 29 | return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] 30 | 31 | 32 | class LinearScheduler(LRScheduler): 33 | def __init__( 34 | self, 35 | optimizer: torch.optim.Optimizer, 36 | num_training_steps: int, 37 | last_epoch: int = -1, 38 | verbose: bool = False, 39 | **kwargs, 40 | ): 41 | self.num_training_steps = num_training_steps 42 | super().__init__(optimizer, last_epoch, verbose) 43 | 44 | def get_lr(self): 45 | def scheduler_fn(current_step): 46 | # if current_step < self.num_warmup_steps: 47 | # return current_step / max(1, self.num_warmup_steps) 48 | return max( 49 | 0.0, 50 | float(self.num_training_steps - current_step) 51 | / float(max(1, self.num_training_steps)), 52 | ) 53 | 54 | return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs] 55 | -------------------------------------------------------------------------------- /relik/retriever/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from relik.retriever.trainer.train import RetrieverTrainer 2 | -------------------------------------------------------------------------------- /relik/version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _MAJOR = "1" 4 | _MINOR = "0" 5 | # On main and in a nightly release the patch should be one ahead of the last 6 | # released build. 7 | _PATCH = "7" 8 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 9 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 10 | _SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "") 11 | 12 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 13 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | #------- Core dependencies ------- 2 | --extra-index-url https://download.pytorch.org/whl/cu12.1 3 | torch==2.3.1 4 | 5 | transformers[sentencepiece]>=4.41,<4.42 6 | rich>=13.0.0,<14.0.0 7 | scikit-learn>=1.5,<1.6 8 | overrides>=7.4,<7.9 9 | art==6.2 10 | pprintpp==0.4.0 11 | colorama==0.4.6 12 | termcolor==2.4.0 13 | spacy>=3.7,<3.8 14 | typer>=0.12,<0.13 15 | 16 | #------- Optional dependencies ------- 17 | 18 | # train 19 | lightning>=2.3,<2.4 20 | datasets>=2.13,<2.15 21 | hydra-core>=1.3,<1.4 22 | hydra_colorlog 23 | wandb>=0.15,<0.18 24 | 25 | # faiss 26 | faiss-cpu==1.8.0 # needed by: faiss 27 | 28 | # serve 29 | fastapi>=0.112,<0.113 # needed by: serve, ray 30 | uvicorn[standard]==0.23.2 # needed by: serve, ray 31 | gunicorn==22.0.0 # needed by: serve, ray 32 | streamlit>=1.28,<1.29 # needed by: serve, ray 33 | streamlit_extras>=0.3,<0.4 # needed by: serve, ray 34 | gradio>=4.37,<4.38 # needed by: serve, ray 35 | pyvis # needed by: serve, ray 36 | ray[serve]>=2.34,<=2.35 # needed by: ray 37 | 38 | # dev 39 | pre-commit # needed by: dev 40 | black[d] # needed by: dev 41 | isort # needed by: dev 42 | -------------------------------------------------------------------------------- /scripts/build_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | # get version from version.py file 5 | VERSION=$(python -c "import relik; print(relik.__version__)") 6 | LATEST_VERSION=$(echo "$VERSION" | tail -n 1) 7 | echo "Building version: $LATEST_VERSION" 8 | 9 | echo "==== Building CPU images ====" 10 | # docker build -f dockerfiles/ray/Dockerfile.cpu -t relik:$VERSION-cpu-ray . 11 | docker build -f dockerfiles/fastapi/Dockerfile.cpu -t relik:$LATEST_VERSION-cpu-fastapi . 12 | 13 | echo "==== Building GPU images ====" 14 | # docker build -f dockerfiles/ray/Dockerfile.cuda -t relik:$VERSION-cuda-ray . 15 | docker build -f dockerfiles/fastapi/Dockerfile.cuda -t relik:$LATEST_VERSION-cuda-fastapi . 16 | -------------------------------------------------------------------------------- /scripts/build_docker_with_weights.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # get version from version.py file 4 | VERSION=$(python -c "import relik; print(relik.__version__)") 5 | echo "Building version: $VERSION" 6 | 7 | # if relik-model exists, delete it 8 | REPO_MODEL_PATH="relik-model" 9 | if [ -d "$REPO_MODEL_PATH" ]; then 10 | echo "Deleting $REPO_MODEL_PATH" 11 | rm -r "$REPO_MODEL_PATH" 12 | fi 13 | 14 | # create relik-model directory and copy model files 15 | echo "Copying model files to $REPO_MODEL_PATH" 16 | mkdir -p "$REPO_MODEL_PATH" 17 | # copy model files 18 | cp -r "$MODEL_PATH"/* "$REPO_MODEL_PATH" 19 | 20 | docker build -f dockerfiles/ray/Dockerfile.cuda -t relik:$VERSION-bsc-cuda-ray . 21 | 22 | # clean up 23 | if [ -d "$REPO_MODEL_PATH" ]; then 24 | echo "Deleting $REPO_MODEL_PATH" 25 | rm -r "$REPO_MODEL_PATH" 26 | fi 27 | -------------------------------------------------------------------------------- /scripts/data/blink/preprocess_genre_blink.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | import re 5 | from typing import List, Tuple 6 | 7 | from tqdm import tqdm 8 | 9 | LABELS_REGEX = "[{][^}]+[}] \[[^]]+\]" 10 | REPLACE_PATTERNS = [ 11 | ("BULLET::::- ", ""), 12 | ("( )", ""), 13 | (" ", " "), 14 | (" ", " "), 15 | (" ", " "), 16 | ] 17 | LABELS_FILTERING_FUNCTIONS = [lambda x: x.startswith("List of")] 18 | SUBSTITUTION_PATTERN = "#$#" 19 | 20 | 21 | def process_annotation(ann_surface_text: str) -> Tuple[str, str]: 22 | mention, label = ann_surface_text.split("} [") 23 | mention = mention.replace("{", "").strip() 24 | label = label.replace("]", "").strip() 25 | return mention, label 26 | 27 | 28 | def substitute_annotations( 29 | annotations: List[str], sub_line: str 30 | ) -> Tuple[str, List[Tuple[int, int, str]]]: 31 | final_annotations_store = [] 32 | for annotation in annotations: 33 | mention, label = process_annotation(annotation) 34 | start_char = sub_line.index(SUBSTITUTION_PATTERN) 35 | end_char = start_char + len(mention) 36 | sub_line = sub_line.replace(SUBSTITUTION_PATTERN, mention, 1) 37 | assert sub_line[start_char:end_char] == mention 38 | if any([fl(label) for fl in LABELS_FILTERING_FUNCTIONS]): 39 | continue 40 | final_annotations_store.append((start_char, end_char, label)) 41 | return sub_line, final_annotations_store 42 | 43 | 44 | def preprocess_line(line: str) -> Tuple[str, List[Tuple[int, int, str]]]: 45 | for rps, rpe in REPLACE_PATTERNS: 46 | line = line.replace(rps, rpe) 47 | 48 | annotations = re.findall(LABELS_REGEX, line) 49 | sub_line = re.sub(LABELS_REGEX, SUBSTITUTION_PATTERN, line) 50 | return substitute_annotations(annotations, sub_line) 51 | 52 | 53 | def preprocess_genre_el_file( 54 | file_path: str, output_path: str, limit_lines: int = -1 55 | ) -> None: 56 | # Create output directory 57 | Path(output_path).parent.mkdir(parents=True, exist_ok=True) 58 | with open(file_path) as fi, open(output_path, "w") as fo: 59 | for i, line in tqdm(enumerate(fi)): 60 | text, annotations = preprocess_line(line.strip()) 61 | fo.write( 62 | json.dumps(dict(doc_id=i, doc_text=text, doc_span_annotations=annotations)) 63 | + "\n" 64 | ) 65 | if limit_lines == i: 66 | break 67 | 68 | 69 | def main(): 70 | 71 | arg_parser = argparse.ArgumentParser("Preprocess Genre BLINK file.") 72 | arg_parser.add_argument("input_file", type=str) 73 | arg_parser.add_argument("output_file", type=str) 74 | arg_parser.add_argument("--limit-lines", type=int, default=-1) 75 | args = arg_parser.parse_args() 76 | 77 | preprocess_genre_el_file(args.input_file, args.output_file, args.limit_lines) 78 | 79 | 80 | if __name__ == "__main__": 81 | main() 82 | -------------------------------------------------------------------------------- /scripts/data/retriever/create_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | import torch 8 | 9 | from relik.retriever import GoldenRetriever 10 | from relik.common.utils import get_logger, get_callable_from_string 11 | from relik.retriever.indexers.document import DocumentStore 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | @torch.no_grad() 17 | def build_index( 18 | question_encoder_name_or_path: Union[str, os.PathLike], 19 | document_path: Union[str, os.PathLike], 20 | output_folder: Union[str, os.PathLike], 21 | document_file_type: str = "jsonl", 22 | passage_encoder_name_or_path: Optional[Union[str, os.PathLike]] = None, 23 | indexer_class: str = "relik.retriever.indexers.inmemory.InMemoryDocumentIndex", 24 | batch_size: int = 512, 25 | num_workers: int = 4, 26 | passage_max_length: int = 64, 27 | device: str = "cuda", 28 | index_device: str = "cpu", 29 | precision: str = "fp32", 30 | push_to_hub: bool = False, 31 | repo_id: Optional[str] = None, 32 | ): 33 | if push_to_hub: 34 | if not repo_id: 35 | raise ValueError("`repo_id` must be provided when `push_to_hub=True`") 36 | 37 | logger.info("Loading documents") 38 | if document_file_type == "jsonl": 39 | documents = DocumentStore.from_file(document_path) 40 | elif document_file_type == "csv": 41 | documents = DocumentStore.from_tsv( 42 | document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True 43 | ) 44 | elif document_file_type == "tsv": 45 | documents = DocumentStore.from_tsv( 46 | document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True 47 | ) 48 | else: 49 | raise ValueError( 50 | f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv" 51 | ) 52 | 53 | logger.info("Loading document index") 54 | logger.info(f"Loaded {len(documents)} documents") 55 | indexer = get_callable_from_string(indexer_class)( 56 | documents, device=index_device, precision=precision 57 | ) 58 | 59 | retriever = GoldenRetriever( 60 | question_encoder=question_encoder_name_or_path, 61 | passage_encoder=passage_encoder_name_or_path, 62 | document_index=indexer, 63 | device=device, 64 | precision=precision, 65 | ) 66 | retriever.eval() 67 | 68 | retriever.index( 69 | batch_size=batch_size, 70 | num_workers=num_workers, 71 | max_length=passage_max_length, 72 | force_reindex=True, 73 | precision=precision, 74 | ) 75 | 76 | output_folder = Path(output_folder) 77 | output_folder.mkdir(exist_ok=True, parents=True) 78 | retriever.document_index.save_pretrained( 79 | output_folder, push_to_hub=push_to_hub, model_id=repo_id 80 | ) 81 | 82 | 83 | if __name__ == "__main__": 84 | arg_parser = argparse.ArgumentParser("Create retriever index.") 85 | arg_parser.add_argument("--question-encoder-name-or-path", type=str, required=True) 86 | arg_parser.add_argument("--document-path", type=str, required=True) 87 | arg_parser.add_argument("--passage-encoder-name-or-path", type=str) 88 | arg_parser.add_argument( 89 | "--indexer_class", 90 | type=str, 91 | default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex", 92 | ) 93 | arg_parser.add_argument("--document-file-type", type=str, default="jsonl") 94 | arg_parser.add_argument("--output-folder", type=str, required=True) 95 | arg_parser.add_argument("--batch-size", type=int, default=128) 96 | arg_parser.add_argument("--passage-max-length", type=int, default=64) 97 | arg_parser.add_argument("--device", type=str, default="cuda") 98 | arg_parser.add_argument("--index-device", type=str, default="cpu") 99 | arg_parser.add_argument("--precision", type=str, default="fp32") 100 | arg_parser.add_argument("--num-workers", type=int, default=4) 101 | arg_parser.add_argument("--push-to-hub", action="store_true") 102 | arg_parser.add_argument("--repo-id", type=str) 103 | 104 | build_index(**vars(arg_parser.parse_args())) 105 | -------------------------------------------------------------------------------- /scripts/docker/gunicorn_conf.py: -------------------------------------------------------------------------------- 1 | import json 2 | import multiprocessing 3 | import os 4 | 5 | max_cores_str = os.getenv("MAX_CORES", "1") 6 | workers_per_core_str = os.getenv("WORKERS_PER_CORE", "1") 7 | max_workers_str = os.getenv("MAX_WORKERS", "1") 8 | use_max_workers = None 9 | if max_workers_str: 10 | use_max_workers = int(max_workers_str) 11 | web_concurrency_str = os.getenv("WEB_CONCURRENCY", None) 12 | 13 | host = os.getenv("HOST", "0.0.0.0") 14 | port = os.getenv("PORT", "80") 15 | bind_env = os.getenv("BIND", None) 16 | use_loglevel = os.getenv("LOG_LEVEL", "info") 17 | if bind_env: 18 | use_bind = bind_env 19 | else: 20 | use_bind = f"{host}:{port}" 21 | 22 | cores = int(max_cores_str) 23 | workers_per_core = float(workers_per_core_str) 24 | default_web_concurrency = workers_per_core * cores 25 | if web_concurrency_str: 26 | web_concurrency = int(web_concurrency_str) 27 | assert web_concurrency > 0 28 | else: 29 | web_concurrency = max(int(default_web_concurrency), 1) 30 | if use_max_workers: 31 | web_concurrency = min(web_concurrency, use_max_workers) 32 | accesslog_var = os.getenv("ACCESS_LOG", "-") 33 | use_accesslog = accesslog_var or None 34 | errorlog_var = os.getenv("ERROR_LOG", "-") 35 | use_errorlog = errorlog_var or None 36 | graceful_timeout_str = os.getenv("GRACEFUL_TIMEOUT", "500") 37 | timeout_str = os.getenv("TIMEOUT", "500") 38 | keepalive_str = os.getenv("KEEP_ALIVE", "5") 39 | 40 | # Gunicorn config variables 41 | loglevel = use_loglevel 42 | workers = web_concurrency 43 | bind = use_bind 44 | errorlog = use_errorlog 45 | worker_tmp_dir = "/dev/shm" 46 | accesslog = use_accesslog 47 | graceful_timeout = int(graceful_timeout_str) 48 | timeout = int(timeout_str) 49 | keepalive = int(keepalive_str) 50 | 51 | 52 | # For debugging and testing 53 | log_data = { 54 | "loglevel": loglevel, 55 | "workers": workers, 56 | "bind": bind, 57 | "graceful_timeout": graceful_timeout, 58 | "timeout": timeout, 59 | "keepalive": keepalive, 60 | "errorlog": errorlog, 61 | "accesslog": accesslog, 62 | # Additional, non-gunicorn variables 63 | "workers_per_core": workers_per_core, 64 | "use_max_workers": use_max_workers, 65 | "host": host, 66 | "port": port, 67 | } 68 | # pretty print the log data 69 | print(json.dumps(log_data, indent=2)) 70 | -------------------------------------------------------------------------------- /scripts/docker/pre-start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cat <<"EOF" 4 | 5 | # _/_/_/ _/_/_/_/ _/ _/_/_/ _/ _/ 6 | # _/ _/ _/ _/ _/ _/ _/ 7 | # _/_/_/ _/_/_/ _/ _/ _/_/ 8 | # _/ _/ _/ _/ _/ _/ _/ 9 | # _/ _/ _/_/_/_/ _/_/_/_/ _/_/_/ _/ _/ 10 | 11 | # ReLiK Inference API 12 | 13 | # EOF 14 | 15 | # pre-download the model if provided in input 16 | if [ "$1" ]; then 17 | # micromamba run -n base python -c "from relik import Relik; Relik.from_pretrained('$1')" 18 | python3 -c "from relik import Relik; Relik.from_pretrained('$1')" 19 | fi 20 | -------------------------------------------------------------------------------- /scripts/docker/start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # Pre-start 5 | # checkmark font for fancy log 6 | CHECK_MARK="\033[0;32m\xE2\x9C\x94\033[0m" 7 | # usage text 8 | USAGE="$(basename "$0") [-h --help] [-c --config] [-p --precision] [-d --device] [--retriever] [--retriever-device] 9 | [--retriever-precision] [--index-device] [--index-precision] [--reader] [--reader-device] [--reader-precision] 10 | 11 | where: 12 | -h --help Show this help text 13 | -c --config Config name (from HuggingFace) or path 14 | -p --precision Training precision, default '32'. 15 | -d --device Device to use, default 'cpu'. 16 | --retriever Override retriever model name. 17 | --retriever-device Override retriever device. 18 | --retriever-precision Override retriever precision. 19 | --index-device Override index device. 20 | --index-precision Override index precision. 21 | --reader Override reader model name. 22 | --reader-device Override reader device. 23 | --reader-precision Override reader precision. 24 | --annotation-type Annotation type ('char', 'word'), default 'char'. 25 | " 26 | 27 | # Transform long options to short ones 28 | for arg in "$@"; do 29 | shift 30 | case "$arg" in 31 | '--help') set -- "$@" '-h' ;; 32 | '--config') set -- "$@" '-c' ;; 33 | '--precision') set -- "$@" '-p' ;; 34 | '--device') set -- "$@" '-d' ;; 35 | '--retriever') set -- "$@" '-q' ;; 36 | '--retriever-device') set -- "$@" '-w' ;; 37 | '--retriever-precision') set -- "$@" '-e' ;; 38 | '--index-device') set -- "$@" '-r' ;; 39 | '--index-precision') set -- "$@" '-t' ;; 40 | '--reader') set -- "$@" '-y' ;; 41 | '--reader-device') set -- "$@" '-a' ;; 42 | '--reader-precision') set -- "$@" '-b' ;; 43 | '--annotation-type') set -- "$@" '-f' ;; 44 | *) set -- "$@" "$arg" ;; 45 | esac 46 | done 47 | 48 | # check for named params 49 | #while [ $OPTIND -le "$#" ]; do 50 | while getopts ":hc:p:d:q:w:e:r:t:y:a:b:f:" opt; do 51 | case $opt in 52 | h) 53 | printf "%s$USAGE" && exit 0 54 | ;; 55 | c) 56 | export RELIK_PRETRAINED=$OPTARG 57 | ;; 58 | p) 59 | export PRECISION=$OPTARG 60 | ;; 61 | d) 62 | export DEVICE=$OPTARG 63 | ;; 64 | q) 65 | export RETRIEVER_MODEL_NAME=$OPTARG 66 | ;; 67 | w) 68 | export RETRIEVER_DEVICE=$OPTARG 69 | ;; 70 | e) 71 | export RETRIEVER_PRECISION=$OPTARG 72 | ;; 73 | r) 74 | export INDEX_DEVICE=$OPTARG 75 | ;; 76 | t) 77 | export INDEX_PRECISION=$OPTARG 78 | ;; 79 | y) 80 | export READER_MODEL_NAME=$OPTARG 81 | ;; 82 | a) 83 | export READER_DEVICE=$OPTARG 84 | ;; 85 | b) 86 | export READER_PRECISION=$OPTARG 87 | ;; 88 | f) 89 | export ANNOTATION_TYPE=$OPTARG 90 | ;; 91 | \?) 92 | echo "Invalid option -$OPTARG" >&2 && echo "$USAGE" && exit 0 93 | ;; 94 | esac 95 | done 96 | 97 | # FastAPI app location 98 | if [ -z "$APP_MODULE" ]; then 99 | # echo "APP_MODULE not set, using default" 100 | export APP_MODULE=relik.inference.serve.backend.ray:server 101 | fi 102 | # echo "APP_MODULE set to $APP_MODULE" 103 | 104 | # If there's a prestart.sh script in the /app directory, run it before starting 105 | if [ -z "$PRE_START_PATH" ]; then 106 | # echo "PRE_START_PATH not set, using default" 107 | PRE_START_PATH=scripts/docker/pre-start.sh 108 | fi 109 | # echo "PRE_START_PATH set to $PRE_START_PATH" 110 | 111 | if [ -f $PRE_START_PATH ]; then 112 | . "$PRE_START_PATH" $RELIK_PRETRAINED 113 | else 114 | echo "There is no script $PRE_START_PATH" 115 | fi 116 | 117 | # Start Ray Serve with the app 118 | exec serve run "$APP_MODULE" --host 0.0.0.0 --port 8000 119 | # micromamba run -n base serve run "$APP_MODULE" --host 0.0.0.0 --port 8000 120 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/debug.py: -------------------------------------------------------------------------------- 1 | from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter 2 | from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer 3 | from relik.inference.data.window.manager import WindowManager 4 | 5 | 6 | document = { 7 | "doc_id": "-DOCSTART- (956testa SOCCER)", 8 | "doc_text": "SOCCER - RESULTS OF SOUTH KOREAN PRO-SOCCER GAMES . SEOUL 1996-08-30 Results of South Korean pro-soccer games played on Thursday . Pohang 3 Ulsan 2 ( halftime 1-0 ) Puchon 2 Chonbuk 1 ( halftime 1-1 ) Standings after games played on Thursday ( tabulate under - won , drawn , lost , goals for , goals against , points ) : W D L G / F G / A P Puchon 3 1 0 6 1 10 Chonan 3 0 1 13 10 9 Pohang 2 1 1 11 10 7 Suwan 1 3 0 7 3 6 Ulsan 1 0 2 8 9 3 Anyang 0 3 1 6 9 3 Chonnam 0 2 1 4 5 2 Pusan 0 2 1 3 7 2 Chonbuk 0 0 3 3 7 0", 9 | "doc_annotations": [ 10 | [20, 32, "South Korea"], 11 | [52, 57, "Seoul"], 12 | [80, 92, "South Korea"], 13 | [131, 137, "Pohang Steelers"], 14 | [140, 145, "Ulsan Hyundai FC"], 15 | [165, 171, "--NME--"], 16 | [174, 181, "--NME--"], 17 | [341, 347, "--NME--"], 18 | [361, 367, "--NME--"], 19 | [382, 388, "Pohang Steelers"], 20 | [403, 408, "--NME--"], 21 | [421, 426, "Ulsan Hyundai FC"], 22 | [439, 445, "Anyang LG Cheetahs"], 23 | [458, 465, "--NME--"], 24 | [478, 483, "--NME--"], 25 | [496, 503, "--NME--"], 26 | ], 27 | } 28 | 29 | tokenizer = SpacyTokenizer(language="en") 30 | sentence_splitter = WindowSentenceSplitter(window_size=32, window_stride=16) 31 | 32 | window_manager = WindowManager(splitter=sentence_splitter, tokenizer=tokenizer) 33 | 34 | doc_info = document["doc_id"] 35 | doc_info = doc_info.replace("-DOCSTART-", "").replace("(", "").replace(")", "").strip() 36 | doc_id, doc_topic = doc_info.split(" ") 37 | 38 | if "testa" in doc_id: 39 | split = "dev" 40 | elif "testb" in doc_id: 41 | split = "test" 42 | else: 43 | split = "train" 44 | 45 | doc_id = doc_id.replace("testa", "").replace("testb", "").strip() 46 | doc_id = int(doc_id) 47 | 48 | windowized_document = window_manager.create_windows( 49 | document["doc_text"], 50 | 32, 51 | 16, 52 | doc_ids=doc_id, 53 | doc_topic=doc_topic, 54 | ) 55 | 56 | print(windowized_document) 57 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/aida_to_dpr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Optional, Union 5 | 6 | from tqdm import tqdm 7 | 8 | from relik.common.log import get_logger 9 | 10 | logger = get_logger() 11 | 12 | 13 | def aida_to_dpr( 14 | conll_path: Union[str, os.PathLike], 15 | output_path: Union[str, os.PathLike], 16 | documents_path: Optional[Union[str, os.PathLike]] = None, 17 | title_map: Optional[Union[str, os.PathLike]] = None, 18 | ) -> List[Dict[str, Any]]: 19 | documents = {} 20 | output_path = Path(output_path) 21 | output_path.parent.mkdir(parents=True, exist_ok=True) 22 | # read entities definitions 23 | logger.info(f"Loading documents from {documents_path}") 24 | with open(documents_path, "r") as f: 25 | for line in f: 26 | line_data = json.loads(line) 27 | title = line_data["text"].strip() 28 | definition = line_data["metadata"]["definition"].strip() 29 | documents[title] = definition 30 | 31 | if title_map is not None: 32 | with open(title_map, "r") as f: 33 | title_map = json.load(f) 34 | 35 | # store dpr data 36 | dpr = [] 37 | # lower case titles 38 | title_to_lower_map = {title.lower(): title for title in documents.keys()} 39 | # store missing entities 40 | missing = set() 41 | # Read AIDA file 42 | with open(conll_path, "r") as f, open(output_path, "w") as f_out: 43 | for line in tqdm(f, desc="Processing AIDA data"): 44 | sentence = json.loads(line) 45 | # for sentence in aida_data: 46 | question = sentence["text"] 47 | positive_pssgs = [] 48 | for idx, entity in enumerate(sentence["window_labels"]): 49 | entity = entity[2] 50 | if not entity: 51 | continue 52 | entity = entity.strip().lower().replace("_", " ") 53 | # if title_map and entity in title_to_lower_map: 54 | entity = title_to_lower_map.get(entity, entity) 55 | if entity in documents: 56 | def_text = documents[entity] 57 | positive_pssgs.append( 58 | { 59 | "title": title_to_lower_map[entity.lower()], 60 | "text": f"{title_to_lower_map[entity.lower()]} {def_text}", 61 | "passage_id": f"{sentence['doc_id']}_{sentence['offset']}_{idx}", 62 | } 63 | ) 64 | else: 65 | missing.add(entity) 66 | print(f"Entity {entity} not found in definitions") 67 | 68 | if len(positive_pssgs) == 0: 69 | continue 70 | 71 | dpr_sentence = { 72 | "id": f"{sentence['doc_id']}_{sentence['offset']}", 73 | "doc_topic": sentence["doc_topic"], 74 | "question": question, 75 | "answers": "", 76 | "positive_ctxs": positive_pssgs, 77 | "negative_ctxs": "", 78 | "hard_negative_ctxs": "", 79 | } 80 | f_out.write(json.dumps(dpr_sentence) + "\n") 81 | 82 | for e in missing: 83 | print(e) 84 | print(f"Number of missing entities: {len(missing)}") 85 | 86 | return dpr 87 | 88 | 89 | if __name__ == "__main__": 90 | import argparse 91 | 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("input", type=str, help="Path to AIDA file") 94 | parser.add_argument("output", type=str, help="Path to output file") 95 | parser.add_argument("documents", type=str, help="Path to entities definitions file") 96 | parser.add_argument("--title_map", type=str, help="Path to title map file") 97 | args = parser.parse_args() 98 | 99 | # Convert to DPR 100 | aida_to_dpr(args.input, args.output, args.documents, args.title_map) 101 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/blink/create_random_sample_coverage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from relik.common.log import get_logger 12 | 13 | logger = get_logger() 14 | 15 | 16 | def sample( 17 | input_file: Union[str, os.PathLike], 18 | output_file: Union[str, os.PathLike], 19 | n_samples: int = 5, 20 | seed: int = 42, 21 | ): 22 | documents = defaultdict(list) 23 | 24 | logger.info(f"Loading data from {input_file}") 25 | with open(input_file) as f: 26 | for i, line in tqdm(enumerate(f)): 27 | try: 28 | sample = json.loads(line) 29 | # data.append(sample) 30 | labels = [l[-1] for l in sample["window_labels"]] 31 | for label in labels: 32 | documents[label].append(i) 33 | except json.JSONDecodeError: 34 | logger.error(f"Error parsing line {i}") 35 | continue 36 | 37 | logger.info("Sampling data") 38 | # Random sample from in-distribution documents 39 | np.random.seed(seed) 40 | documents = { 41 | k: np.random.choice(v, min(len(v), n_samples), replace=False).tolist() 42 | for k, v in tqdm(documents.items()) 43 | } 44 | 45 | output_file_path = Path(output_file) 46 | output_file_path.parent.mkdir(parents=True, exist_ok=True) 47 | logger.info(f"Saving sampled data to {output_file}") 48 | with open(output_file, "w") as f: 49 | json.dump(documents, f, indent=2) 50 | 51 | 52 | if __name__ == "__main__": 53 | arg_parser = argparse.ArgumentParser() 54 | arg_parser.add_argument("input_file", type=str) 55 | arg_parser.add_argument("output_file", type=str) 56 | arg_parser.add_argument("--n_samples", type=int, required=False, default=5) 57 | 58 | sample(**vars(arg_parser.parse_args())) 59 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/blink/sample_from_data_coverate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import Union 7 | 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | from relik.common.log import get_logger 12 | 13 | logger = get_logger() 14 | 15 | 16 | def sample( 17 | input_file: Union[str, os.PathLike], 18 | output_file: Union[str, os.PathLike], 19 | sample_index_file: Union[str, os.PathLike], 20 | ): 21 | 22 | logger.info(f"Loading sample index from {sample_index_file}") 23 | with open(sample_index_file) as f: 24 | sample_index = json.load(f) 25 | # get all unique values 26 | sample_index = set( 27 | [item for sublist in sample_index.values() for item in sublist] 28 | ) 29 | 30 | output_file_path = Path(output_file) 31 | output_file_path.parent.mkdir(parents=True, exist_ok=True) 32 | logger.info(f"Loading data from {input_file} and sampling") 33 | logger.info(f"Saving sampled data to {output_file}") 34 | with open(input_file) as f, open(output_file, "w") as f_out: 35 | for i, line in tqdm(enumerate(f)): 36 | if int(i) in sample_index: 37 | try: 38 | f_out.write(json.dumps(json.loads(line)) + "\n") 39 | except json.JSONDecodeError: 40 | logger.error(f"Error parsing line {i}") 41 | continue 42 | 43 | 44 | if __name__ == "__main__": 45 | arg_parser = argparse.ArgumentParser() 46 | arg_parser.add_argument("input_file", type=str) 47 | arg_parser.add_argument("output_file", type=str) 48 | arg_parser.add_argument("sample_index_file", type=str) 49 | 50 | sample(**vars(arg_parser.parse_args())) 51 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/create_index.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import os 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | import torch 8 | 9 | from relik.retriever import GoldenRetriever 10 | from relik.common.utils import get_logger, get_callable_from_string 11 | from relik.retriever.indexers.document import DocumentStore 12 | 13 | logger = get_logger(__name__) 14 | 15 | 16 | @torch.no_grad() 17 | def build_index( 18 | question_encoder_name_or_path: Union[str, os.PathLike], 19 | document_path: Union[str, os.PathLike], 20 | output_folder: Union[str, os.PathLike], 21 | document_file_type: str = "jsonl", 22 | passage_encoder_name_or_path: Optional[Union[str, os.PathLike]] = None, 23 | indexer_class: str = "relik.retriever.indexers.inmemory.InMemoryDocumentIndex", 24 | batch_size: int = 512, 25 | num_workers: int = 4, 26 | passage_max_length: int = 64, 27 | device: str = "cuda", 28 | index_device: str = "cpu", 29 | precision: str = "fp32", 30 | ): 31 | logger.info("Loading documents") 32 | if document_file_type == "jsonl": 33 | documents = DocumentStore.from_file(document_path) 34 | elif document_file_type == "csv": 35 | documents = DocumentStore.from_tsv( 36 | document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True 37 | ) 38 | elif document_file_type == "tsv": 39 | documents = DocumentStore.from_tsv( 40 | document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True 41 | ) 42 | else: 43 | raise ValueError( 44 | f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv" 45 | ) 46 | 47 | logger.info("Loading document index") 48 | logger.info(f"Loaded {len(documents)} documents") 49 | indexer = get_callable_from_string(indexer_class)( 50 | documents, device=index_device, precision=precision 51 | ) 52 | 53 | retriever = GoldenRetriever( 54 | question_encoder=question_encoder_name_or_path, 55 | passage_encoder=passage_encoder_name_or_path, 56 | document_index=indexer, 57 | device=device, 58 | precision=precision, 59 | ) 60 | retriever.eval() 61 | 62 | retriever.index( 63 | batch_size=batch_size, 64 | num_workers=num_workers, 65 | max_length=passage_max_length, 66 | force_reindex=True, 67 | precision=precision, 68 | ) 69 | 70 | output_folder = Path(output_folder) 71 | output_folder.mkdir(exist_ok=True, parents=True) 72 | retriever.save_pretrained(output_folder) 73 | 74 | 75 | if __name__ == "__main__": 76 | arg_parser = argparse.ArgumentParser() 77 | arg_parser.add_argument("--question_encoder_name_or_path", type=str, required=True) 78 | arg_parser.add_argument("--document_path", type=str, required=True) 79 | arg_parser.add_argument("--passage_encoder_name_or_path", type=str) 80 | arg_parser.add_argument( 81 | "--indexer_class", 82 | type=str, 83 | default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex", 84 | ) 85 | arg_parser.add_argument("--document_file_type", type=str, default="jsonl") 86 | arg_parser.add_argument("--output_folder", type=str, required=True) 87 | arg_parser.add_argument("--batch_size", type=int, default=128) 88 | arg_parser.add_argument("--passage_max_length", type=int, default=64) 89 | arg_parser.add_argument("--device", type=str, default="cuda") 90 | arg_parser.add_argument("--index_device", type=str, default="cpu") 91 | arg_parser.add_argument("--precision", type=str, default="fp32") 92 | 93 | build_index(**vars(arg_parser.parse_args())) 94 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/explore_blink.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | if __name__ == "__main__": 6 | # arg_parser = argparse.ArgumentParser() 7 | # arg_parser.add_argument("--question_encoder_name_or_path", type=str, required=True) 8 | # arg_parser.add_argument("--document_path", type=str, required=True) 9 | # arg_parser.add_argument("--passage_encoder_name_or_path", type=str) 10 | # arg_parser.add_argument( 11 | # "--indexer_class", 12 | # type=str, 13 | # default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex", 14 | # ) 15 | # arg_parser.add_argument("--document_file_type", type=str, default="jsonl") 16 | # arg_parser.add_argument("--output_folder", type=str, required=True) 17 | # arg_parser.add_argument("--batch_size", type=int, default=128) 18 | # arg_parser.add_argument("--passage_max_length", type=int, default=64) 19 | # arg_parser.add_argument("--device", type=str, default="cuda") 20 | # arg_parser.add_argument("--index_device", type=str, default="cpu") 21 | # arg_parser.add_argument("--precision", type=str, default="fp32") 22 | 23 | # build_index(**vars(arg_parser.parse_args())) 24 | 25 | with open("/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/first_1M.jsonl") as f: 26 | data = [json.loads(line) for line in f] 27 | 28 | -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/save_retriever_from_checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/scripts/old-scripts/data/retriever/save_retriever_from_checkpoint.py -------------------------------------------------------------------------------- /scripts/old-scripts/data/retriever/triplets_to_dpr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | from typing import Union, Dict, List, Optional, Any 5 | 6 | from tqdm import tqdm 7 | 8 | # from transformers import AutoTokenizer, BertTokenizer 9 | 10 | 11 | def aida_to_dpr( 12 | conll_path: Union[str, os.PathLike], 13 | output_path: Union[str, os.PathLike], 14 | definitions_path: Optional[Union[str, os.PathLike]] = None, 15 | ) -> List[Dict[str, Any]]: 16 | definitions = {} 17 | output_path = Path(output_path) 18 | output_path.parent.mkdir(parents=True, exist_ok=True) 19 | # read entities definitions 20 | with open(definitions_path, "r") as f: 21 | for line in f: 22 | line_data = json.loads(line) 23 | title = line_data["text"].strip() 24 | definition = line_data["metadata"]["definition"].strip() 25 | definitions[title] = definition 26 | # title, definition = line.split(" ") 27 | # title = title.strip() 28 | # definition = definition.strip() 29 | # definitions[title] = definition 30 | 31 | dpr = [] 32 | 33 | title_to_lower_map = {title.lower(): title for title in definitions.keys()} 34 | 35 | missing = set() 36 | 37 | # Read AIDA file 38 | with open(conll_path, "r") as f, open(output_path, "w") as f_out: 39 | for line in tqdm(f): 40 | sentence = json.loads(line) 41 | # for sentence in aida_data: 42 | question = sentence["text"] 43 | positive_pssgs = [] 44 | for idx, triplet in enumerate(sentence["triplets"]): 45 | relation = triplet["relation"]["name"] 46 | if not relation: 47 | continue 48 | if relation in definitions: 49 | def_text = definitions[relation] 50 | positive_pssgs.append( 51 | { 52 | "title": title_to_lower_map[relation.lower()], 53 | "text": f"{title_to_lower_map[relation.lower()]} {def_text}", 54 | "passage_id": f"{sentence['doc_id']}_{sentence['offset']}_{idx}", 55 | } 56 | ) 57 | else: 58 | missing.add(relation) 59 | # print(f"Entity {entity} not found in definitions") 60 | 61 | if len(positive_pssgs) == 0: 62 | continue 63 | 64 | dpr_sentence = { 65 | "id": f"{sentence['doc_id']}_{sentence['offset']}", 66 | "doc_topic": sentence["doc_topic"] if "doc_topic" in sentence else "", 67 | "question": question, 68 | "answers": "", 69 | "positive_ctxs": positive_pssgs, 70 | "negative_ctxs": "", 71 | "hard_negative_ctxs": "", 72 | } 73 | f_out.write(json.dumps(dpr_sentence) + "\n") 74 | 75 | for e in missing: 76 | print(e) 77 | print(f"Number of missing entities: {len(missing)}") 78 | 79 | return dpr 80 | 81 | 82 | if __name__ == "__main__": 83 | import argparse 84 | 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("input", type=str, help="Path to AIDA file") 87 | parser.add_argument("output", type=str, help="Path to output file") 88 | parser.add_argument( 89 | "--definitions", type=str, help="Path to entities definitions file" 90 | ) 91 | args = parser.parse_args() 92 | 93 | # Convert to DPR 94 | aida_to_dpr(args.input, args.output, args.definitions) -------------------------------------------------------------------------------- /scripts/old-scripts/data/split_aida.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from tqdm import tqdm 4 | 5 | with open("data/processed/aida_from_edo.jsonl", "r") as f: 6 | aida = [json.loads(l) for l in f] 7 | 8 | aida_train, aida_dev, aida_test = [], [], [] 9 | 10 | for document in tqdm(aida): 11 | doc_info = document["doc_id"] 12 | 13 | # clean doc_info, e.g. "-DOCSTART- (1 EU)" 14 | doc_info = ( 15 | doc_info.replace("-DOCSTART-", "").replace("(", "").replace(")", "").strip() 16 | ) 17 | doc_id, doc_topic = doc_info.split(" ") 18 | 19 | if "testa" in doc_id: 20 | split = "dev" 21 | elif "testb" in doc_id: 22 | split = "test" 23 | else: 24 | split = "train" 25 | 26 | doc_id = doc_id.replace("testa", "").replace("testb", "").strip() 27 | doc_id = int(doc_id) 28 | 29 | document["doc_id"] = doc_id 30 | 31 | if split == "train": 32 | aida_train.append(document) 33 | elif split == "dev": 34 | aida_dev.append(document) 35 | else: 36 | aida_test.append(document) 37 | 38 | with open("data/processed/aida_train.jsonl", "w") as f: 39 | for document in aida_train: 40 | f.write(json.dumps(document) + "\n") 41 | 42 | with open("data/processed/aida_dev.jsonl", "w") as f: 43 | for document in aida_dev: 44 | f.write(json.dumps(document) + "\n") 45 | 46 | with open("data/processed/aida_test.jsonl", "w") as f: 47 | for document in aida_test: 48 | f.write(json.dumps(document) + "\n") 49 | -------------------------------------------------------------------------------- /scripts/old-scripts/evaluate/evaluate_re.py: -------------------------------------------------------------------------------- 1 | from relik import Relik 2 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 3 | from relik.retriever import GoldenRetriever 4 | from relik.inference.data.objects import TaskType 5 | from relik.reader.utils.relation_matching_eval import StrongMatchingPerRelation, StrongMatching 6 | from relik.inference.data.objects import AnnotationType 7 | 8 | import json 9 | import argparse 10 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples 11 | 12 | def evaluate(reader_path, question_encoder_path, document_index_path, input_file, output_file, max_triplets=8, use_predefined_spans=False): 13 | 14 | reader = RelikReaderForTripletExtraction(reader_path, 15 | dataset_kwargs={"use_nme": False, "max_triplets": max_triplets}) 16 | 17 | 18 | retriever = { 19 | "triplet": GoldenRetriever( 20 | question_encoder=question_encoder_path, 21 | document_index=document_index_path, 22 | ), 23 | } 24 | 25 | relik = Relik(reader=reader, retriever=retriever, top_k=max_triplets, task=TaskType.TRIPLET, device="cuda", window_size="sentence") 26 | 27 | samples = list(load_relik_reader_samples(input_file)) 28 | # add text field to samples from joining words if there is no text field 29 | for id, sample in enumerate(samples): 30 | # del sample._d["span_candidates"] 31 | if sample._d.get("text") is None: 32 | sample._d["text"] = " ".join(sample.words) 33 | sample.doc_id = id 34 | results = relik(windows=samples, num_workers=4, device="cuda", progress_bar=True, annotation_type=AnnotationType.WORD, return_also_windows=True, use_predefined_spans=use_predefined_spans, relation_threshold=0.5) 35 | windows = [] 36 | for sample in results: 37 | windows.extend(sample.windows) 38 | with open(output_file, "w") as f: 39 | for sample in windows: 40 | f.write(sample.to_jsons() + "\n") 41 | 42 | strong_matching_metric = StrongMatchingPerRelation() 43 | results = list(windows) 44 | for k, v in strong_matching_metric(results).items(): 45 | print(f"test_{k}", v) 46 | 47 | strong_matching_metric = StrongMatching() 48 | for k, v in strong_matching_metric(results).items(): 49 | print(f"test_{k}", v) 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--reader_path", type=str, required=True) 54 | parser.add_argument("--question_encoder_path", type=str, required=True) 55 | parser.add_argument("--document_index_path", type=str, required=True) 56 | parser.add_argument("--input_file", type=str, required=True) 57 | parser.add_argument("--output_file", type=str, required=True) 58 | parser.add_argument("--max_triplets", type=int, default=8) 59 | parser.add_argument("--use_predefined_spans", action="store_true") 60 | args = parser.parse_args() 61 | evaluate(args.reader_path, args.question_encoder_path, args.document_index_path, args.input_file, args.output_file, args.max_triplets, args.use_predefined_spans) -------------------------------------------------------------------------------- /scripts/old-scripts/evaluate/evaluate_re_bio.py: -------------------------------------------------------------------------------- 1 | from relik import Relik 2 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction 3 | from relik.retriever import GoldenRetriever 4 | from relik.inference.data.objects import TaskType 5 | from relik.reader.utils.relation_matching_eval import StrongMatchingPerRelation, StrongMatching 6 | from relik.inference.data.objects import AnnotationType 7 | 8 | import json 9 | import pandas as pd 10 | import argparse 11 | import os 12 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples 13 | 14 | def evaluate(reader_path, question_encoder_path, document_index_path, input_file, output_file, max_triplets=8, use_predefined_spans=False): 15 | # input_file path fileame bio-map.tsv 16 | mapping_classes = pd.read_csv(os.path.join(os.path.dirname(input_file), "bio-map.tsv"), sep="\t") 17 | mapping_dict = dict(zip(mapping_classes.label, mapping_classes.hierarchy)) 18 | 19 | reader = RelikReaderForTripletExtraction(reader_path, 20 | dataset_kwargs={"use_nme": False, "max_triplets": max_triplets}) 21 | 22 | 23 | retriever = { 24 | "triplet": GoldenRetriever( 25 | question_encoder=question_encoder_path, 26 | document_index=document_index_path, 27 | ), 28 | } 29 | 30 | relik = Relik(reader=reader, retriever=retriever, top_k=max_triplets, task=TaskType.TRIPLET, device="cuda", window_size="sentence") 31 | 32 | samples = list(load_relik_reader_samples(input_file)) 33 | # add text field to samples from joining words if there is no text field 34 | for id, sample in enumerate(samples): 35 | if sample._d.get("text") is None: 36 | sample._d["text"] = " ".join(sample.words) 37 | sample.doc_id = id 38 | results = relik(windows=samples, num_workers=4, device="cuda", progress_bar=True, annotation_type=AnnotationType.WORD, return_also_windows=True, use_predefined_spans=use_predefined_spans, relation_threshold=0.5) 39 | windows = [] 40 | for sample in results: 41 | windows.extend(sample.windows) 42 | with open(output_file, "w") as f: 43 | for sample in windows: 44 | f.write(sample.to_jsons() + "\n") 45 | 46 | for sample in windows: 47 | for triplet in sample.predicted_relations: 48 | triplet["relation"]["name"] = mapping_dict[triplet["relation"]["name"]] 49 | 50 | strong_matching_metric = StrongMatchingPerRelation() 51 | results = list(windows) 52 | for k, v in strong_matching_metric(results).items(): 53 | print(f"test_{k}", v) 54 | 55 | strong_matching_metric = StrongMatching() 56 | for k, v in strong_matching_metric(results).items(): 57 | print(f"test_{k}", v) 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--reader_path", type=str, required=True) 62 | parser.add_argument("--question_encoder_path", type=str, required=True) 63 | parser.add_argument("--document_index_path", type=str, required=True) 64 | parser.add_argument("--input_file", type=str, required=True) 65 | parser.add_argument("--output_file", type=str, required=True) 66 | parser.add_argument("--max_triplets", type=int, default=8) 67 | parser.add_argument("--use_predefined_spans", action="store_true") 68 | args = parser.parse_args() 69 | evaluate(args.reader_path, args.question_encoder_path, args.document_index_path, args.input_file, args.output_file, args.max_triplets, args.use_predefined_spans) -------------------------------------------------------------------------------- /scripts/old-scripts/predict/predict_aida.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pprintpp import pprint 3 | 4 | from relik.inference.annotator import Relik 5 | from relik.inference.data.objects import TaskType 6 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction 7 | from relik.retriever.pytorch_modules.model import GoldenRetriever 8 | 9 | 10 | def main(): 11 | # retriever = GoldenRetriever( 12 | # question_encoder="riccorl/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder", 13 | # document_index="riccorl/retriever-relik-entity-linking-aida-wikipedia-base-index", 14 | # device="cuda", 15 | # index_device="cpu", 16 | # precision=16, 17 | # index_precision=32, 18 | # ) 19 | # reader = RelikReaderForSpanExtraction( 20 | # "riccorl/reader-relik-entity-linking-aida-wikipedia-small" 21 | # ) 22 | 23 | # relik = Relik( 24 | # retriever=retriever, 25 | # reader=reader, 26 | # top_k=100, 27 | # window_size=32, 28 | # window_stride=16, 29 | # task=TaskType.SPAN, 30 | # ) 31 | # relik.save_pretrained( 32 | # "relik-entity-linking-aida-wikipedia-tiny", 33 | # save_weights=False, 34 | # push_to_hub=True, 35 | # # reader_model_id="reader-relik-entity-linking-aida-wikipedia-small", 36 | # # retriever_model_id="retriever-relik-entity-linking-aida-wikipedia-base", 37 | # ) 38 | 39 | reader = RelikReaderForSpanExtraction( 40 | "riccorl/relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-large", 41 | device="cuda", 42 | precision="fp16", # , reader_device="cpu", reader_precision="fp32" 43 | dataset_kwargs={"use_nme": True}, 44 | ) 45 | 46 | relik = Relik(reader=reader) 47 | 48 | with open( 49 | "/home/ric/Projects/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl" 50 | ) as f: 51 | data = [json.loads(line) for line in f] 52 | 53 | text = [data["text"] for data in data] 54 | candidates = [data["span_candidates"] for data in data] 55 | 56 | predictions = relik( 57 | text, 58 | candidates=candidates, 59 | window_size="none", 60 | annotation_type="char", 61 | progress_bar=True, 62 | reader_batch_size=32, 63 | ) 64 | 65 | output = [] 66 | 67 | for p, s in zip(predictions, data): 68 | output.append( 69 | { 70 | "doc_id": s["doc_id"], 71 | "window_id": s["window_id"], 72 | "text": s["text"], 73 | "window_labels": [ 74 | [span[0] - s["offset"], span[1] - s["offset"], span[2]] 75 | for span in s["window_labels"] 76 | ], 77 | "predictions": [[span.start, span.end, span.label] for span in p.spans], 78 | } 79 | ) 80 | 81 | with open( 82 | "/home/ric/Projects/relik-sapienzanlp/experiments/predictions/deberta-large/testa.jsonl", 83 | "w", 84 | ) as f: 85 | for line in output: 86 | f.write(json.dumps(line) + "\n") 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /scripts/old-scripts/retriever/test_aida.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from relik.common.log import get_logger 3 | from relik.retriever import GoldenRetriever 4 | from relik.retriever.data.datasets import AidaInBatchNegativesDataset 5 | from relik.retriever.indexers.document import DocumentStore 6 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex 7 | from relik.retriever.trainer import RetrieverTrainer 8 | 9 | logger = get_logger(__name__) 10 | 11 | if __name__ == "__main__": 12 | 13 | arg_parser = argparse.ArgumentParser() 14 | arg_parser.add_argument("encoder", type=str, required=True) 15 | arg_parser.add_argument("index", type=str, required=True) 16 | args = arg_parser.parse_args() 17 | 18 | # instantiate retriever 19 | retriever = GoldenRetriever( 20 | question_encoder=args.encoder, 21 | document_index=args.index, 22 | device="cuda", 23 | ) 24 | 25 | # val_dataset = AidaInBatchNegativesDataset( 26 | # name="aida_val", 27 | # path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/val.jsonl", 28 | # tokenizer=retriever.question_tokenizer, 29 | # question_batch_size=64, 30 | # passage_batch_size=400, 31 | # max_passage_length=64, 32 | # use_topics=True, 33 | # ) 34 | test_dataset = AidaInBatchNegativesDataset( 35 | name="aida_test", 36 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/test.jsonl", 37 | tokenizer=retriever.question_tokenizer, 38 | question_batch_size=64, 39 | passage_batch_size=400, 40 | max_passage_length=64, 41 | use_topics=True, 42 | ) 43 | 44 | trainer = RetrieverTrainer( 45 | retriever=retriever, 46 | # train_dataset=train_dataset, 47 | # val_dataset=val_dataset, 48 | test_dataset=test_dataset, 49 | num_workers=4, 50 | max_steps=25_000, 51 | log_to_wandb=False, 52 | # wandb_online_mode=False, 53 | # wandb_project_name="relik-retriever-aida", 54 | # wandb_experiment_name="aida-e5-base-topics-from-blink-new-data", 55 | max_hard_negatives_to_mine=15, 56 | resume_from_checkpoint_path=None, # path to lightning checkpoint 57 | trainer_kwargs={"logger": False}, 58 | ) 59 | 60 | # trainer.train() 61 | trainer.test() 62 | -------------------------------------------------------------------------------- /scripts/old-scripts/retriever/train_aida.py: -------------------------------------------------------------------------------- 1 | from relik.common.log import get_logger 2 | from relik.retriever import GoldenRetriever 3 | from relik.retriever.data.datasets import AidaInBatchNegativesDataset 4 | from relik.retriever.indexers.document import DocumentStore 5 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex 6 | from relik.retriever.trainer import RetrieverTrainer 7 | 8 | logger = get_logger(__name__) 9 | 10 | if __name__ == "__main__": 11 | # instantiate retriever 12 | retriever = GoldenRetriever( 13 | # question_encoder="/root/golden-retriever/wandb/blink-first1M-e5-base-topics/files/retriever/question_encoder", 14 | question_encoder="riccorl/golden-retriever-base-blink-before-hf", 15 | document_index=InMemoryDocumentIndex( 16 | documents=DocumentStore.from_file( 17 | "/root/relik-sapienzanlp/data/retriever/el/documents.jsonl" 18 | ), 19 | metadata_fields=["definition"], 20 | separator=" ", 21 | device="cuda", 22 | precision="16", 23 | ), 24 | ) 25 | 26 | train_dataset = AidaInBatchNegativesDataset( 27 | name="aida_train", 28 | # path="/root/golden-retriever/data/entitylinking/aida_32_tokens_topic/train.jsonl", 29 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/train.jsonl", 30 | tokenizer=retriever.question_tokenizer, 31 | question_batch_size=64, 32 | passage_batch_size=400, 33 | max_passage_length=64, 34 | shuffle=True, 35 | use_topics=True, 36 | ) 37 | val_dataset = AidaInBatchNegativesDataset( 38 | name="aida_val", 39 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/val.jsonl", 40 | tokenizer=retriever.question_tokenizer, 41 | question_batch_size=64, 42 | passage_batch_size=400, 43 | max_passage_length=64, 44 | use_topics=True, 45 | ) 46 | test_dataset = AidaInBatchNegativesDataset( 47 | name="aida_test", 48 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/test.jsonl", 49 | tokenizer=retriever.question_tokenizer, 50 | question_batch_size=64, 51 | passage_batch_size=400, 52 | max_passage_length=64, 53 | use_topics=True, 54 | ) 55 | 56 | trainer = RetrieverTrainer( 57 | retriever=retriever, 58 | train_dataset=train_dataset, 59 | val_dataset=val_dataset, 60 | test_dataset=test_dataset, 61 | num_workers=4, 62 | max_steps=25_000, 63 | wandb_online_mode=True, 64 | wandb_project_name="relik-retriever-aida", 65 | wandb_experiment_name="aida-e5-base-topics-from-blink-new-data", 66 | max_hard_negatives_to_mine=15, 67 | resume_from_checkpoint_path=None, # path to lightning checkpoint 68 | ) 69 | 70 | trainer.train() 71 | trainer.test() 72 | -------------------------------------------------------------------------------- /scripts/old-scripts/retriever/train_blink.py: -------------------------------------------------------------------------------- 1 | from relik.common.log import get_logger 2 | from relik.retriever import GoldenRetriever 3 | from relik.retriever.data.datasets import ( 4 | AidaInBatchNegativesDataset, 5 | SubsampleStrategyEnum, 6 | ) 7 | from relik.retriever.indexers.document import DocumentStore 8 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex 9 | from relik.retriever.trainer import Trainer 10 | 11 | logger = get_logger(__name__) 12 | 13 | if __name__ == "__main__": 14 | # instantiate retriever 15 | retriever = GoldenRetriever(question_encoder="intfloat/e5-base-v2") 16 | 17 | train_dataset = AidaInBatchNegativesDataset( 18 | name="aida_train", 19 | path="/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/first_1M.jsonl", 20 | tokenizer=retriever.question_tokenizer, 21 | question_batch_size=64, 22 | passage_batch_size=400, 23 | max_passage_length=64, 24 | shuffle=True, 25 | subsample_strategy=SubsampleStrategyEnum.RANDOM, 26 | # use_topics=True, 27 | ) 28 | val_dataset = AidaInBatchNegativesDataset( 29 | name="aida_val", 30 | path="/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/val.jsonl", 31 | tokenizer=retriever.question_tokenizer, 32 | question_batch_size=64, 33 | passage_batch_size=400, 34 | max_passage_length=64, 35 | # use_topics=True, 36 | ) 37 | # test_dataset = AidaInBatchNegativesDataset( 38 | # name="aida_test", 39 | # path="/root/golden-retriever/data/entitylinking/aida_32_tokens_topic/test.jsonl", 40 | # tokenizer=retriever.question_tokenizer, 41 | # question_batch_size=64, 42 | # passage_batch_size=400, 43 | # max_passage_length=64, 44 | # use_topics=True, 45 | # ) 46 | 47 | logger.info("Loading document index") 48 | document_index = InMemoryDocumentIndex( 49 | documents=DocumentStore.from_file( 50 | "/root/golden-retriever/data/entitylinking/documents.jsonl" 51 | ), 52 | metadata_fields=["definition"], 53 | separator=" ", 54 | device="cuda", 55 | precision="16", 56 | ) 57 | retriever.document_index = document_index 58 | 59 | trainer = Trainer( 60 | retriever=retriever, 61 | train_dataset=train_dataset, 62 | val_dataset=val_dataset, 63 | test_dataset=None, 64 | num_workers=0, 65 | max_steps=400_000, 66 | wandb_online_mode=True, 67 | wandb_project_name="relik-retriever-blink", 68 | wandb_experiment_name="blink-first1M-e5-base-topics-recheck", 69 | max_hard_negatives_to_mine=15, 70 | mine_hard_negatives_with_probability=0.2, 71 | save_last=True, 72 | resume_from_checkpoint_path=None, # path to lightning checkpoint 73 | ) 74 | 75 | trainer.train() 76 | trainer.test() 77 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # setup conda 4 | CONDA_BASE=$(conda info --base) 5 | # check if conda is installed 6 | if [ -z "$CONDA_BASE" ]; then 7 | echo "Conda is not installed. Please install conda first." 8 | exit 1 9 | fi 10 | source "$CONDA_BASE"/etc/profile.d/conda.sh 11 | 12 | # create conda env 13 | read -rp "Enter environment name or prefix: " ENV_NAME 14 | read -rp "Enter python version (default 3.10): " PYTHON_VERSION 15 | if [ -z "$PYTHON_VERSION" ]; then 16 | PYTHON_VERSION="3.10" 17 | fi 18 | 19 | # check if ENV_NAME is a full path 20 | if [[ "$ENV_NAME" == /* ]]; then 21 | CONDA_NEW_ARG="--prefix" 22 | else 23 | CONDA_NEW_ARG="--name" 24 | fi 25 | 26 | conda create -y "$CONDA_NEW_ARG" "$ENV_NAME" python="$PYTHON_VERSION" 27 | conda activate "$ENV_NAME" 28 | 29 | # replace placeholder env with $ENV_NAME in scripts/train.sh 30 | # NEW_CONDA_LINE="source \$CONDA_BASE/bin/activate $ENV_NAME" 31 | # sed -i.bak -e "s,.*bin/activate.*,$NEW_CONDA_LINE,g" scripts/train.sh 32 | 33 | # install torch 34 | read -rp "Enter cuda version (e.g. '11.8', default no cuda support): " CUDA_VERSION 35 | read -rp "Enter PyTorch version (e.g. '2.1', default latest): " PYTORCH_VERSION 36 | if [ -n "$PYTORCH_VERSION" ]; then 37 | PYTORCH_VERSION="=$PYTORCH_VERSION" 38 | fi 39 | if [ -z "$CUDA_VERSION" ]; then 40 | conda install -y pytorch"$PYTORCH_VERSION" cpuonly -c pytorch 41 | else 42 | conda install -y pytorch"$PYTORCH_VERSION" pytorch-cuda="$CUDA_VERSION" -c pytorch -c nvidia 43 | fi 44 | 45 | # install python requirements 46 | pip install -e .[all] 47 | --------------------------------------------------------------------------------