├── .dockerignore
├── .flake8
├── .github
└── workflows
│ ├── black.yaml
│ └── python-publish-pypi.yml
├── .gitignore
├── .pre-commit-config.yaml
├── MANIFEST.in
├── README.md
├── SETUP.cfg
├── Sapienza_Babelscape.png
├── constraints.cpu.txt
├── dockerfiles
├── fastapi
│ ├── Dockerfile.cpu
│ └── Dockerfile.cuda
└── ray
│ ├── Dockerfile.cpu
│ └── Dockerfile.cuda
├── examples
├── cie.py
└── langchain.py
├── pyproject.toml
├── relik.png
├── relik
├── __init__.py
├── cli
│ ├── __init__.py
│ ├── cli.py
│ ├── data.py
│ ├── reader.py
│ ├── retriever.py
│ └── utils.py
├── common
│ ├── __init__.py
│ ├── log.py
│ ├── torch_utils.py
│ ├── upload.py
│ └── utils.py
├── inference
│ ├── __init__.py
│ ├── annotator.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── objects.py
│ │ ├── splitters
│ │ │ ├── __init__.py
│ │ │ ├── base_sentence_splitter.py
│ │ │ ├── blank_sentence_splitter.py
│ │ │ ├── spacy_sentence_splitter.py
│ │ │ └── window_based_splitter.py
│ │ ├── tokenizers
│ │ │ ├── __init__.py
│ │ │ ├── base_tokenizer.py
│ │ │ └── spacy_tokenizer.py
│ │ └── window
│ │ │ ├── __init__.py
│ │ │ └── manager.py
│ ├── serve
│ │ ├── __init__.py
│ │ ├── backend
│ │ │ ├── __init__.py
│ │ │ ├── fastapi_be.py
│ │ │ ├── ray.py
│ │ │ └── utils.py
│ │ └── frontend
│ │ │ ├── __init__.py
│ │ │ ├── gradio_fe.py
│ │ │ ├── relik_front.py
│ │ │ ├── relik_re_front.py
│ │ │ ├── style.css
│ │ │ └── utils.py
│ └── utils.py
├── reader
│ ├── __init__.py
│ ├── conf
│ │ ├── base.yaml
│ │ ├── base_nyt.yaml
│ │ ├── cie.yaml
│ │ ├── config.yaml
│ │ ├── data
│ │ │ ├── base.yaml
│ │ │ ├── cie.yaml
│ │ │ ├── large.yaml
│ │ │ └── nyt.yaml
│ │ ├── large.yaml
│ │ ├── large_nyt.yaml
│ │ ├── model
│ │ │ ├── base.yaml
│ │ │ ├── cie.yaml
│ │ │ ├── large.yaml
│ │ │ ├── nyt.yaml
│ │ │ ├── nyt_base.yaml
│ │ │ ├── nyt_large.yaml
│ │ │ ├── nyt_small.yaml
│ │ │ └── small.yaml
│ │ ├── small.yaml
│ │ ├── small_nyt.yaml
│ │ └── training
│ │ │ ├── base.yaml
│ │ │ ├── cie.yaml
│ │ │ ├── large.yaml
│ │ │ └── nyt.yaml
│ ├── data
│ │ ├── __init__.py
│ │ ├── patches.py
│ │ ├── relik_reader_data.py
│ │ ├── relik_reader_data_utils.py
│ │ ├── relik_reader_re_data.py
│ │ └── relik_reader_sample.py
│ ├── lightning_modules
│ │ ├── __init__.py
│ │ ├── relik_reader_pl_module.py
│ │ └── relik_reader_re_pl_module.py
│ ├── pytorch_modules
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── hf
│ │ │ ├── __init__.py
│ │ │ ├── configuration_relik.py
│ │ │ └── modeling_relik.py
│ │ ├── optim
│ │ │ ├── __init__.py
│ │ │ ├── adamw_with_warmup.py
│ │ │ └── layer_wise_lr_decay.py
│ │ ├── span.py
│ │ └── triplet.py
│ ├── trainer
│ │ ├── __init__.py
│ │ ├── predict.py
│ │ ├── predict_cie.py
│ │ ├── predict_re.py
│ │ ├── train.py
│ │ ├── train_cie.py
│ │ └── train_re.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── gerbil.py
│ │ ├── metrics.py
│ │ ├── relation_matching_eval.py
│ │ ├── relik_reader_predictor.py
│ │ ├── save_load_utilities.py
│ │ ├── shuffle_train_callback.py
│ │ ├── special_symbols.py
│ │ └── strong_matching_eval.py
├── retriever
│ ├── __init__.py
│ ├── callbacks
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── evaluation_callbacks.py
│ │ ├── prediction_callbacks.py
│ │ ├── training_callbacks.py
│ │ └── utils_callbacks.py
│ ├── common
│ │ ├── __init__.py
│ │ ├── model_inputs.py
│ │ └── sampler.py
│ ├── conf
│ │ ├── data
│ │ │ ├── aida_dataset.yaml
│ │ │ └── blink_dataset.yaml
│ │ ├── finetune_iterable_in_batch.yaml
│ │ ├── index
│ │ │ └── inmemory.yaml
│ │ ├── logging
│ │ │ └── wandb_logging.yaml
│ │ ├── loss
│ │ │ ├── nce_loss.yaml
│ │ │ └── nll_loss.yaml
│ │ ├── model
│ │ │ └── golden_retriever.yaml
│ │ ├── optimizer
│ │ │ ├── adamw.yaml
│ │ │ ├── radam.yaml
│ │ │ └── radamw.yaml
│ │ ├── pretrain_iterable_in_batch.yaml
│ │ └── scheduler
│ │ │ ├── linear_scheduler.yaml
│ │ │ ├── linear_scheduler_with_warmup.yaml
│ │ │ └── none.yaml
│ ├── data
│ │ ├── __init__.py
│ │ ├── base
│ │ │ ├── __init__.py
│ │ │ └── datasets.py
│ │ ├── datasets.py
│ │ ├── labels.py
│ │ └── utils.py
│ ├── indexers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── document.py
│ │ ├── faissindex.py
│ │ └── inmemory.py
│ ├── lightning_modules
│ │ ├── __init__.py
│ │ ├── pl_data_modules.py
│ │ └── pl_modules.py
│ ├── pytorch_modules
│ │ ├── __init__.py
│ │ ├── hf.py
│ │ ├── loss.py
│ │ ├── model.py
│ │ ├── optim.py
│ │ └── scheduler.py
│ └── trainer
│ │ ├── __init__.py
│ │ └── train.py
└── version.py
├── requirements.txt
├── scripts
├── build_all.sh
├── build_docker_with_weights.sh
├── data
│ ├── blink
│ │ └── preprocess_genre_blink.py
│ ├── create_windows.py
│ ├── nyt
│ │ └── preprocess_nyt.py
│ └── retriever
│ │ ├── add_candidates.py
│ │ ├── convert_to_dpr.py
│ │ └── create_index.py
├── docker
│ ├── gunicorn_conf.py
│ ├── pre-start.sh
│ ├── start-gunic.sh
│ └── start.sh
├── old-scripts
│ ├── data
│ │ ├── add_candidates.py
│ │ ├── add_re_candidates.py
│ │ ├── create_windows.py
│ │ ├── create_windows_re.py
│ │ ├── debug.py
│ │ ├── reader
│ │ │ └── add_candidates_from_retriever.py
│ │ ├── retriever
│ │ │ ├── aida_to_dpr.py
│ │ │ ├── blink
│ │ │ │ ├── create_random_sample_coverage.py
│ │ │ │ ├── create_windows.py
│ │ │ │ └── sample_from_data_coverate.py
│ │ │ ├── create_index.py
│ │ │ ├── explore_blink.py
│ │ │ ├── save_retriever_from_checkpoint.py
│ │ │ └── triplets_to_dpr.py
│ │ └── split_aida.py
│ ├── evaluate
│ │ ├── evaluate_re.py
│ │ └── evaluate_re_bio.py
│ ├── predict
│ │ └── predict_aida.py
│ └── retriever
│ │ ├── test_aida.py
│ │ ├── train_aida.py
│ │ └── train_blink.py
└── setup.sh
└── setup.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | .git
2 | data
3 | benchmark
4 | resources
5 | outputs
6 | retrievers
7 | experiments
8 | build
9 | dist
10 | models
11 | outputs
12 | pretrained_configs
13 | .idea
14 | .vscode
15 | *.egg-info
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203, E266, E501, W503, F403, F401, E402, C901
3 | max-line-length = 88
4 | max-complexity = 18
5 | select = B,C,E,F,W,T4,B9
6 |
--------------------------------------------------------------------------------
/.github/workflows/black.yaml:
--------------------------------------------------------------------------------
1 | name: Check Code Quality
2 |
3 | on: pull_request
4 |
5 | jobs:
6 | lint:
7 | runs-on: ubuntu-latest
8 | steps:
9 | - uses: actions/checkout@v2
10 | - uses: psf/black@stable
11 | with:
12 | options: --check .
13 | - uses: actions/checkout@v2
14 | - uses: actions/setup-python@v2
15 | with:
16 | python-version: "3.10"
17 | - name: Run flake8
18 | uses: julianwachholz/flake8-action@v2
19 | with:
20 | checkName: "Python Lint"
21 | path: ./relik
22 | plugins: "pep8-naming==0.13.3 flake8-comprehensions==3.14.0"
23 | config: .flake8
24 | env:
25 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
--------------------------------------------------------------------------------
/.github/workflows/python-publish-pypi.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package to PyPi
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | jobs:
8 | publish:
9 | runs-on: ubuntu-latest
10 |
11 | steps:
12 | - uses: actions/checkout@v2
13 | - name: Set up Python
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: "3.x"
17 |
18 | - name: Install dependencies
19 | run: |
20 | python -m pip install --upgrade pip
21 | pip install build
22 |
23 | - name: Extract version
24 | run: echo "RELIK_VERSION=`python setup.py --version`" >> $GITHUB_ENV
25 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/ambv/black
3 | rev: '23.10.1'
4 | hooks:
5 | - id: black
6 | - repo: https://github.com/pycqa/flake8
7 | rev: '6.1.0'
8 | hooks:
9 | - id: flake8
10 |
11 | default_language_version:
12 | python: python3
13 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include requirements.txt
2 |
--------------------------------------------------------------------------------
/SETUP.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
4 | [build]
5 | build-base = /tmp/build
6 |
--------------------------------------------------------------------------------
/Sapienza_Babelscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/Sapienza_Babelscape.png
--------------------------------------------------------------------------------
/constraints.cpu.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cpu
2 | torch==2.3.1
3 |
--------------------------------------------------------------------------------
/dockerfiles/fastapi/Dockerfile.cpu:
--------------------------------------------------------------------------------
1 | FROM python:3.11.9-slim-bullseye
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 |
5 | RUN adduser --disabled-password --gecos '' relik-user
6 | USER relik-user
7 | ENV PATH=${PATH}:/home/relik-user/.local/bin
8 |
9 | # Set the working directory
10 | COPY --chown=relik-user:relik-user . /home/relik-user/relik
11 | WORKDIR /home/relik-user/relik
12 |
13 | # mount huggingface cache dir
14 | RUN mkdir -p /home/relik-user/.cache/huggingface
15 | # ENV HF_HOME=/home/relik-user/.cache/huggingface
16 | # mount huggingface
17 |
18 | RUN pip install --upgrade --no-cache-dir .[serve] -c constraints.cpu.txt \
19 | && chmod +x scripts/docker/start-gunic.sh
20 |
21 | EXPOSE 8000 8001
22 |
23 | ENTRYPOINT ["scripts/docker/start-gunic.sh"]
24 |
--------------------------------------------------------------------------------
/dockerfiles/fastapi/Dockerfile.cuda:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.0.0-base-ubuntu22.04
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 |
5 | RUN adduser --disabled-password --gecos '' relik-user \
6 | && apt-get update \
7 | && apt-get install -y --no-install-recommends \
8 | curl wget python3.11 python3-distutils python3-pip \
9 | && rm -rf /var/lib/apt/lists/*
10 |
11 | USER relik-user
12 | ENV PATH=${PATH}:/home/relik-user/.local/bin
13 |
14 | # Set the working directory
15 | COPY --chown=relik-user:relik-user . /home/relik-user/relik
16 | WORKDIR /home/relik-user/relik
17 |
18 | RUN mkdir -p /home/relik-user/.cache/huggingface
19 |
20 |
21 | RUN pip install --upgrade --no-cache-dir .[serve] \
22 | && chmod +x scripts/docker/start-gunic.sh
23 |
24 |
25 | EXPOSE 8000 8001
26 |
27 | ENTRYPOINT ["scripts/docker/start-gunic.sh"]
28 |
--------------------------------------------------------------------------------
/dockerfiles/ray/Dockerfile.cpu:
--------------------------------------------------------------------------------
1 | FROM python:3.10.13-slim-bullseye
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 |
5 | RUN adduser --disabled-password --gecos '' relik-user
6 | USER relik-user
7 | ENV PATH=${PATH}:/home/relik-user/.local/bin
8 |
9 | # Set the working directory
10 | COPY --chown=relik-user:relik-user . /home/relik-user/relik
11 | WORKDIR /home/relik-user/relik
12 |
13 | RUN pip install --upgrade --no-cache-dir .[serve,ray] -c constraints.cpu.txt \
14 | && chmod +x scripts/docker/start.sh
15 |
16 | EXPOSE 8000
17 |
18 | ENTRYPOINT ["scripts/docker/start.sh"]
19 |
20 | # FROM mambaorg/micromamba:bullseye-slim
21 |
22 | # ARG DEBIAN_FRONTEND=noninteractive
23 | # ARG MAMBA_DOCKERFILE_ACTIVATE=1
24 |
25 | # # Set the working directory
26 | # COPY --chown=mambauser:mambauser . /home/mambauser/relik
27 | # WORKDIR /home/mambauser/relik
28 |
29 | # RUN micromamba install -y -n base python==3.10 pytorch==2.1.0 cpuonly -c pytorch -c conda-forge \
30 | # && which pip \
31 | # && pip install --upgrade --no-cache-dir .[serve,ray] \
32 | # && micromamba clean --all --yes \
33 | # && chmod +x scripts/docker/start.sh
34 |
35 | # EXPOSE 8000
36 |
37 | # ENTRYPOINT ["scripts/docker/start.sh"]
38 |
--------------------------------------------------------------------------------
/dockerfiles/ray/Dockerfile.cuda:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.0.0-base-ubuntu22.04
2 |
3 | ARG DEBIAN_FRONTEND=noninteractive
4 |
5 | RUN adduser --disabled-password --gecos '' relik-user \
6 | && apt-get update \
7 | && apt-get install -y --no-install-recommends \
8 | curl wget python3.10 python3-distutils python3-pip \
9 | && rm -rf /var/lib/apt/lists/*
10 |
11 | USER relik-user
12 | ENV PATH=${PATH}:/home/relik-user/.local/bin
13 | # Set the working directory
14 | COPY --chown=relik-user:relik-user . /home/relik-user/relik
15 | WORKDIR /home/relik-user/relik
16 |
17 | # ENVS
18 | # ENV PATH="/root/conda/bin:${PATH}"
19 |
20 | RUN pip install --upgrade --no-cache-dir .[serve,ray] \
21 | && chmod +x scripts/docker/start.sh
22 |
23 | EXPOSE 8000
24 |
25 | ENTRYPOINT ["scripts/docker/start.sh"]
26 |
--------------------------------------------------------------------------------
/examples/cie.py:
--------------------------------------------------------------------------------
1 | from relik import Relik
2 |
3 | relik = Relik.from_pretrained("relik-ie/relik-cie-small", device="cuda")
4 |
5 | text = """When Noah Lyles put his spike into Stade de France’s purple track for his first stride Sunday night of the Paris Olympics 100-meter final, he was already behind. In an event in which margin for error is slimmest, his reaction time to the starting gun was the slowest.
6 |
7 | Halfway through, Lyles, 27, of the U.S., was still in seventh place in an eight-man field, trying to chase down Jamaica’s Kishane Thompson, who owned not only this season’s fastest time but also the fastest time in the semifinal round contested earlier Sunday.
8 |
9 | By the final steps Lyles had caught up so much to Thompson, American Fred Kerley and South Africa’s Akani Simbine that he did something he rarely practices — dipping his shoulder at the finish.
10 |
11 | Even then, Lyles was unconvinced he had won the gold medal he had so boldly predicted, and so badly wanted, for three years. The scoreboard offered no indication of who had won gold, silver or bronze as it processed a photo finish, a sold-out, raucous stadium sharing in the uncertainty.
12 |
13 | “I think you got that one, big dog,” Lyles told Thompson.
14 |
15 | “I’m not even sure,” Thompson replied. “It was that close.”"""
16 |
17 | output = relik(text)
18 |
19 | # Entities
20 | print(output.spans)
21 | # Relations
22 | print(output.triplets)
23 |
--------------------------------------------------------------------------------
/examples/langchain.py:
--------------------------------------------------------------------------------
1 | # pip install langchain-core langchain-experimental
2 |
3 | from langchain_experimental.graph_transformers import RelikGraphTransformer
4 | from langchain_core.documents import Document
5 |
6 | relik = RelikGraphTransformer("relik-ie/relik-relation-extraction-small-wikipedia")
7 |
8 | text = """When Noah Lyles put his spike into Stade de France’s purple track for his first stride Sunday night of the Paris Olympics 100-meter final, he was already behind. In an event in which margin for error is slimmest, his reaction time to the starting gun was the slowest.
9 |
10 | Halfway through, Lyles, 27, of the U.S., was still in seventh place in an eight-man field, trying to chase down Jamaica’s Kishane Thompson, who owned not only this season’s fastest time but also the fastest time in the semifinal round contested earlier Sunday.
11 |
12 | By the final steps Lyles had caught up so much to Thompson, American Fred Kerley and South Africa’s Akani Simbine that he did something he rarely practices — dipping his shoulder at the finish.
13 |
14 | Even then, Lyles was unconvinced he had won the gold medal he had so boldly predicted, and so badly wanted, for three years. The scoreboard offered no indication of who had won gold, silver or bronze as it processed a photo finish, a sold-out, raucous stadium sharing in the uncertainty.
15 |
16 | “I think you got that one, big dog,” Lyles told Thompson.
17 |
18 | “I’m not even sure,” Thompson replied. “It was that close.”"""
19 |
20 | documents = [Document(page_content=text)]
21 | output = relik.convert_to_graph_documents(documents)
22 | # triplets
23 | print(output)
24 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | include = '\.pyi?$'
3 | exclude = '''
4 | /(
5 | \.git
6 | | \.hg
7 | | \.mypy_cache
8 | | \.tox
9 | | \.venv
10 | | _build
11 | | buck-out
12 | | build
13 | | dist
14 | )/
15 | '''
16 |
--------------------------------------------------------------------------------
/relik.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik.png
--------------------------------------------------------------------------------
/relik/__init__.py:
--------------------------------------------------------------------------------
1 | from relik.inference.annotator import Relik
2 | from pathlib import Path
3 |
4 | VERSION = {} # type: ignore
5 | with open(Path(__file__).parent / "version.py", "r") as version_file:
6 | exec(version_file.read(), VERSION)
7 |
8 | __version__ = VERSION["VERSION"]
9 |
--------------------------------------------------------------------------------
/relik/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/cli/__init__.py
--------------------------------------------------------------------------------
/relik/cli/reader.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import hydra
4 | import typer
5 | from omegaconf import OmegaConf
6 |
7 | from relik.cli.utils import resolve_config
8 | from relik.common.log import get_logger, print_relik_text_art
9 | from relik.reader.trainer.train import train as reader_train
10 | from relik.reader.trainer.train_cie import train as reader_train_cie
11 |
12 | logger = get_logger(__name__)
13 |
14 | app = typer.Typer(no_args_is_help=True, pretty_exceptions_show_locals=False)
15 |
16 |
17 | @app.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
18 | def train():
19 | """
20 | Trains the reader model.
21 |
22 | This function prints the Relik text art, resolves the configuration file path,
23 | and then calls the `_reader_train` function to train the reader model.
24 |
25 | Args:
26 | None
27 |
28 | Returns:
29 | None
30 | """
31 | print_relik_text_art()
32 | config_dir, config_name, overrides = resolve_config("reader")
33 |
34 | @hydra.main(
35 | config_path=str(config_dir),
36 | config_name=str(config_name),
37 | version_base="1.3",
38 | )
39 | def _reader_train(conf):
40 | reader_train(conf)
41 |
42 | # clean sys.argv for hydra
43 | sys.argv = sys.argv[:1]
44 | # add the overrides to sys.argv
45 | sys.argv.extend(overrides)
46 |
47 | _reader_train()
48 |
49 |
50 | @app.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
51 | def train_cie():
52 | """
53 | Trains the reader model.
54 |
55 | This function prints the Relik text art, resolves the configuration file path,
56 | and then calls the `_reader_train` function to train the reader model.
57 |
58 | Args:
59 | None
60 |
61 | Returns:
62 | None
63 | """
64 | print_relik_text_art()
65 | config_dir, config_name, overrides = resolve_config("reader")
66 |
67 | @hydra.main(
68 | config_path=str(config_dir),
69 | config_name=str(config_name),
70 | version_base="1.3",
71 | )
72 | def _reader_train_cie(conf):
73 | reader_train_cie(conf)
74 |
75 | # clean sys.argv for hydra
76 | sys.argv = sys.argv[:1]
77 | # add the overrides to sys.argv
78 | sys.argv.extend(overrides)
79 |
80 | _reader_train_cie()
81 |
--------------------------------------------------------------------------------
/relik/cli/utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | import typer
5 | from hydra import compose, initialize, initialize_config_dir
6 | from omegaconf import OmegaConf
7 |
8 | from relik.common.log import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | def resolve_config(type: str | None = None) -> OmegaConf:
14 | """
15 | Resolve the config file and return the OmegaConf object.
16 |
17 | Args:
18 | config_path (`str`):
19 | The path to the config file.
20 |
21 | Returns:
22 | `OmegaConf`:
23 | The OmegaConf object.
24 | """
25 | # first arg is the entry point
26 | # second arg is the command
27 | # third arg is the subcommand
28 | # fourth arg is the config path/name
29 | # fifth arg is the overrides
30 | _, _, _, config_path, *overrides = sys.argv
31 | config_path = Path(config_path)
32 | # TODO: do checks
33 | # if not config_path.exists():
34 | # raise ValueError(f"File {config_path} does not exist!")
35 | # get path and name
36 | config_dir, config_name = config_path.parent, config_path.stem
37 | # logger.debug(f"config_path: {config_path}")
38 | # logger.debug(f"config_name: {config_name}")
39 | # check if config_path is absolute or relative
40 | # if config_path.is_absolute():
41 | # context = initialize_config_dir(config_dir=str(config_path), version_base="1.3")
42 | # else:
43 | if not config_dir.is_absolute():
44 | base_path = Path(__file__).parent.parent
45 | if type == "reader":
46 | config_dir = base_path / "reader" / "conf"
47 | elif type == "retriever":
48 | config_dir = base_path / "retriever" / "conf"
49 | else:
50 | raise ValueError(
51 | "Please provide the type (`reader` or `retriever`) or provide an absolute path."
52 | )
53 | logger.debug(f"config_dir: {config_dir}")
54 | # logger.debug(f"config_name: {config_name}")
55 |
56 | # print(OmegaConf.load(config_dir / f"{config_name}.yaml"))
57 |
58 | # with initialize_config_dir(config_dir=str(config_dir), version_base="1.3"):
59 | # cfg = compose(config_name=config_name, overrides=overrides)
60 |
61 | return config_dir, config_name, overrides
62 |
63 |
64 | def int_or_str_typer(value: str) -> int | None:
65 | """
66 | Converts a string value to an integer or None.
67 |
68 | Args:
69 | value (str): The string value to be converted.
70 |
71 | Returns:
72 | int | None: The converted integer value or None if the input is "None".
73 | """
74 | if value == "None":
75 | return None
76 | return int(value)
77 |
--------------------------------------------------------------------------------
/relik/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/common/__init__.py
--------------------------------------------------------------------------------
/relik/common/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import sys
4 | import threading
5 | from logging.config import dictConfig
6 | from typing import Any, Dict, Optional
7 |
8 | from art import text2art, tprint
9 | from colorama import Fore, Style, init
10 | from rich import get_console
11 | from termcolor import colored, cprint
12 |
13 |
14 | _lock = threading.Lock()
15 | _default_handler: Optional[logging.Handler] = None
16 |
17 | _default_log_level = logging.WARNING
18 |
19 | # fancy logger
20 | _console = get_console()
21 |
22 |
23 | class ColorfulFormatter(logging.Formatter):
24 | """
25 | Formatter to add coloring to log messages by log type
26 | """
27 |
28 | COLORS = {
29 | "WARNING": Fore.YELLOW,
30 | "ERROR": Fore.RED,
31 | "CRITICAL": Fore.RED + Style.BRIGHT,
32 | "DEBUG": Fore.CYAN,
33 | # "INFO": Fore.GREEN,
34 | }
35 |
36 | def format(self, record):
37 | record.rank = int(os.getenv("LOCAL_RANK", "0"))
38 | log_message = super().format(record)
39 | return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
40 |
41 |
42 | DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
43 | "version": 1,
44 | "formatters": {
45 | "simple": {
46 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
47 | },
48 | "colorful": {
49 | "()": ColorfulFormatter,
50 | "format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
51 | },
52 | },
53 | "filters": {},
54 | "handlers": {
55 | "console": {
56 | "class": "logging.StreamHandler",
57 | "formatter": "simple",
58 | "filters": [],
59 | "stream": sys.stdout,
60 | },
61 | "color_console": {
62 | "class": "logging.StreamHandler",
63 | "formatter": "colorful",
64 | "filters": [],
65 | "stream": sys.stdout,
66 | },
67 | },
68 | "root": {"handlers": ["console"], "level": os.getenv("LOG_LEVEL", "INFO")},
69 | "loggers": {
70 | "relik": {
71 | "handlers": ["color_console"],
72 | "level": "DEBUG",
73 | "propagate": False,
74 | },
75 | },
76 | }
77 |
78 |
79 | def configure_logging(**kwargs):
80 | """Configure with default logging"""
81 | init() # Initialize colorama
82 | # merge DEFAULT_LOGGING_CONFIG with kwargs
83 | logger_config = DEFAULT_LOGGING_CONFIG
84 | if kwargs:
85 | logger_config.update(kwargs)
86 | dictConfig(logger_config)
87 |
88 |
89 | def _get_library_name() -> str:
90 | return __name__.split(".")[0]
91 |
92 |
93 | def _get_library_root_logger() -> logging.Logger:
94 | return logging.getLogger(_get_library_name())
95 |
96 |
97 | def _configure_library_root_logger() -> None:
98 | global _default_handler
99 |
100 | with _lock:
101 | if _default_handler:
102 | # This library has already configured the library root logger.
103 | return
104 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream.
105 | _default_handler.flush = sys.stderr.flush
106 |
107 | # Apply our default configuration to the library root logger.
108 | library_root_logger = _get_library_root_logger()
109 | library_root_logger.addHandler(_default_handler)
110 | library_root_logger.setLevel(_default_log_level)
111 | library_root_logger.propagate = False
112 |
113 |
114 | def _reset_library_root_logger() -> None:
115 | global _default_handler
116 |
117 | with _lock:
118 | if not _default_handler:
119 | return
120 |
121 | library_root_logger = _get_library_root_logger()
122 | library_root_logger.removeHandler(_default_handler)
123 | library_root_logger.setLevel(logging.NOTSET)
124 | _default_handler = None
125 |
126 |
127 | def set_log_level(level: int, logger: logging.Logger = None) -> None:
128 | """
129 | Set the log level.
130 | Args:
131 | level (:obj:`int`):
132 | Logging level.
133 | logger (:obj:`logging.Logger`):
134 | Logger to set the log level.
135 | """
136 | if not logger:
137 | _configure_library_root_logger()
138 | logger = _get_library_root_logger()
139 | logger.setLevel(level)
140 |
141 |
142 | def get_logger(
143 | name: Optional[str] = None,
144 | level: Optional[int] = None,
145 | formatter: Optional[str] = None,
146 | **kwargs,
147 | ) -> logging.Logger:
148 | """
149 | Return a logger with the specified name.
150 | """
151 |
152 | configure_logging(**kwargs)
153 |
154 | if name is None:
155 | name = _get_library_name()
156 |
157 | _configure_library_root_logger()
158 |
159 | if level is not None:
160 | set_log_level(level)
161 |
162 | if formatter is None:
163 | formatter = logging.Formatter(
164 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s"
165 | )
166 | _default_handler.setFormatter(formatter)
167 |
168 | return logging.getLogger(name)
169 |
170 |
171 | def get_console_logger():
172 | return _console
173 |
174 |
175 | def print_relik_text_art(
176 | text: str = "relik", font: str = "larry3d", color: str = "magenta", **kwargs
177 | ):
178 | # tprint(text, font=font, **kwargs)
179 | art = text2art(text, font=font, **kwargs) # .rstrip()
180 | # art += "\n\n Retrieve, Read, and Link"
181 | # art += "\nA fast and lightweight Information Extraction framework"
182 | cprint(art, color, attrs=["bold"])
183 |
--------------------------------------------------------------------------------
/relik/common/torch_utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import tempfile
3 |
4 | import torch
5 | import transformers as tr
6 |
7 |
8 | def get_autocast_context(
9 | device: str | torch.device, precision: str
10 | ) -> contextlib.AbstractContextManager:
11 | # fucking autocast only wants pure strings like 'cpu' or 'cuda'
12 | # we need to convert the model device to that
13 | device_type_for_autocast = str(device).split(":")[0]
14 |
15 | from relik.retriever.pytorch_modules import PRECISION_MAP
16 |
17 | # autocast doesn't work with CPU and stuff different from bfloat16
18 | autocast_manager = (
19 | contextlib.nullcontext()
20 | if device_type_for_autocast in ["cpu", "mps"]
21 | and PRECISION_MAP[precision] != torch.bfloat16
22 | else (
23 | torch.autocast(
24 | device_type=device_type_for_autocast,
25 | dtype=PRECISION_MAP[precision],
26 | )
27 | )
28 | )
29 | return autocast_manager
30 |
--------------------------------------------------------------------------------
/relik/common/upload.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import tempfile
6 | import zipfile
7 | from datetime import datetime
8 | from pathlib import Path
9 | from typing import Optional, Union
10 |
11 | import huggingface_hub
12 |
13 | from relik.common.log import get_logger
14 | from relik.common.utils import SAPIENZANLP_DATE_FORMAT, get_md5
15 |
16 | logger = get_logger(__name__, level=logging.DEBUG)
17 |
18 |
19 | def create_info_file(tmpdir: Path):
20 | logger.debug("Computing md5 of model.zip")
21 | md5 = get_md5(tmpdir / "model.zip")
22 | date = datetime.now().strftime(SAPIENZANLP_DATE_FORMAT)
23 |
24 | logger.debug("Dumping info.json file")
25 | with (tmpdir / "info.json").open("w") as f:
26 | json.dump(dict(md5=md5, upload_date=date), f, indent=2)
27 |
28 |
29 | def zip_run(
30 | dir_path: Union[str, os.PathLike],
31 | tmpdir: Union[str, os.PathLike],
32 | zip_name: str = "model.zip",
33 | ) -> Path:
34 | logger.debug(f"zipping {dir_path} to {tmpdir}")
35 | # creates a zip version of the provided dir_path
36 | run_dir = Path(dir_path)
37 | zip_path = tmpdir / zip_name
38 |
39 | with zipfile.ZipFile(zip_path, "w") as zip_file:
40 | # fully zip the run directory maintaining its structure
41 | for file in run_dir.rglob("*.*"):
42 | if file.is_dir():
43 | continue
44 |
45 | zip_file.write(file, arcname=file.relative_to(run_dir))
46 |
47 | return zip_path
48 |
49 |
50 | def get_logged_in_username():
51 | token = huggingface_hub.HfFolder.get_token()
52 | if token is None:
53 | raise ValueError(
54 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
55 | )
56 | api = huggingface_hub.HfApi()
57 | user = api.whoami(token=token)
58 | return user["name"]
59 |
60 |
61 | def upload(
62 | model_dir: Union[str, os.PathLike],
63 | model_name: str,
64 | filenames: Optional[list[str]] = None,
65 | organization: Optional[str] = None,
66 | repo_name: Optional[str] = None,
67 | commit: Optional[str] = None,
68 | archive: bool = False,
69 | ):
70 | token = huggingface_hub.HfFolder.get_token()
71 | if token is None:
72 | raise ValueError(
73 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!"
74 | )
75 |
76 | repo_id = repo_name or model_name
77 | if organization is not None:
78 | repo_id = f"{organization}/{repo_id}"
79 | with tempfile.TemporaryDirectory() as tmpdir:
80 | api = huggingface_hub.HfApi()
81 | repo_url = api.create_repo(
82 | token=token,
83 | repo_id=repo_id,
84 | exist_ok=True,
85 | )
86 | repo = huggingface_hub.Repository(
87 | str(tmpdir), clone_from=repo_url, use_auth_token=token
88 | )
89 |
90 | tmp_path = Path(tmpdir)
91 | if archive:
92 | # otherwise we zip the model_dir
93 | logger.debug(f"Zipping {model_dir} to {tmp_path}")
94 | zip_run(model_dir, tmp_path)
95 | create_info_file(tmp_path)
96 | else:
97 | # if the user wants to upload a transformers model, we don't need to zip it
98 | # we just need to copy the files to the tmpdir
99 | logger.debug(f"Copying {model_dir} to {tmpdir}")
100 | # copy only the files that are needed
101 | if filenames is not None:
102 | for filename in filenames:
103 | os.system(f"cp {model_dir}/{filename} {tmpdir}")
104 | else:
105 | os.system(f"cp -r {model_dir}/* {tmpdir}")
106 |
107 | # this method automatically puts large files (>10MB) into git lfs
108 | repo.push_to_hub(commit_message=commit or "Automatic push from sapienzanlp")
109 |
110 |
111 | def parse_args() -> argparse.Namespace:
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument(
114 | "model_dir", help="The directory of the model you want to upload"
115 | )
116 | parser.add_argument("model_name", help="The model you want to upload")
117 | parser.add_argument(
118 | "--organization",
119 | help="the name of the organization where you want to upload the model",
120 | )
121 | parser.add_argument(
122 | "--repo_name",
123 | help="Optional name to use when uploading to the HuggingFace repository",
124 | )
125 | parser.add_argument(
126 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub"
127 | )
128 | parser.add_argument(
129 | "--archive",
130 | action="store_true",
131 | help="""
132 | Whether to compress the model directory before uploading it.
133 | If True, the model directory will be zipped and the zip file will be uploaded.
134 | If False, the model directory will be uploaded as is.""",
135 | )
136 | return parser.parse_args()
137 |
138 |
139 | def main():
140 | upload(**vars(parse_args()))
141 |
142 |
143 | if __name__ == "__main__":
144 | main()
145 |
--------------------------------------------------------------------------------
/relik/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/__init__.py
--------------------------------------------------------------------------------
/relik/inference/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/__init__.py
--------------------------------------------------------------------------------
/relik/inference/data/objects.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from dataclasses import dataclass
4 | from enum import Enum
5 | from typing import Dict, List, NamedTuple, Optional
6 |
7 | from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSample
8 | from relik.retriever.indexers.document import Document
9 |
10 |
11 | @dataclass
12 | class Word:
13 | """
14 | A word representation that includes text, index in the sentence, POS tag, lemma,
15 | dependency relation, and similar information.
16 |
17 | # Parameters
18 | text : `str`, optional
19 | The text representation.
20 | index : `int`, optional
21 | The word offset in the sentence.
22 | lemma : `str`, optional
23 | The lemma of this word.
24 | pos : `str`, optional
25 | The coarse-grained part of speech of this word.
26 | dep : `str`, optional
27 | The dependency relation for this word.
28 |
29 | input_id : `int`, optional
30 | Integer representation of the word, used to pass it to a model.
31 | token_type_id : `int`, optional
32 | Token type id used by some transformers.
33 | attention_mask: `int`, optional
34 | Attention mask used by transformers, indicates to the model which tokens should
35 | be attended to, and which should not.
36 | """
37 |
38 | text: str
39 | i: int
40 | idx: Optional[int] = None
41 | idx_end: Optional[int] = None
42 | # preprocessing fields
43 | lemma: Optional[str] = None
44 | pos: Optional[str] = None
45 | dep: Optional[str] = None
46 | head: Optional[int] = None
47 |
48 | def __str__(self):
49 | return self.text
50 |
51 | def __repr__(self):
52 | return self.__str__()
53 |
54 |
55 | class Span(NamedTuple):
56 | start: int
57 | end: int
58 | label: str
59 | text: str
60 |
61 |
62 | class Triplets(NamedTuple):
63 | subject: Span
64 | label: str
65 | object: Span
66 | confidence: float
67 |
68 |
69 | class Candidates(NamedTuple):
70 | span: Dict[List[Document]]
71 | triplet: Dict[List[Document]]
72 |
73 |
74 | @dataclass
75 | class RelikOutput:
76 | """
77 | Represents the output of the Relik model.
78 |
79 | Attributes:
80 | text (str):
81 | The original input text.
82 | tokens (List[str]):
83 | The list of tokens generated from the input text.
84 | spans (List[Span]):
85 | The list of spans generated for the input text.
86 | triples (List[Triples]):
87 | The list of triples generated for the input text.
88 | candidates (Candidates):
89 | The candidates for spans and triplets. The candidates are generated by the retriever.
90 | For each type of candidate, the documents are stored in a list of lists. The outer list
91 | represents the windows, and the inner list represents the documents in that window.
92 | If only one window is used, the outer list will have only one element.
93 | windows (Optional[List[RelikReaderSample]]):
94 | The list of windows used for processing the input text.
95 | """
96 |
97 | text: str
98 | tokens: List[str]
99 | id: str | int
100 | spans: List[Span]
101 | triplets: List[Triplets]
102 | candidates: Candidates = None
103 | windows: Optional[List[RelikReaderSample]] = None
104 |
105 | # convert to dict
106 | def to_dict(self):
107 | self_dict = {
108 | "text": self.text,
109 | "tokens": [tok.text for tok in self.tokens],
110 | "spans": self.spans,
111 | "triplets": self.triplets,
112 | "candidates": {
113 | "span": [
114 | [[doc.to_dict() for doc in documents] for documents in window]
115 | for window in self.candidates.span
116 | ],
117 | "triplet": [
118 | [[doc.to_dict() for doc in documents] for documents in window]
119 | for window in self.candidates.triplet
120 | ],
121 | },
122 | }
123 | if self.windows is not None:
124 | self_dict["windows"] = [window.to_dict() for window in self.windows]
125 | return self_dict
126 |
127 |
128 | class AnnotationType(Enum):
129 | CHAR = "char"
130 | WORD = "word"
131 |
132 |
133 | class TaskType(Enum):
134 | SPAN = "span"
135 | TRIPLET = "triplet"
136 | BOTH = "both"
137 |
--------------------------------------------------------------------------------
/relik/inference/data/splitters/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/splitters/__init__.py
--------------------------------------------------------------------------------
/relik/inference/data/splitters/base_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 |
4 | class BaseSentenceSplitter:
5 | """
6 | A `BaseSentenceSplitter` splits strings into sentences.
7 | """
8 |
9 | def __call__(self, *args, **kwargs):
10 | """
11 | Calls :meth:`split_sentences`.
12 | """
13 | return self.split_sentences(*args, **kwargs)
14 |
15 | def split_sentences(
16 | self, text: str, max_len: int = 0, *args, **kwargs
17 | ) -> List[str]:
18 | """
19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20 | """
21 | raise NotImplementedError
22 |
23 | def split_sentences_batch(
24 | self, texts: List[str], *args, **kwargs
25 | ) -> List[List[str]]:
26 | """
27 | Default implementation is to just iterate over the texts and call `split_sentences`.
28 | """
29 | return [self.split_sentences(text) for text in texts]
30 |
31 | @staticmethod
32 | def check_is_batched(
33 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
34 | ):
35 | """
36 | Check if input is batched or a single sample.
37 |
38 | Args:
39 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
40 | Text to check.
41 | is_split_into_words (:obj:`bool`):
42 | If :obj:`True` and the input is a string, the input is split on spaces.
43 |
44 | Returns:
45 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
46 | """
47 | return bool(
48 | (not is_split_into_words and isinstance(texts, (list, tuple)))
49 | or (
50 | is_split_into_words
51 | and isinstance(texts, (list, tuple))
52 | and texts
53 | and isinstance(texts[0], (list, tuple))
54 | )
55 | )
56 |
--------------------------------------------------------------------------------
/relik/inference/data/splitters/blank_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 |
4 | class BlankSentenceSplitter:
5 | """
6 | A `BlankSentenceSplitter` splits strings into sentences.
7 | """
8 |
9 | def __call__(self, *args, **kwargs):
10 | """
11 | Calls :meth:`split_sentences`.
12 | """
13 | return self.split_sentences(*args, **kwargs)
14 |
15 | def split_sentences(
16 | self, text: str, max_len: int = 0, *args, **kwargs
17 | ) -> List[str]:
18 | """
19 | Splits a `text` :class:`str` paragraph into a list of :class:`str`, where each is a sentence.
20 | """
21 | return [text]
22 |
23 | def split_sentences_batch(
24 | self, texts: List[str], *args, **kwargs
25 | ) -> List[List[str]]:
26 | """
27 | Default implementation is to just iterate over the texts and call `split_sentences`.
28 | """
29 | return [self.split_sentences(text) for text in texts]
30 |
--------------------------------------------------------------------------------
/relik/inference/data/splitters/window_based_splitter.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from relik.inference.data.splitters.base_sentence_splitter import BaseSentenceSplitter
4 |
5 |
6 | class WindowSentenceSplitter(BaseSentenceSplitter):
7 | """
8 | A :obj:`WindowSentenceSplitter` that splits a text into windows of a given size.
9 | """
10 |
11 | def __init__(self, window_size: int, window_stride: int, *args, **kwargs) -> None:
12 | super(WindowSentenceSplitter, self).__init__()
13 | self.window_size = window_size
14 | self.window_stride = window_stride
15 |
16 | def __call__(
17 | self,
18 | texts: Union[str, List[str], List[List[str]]],
19 | is_split_into_words: bool = False,
20 | **kwargs,
21 | ) -> Union[List[str], List[List[str]]]:
22 | """
23 | Tokenize the input into single words using SpaCy models.
24 |
25 | Args:
26 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
27 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
28 |
29 | Returns:
30 | :obj:`List[List[str]]`: The input doc split into sentences.
31 | """
32 | return self.split_sentences(texts)
33 |
34 | def split_sentences(self, text: str | List, *args, **kwargs) -> List[List]:
35 | """
36 | Splits a `text` into sentences.
37 |
38 | Args:
39 | text (:obj:`str`):
40 | Text to split.
41 |
42 | Returns:
43 | :obj:`List[str]`: The input text split into sentences.
44 | """
45 |
46 | if isinstance(text, str):
47 | text = text.split()
48 | sentences = []
49 | # if window_stride is zero, we don't need overlapping windows
50 | self.window_stride = (
51 | self.window_stride if self.window_stride != 0 else self.window_size
52 | )
53 | for i in range(0, len(text), self.window_stride):
54 | # if the last stride is smaller than the window size, then we can
55 | # include more tokens form the previous window.
56 | if i != 0 and i + self.window_size > len(text):
57 | overflowing_tokens = i + self.window_size - len(text)
58 | if overflowing_tokens >= self.window_stride:
59 | break
60 | i -= overflowing_tokens
61 | involved_token_indices = list(
62 | range(i, min(i + self.window_size, len(text)))
63 | )
64 | window_tokens = [text[j] for j in involved_token_indices]
65 | sentences.append(window_tokens)
66 | return sentences
67 |
--------------------------------------------------------------------------------
/relik/inference/data/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | SPACY_LANGUAGE_MAPPER = {
2 | "ca": "ca_core_news_sm",
3 | "da": "da_core_news_sm",
4 | "de": "de_core_news_sm",
5 | "el": "el_core_news_sm",
6 | "en": "en_core_web_sm",
7 | "es": "es_core_news_sm",
8 | "fr": "fr_core_news_sm",
9 | "it": "it_core_news_sm",
10 | "ja": "ja_core_news_sm",
11 | "lt": "lt_core_news_sm",
12 | "mk": "mk_core_news_sm",
13 | "nb": "nb_core_news_sm",
14 | "nl": "nl_core_news_sm",
15 | "pl": "pl_core_news_sm",
16 | "pt": "pt_core_news_sm",
17 | "ro": "ro_core_news_sm",
18 | "ru": "ru_core_news_sm",
19 | "xx": "xx_sent_ud_sm",
20 | "zh": "zh_core_web_sm",
21 | "ca_core_news_sm": "ca_core_news_sm",
22 | "ca_core_news_md": "ca_core_news_md",
23 | "ca_core_news_lg": "ca_core_news_lg",
24 | "ca_core_news_trf": "ca_core_news_trf",
25 | "da_core_news_sm": "da_core_news_sm",
26 | "da_core_news_md": "da_core_news_md",
27 | "da_core_news_lg": "da_core_news_lg",
28 | "da_core_news_trf": "da_core_news_trf",
29 | "de_core_news_sm": "de_core_news_sm",
30 | "de_core_news_md": "de_core_news_md",
31 | "de_core_news_lg": "de_core_news_lg",
32 | "de_dep_news_trf": "de_dep_news_trf",
33 | "el_core_news_sm": "el_core_news_sm",
34 | "el_core_news_md": "el_core_news_md",
35 | "el_core_news_lg": "el_core_news_lg",
36 | "en_core_web_sm": "en_core_web_sm",
37 | "en_core_web_md": "en_core_web_md",
38 | "en_core_web_lg": "en_core_web_lg",
39 | "en_core_web_trf": "en_core_web_trf",
40 | "es_core_news_sm": "es_core_news_sm",
41 | "es_core_news_md": "es_core_news_md",
42 | "es_core_news_lg": "es_core_news_lg",
43 | "es_dep_news_trf": "es_dep_news_trf",
44 | "fr_core_news_sm": "fr_core_news_sm",
45 | "fr_core_news_md": "fr_core_news_md",
46 | "fr_core_news_lg": "fr_core_news_lg",
47 | "fr_dep_news_trf": "fr_dep_news_trf",
48 | "it_core_news_sm": "it_core_news_sm",
49 | "it_core_news_md": "it_core_news_md",
50 | "it_core_news_lg": "it_core_news_lg",
51 | "ja_core_news_sm": "ja_core_news_sm",
52 | "ja_core_news_md": "ja_core_news_md",
53 | "ja_core_news_lg": "ja_core_news_lg",
54 | "ja_dep_news_trf": "ja_dep_news_trf",
55 | "lt_core_news_sm": "lt_core_news_sm",
56 | "lt_core_news_md": "lt_core_news_md",
57 | "lt_core_news_lg": "lt_core_news_lg",
58 | "mk_core_news_sm": "mk_core_news_sm",
59 | "mk_core_news_md": "mk_core_news_md",
60 | "mk_core_news_lg": "mk_core_news_lg",
61 | "nb_core_news_sm": "nb_core_news_sm",
62 | "nb_core_news_md": "nb_core_news_md",
63 | "nb_core_news_lg": "nb_core_news_lg",
64 | "nl_core_news_sm": "nl_core_news_sm",
65 | "nl_core_news_md": "nl_core_news_md",
66 | "nl_core_news_lg": "nl_core_news_lg",
67 | "pl_core_news_sm": "pl_core_news_sm",
68 | "pl_core_news_md": "pl_core_news_md",
69 | "pl_core_news_lg": "pl_core_news_lg",
70 | "pt_core_news_sm": "pt_core_news_sm",
71 | "pt_core_news_md": "pt_core_news_md",
72 | "pt_core_news_lg": "pt_core_news_lg",
73 | "ro_core_news_sm": "ro_core_news_sm",
74 | "ro_core_news_md": "ro_core_news_md",
75 | "ro_core_news_lg": "ro_core_news_lg",
76 | "ru_core_news_sm": "ru_core_news_sm",
77 | "ru_core_news_md": "ru_core_news_md",
78 | "ru_core_news_lg": "ru_core_news_lg",
79 | "xx_ent_wiki_sm": "xx_ent_wiki_sm",
80 | "xx_sent_ud_sm": "xx_sent_ud_sm",
81 | "zh_core_web_sm": "zh_core_web_sm",
82 | "zh_core_web_md": "zh_core_web_md",
83 | "zh_core_web_lg": "zh_core_web_lg",
84 | "zh_core_web_trf": "zh_core_web_trf",
85 | }
86 |
87 | from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
88 |
--------------------------------------------------------------------------------
/relik/inference/data/tokenizers/base_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union
2 |
3 | from relik.inference.data.objects import Word
4 |
5 |
6 | class BaseTokenizer:
7 | """
8 | A :obj:`Tokenizer` splits strings of text into single words, optionally adds
9 | pos tags and perform lemmatization.
10 | """
11 |
12 | def __call__(
13 | self,
14 | texts: Union[str, List[str], List[List[str]]],
15 | is_split_into_words: bool = False,
16 | **kwargs
17 | ) -> List[List[Word]]:
18 | """
19 | Tokenize the input into single words.
20 |
21 | Args:
22 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
23 | Text to tag. It can be a single string, a batch of string and pre-tokenized strings.
24 | is_split_into_words (:obj:`bool`, optional, defaults to :obj:`False`):
25 | If :obj:`True` and the input is a string, the input is split on spaces.
26 |
27 | Returns:
28 | :obj:`List[List[Word]]`: The input text tokenized in single words.
29 | """
30 | raise NotImplementedError
31 |
32 | def tokenize(self, text: str) -> List[Word]:
33 | """
34 | Implements splitting words into tokens.
35 |
36 | Args:
37 | text (:obj:`str`):
38 | Text to tokenize.
39 |
40 | Returns:
41 | :obj:`List[Word]`: The input text tokenized in single words.
42 |
43 | """
44 | raise NotImplementedError
45 |
46 | def tokenize_batch(self, texts: List[str]) -> List[List[Word]]:
47 | """
48 | Implements batch splitting words into tokens.
49 |
50 | Args:
51 | texts (:obj:`List[str]`):
52 | Batch of text to tokenize.
53 |
54 | Returns:
55 | :obj:`List[List[Word]]`: The input batch tokenized in single words.
56 |
57 | """
58 | return [self.tokenize(text) for text in texts]
59 |
60 | @staticmethod
61 | def check_is_batched(
62 | texts: Union[str, List[str], List[List[str]]], is_split_into_words: bool
63 | ):
64 | """
65 | Check if input is batched or a single sample.
66 |
67 | Args:
68 | texts (:obj:`str`, :obj:`List[str]`, :obj:`List[List[str]]`):
69 | Text to check.
70 | is_split_into_words (:obj:`bool`):
71 | If :obj:`True` and the input is a string, the input is split on spaces.
72 |
73 | Returns:
74 | :obj:`bool`: ``True`` if ``texts`` is batched, ``False`` otherwise.
75 | """
76 | return bool(
77 | (not is_split_into_words and isinstance(texts, (list, tuple)))
78 | or (
79 | is_split_into_words
80 | and isinstance(texts, (list, tuple))
81 | and texts
82 | and isinstance(texts[0], (list, tuple))
83 | )
84 | )
85 |
--------------------------------------------------------------------------------
/relik/inference/data/window/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/data/window/__init__.py
--------------------------------------------------------------------------------
/relik/inference/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/__init__.py
--------------------------------------------------------------------------------
/relik/inference/serve/backend/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/backend/__init__.py
--------------------------------------------------------------------------------
/relik/inference/serve/backend/utils.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import os
3 | from dataclasses import dataclass
4 |
5 |
6 | @dataclass
7 | class ServerParameterManager:
8 | relik_pretrained: str = os.environ.get("RELIK_PRETRAINED", None)
9 | device: str = os.environ.get("DEVICE", "cpu")
10 | retriever_device: str | None = os.environ.get("RETRIEVER_DEVICE", None)
11 | document_index_device: str | None = os.environ.get("INDEX_DEVICE", None)
12 | reader_device: str | None = os.environ.get("READER_DEVICE", None)
13 | precision: int | str | None = os.environ.get("PRECISION", "fp32")
14 | retriever_precision: int | str | None = os.environ.get("RETRIEVER_PRECISION", None)
15 | document_index_precision: int | str | None = os.environ.get("INDEX_PRECISION", None)
16 | reader_precision: int | str | None = os.environ.get("READER_PRECISION", None)
17 | annotation_type: str = os.environ.get("ANNOTATION_TYPE", "char")
18 | question_encoder: str = os.environ.get("QUESTION_ENCODER", None)
19 | passage_encoder: str = os.environ.get("PASSAGE_ENCODER", None)
20 | document_index: str = os.environ.get("DOCUMENT_INDEX", None)
21 | reader_encoder: str = os.environ.get("READER_ENCODER", None)
22 | top_k: int = int(os.environ.get("TOP_K", 100))
23 | use_faiss: bool = os.environ.get("USE_FAISS", False)
24 | retriever_batch_size: int = int(os.environ.get("RETRIEVER_BATCH_SIZE", 32))
25 | reader_batch_size: int = int(os.environ.get("READER_BATCH_SIZE", 32))
26 | window_size: int = int(os.environ.get("WINDOW_SIZE", 32))
27 | window_stride: int = int(os.environ.get("WINDOW_SIZE", 16))
28 | split_on_spaces: bool = os.environ.get("SPLIT_ON_SPACES", False)
29 | # relik_config_override: dict = ast.literal_eval(
30 | # os.environ.get("RELIK_CONFIG_OVERRIDE", None)
31 | # )
32 |
33 |
34 | class RayParameterManager:
35 | def __init__(self) -> None:
36 | self.num_gpus = int(os.environ.get("NUM_GPUS", 1))
37 | self.min_replicas = int(os.environ.get("MIN_REPLICAS", 1))
38 | self.max_replicas = int(os.environ.get("MAX_REPLICAS", 1))
39 |
--------------------------------------------------------------------------------
/relik/inference/serve/frontend/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/inference/serve/frontend/__init__.py
--------------------------------------------------------------------------------
/relik/inference/serve/frontend/style.css:
--------------------------------------------------------------------------------
1 | /* Sidebar */
2 | .eczjsme11 {
3 | background-color: #802433;
4 | }
5 |
6 | .st-emotion-cache-10oheav h2 {
7 | color: white;
8 | }
9 |
10 | .st-emotion-cache-10oheav li {
11 | color: white;
12 | }
13 |
14 | /* Main */
15 | a:link {
16 | text-decoration: none;
17 | color: white;
18 | }
19 |
20 | a:visited {
21 | text-decoration: none;
22 | color: white;
23 | }
24 |
25 | a:hover {
26 | text-decoration: none;
27 | color: rgba(255, 255, 255, 0.871);
28 | }
29 |
30 | a:active {
31 | text-decoration: none;
32 | color: white;
33 | }
--------------------------------------------------------------------------------
/relik/inference/serve/frontend/utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import random
3 | from typing import Dict, List, Optional, Union
4 |
5 | import spacy
6 | import streamlit as st
7 | from spacy import displacy
8 |
9 |
10 | def get_html(html: str):
11 | """Convert HTML so it can be rendered."""
12 | WRAPPER = """
{}
"""
13 | # Newlines seem to mess with the rendering
14 | html = html.replace("\n", " ")
15 | return WRAPPER.format(html)
16 |
17 |
18 | def get_svg(svg: str, style: str = "", wrap: bool = True):
19 | """Convert an SVG to a base64-encoded image."""
20 | b64 = base64.b64encode(svg.encode("utf-8")).decode("utf-8")
21 | html = f'
'
22 | return get_html(html) if wrap else html
23 |
24 |
25 | def visualize_parser(
26 | doc: Union[spacy.tokens.Doc, List[Dict[str, str]]],
27 | *,
28 | title: Optional[str] = None,
29 | key: Optional[str] = None,
30 | manual: bool = False,
31 | displacy_options: Optional[Dict] = None,
32 | ) -> None:
33 | """Visualizer for dependency parses.
34 |
35 | doc (Doc, List): The document to visualize.
36 | key (str): Key used for the streamlit component for selecting labels.
37 | title (str): The title displayed at the top of the parser visualization.
38 | manual (bool): Flag signifying whether the doc argument is a Doc object or a List of Dicts containing parse information.
39 | displacy_options (Dict): Dictionary of options to be passed to the displacy render method for generating the HTML to be rendered.
40 | See: https://spacy.io/api/top-level#options-dep
41 | """
42 | if displacy_options is None:
43 | displacy_options = dict()
44 | if title:
45 | st.header(title)
46 | docs = [doc]
47 | # add selected options to options provided by user
48 | # `options` from `displacy_options` are overwritten by user provided
49 | # options from the checkboxes
50 | for sent in docs:
51 | html = displacy.render(
52 | sent, options=displacy_options, style="dep", manual=manual
53 | )
54 | # Double newlines seem to mess with the rendering
55 | html = html.replace("\n\n", "\n")
56 | st.write(get_svg(html), unsafe_allow_html=True)
57 |
58 |
59 | def get_random_color(ents):
60 | colors = {}
61 | random_colors = generate_pastel_colors(len(ents))
62 | for ent in ents:
63 | colors[ent] = random_colors.pop(random.randint(0, len(random_colors) - 1))
64 | return colors
65 |
66 |
67 | def floatrange(start, stop, steps):
68 | if int(steps) == 1:
69 | return [stop]
70 | return [
71 | start + float(i) * (stop - start) / (float(steps) - 1) for i in range(steps)
72 | ]
73 |
74 |
75 | def hsl_to_rgb(h, s, l):
76 | def hue_2_rgb(v1, v2, v_h):
77 | while v_h < 0.0:
78 | v_h += 1.0
79 | while v_h > 1.0:
80 | v_h -= 1.0
81 | if 6 * v_h < 1.0:
82 | return v1 + (v2 - v1) * 6.0 * v_h
83 | if 2 * v_h < 1.0:
84 | return v2
85 | if 3 * v_h < 2.0:
86 | return v1 + (v2 - v1) * ((2.0 / 3.0) - v_h) * 6.0
87 | return v1
88 |
89 | # if not (0 <= s <= 1): raise ValueError, "s (saturation) parameter must be between 0 and 1."
90 | # if not (0 <= l <= 1): raise ValueError, "l (lightness) parameter must be between 0 and 1."
91 |
92 | r, b, g = (l * 255,) * 3
93 | if s != 0.0:
94 | if l < 0.5:
95 | var_2 = l * (1.0 + s)
96 | else:
97 | var_2 = (l + s) - (s * l)
98 | var_1 = 2.0 * l - var_2
99 | r = 255 * hue_2_rgb(var_1, var_2, h + (1.0 / 3.0))
100 | g = 255 * hue_2_rgb(var_1, var_2, h)
101 | b = 255 * hue_2_rgb(var_1, var_2, h - (1.0 / 3.0))
102 |
103 | return int(round(r)), int(round(g)), int(round(b))
104 |
105 |
106 | def generate_pastel_colors(n):
107 | """Return different pastel colours.
108 |
109 | Input:
110 | n (integer) : The number of colors to return
111 |
112 | Output:
113 | A list of colors in HTML notation (eg.['#cce0ff', '#ffcccc', '#ccffe0', '#f5ccff', '#f5ffcc'])
114 |
115 | Example:
116 | >>> print generate_pastel_colors(5)
117 | ['#cce0ff', '#f5ccff', '#ffcccc', '#f5ffcc', '#ccffe0']
118 | """
119 | if n == 0:
120 | return []
121 |
122 | # To generate colors, we use the HSL colorspace (see http://en.wikipedia.org/wiki/HSL_color_space)
123 | start_hue = 0.0 # 0=red 1/3=0.333=green 2/3=0.666=blue
124 | saturation = 1.0
125 | lightness = 0.9
126 | # We take points around the chromatic circle (hue):
127 | # (Note: we generate n+1 colors, then drop the last one ([:-1]) because
128 | # it equals the first one (hue 0 = hue 1))
129 | return [
130 | "#%02x%02x%02x" % hsl_to_rgb(hue, saturation, lightness)
131 | for hue in floatrange(start_hue, start_hue + 1, n + 1)
132 | ][:-1]
133 |
--------------------------------------------------------------------------------
/relik/reader/__init__.py:
--------------------------------------------------------------------------------
1 | # from relik.reader.pytorch_modules.base import RelikReaderBase
2 | # from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
3 | # from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
4 |
--------------------------------------------------------------------------------
/relik/reader/conf/base.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | job:
4 | chdir: True
5 | run:
6 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
7 |
8 | model_name: relik-reader-deberta-base-04062024-lrd08-1x4096-seed42 # -start-end-mask-0.001 # used to name the model in wandb and output dir
9 | project_name: relik-reader # used to name the project in wandb
10 | offline: false # if true, wandb will not be used
11 |
12 | defaults:
13 | - _self_
14 | - training: base
15 | - model: base
16 | - data: base
17 |
--------------------------------------------------------------------------------
/relik/reader/conf/base_nyt.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-base
7 | project_name: relik-reader-nyt # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: nyt
13 | - model: nyt_base
14 | - data: nyt
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/cie.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-large # no-proj-special-token # -start-end-mask-0.001 # used to name the model in wandb and output dir
7 | project_name: relik-reader-cie # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: cie
13 | - model: cie
14 | - data: cie
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/config.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-twin-no-pere # -start-end-mask-0.001 # used to name the model in wandb and output dir
7 | project_name: relik-reader # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: base
13 | - model: base
14 | - data: base
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/data/base.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/train_windowed_candidates.jsonl"
2 | val_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl"
3 |
4 | train_dataset:
5 | _target_: "relik.reader.data.relik_reader_data.RelikDataset"
6 | transformer_model: "${model.model.transformer_model}"
7 | materialize_samples: False
8 | shuffle_candidates: 0.5
9 | random_drop_gold_candidates: 0.05
10 | noise_param: 0.0
11 | for_inference: False
12 | tokens_per_batch: 4096
13 | special_symbols: null
14 |
15 | val_dataset:
16 | _target_: "relik.reader.data.relik_reader_data.RelikDataset"
17 | transformer_model: "${model.model.transformer_model}"
18 | materialize_samples: False
19 | shuffle_candidates: False
20 | for_inference: True
21 | special_symbols: null
22 |
--------------------------------------------------------------------------------
/relik/reader/conf/data/cie.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path:
2 | val_dataset_path:
3 | test_dataset_path:
4 |
5 | train_dataset:
6 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset"
7 | transformer_model: "${model.model.transformer_model}"
8 | materialize_samples: False
9 | shuffle_candidates: False
10 | flip_candidates: 1.0
11 | noise_param: 0.0
12 | for_inference: False
13 | tokens_per_batch: 4096
14 | max_length: 1024
15 | max_triplets: 25
16 | max_spans: 75
17 | min_length: -1
18 | special_symbols: null
19 | special_symbols_re: null
20 | section_size: null
21 | use_nme: True
22 | sorting_fields:
23 | - "predictable_candidates"
24 | val_dataset:
25 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset"
26 | transformer_model: "${model.model.transformer_model}"
27 | materialize_samples: False
28 | shuffle_candidates: False
29 | flip_candidates: False
30 | for_inference: True
31 | use_nme: True
32 | max_triplets: 25
33 | max_spans: 75
34 | min_length: -1
35 | special_symbols: null
36 | special_symbols_re: null
37 |
--------------------------------------------------------------------------------
/relik/reader/conf/data/large.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/train_windowed_candidates.jsonl"
2 | val_dataset_path: "/root/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl"
3 |
4 | train_dataset:
5 | _target_: "relik.reader.data.relik_reader_data.RelikDataset"
6 | transformer_model: "${model.model.transformer_model}"
7 | materialize_samples: False
8 | shuffle_candidates: 0.5
9 | random_drop_gold_candidates: 0.05
10 | noise_param: 0.0
11 | for_inference: False
12 | tokens_per_batch: 2048
13 | special_symbols: null
14 |
15 | val_dataset:
16 | _target_: "relik.reader.data.relik_reader_data.RelikDataset"
17 | transformer_model: "${model.model.transformer_model}"
18 | materialize_samples: False
19 | shuffle_candidates: False
20 | for_inference: True
21 | special_symbols: null
22 |
--------------------------------------------------------------------------------
/relik/reader/conf/data/nyt.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path: "data/reader/nyt/train.relik.candidates.jsonl"
2 | val_dataset_path: "data/reader/nyt/valid.relik.candidates.jsonl"
3 | test_dataset_path: "data/reader/nyt/test.relik.candidates.jsonl"
4 |
5 | train_dataset:
6 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset"
7 | transformer_model: "${model.model.transformer_model}"
8 | materialize_samples: False
9 | shuffle_candidates: False
10 | flip_candidates: 1.0
11 | noise_param: 0.0
12 | for_inference: False
13 | tokens_per_batch: 2048
14 | max_length: 1024
15 | max_triplets: 24
16 | max_spans:
17 | min_length: -1
18 | special_symbols: null
19 | special_symbols_re: null
20 | section_size: null
21 | use_nme: False
22 | sorting_fields:
23 | - "predictable_candidates"
24 | val_dataset:
25 | _target_: "relik.reader.data.relik_reader_re_data.RelikREDataset"
26 | transformer_model: "${model.model.transformer_model}"
27 | materialize_samples: False
28 | shuffle_candidates: False
29 | flip_candidates: False
30 | for_inference: True
31 | use_nme: False
32 | max_triplets: 24
33 | max_spans:
34 | min_length: -1
35 | special_symbols: null
36 | special_symbols_re: null
37 |
--------------------------------------------------------------------------------
/relik/reader/conf/large.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-large-29052024-lrd09-8x2048-onelayer # no-proj-special-token # -start-end-mask-0.001 # used to name the model in wandb and output dir
7 | project_name: relik-reader # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: large
13 | - model: large
14 | - data: large
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/large_nyt.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-large
7 | project_name: relik-reader-nyt # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: nyt
13 | - model: nyt_large
14 | - data: nyt
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-base"
3 |
4 | optimizer:
5 | lr: 0.0001
6 | warmup_steps: 5000
7 | total_steps: ${training.trainer.max_steps}
8 | total_reset: 1
9 | weight_decay: 0.0
10 | lr_decay: 0.8
11 | no_decay_params:
12 | - "bias"
13 | - LayerNorm.weight
14 |
15 | entities_per_forward: 100
16 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/cie.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-large"
3 |
4 | optimizer:
5 | lr: 1.0e-05
6 | warmup_steps: 5000
7 | total_steps: ${training.trainer.max_steps}
8 | total_reset: 1
9 | weight_decay: 0.01
10 | lr_decay: 0.9
11 | no_decay_params:
12 | - "bias"
13 | - LayerNorm.weight
14 |
15 | entities_per_forward: 75
16 | relations_per_forward: 25
17 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/large.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-large"
3 |
4 | optimizer:
5 | lr: 1.0e-05
6 | warmup_steps: 5000
7 | total_steps: ${training.trainer.max_steps}
8 | total_reset: 1
9 | weight_decay: 0.01
10 | lr_decay: 0.9
11 | no_decay_params:
12 | - "bias"
13 | - LayerNorm.weight
14 |
15 | entities_per_forward: 100
16 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/nyt.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-small"
3 |
4 | optimizer:
5 | lr: 0.00002
6 | warmup_steps: 10000
7 | total_steps: ${training.trainer.max_steps}
8 | weight_decay: 0.01
9 | no_decay_params:
10 | - "bias"
11 | - LayerNorm.weight
12 |
13 | relations_per_forward: 24
14 | entities_per_forward:
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/nyt_base.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-base"
3 |
4 | optimizer:
5 | lr: 0.00002
6 | warmup_steps: 75000
7 | total_steps: ${training.trainer.max_steps}
8 | weight_decay: 0.01
9 | no_decay_params:
10 | - "bias"
11 | - LayerNorm.weight
12 |
13 | relations_per_forward: 24
14 | entities_per_forward:
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/nyt_large.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-large"
3 |
4 | optimizer:
5 | lr: 0.00002
6 | warmup_steps: 75000
7 | total_steps: ${training.trainer.max_steps}
8 | weight_decay: 0.01
9 | no_decay_params:
10 | - "bias"
11 | - LayerNorm.weight
12 |
13 | relations_per_forward: 24
14 | entities_per_forward:
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/nyt_small.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-small"
3 |
4 | optimizer:
5 | lr: 0.00002
6 | warmup_steps: 75000
7 | total_steps: ${training.trainer.max_steps}
8 | weight_decay: 0.01
9 | no_decay_params:
10 | - "bias"
11 | - LayerNorm.weight
12 |
13 | relations_per_forward: 24
14 | entities_per_forward:
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/model/small.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | transformer_model: "microsoft/deberta-v3-small"
3 |
4 | optimizer:
5 | lr: 0.0001
6 | warmup_steps: 5000
7 | total_steps: ${training.trainer.max_steps}
8 | total_reset: 1
9 | weight_decay: 0.0
10 | lr_decay: 0.8
11 | no_decay_params:
12 | - "bias"
13 | - LayerNorm.weight
14 |
15 | entities_per_forward: 100
16 |
--------------------------------------------------------------------------------
/relik/reader/conf/small.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-small-retriever-relik-entity-linking-aida-wikipedia # -start-end-mask-0.001 # used to name the model in wandb and output dir
7 | project_name: relik-reader # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: base
13 | - model: small
14 | - data: base
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/small_nyt.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: relik-reader-deberta-small
7 | project_name: relik-reader-nyt # used to name the project in wandb
8 | offline: false # if true, wandb will not be used
9 |
10 | defaults:
11 | - _self_
12 | - training: nyt
13 | - model: nyt_small
14 | - data: nyt
15 |
--------------------------------------------------------------------------------
/relik/reader/conf/training/base.yaml:
--------------------------------------------------------------------------------
1 | seed: 94
2 |
3 | trainer:
4 | _target_: lightning.Trainer
5 | devices:
6 | - 0
7 | precision: "16-mixed"
8 | max_steps: 50000
9 | val_check_interval: 1.0
10 | num_sanity_val_steps: 0
11 | limit_val_batches: 1
12 | gradient_clip_val: 1.0
13 | accumulate_grad_batches: 1
14 |
--------------------------------------------------------------------------------
/relik/reader/conf/training/cie.yaml:
--------------------------------------------------------------------------------
1 | seed: 15
2 |
3 | trainer:
4 | _target_: lightning.Trainer
5 | devices:
6 | - 0
7 | precision: "16-mixed"
8 | max_steps: 100000
9 | val_check_interval: 1.0
10 | num_sanity_val_steps: 0
11 | limit_val_batches: 1
12 | gradient_clip_val: 1.0
13 | accumulate_grad_batches: 2
14 |
--------------------------------------------------------------------------------
/relik/reader/conf/training/large.yaml:
--------------------------------------------------------------------------------
1 | seed: 94
2 |
3 | trainer:
4 | _target_: lightning.Trainer
5 | devices:
6 | - 0
7 | precision: "16-mixed"
8 | max_steps: 50000
9 | val_check_interval: 1.0
10 | num_sanity_val_steps: 0
11 | limit_val_batches: 1
12 | gradient_clip_val: 1.0
13 | accumulate_grad_batches: 4
14 |
--------------------------------------------------------------------------------
/relik/reader/conf/training/nyt.yaml:
--------------------------------------------------------------------------------
1 | seed: 15
2 |
3 | trainer:
4 | _target_: lightning.Trainer
5 | devices:
6 | - 0
7 | precision: "16-mixed"
8 | max_steps: 100000
9 | val_check_interval: 1.0
10 | num_sanity_val_steps: 0
11 | limit_val_batches: 1
12 | gradient_clip_val: 1.0
13 | accumulate_grad_batches: 2
14 |
15 | ckpt_path:
16 |
--------------------------------------------------------------------------------
/relik/reader/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/data/__init__.py
--------------------------------------------------------------------------------
/relik/reader/data/patches.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from relik.reader.data.relik_reader_sample import RelikReaderSample
4 | from relik.reader.utils.special_symbols import NME_SYMBOL
5 |
6 |
7 | def merge_patches_predictions(sample) -> None:
8 | sample._d["predicted_window_labels"] = dict()
9 | predicted_window_labels = sample._d["predicted_window_labels"]
10 |
11 | sample._d["span_title_probabilities"] = dict()
12 | span_title_probabilities = sample._d["span_title_probabilities"]
13 |
14 | span2title = dict()
15 | for _, patch_info in sorted(sample.patches.items(), key=lambda x: x[0]):
16 | # selecting span predictions
17 | for predicted_title, predicted_spans in patch_info[
18 | "predicted_window_labels"
19 | ].items():
20 | for pred_span in predicted_spans:
21 | pred_span = tuple(pred_span)
22 | curr_title = span2title.get(pred_span)
23 | if curr_title is None or curr_title == NME_SYMBOL:
24 | span2title[pred_span] = predicted_title
25 | # else:
26 | # print("Merging at patch level")
27 |
28 | # selecting span predictions probability
29 | for predicted_span, titles_probabilities in patch_info[
30 | "span_title_probabilities"
31 | ].items():
32 | if predicted_span not in span_title_probabilities:
33 | span_title_probabilities[predicted_span] = titles_probabilities
34 |
35 | for span, title in span2title.items():
36 | if title not in predicted_window_labels:
37 | predicted_window_labels[title] = list()
38 | predicted_window_labels[title].append(span)
39 |
40 |
41 | def remove_duplicate_samples(
42 | samples: List[RelikReaderSample],
43 | ) -> List[RelikReaderSample]:
44 | seen_sample = set()
45 | samples_store = []
46 | for sample in samples:
47 | sample_id = f"{sample.doc_id}#{sample.sent_id}#{sample.offset}"
48 | if sample_id not in seen_sample:
49 | seen_sample.add(sample_id)
50 | samples_store.append(sample)
51 | return samples_store
52 |
--------------------------------------------------------------------------------
/relik/reader/data/relik_reader_data_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def flatten(lsts: List[list]) -> list:
8 | acc_lst = list()
9 | for lst in lsts:
10 | acc_lst.extend(lst)
11 | return acc_lst
12 |
13 |
14 | def batchify(tensors: List[torch.Tensor], padding_value: int = 0) -> torch.Tensor:
15 | return torch.nn.utils.rnn.pad_sequence(
16 | tensors, batch_first=True, padding_value=padding_value
17 | )
18 |
19 |
20 | def batchify_matrices(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
21 | x = max([t.shape[0] for t in tensors])
22 | y = max([t.shape[1] for t in tensors])
23 | out_matrix = torch.zeros((len(tensors), x, y))
24 | out_matrix += padding_value
25 | for i, tensor in enumerate(tensors):
26 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1]] = tensor
27 | return out_matrix
28 |
29 |
30 | def batchify_tensor(tensors: List[torch.Tensor], padding_value: int) -> torch.Tensor:
31 | x = max([t.shape[0] for t in tensors])
32 | y = max([t.shape[1] for t in tensors])
33 | rest = tensors[0].shape[2]
34 | out_matrix = torch.zeros((len(tensors), x, y, rest))
35 | out_matrix += padding_value
36 | for i, tensor in enumerate(tensors):
37 | out_matrix[i][0 : tensor.shape[0], 0 : tensor.shape[1], :] = tensor
38 | return out_matrix
39 |
40 |
41 | def chunks(lst: list, chunk_size: int) -> List[list]:
42 | chunks_acc = list()
43 | for i in range(0, len(lst), chunk_size):
44 | chunks_acc.append(lst[i : i + chunk_size])
45 | return chunks_acc
46 |
47 |
48 | def add_noise_to_value(value: int, noise_param: float):
49 | noise_value = value * noise_param
50 | noise = np.random.uniform(-noise_value, noise_value)
51 | return max(1, value + noise)
52 |
--------------------------------------------------------------------------------
/relik/reader/data/relik_reader_sample.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | from typing import Iterable
4 |
5 |
6 | class NpEncoder(json.JSONEncoder):
7 | def default(self, obj):
8 | if isinstance(obj, np.integer):
9 | return int(obj)
10 | if isinstance(obj, np.floating):
11 | return float(obj)
12 | if isinstance(obj, np.ndarray):
13 | return obj.tolist()
14 | return super(NpEncoder, self).default(obj)
15 |
16 |
17 | class RelikReaderSample:
18 | def __init__(self, **kwargs):
19 | super().__setattr__("_d", {})
20 | self._d = kwargs
21 |
22 | def __getattribute__(self, item):
23 | return super(RelikReaderSample, self).__getattribute__(item)
24 |
25 | def __getattr__(self, item):
26 | if item.startswith("__") and item.endswith("__"):
27 | # this is likely some python library-specific variable (such as __deepcopy__ for copy)
28 | # better follow standard behavior here
29 | raise AttributeError(item)
30 | elif item in self._d:
31 | return self._d[item]
32 | else:
33 | return None
34 |
35 | def __setattr__(self, key, value):
36 | if key in self._d:
37 | self._d[key] = value
38 | else:
39 | super().__setattr__(key, value)
40 |
41 | def to_jsons(self) -> str:
42 | if "predicted_window_labels" in self._d:
43 | new_obj = {
44 | k: v
45 | for k, v in self._d.items()
46 | if k != "predicted_window_labels" and k != "span_title_probabilities"
47 | }
48 | new_obj["predicted_window_labels"] = [
49 | [ss, se, pred_title]
50 | for (ss, se), pred_title in self.predicted_window_labels_chars
51 | ]
52 | else:
53 | return json.dumps(self._d, cls=NpEncoder)
54 |
55 | def to_dict(self) -> dict:
56 | return self._d
57 |
58 |
59 | def load_relik_reader_samples(path: str) -> Iterable[RelikReaderSample]:
60 | with open(path) as f:
61 | for line in f:
62 | jsonl_line = json.loads(line.strip())
63 | relik_reader_sample = RelikReaderSample(**jsonl_line)
64 | yield relik_reader_sample
65 |
--------------------------------------------------------------------------------
/relik/reader/lightning_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/lightning_modules/__init__.py
--------------------------------------------------------------------------------
/relik/reader/lightning_modules/relik_reader_pl_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | import lightning
4 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5 |
6 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
7 |
8 |
9 | class RelikReaderPLModule(lightning.LightningModule):
10 | def __init__(
11 | self,
12 | cfg: dict,
13 | transformer_model: str,
14 | additional_special_symbols: int,
15 | num_layers: Optional[int] = None,
16 | activation: str = "gelu",
17 | linears_hidden_size: Optional[int] = 512,
18 | use_last_k_layers: int = 1,
19 | training: bool = False,
20 | *args: Any,
21 | **kwargs: Any
22 | ):
23 | super().__init__(*args, **kwargs)
24 | self.save_hyperparameters()
25 | self.relik_reader_core_model = RelikReaderForSpanExtraction(
26 | transformer_model,
27 | additional_special_symbols,
28 | num_layers,
29 | activation,
30 | linears_hidden_size,
31 | use_last_k_layers,
32 | training=training,
33 | )
34 | self.optimizer_factory = None
35 |
36 | def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
37 | relik_output = self.relik_reader_core_model(**batch)
38 | self.log("train-loss", relik_output["loss"])
39 | return relik_output["loss"]
40 |
41 | def validation_step(
42 | self, batch: dict, *args: Any, **kwargs: Any
43 | ) -> Optional[STEP_OUTPUT]:
44 | return
45 |
46 | def set_optimizer_factory(self, optimizer_factory) -> None:
47 | self.optimizer_factory = optimizer_factory
48 |
49 | def configure_optimizers(self) -> OptimizerLRScheduler:
50 | return self.optimizer_factory(self.relik_reader_core_model)
51 |
--------------------------------------------------------------------------------
/relik/reader/lightning_modules/relik_reader_re_pl_module.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Optional
2 |
3 | import lightning
4 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
5 |
6 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
7 |
8 |
9 | class RelikReaderREPLModule(lightning.LightningModule):
10 | def __init__(
11 | self,
12 | cfg: dict,
13 | transformer_model: str,
14 | additional_special_symbols: int,
15 | additional_special_symbols_types: Optional[int] = 0,
16 | entity_type_loss: bool = None,
17 | add_entity_embedding: bool = None,
18 | num_layers: Optional[int] = None,
19 | activation: str = "gelu",
20 | linears_hidden_size: Optional[int] = 512,
21 | use_last_k_layers: int = 1,
22 | training: bool = False,
23 | default_reader_class: str = "relik.reader.pytorch_modules.hf.modeling_relik.RelikReaderREModel",
24 | *args: Any,
25 | **kwargs: Any
26 | ):
27 | super().__init__(*args, **kwargs)
28 | self.save_hyperparameters()
29 |
30 | self.relik_reader_re_model = RelikReaderForTripletExtraction(
31 | transformer_model,
32 | additional_special_symbols,
33 | additional_special_symbols_types,
34 | entity_type_loss,
35 | add_entity_embedding,
36 | num_layers,
37 | activation,
38 | linears_hidden_size,
39 | use_last_k_layers,
40 | training=training,
41 | default_reader_class=default_reader_class,
42 | **kwargs,
43 | )
44 | self.optimizer_factory = None
45 |
46 | def training_step(self, batch: dict, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
47 | relik_output = self.relik_reader_re_model(**batch)
48 | self.log("train-loss", relik_output["loss"])
49 | self.log("train-start_loss", relik_output["ned_start_loss"])
50 | self.log("train-end_loss", relik_output["ned_end_loss"])
51 | self.log("train-relation_loss", relik_output["re_loss"])
52 | if "ned_type_loss" in relik_output:
53 | self.log("train-ned_type_loss", relik_output["ned_type_loss"])
54 | return relik_output["loss"]
55 |
56 | def validation_step(
57 | self, batch: dict, *args: Any, **kwargs: Any
58 | ) -> Optional[STEP_OUTPUT]:
59 | return
60 |
61 | def set_optimizer_factory(self, optimizer_factory) -> None:
62 | self.optimizer_factory = optimizer_factory
63 |
64 | def configure_optimizers(self) -> OptimizerLRScheduler:
65 | return self.optimizer_factory(self.relik_reader_re_model)
66 |
--------------------------------------------------------------------------------
/relik/reader/pytorch_modules/__init__.py:
--------------------------------------------------------------------------------
1 | # from relik.reader.pytorch_modules.hf.modeling_relik import RelikReaderSpanModel
2 | # from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
3 | # from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
4 |
5 |
6 | RELIK_READER_CLASS_MAP = {
7 | "RelikReaderSpanModel": "relik.reader.pytorch_modules.span.RelikReaderForSpanExtraction",
8 | "RelikReaderREModel": "relik.reader.pytorch_modules.triplet.RelikReaderForTripletExtraction",
9 | }
10 |
--------------------------------------------------------------------------------
/relik/reader/pytorch_modules/hf/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_relik import RelikReaderConfig
2 | from .modeling_relik import RelikReaderREModel
3 |
--------------------------------------------------------------------------------
/relik/reader/pytorch_modules/hf/configuration_relik.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from transformers import AutoConfig
4 | from transformers.configuration_utils import PretrainedConfig
5 |
6 |
7 | class RelikReaderConfig(PretrainedConfig):
8 | model_type = "relik-reader"
9 |
10 | def __init__(
11 | self,
12 | transformer_model: str = "microsoft/deberta-v3-base",
13 | additional_special_symbols: int = 101,
14 | additional_special_symbols_types: Optional[int] = 0,
15 | num_layers: Optional[int] = None,
16 | activation: str = "gelu",
17 | linears_hidden_size: Optional[int] = 512,
18 | use_last_k_layers: int = 1,
19 | entity_type_loss: bool = False,
20 | add_entity_embedding: bool = None,
21 | binary_end_logits: bool = False,
22 | training: bool = False,
23 | default_reader_class: Optional[str] = None,
24 | threshold: Optional[float] = 0.5,
25 | **kwargs
26 | ) -> None:
27 | # TODO: add name_or_path to kwargs
28 | self.transformer_model = transformer_model
29 | self.additional_special_symbols = additional_special_symbols
30 | self.additional_special_symbols_types = additional_special_symbols_types
31 | self.num_layers = num_layers
32 | self.activation = activation
33 | self.linears_hidden_size = linears_hidden_size
34 | self.use_last_k_layers = use_last_k_layers
35 | self.entity_type_loss = entity_type_loss
36 | self.add_entity_embedding = (
37 | True
38 | if add_entity_embedding is None and entity_type_loss
39 | else add_entity_embedding
40 | )
41 | self.threshold = threshold
42 | self.binary_end_logits = binary_end_logits
43 | self.training = training
44 | self.default_reader_class = default_reader_class
45 | super().__init__(**kwargs)
46 |
--------------------------------------------------------------------------------
/relik/reader/pytorch_modules/optim/__init__.py:
--------------------------------------------------------------------------------
1 | from relik.reader.pytorch_modules.optim.adamw_with_warmup import (
2 | AdamWWithWarmupOptimizer,
3 | )
4 | from relik.reader.pytorch_modules.optim.layer_wise_lr_decay import (
5 | LayerWiseLRDecayOptimizer,
6 | )
7 |
--------------------------------------------------------------------------------
/relik/reader/pytorch_modules/optim/adamw_with_warmup.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch
4 | import transformers
5 | from torch.optim import AdamW
6 |
7 |
8 | class AdamWWithWarmupOptimizer:
9 | def __init__(
10 | self,
11 | lr: float | List[float],
12 | warmup_steps: int,
13 | total_steps: int,
14 | weight_decay: float,
15 | no_decay_params: List[str],
16 | other_lr_params: List[str] = None,
17 | ):
18 | self.lr = lr[0] if isinstance(lr, list) else lr
19 | self.lr2 = lr[1] if isinstance(lr, list) else lr
20 | self.warmup_steps = warmup_steps
21 | self.total_steps = total_steps
22 | self.weight_decay = weight_decay
23 | self.no_decay_params = no_decay_params
24 | self.other_lr_params = other_lr_params or [] # Ensure it's a list
25 |
26 | def group_params(self, module: torch.nn.Module) -> list:
27 | decay_params = set()
28 | no_decay_params = set()
29 | other_lr_params = set()
30 | # Populate parameter sets
31 | for n, p in module.named_parameters():
32 | if any(nd in n for nd in self.no_decay_params):
33 | no_decay_params.add(p)
34 | elif any(olr in n for olr in self.other_lr_params):
35 | other_lr_params.add(p)
36 | else:
37 | decay_params.add(p)
38 | # Group parameters
39 | optimizer_grouped_parameters = []
40 | if decay_params:
41 | optimizer_grouped_parameters.append(
42 | {"params": list(decay_params), "weight_decay": self.weight_decay}
43 | )
44 | if no_decay_params:
45 | optimizer_grouped_parameters.append(
46 | {"params": list(no_decay_params), "weight_decay": 0.0}
47 | )
48 | if other_lr_params:
49 | optimizer_grouped_parameters.append(
50 | {
51 | "params": list(other_lr_params),
52 | "lr": self.lr2,
53 | "weight_decay": self.weight_decay,
54 | }
55 | )
56 |
57 | return optimizer_grouped_parameters
58 |
59 | def __call__(self, module: torch.nn.Module):
60 | optimizer_grouped_parameters = self.group_params(module)
61 | optimizer = AdamW(
62 | optimizer_grouped_parameters, lr=self.lr, weight_decay=self.weight_decay
63 | )
64 | scheduler = transformers.get_linear_schedule_with_warmup(
65 | optimizer, self.warmup_steps, self.total_steps
66 | )
67 | return {
68 | "optimizer": optimizer,
69 | "lr_scheduler": {
70 | "scheduler": scheduler,
71 | "interval": "step",
72 | "frequency": 1,
73 | },
74 | }
75 |
--------------------------------------------------------------------------------
/relik/reader/trainer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/trainer/__init__.py
--------------------------------------------------------------------------------
/relik/reader/trainer/predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pprint import pprint
3 | from typing import Optional
4 |
5 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
6 | from relik.reader.utils.strong_matching_eval import StrongMatching
7 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples
8 |
9 |
10 | def predict(
11 | model_path: str,
12 | dataset_path: str,
13 | token_batch_size: int,
14 | is_eval: bool,
15 | output_path: Optional[str],
16 | ) -> None:
17 | relik_reader = RelikReaderForSpanExtraction(
18 | model_path, dataset_kwargs={"use_nme": True}, device="cuda"
19 | )
20 | samples = list(load_relik_reader_samples(dataset_path))
21 | predicted_samples = relik_reader.read(
22 | samples=samples, token_batch_size=token_batch_size, progress_bar=True
23 | )
24 | if is_eval:
25 | eval_dict = StrongMatching()(predicted_samples)
26 | pprint(eval_dict)
27 | if output_path is not None:
28 | with open(output_path, "w") as f:
29 | for sample in predicted_samples:
30 | f.write(sample.to_jsons() + "\n")
31 |
32 |
33 | def parse_arg() -> argparse.Namespace:
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--model-path",
37 | required=True,
38 | )
39 | parser.add_argument("--dataset-path", "-i", required=True)
40 | parser.add_argument("--is-eval", action="store_true")
41 | parser.add_argument(
42 | "--output-path",
43 | "-o",
44 | )
45 | parser.add_argument("--token-batch-size", default=4096)
46 | return parser.parse_args()
47 |
48 |
49 | def main():
50 | args = parse_arg()
51 | predict(
52 | args.model_path,
53 | args.dataset_path,
54 | token_batch_size=args.token_batch_size,
55 | is_eval=args.is_eval,
56 | output_path=args.output_path,
57 | )
58 |
59 |
60 | if __name__ == "__main__":
61 | main()
62 |
--------------------------------------------------------------------------------
/relik/reader/trainer/predict_cie.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples
3 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
4 | from relik.reader.utils.relation_matching_eval import StrongMatching
5 | from relik.inference.data.objects import AnnotationType
6 |
7 | import torch
8 |
9 | import numpy as np
10 | from sklearn.metrics import precision_recall_curve
11 |
12 |
13 | def find_optimal_threshold(scores, labels):
14 | # Calculate precision-recall pairs for various threshold values
15 | precision, recall, thresholds = precision_recall_curve(labels, scores)
16 | # Add the end point for thresholds, which is the maximum score + 1 to ensure completeness
17 | thresholds = np.append(thresholds, thresholds[-1] + 1)
18 | # Calculate F1 scores from precision and recall for each threshold
19 | f1_scores = 2 * (precision * recall) / (precision + recall)
20 | # Handle the case where precision + recall equals zero (to avoid division by zero)
21 | f1_scores = np.nan_to_num(f1_scores)
22 | # Find the index of the maximum F1 score
23 | max_index = np.argmax(f1_scores)
24 | # Find the threshold and F1 score corresponding to the maximum F1 score
25 | optimal_threshold = thresholds[max_index]
26 | best_f1 = f1_scores[max_index]
27 |
28 | return optimal_threshold, best_f1
29 |
30 |
31 | def eval(
32 | model_path,
33 | data_path,
34 | is_eval,
35 | output_path=None,
36 | compute_threshold=False,
37 | save_threshold=False,
38 | ):
39 | device = "cuda" if torch.cuda.is_available() else "cpu"
40 | device = "mps" if torch.backends.mps.is_available() else device
41 | print(f"Device: {device}")
42 | reader = RelikReaderForTripletExtraction(model_path, training=False, device=device)
43 | samples = list(load_relik_reader_samples(data_path))
44 | optimal_threshold = None
45 | if compute_threshold:
46 | predicted_samples = reader.read(
47 | samples=samples,
48 | progress_bar=True,
49 | annotation_type=AnnotationType.WORD,
50 | return_threshold_utils=True,
51 | )
52 | re_probabilities, re_labels = [], []
53 | for sample in predicted_samples:
54 | re_probabilities.extend(sample.re_probabilities.flatten())
55 | re_labels.extend(sample.re_labels.flatten())
56 | optimal_threshold, best_f1 = find_optimal_threshold(re_probabilities, re_labels)
57 | print(f"Optimal threshold: {optimal_threshold}")
58 | print(f"Best F1: {best_f1}")
59 | # set the threshold to the optimal threshold
60 | samples = list(load_relik_reader_samples(data_path))
61 | if save_threshold:
62 | reader.relik_reader_model.config.threshold = optimal_threshold
63 | reader.relik_reader_model.save_pretrained(model_path)
64 |
65 | predicted_samples = reader.read(
66 | samples=samples,
67 | progress_bar=True,
68 | annotation_type=AnnotationType.WORD,
69 | relation_threshold=optimal_threshold,
70 | )
71 |
72 | if is_eval:
73 | strong_matching_metric = StrongMatching()
74 | predicted_samples = list(predicted_samples)
75 | for k, v in strong_matching_metric(predicted_samples).items():
76 | print(f"test_{k}", v)
77 | if output_path is not None:
78 | with open(output_path, "w") as f:
79 | for sample in predicted_samples:
80 | f.write(sample.to_jsons() + "\n")
81 |
82 |
83 | def main():
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument(
86 | "--model_path",
87 | type=str,
88 | required=True,
89 | )
90 | parser.add_argument(
91 | "--data_path",
92 | type=str,
93 | required=True,
94 | )
95 | parser.add_argument("--is-eval", action="store_true")
96 | parser.add_argument("--output_path", type=str, default=None)
97 | parser.add_argument("--compute-threshold", action="store_true")
98 | parser.add_argument("--save-threshold", action="store_true")
99 | args = parser.parse_args()
100 | eval(
101 | args.model_path,
102 | args.data_path,
103 | args.is_eval,
104 | args.output_path,
105 | args.compute_threshold,
106 | args.save_threshold,
107 | )
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/relik/reader/trainer/predict_re.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples
3 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
4 | from relik.reader.utils.relation_matching_eval import StrongMatching
5 | from relik.inference.data.objects import AnnotationType
6 |
7 | import torch
8 |
9 | import numpy as np
10 | from sklearn.metrics import precision_recall_curve
11 |
12 |
13 | def find_optimal_threshold(scores, labels):
14 | # Calculate precision-recall pairs for various threshold values
15 | precision, recall, thresholds = precision_recall_curve(labels, scores)
16 | # Add the end point for thresholds, which is the maximum score + 1 to ensure completeness
17 | thresholds = np.append(thresholds, thresholds[-1] + 1)
18 | # Calculate F1 scores from precision and recall for each threshold
19 | f1_scores = 2 * (precision * recall) / (precision + recall)
20 | # Handle the case where precision + recall equals zero (to avoid division by zero)
21 | f1_scores = np.nan_to_num(f1_scores)
22 | # Find the index of the maximum F1 score
23 | max_index = np.argmax(f1_scores)
24 | # Find the threshold and F1 score corresponding to the maximum F1 score
25 | optimal_threshold = thresholds[max_index]
26 | best_f1 = f1_scores[max_index]
27 |
28 | return optimal_threshold, best_f1
29 |
30 |
31 | def eval(
32 | model_path,
33 | data_path,
34 | is_eval,
35 | output_path=None,
36 | compute_threshold=False,
37 | save_threshold=False,
38 | ):
39 | device = "cuda" if torch.cuda.is_available() else "cpu"
40 | device = "mps" if torch.backends.mps.is_available() else device
41 | print(f"Device: {device}")
42 | reader = RelikReaderForTripletExtraction(model_path, training=False, device=device)
43 | samples = list(load_relik_reader_samples(data_path))
44 | optimal_threshold = None
45 | if compute_threshold:
46 | predicted_samples = reader.read(
47 | samples=samples,
48 | progress_bar=True,
49 | annotation_type=AnnotationType.WORD,
50 | return_threshold_utils=True,
51 | )
52 | re_probabilities, re_labels = [], []
53 | for sample in predicted_samples:
54 | re_probabilities.extend(sample.re_probabilities.flatten())
55 | re_labels.extend(sample.re_labels.flatten())
56 | optimal_threshold, best_f1 = find_optimal_threshold(re_probabilities, re_labels)
57 | print(f"Optimal threshold: {optimal_threshold}")
58 | print(f"Best F1: {best_f1}")
59 | # set the threshold to the optimal threshold
60 | samples = list(load_relik_reader_samples(data_path))
61 | if save_threshold:
62 | reader.relik_reader_model.config.threshold = optimal_threshold
63 | reader.relik_reader_model.save_pretrained(model_path)
64 |
65 | predicted_samples = reader.read(
66 | samples=samples,
67 | progress_bar=True,
68 | annotation_type=AnnotationType.WORD,
69 | relation_threshold=optimal_threshold,
70 | )
71 |
72 | if is_eval:
73 | strong_matching_metric = StrongMatching()
74 | predicted_samples = list(predicted_samples)
75 | for k, v in strong_matching_metric(predicted_samples).items():
76 | print(f"test_{k}", v)
77 | if output_path is not None:
78 | with open(output_path, "w") as f:
79 | for sample in predicted_samples:
80 | f.write(sample.to_jsons() + "\n")
81 |
82 |
83 | def main():
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument(
86 | "--model_path",
87 | type=str,
88 | required=True,
89 | )
90 | parser.add_argument(
91 | "--data_path",
92 | type=str,
93 | required=True,
94 | )
95 | parser.add_argument("--is-eval", action="store_true")
96 | parser.add_argument("--output_path", type=str, default=None)
97 | parser.add_argument("--compute-threshold", action="store_true")
98 | parser.add_argument("--save-threshold", action="store_true")
99 | args = parser.parse_args()
100 | eval(
101 | args.model_path,
102 | args.data_path,
103 | args.is_eval,
104 | args.output_path,
105 | args.compute_threshold,
106 | args.save_threshold,
107 | )
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/relik/reader/trainer/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from pprint import pprint
4 | import hydra
5 | import lightning
6 | from hydra.utils import to_absolute_path, get_original_cwd
7 | from lightning import Trainer
8 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
9 | from lightning.pytorch.loggers.wandb import WandbLogger
10 | from omegaconf import DictConfig, OmegaConf, open_dict
11 | import omegaconf
12 | import torch
13 | from torch.utils.data import DataLoader
14 |
15 | from relik.reader.data.relik_reader_data import RelikDataset
16 | from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
17 | from relik.reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer
18 | from relik.reader.utils.special_symbols import get_special_symbols
19 | from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback
20 | from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback
21 |
22 |
23 | def train(cfg: DictConfig) -> None:
24 | lightning.seed_everything(cfg.training.seed)
25 | # check if deterministic algorithms are available
26 | if "deterministic" in cfg and cfg.deterministic:
27 | torch.use_deterministic_algorithms(True, warn_only=True)
28 |
29 | # log the configuration
30 | pprint(OmegaConf.to_container(cfg, resolve=True))
31 |
32 | special_symbols = get_special_symbols(cfg.model.entities_per_forward)
33 |
34 | # model declaration
35 | model = RelikReaderPLModule(
36 | cfg=OmegaConf.to_container(cfg),
37 | transformer_model=cfg.model.model.transformer_model,
38 | additional_special_symbols=len(special_symbols),
39 | training=True,
40 | )
41 |
42 | # optimizer declaration
43 | opt_conf = cfg.model.optimizer
44 | electra_optimizer_factory = LayerWiseLRDecayOptimizer(
45 | lr=opt_conf.lr,
46 | warmup_steps=opt_conf.warmup_steps,
47 | total_steps=opt_conf.total_steps,
48 | total_reset=opt_conf.total_reset,
49 | no_decay_params=opt_conf.no_decay_params,
50 | weight_decay=opt_conf.weight_decay,
51 | lr_decay=opt_conf.lr_decay,
52 | )
53 |
54 | model.set_optimizer_factory(electra_optimizer_factory)
55 |
56 | # datasets declaration
57 | train_dataset: RelikDataset = hydra.utils.instantiate(
58 | cfg.data.train_dataset,
59 | dataset_path=to_absolute_path(cfg.data.train_dataset_path),
60 | special_symbols=special_symbols,
61 | )
62 |
63 | # update of validation dataset config with special_symbols since they
64 | # are required even from the EvaluationCallback dataset_config
65 | with open_dict(cfg):
66 | cfg.data.val_dataset.special_symbols = special_symbols
67 |
68 | val_dataset: RelikDataset = hydra.utils.instantiate(
69 | cfg.data.val_dataset,
70 | dataset_path=to_absolute_path(cfg.data.val_dataset_path),
71 | )
72 |
73 | # callbacks declaration
74 | callbacks = [
75 | ELStrongMatchingCallback(
76 | to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
77 | ),
78 | ModelCheckpoint(
79 | "model",
80 | filename="{epoch}-{val_core_f1:.2f}",
81 | monitor="val_core_f1",
82 | mode="max",
83 | ),
84 | LearningRateMonitor(),
85 | ]
86 | if (
87 | cfg.data.train_dataset.section_size == None
88 | ): # If section_size is None, we shuffle the dataset. This increases a lot the speed for bigger datasets but be careful, as it will shuffle the file itself at the end of each epoch
89 | callbacks.append(ShuffleTrainCallback())
90 |
91 | wandb_logger = WandbLogger(
92 | cfg.model_name, project=cfg.project_name, offline=cfg.offline
93 | )
94 |
95 | # trainer declaration
96 | trainer: Trainer = hydra.utils.instantiate(
97 | cfg.training.trainer,
98 | callbacks=callbacks,
99 | logger=wandb_logger,
100 | )
101 |
102 | # model.relik_reader_core_model._tokenizer = train_dataset.tokenizer
103 |
104 | # Trainer fit
105 | trainer.fit(
106 | model=model,
107 | train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
108 | val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
109 | )
110 |
111 | # if cfg.training.save_model_path:
112 | experiment_path = Path(wandb_logger.experiment.dir)
113 | model = RelikReaderPLModule.load_from_checkpoint(
114 | trainer.checkpoint_callback.best_model_path
115 | )
116 | model.relik_reader_core_model._tokenizer = train_dataset.tokenizer
117 | model.relik_reader_core_model.save_pretrained(experiment_path)
118 |
119 |
120 | @hydra.main(config_path="../conf", config_name="config", version_base="1.3")
121 | def main(conf: omegaconf.DictConfig):
122 | train(conf)
123 |
124 |
125 | if __name__ == "__main__":
126 | main()
127 |
--------------------------------------------------------------------------------
/relik/reader/trainer/train_re.py:
--------------------------------------------------------------------------------
1 | import hydra
2 | from pathlib import Path
3 | import lightning
4 | from hydra.utils import to_absolute_path
5 | from lightning import Trainer
6 | from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
7 | from lightning.pytorch.loggers.wandb import WandbLogger
8 | from omegaconf import DictConfig, OmegaConf, open_dict
9 | from torch.utils.data import DataLoader
10 |
11 | from relik.reader.data.relik_reader_re_data import RelikREDataset
12 | from relik.reader.lightning_modules.relik_reader_re_pl_module import (
13 | RelikReaderREPLModule,
14 | )
15 | from relik.reader.pytorch_modules.optim import (
16 | AdamWWithWarmupOptimizer,
17 | LayerWiseLRDecayOptimizer,
18 | )
19 | from relik.reader.utils.relation_matching_eval import REStrongMatchingCallback
20 | from relik.reader.utils.special_symbols import get_special_symbols_re
21 | from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback
22 |
23 |
24 | @hydra.main(config_path="../conf", config_name="config")
25 | def train(cfg: DictConfig) -> None:
26 | lightning.seed_everything(cfg.training.seed)
27 |
28 | special_symbols = get_special_symbols_re(cfg.model.relations_per_forward)
29 |
30 | # datasets declaration
31 | train_dataset: RelikREDataset = hydra.utils.instantiate(
32 | cfg.data.train_dataset,
33 | dataset_path=to_absolute_path(cfg.data.train_dataset_path),
34 | special_symbols_re=special_symbols,
35 | )
36 |
37 | # update of validation dataset config with special_symbols since they
38 | # are required even from the EvaluationCallback dataset_config
39 | with open_dict(cfg):
40 | cfg.data.val_dataset.special_symbols_re = special_symbols
41 |
42 | val_dataset: RelikREDataset = hydra.utils.instantiate(
43 | cfg.data.val_dataset,
44 | dataset_path=to_absolute_path(cfg.data.val_dataset_path),
45 | )
46 |
47 | if val_dataset.materialize_samples:
48 | list(val_dataset.dataset_iterator_func())
49 | # model declaration
50 | model = RelikReaderREPLModule(
51 | cfg=OmegaConf.to_container(cfg),
52 | transformer_model=cfg.model.model.transformer_model,
53 | additional_special_symbols=len(special_symbols),
54 | training=True,
55 | )
56 | model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
57 | # optimizer declaration
58 | opt_conf = cfg.model.optimizer
59 |
60 | if "total_reset" not in opt_conf:
61 | optimizer_factory = AdamWWithWarmupOptimizer(
62 | lr=opt_conf.lr,
63 | warmup_steps=opt_conf.warmup_steps,
64 | total_steps=opt_conf.total_steps,
65 | no_decay_params=opt_conf.no_decay_params,
66 | weight_decay=opt_conf.weight_decay,
67 | )
68 | else:
69 | optimizer_factory = LayerWiseLRDecayOptimizer(
70 | lr=opt_conf.lr,
71 | warmup_steps=opt_conf.warmup_steps,
72 | total_steps=opt_conf.total_steps,
73 | total_reset=opt_conf.total_reset,
74 | no_decay_params=opt_conf.no_decay_params,
75 | weight_decay=opt_conf.weight_decay,
76 | lr_decay=opt_conf.lr_decay,
77 | )
78 |
79 | model.set_optimizer_factory(optimizer_factory)
80 | # callbacks declaration
81 | callbacks = [
82 | REStrongMatchingCallback(
83 | to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
84 | ),
85 | ModelCheckpoint(
86 | "model",
87 | filename="{epoch}-{val_f1:.2f}",
88 | monitor="val_f1",
89 | mode="max",
90 | ),
91 | LearningRateMonitor(),
92 | ]
93 |
94 | if (
95 | cfg.data.train_dataset.section_size == None
96 | ): # If section_size is None, we shuffle the dataset. This increases a lot the speed for bigger datasets but be careful, as it will shuffle the file itself at the end of each epoch
97 | callbacks.append(
98 | ShuffleTrainCallback(
99 | data_path=to_absolute_path(cfg.data.train_dataset_path)
100 | )
101 | )
102 |
103 | wandb_logger = WandbLogger(
104 | cfg.model_name, project=cfg.project_name, offline=cfg.offline
105 | )
106 |
107 | # trainer declaration
108 | trainer: Trainer = hydra.utils.instantiate(
109 | cfg.training.trainer,
110 | callbacks=callbacks,
111 | logger=wandb_logger,
112 | )
113 |
114 | # Trainer fit
115 | trainer.fit(
116 | model=model,
117 | train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=0),
118 | val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
119 | ckpt_path=(
120 | cfg.training.ckpt_path
121 | if "ckpt_path" in cfg.training and cfg.training.ckpt_path
122 | else None
123 | ),
124 | )
125 |
126 | # Load best checkpoint
127 | # if cfg.training.save_model_path:
128 | model = RelikReaderREPLModule.load_from_checkpoint(
129 | trainer.checkpoint_callback.best_model_path
130 | )
131 | experiment_path = Path(wandb_logger.experiment.dir)
132 | model.relik_reader_re_model._tokenizer = train_dataset.tokenizer
133 | model.relik_reader_re_model.save_pretrained(experiment_path / "hf_model")
134 |
135 |
136 | def main():
137 | train()
138 |
139 |
140 | if __name__ == "__main__":
141 | main()
142 |
--------------------------------------------------------------------------------
/relik/reader/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/reader/utils/__init__.py
--------------------------------------------------------------------------------
/relik/reader/utils/metrics.py:
--------------------------------------------------------------------------------
1 | def safe_divide(num: float, den: float) -> float:
2 | if den == 0:
3 | return 0
4 | else:
5 | return num / den
6 |
7 |
8 | def f1_measure(precision: float, recall: float) -> float:
9 | if precision == 0 or recall == 0:
10 | return 0.0
11 | return safe_divide(2 * precision * recall, (precision + recall))
12 |
13 |
14 | def compute_metrics(total_correct, total_preds, total_gold):
15 | precision = safe_divide(total_correct, total_preds)
16 | recall = safe_divide(total_correct, total_gold)
17 | f1 = f1_measure(precision, recall)
18 | return precision, recall, f1
19 |
--------------------------------------------------------------------------------
/relik/reader/utils/save_load_utilities.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from typing import Tuple
4 |
5 | import omegaconf
6 | import torch
7 |
8 | from relik.common.utils import from_cache
9 | from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
10 |
11 | CKPT_FILE_NAME = "model.ckpt"
12 | CONFIG_FILE_NAME = "cfg.yaml"
13 |
14 |
15 | def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None:
16 | if not os.path.exists(output_dir):
17 | os.makedirs(output_dir)
18 | else:
19 | print(f"{output_dir} already exists, aborting operation")
20 | exit(1)
21 |
22 | relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint(
23 | pl_module_ckpt_path
24 | )
25 | torch.save(
26 | relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}"
27 | )
28 | with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f:
29 | omegaconf.OmegaConf.save(
30 | omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f
31 | )
32 |
33 |
34 | def load_model_and_conf(
35 | model_dir_path: str,
36 | ) -> Tuple[torch.nn.Module, omegaconf.DictConfig]:
37 | # TODO: quick workaround to load the model from HF hub
38 | model_dir = from_cache(
39 | model_dir_path,
40 | filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME],
41 | cache_dir=None,
42 | force_download=False,
43 | )
44 |
45 | ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}"
46 | model = torch.load(ckpt_path, map_location=torch.device("cpu"))
47 |
48 | model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}"
49 | model_conf = omegaconf.OmegaConf.load(model_cfg_path)
50 | return model, model_conf
51 |
52 |
53 | def parse_arg() -> argparse.Namespace:
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument(
56 | "--ckpt",
57 | help="Path to the pytorch lightning ckpt you want to convert.",
58 | required=True,
59 | )
60 | parser.add_argument(
61 | "--output-dir",
62 | "-o",
63 | help="The output dir to store the bare models and the config.",
64 | required=True,
65 | )
66 | return parser.parse_args()
67 |
68 |
69 | def main():
70 | args = parse_arg()
71 | convert_pl_module(args.ckpt, args.output_dir)
72 |
73 |
74 | if __name__ == "__main__":
75 | main()
76 |
--------------------------------------------------------------------------------
/relik/reader/utils/shuffle_train_callback.py:
--------------------------------------------------------------------------------
1 | from lightning.pytorch.callbacks import Callback
2 |
3 | from relik.common.log import get_logger
4 |
5 | import os
6 |
7 | import random
8 |
9 | logger = get_logger()
10 |
11 |
12 | class ShuffleTrainCallback(Callback):
13 | def __init__(self, shuffle_every: int = 1, data_path: str = None):
14 | self.shuffle_every = shuffle_every
15 | self.data_path = data_path
16 |
17 | def on_train_epoch_end(self, trainer, pl_module):
18 | if (trainer.current_epoch + 1) % self.shuffle_every == 0:
19 | logger.info("Shuffling train dataset")
20 | # os.system(f"shuf {self.data_path} > {self.data_path}.shuf")
21 | # os.system(f"mv {self.data_path}.shuf {self.data_path}")
22 | lines = open(self.data_path).readlines()
23 | random.shuffle(lines)
24 | open(self.data_path, "w").writelines(lines)
25 |
--------------------------------------------------------------------------------
/relik/reader/utils/special_symbols.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | NME_SYMBOL = "--NME--"
4 |
5 |
6 | def get_special_symbols(num_entities: int) -> List[str]:
7 | return [NME_SYMBOL] + [f"[E-{i}]" for i in range(num_entities)]
8 |
9 |
10 | def get_special_symbols_re(num_entities: int, use_nme: bool = False) -> List[str]:
11 | if use_nme:
12 | return [NME_SYMBOL] + [f"[R-{i}]" for i in range(num_entities)]
13 | else:
14 | return [f"[R-{i}]" for i in range(num_entities)]
15 |
--------------------------------------------------------------------------------
/relik/retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from relik.retriever.pytorch_modules.model import GoldenRetriever
2 |
--------------------------------------------------------------------------------
/relik/retriever/callbacks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/callbacks/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/common/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/common/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/common/model_inputs.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from collections import UserDict
4 | from typing import Any, Union
5 |
6 | import torch
7 | from lightning.fabric.utilities import move_data_to_device
8 |
9 | from relik.common.log import get_logger
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | class ModelInputs(UserDict):
15 | """Model input dictionary wrapper."""
16 |
17 | def __getattr__(self, item: str):
18 | try:
19 | return self.data[item]
20 | except KeyError:
21 | raise AttributeError(f"`ModelInputs` has no attribute `{item}`")
22 |
23 | def __getitem__(self, item: str) -> Any:
24 | return self.data[item]
25 |
26 | def __getstate__(self):
27 | return {"data": self.data}
28 |
29 | def __setstate__(self, state):
30 | if "data" in state:
31 | self.data = state["data"]
32 |
33 | def keys(self):
34 | """A set-like object providing a view on D's keys."""
35 | return self.data.keys()
36 |
37 | def values(self):
38 | """An object providing a view on D's values."""
39 | return self.data.values()
40 |
41 | def items(self):
42 | """A set-like object providing a view on D's items."""
43 | return self.data.items()
44 |
45 | def to(self, device: Union[str, torch.device]) -> ModelInputs:
46 | """
47 | Send all tensors values to device.
48 | Args:
49 | device (`str` or `torch.device`): The device to put the tensors on.
50 | Returns:
51 | :class:`tokenizers.ModelInputs`: The same instance of :class:`~tokenizers.ModelInputs`
52 | after modification.
53 | """
54 | self.data = move_data_to_device(self.data, device)
55 | return self
56 |
--------------------------------------------------------------------------------
/relik/retriever/common/sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from torch.utils.data.sampler import BatchSampler, Sampler, SubsetRandomSampler
4 |
5 |
6 | def identity(x):
7 | return x
8 |
9 |
10 | class SortedSampler(Sampler):
11 | """
12 | Samples elements sequentially, always in the same order.
13 |
14 | Args:
15 | data (`obj`: `Iterable`):
16 | Iterable data.
17 | sort_key (`obj`: `Callable`):
18 | Specifies a function of one argument that is used to
19 | extract a numerical comparison key from each list element.
20 |
21 | Example:
22 | >>> list(SortedSampler(range(10), sort_key=lambda i: -i))
23 | [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
24 |
25 | """
26 |
27 | def __init__(self, data, sort_key=identity):
28 | super().__init__(data)
29 | self.data = data
30 | self.sort_key = sort_key
31 | zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
32 | zip_ = sorted(zip_, key=lambda r: r[1])
33 | self.sorted_indexes = [item[0] for item in zip_]
34 |
35 | def __iter__(self):
36 | return iter(self.sorted_indexes)
37 |
38 | def __len__(self):
39 | return len(self.data)
40 |
41 |
42 | class BucketBatchSampler(BatchSampler):
43 | """
44 | `BucketBatchSampler` toggles between `sampler` batches and sorted batches.
45 | Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
46 | random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice
47 | versa.
48 | Background:
49 | ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like
50 | ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar
51 | size length to reduce the padding required for each batch while maintaining some noise
52 | through bucketing.
53 | **AllenNLP Implementation:**
54 | https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py
55 | **torchtext Implementation:**
56 | https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225
57 |
58 | Args:
59 | sampler (`obj`: `torch.data.utils.sampler.Sampler):
60 | batch_size (`int`):
61 | Size of mini-batch.
62 | drop_last (`bool`, optional, defaults to `False`):
63 | If `True` the sampler will drop the last batch if its size would be less than `batch_size`.
64 | sort_key (`obj`: `Callable`, optional, defaults to `identity`):
65 | Callable to specify a comparison key for sorting.
66 | bucket_size_multiplier (`int`, optional, defaults to `100`):
67 | Buckets are of size `batch_size * bucket_size_multiplier`.
68 | Example:
69 | >>> from torchnlp.random import set_seed
70 | >>> set_seed(123)
71 | >>>
72 | >>> from torch.utils.data.sampler import SequentialSampler
73 | >>> sampler = SequentialSampler(list(range(10)))
74 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False))
75 | [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]]
76 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True))
77 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
78 |
79 | """
80 |
81 | def __init__(
82 | self,
83 | sampler,
84 | batch_size,
85 | drop_last: bool = False,
86 | sort_key=identity,
87 | bucket_size_multiplier=100,
88 | ):
89 | super().__init__(sampler, batch_size, drop_last)
90 | self.sort_key = sort_key
91 | _bucket_size = batch_size * bucket_size_multiplier
92 | if hasattr(sampler, "__len__"):
93 | _bucket_size = min(_bucket_size, len(sampler))
94 | self.bucket_sampler = BatchSampler(sampler, _bucket_size, False)
95 |
96 | def __iter__(self):
97 | for bucket in self.bucket_sampler:
98 | sorted_sampler = SortedSampler(bucket, self.sort_key)
99 | for batch in SubsetRandomSampler(
100 | list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))
101 | ):
102 | yield [bucket[i] for i in batch]
103 |
104 | def __len__(self):
105 | if self.drop_last:
106 | return len(self.sampler) // self.batch_size
107 | else:
108 | return math.ceil(len(self.sampler) / self.batch_size)
109 |
--------------------------------------------------------------------------------
/relik/retriever/conf/data/aida_dataset.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path: null
2 | val_dataset_path: null
3 | test_dataset_path: null
4 |
5 | shared_params:
6 | documents_path: null
7 | max_passage_length: 64
8 | passage_batch_size: 64
9 | question_batch_size: 64
10 | use_topics: True
11 |
12 | datamodule:
13 | _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule
14 | datasets:
15 | train:
16 | _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
17 | name: "train"
18 | path: ${data.train_dataset_path}
19 | tokenizer: ${model.language_model}
20 | max_passage_length: ${data.shared_params.max_passage_length}
21 | question_batch_size: ${data.shared_params.question_batch_size}
22 | passage_batch_size: ${data.shared_params.passage_batch_size}
23 | subsample_strategy: null
24 | subsample_portion: 0.1
25 | shuffle: True
26 | metadata_fields: ['definition']
27 | metadata_separator: ' '
28 | use_topics: ${data.shared_params.use_topics}
29 |
30 | val:
31 | - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
32 | name: "val"
33 | path: ${data.val_dataset_path}
34 | tokenizer: ${model.language_model}
35 | max_passage_length: ${data.shared_params.max_passage_length}
36 | question_batch_size: ${data.shared_params.question_batch_size}
37 | passage_batch_size: ${data.shared_params.passage_batch_size}
38 | metadata_fields: ['definition']
39 | metadata_separator: ' '
40 | use_topics: ${data.shared_params.use_topics}
41 |
42 | test:
43 | - _target_: relik.retriever.data.datasets.AidaInBatchNegativesDataset
44 | name: "test"
45 | path: ${data.test_dataset_path}
46 | tokenizer: ${model.language_model}
47 | max_passage_length: ${data.shared_params.max_passage_length}
48 | question_batch_size: ${data.shared_params.question_batch_size}
49 | passage_batch_size: ${data.shared_params.passage_batch_size}
50 | metadata_fields: ['definition']
51 | metadata_separator: ' '
52 | use_topics: ${data.shared_params.use_topics}
53 |
54 | num_workers:
55 | train: 4
56 | val: 4
57 | test: 4
58 |
--------------------------------------------------------------------------------
/relik/retriever/conf/data/blink_dataset.yaml:
--------------------------------------------------------------------------------
1 | train_dataset_path: null
2 | val_dataset_path: null
3 | test_dataset_path: null
4 |
5 | shared_params:
6 | documents_path: null
7 | max_passage_length: 64
8 | passage_batch_size: 64
9 | question_batch_size: 64
10 |
11 | datamodule:
12 | _target_: relik.retriever.lightning_modules.pl_data_modules.GoldenRetrieverPLDataModule
13 | datasets:
14 | train:
15 | _target_: relik.retriever.data.datasets.InBatchNegativesDataset
16 | name: "train"
17 | path: ${data.train_dataset_path}
18 | tokenizer: ${model.language_model}
19 | max_passage_length: ${data.shared_params.max_passage_length}
20 | question_batch_size: ${data.shared_params.question_batch_size}
21 | passage_batch_size: ${data.shared_params.passage_batch_size}
22 | subsample_strategy: random
23 | subsample_portion: 0.1
24 | metadata_fields: ['definition']
25 | metadata_separator: ' '
26 | shuffle: True
27 |
28 | val:
29 | - _target_: relik.retriever.data.datasets.InBatchNegativesDataset
30 | name: "val"
31 | path: ${data.val_dataset_path}
32 | tokenizer: ${model.language_model}
33 | max_passage_length: ${data.shared_params.max_passage_length}
34 | question_batch_size: ${data.shared_params.question_batch_size}
35 | metadata_fields: ['definition']
36 | metadata_separator: ' '
37 | passage_batch_size: ${data.shared_params.passage_batch_size}
38 |
39 | test:
40 | - _target_: relik.retriever.data.datasets.InBatchNegativesDataset
41 | name: "test"
42 | path: ${data.test_dataset_path}
43 | tokenizer: ${model.language_model}
44 | max_passage_length: ${data.shared_params.max_passage_length}
45 | question_batch_size: ${data.shared_params.question_batch_size}
46 | passage_batch_size: ${data.shared_params.passage_batch_size}
47 | metadata_fields: ['definition']
48 | metadata_separator: ' '
49 |
50 | num_workers:
51 | train: 0
52 | val: 0
53 | test: 0
54 |
--------------------------------------------------------------------------------
/relik/retriever/conf/finetune_iterable_in_batch.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: ${model.language_model} # used to name the model in wandb
7 | project_name: relik-retriever-aida # used to name the project in wandb
8 |
9 | defaults:
10 | - _self_
11 | - model: golden_retriever
12 | - index: inmemory
13 | - loss: nce_loss
14 | - optimizer: radamw
15 | - scheduler: linear_scheduler
16 | - data: aida_dataset
17 | - logging: wandb_logging
18 | - override hydra/job_logging: colorlog
19 | - override hydra/hydra_logging: colorlog
20 |
21 | train:
22 | # reproducibility
23 | seed: 42
24 | set_determinism_the_old_way: False
25 | # torch parameters
26 | float32_matmul_precision: "high"
27 | # if true, only test the model
28 | only_test: False
29 | # if provided, initialize the model with the weights from the checkpoint
30 | pretrain_ckpt_path: null
31 | # if provided, start training from the checkpoint
32 | checkpoint_path: null
33 |
34 | # task specific parameter
35 | top_k: 100
36 |
37 | # pl_trainer
38 | pl_trainer:
39 | _target_: lightning.Trainer
40 | accelerator: gpu
41 | devices: 1
42 | num_nodes: 1
43 | strategy: auto
44 | accumulate_grad_batches: 1
45 | gradient_clip_val: 1.0
46 | val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
47 | check_val_every_n_epoch: 1
48 | max_epochs: 0
49 | max_steps: 25_000
50 | deterministic: True
51 | fast_dev_run: False
52 | precision: 16
53 | reload_dataloaders_every_n_epochs: 1
54 |
55 | early_stopping_callback:
56 | # null
57 | _target_: lightning.pytorch.callbacks.EarlyStopping
58 | monitor: validate_recall@${train.top_k}
59 | mode: max
60 | patience: 3
61 |
62 | model_checkpoint_callback:
63 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
64 | monitor: validate_recall@${train.top_k}
65 | mode: max
66 | verbose: True
67 | save_top_k: 1
68 | save_last: False
69 | filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
70 | auto_insert_metric_name: False
71 |
72 | callbacks:
73 | prediction_callback:
74 | _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback
75 | k: ${train.top_k}
76 | batch_size: 64
77 | precision: 16
78 | index_precision: 16
79 | other_callbacks:
80 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
81 | k: ${train.top_k}
82 | verbose: True
83 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
84 | k: 50
85 | verbose: True
86 | prog_bar: False
87 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
88 | k: ${train.top_k}
89 | verbose: True
90 | - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback
91 |
92 | hard_negatives_callback:
93 | _target_: relik.retriever.callbacks.training_callbacks.NegativeAugmentationCallback
94 | k: ${train.top_k}
95 | batch_size: 64
96 | precision: 16
97 | index_precision: 16
98 | stages: [validate] #[validate, sanity_check]
99 | metrics_to_monitor:
100 | validate_recall@${train.top_k}
101 | # - sanity_check_recall@${train.top_k}
102 | threshold: 0.0
103 | max_negatives: 20
104 | add_with_probability: 1.0
105 | refresh_every_n_epochs: 1
106 | other_callbacks:
107 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
108 | k: ${train.top_k}
109 | verbose: True
110 | prefix: "train"
111 |
112 | utils_callbacks:
113 | - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
114 | - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
115 | # - _target_: relik.retriever.callbacks.utils_callbacks.ResetModelCallback
116 | # question_encoder: ${model.pl_module.model.question_encoder}
117 | # passage_encoder: ${model.pl_module.model.passage_encoder}
118 |
--------------------------------------------------------------------------------
/relik/retriever/conf/index/inmemory.yaml:
--------------------------------------------------------------------------------
1 | _target_: relik.retriever.indexers.inmemory.InMemoryDocumentIndex
2 | documents:
3 | _target_: relik.retriever.indexers.document.DocumentStore.from_file
4 | file_path: ${data.shared_params.documents_path}
5 | device: "cuda"
6 | precision: "16"
7 |
--------------------------------------------------------------------------------
/relik/retriever/conf/logging/wandb_logging.yaml:
--------------------------------------------------------------------------------
1 | # don't forget loggers.login() for the first usage.
2 |
3 | log: True # set to False to avoid the logging
4 |
5 | wandb_arg:
6 | _target_: lightning.pytorch.loggers.WandbLogger
7 | name: ${model_name}
8 | project: ${project_name}
9 | save_dir: ./
10 | log_model: True
11 | mode: "online"
12 | entity: null
13 |
14 | watch:
15 | log: "all"
16 | log_freq: 100
17 |
--------------------------------------------------------------------------------
/relik/retriever/conf/loss/nce_loss.yaml:
--------------------------------------------------------------------------------
1 | _target_: relik.retriever.pytorch_modules.loss.MultiLabelNCELoss
2 |
--------------------------------------------------------------------------------
/relik/retriever/conf/loss/nll_loss.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.nn.NLLLoss
2 |
--------------------------------------------------------------------------------
/relik/retriever/conf/model/golden_retriever.yaml:
--------------------------------------------------------------------------------
1 | language_model: "intfloat/e5-small-v2"
2 |
3 | pl_module:
4 | _target_: relik.retriever.lightning_modules.pl_modules.GoldenRetrieverPLModule
5 | model:
6 | _target_: relik.retriever.pytorch_modules.model.GoldenRetriever
7 | question_encoder: ${model.language_model}
8 | passage_encoder: ${model.language_model}
9 | document_index: ${index}
10 | loss_type: ${loss}
11 | optimizer: ${optimizer}
12 | lr_scheduler: ${scheduler}
13 |
--------------------------------------------------------------------------------
/relik/retriever/conf/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.AdamW
2 | lr: 1e-5
3 | weight_decay: 0.01
4 | fused: False
5 |
--------------------------------------------------------------------------------
/relik/retriever/conf/optimizer/radam.yaml:
--------------------------------------------------------------------------------
1 | _target_: torch.optim.RAdam
2 | lr: 1e-5
3 | weight_decay: 0
4 |
--------------------------------------------------------------------------------
/relik/retriever/conf/optimizer/radamw.yaml:
--------------------------------------------------------------------------------
1 | _target_: relik.retriever.pytorch_modules.optim.RAdamW
2 | lr: 1e-5
3 | weight_decay: 0.01
4 |
--------------------------------------------------------------------------------
/relik/retriever/conf/pretrain_iterable_in_batch.yaml:
--------------------------------------------------------------------------------
1 | # Required to make the "experiments" dir the default one for the output of the models
2 | hydra:
3 | run:
4 | dir: ./experiments/${model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S}
5 |
6 | model_name: ${model.language_model} # used to name the model in wandb
7 | project_name: relik-retriever # used to name the project in wandb
8 |
9 | defaults:
10 | - _self_
11 | - model: golden_retriever
12 | - index: inmemory
13 | - loss: nce_loss
14 | - optimizer: radamw
15 | - scheduler: linear_scheduler
16 | - data: blink_dataset
17 | - logging: wandb_logging
18 | - override hydra/job_logging: colorlog
19 | - override hydra/hydra_logging: colorlog
20 |
21 | train:
22 | # reproducibility
23 | seed: 42
24 | set_determinism_the_old_way: False
25 | # torch parameters
26 | float32_matmul_precision: "medium"
27 | # if true, only test the model
28 | only_test: False
29 | # if provided, initialize the model with the weights from the checkpoint
30 | pretrain_ckpt_path: null
31 | # if provided, start training from the checkpoint
32 | checkpoint_path: null
33 |
34 | # task specific parameter
35 | top_k: 100
36 |
37 | # pl_trainer
38 | pl_trainer:
39 | _target_: lightning.Trainer
40 | accelerator: gpu
41 | devices: 1
42 | num_nodes: 1
43 | strategy: auto
44 | accumulate_grad_batches: 1
45 | gradient_clip_val: 1.0
46 | val_check_interval: 1.0 # you can specify an int "n" here => validation every "n" steps
47 | check_val_every_n_epoch: 1
48 | max_epochs: 0
49 | max_steps: 220_000
50 | deterministic: True
51 | fast_dev_run: False
52 | precision: 16
53 | reload_dataloaders_every_n_epochs: 1
54 |
55 | early_stopping_callback:
56 | null
57 | # _target_: lightning.pytorch.callbacks.EarlyStopping
58 | # monitor: validate_recall@${train.top_k}
59 | # mode: max
60 | # patience: 15
61 |
62 | model_checkpoint_callback:
63 | _target_: lightning.pytorch.callbacks.ModelCheckpoint
64 | monitor: validate_recall@${train.top_k}
65 | mode: max
66 | verbose: True
67 | save_top_k: 1
68 | save_last: True
69 | filename: "checkpoint-validate_recall@${train.top_k}_{validate_recall@${train.top_k}:.4f}-epoch_{epoch:02d}"
70 | auto_insert_metric_name: False
71 |
72 | callbacks:
73 | prediction_callback:
74 | _target_: relik.retriever.callbacks.prediction_callbacks.GoldenRetrieverPredictionCallback
75 | k: ${train.top_k}
76 | batch_size: 128
77 | precision: 16
78 | index_precision: 16
79 | other_callbacks:
80 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
81 | k: ${train.top_k}
82 | verbose: True
83 | - _target_: relik.retriever.callbacks.evaluation_callbacks.RecallAtKEvaluationCallback
84 | k: 50
85 | verbose: True
86 | prog_bar: False
87 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
88 | k: ${train.top_k}
89 | verbose: True
90 | - _target_: relik.retriever.callbacks.utils_callbacks.SavePredictionsCallback
91 |
92 | hard_negatives_callback:
93 | k: ${train.top_k}
94 | batch_size: 128
95 | precision: 16
96 | index_precision: 16
97 | stages: [validate] #[validate, sanity_check]
98 | metrics_to_monitor:
99 | validate_recall@${train.top_k}
100 | # - sanity_check_recall@${train.top_k}
101 | threshold: 0.0
102 | max_negatives: 15
103 | add_with_probability: 0.2
104 | refresh_every_n_epochs: 1
105 | other_callbacks:
106 | - _target_: relik.retriever.callbacks.evaluation_callbacks.AvgRankingEvaluationCallback
107 | k: ${train.top_k}
108 | verbose: True
109 | prefix: "train"
110 |
111 | utils_callbacks:
112 | - _target_: relik.retriever.callbacks.utils_callbacks.SaveRetrieverCallback
113 | - _target_: relik.retriever.callbacks.utils_callbacks.FreeUpIndexerVRAMCallback
114 |
--------------------------------------------------------------------------------
/relik/retriever/conf/scheduler/linear_scheduler.yaml:
--------------------------------------------------------------------------------
1 | _target_: transformers.get_linear_schedule_with_warmup
2 | num_warmup_steps: 0
3 | num_training_steps: ${train.pl_trainer.max_steps}
4 |
--------------------------------------------------------------------------------
/relik/retriever/conf/scheduler/linear_scheduler_with_warmup.yaml:
--------------------------------------------------------------------------------
1 | _target_: transformers.get_linear_schedule_with_warmup
2 | num_warmup_steps: 5_000
3 | num_training_steps: ${train.pl_trainer.max_steps}
4 |
--------------------------------------------------------------------------------
/relik/retriever/conf/scheduler/none.yaml:
--------------------------------------------------------------------------------
1 | null
--------------------------------------------------------------------------------
/relik/retriever/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/data/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/data/base/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/data/base/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/data/base/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
4 |
5 | import torch
6 | from torch.utils.data import Dataset, IterableDataset
7 |
8 | from relik.common.log import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | class BaseDataset(Dataset):
14 | def __init__(
15 | self,
16 | name: str,
17 | path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None,
18 | data: Any = None,
19 | **kwargs,
20 | ):
21 | super().__init__()
22 | self.name = name
23 | if path is None and data is None:
24 | raise ValueError("Either `path` or `data` must be provided")
25 | self.path = path
26 | self.project_folder = Path(__file__).parent.parent.parent
27 | self.data = data
28 |
29 | def __len__(self) -> int:
30 | return len(self.data)
31 |
32 | def __getitem__(
33 | self, index
34 | ) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
35 | return self.data[index]
36 |
37 | def __repr__(self) -> str:
38 | return f"Dataset({self.name=}, {self.path=})"
39 |
40 | def load(
41 | self,
42 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
43 | *args,
44 | **kwargs,
45 | ) -> Any:
46 | # load data from single or multiple paths in one single dataset
47 | raise NotImplementedError
48 |
49 | @staticmethod
50 | def collate_fn(batch: Any, *args, **kwargs) -> Any:
51 | raise NotImplementedError
52 |
53 |
54 | class IterableBaseDataset(IterableDataset):
55 | def __init__(
56 | self,
57 | name: str,
58 | path: Optional[Union[str, Path, List[str], List[Path]]] = None,
59 | data: Any = None,
60 | *args,
61 | **kwargs,
62 | ):
63 | super().__init__()
64 | self.name = name
65 | if path is None and data is None:
66 | raise ValueError("Either `path` or `data` must be provided")
67 | self.path = path
68 | self.project_folder = Path(__file__).parent.parent.parent
69 | self.data = data
70 |
71 | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
72 | for sample in self.data:
73 | yield sample
74 |
75 | def __repr__(self) -> str:
76 | return f"Dataset({self.name=}, {self.path=})"
77 |
78 | def load(
79 | self,
80 | paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
81 | *args,
82 | **kwargs,
83 | ) -> Any:
84 | # load data from single or multiple paths in one single dataset
85 | raise NotImplementedError
86 |
87 | @staticmethod
88 | def collate_fn(batch: Any, *args, **kwargs) -> Any:
89 | raise NotImplementedError
90 |
--------------------------------------------------------------------------------
/relik/retriever/indexers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/indexers/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/lightning_modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/relik/retriever/lightning_modules/__init__.py
--------------------------------------------------------------------------------
/relik/retriever/lightning_modules/pl_data_modules.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Sequence, Union
2 |
3 | import hydra
4 | import lightning as pl
5 | import torch
6 | from lightning.pytorch.utilities.types import EVAL_DATALOADERS
7 | from omegaconf import DictConfig
8 | from torch.utils.data import DataLoader
9 |
10 | from relik.common.log import get_logger
11 | from relik.retriever.data.datasets import GoldenRetrieverDataset
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | class GoldenRetrieverPLDataModule(pl.LightningDataModule):
17 | def __init__(
18 | self,
19 | train_dataset: Optional[GoldenRetrieverDataset] = None,
20 | val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
21 | test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = None,
22 | num_workers: Optional[Union[DictConfig, int]] = None,
23 | datasets: Optional[DictConfig] = None,
24 | *args,
25 | **kwargs,
26 | ):
27 | super().__init__()
28 | self.datasets = datasets
29 | if num_workers is None:
30 | num_workers = 0
31 | if isinstance(num_workers, int):
32 | num_workers = DictConfig(
33 | {"train": num_workers, "val": num_workers, "test": num_workers}
34 | )
35 | self.num_workers = num_workers
36 | # data
37 | self.train_dataset: Optional[GoldenRetrieverDataset] = train_dataset
38 | self.val_datasets: Optional[Sequence[GoldenRetrieverDataset]] = val_datasets
39 | self.test_datasets: Optional[Sequence[GoldenRetrieverDataset]] = test_datasets
40 |
41 | def prepare_data(self, *args, **kwargs):
42 | """
43 | Method for preparing the data before the training. This method is called only once.
44 | It is used to download the data, tokenize the data, etc.
45 | """
46 | pass
47 |
48 | def setup(self, stage: Optional[str] = None):
49 | if stage == "fit" or stage is None:
50 | # usually there is only one dataset for train
51 | # if you need more train loader, you can follow
52 | # the same logic as val and test datasets
53 | if self.train_dataset is None:
54 | self.train_dataset = hydra.utils.instantiate(self.datasets.train)
55 | self.val_datasets = [
56 | hydra.utils.instantiate(dataset_cfg)
57 | for dataset_cfg in self.datasets.val
58 | ]
59 | if stage == "test":
60 | if self.test_datasets is None:
61 | self.test_datasets = [
62 | hydra.utils.instantiate(dataset_cfg)
63 | for dataset_cfg in self.datasets.test
64 | ]
65 |
66 | def train_dataloader(self, *args, **kwargs) -> DataLoader:
67 | torch_dataset = self.train_dataset.to_torch_dataset()
68 | return DataLoader(
69 | # self.train_dataset.to_torch_dataset(),
70 | torch_dataset,
71 | shuffle=False,
72 | batch_size=None,
73 | num_workers=self.num_workers.train,
74 | pin_memory=True,
75 | collate_fn=lambda x: x,
76 | )
77 |
78 | def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
79 | dataloaders = []
80 | for dataset in self.val_datasets:
81 | torch_dataset = dataset.to_torch_dataset()
82 | dataloaders.append(
83 | DataLoader(
84 | torch_dataset,
85 | shuffle=False,
86 | batch_size=None,
87 | num_workers=self.num_workers.val,
88 | pin_memory=True,
89 | collate_fn=lambda x: x,
90 | )
91 | )
92 | return dataloaders
93 |
94 | def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]:
95 | dataloaders = []
96 | for dataset in self.test_datasets:
97 | torch_dataset = dataset.to_torch_dataset()
98 | dataloaders.append(
99 | DataLoader(
100 | torch_dataset,
101 | shuffle=False,
102 | batch_size=None,
103 | num_workers=self.num_workers.test,
104 | pin_memory=True,
105 | collate_fn=lambda x: x,
106 | )
107 | )
108 | return dataloaders
109 |
110 | def predict_dataloader(self) -> EVAL_DATALOADERS:
111 | raise NotImplementedError
112 |
113 | def transfer_batch_to_device(
114 | self, batch: Any, device: torch.device, dataloader_idx: int
115 | ) -> Any:
116 | return super().transfer_batch_to_device(batch, device, dataloader_idx)
117 |
118 | def __repr__(self) -> str:
119 | return (
120 | f"{self.__class__.__name__}(" f"{self.datasets=}, " f"{self.num_workers=}, "
121 | )
122 |
--------------------------------------------------------------------------------
/relik/retriever/lightning_modules/pl_modules.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Union
2 |
3 | import hydra
4 | import lightning as pl
5 | import torch
6 | from omegaconf import DictConfig
7 |
8 | from relik.retriever.common.model_inputs import ModelInputs
9 |
10 |
11 | class GoldenRetrieverPLModule(pl.LightningModule):
12 | def __init__(
13 | self,
14 | model: Union[torch.nn.Module, DictConfig],
15 | optimizer: Union[torch.optim.Optimizer, DictConfig],
16 | lr_scheduler: Union[torch.optim.lr_scheduler.LRScheduler, DictConfig] = None,
17 | *args,
18 | **kwargs,
19 | ) -> None:
20 | super().__init__()
21 | self.save_hyperparameters(ignore=["model"])
22 | if isinstance(model, DictConfig):
23 | self.model = hydra.utils.instantiate(model)
24 | else:
25 | self.model = model
26 |
27 | self.optimizer_config = optimizer
28 | self.lr_scheduler_config = lr_scheduler
29 |
30 | def forward(self, **kwargs) -> dict:
31 | """
32 | Method for the forward pass.
33 | 'training_step', 'validation_step' and 'test_step' should call
34 | this method in order to compute the output predictions and the loss.
35 |
36 | Returns:
37 | output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
38 |
39 | """
40 | return self.model(**kwargs)
41 |
42 | def training_step(self, batch: ModelInputs, batch_idx: int) -> torch.Tensor:
43 | forward_output = self.forward(**batch, return_loss=True)
44 | self.log(
45 | "loss",
46 | forward_output["loss"],
47 | batch_size=batch["questions"]["input_ids"].size(0),
48 | prog_bar=True,
49 | )
50 | return forward_output["loss"]
51 |
52 | def validation_step(self, batch: ModelInputs, batch_idx: int) -> None:
53 | forward_output = self.forward(**batch, return_loss=True)
54 | self.log(
55 | "val_loss",
56 | forward_output["loss"],
57 | batch_size=batch["questions"]["input_ids"].size(0),
58 | )
59 |
60 | def test_step(self, batch: ModelInputs, batch_idx: int) -> Any:
61 | forward_output = self.forward(**batch, return_loss=True)
62 | self.log(
63 | "test_loss",
64 | forward_output["loss"],
65 | batch_size=batch["questions"]["input_ids"].size(0),
66 | )
67 |
68 | def configure_optimizers(self):
69 | if isinstance(self.optimizer_config, DictConfig):
70 | param_optimizer = list(self.named_parameters())
71 | no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
72 | optimizer_grouped_parameters = [
73 | {
74 | "params": [
75 | p for n, p in param_optimizer if "layer_norm_layer" in n
76 | ],
77 | "weight_decay": self.hparams.optimizer.weight_decay,
78 | "lr": 1e-4,
79 | },
80 | {
81 | "params": [
82 | p
83 | for n, p in param_optimizer
84 | if all(nd not in n for nd in no_decay)
85 | and "layer_norm_layer" not in n
86 | ],
87 | "weight_decay": self.hparams.optimizer.weight_decay,
88 | },
89 | {
90 | "params": [
91 | p
92 | for n, p in param_optimizer
93 | if "layer_norm_layer" not in n
94 | and any(nd in n for nd in no_decay)
95 | ],
96 | "weight_decay": 0.0,
97 | },
98 | ]
99 | optimizer = hydra.utils.instantiate(
100 | self.optimizer_config,
101 | # params=self.parameters(),
102 | params=optimizer_grouped_parameters,
103 | _convert_="partial",
104 | )
105 | else:
106 | optimizer = self.optimizer_config
107 |
108 | if self.lr_scheduler_config is None:
109 | return optimizer
110 |
111 | if isinstance(self.lr_scheduler_config, DictConfig):
112 | lr_scheduler = hydra.utils.instantiate(
113 | self.lr_scheduler_config, optimizer=optimizer
114 | )
115 | else:
116 | lr_scheduler = self.lr_scheduler_config
117 |
118 | lr_scheduler_config = {
119 | "scheduler": lr_scheduler,
120 | "interval": "step",
121 | "frequency": 1,
122 | }
123 | return [optimizer], [lr_scheduler_config]
124 |
--------------------------------------------------------------------------------
/relik/retriever/pytorch_modules/__init__.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | import torch
4 |
5 | from relik.retriever.indexers.document import Document
6 |
7 | PRECISION_MAP = {
8 | None: torch.float32,
9 | 32: torch.float32,
10 | 16: torch.float16,
11 | torch.float32: torch.float32,
12 | torch.float16: torch.float16,
13 | torch.bfloat16: torch.bfloat16,
14 | "float32": torch.float32,
15 | "float16": torch.float16,
16 | "bfloat16": torch.bfloat16,
17 | "float": torch.float32,
18 | "half": torch.float16,
19 | "32": torch.float32,
20 | "16": torch.float16,
21 | "fp32": torch.float32,
22 | "fp16": torch.float16,
23 | "bf16": torch.bfloat16,
24 | }
25 |
26 |
27 | @dataclass
28 | class RetrievedSample:
29 | """
30 | Dataclass for the output of the GoldenRetriever model.
31 | """
32 |
33 | score: float
34 | document: Document
35 |
--------------------------------------------------------------------------------
/relik/retriever/pytorch_modules/hf.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple, Union
2 |
3 | import torch
4 | from transformers import PretrainedConfig
5 | from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
6 | from transformers.models.bert.modeling_bert import BertModel
7 |
8 |
9 | class GoldenRetrieverConfig(PretrainedConfig):
10 | model_type = "bert"
11 |
12 | def __init__(
13 | self,
14 | vocab_size=30522,
15 | hidden_size=768,
16 | num_hidden_layers=12,
17 | num_attention_heads=12,
18 | intermediate_size=3072,
19 | hidden_act="gelu",
20 | hidden_dropout_prob=0.1,
21 | attention_probs_dropout_prob=0.1,
22 | max_position_embeddings=512,
23 | type_vocab_size=2,
24 | initializer_range=0.02,
25 | layer_norm_eps=1e-12,
26 | pad_token_id=0,
27 | position_embedding_type="absolute",
28 | use_cache=True,
29 | classifier_dropout=None,
30 | **kwargs,
31 | ):
32 | super().__init__(pad_token_id=pad_token_id, **kwargs)
33 |
34 | self.vocab_size = vocab_size
35 | self.hidden_size = hidden_size
36 | self.num_hidden_layers = num_hidden_layers
37 | self.num_attention_heads = num_attention_heads
38 | self.hidden_act = hidden_act
39 | self.intermediate_size = intermediate_size
40 | self.hidden_dropout_prob = hidden_dropout_prob
41 | self.attention_probs_dropout_prob = attention_probs_dropout_prob
42 | self.max_position_embeddings = max_position_embeddings
43 | self.type_vocab_size = type_vocab_size
44 | self.initializer_range = initializer_range
45 | self.layer_norm_eps = layer_norm_eps
46 | self.position_embedding_type = position_embedding_type
47 | self.use_cache = use_cache
48 | self.classifier_dropout = classifier_dropout
49 |
50 |
51 | class GoldenRetrieverModel(BertModel):
52 | config_class = GoldenRetrieverConfig
53 |
54 | def __init__(self, config, *args, **kwargs):
55 | super().__init__(config)
56 | self.layer_norm_layer = torch.nn.LayerNorm(
57 | config.hidden_size, eps=config.layer_norm_eps
58 | )
59 |
60 | def forward(
61 | self, **kwargs
62 | ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
63 | attention_mask = kwargs.get("attention_mask", None)
64 | model_outputs = super().forward(**kwargs)
65 | if attention_mask is None:
66 | pooler_output = model_outputs.pooler_output
67 | else:
68 | token_embeddings = model_outputs.last_hidden_state
69 | input_mask_expanded = (
70 | attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
71 | )
72 | pooler_output = torch.sum(
73 | token_embeddings * input_mask_expanded, 1
74 | ) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
75 |
76 | pooler_output = self.layer_norm_layer(pooler_output)
77 |
78 | if not kwargs.get("return_dict", True):
79 | return (model_outputs[0], pooler_output) + model_outputs[2:]
80 |
81 | return BaseModelOutputWithPoolingAndCrossAttentions(
82 | last_hidden_state=model_outputs.last_hidden_state,
83 | pooler_output=pooler_output,
84 | past_key_values=model_outputs.past_key_values,
85 | hidden_states=model_outputs.hidden_states,
86 | attentions=model_outputs.attentions,
87 | cross_attentions=model_outputs.cross_attentions,
88 | )
89 |
--------------------------------------------------------------------------------
/relik/retriever/pytorch_modules/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch.nn.modules.loss import _WeightedLoss
5 |
6 |
7 | class MultiLabelNCELoss(_WeightedLoss):
8 | __constants__ = ["reduction"]
9 |
10 | def __init__(
11 | self,
12 | weight: Optional[torch.Tensor] = None,
13 | size_average=None,
14 | reduction: Optional[str] = "mean",
15 | ) -> None:
16 | super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction)
17 |
18 | def forward(
19 | self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100
20 | ) -> torch.Tensor:
21 | gold_scores = input.masked_fill(~(target.bool()), 0)
22 | gold_scores_sum = gold_scores.sum(-1) # B x C
23 | neg_logits = input.masked_fill(target.bool(), float("-inf")) # B x C x L
24 | neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) # B x C x 1
25 | norm_term = (
26 | torch.logaddexp(input, neg_log_sum_exp)
27 | .masked_fill(~(target.bool()), 0)
28 | .sum(-1)
29 | )
30 | gold_log_probs = gold_scores_sum - norm_term
31 | loss = -gold_log_probs.sum()
32 | if self.reduction == "mean":
33 | loss /= input.size(0)
34 | return loss
35 |
--------------------------------------------------------------------------------
/relik/retriever/pytorch_modules/optim.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.optim import Optimizer
5 |
6 |
7 | class RAdamW(Optimizer):
8 | r"""Implements RAdamW algorithm.
9 |
10 | RAdam from `On the Variance of the Adaptive Learning Rate and Beyond
11 | `_
12 |
13 | * `Adam: A Method for Stochastic Optimization
14 | `_
15 | * `Decoupled Weight Decay Regularization
16 | `_
17 | * `On the Convergence of Adam and Beyond
18 | `_
19 | * `On the Variance of the Adaptive Learning Rate and Beyond
20 | `_
21 |
22 | Arguments:
23 | params (iterable): iterable of parameters to optimize or dicts defining
24 | parameter groups
25 | lr (float, optional): learning rate (default: 1e-3)
26 | betas (Tuple[float, float], optional): coefficients used for computing
27 | running averages of gradient and its square (default: (0.9, 0.999))
28 | eps (float, optional): term added to the denominator to improve
29 | numerical stability (default: 1e-8)
30 | weight_decay (float, optional): weight decay coefficient (default: 1e-2)
31 | """
32 |
33 | def __init__(
34 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2
35 | ):
36 | if not 0.0 <= lr:
37 | raise ValueError("Invalid learning rate: {}".format(lr))
38 | if not 0.0 <= eps:
39 | raise ValueError("Invalid epsilon value: {}".format(eps))
40 | if not 0.0 <= betas[0] < 1.0:
41 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
42 | if not 0.0 <= betas[1] < 1.0:
43 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
44 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
45 | super(RAdamW, self).__init__(params, defaults)
46 |
47 | def step(self, closure=None):
48 | """Performs a single optimization step.
49 |
50 | Arguments:
51 | closure (callable, optional): A closure that reevaluates the model
52 | and returns the loss.
53 | """
54 | loss = None
55 | if closure is not None:
56 | loss = closure()
57 |
58 | for group in self.param_groups:
59 | for p in group["params"]:
60 | if p.grad is None:
61 | continue
62 |
63 | # Perform optimization step
64 | grad = p.grad.data
65 | if grad.is_sparse:
66 | raise RuntimeError(
67 | "Adam does not support sparse gradients, please consider SparseAdam instead"
68 | )
69 |
70 | state = self.state[p]
71 |
72 | # State initialization
73 | if len(state) == 0:
74 | state["step"] = 0
75 | # Exponential moving average of gradient values
76 | state["exp_avg"] = torch.zeros_like(p.data)
77 | # Exponential moving average of squared gradient values
78 | state["exp_avg_sq"] = torch.zeros_like(p.data)
79 |
80 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
81 | beta1, beta2 = group["betas"]
82 | eps = group["eps"]
83 | lr = group["lr"]
84 | if "rho_inf" not in group:
85 | group["rho_inf"] = 2 / (1 - beta2) - 1
86 | rho_inf = group["rho_inf"]
87 |
88 | state["step"] += 1
89 | t = state["step"]
90 |
91 | # Decay the first and second moment running average coefficient
92 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
93 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
94 | rho_t = rho_inf - ((2 * t * (beta2**t)) / (1 - beta2**t))
95 |
96 | # Perform stepweight decay
97 | p.data.mul_(1 - lr * group["weight_decay"])
98 |
99 | if rho_t >= 5:
100 | var = exp_avg_sq.sqrt().add_(eps)
101 | r = math.sqrt(
102 | (1 - beta2**t)
103 | * ((rho_t - 4) * (rho_t - 2) * rho_inf)
104 | / ((rho_inf - 4) * (rho_inf - 2) * rho_t)
105 | )
106 |
107 | p.data.addcdiv_(exp_avg, var, value=-lr * r / (1 - beta1**t))
108 | else:
109 | p.data.add_(exp_avg, alpha=-lr / (1 - beta1**t))
110 |
111 | return loss
112 |
--------------------------------------------------------------------------------
/relik/retriever/pytorch_modules/scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import LRScheduler
3 |
4 |
5 | class LinearSchedulerWithWarmup(LRScheduler):
6 | def __init__(
7 | self,
8 | optimizer: torch.optim.Optimizer,
9 | num_warmup_steps: int,
10 | num_training_steps: int,
11 | last_epoch: int = -1,
12 | verbose: bool = False,
13 | **kwargs,
14 | ):
15 | self.num_warmup_steps = num_warmup_steps
16 | self.num_training_steps = num_training_steps
17 | super().__init__(optimizer, last_epoch, verbose)
18 |
19 | def get_lr(self):
20 | def scheduler_fn(current_step):
21 | if current_step < self.num_warmup_steps:
22 | return current_step / max(1, self.num_warmup_steps)
23 | return max(
24 | 0.0,
25 | float(self.num_training_steps - current_step)
26 | / float(max(1, self.num_training_steps - self.num_warmup_steps)),
27 | )
28 |
29 | return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]
30 |
31 |
32 | class LinearScheduler(LRScheduler):
33 | def __init__(
34 | self,
35 | optimizer: torch.optim.Optimizer,
36 | num_training_steps: int,
37 | last_epoch: int = -1,
38 | verbose: bool = False,
39 | **kwargs,
40 | ):
41 | self.num_training_steps = num_training_steps
42 | super().__init__(optimizer, last_epoch, verbose)
43 |
44 | def get_lr(self):
45 | def scheduler_fn(current_step):
46 | # if current_step < self.num_warmup_steps:
47 | # return current_step / max(1, self.num_warmup_steps)
48 | return max(
49 | 0.0,
50 | float(self.num_training_steps - current_step)
51 | / float(max(1, self.num_training_steps)),
52 | )
53 |
54 | return [base_lr * scheduler_fn(self.last_epoch) for base_lr in self.base_lrs]
55 |
--------------------------------------------------------------------------------
/relik/retriever/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from relik.retriever.trainer.train import RetrieverTrainer
2 |
--------------------------------------------------------------------------------
/relik/version.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | _MAJOR = "1"
4 | _MINOR = "0"
5 | # On main and in a nightly release the patch should be one ahead of the last
6 | # released build.
7 | _PATCH = "7"
8 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See
9 | # https://semver.org/#is-v123-a-semantic-version for the semantics.
10 | _SUFFIX = os.environ.get("RELIK_VERSION_SUFFIX", "")
11 |
12 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR)
13 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX)
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | #------- Core dependencies -------
2 | --extra-index-url https://download.pytorch.org/whl/cu12.1
3 | torch==2.3.1
4 |
5 | transformers[sentencepiece]>=4.41,<4.42
6 | rich>=13.0.0,<14.0.0
7 | scikit-learn>=1.5,<1.6
8 | overrides>=7.4,<7.9
9 | art==6.2
10 | pprintpp==0.4.0
11 | colorama==0.4.6
12 | termcolor==2.4.0
13 | spacy>=3.7,<3.8
14 | typer>=0.12,<0.13
15 |
16 | #------- Optional dependencies -------
17 |
18 | # train
19 | lightning>=2.3,<2.4
20 | datasets>=2.13,<2.15
21 | hydra-core>=1.3,<1.4
22 | hydra_colorlog
23 | wandb>=0.15,<0.18
24 |
25 | # faiss
26 | faiss-cpu==1.8.0 # needed by: faiss
27 |
28 | # serve
29 | fastapi>=0.112,<0.113 # needed by: serve, ray
30 | uvicorn[standard]==0.23.2 # needed by: serve, ray
31 | gunicorn==22.0.0 # needed by: serve, ray
32 | streamlit>=1.28,<1.29 # needed by: serve, ray
33 | streamlit_extras>=0.3,<0.4 # needed by: serve, ray
34 | gradio>=4.37,<4.38 # needed by: serve, ray
35 | pyvis # needed by: serve, ray
36 | ray[serve]>=2.34,<=2.35 # needed by: ray
37 |
38 | # dev
39 | pre-commit # needed by: dev
40 | black[d] # needed by: dev
41 | isort # needed by: dev
42 |
--------------------------------------------------------------------------------
/scripts/build_all.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | # get version from version.py file
5 | VERSION=$(python -c "import relik; print(relik.__version__)")
6 | LATEST_VERSION=$(echo "$VERSION" | tail -n 1)
7 | echo "Building version: $LATEST_VERSION"
8 |
9 | echo "==== Building CPU images ===="
10 | # docker build -f dockerfiles/ray/Dockerfile.cpu -t relik:$VERSION-cpu-ray .
11 | docker build -f dockerfiles/fastapi/Dockerfile.cpu -t relik:$LATEST_VERSION-cpu-fastapi .
12 |
13 | echo "==== Building GPU images ===="
14 | # docker build -f dockerfiles/ray/Dockerfile.cuda -t relik:$VERSION-cuda-ray .
15 | docker build -f dockerfiles/fastapi/Dockerfile.cuda -t relik:$LATEST_VERSION-cuda-fastapi .
16 |
--------------------------------------------------------------------------------
/scripts/build_docker_with_weights.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # get version from version.py file
4 | VERSION=$(python -c "import relik; print(relik.__version__)")
5 | echo "Building version: $VERSION"
6 |
7 | # if relik-model exists, delete it
8 | REPO_MODEL_PATH="relik-model"
9 | if [ -d "$REPO_MODEL_PATH" ]; then
10 | echo "Deleting $REPO_MODEL_PATH"
11 | rm -r "$REPO_MODEL_PATH"
12 | fi
13 |
14 | # create relik-model directory and copy model files
15 | echo "Copying model files to $REPO_MODEL_PATH"
16 | mkdir -p "$REPO_MODEL_PATH"
17 | # copy model files
18 | cp -r "$MODEL_PATH"/* "$REPO_MODEL_PATH"
19 |
20 | docker build -f dockerfiles/ray/Dockerfile.cuda -t relik:$VERSION-bsc-cuda-ray .
21 |
22 | # clean up
23 | if [ -d "$REPO_MODEL_PATH" ]; then
24 | echo "Deleting $REPO_MODEL_PATH"
25 | rm -r "$REPO_MODEL_PATH"
26 | fi
27 |
--------------------------------------------------------------------------------
/scripts/data/blink/preprocess_genre_blink.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from pathlib import Path
4 | import re
5 | from typing import List, Tuple
6 |
7 | from tqdm import tqdm
8 |
9 | LABELS_REGEX = "[{][^}]+[}] \[[^]]+\]"
10 | REPLACE_PATTERNS = [
11 | ("BULLET::::- ", ""),
12 | ("( )", ""),
13 | (" ", " "),
14 | (" ", " "),
15 | (" ", " "),
16 | ]
17 | LABELS_FILTERING_FUNCTIONS = [lambda x: x.startswith("List of")]
18 | SUBSTITUTION_PATTERN = "#$#"
19 |
20 |
21 | def process_annotation(ann_surface_text: str) -> Tuple[str, str]:
22 | mention, label = ann_surface_text.split("} [")
23 | mention = mention.replace("{", "").strip()
24 | label = label.replace("]", "").strip()
25 | return mention, label
26 |
27 |
28 | def substitute_annotations(
29 | annotations: List[str], sub_line: str
30 | ) -> Tuple[str, List[Tuple[int, int, str]]]:
31 | final_annotations_store = []
32 | for annotation in annotations:
33 | mention, label = process_annotation(annotation)
34 | start_char = sub_line.index(SUBSTITUTION_PATTERN)
35 | end_char = start_char + len(mention)
36 | sub_line = sub_line.replace(SUBSTITUTION_PATTERN, mention, 1)
37 | assert sub_line[start_char:end_char] == mention
38 | if any([fl(label) for fl in LABELS_FILTERING_FUNCTIONS]):
39 | continue
40 | final_annotations_store.append((start_char, end_char, label))
41 | return sub_line, final_annotations_store
42 |
43 |
44 | def preprocess_line(line: str) -> Tuple[str, List[Tuple[int, int, str]]]:
45 | for rps, rpe in REPLACE_PATTERNS:
46 | line = line.replace(rps, rpe)
47 |
48 | annotations = re.findall(LABELS_REGEX, line)
49 | sub_line = re.sub(LABELS_REGEX, SUBSTITUTION_PATTERN, line)
50 | return substitute_annotations(annotations, sub_line)
51 |
52 |
53 | def preprocess_genre_el_file(
54 | file_path: str, output_path: str, limit_lines: int = -1
55 | ) -> None:
56 | # Create output directory
57 | Path(output_path).parent.mkdir(parents=True, exist_ok=True)
58 | with open(file_path) as fi, open(output_path, "w") as fo:
59 | for i, line in tqdm(enumerate(fi)):
60 | text, annotations = preprocess_line(line.strip())
61 | fo.write(
62 | json.dumps(dict(doc_id=i, doc_text=text, doc_span_annotations=annotations))
63 | + "\n"
64 | )
65 | if limit_lines == i:
66 | break
67 |
68 |
69 | def main():
70 |
71 | arg_parser = argparse.ArgumentParser("Preprocess Genre BLINK file.")
72 | arg_parser.add_argument("input_file", type=str)
73 | arg_parser.add_argument("output_file", type=str)
74 | arg_parser.add_argument("--limit-lines", type=int, default=-1)
75 | args = arg_parser.parse_args()
76 |
77 | preprocess_genre_el_file(args.input_file, args.output_file, args.limit_lines)
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/scripts/data/retriever/create_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | from pathlib import Path
5 | from typing import Optional, Union
6 |
7 | import torch
8 |
9 | from relik.retriever import GoldenRetriever
10 | from relik.common.utils import get_logger, get_callable_from_string
11 | from relik.retriever.indexers.document import DocumentStore
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | @torch.no_grad()
17 | def build_index(
18 | question_encoder_name_or_path: Union[str, os.PathLike],
19 | document_path: Union[str, os.PathLike],
20 | output_folder: Union[str, os.PathLike],
21 | document_file_type: str = "jsonl",
22 | passage_encoder_name_or_path: Optional[Union[str, os.PathLike]] = None,
23 | indexer_class: str = "relik.retriever.indexers.inmemory.InMemoryDocumentIndex",
24 | batch_size: int = 512,
25 | num_workers: int = 4,
26 | passage_max_length: int = 64,
27 | device: str = "cuda",
28 | index_device: str = "cpu",
29 | precision: str = "fp32",
30 | push_to_hub: bool = False,
31 | repo_id: Optional[str] = None,
32 | ):
33 | if push_to_hub:
34 | if not repo_id:
35 | raise ValueError("`repo_id` must be provided when `push_to_hub=True`")
36 |
37 | logger.info("Loading documents")
38 | if document_file_type == "jsonl":
39 | documents = DocumentStore.from_file(document_path)
40 | elif document_file_type == "csv":
41 | documents = DocumentStore.from_tsv(
42 | document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True
43 | )
44 | elif document_file_type == "tsv":
45 | documents = DocumentStore.from_tsv(
46 | document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True
47 | )
48 | else:
49 | raise ValueError(
50 | f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv"
51 | )
52 |
53 | logger.info("Loading document index")
54 | logger.info(f"Loaded {len(documents)} documents")
55 | indexer = get_callable_from_string(indexer_class)(
56 | documents, device=index_device, precision=precision
57 | )
58 |
59 | retriever = GoldenRetriever(
60 | question_encoder=question_encoder_name_or_path,
61 | passage_encoder=passage_encoder_name_or_path,
62 | document_index=indexer,
63 | device=device,
64 | precision=precision,
65 | )
66 | retriever.eval()
67 |
68 | retriever.index(
69 | batch_size=batch_size,
70 | num_workers=num_workers,
71 | max_length=passage_max_length,
72 | force_reindex=True,
73 | precision=precision,
74 | )
75 |
76 | output_folder = Path(output_folder)
77 | output_folder.mkdir(exist_ok=True, parents=True)
78 | retriever.document_index.save_pretrained(
79 | output_folder, push_to_hub=push_to_hub, model_id=repo_id
80 | )
81 |
82 |
83 | if __name__ == "__main__":
84 | arg_parser = argparse.ArgumentParser("Create retriever index.")
85 | arg_parser.add_argument("--question-encoder-name-or-path", type=str, required=True)
86 | arg_parser.add_argument("--document-path", type=str, required=True)
87 | arg_parser.add_argument("--passage-encoder-name-or-path", type=str)
88 | arg_parser.add_argument(
89 | "--indexer_class",
90 | type=str,
91 | default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex",
92 | )
93 | arg_parser.add_argument("--document-file-type", type=str, default="jsonl")
94 | arg_parser.add_argument("--output-folder", type=str, required=True)
95 | arg_parser.add_argument("--batch-size", type=int, default=128)
96 | arg_parser.add_argument("--passage-max-length", type=int, default=64)
97 | arg_parser.add_argument("--device", type=str, default="cuda")
98 | arg_parser.add_argument("--index-device", type=str, default="cpu")
99 | arg_parser.add_argument("--precision", type=str, default="fp32")
100 | arg_parser.add_argument("--num-workers", type=int, default=4)
101 | arg_parser.add_argument("--push-to-hub", action="store_true")
102 | arg_parser.add_argument("--repo-id", type=str)
103 |
104 | build_index(**vars(arg_parser.parse_args()))
105 |
--------------------------------------------------------------------------------
/scripts/docker/gunicorn_conf.py:
--------------------------------------------------------------------------------
1 | import json
2 | import multiprocessing
3 | import os
4 |
5 | max_cores_str = os.getenv("MAX_CORES", "1")
6 | workers_per_core_str = os.getenv("WORKERS_PER_CORE", "1")
7 | max_workers_str = os.getenv("MAX_WORKERS", "1")
8 | use_max_workers = None
9 | if max_workers_str:
10 | use_max_workers = int(max_workers_str)
11 | web_concurrency_str = os.getenv("WEB_CONCURRENCY", None)
12 |
13 | host = os.getenv("HOST", "0.0.0.0")
14 | port = os.getenv("PORT", "80")
15 | bind_env = os.getenv("BIND", None)
16 | use_loglevel = os.getenv("LOG_LEVEL", "info")
17 | if bind_env:
18 | use_bind = bind_env
19 | else:
20 | use_bind = f"{host}:{port}"
21 |
22 | cores = int(max_cores_str)
23 | workers_per_core = float(workers_per_core_str)
24 | default_web_concurrency = workers_per_core * cores
25 | if web_concurrency_str:
26 | web_concurrency = int(web_concurrency_str)
27 | assert web_concurrency > 0
28 | else:
29 | web_concurrency = max(int(default_web_concurrency), 1)
30 | if use_max_workers:
31 | web_concurrency = min(web_concurrency, use_max_workers)
32 | accesslog_var = os.getenv("ACCESS_LOG", "-")
33 | use_accesslog = accesslog_var or None
34 | errorlog_var = os.getenv("ERROR_LOG", "-")
35 | use_errorlog = errorlog_var or None
36 | graceful_timeout_str = os.getenv("GRACEFUL_TIMEOUT", "500")
37 | timeout_str = os.getenv("TIMEOUT", "500")
38 | keepalive_str = os.getenv("KEEP_ALIVE", "5")
39 |
40 | # Gunicorn config variables
41 | loglevel = use_loglevel
42 | workers = web_concurrency
43 | bind = use_bind
44 | errorlog = use_errorlog
45 | worker_tmp_dir = "/dev/shm"
46 | accesslog = use_accesslog
47 | graceful_timeout = int(graceful_timeout_str)
48 | timeout = int(timeout_str)
49 | keepalive = int(keepalive_str)
50 |
51 |
52 | # For debugging and testing
53 | log_data = {
54 | "loglevel": loglevel,
55 | "workers": workers,
56 | "bind": bind,
57 | "graceful_timeout": graceful_timeout,
58 | "timeout": timeout,
59 | "keepalive": keepalive,
60 | "errorlog": errorlog,
61 | "accesslog": accesslog,
62 | # Additional, non-gunicorn variables
63 | "workers_per_core": workers_per_core,
64 | "use_max_workers": use_max_workers,
65 | "host": host,
66 | "port": port,
67 | }
68 | # pretty print the log data
69 | print(json.dumps(log_data, indent=2))
70 |
--------------------------------------------------------------------------------
/scripts/docker/pre-start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # cat <<"EOF"
4 |
5 | # _/_/_/ _/_/_/_/ _/ _/_/_/ _/ _/
6 | # _/ _/ _/ _/ _/ _/ _/
7 | # _/_/_/ _/_/_/ _/ _/ _/_/
8 | # _/ _/ _/ _/ _/ _/ _/
9 | # _/ _/ _/_/_/_/ _/_/_/_/ _/_/_/ _/ _/
10 |
11 | # ReLiK Inference API
12 |
13 | # EOF
14 |
15 | # pre-download the model if provided in input
16 | if [ "$1" ]; then
17 | # micromamba run -n base python -c "from relik import Relik; Relik.from_pretrained('$1')"
18 | python3 -c "from relik import Relik; Relik.from_pretrained('$1')"
19 | fi
20 |
--------------------------------------------------------------------------------
/scripts/docker/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # Pre-start
5 | # checkmark font for fancy log
6 | CHECK_MARK="\033[0;32m\xE2\x9C\x94\033[0m"
7 | # usage text
8 | USAGE="$(basename "$0") [-h --help] [-c --config] [-p --precision] [-d --device] [--retriever] [--retriever-device]
9 | [--retriever-precision] [--index-device] [--index-precision] [--reader] [--reader-device] [--reader-precision]
10 |
11 | where:
12 | -h --help Show this help text
13 | -c --config Config name (from HuggingFace) or path
14 | -p --precision Training precision, default '32'.
15 | -d --device Device to use, default 'cpu'.
16 | --retriever Override retriever model name.
17 | --retriever-device Override retriever device.
18 | --retriever-precision Override retriever precision.
19 | --index-device Override index device.
20 | --index-precision Override index precision.
21 | --reader Override reader model name.
22 | --reader-device Override reader device.
23 | --reader-precision Override reader precision.
24 | --annotation-type Annotation type ('char', 'word'), default 'char'.
25 | "
26 |
27 | # Transform long options to short ones
28 | for arg in "$@"; do
29 | shift
30 | case "$arg" in
31 | '--help') set -- "$@" '-h' ;;
32 | '--config') set -- "$@" '-c' ;;
33 | '--precision') set -- "$@" '-p' ;;
34 | '--device') set -- "$@" '-d' ;;
35 | '--retriever') set -- "$@" '-q' ;;
36 | '--retriever-device') set -- "$@" '-w' ;;
37 | '--retriever-precision') set -- "$@" '-e' ;;
38 | '--index-device') set -- "$@" '-r' ;;
39 | '--index-precision') set -- "$@" '-t' ;;
40 | '--reader') set -- "$@" '-y' ;;
41 | '--reader-device') set -- "$@" '-a' ;;
42 | '--reader-precision') set -- "$@" '-b' ;;
43 | '--annotation-type') set -- "$@" '-f' ;;
44 | *) set -- "$@" "$arg" ;;
45 | esac
46 | done
47 |
48 | # check for named params
49 | #while [ $OPTIND -le "$#" ]; do
50 | while getopts ":hc:p:d:q:w:e:r:t:y:a:b:f:" opt; do
51 | case $opt in
52 | h)
53 | printf "%s$USAGE" && exit 0
54 | ;;
55 | c)
56 | export RELIK_PRETRAINED=$OPTARG
57 | ;;
58 | p)
59 | export PRECISION=$OPTARG
60 | ;;
61 | d)
62 | export DEVICE=$OPTARG
63 | ;;
64 | q)
65 | export RETRIEVER_MODEL_NAME=$OPTARG
66 | ;;
67 | w)
68 | export RETRIEVER_DEVICE=$OPTARG
69 | ;;
70 | e)
71 | export RETRIEVER_PRECISION=$OPTARG
72 | ;;
73 | r)
74 | export INDEX_DEVICE=$OPTARG
75 | ;;
76 | t)
77 | export INDEX_PRECISION=$OPTARG
78 | ;;
79 | y)
80 | export READER_MODEL_NAME=$OPTARG
81 | ;;
82 | a)
83 | export READER_DEVICE=$OPTARG
84 | ;;
85 | b)
86 | export READER_PRECISION=$OPTARG
87 | ;;
88 | f)
89 | export ANNOTATION_TYPE=$OPTARG
90 | ;;
91 | \?)
92 | echo "Invalid option -$OPTARG" >&2 && echo "$USAGE" && exit 0
93 | ;;
94 | esac
95 | done
96 |
97 | # FastAPI app location
98 | if [ -z "$APP_MODULE" ]; then
99 | # echo "APP_MODULE not set, using default"
100 | export APP_MODULE=relik.inference.serve.backend.ray:server
101 | fi
102 | # echo "APP_MODULE set to $APP_MODULE"
103 |
104 | # If there's a prestart.sh script in the /app directory, run it before starting
105 | if [ -z "$PRE_START_PATH" ]; then
106 | # echo "PRE_START_PATH not set, using default"
107 | PRE_START_PATH=scripts/docker/pre-start.sh
108 | fi
109 | # echo "PRE_START_PATH set to $PRE_START_PATH"
110 |
111 | if [ -f $PRE_START_PATH ]; then
112 | . "$PRE_START_PATH" $RELIK_PRETRAINED
113 | else
114 | echo "There is no script $PRE_START_PATH"
115 | fi
116 |
117 | # Start Ray Serve with the app
118 | exec serve run "$APP_MODULE" --host 0.0.0.0 --port 8000
119 | # micromamba run -n base serve run "$APP_MODULE" --host 0.0.0.0 --port 8000
120 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/debug.py:
--------------------------------------------------------------------------------
1 | from relik.inference.data.splitters.window_based_splitter import WindowSentenceSplitter
2 | from relik.inference.data.tokenizers.spacy_tokenizer import SpacyTokenizer
3 | from relik.inference.data.window.manager import WindowManager
4 |
5 |
6 | document = {
7 | "doc_id": "-DOCSTART- (956testa SOCCER)",
8 | "doc_text": "SOCCER - RESULTS OF SOUTH KOREAN PRO-SOCCER GAMES . SEOUL 1996-08-30 Results of South Korean pro-soccer games played on Thursday . Pohang 3 Ulsan 2 ( halftime 1-0 ) Puchon 2 Chonbuk 1 ( halftime 1-1 ) Standings after games played on Thursday ( tabulate under - won , drawn , lost , goals for , goals against , points ) : W D L G / F G / A P Puchon 3 1 0 6 1 10 Chonan 3 0 1 13 10 9 Pohang 2 1 1 11 10 7 Suwan 1 3 0 7 3 6 Ulsan 1 0 2 8 9 3 Anyang 0 3 1 6 9 3 Chonnam 0 2 1 4 5 2 Pusan 0 2 1 3 7 2 Chonbuk 0 0 3 3 7 0",
9 | "doc_annotations": [
10 | [20, 32, "South Korea"],
11 | [52, 57, "Seoul"],
12 | [80, 92, "South Korea"],
13 | [131, 137, "Pohang Steelers"],
14 | [140, 145, "Ulsan Hyundai FC"],
15 | [165, 171, "--NME--"],
16 | [174, 181, "--NME--"],
17 | [341, 347, "--NME--"],
18 | [361, 367, "--NME--"],
19 | [382, 388, "Pohang Steelers"],
20 | [403, 408, "--NME--"],
21 | [421, 426, "Ulsan Hyundai FC"],
22 | [439, 445, "Anyang LG Cheetahs"],
23 | [458, 465, "--NME--"],
24 | [478, 483, "--NME--"],
25 | [496, 503, "--NME--"],
26 | ],
27 | }
28 |
29 | tokenizer = SpacyTokenizer(language="en")
30 | sentence_splitter = WindowSentenceSplitter(window_size=32, window_stride=16)
31 |
32 | window_manager = WindowManager(splitter=sentence_splitter, tokenizer=tokenizer)
33 |
34 | doc_info = document["doc_id"]
35 | doc_info = doc_info.replace("-DOCSTART-", "").replace("(", "").replace(")", "").strip()
36 | doc_id, doc_topic = doc_info.split(" ")
37 |
38 | if "testa" in doc_id:
39 | split = "dev"
40 | elif "testb" in doc_id:
41 | split = "test"
42 | else:
43 | split = "train"
44 |
45 | doc_id = doc_id.replace("testa", "").replace("testb", "").strip()
46 | doc_id = int(doc_id)
47 |
48 | windowized_document = window_manager.create_windows(
49 | document["doc_text"],
50 | 32,
51 | 16,
52 | doc_ids=doc_id,
53 | doc_topic=doc_topic,
54 | )
55 |
56 | print(windowized_document)
57 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/aida_to_dpr.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Any, Dict, List, Optional, Union
5 |
6 | from tqdm import tqdm
7 |
8 | from relik.common.log import get_logger
9 |
10 | logger = get_logger()
11 |
12 |
13 | def aida_to_dpr(
14 | conll_path: Union[str, os.PathLike],
15 | output_path: Union[str, os.PathLike],
16 | documents_path: Optional[Union[str, os.PathLike]] = None,
17 | title_map: Optional[Union[str, os.PathLike]] = None,
18 | ) -> List[Dict[str, Any]]:
19 | documents = {}
20 | output_path = Path(output_path)
21 | output_path.parent.mkdir(parents=True, exist_ok=True)
22 | # read entities definitions
23 | logger.info(f"Loading documents from {documents_path}")
24 | with open(documents_path, "r") as f:
25 | for line in f:
26 | line_data = json.loads(line)
27 | title = line_data["text"].strip()
28 | definition = line_data["metadata"]["definition"].strip()
29 | documents[title] = definition
30 |
31 | if title_map is not None:
32 | with open(title_map, "r") as f:
33 | title_map = json.load(f)
34 |
35 | # store dpr data
36 | dpr = []
37 | # lower case titles
38 | title_to_lower_map = {title.lower(): title for title in documents.keys()}
39 | # store missing entities
40 | missing = set()
41 | # Read AIDA file
42 | with open(conll_path, "r") as f, open(output_path, "w") as f_out:
43 | for line in tqdm(f, desc="Processing AIDA data"):
44 | sentence = json.loads(line)
45 | # for sentence in aida_data:
46 | question = sentence["text"]
47 | positive_pssgs = []
48 | for idx, entity in enumerate(sentence["window_labels"]):
49 | entity = entity[2]
50 | if not entity:
51 | continue
52 | entity = entity.strip().lower().replace("_", " ")
53 | # if title_map and entity in title_to_lower_map:
54 | entity = title_to_lower_map.get(entity, entity)
55 | if entity in documents:
56 | def_text = documents[entity]
57 | positive_pssgs.append(
58 | {
59 | "title": title_to_lower_map[entity.lower()],
60 | "text": f"{title_to_lower_map[entity.lower()]} {def_text}",
61 | "passage_id": f"{sentence['doc_id']}_{sentence['offset']}_{idx}",
62 | }
63 | )
64 | else:
65 | missing.add(entity)
66 | print(f"Entity {entity} not found in definitions")
67 |
68 | if len(positive_pssgs) == 0:
69 | continue
70 |
71 | dpr_sentence = {
72 | "id": f"{sentence['doc_id']}_{sentence['offset']}",
73 | "doc_topic": sentence["doc_topic"],
74 | "question": question,
75 | "answers": "",
76 | "positive_ctxs": positive_pssgs,
77 | "negative_ctxs": "",
78 | "hard_negative_ctxs": "",
79 | }
80 | f_out.write(json.dumps(dpr_sentence) + "\n")
81 |
82 | for e in missing:
83 | print(e)
84 | print(f"Number of missing entities: {len(missing)}")
85 |
86 | return dpr
87 |
88 |
89 | if __name__ == "__main__":
90 | import argparse
91 |
92 | parser = argparse.ArgumentParser()
93 | parser.add_argument("input", type=str, help="Path to AIDA file")
94 | parser.add_argument("output", type=str, help="Path to output file")
95 | parser.add_argument("documents", type=str, help="Path to entities definitions file")
96 | parser.add_argument("--title_map", type=str, help="Path to title map file")
97 | args = parser.parse_args()
98 |
99 | # Convert to DPR
100 | aida_to_dpr(args.input, args.output, args.documents, args.title_map)
101 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/blink/create_random_sample_coverage.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from collections import defaultdict
5 | from pathlib import Path
6 | from typing import Union
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from relik.common.log import get_logger
12 |
13 | logger = get_logger()
14 |
15 |
16 | def sample(
17 | input_file: Union[str, os.PathLike],
18 | output_file: Union[str, os.PathLike],
19 | n_samples: int = 5,
20 | seed: int = 42,
21 | ):
22 | documents = defaultdict(list)
23 |
24 | logger.info(f"Loading data from {input_file}")
25 | with open(input_file) as f:
26 | for i, line in tqdm(enumerate(f)):
27 | try:
28 | sample = json.loads(line)
29 | # data.append(sample)
30 | labels = [l[-1] for l in sample["window_labels"]]
31 | for label in labels:
32 | documents[label].append(i)
33 | except json.JSONDecodeError:
34 | logger.error(f"Error parsing line {i}")
35 | continue
36 |
37 | logger.info("Sampling data")
38 | # Random sample from in-distribution documents
39 | np.random.seed(seed)
40 | documents = {
41 | k: np.random.choice(v, min(len(v), n_samples), replace=False).tolist()
42 | for k, v in tqdm(documents.items())
43 | }
44 |
45 | output_file_path = Path(output_file)
46 | output_file_path.parent.mkdir(parents=True, exist_ok=True)
47 | logger.info(f"Saving sampled data to {output_file}")
48 | with open(output_file, "w") as f:
49 | json.dump(documents, f, indent=2)
50 |
51 |
52 | if __name__ == "__main__":
53 | arg_parser = argparse.ArgumentParser()
54 | arg_parser.add_argument("input_file", type=str)
55 | arg_parser.add_argument("output_file", type=str)
56 | arg_parser.add_argument("--n_samples", type=int, required=False, default=5)
57 |
58 | sample(**vars(arg_parser.parse_args()))
59 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/blink/sample_from_data_coverate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from collections import defaultdict
5 | from pathlib import Path
6 | from typing import Union
7 |
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | from relik.common.log import get_logger
12 |
13 | logger = get_logger()
14 |
15 |
16 | def sample(
17 | input_file: Union[str, os.PathLike],
18 | output_file: Union[str, os.PathLike],
19 | sample_index_file: Union[str, os.PathLike],
20 | ):
21 |
22 | logger.info(f"Loading sample index from {sample_index_file}")
23 | with open(sample_index_file) as f:
24 | sample_index = json.load(f)
25 | # get all unique values
26 | sample_index = set(
27 | [item for sublist in sample_index.values() for item in sublist]
28 | )
29 |
30 | output_file_path = Path(output_file)
31 | output_file_path.parent.mkdir(parents=True, exist_ok=True)
32 | logger.info(f"Loading data from {input_file} and sampling")
33 | logger.info(f"Saving sampled data to {output_file}")
34 | with open(input_file) as f, open(output_file, "w") as f_out:
35 | for i, line in tqdm(enumerate(f)):
36 | if int(i) in sample_index:
37 | try:
38 | f_out.write(json.dumps(json.loads(line)) + "\n")
39 | except json.JSONDecodeError:
40 | logger.error(f"Error parsing line {i}")
41 | continue
42 |
43 |
44 | if __name__ == "__main__":
45 | arg_parser = argparse.ArgumentParser()
46 | arg_parser.add_argument("input_file", type=str)
47 | arg_parser.add_argument("output_file", type=str)
48 | arg_parser.add_argument("sample_index_file", type=str)
49 |
50 | sample(**vars(arg_parser.parse_args()))
51 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/create_index.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import csv
3 | import os
4 | from pathlib import Path
5 | from typing import Optional, Union
6 |
7 | import torch
8 |
9 | from relik.retriever import GoldenRetriever
10 | from relik.common.utils import get_logger, get_callable_from_string
11 | from relik.retriever.indexers.document import DocumentStore
12 |
13 | logger = get_logger(__name__)
14 |
15 |
16 | @torch.no_grad()
17 | def build_index(
18 | question_encoder_name_or_path: Union[str, os.PathLike],
19 | document_path: Union[str, os.PathLike],
20 | output_folder: Union[str, os.PathLike],
21 | document_file_type: str = "jsonl",
22 | passage_encoder_name_or_path: Optional[Union[str, os.PathLike]] = None,
23 | indexer_class: str = "relik.retriever.indexers.inmemory.InMemoryDocumentIndex",
24 | batch_size: int = 512,
25 | num_workers: int = 4,
26 | passage_max_length: int = 64,
27 | device: str = "cuda",
28 | index_device: str = "cpu",
29 | precision: str = "fp32",
30 | ):
31 | logger.info("Loading documents")
32 | if document_file_type == "jsonl":
33 | documents = DocumentStore.from_file(document_path)
34 | elif document_file_type == "csv":
35 | documents = DocumentStore.from_tsv(
36 | document_path, delimiter=",", quoting=csv.QUOTE_NONE, ingore_case=True
37 | )
38 | elif document_file_type == "tsv":
39 | documents = DocumentStore.from_tsv(
40 | document_path, delimiter="\t", quoting=csv.QUOTE_NONE, ingore_case=True
41 | )
42 | else:
43 | raise ValueError(
44 | f"Unknown document file type: {document_file_type}, must be one of jsonl, csv, tsv"
45 | )
46 |
47 | logger.info("Loading document index")
48 | logger.info(f"Loaded {len(documents)} documents")
49 | indexer = get_callable_from_string(indexer_class)(
50 | documents, device=index_device, precision=precision
51 | )
52 |
53 | retriever = GoldenRetriever(
54 | question_encoder=question_encoder_name_or_path,
55 | passage_encoder=passage_encoder_name_or_path,
56 | document_index=indexer,
57 | device=device,
58 | precision=precision,
59 | )
60 | retriever.eval()
61 |
62 | retriever.index(
63 | batch_size=batch_size,
64 | num_workers=num_workers,
65 | max_length=passage_max_length,
66 | force_reindex=True,
67 | precision=precision,
68 | )
69 |
70 | output_folder = Path(output_folder)
71 | output_folder.mkdir(exist_ok=True, parents=True)
72 | retriever.save_pretrained(output_folder)
73 |
74 |
75 | if __name__ == "__main__":
76 | arg_parser = argparse.ArgumentParser()
77 | arg_parser.add_argument("--question_encoder_name_or_path", type=str, required=True)
78 | arg_parser.add_argument("--document_path", type=str, required=True)
79 | arg_parser.add_argument("--passage_encoder_name_or_path", type=str)
80 | arg_parser.add_argument(
81 | "--indexer_class",
82 | type=str,
83 | default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex",
84 | )
85 | arg_parser.add_argument("--document_file_type", type=str, default="jsonl")
86 | arg_parser.add_argument("--output_folder", type=str, required=True)
87 | arg_parser.add_argument("--batch_size", type=int, default=128)
88 | arg_parser.add_argument("--passage_max_length", type=int, default=64)
89 | arg_parser.add_argument("--device", type=str, default="cuda")
90 | arg_parser.add_argument("--index_device", type=str, default="cpu")
91 | arg_parser.add_argument("--precision", type=str, default="fp32")
92 |
93 | build_index(**vars(arg_parser.parse_args()))
94 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/explore_blink.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 |
5 | if __name__ == "__main__":
6 | # arg_parser = argparse.ArgumentParser()
7 | # arg_parser.add_argument("--question_encoder_name_or_path", type=str, required=True)
8 | # arg_parser.add_argument("--document_path", type=str, required=True)
9 | # arg_parser.add_argument("--passage_encoder_name_or_path", type=str)
10 | # arg_parser.add_argument(
11 | # "--indexer_class",
12 | # type=str,
13 | # default="relik.retriever.indexers.inmemory.InMemoryDocumentIndex",
14 | # )
15 | # arg_parser.add_argument("--document_file_type", type=str, default="jsonl")
16 | # arg_parser.add_argument("--output_folder", type=str, required=True)
17 | # arg_parser.add_argument("--batch_size", type=int, default=128)
18 | # arg_parser.add_argument("--passage_max_length", type=int, default=64)
19 | # arg_parser.add_argument("--device", type=str, default="cuda")
20 | # arg_parser.add_argument("--index_device", type=str, default="cpu")
21 | # arg_parser.add_argument("--precision", type=str, default="fp32")
22 |
23 | # build_index(**vars(arg_parser.parse_args()))
24 |
25 | with open("/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/first_1M.jsonl") as f:
26 | data = [json.loads(line) for line in f]
27 |
28 |
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/save_retriever_from_checkpoint.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SapienzaNLP/relik/999baf657a9df095ac138fac61bff944dff3d8ea/scripts/old-scripts/data/retriever/save_retriever_from_checkpoint.py
--------------------------------------------------------------------------------
/scripts/old-scripts/data/retriever/triplets_to_dpr.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from pathlib import Path
4 | from typing import Union, Dict, List, Optional, Any
5 |
6 | from tqdm import tqdm
7 |
8 | # from transformers import AutoTokenizer, BertTokenizer
9 |
10 |
11 | def aida_to_dpr(
12 | conll_path: Union[str, os.PathLike],
13 | output_path: Union[str, os.PathLike],
14 | definitions_path: Optional[Union[str, os.PathLike]] = None,
15 | ) -> List[Dict[str, Any]]:
16 | definitions = {}
17 | output_path = Path(output_path)
18 | output_path.parent.mkdir(parents=True, exist_ok=True)
19 | # read entities definitions
20 | with open(definitions_path, "r") as f:
21 | for line in f:
22 | line_data = json.loads(line)
23 | title = line_data["text"].strip()
24 | definition = line_data["metadata"]["definition"].strip()
25 | definitions[title] = definition
26 | # title, definition = line.split(" ")
27 | # title = title.strip()
28 | # definition = definition.strip()
29 | # definitions[title] = definition
30 |
31 | dpr = []
32 |
33 | title_to_lower_map = {title.lower(): title for title in definitions.keys()}
34 |
35 | missing = set()
36 |
37 | # Read AIDA file
38 | with open(conll_path, "r") as f, open(output_path, "w") as f_out:
39 | for line in tqdm(f):
40 | sentence = json.loads(line)
41 | # for sentence in aida_data:
42 | question = sentence["text"]
43 | positive_pssgs = []
44 | for idx, triplet in enumerate(sentence["triplets"]):
45 | relation = triplet["relation"]["name"]
46 | if not relation:
47 | continue
48 | if relation in definitions:
49 | def_text = definitions[relation]
50 | positive_pssgs.append(
51 | {
52 | "title": title_to_lower_map[relation.lower()],
53 | "text": f"{title_to_lower_map[relation.lower()]} {def_text}",
54 | "passage_id": f"{sentence['doc_id']}_{sentence['offset']}_{idx}",
55 | }
56 | )
57 | else:
58 | missing.add(relation)
59 | # print(f"Entity {entity} not found in definitions")
60 |
61 | if len(positive_pssgs) == 0:
62 | continue
63 |
64 | dpr_sentence = {
65 | "id": f"{sentence['doc_id']}_{sentence['offset']}",
66 | "doc_topic": sentence["doc_topic"] if "doc_topic" in sentence else "",
67 | "question": question,
68 | "answers": "",
69 | "positive_ctxs": positive_pssgs,
70 | "negative_ctxs": "",
71 | "hard_negative_ctxs": "",
72 | }
73 | f_out.write(json.dumps(dpr_sentence) + "\n")
74 |
75 | for e in missing:
76 | print(e)
77 | print(f"Number of missing entities: {len(missing)}")
78 |
79 | return dpr
80 |
81 |
82 | if __name__ == "__main__":
83 | import argparse
84 |
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument("input", type=str, help="Path to AIDA file")
87 | parser.add_argument("output", type=str, help="Path to output file")
88 | parser.add_argument(
89 | "--definitions", type=str, help="Path to entities definitions file"
90 | )
91 | args = parser.parse_args()
92 |
93 | # Convert to DPR
94 | aida_to_dpr(args.input, args.output, args.definitions)
--------------------------------------------------------------------------------
/scripts/old-scripts/data/split_aida.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from tqdm import tqdm
4 |
5 | with open("data/processed/aida_from_edo.jsonl", "r") as f:
6 | aida = [json.loads(l) for l in f]
7 |
8 | aida_train, aida_dev, aida_test = [], [], []
9 |
10 | for document in tqdm(aida):
11 | doc_info = document["doc_id"]
12 |
13 | # clean doc_info, e.g. "-DOCSTART- (1 EU)"
14 | doc_info = (
15 | doc_info.replace("-DOCSTART-", "").replace("(", "").replace(")", "").strip()
16 | )
17 | doc_id, doc_topic = doc_info.split(" ")
18 |
19 | if "testa" in doc_id:
20 | split = "dev"
21 | elif "testb" in doc_id:
22 | split = "test"
23 | else:
24 | split = "train"
25 |
26 | doc_id = doc_id.replace("testa", "").replace("testb", "").strip()
27 | doc_id = int(doc_id)
28 |
29 | document["doc_id"] = doc_id
30 |
31 | if split == "train":
32 | aida_train.append(document)
33 | elif split == "dev":
34 | aida_dev.append(document)
35 | else:
36 | aida_test.append(document)
37 |
38 | with open("data/processed/aida_train.jsonl", "w") as f:
39 | for document in aida_train:
40 | f.write(json.dumps(document) + "\n")
41 |
42 | with open("data/processed/aida_dev.jsonl", "w") as f:
43 | for document in aida_dev:
44 | f.write(json.dumps(document) + "\n")
45 |
46 | with open("data/processed/aida_test.jsonl", "w") as f:
47 | for document in aida_test:
48 | f.write(json.dumps(document) + "\n")
49 |
--------------------------------------------------------------------------------
/scripts/old-scripts/evaluate/evaluate_re.py:
--------------------------------------------------------------------------------
1 | from relik import Relik
2 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
3 | from relik.retriever import GoldenRetriever
4 | from relik.inference.data.objects import TaskType
5 | from relik.reader.utils.relation_matching_eval import StrongMatchingPerRelation, StrongMatching
6 | from relik.inference.data.objects import AnnotationType
7 |
8 | import json
9 | import argparse
10 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples
11 |
12 | def evaluate(reader_path, question_encoder_path, document_index_path, input_file, output_file, max_triplets=8, use_predefined_spans=False):
13 |
14 | reader = RelikReaderForTripletExtraction(reader_path,
15 | dataset_kwargs={"use_nme": False, "max_triplets": max_triplets})
16 |
17 |
18 | retriever = {
19 | "triplet": GoldenRetriever(
20 | question_encoder=question_encoder_path,
21 | document_index=document_index_path,
22 | ),
23 | }
24 |
25 | relik = Relik(reader=reader, retriever=retriever, top_k=max_triplets, task=TaskType.TRIPLET, device="cuda", window_size="sentence")
26 |
27 | samples = list(load_relik_reader_samples(input_file))
28 | # add text field to samples from joining words if there is no text field
29 | for id, sample in enumerate(samples):
30 | # del sample._d["span_candidates"]
31 | if sample._d.get("text") is None:
32 | sample._d["text"] = " ".join(sample.words)
33 | sample.doc_id = id
34 | results = relik(windows=samples, num_workers=4, device="cuda", progress_bar=True, annotation_type=AnnotationType.WORD, return_also_windows=True, use_predefined_spans=use_predefined_spans, relation_threshold=0.5)
35 | windows = []
36 | for sample in results:
37 | windows.extend(sample.windows)
38 | with open(output_file, "w") as f:
39 | for sample in windows:
40 | f.write(sample.to_jsons() + "\n")
41 |
42 | strong_matching_metric = StrongMatchingPerRelation()
43 | results = list(windows)
44 | for k, v in strong_matching_metric(results).items():
45 | print(f"test_{k}", v)
46 |
47 | strong_matching_metric = StrongMatching()
48 | for k, v in strong_matching_metric(results).items():
49 | print(f"test_{k}", v)
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--reader_path", type=str, required=True)
54 | parser.add_argument("--question_encoder_path", type=str, required=True)
55 | parser.add_argument("--document_index_path", type=str, required=True)
56 | parser.add_argument("--input_file", type=str, required=True)
57 | parser.add_argument("--output_file", type=str, required=True)
58 | parser.add_argument("--max_triplets", type=int, default=8)
59 | parser.add_argument("--use_predefined_spans", action="store_true")
60 | args = parser.parse_args()
61 | evaluate(args.reader_path, args.question_encoder_path, args.document_index_path, args.input_file, args.output_file, args.max_triplets, args.use_predefined_spans)
--------------------------------------------------------------------------------
/scripts/old-scripts/evaluate/evaluate_re_bio.py:
--------------------------------------------------------------------------------
1 | from relik import Relik
2 | from relik.reader.pytorch_modules.triplet import RelikReaderForTripletExtraction
3 | from relik.retriever import GoldenRetriever
4 | from relik.inference.data.objects import TaskType
5 | from relik.reader.utils.relation_matching_eval import StrongMatchingPerRelation, StrongMatching
6 | from relik.inference.data.objects import AnnotationType
7 |
8 | import json
9 | import pandas as pd
10 | import argparse
11 | import os
12 | from relik.reader.data.relik_reader_sample import load_relik_reader_samples
13 |
14 | def evaluate(reader_path, question_encoder_path, document_index_path, input_file, output_file, max_triplets=8, use_predefined_spans=False):
15 | # input_file path fileame bio-map.tsv
16 | mapping_classes = pd.read_csv(os.path.join(os.path.dirname(input_file), "bio-map.tsv"), sep="\t")
17 | mapping_dict = dict(zip(mapping_classes.label, mapping_classes.hierarchy))
18 |
19 | reader = RelikReaderForTripletExtraction(reader_path,
20 | dataset_kwargs={"use_nme": False, "max_triplets": max_triplets})
21 |
22 |
23 | retriever = {
24 | "triplet": GoldenRetriever(
25 | question_encoder=question_encoder_path,
26 | document_index=document_index_path,
27 | ),
28 | }
29 |
30 | relik = Relik(reader=reader, retriever=retriever, top_k=max_triplets, task=TaskType.TRIPLET, device="cuda", window_size="sentence")
31 |
32 | samples = list(load_relik_reader_samples(input_file))
33 | # add text field to samples from joining words if there is no text field
34 | for id, sample in enumerate(samples):
35 | if sample._d.get("text") is None:
36 | sample._d["text"] = " ".join(sample.words)
37 | sample.doc_id = id
38 | results = relik(windows=samples, num_workers=4, device="cuda", progress_bar=True, annotation_type=AnnotationType.WORD, return_also_windows=True, use_predefined_spans=use_predefined_spans, relation_threshold=0.5)
39 | windows = []
40 | for sample in results:
41 | windows.extend(sample.windows)
42 | with open(output_file, "w") as f:
43 | for sample in windows:
44 | f.write(sample.to_jsons() + "\n")
45 |
46 | for sample in windows:
47 | for triplet in sample.predicted_relations:
48 | triplet["relation"]["name"] = mapping_dict[triplet["relation"]["name"]]
49 |
50 | strong_matching_metric = StrongMatchingPerRelation()
51 | results = list(windows)
52 | for k, v in strong_matching_metric(results).items():
53 | print(f"test_{k}", v)
54 |
55 | strong_matching_metric = StrongMatching()
56 | for k, v in strong_matching_metric(results).items():
57 | print(f"test_{k}", v)
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--reader_path", type=str, required=True)
62 | parser.add_argument("--question_encoder_path", type=str, required=True)
63 | parser.add_argument("--document_index_path", type=str, required=True)
64 | parser.add_argument("--input_file", type=str, required=True)
65 | parser.add_argument("--output_file", type=str, required=True)
66 | parser.add_argument("--max_triplets", type=int, default=8)
67 | parser.add_argument("--use_predefined_spans", action="store_true")
68 | args = parser.parse_args()
69 | evaluate(args.reader_path, args.question_encoder_path, args.document_index_path, args.input_file, args.output_file, args.max_triplets, args.use_predefined_spans)
--------------------------------------------------------------------------------
/scripts/old-scripts/predict/predict_aida.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pprintpp import pprint
3 |
4 | from relik.inference.annotator import Relik
5 | from relik.inference.data.objects import TaskType
6 | from relik.reader.pytorch_modules.span import RelikReaderForSpanExtraction
7 | from relik.retriever.pytorch_modules.model import GoldenRetriever
8 |
9 |
10 | def main():
11 | # retriever = GoldenRetriever(
12 | # question_encoder="riccorl/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder",
13 | # document_index="riccorl/retriever-relik-entity-linking-aida-wikipedia-base-index",
14 | # device="cuda",
15 | # index_device="cpu",
16 | # precision=16,
17 | # index_precision=32,
18 | # )
19 | # reader = RelikReaderForSpanExtraction(
20 | # "riccorl/reader-relik-entity-linking-aida-wikipedia-small"
21 | # )
22 |
23 | # relik = Relik(
24 | # retriever=retriever,
25 | # reader=reader,
26 | # top_k=100,
27 | # window_size=32,
28 | # window_stride=16,
29 | # task=TaskType.SPAN,
30 | # )
31 | # relik.save_pretrained(
32 | # "relik-entity-linking-aida-wikipedia-tiny",
33 | # save_weights=False,
34 | # push_to_hub=True,
35 | # # reader_model_id="reader-relik-entity-linking-aida-wikipedia-small",
36 | # # retriever_model_id="retriever-relik-entity-linking-aida-wikipedia-base",
37 | # )
38 |
39 | reader = RelikReaderForSpanExtraction(
40 | "riccorl/relik-reader-deberta-base-retriever-relik-entity-linking-aida-wikipedia-large",
41 | device="cuda",
42 | precision="fp16", # , reader_device="cpu", reader_precision="fp32"
43 | dataset_kwargs={"use_nme": True},
44 | )
45 |
46 | relik = Relik(reader=reader)
47 |
48 | with open(
49 | "/home/ric/Projects/relik-sapienzanlp/data/reader/retriever-relik-entity-linking-aida-wikipedia-base-question-encoder/testa_windowed_candidates.jsonl"
50 | ) as f:
51 | data = [json.loads(line) for line in f]
52 |
53 | text = [data["text"] for data in data]
54 | candidates = [data["span_candidates"] for data in data]
55 |
56 | predictions = relik(
57 | text,
58 | candidates=candidates,
59 | window_size="none",
60 | annotation_type="char",
61 | progress_bar=True,
62 | reader_batch_size=32,
63 | )
64 |
65 | output = []
66 |
67 | for p, s in zip(predictions, data):
68 | output.append(
69 | {
70 | "doc_id": s["doc_id"],
71 | "window_id": s["window_id"],
72 | "text": s["text"],
73 | "window_labels": [
74 | [span[0] - s["offset"], span[1] - s["offset"], span[2]]
75 | for span in s["window_labels"]
76 | ],
77 | "predictions": [[span.start, span.end, span.label] for span in p.spans],
78 | }
79 | )
80 |
81 | with open(
82 | "/home/ric/Projects/relik-sapienzanlp/experiments/predictions/deberta-large/testa.jsonl",
83 | "w",
84 | ) as f:
85 | for line in output:
86 | f.write(json.dumps(line) + "\n")
87 |
88 |
89 | if __name__ == "__main__":
90 | main()
91 |
--------------------------------------------------------------------------------
/scripts/old-scripts/retriever/test_aida.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from relik.common.log import get_logger
3 | from relik.retriever import GoldenRetriever
4 | from relik.retriever.data.datasets import AidaInBatchNegativesDataset
5 | from relik.retriever.indexers.document import DocumentStore
6 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
7 | from relik.retriever.trainer import RetrieverTrainer
8 |
9 | logger = get_logger(__name__)
10 |
11 | if __name__ == "__main__":
12 |
13 | arg_parser = argparse.ArgumentParser()
14 | arg_parser.add_argument("encoder", type=str, required=True)
15 | arg_parser.add_argument("index", type=str, required=True)
16 | args = arg_parser.parse_args()
17 |
18 | # instantiate retriever
19 | retriever = GoldenRetriever(
20 | question_encoder=args.encoder,
21 | document_index=args.index,
22 | device="cuda",
23 | )
24 |
25 | # val_dataset = AidaInBatchNegativesDataset(
26 | # name="aida_val",
27 | # path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/val.jsonl",
28 | # tokenizer=retriever.question_tokenizer,
29 | # question_batch_size=64,
30 | # passage_batch_size=400,
31 | # max_passage_length=64,
32 | # use_topics=True,
33 | # )
34 | test_dataset = AidaInBatchNegativesDataset(
35 | name="aida_test",
36 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/test.jsonl",
37 | tokenizer=retriever.question_tokenizer,
38 | question_batch_size=64,
39 | passage_batch_size=400,
40 | max_passage_length=64,
41 | use_topics=True,
42 | )
43 |
44 | trainer = RetrieverTrainer(
45 | retriever=retriever,
46 | # train_dataset=train_dataset,
47 | # val_dataset=val_dataset,
48 | test_dataset=test_dataset,
49 | num_workers=4,
50 | max_steps=25_000,
51 | log_to_wandb=False,
52 | # wandb_online_mode=False,
53 | # wandb_project_name="relik-retriever-aida",
54 | # wandb_experiment_name="aida-e5-base-topics-from-blink-new-data",
55 | max_hard_negatives_to_mine=15,
56 | resume_from_checkpoint_path=None, # path to lightning checkpoint
57 | trainer_kwargs={"logger": False},
58 | )
59 |
60 | # trainer.train()
61 | trainer.test()
62 |
--------------------------------------------------------------------------------
/scripts/old-scripts/retriever/train_aida.py:
--------------------------------------------------------------------------------
1 | from relik.common.log import get_logger
2 | from relik.retriever import GoldenRetriever
3 | from relik.retriever.data.datasets import AidaInBatchNegativesDataset
4 | from relik.retriever.indexers.document import DocumentStore
5 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
6 | from relik.retriever.trainer import RetrieverTrainer
7 |
8 | logger = get_logger(__name__)
9 |
10 | if __name__ == "__main__":
11 | # instantiate retriever
12 | retriever = GoldenRetriever(
13 | # question_encoder="/root/golden-retriever/wandb/blink-first1M-e5-base-topics/files/retriever/question_encoder",
14 | question_encoder="riccorl/golden-retriever-base-blink-before-hf",
15 | document_index=InMemoryDocumentIndex(
16 | documents=DocumentStore.from_file(
17 | "/root/relik-sapienzanlp/data/retriever/el/documents.jsonl"
18 | ),
19 | metadata_fields=["definition"],
20 | separator=" ",
21 | device="cuda",
22 | precision="16",
23 | ),
24 | )
25 |
26 | train_dataset = AidaInBatchNegativesDataset(
27 | name="aida_train",
28 | # path="/root/golden-retriever/data/entitylinking/aida_32_tokens_topic/train.jsonl",
29 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/train.jsonl",
30 | tokenizer=retriever.question_tokenizer,
31 | question_batch_size=64,
32 | passage_batch_size=400,
33 | max_passage_length=64,
34 | shuffle=True,
35 | use_topics=True,
36 | )
37 | val_dataset = AidaInBatchNegativesDataset(
38 | name="aida_val",
39 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/val.jsonl",
40 | tokenizer=retriever.question_tokenizer,
41 | question_batch_size=64,
42 | passage_batch_size=400,
43 | max_passage_length=64,
44 | use_topics=True,
45 | )
46 | test_dataset = AidaInBatchNegativesDataset(
47 | name="aida_test",
48 | path="/root/relik-sapienzanlp/data/retriever/el/aida_32_tokens_topic_relik/test.jsonl",
49 | tokenizer=retriever.question_tokenizer,
50 | question_batch_size=64,
51 | passage_batch_size=400,
52 | max_passage_length=64,
53 | use_topics=True,
54 | )
55 |
56 | trainer = RetrieverTrainer(
57 | retriever=retriever,
58 | train_dataset=train_dataset,
59 | val_dataset=val_dataset,
60 | test_dataset=test_dataset,
61 | num_workers=4,
62 | max_steps=25_000,
63 | wandb_online_mode=True,
64 | wandb_project_name="relik-retriever-aida",
65 | wandb_experiment_name="aida-e5-base-topics-from-blink-new-data",
66 | max_hard_negatives_to_mine=15,
67 | resume_from_checkpoint_path=None, # path to lightning checkpoint
68 | )
69 |
70 | trainer.train()
71 | trainer.test()
72 |
--------------------------------------------------------------------------------
/scripts/old-scripts/retriever/train_blink.py:
--------------------------------------------------------------------------------
1 | from relik.common.log import get_logger
2 | from relik.retriever import GoldenRetriever
3 | from relik.retriever.data.datasets import (
4 | AidaInBatchNegativesDataset,
5 | SubsampleStrategyEnum,
6 | )
7 | from relik.retriever.indexers.document import DocumentStore
8 | from relik.retriever.indexers.inmemory import InMemoryDocumentIndex
9 | from relik.retriever.trainer import Trainer
10 |
11 | logger = get_logger(__name__)
12 |
13 | if __name__ == "__main__":
14 | # instantiate retriever
15 | retriever = GoldenRetriever(question_encoder="intfloat/e5-base-v2")
16 |
17 | train_dataset = AidaInBatchNegativesDataset(
18 | name="aida_train",
19 | path="/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/first_1M.jsonl",
20 | tokenizer=retriever.question_tokenizer,
21 | question_batch_size=64,
22 | passage_batch_size=400,
23 | max_passage_length=64,
24 | shuffle=True,
25 | subsample_strategy=SubsampleStrategyEnum.RANDOM,
26 | # use_topics=True,
27 | )
28 | val_dataset = AidaInBatchNegativesDataset(
29 | name="aida_val",
30 | path="/media/data/EL/blink/window_32_tokens/random_1M/dpr-like/val.jsonl",
31 | tokenizer=retriever.question_tokenizer,
32 | question_batch_size=64,
33 | passage_batch_size=400,
34 | max_passage_length=64,
35 | # use_topics=True,
36 | )
37 | # test_dataset = AidaInBatchNegativesDataset(
38 | # name="aida_test",
39 | # path="/root/golden-retriever/data/entitylinking/aida_32_tokens_topic/test.jsonl",
40 | # tokenizer=retriever.question_tokenizer,
41 | # question_batch_size=64,
42 | # passage_batch_size=400,
43 | # max_passage_length=64,
44 | # use_topics=True,
45 | # )
46 |
47 | logger.info("Loading document index")
48 | document_index = InMemoryDocumentIndex(
49 | documents=DocumentStore.from_file(
50 | "/root/golden-retriever/data/entitylinking/documents.jsonl"
51 | ),
52 | metadata_fields=["definition"],
53 | separator=" ",
54 | device="cuda",
55 | precision="16",
56 | )
57 | retriever.document_index = document_index
58 |
59 | trainer = Trainer(
60 | retriever=retriever,
61 | train_dataset=train_dataset,
62 | val_dataset=val_dataset,
63 | test_dataset=None,
64 | num_workers=0,
65 | max_steps=400_000,
66 | wandb_online_mode=True,
67 | wandb_project_name="relik-retriever-blink",
68 | wandb_experiment_name="blink-first1M-e5-base-topics-recheck",
69 | max_hard_negatives_to_mine=15,
70 | mine_hard_negatives_with_probability=0.2,
71 | save_last=True,
72 | resume_from_checkpoint_path=None, # path to lightning checkpoint
73 | )
74 |
75 | trainer.train()
76 | trainer.test()
77 |
--------------------------------------------------------------------------------
/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # setup conda
4 | CONDA_BASE=$(conda info --base)
5 | # check if conda is installed
6 | if [ -z "$CONDA_BASE" ]; then
7 | echo "Conda is not installed. Please install conda first."
8 | exit 1
9 | fi
10 | source "$CONDA_BASE"/etc/profile.d/conda.sh
11 |
12 | # create conda env
13 | read -rp "Enter environment name or prefix: " ENV_NAME
14 | read -rp "Enter python version (default 3.10): " PYTHON_VERSION
15 | if [ -z "$PYTHON_VERSION" ]; then
16 | PYTHON_VERSION="3.10"
17 | fi
18 |
19 | # check if ENV_NAME is a full path
20 | if [[ "$ENV_NAME" == /* ]]; then
21 | CONDA_NEW_ARG="--prefix"
22 | else
23 | CONDA_NEW_ARG="--name"
24 | fi
25 |
26 | conda create -y "$CONDA_NEW_ARG" "$ENV_NAME" python="$PYTHON_VERSION"
27 | conda activate "$ENV_NAME"
28 |
29 | # replace placeholder env with $ENV_NAME in scripts/train.sh
30 | # NEW_CONDA_LINE="source \$CONDA_BASE/bin/activate $ENV_NAME"
31 | # sed -i.bak -e "s,.*bin/activate.*,$NEW_CONDA_LINE,g" scripts/train.sh
32 |
33 | # install torch
34 | read -rp "Enter cuda version (e.g. '11.8', default no cuda support): " CUDA_VERSION
35 | read -rp "Enter PyTorch version (e.g. '2.1', default latest): " PYTORCH_VERSION
36 | if [ -n "$PYTORCH_VERSION" ]; then
37 | PYTORCH_VERSION="=$PYTORCH_VERSION"
38 | fi
39 | if [ -z "$CUDA_VERSION" ]; then
40 | conda install -y pytorch"$PYTORCH_VERSION" cpuonly -c pytorch
41 | else
42 | conda install -y pytorch"$PYTORCH_VERSION" pytorch-cuda="$CUDA_VERSION" -c pytorch -c nvidia
43 | fi
44 |
45 | # install python requirements
46 | pip install -e .[all]
47 |
--------------------------------------------------------------------------------