├── .coveragerc ├── .flake8 ├── .github └── workflows │ └── ci.yaml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CHANGELOG.rst ├── LICENSE ├── Makefile ├── README.md ├── bootleg ├── __init__.py ├── _version.py ├── data.py ├── dataset.py ├── end2end │ ├── __init__.py │ ├── annotator_utils.py │ ├── bootleg_annotator.py │ └── extract_mentions.py ├── extract_all_entities.py ├── layers │ ├── __init__.py │ ├── alias_to_ent_encoder.py │ ├── bert_encoder.py │ └── static_entity_embeddings.py ├── run.py ├── scorer.py ├── slicing │ ├── __init__.py │ └── slice_dataset.py ├── symbols │ ├── __init__.py │ ├── constants.py │ ├── entity_profile.py │ ├── entity_symbols.py │ ├── kg_symbols.py │ └── type_symbols.py ├── task_config.py ├── tasks │ ├── __init__.py │ ├── entity_gen_task.py │ └── ned_task.py └── utils │ ├── __init__.py │ ├── classes │ ├── __init__.py │ ├── comment_json.py │ ├── dotted_dict.py │ ├── emmental_data.py │ └── nested_vocab_tries.py │ ├── data_utils.py │ ├── eval_utils.py │ ├── mention_extractor_utils.py │ ├── model_utils.py │ ├── parser │ ├── __init__.py │ ├── bootleg_args.py │ ├── emm_parse_args.py │ └── parser_utils.py │ ├── preprocessing │ ├── __init__.py │ ├── compute_statistics.py │ ├── convert_to_char_spans.py │ ├── get_train_qid_counts.py │ └── sample_eval_data.py │ └── utils.py ├── cand_gen ├── data.py ├── dataset.py ├── eval.py ├── task_config.py ├── tasks │ ├── candgen_task.py │ ├── context_gen_task.py │ └── entity_gen_task.py ├── train.py └── utils │ ├── merge_contextual_cands.py │ └── parser │ ├── candgen_args.py │ └── parser_utils.py ├── configs ├── cand_gen │ └── bi_train.yaml ├── gcp │ ├── bootleg_cand_gen_test.yaml │ ├── bootleg_test.yaml │ ├── bootleg_wiki.yaml │ └── launch_gcp.py ├── standard │ └── train.yaml └── tutorial │ ├── bootleg_wiki.yaml │ └── sample_config.yaml ├── data ├── sample_entity_db │ ├── entity_mappings │ │ ├── alias2id │ │ │ ├── config.json │ │ │ ├── itoexti.npy │ │ │ └── vocabulary_trie.marisa │ │ ├── alias2qids │ │ │ ├── max_value.json │ │ │ ├── record_trie.marisa │ │ │ └── vocabulary_trie.marisa │ │ ├── config.json │ │ ├── qid2desc.json │ │ ├── qid2eid │ │ │ ├── config.json │ │ │ ├── itoexti.npy │ │ │ └── vocabulary_trie.marisa │ │ └── qid2title.json │ ├── kg_mappings │ │ ├── config.json │ │ ├── kg_adj.txt │ │ ├── qid2relations.json │ │ └── relation_vocab.json │ └── type_mappings │ │ ├── hyena │ │ ├── config.json │ │ └── qid2typenames │ │ │ ├── max_value.json │ │ │ ├── record_trie.marisa │ │ │ └── vocabulary_trie.marisa │ │ ├── relations │ │ ├── config.json │ │ └── qid2typenames │ │ │ ├── max_value.json │ │ │ ├── record_trie.marisa │ │ │ └── vocabulary_trie.marisa │ │ └── wiki │ │ ├── config.json │ │ └── qid2typenames │ │ ├── max_value.json │ │ ├── record_trie.marisa │ │ └── vocabulary_trie.marisa ├── sample_raw_entity_data │ └── raw_profile.jsonl └── sample_text_data │ ├── dev.jsonl │ └── train.jsonl ├── docs ├── Makefile ├── make.bat └── source │ ├── advanced │ └── distributed_training.rst │ ├── apidocs │ ├── bootleg.end2end.rst │ ├── bootleg.layers.rst │ ├── bootleg.rst │ ├── bootleg.slicing.rst │ ├── bootleg.symbols.rst │ ├── bootleg.tasks.rst │ ├── bootleg.utils.classes.rst │ ├── bootleg.utils.parser.rst │ ├── bootleg.utils.preprocessing.rst │ ├── bootleg.utils.rst │ └── modules.rst │ ├── conf.py │ ├── dev │ ├── changelog.rst │ ├── codestyle.rst │ ├── install.rst │ └── tests.rst │ ├── gettingstarted │ ├── config.rst │ ├── emmental.rst │ ├── entity_profile.rst │ ├── input_data.rst │ ├── install.rst │ ├── model.rst │ ├── quickstart.rst │ └── training.rst │ └── index.rst ├── scripts └── train.zsh ├── setup.py ├── tests ├── data │ ├── data_loader │ │ ├── end2end_dev.jsonl │ │ ├── end2end_train.jsonl │ │ └── end2end_train_not_in_cand.jsonl │ └── entity_loader │ │ └── entity_data │ │ ├── entity_mappings │ │ ├── alias2id │ │ │ ├── config.json │ │ │ ├── itoexti.npy │ │ │ └── vocabulary_trie.marisa │ │ ├── alias2qids │ │ │ ├── max_value.json │ │ │ ├── record_trie.marisa │ │ │ └── vocabulary_trie.marisa │ │ ├── config.json │ │ ├── qid2eid │ │ │ ├── config.json │ │ │ ├── itoexti.npy │ │ │ └── vocabulary_trie.marisa │ │ └── qid2title.json │ │ ├── kg_mappings │ │ ├── config.json │ │ └── qid2relations │ │ │ ├── key_vocabulary_trie.marisa │ │ │ ├── max_value.json │ │ │ ├── record_trie.marisa │ │ │ └── value_vocabulary_trie.marisa │ │ └── type_mappings │ │ └── wiki │ │ ├── config.json │ │ └── qid2typenames │ │ ├── max_value.json │ │ ├── record_trie.marisa │ │ └── vocabulary_trie.marisa ├── run_args │ ├── test_candgen.json │ ├── test_data.json │ ├── test_end2end.json │ └── test_entity_data.json ├── test_cand_gen │ └── test_eval.py ├── test_data │ ├── test_data.py │ ├── test_entity_data.py │ └── test_slice_data.py ├── test_end_to_end │ ├── test_annotator.py │ ├── test_end_to_end.py │ ├── test_gen_entities.py │ └── test_mention_extraction.py ├── test_entity │ ├── test_entity.py │ └── test_entity_profile.py ├── test_scorer │ └── test_scorer.py └── test_utils │ ├── test_eval_utils.py │ ├── test_preprocessing.py │ └── test_util_classes.py ├── tutorials ├── README.md ├── annotation-on-the-fly.ipynb ├── download_data.sh ├── download_model.sh ├── download_wiki.sh ├── end2end_ned_tutorial.ipynb ├── entity_embedding_tutorial.ipynb ├── entity_profile_tutorial.ipynb └── utils.py └── web └── images ├── bootleg-logo.png ├── bootleg-performance.png ├── bootleg-text.png ├── bootleg_architecture.png ├── bootleg_dataflow.png └── full_logo.png /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | branch = True 3 | source = bootleg 4 | 5 | [report] 6 | # Regexes for lines to exclude from consideration 7 | exclude_lines = 8 | # Have to re-enable the standard pragma 9 | pragma: no cover 10 | 11 | # Don't complain about missing debug-only code: 12 | def __repr__ 13 | if self\.debug 14 | 15 | # Don't complain if tests don't hit defensive assertion code: 16 | raise AssertionError 17 | raise NotImplementedError 18 | 19 | # Don't complain if non-runnable code isn't run: 20 | if 0: 21 | if __name__ == .__main__.: 22 | 23 | ignore_errors = True 24 | 25 | omit = 26 | # Omit anything in the test directory 27 | tests/* 28 | cand_gen/* 29 | **/__init__.py 30 | setup.py 31 | bootleg/_version.py 32 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | # This is our code-style check. We currently allow the following exceptions: 2 | # - E731: do not assign a lambda expression, use a def 3 | # - W503: line break before binary operator 4 | # - E203: whitespace before ':' 5 | # - N803/N806 name should be lowercase 6 | # - N812 lowercase imported as non lowercase 7 | 8 | [flake8] 9 | exclude = .git 10 | max-line-length = 120 11 | ignore = E731, W503, E203, N803, N806, N812 12 | -------------------------------------------------------------------------------- /.github/workflows/ci.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | branches: 9 | - master 10 | 11 | # Allows you to run this workflow manually from the Actions tab 12 | workflow_dispatch: 13 | 14 | jobs: 15 | test: 16 | runs-on: ${{ matrix.os }} 17 | timeout-minutes: 30 18 | strategy: 19 | matrix: 20 | os: [ubuntu-latest] 21 | python-version: [3.7, 3.8] 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Cache conda 25 | uses: actions/cache@v1 26 | env: 27 | # Increase this value to reset cache 28 | CACHE_NUMBER: 0 29 | with: 30 | path: ~/conda_pkgs_dir 31 | key: ${{ runner.os }}-conda-${{ env.CACHE_NUMBER }} 32 | - name: Install miniconda and python version ${{ matrix.python-version }} 33 | uses: conda-incubator/setup-miniconda@v2 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | channels: conda-forge 37 | channel-priority: strict 38 | use-only-tar-bz2: true # IMPORTANT: This needs to be set for caching to work properly! 39 | - name: Install SSH Key 40 | uses: webfactory/ssh-agent@v0.5.3 41 | with: 42 | ssh-private-key: ${{ secrets.SSH_PRIVATE_KEY }} 43 | - name: Install Bootleg 44 | shell: bash -l {0} 45 | run: | 46 | make dev 47 | pip install -q pytest-cov 48 | - name: Run preliminary checks 49 | shell: bash -l {0} 50 | run: | 51 | isort --version 52 | black --version 53 | flake8 --version 54 | make check 55 | make docs 56 | - name: Test with pytest 57 | shell: bash -l {0} 58 | run: | 59 | pytest --cov=./ --cov-report=xml 60 | - name: Upload to codecov.io 61 | uses: codecov/codecov-action@v1 62 | with: 63 | file: ./coverage.xml 64 | flags: unittests 65 | name: codecov-umbrella 66 | fail_ci_if_error: true 67 | build-and-publish: 68 | name: Build and Publish Bootleg to PyPI 69 | runs-on: ubuntu-latest 70 | needs: test 71 | strategy: 72 | matrix: 73 | python-version: [3.8] 74 | steps: 75 | - uses: actions/checkout@v2 76 | - name: Set up Python ${{ matrix.python-version }} 77 | uses: actions/setup-python@v1 78 | with: 79 | python-version: ${{ matrix.python-version }} 80 | - name: Build Bootleg 81 | run: | 82 | pip install -U setuptools wheel pip 83 | python setup.py sdist bdist_wheel 84 | - name: Publish distribution 📦 to PyPI 85 | if: startsWith(github.event.ref, 'refs/tags') 86 | uses: pypa/gh-action-pypi-publish@master 87 | with: 88 | password: ${{ secrets.pypi_password }} 89 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs/* 2 | 3 | *._* 4 | # Pickle Saved 5 | *.pt 6 | *.pk 7 | **/*.pt 8 | **/*.pk 9 | 10 | # PyCharm 11 | *.idea 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | 118 | *.tsv 119 | *.7z 120 | .DS_Store 121 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | multi_line_output=3 3 | include_trailing_comma=True 4 | force_grid_wrap=0 5 | combine_as_imports=True 6 | line_length=88 7 | known_first_party = bootleg,tests 8 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v3.2.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-yaml 8 | - id: check-toml 9 | - id: check-merge-conflict 10 | - id: check-added-large-files 11 | - repo: https://github.com/timothycrosley/isort 12 | rev: 5.9.3 13 | hooks: 14 | - id: isort 15 | - repo: https://github.com/psf/black 16 | rev: 22.3.0 17 | hooks: 18 | - id: black 19 | language_version: python3 20 | - repo: https://gitlab.com/pycqa/flake8 21 | rev: 3.9.2 22 | hooks: 23 | - id: flake8 24 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Build documentation in the docs/ directory with Sphinx 9 | sphinx: 10 | configuration: docs/source/conf.py 11 | 12 | # Optionally build your docs in additional formats such as PDF and ePub 13 | formats: all 14 | 15 | # Optionally set the version of Python and requirements required to build your docs 16 | python: 17 | version: 3.8 18 | install: 19 | - method: pip 20 | path: .[dev] 21 | system_packages: true 22 | -------------------------------------------------------------------------------- /CHANGELOG.rst: -------------------------------------------------------------------------------- 1 | Unreleased 1.1.1dev0 2 | --------------------- 3 | Fixed 4 | ^^^^^^^^ 5 | * Corrected training with ``train_in_candidates`` set to False. 6 | 7 | 1.1.0 - 2022-04-12 8 | --------------------- 9 | Changed 10 | ^^^^^^^^^ 11 | * We did an architectural change and switched to a biencoder model. This changes our task flow and dataprep. This new model uses less CPU storage and uses the standard BERT architecture. Our entity encoder now takes a textual input of an entity that contains its title, description, KG relationships, and types. 12 | * To support larger files for dumping predictions over, we support adding an ``entity_emb_file`` to the model (extracted from ``extract_all_entities.py``. This will make evaluation faster. Further, we added ``dump_preds_num_data_splits`` to split a file before dumping. As each file pass gets a new dataload object, this can mitiage any torch dataloader memory issues that happens over large files. 13 | * Renamed ``eval_accumulation_steps`` to ``dump_preds_accumulation_steps``. 14 | * Removed option to ``dump_embs``. Users should use ``dump_preds`` instead. The output file will have ``entity_ids`` attribute that will index into the extracted entity embeddings. 15 | * Restructured our ``entity_db`` data for faster loading. It uses Tries rather than jsons to store the data for read only mode. The KG relations are not backwards compatible. 16 | * Moved to character spans for input data. Added utils.preprocessing.convert_to_char_spans as a helper function to convert from word offsets to character offsets. 17 | 18 | Added 19 | ^^^^^^ 20 | * ``BOOTLEG_STRIP`` and ``BOOTLEG_LOWER`` environment variables for ``get_lnrm``. 21 | * ``extract_all_entities.py`` as a way to extract all entity embeddings. These entity embeddings can be used in eval and be used downstream. Uses can use ``get_eid`` from the ``EntityProfile`` to extract the row id for a specific entity. 22 | 23 | 1.0.5 - 2021-08-20 24 | --------------------- 25 | Fixed 26 | ^^^^^^^^ 27 | * Fixed -1 command line argparse error 28 | * Adjusted requirements 29 | 30 | 1.0.4 - 2021-07-12 31 | --------------------- 32 | Added 33 | ^^^^^^ 34 | * Tutorial to generate contextualized entity embeddings that perform better downstream 35 | 36 | Fixed 37 | ^^^^^^^^ 38 | * Bump version of Pydantic to 1.7.4 39 | 40 | 1.0.3 - 2021-06-29 41 | --------------------- 42 | Fixed 43 | ^^^^^^^ 44 | * Corrected how custom candidates were handled in the BootlegAnnotator when using ``extracted_examples`` 45 | * Fixed memory leak in BooltegAnnotator due to missing ``torch.no_grad()`` 46 | 47 | 1.0.2 - 2021-04-28 48 | --------------------- 49 | 50 | Added 51 | ^^^^^^ 52 | * Support for ``min_alias_len`` to ``extract_mentions`` and the ``BootlegAnnotator``. 53 | * ``return_embs`` flag to pass into ``BootlegAnnotator`` that will return the contextualized embeddings of the entity (using key ``embs``) and entity candidates (using key ``cand_embs``). 54 | 55 | Changed 56 | ^^^^^^^^^ 57 | * Removed condition that aliases for eval must appear in candidate lists. We now allow for eval to not have known aliases and always mark these as incorrect. When dumping predictions, these get "-1" candidates and null probabilities. 58 | 59 | Fixed 60 | ^^^^^^^ 61 | * Corrected ``fit_to_profile`` to rebuild the title embeddings for the new entities. 62 | 63 | 1.0.1 - 2021-03-22 64 | ------------------- 65 | 66 | .. note:: 67 | 68 | If upgrading to 1.0.1 from 1.0.0, you will need to re-download our models given the links in the README.md. We altered what keys were saved in the state dict, but the model weights are unchanged. 69 | 70 | Added 71 | ^^^^^^^ 72 | * ``data_config.print_examples_prep`` flag to toggle data example printing during data prep. 73 | * ``data_config.dump_preds_accumulation_steps`` to support subbatching dumping of predictings. We save outputs to separate files of size approximately ``data_config.dump_preds_accumulation_steps*data_config.eval_batch_size`` and merge into a final file at the end. 74 | * Entity Profile API. See the `docs `_. This allows for modifying entity metadata as well as adding and removing entities. We profile methods for refitting a model with a new profile for immediate inference, no finetuning needed. 75 | 76 | Changed 77 | ^^^^^^^^ 78 | * Support for not using multiprocessing if use sets ``data_config.dataset_threads`` to be 1. 79 | * Added better argument parsing to check for arguments that were misspelled or otherwise wouldn't trigger anything. 80 | * Code is now Flake8 compatible. 81 | 82 | Fixed 83 | ^^^^^^^ 84 | * Fixed readthedocs so the BootlegAnnotator was loaded correctly. 85 | * Fixed logging in BootlegAnnotator. 86 | * Fixed ``use_exact_path`` argument in Emmental. 87 | 88 | 1.0.0 - 2021-02-15 89 | ------------------- 90 | We did a major rewrite of our entire codebase and moved to using `Emmental `_ for training. Emmental allows for each multi-task training, FP16, and support for both DataParallel and DistributedDataParallel. 91 | 92 | The overall functionality of Bootleg remains unchanged. We still support the use of an annotator and bulk mention extraction and evaluation. The core Bootleg model has remained largely unchanged. Checkout our `documentation `_ for more information on getting started. We have new models trained as described in our `README `_. 93 | 94 | .. note:: 95 | 96 | This branch os **not** backwards compatible with our old models or code base. 97 | 98 | Some more subtle changes are below 99 | 100 | Added 101 | ^^^^^ 102 | * Support for data parallel and distributed data parallel training (through Emmental) 103 | * FP16 (through Emmental) 104 | * Easy install with ``BootlegAnnotator`` 105 | 106 | Changed 107 | ^^^^^^^^ 108 | * Mention extraction code and alias map has been updated 109 | * Models trained on October 2020 save of Wikipedia 110 | * Have uncased and cased models 111 | 112 | Removed 113 | ^^^^^^^ 114 | * Support for slice-based learning 115 | * Support for ``batch prepped`` KG embeddings (only use ``batch on the fly``) 116 | 117 | 118 | .. _@lorr1: https://github.com/lorr1 119 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | dev: 2 | pip install -e .[dev] 3 | pre-commit install 4 | 5 | test: dev check docs 6 | pip install -e . 7 | pytest tests 8 | 9 | format: 10 | isort --atomic bootleg/ tests/ 11 | black bootleg/ tests/ 12 | # docformatter --in-place --recursive bootleg tests 13 | 14 | check: 15 | isort -c bootleg/ tests/ 16 | black bootleg/ tests/ --check 17 | flake8 bootleg/ tests/ 18 | 19 | docs: 20 | sphinx-build -b html docs/source/ docs/build/html/ 21 | # sphinx-apidoc -o docs/source/apidocs/ bootleg 22 | 23 | docs-check: 24 | sphinx-build -b html docs/source/ docs/build/html/ -W 25 | 26 | livedocs: 27 | sphinx-autobuild -b html docs/source/ docs/build/html/ 28 | 29 | clean: 30 | pip uninstall -y bootleg 31 | rm -rf src/bootleg.egg-info 32 | rm -rf build/ dist/ 33 | 34 | prune: 35 | @bash -c "git fetch -p"; 36 | @bash -c "for branch in $(git branch -vv | grep ': gone]' | awk '{print $1}'); do git branch -d $branch; done"; 37 | 38 | .PHONY: dev test clean check docs prune 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ![GitHub Workflow Status](https://img.shields.io/github/workflow/status/HazyResearch/bootleg/CI) 6 | [![codecov](https://codecov.io/gh/HazyResearch/bootleg/branch/master/graph/badge.svg)](https://codecov.io/gh/HazyResearch/bootleg) 7 | [![Documentation Status](https://readthedocs.org/projects/bootleg/badge/?version=latest)](https://bootleg.readthedocs.io/en/latest/?badge=latest) 8 | [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 9 | 10 | # Self-Supervision for Named Entity Disambiguation at the Tail 11 | Bootleg is a self-supervised named entity disambiguation (NED) system for English built to improve disambiguation of entities that occur infrequently, or not at all, in training data. We call these entities *tail* entities. This is a critical task as the majority of entities are rare. The core insight behind Bootleg is that these tail entities can be disambiguated by reasoning over entity types and relations. We give an [overview](#bootleg-overview) of how Bootleg achieves this below. For details, please see our [blog post](https://hazyresearch.stanford.edu/bootleg_blog) and [paper](http://arxiv.org/abs/2010.10363). 12 | 13 | Note that Bootleg is *actively under development* and feedback is welcome. Submit bugs on the Issues page or feel free to submit your contributions as a pull request. 14 | 15 | **Update 9-25-2021**: We changed our architecture to be a biencoder. Our entity textual input still has all the goodness of types and KG relations, but our model now requires less storage space and has improved performance. A secret to getting the biencoder to work over the tail was heavy masking of the mention in the context encoder and entity title in the entity encoder. 16 | 17 | **Update 2-15-2021**: We made a major rewrite of the codebase and moved to using Emmental for training--check out the [changelog](CHANGELOG.rst) for details) 18 | 19 | # Getting Started 20 | 21 | Install via 22 | 23 | ``` 24 | git clone git@github.com:HazyResearch/bootleg bootleg 25 | cd bootleg 26 | python3 setup.py install 27 | ``` 28 | 29 | Checkout out our installation and quickstart guide [here](https://bootleg.readthedocs.io/en/latest/gettingstarted/install.html). 30 | 31 | ## Using a Trained Model 32 | ### Models 33 | Below is the link to download the English Bootleg model. The download comes with the saved model and config to run the model. We show in our [quickstart guide](https://bootleg.readthedocs.io/en/latest/gettingstarted/quickstart.html) and [end-to-end](tutorials/end2end_ned_tutorial.ipynb) tutorial how to load a config and run a model. 34 | 35 | | Model | Description | Number Parameters | Link | 36 | |------------------- |---------------------------------|-------------------|----------| 37 | | BootlegUncased | Uses titles, descriptions, types, and KG relations. Trained on uncased data. | 110M | [Download](https://bootleg-ned-data.s3-us-west-1.amazonaws.com/models/latest/bootleg_uncased.tar.gz) | 38 | 39 | ### Embeddings 40 | Below is the link to download a dump of all entity embeddings from our entity encoder. Follow our entity profile tutorial [here](https://github.com/HazyResearch/bootleg/blob/master/tutorials/entity_profile_tutorial.ipynb) to load our EntityProfile. From there, you can use our ```get_eid``` [method](https://bootleg.readthedocs.io/en/latest/apidocs/bootleg.symbols.html#bootleg.symbols.entity_profile.EntityProfile.get_eid) to access the row id for an entity. 41 | 42 | | Embeddings | Description | Number Parameters | Link | 43 | |------------------- |---------------------------------|-------------------|----------| 44 | | 5.8M Wikipedia Entities | Embeddings from BootlegUncased. | 1.2B | [Download](https://bootleg-ned-data.s3-us-west-1.amazonaws.com/models/latest/bootleg_uncased_entity_embeddings.npy.tar.gz) | 45 | 46 | ### Metadata 47 | Below is the link to download a dump of all entity metadata to use in our entity profile tutorial [here](https://github.com/HazyResearch/bootleg/blob/master/tutorials/entity_profile_tutorial.ipynb). 48 | 49 | | Metadata | Description | Link | 50 | |------------------- |---------------------------------|----------| 51 | | 5.8M Wikipedia Entities | Wikidata metadata for entities. | [Download](https://bootleg-data.s3.us-west-2.amazonaws.com/data/latest/entity_db.tar.gz) | 52 | 53 | ## Training 54 | We provide detailed training instructions [here](https://bootleg.readthedocs.io/en/latest/gettingstarted/training.html). We provide a starter config [here](configs/standard/train.yaml). You only need to adjust `data_config.data_dir` and `data_config.entity_dir` to points to your local data. You may need to shrink the model size to fit on your available hardware. The use the training zsh script [here](scripts/train.zsh). 55 | 56 | ## Tutorials 57 | We provide tutorials to help users get familiar with Bootleg [here](tutorials/). 58 | 59 | # Bootleg Overview 60 | Given an input sentence, Bootleg takes the sentence and outputs a predicted entity for each detected mention. Bootleg first extracts mentions in the 61 | sentence, and for each mention, we extract its set of possible candidate entities 62 | and any structural information about that entity, e.g., type information or knowledge graph (KG) information. Bootleg leverages this information to generate an entity embedding through a Transformer entity encoder. The mention and its surrounding context is encoded in a context encoder. The entity with the highest dot product with the context is selected for each mention. 63 | 64 | ![Dataflow](web/images/bootleg_dataflow.png "Bootleg Dataflow") 65 | 66 | More details can be found [here](https://bootleg.readthedocs.io/en/latest/gettingstarted/input_data.html) 67 | 68 | ## Inference 69 | Given a pretrained model, we support three types of inference: `--mode eval`, `--mode dump_preds`, and `--mode dump_embs`. `Eval` mode is the fastest option and will run the test files through the model and output aggregated quality metrics to the log. `Dump_preds` mode will write the individual predictions and corresponding probabilities to a jsonlines file. This is useful for error analysis. `Dump_embs` mode is the same as `dump_preds`, but will additionally output entity embeddings. These can then be read and processed in a downstream system. See this [notebook](tutorials/end2end_ned_tutorial.ipynb) to see how with a downloaded Bootleg model. 70 | 71 | ## Entity Embedding Extraction 72 | As we have a separate encoder for generating an entity representation, we also support the ability to dump all entities to create a single entity embedding matrix for use downstream. This is done through the ```bootleg.extract_all_entities``` script. See this [notebook](tutorials/entity_embedding_tutorial.ipynb) to see how with a downloaded Bootleg model. 73 | 74 | ## Training 75 | We recommend using GPUs for training Bootleg models. For large datasets, we support distributed training with Pytorch's Distributed DataParallel framework to distribute batches across multiple GPUs. Check out the [Basic Training](https://bootleg.readthedocs.io/en/latest/gettingstarted/training.html) and [Advanced Training](https://bootleg.readthedocs.io/en/latest/advanced/distributed_training.html) tutorials for more information and sample data! 76 | 77 | ## Downstream Tasks 78 | Bootleg produces contextual entity embeddings (as well as learned static embeddings) that can be used in downstream tasks, such as relation extraction and question answering. Check out the [tutorial](tutorials) to see how this is done. 79 | 80 | ## Other Languages 81 | The released Bootleg model only supports English, but we have trained multi-lingual models using Wikipedia and Wikidata. If you have interest in doing this, please let us know with an issue request or email lorr1@cs.stanford.edu. We have data prep code to help prepare multi-lingual data. 82 | -------------------------------------------------------------------------------- /bootleg/__init__.py: -------------------------------------------------------------------------------- 1 | """Print functions for distributed computation.""" 2 | import torch 3 | 4 | 5 | def log_rank_0_info(logger, message): 6 | """If distributed is initialized log info only on rank 0.""" 7 | if torch.distributed.is_initialized(): 8 | if torch.distributed.get_rank() == 0: 9 | logger.info(message) 10 | else: 11 | logger.info(message) 12 | 13 | 14 | def log_rank_0_debug(logger, message): 15 | """If distributed is initialized log debug only on rank 0.""" 16 | if torch.distributed.is_initialized(): 17 | if torch.distributed.get_rank() == 0: 18 | logger.debug(message) 19 | else: 20 | logger.debug(message) 21 | -------------------------------------------------------------------------------- /bootleg/_version.py: -------------------------------------------------------------------------------- 1 | """Bootleg version.""" 2 | __version__ = "1.1.1dev0" 3 | -------------------------------------------------------------------------------- /bootleg/end2end/__init__.py: -------------------------------------------------------------------------------- 1 | """End2End init.""" 2 | -------------------------------------------------------------------------------- /bootleg/end2end/annotator_utils.py: -------------------------------------------------------------------------------- 1 | """Annotator utils.""" 2 | 3 | import progressbar 4 | 5 | 6 | class DownloadProgressBar: 7 | """Progress bar.""" 8 | 9 | def __init__(self): 10 | """Progress bar initializer.""" 11 | self.pbar = None 12 | 13 | def __call__(self, block_num, block_size, total_size): 14 | """Call.""" 15 | if not self.pbar: 16 | self.pbar = progressbar.ProgressBar( 17 | maxval=total_size if total_size > 0 else 1e-2 18 | ) 19 | self.pbar.start() 20 | 21 | downloaded = block_num * block_size 22 | if downloaded < total_size: 23 | self.pbar.update(downloaded) 24 | else: 25 | self.pbar.finish() 26 | -------------------------------------------------------------------------------- /bootleg/layers/__init__.py: -------------------------------------------------------------------------------- 1 | """Layer init.""" 2 | -------------------------------------------------------------------------------- /bootleg/layers/bert_encoder.py: -------------------------------------------------------------------------------- 1 | """BERT encoder.""" 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class Encoder(nn.Module): 7 | """ 8 | Encoder module. 9 | 10 | Return the CLS token of Transformer. 11 | 12 | Args: 13 | transformer: transformer 14 | out_dim: out dimension to project to 15 | """ 16 | 17 | def __init__(self, transformer, out_dim): 18 | """BERT Encoder initializer.""" 19 | super(Encoder, self).__init__() 20 | transformer_output_dim = transformer.embeddings.word_embeddings.weight.size(1) 21 | self.linear = nn.Linear(transformer_output_dim, out_dim) 22 | self.activation = nn.Tanh() 23 | self.transformer = transformer 24 | 25 | def forward(self, token_ids, segment_ids=None, attention_mask=None): 26 | """BERT Encoder forward.""" 27 | encoded_layers, pooled_output = self.transformer( 28 | input_ids=token_ids.reshape(-1, token_ids.shape[-1]), 29 | token_type_ids=segment_ids.reshape(-1, segment_ids.shape[-1]), 30 | attention_mask=attention_mask.reshape(-1, attention_mask.shape[-1]), 31 | return_dict=False, 32 | ) 33 | full_embs = pooled_output.reshape(*token_ids.shape[:-1], -1) 34 | embs = self.activation(self.linear(full_embs)) 35 | training_bool = ( 36 | torch.tensor([1], device=token_ids.device) * self.training 37 | ).bool() 38 | return embs, training_bool 39 | -------------------------------------------------------------------------------- /bootleg/layers/static_entity_embeddings.py: -------------------------------------------------------------------------------- 1 | """Entity embeddings.""" 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | class EntityEmbedding(torch.nn.Module): 11 | """Static entity embeddings class. 12 | 13 | Args: 14 | entity_emb_file: numpy file of entity embeddings 15 | """ 16 | 17 | def __init__(self, entity_emb_file): 18 | """Entity embedding initializer.""" 19 | super(EntityEmbedding, self).__init__() 20 | embs = torch.FloatTensor(np.load(entity_emb_file)) 21 | # Add -1 padding row; not required as dump from Bootleg should include PAD entity but as a safety 22 | embs = torch.cat([embs, torch.zeros(1, embs.shape[-1])], dim=0) 23 | self.embeddings = torch.nn.Embedding.from_pretrained(embs, padding_idx=-1) 24 | 25 | def forward(self, entity_cand_eid): 26 | """Model forward. 27 | 28 | Args: 29 | entity_cand_eid: entity candidate EIDs (B x M x K) 30 | 31 | Returns: B x M x K x dim tensor of entity embeddings 32 | """ 33 | training_bool = ( 34 | torch.tensor([1], device=entity_cand_eid.device) * self.training 35 | ).bool() 36 | return self.embeddings(entity_cand_eid), training_bool 37 | -------------------------------------------------------------------------------- /bootleg/scorer.py: -------------------------------------------------------------------------------- 1 | """Bootleg scorer.""" 2 | import logging 3 | from collections import Counter 4 | from typing import Dict, List, Optional 5 | 6 | from numpy import ndarray 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class BootlegSlicedScorer: 12 | """Sliced NED scorer init. 13 | 14 | Args: 15 | train_in_candidates: are we training assuming that all gold qids are in the candidates or not 16 | slices_datasets: slice dataset (see slicing/slice_dataset.py) 17 | """ 18 | 19 | def __init__(self, train_in_candidates, slices_datasets=None): 20 | """Bootleg scorer initializer.""" 21 | self.train_in_candidates = train_in_candidates 22 | self.slices_datasets = slices_datasets 23 | 24 | def get_slices(self, uid): 25 | """ 26 | Get slices incidence matrices. 27 | 28 | Get slice incidence matrices for the uid Uid is dtype 29 | (np.dtype([('sent_idx', 'i8', 1), ('subsent_idx', 'i8', 1), 30 | ("alias_orig_list_pos", 'i8', max_aliases)]) where alias_orig_list_pos 31 | gives the mentions original positions in the sentence. 32 | 33 | Args: 34 | uid: unique identifier of sentence 35 | 36 | Returns: dictionary of slice_name -> matrix of 0/1 for if alias is in slice or not (-1 for no alias) 37 | """ 38 | if self.slices_datasets is None: 39 | return {} 40 | for split, dataset in self.slices_datasets.items(): 41 | sent_idx = uid["sent_idx"] 42 | alias_orig_list_pos = uid["alias_orig_list_pos"] 43 | if dataset.contains_sentidx(sent_idx): 44 | return dataset.get_slice_incidence_arr(sent_idx, alias_orig_list_pos) 45 | return {} 46 | 47 | def bootleg_score( 48 | self, 49 | golds: ndarray, 50 | probs: ndarray, 51 | preds: Optional[ndarray], 52 | uids: Optional[List[str]] = None, 53 | ) -> Dict[str, float]: 54 | """Scores the predictions using the gold labels and slices. 55 | 56 | Args: 57 | golds: gold labels 58 | probs: probabilities 59 | preds: predictions (max prob candidate) 60 | uids: unique identifiers 61 | 62 | Returns: dictionary of tensorboard compatible keys and metrics 63 | """ 64 | batch = golds.shape[0] 65 | NO_MENTION = -1 66 | NOT_IN_CANDIDATES = -2 if self.train_in_candidates else 0 67 | res = {} 68 | total = Counter() 69 | total_in_cand = Counter() 70 | correct_boot = Counter() 71 | correct_pop_cand = Counter() 72 | correct_boot_in_cand = Counter() 73 | correct_pop_cand_in_cand = Counter() 74 | assert ( 75 | len(uids) == batch 76 | ), f"Length of uids {len(uids)} does not match batch {batch} in scorer" 77 | for row in range(batch): 78 | gold = golds[row] 79 | pred = preds[row] 80 | uid = uids[row] 81 | pop_cand = 0 + int(not self.train_in_candidates) 82 | if gold == NO_MENTION: 83 | continue 84 | # Slices is dictionary of slice_name -> incidence array. Each array value is 1/0 for if in slice or not 85 | slices = self.get_slices(uid) 86 | for slice_name in slices: 87 | assert ( 88 | slices[slice_name][0] != -1 89 | ), f"Something went wrong with slices {slices} and uid {uid}" 90 | # Check if alias is in slice 91 | if slices[slice_name][0] == 1: 92 | total[slice_name] += 1 93 | if gold != NOT_IN_CANDIDATES: 94 | total_in_cand[slice_name] += 1 95 | if gold == pred: 96 | correct_boot[slice_name] += 1 97 | if gold != NOT_IN_CANDIDATES: 98 | correct_boot_in_cand[slice_name] += 1 99 | if gold == pop_cand: 100 | correct_pop_cand[slice_name] += 1 101 | if gold != NOT_IN_CANDIDATES: 102 | correct_pop_cand_in_cand[slice_name] += 1 103 | for slice_name in total: 104 | res[f"{slice_name}/total_men"] = total[slice_name] 105 | res[f"{slice_name}/total_notNC_men"] = total_in_cand[slice_name] 106 | res[f"{slice_name}/acc_boot"] = ( 107 | 0 108 | if total[slice_name] == 0 109 | else correct_boot[slice_name] / total[slice_name] 110 | ) 111 | res[f"{slice_name}/acc_notNC_boot"] = ( 112 | 0 113 | if total_in_cand[slice_name] == 0 114 | else correct_boot_in_cand[slice_name] / total_in_cand[slice_name] 115 | ) 116 | res[f"{slice_name}/acc_pop"] = ( 117 | 0 118 | if total[slice_name] == 0 119 | else correct_pop_cand[slice_name] / total[slice_name] 120 | ) 121 | res[f"{slice_name}/acc_notNC_pop"] = ( 122 | 0 123 | if total_in_cand[slice_name] == 0 124 | else correct_pop_cand_in_cand[slice_name] / total_in_cand[slice_name] 125 | ) 126 | return res 127 | -------------------------------------------------------------------------------- /bootleg/slicing/__init__.py: -------------------------------------------------------------------------------- 1 | """Slicing initializer.""" 2 | -------------------------------------------------------------------------------- /bootleg/symbols/__init__.py: -------------------------------------------------------------------------------- 1 | """Symbols init.""" 2 | -------------------------------------------------------------------------------- /bootleg/symbols/constants.py: -------------------------------------------------------------------------------- 1 | """Constants.""" 2 | 3 | import logging 4 | import os 5 | from distutils.util import strtobool 6 | from functools import wraps 7 | 8 | from bootleg import log_rank_0_info 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | USE_STRIP = strtobool(os.environ.get("BOOTLEG_STRIP", "true")) 13 | USE_LOWER = strtobool(os.environ.get("BOOTLEG_LOWER", "true")) 14 | LANG_CODE = os.environ.get("BOOTLEG_LANG_CODE", "en") 15 | 16 | log_rank_0_info( 17 | logger, 18 | f"Setting BOOTLEG_STRIP to {USE_STRIP} and BOOTLEG_LOWER to {USE_LOWER} and BOOTLEG_LANG_CODE to {LANG_CODE}. " 19 | f"Set these enviorn variables to change behavior.", 20 | ) 21 | 22 | PAD = "" 23 | UNK_ID = 0 24 | PAD_ID = -1 25 | 26 | CLS_BERT = "[CLS]" 27 | SEP_BERT = "[SEP]" 28 | PAD_BERT = "[PAD]" 29 | BERT_WORD_DIM = 768 30 | MAX_BERT_TOKEN_LEN = 512 31 | 32 | SPECIAL_TOKENS = { 33 | "additional_special_tokens": [ 34 | "[ent_start]", 35 | "[ent_end]", 36 | "[ent_desc]", 37 | "[ent_kg]", 38 | "[ent_type]", 39 | ] 40 | } 41 | FINAL_LOSS = "final_loss" 42 | 43 | TRAIN_SPLIT = "train" 44 | DEV_SPLIT = "dev" 45 | TEST_SPLIT = "test" 46 | 47 | # dataset keys 48 | ANCHOR_KEY = "gold" 49 | 50 | STOP_WORDS = { 51 | "for", 52 | "this", 53 | "haven", 54 | "her", 55 | "are", 56 | "s", 57 | "don't", 58 | "ll", 59 | "isn't", 60 | "been", 61 | "themselves", 62 | "it's", 63 | "needn't", 64 | "haven't", 65 | "shouldn", 66 | "ours", 67 | "d", 68 | "than", 69 | "only", 70 | "ma", 71 | "me", 72 | "after", 73 | "which", 74 | "under", 75 | "then", 76 | "both", 77 | "as", 78 | "can", 79 | "yours", 80 | "hers", 81 | "their", 82 | "hadn't", 83 | "we", 84 | "in", 85 | "off", 86 | "having", 87 | "t", 88 | "up", 89 | "re", 90 | "needn", 91 | "she's", 92 | "below", 93 | "over", 94 | "from", 95 | "all", 96 | "an", 97 | "did", 98 | "most", 99 | "weren't", 100 | "your", 101 | "couldn", 102 | "you've", 103 | "because", 104 | "same", 105 | "didn", 106 | "shouldn't", 107 | "about", 108 | "aren", 109 | "myself", 110 | "while", 111 | "so", 112 | "mightn't", 113 | "very", 114 | "what", 115 | "aren't", 116 | "other", 117 | "won", 118 | "or", 119 | "should've", 120 | "out", 121 | "when", 122 | "doesn", 123 | "of", 124 | "am", 125 | "doing", 126 | "nor", 127 | "above", 128 | "shan't", 129 | "with", 130 | "isn", 131 | "that", 132 | "is", 133 | "yourself", 134 | "him", 135 | "had", 136 | "those", 137 | "just", 138 | "more", 139 | "ain", 140 | "my", 141 | "it", 142 | "won't", 143 | "you", 144 | "yourselves", 145 | "at", 146 | "being", 147 | "between", 148 | "be", 149 | "some", 150 | "o", 151 | "where", 152 | "weren", 153 | "has", 154 | "will", 155 | "wasn't", 156 | "that'll", 157 | "against", 158 | "during", 159 | "ve", 160 | "wouldn't", 161 | "herself", 162 | "such", 163 | "m", 164 | "doesn't", 165 | "itself", 166 | "here", 167 | "and", 168 | "were", 169 | "didn't", 170 | "own", 171 | "through", 172 | "they", 173 | "do", 174 | "you'd", 175 | "once", 176 | "the", 177 | "couldn't", 178 | "hasn't", 179 | "before", 180 | "who", 181 | "any", 182 | "our", 183 | "hadn", 184 | "too", 185 | "no", 186 | "he", 187 | "hasn", 188 | "if", 189 | "why", 190 | "wouldn", 191 | "its", 192 | "on", 193 | "mustn't", 194 | "now", 195 | "again", 196 | "to", 197 | "each", 198 | "whom", 199 | "i", 200 | "by", 201 | "have", 202 | "how", 203 | "theirs", 204 | "not", 205 | "don", 206 | "but", 207 | "there", 208 | "shan", 209 | "ourselves", 210 | "until", 211 | "down", 212 | "mightn", 213 | "wasn", 214 | "few", 215 | "mustn", 216 | "his", 217 | "y", 218 | "you're", 219 | "should", 220 | "does", 221 | "himself", 222 | "was", 223 | "you'll", 224 | "them", 225 | "these", 226 | "she", 227 | "into", 228 | "further", 229 | "a", 230 | } 231 | 232 | 233 | # profile constants/utils wrappers 234 | def edit_op(func): 235 | """Edit op.""" 236 | 237 | @wraps(func) 238 | def wrapper_check_edit_mode(obj, *args, **kwargs): 239 | if obj.edit_mode is False: 240 | raise AttributeError("You must load object in edit_mode=True") 241 | return func(obj, *args, **kwargs) 242 | 243 | return wrapper_check_edit_mode 244 | 245 | 246 | def check_qid_exists(func): 247 | """Check QID exists.""" 248 | 249 | @wraps(func) 250 | def wrapper_check_qid(obj, *args, **kwargs): 251 | if len(args) > 0: 252 | qid = args[0] 253 | else: 254 | qid = kwargs["qid"] 255 | if not obj._entity_symbols.qid_exists(qid): 256 | raise ValueError(f"The entity {qid} is not in our dump") 257 | return func(obj, *args, **kwargs) 258 | 259 | return wrapper_check_qid 260 | -------------------------------------------------------------------------------- /bootleg/task_config.py: -------------------------------------------------------------------------------- 1 | """Emmental task constants.""" 2 | 3 | NED_TASK = "NED" 4 | BATCH_CANDS_LABEL = "gold_unq_eid_idx" 5 | CANDS_LABEL = "gold_cand_K_idx" 6 | -------------------------------------------------------------------------------- /bootleg/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """Task init.""" 2 | -------------------------------------------------------------------------------- /bootleg/tasks/entity_gen_task.py: -------------------------------------------------------------------------------- 1 | """Entity gen task definitions.""" 2 | import torch.nn.functional as F 3 | from emmental.scorer import Scorer 4 | from emmental.task import Action, EmmentalTask 5 | from torch import nn 6 | from transformers import AutoModel 7 | 8 | from bootleg.layers.bert_encoder import Encoder 9 | from bootleg.task_config import NED_TASK 10 | 11 | 12 | class EntityGenOutput: 13 | """Entity gen for output.""" 14 | 15 | def __init__(self, normalize): 16 | """Entity gen for output initializer.""" 17 | self.normalize = normalize 18 | 19 | def entity_output_func(self, intermediate_output_dict): 20 | """Entity output func.""" 21 | ent_out = intermediate_output_dict["entity_encoder"][0] 22 | if self.normalize: 23 | ent_out = F.normalize(ent_out, p=2, dim=-1) 24 | return ent_out 25 | 26 | 27 | def create_task(args, len_context_tok): 28 | """Return an EmmentalTask for entity encoder only. 29 | 30 | Args: 31 | args: args 32 | len_context_tok: number of tokens in the tokenizer 33 | 34 | Returns: EmmentalTask for entity embedding extraction 35 | """ 36 | entity_model = AutoModel.from_pretrained(args.data_config.word_embedding.bert_model) 37 | entity_model.encoder.layer = entity_model.encoder.layer[ 38 | : args.data_config.word_embedding.entity_layers 39 | ] 40 | entity_model.resize_token_embeddings(len_context_tok) 41 | entity_model = Encoder(entity_model, args.model_config.hidden_size) 42 | 43 | # Create module pool and combine with embedding module pool 44 | module_pool = nn.ModuleDict( 45 | { 46 | "entity_encoder": entity_model, 47 | } 48 | ) 49 | 50 | # Create task flow 51 | task_flow = [ 52 | Action( 53 | name="entity_encoder", 54 | module="entity_encoder", 55 | inputs=[ 56 | ("_input_", "entity_input_ids"), 57 | ("_input_", "entity_attention_mask"), 58 | ("_input_", "entity_token_type_ids"), 59 | ], 60 | ), 61 | ] 62 | 63 | return EmmentalTask( 64 | name=NED_TASK, 65 | module_pool=module_pool, 66 | task_flow=task_flow, 67 | loss_func=None, 68 | output_func=EntityGenOutput(args.model_config.normalize).entity_output_func, 69 | require_prob_for_eval=False, 70 | require_pred_for_eval=True, 71 | scorer=Scorer(), 72 | ) 73 | -------------------------------------------------------------------------------- /bootleg/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Util init.""" 2 | -------------------------------------------------------------------------------- /bootleg/utils/classes/__init__.py: -------------------------------------------------------------------------------- 1 | """Classes init.""" 2 | -------------------------------------------------------------------------------- /bootleg/utils/classes/comment_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | JSON with comments class. 3 | 4 | An example of how to remove comments and trailing commas from JSON before 5 | parsing. You only need the two functions below, `remove_comments()` and 6 | `remove_trailing_commas()` to accomplish this. This script serves as an 7 | example of how to use them but feel free to just copy & paste them into your 8 | own code/projects. Usage:: json_cleaner.py some_file.json Alternatively, you 9 | can pipe JSON into this script and it'll clean it up:: cat some_file.json | 10 | json_cleaner.py Why would you do this? So you can have human-generated .json 11 | files (say, for configuration) that include comments and, really, who wants to 12 | deal with catching all those trailing commas that might be present? Here's an 13 | example of a file that will be successfully cleaned up and JSON-parseable: 14 | 15 | .. code-block:: javascript 16 | { 17 | // A comment! You normally can't put these in JSON 18 | "testing": { 19 | "foo": "bar", // <-- A trailing comma! No worries. 20 | }, // <-- Another one! 21 | /* 22 | This style of comments will also be safely removed before parsing 23 | */ 24 | } 25 | FYI: This script will also pretty-print the JSON after it's cleaned up (if 26 | using it from the command line) with an indentation level of 4 (that is, four 27 | spaces). 28 | """ 29 | 30 | __version__ = "1.0.0" 31 | __version_info__ = (1, 0, 0) 32 | __license__ = "Unlicense" 33 | __author__ = "Dan McDougall " 34 | 35 | import re 36 | 37 | 38 | def remove_comments(json_like): 39 | r"""Remove C-style comments from *json_like* and returns the result. 40 | 41 | Example:: 42 | 43 | >>> test_json = '''\ 44 | { 45 | "foo": "bar", // This is a single-line comment 46 | "baz": "blah" /* Multi-line 47 | Comment */ 48 | }''' 49 | >>> remove_comments('{"foo":"bar","baz":"blah",}') 50 | '{\n "foo":"bar",\n "baz":"blah"\n}' 51 | """ 52 | comments_re = re.compile( 53 | r'//.*?$|/\*.*?\*/|\'(?:\\.|[^\\\'])*\'|"(?:\\.|[^\\"])*"', 54 | re.DOTALL | re.MULTILINE, 55 | ) 56 | 57 | def replacer(match): 58 | s = match.group(0) 59 | if s[0] == "/": 60 | return "" 61 | return s 62 | 63 | return comments_re.sub(replacer, json_like) 64 | 65 | 66 | def remove_trailing_commas(json_like): 67 | """Remove trailing commas from *json_like* and returns the result. 68 | 69 | Example:: 70 | 71 | >>> remove_trailing_commas('{"foo":"bar","baz":["blah",],}') 72 | '{"foo":"bar","baz":["blah"]}' 73 | """ 74 | trailing_object_commas_re = re.compile( 75 | r'(,)\s*}(?=([^"\\]*(\\.|"([^"\\]*\\.)*[^"\\]*"))*[^"]*$)' 76 | ) 77 | trailing_array_commas_re = re.compile( 78 | r'(,)\s*\](?=([^"\\]*(\\.|"([^"\\]*\\.)*[^"\\]*"))*[^"]*$)' 79 | ) 80 | # Fix objects {} first 81 | objects_fixed = trailing_object_commas_re.sub("}", json_like) 82 | # Now fix arrays/lists [] and return the result 83 | return trailing_array_commas_re.sub("]", objects_fixed) 84 | -------------------------------------------------------------------------------- /bootleg/utils/classes/emmental_data.py: -------------------------------------------------------------------------------- 1 | """Emmental dataset and dataloader.""" 2 | import logging 3 | from typing import Any, Dict, Optional, Tuple, Union 4 | 5 | from emmental import EmmentalDataset 6 | from torch import Tensor 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class RangedEmmentalDataset(EmmentalDataset): 12 | """ 13 | RangedEmmentalDataset dataset. 14 | 15 | An advanced dataset class to handle that the input data contains multiple fields 16 | and the output data contains multiple label sets. 17 | 18 | Args: 19 | name: The name of the dataset. 20 | X_dict: The feature dict where key is the feature name and value is the 21 | feature. 22 | Y_dict: The label dict where key is the label name and value is 23 | the label, defaults to None. 24 | uid: The unique id key in the X_dict, defaults to None. 25 | data_range: The range of data to select. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | name: str, 31 | X_dict: Dict[str, Any], 32 | Y_dict: Optional[Dict[str, Tensor]] = None, 33 | uid: Optional[str] = None, 34 | data_range: Optional[list] = None, 35 | ) -> None: 36 | """Initialize RangedEmmentalDataset.""" 37 | super().__init__(name, X_dict, Y_dict, uid) 38 | if data_range is not None: 39 | self.data_range = data_range 40 | else: 41 | self.data_range = list(range(len(next(iter(self.X_dict.values()))))) 42 | 43 | def __getitem__( 44 | self, index: int 45 | ) -> Union[Tuple[Dict[str, Any], Dict[str, Tensor]], Dict[str, Any]]: 46 | """Get item by index after taking range into account. 47 | 48 | Args: 49 | index: The index of the item. 50 | Returns: 51 | Tuple of x_dict and y_dict 52 | """ 53 | return super().__getitem__(self.data_range[index]) 54 | 55 | def __len__(self) -> int: 56 | """Total number of items in the dataset.""" 57 | return len(self.data_range) 58 | -------------------------------------------------------------------------------- /bootleg/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """Bootleg data utils.""" 2 | import os 3 | 4 | from bootleg.symbols.constants import FINAL_LOSS, SPECIAL_TOKENS 5 | from bootleg.utils import utils 6 | 7 | 8 | def correct_not_augmented_dict_values(gold, dict_values): 9 | """ 10 | Correct gold label dict values in data prep. 11 | 12 | Modifies the dict_values to only contain those mentions that are gold 13 | labels. The new dictionary has the alias indices be corrected to start at 0 14 | and end at the number of gold mentions. 15 | 16 | Args: 17 | gold: List of T/F values if mention is gold label or not 18 | dict_values: Dict of slice_name -> Dict[alias_idx] -> slice probability 19 | 20 | Returns: adjusted dict_values such that only gold = True aliases are kept (dict is reindexed to start at 0) 21 | """ 22 | new_dict_values = {} 23 | gold_idx = [i for i in range(len(gold)) if gold[i] is True] 24 | for slice_name in list(dict_values.keys()): 25 | alias_dict = dict_values[slice_name] 26 | # i will not be in gold_idx if it wasn't an gold to being with 27 | new_dict_values[slice_name] = { 28 | str(gold_idx.index(int(i))): alias_dict[i] 29 | for i in alias_dict 30 | if int(i) in gold_idx 31 | } 32 | if len(new_dict_values[slice_name]) <= 0: 33 | del new_dict_values[slice_name] 34 | return new_dict_values 35 | 36 | 37 | # eval_slices must include FINAL_LOSS 38 | def get_eval_slices(eval_slices): 39 | """ 40 | Get eval slices in data prep. 41 | 42 | Given input eval slices (passed in config), ensure FINAL_LOSS is in the 43 | eval slices. FINAL_LOSS gives overall metrics. 44 | 45 | Args: 46 | eval_slices: list of input eval slices 47 | 48 | Returns: list of eval slices to use in the model 49 | """ 50 | slice_names = eval_slices[:] 51 | # FINAL LOSS is in ALL MODELS for ALL SLICES 52 | if FINAL_LOSS not in slice_names: 53 | slice_names.insert(0, FINAL_LOSS) 54 | return slice_names 55 | 56 | 57 | def get_save_data_folder(data_args, use_weak_label, dataset): 58 | """ 59 | Get save data folder for the prepped data. 60 | 61 | Args: 62 | data_args: data config 63 | use_weak_label: whether to use weak labelling or not 64 | dataset: dataset name 65 | 66 | Returns: folder string path 67 | """ 68 | name = os.path.splitext(os.path.basename(dataset))[0] 69 | direct = os.path.dirname(dataset) 70 | bert_mod = data_args.word_embedding.bert_model.replace("/", "_") 71 | fold_name = ( 72 | f"{name}_{bert_mod}_L{data_args.max_seq_len}_E{data_args.max_ent_len}" 73 | f"_W{data_args.max_seq_window_len}" 74 | f"_T{data_args.entity_type_data.use_entity_types}" 75 | f"_K{data_args.entity_kg_data.use_entity_kg}" 76 | f"_D{data_args.use_entity_desc}" 77 | f"_InC{int(data_args.train_in_candidates)}" 78 | f"_Aug{int(use_weak_label)}" 79 | ) 80 | return os.path.join(direct, data_args.data_prep_dir, fold_name) 81 | 82 | 83 | def get_save_data_folder_candgen(data_args, use_weak_label, dataset): 84 | """Give save data folder for the prepped data. 85 | 86 | Args: 87 | data_args: data config 88 | use_weak_label: whether to use weak labelling or not 89 | dataset: dataset name 90 | 91 | Returns: folder string path 92 | """ 93 | name = os.path.splitext(os.path.basename(dataset))[0] 94 | direct = os.path.dirname(dataset) 95 | bert_mod = data_args.word_embedding.bert_model.replace("/", "_") 96 | fold_name = ( 97 | f"{name}_{bert_mod}_L{data_args.max_seq_len}_E{data_args.max_ent_len}" 98 | f"_W{data_args.max_seq_window_len}" 99 | f"_A{data_args.use_entity_akas}" 100 | f"_D{data_args.use_entity_desc}" 101 | f"_InC{int(data_args.train_in_candidates)}" 102 | f"_Aug{int(use_weak_label)}" 103 | ) 104 | return os.path.join(direct, data_args.data_prep_dir, fold_name) 105 | 106 | 107 | def generate_slice_name(data_args, slice_names, use_weak_label, dataset): 108 | """ 109 | Generate name for slice datasets, taking into account the config eval slices. 110 | 111 | Args: 112 | data_args: data args 113 | slice_names: slice names 114 | use_weak_label: if using weak labels or not 115 | dataset: dataset name 116 | 117 | Returns: dataset name for saving slice data 118 | """ 119 | dataset_name = os.path.join( 120 | get_save_data_folder(data_args, use_weak_label, dataset), "slices.pt" 121 | ) 122 | names_for_dataset = str(hash(slice_names)) 123 | dataset_name = os.path.splitext(dataset_name)[0] + "_" + names_for_dataset + ".pt" 124 | return dataset_name 125 | 126 | 127 | def get_emb_prep_dir(data_config): 128 | """ 129 | Get embedding prep directory for saving prep files. 130 | 131 | Args: 132 | data_config: data config 133 | 134 | Returns: directory path 135 | """ 136 | prep_dir = os.path.join(data_config.entity_dir, data_config.entity_prep_dir) 137 | utils.ensure_dir(prep_dir) 138 | return prep_dir 139 | 140 | 141 | def get_data_prep_dir(data_config): 142 | """ 143 | Get data prep directory for saving prep files. 144 | 145 | Args: 146 | data_config: data config 147 | 148 | Returns: directory path 149 | """ 150 | prep_dir = os.path.join(data_config.data_dir, data_config.data_prep_dir) 151 | utils.ensure_dir(prep_dir) 152 | return prep_dir 153 | 154 | 155 | def get_chunk_dir(prep_dir): 156 | """ 157 | Get directory for saving data chunks. 158 | 159 | Args: 160 | prep_dir: prep directory 161 | 162 | Returns: directory path 163 | """ 164 | return os.path.join(prep_dir, "chunks") 165 | 166 | 167 | def add_special_tokens(tokenizer): 168 | """ 169 | Add special tokens. 170 | 171 | Args: 172 | tokenizer: tokenizer 173 | data_config: data config 174 | entitysymbols: entity symbols 175 | """ 176 | # Add standard tokens 177 | tokenizer.add_special_tokens(SPECIAL_TOKENS) 178 | 179 | 180 | def read_in_akas(entitysymbols): 181 | """Read in alias to QID mappings and generates a QID to list of alternate names. 182 | 183 | Args: 184 | entitysymbols: entity symbols 185 | 186 | Returns: dictionary of QID to type names 187 | """ 188 | # take the first type; UNK type is 0 189 | qid2aliases = {} 190 | for al in entitysymbols.get_all_aliases(): 191 | for qid in entitysymbols.get_qid_cands(al): 192 | if qid not in qid2aliases: 193 | qid2aliases[qid] = set() 194 | qid2aliases[qid].add(al) 195 | # Turn into sets for dumping 196 | for qid in qid2aliases: 197 | qid2aliases[qid] = list(qid2aliases[qid]) 198 | return qid2aliases 199 | -------------------------------------------------------------------------------- /bootleg/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | """Model utils.""" 2 | import logging 3 | 4 | from bootleg import log_rank_0_debug 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def count_parameters(model, requires_grad, logger): 10 | """Count the number of parameters. 11 | 12 | Args: 13 | model: model to count 14 | requires_grad: whether to look at grad or no grad params 15 | logger: logger 16 | """ 17 | for p in [ 18 | p for p in model.named_parameters() if p[1].requires_grad is requires_grad 19 | ]: 20 | log_rank_0_debug( 21 | logger, 22 | "{:s} {:d} {:.2f} MB".format( 23 | p[0], p[1].numel(), p[1].numel() * 4 / 1024**2 24 | ), 25 | ) 26 | return sum( 27 | p.numel() for p in model.parameters() if p.requires_grad is requires_grad 28 | ) 29 | 30 | 31 | def get_max_candidates(entity_symbols, data_config): 32 | """ 33 | Get max candidates. 34 | 35 | Returns the maximum number of candidates used in the model, taking into 36 | account train_in_candidates If train_in_canddiates is False, we add a NC 37 | entity candidate (for null candidate) 38 | 39 | Args: 40 | entity_symbols: entity symbols 41 | data_config: data config 42 | """ 43 | return entity_symbols.max_candidates + int(not data_config.train_in_candidates) 44 | -------------------------------------------------------------------------------- /bootleg/utils/parser/__init__.py: -------------------------------------------------------------------------------- 1 | """Parser init.""" 2 | -------------------------------------------------------------------------------- /bootleg/utils/parser/bootleg_args.py: -------------------------------------------------------------------------------- 1 | """Bootleg default configuration parameters. 2 | 3 | In the json file, everything is a string or number. In this python file, 4 | if the default is a boolean, it will be parsed as such. If the default 5 | is a dictionary, True and False strings will become booleans. Otherwise 6 | they will stay string. 7 | """ 8 | import multiprocessing 9 | 10 | config_args = { 11 | "run_config": { 12 | "spawn_method": ( 13 | "forkserver", 14 | "multiprocessing spawn method. forkserver will save memory but have slower startup costs.", 15 | ), 16 | "eval_batch_size": (128, "batch size for eval"), 17 | "dump_preds_accumulation_steps": ( 18 | 1000, 19 | "number of eval steps to accumulate the output tensors for before saving results to file", 20 | ), 21 | "dump_preds_num_data_splits": ( 22 | 1, 23 | "number of chunks to split the input file; helps with OOM issues", 24 | ), 25 | "overwrite_eval_dumps": (False, "overwrite dumped eval data"), 26 | "dataloader_threads": (16, "data loader threads to feed gpus"), 27 | "log_level": ("info", "logging level"), 28 | "dataset_threads": ( 29 | int(multiprocessing.cpu_count() * 0.9), 30 | "data set threads for prepping data", 31 | ), 32 | "result_label_file": ( 33 | "bootleg_labels.jsonl", 34 | "file name to save predicted entities in", 35 | ), 36 | "result_emb_file": ( 37 | "bootleg_embs.npy", 38 | "file name to save contextualized embs in", 39 | ), 40 | }, 41 | # Parameters for hyperparameter tuning 42 | "train_config": { 43 | "batch_size": (32, "batch size"), 44 | }, 45 | "model_config": { 46 | "hidden_size": (300, "hidden dimension for the embeddings before scoring"), 47 | "normalize": (False, "normalize embeddings before dot product"), 48 | "temperature": (1.0, "temperature for softmax in loss"), 49 | }, 50 | "data_config": { 51 | "eval_slices": ([], "slices for evaluation"), 52 | "train_in_candidates": ( 53 | True, 54 | "Train in candidates (if False, this means we include NIL entity)", 55 | ), 56 | "data_dir": ("data", "where training, testing, and dev data is stored"), 57 | "data_prep_dir": ( 58 | "prep", 59 | "directory where data prep files are saved inside data_dir", 60 | ), 61 | "entity_dir": ( 62 | "entity_data", 63 | "where entity profile information and prepped embedding data is stored", 64 | ), 65 | "entity_prep_dir": ( 66 | "prep", 67 | "directory where prepped embedding data is saved inside entity_dir", 68 | ), 69 | "entity_map_dir": ( 70 | "entity_mappings", 71 | "directory where entity json mappings are saved inside entity_dir", 72 | ), 73 | "alias_cand_map": ( 74 | "alias2qids", 75 | "name of alias candidate map file, should be saved in entity_dir/entity_map_dir", 76 | ), 77 | "alias_idx_map": ( 78 | "alias2id", 79 | "name of alias index map file, should be saved in entity_dir/entity_map_dir", 80 | ), 81 | "qid_cnt_map": ( 82 | "qid2cnt.json", 83 | "name of alias index map file, should be saved in data_dir", 84 | ), 85 | "max_seq_len": (128, "max token length sentences"), 86 | "max_seq_window_len": (64, "max window around an entity"), 87 | "max_ent_len": (128, "max token length for entire encoded entity"), 88 | "context_mask_perc": ( 89 | 0.0, 90 | "mask percent for context tokens in addition to tail masking", 91 | ), 92 | "popularity_mask": ( 93 | True, 94 | "whether to use popularity masking for training in the entity and context encoders", 95 | ), 96 | "overwrite_preprocessed_data": (False, "overwrite preprocessed data"), 97 | "print_examples_prep": (True, "whether to print examples during prep or not"), 98 | "use_entity_desc": (True, "whether to use entity descriptions or not"), 99 | "entity_type_data": { 100 | "use_entity_types": (False, "whether to use entity type data"), 101 | "type_symbols_dir": ( 102 | "type_mappings/wiki", 103 | "directory to type symbols inside entity_dir", 104 | ), 105 | "max_ent_type_len": (20, "max WORD length for type sequence"), 106 | }, 107 | "entity_kg_data": { 108 | "use_entity_kg": (False, "whether to use entity type data"), 109 | "kg_symbols_dir": ( 110 | "kg_mappings", 111 | "directory to kg symbols inside entity_dir", 112 | ), 113 | "max_ent_kg_len": (60, "max WORD length for kg sequence"), 114 | }, 115 | "train_dataset": { 116 | "file": ("train.jsonl", ""), 117 | "use_weak_label": (True, "Use weakly labeled mentions"), 118 | }, 119 | "dev_dataset": { 120 | "file": ("dev.jsonl", ""), 121 | "use_weak_label": (True, "Use weakly labeled mentions"), 122 | }, 123 | "test_dataset": { 124 | "file": ("test.jsonl", ""), 125 | "use_weak_label": (True, "Use weakly labeled mentions"), 126 | }, 127 | "word_embedding": { 128 | "bert_model": ("bert-base-uncased", ""), 129 | "context_layers": (12, ""), 130 | "entity_layers": (12, ""), 131 | "cache_dir": ( 132 | "pretrained_bert_models", 133 | "Directory where word embeddings are cached", 134 | ), 135 | }, 136 | }, 137 | } 138 | -------------------------------------------------------------------------------- /bootleg/utils/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | """Preprocessing init.""" 2 | -------------------------------------------------------------------------------- /bootleg/utils/preprocessing/convert_to_char_spans.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute QID counts. 3 | 4 | Helper function that computes a dictionary of QID -> count in training data. 5 | 6 | If a QID is not in this dictionary, it has a count of zero. 7 | """ 8 | 9 | import argparse 10 | import multiprocessing 11 | import os 12 | import shutil 13 | import tempfile 14 | from collections import defaultdict 15 | from pathlib import Path 16 | 17 | import ujson 18 | from tqdm.auto import tqdm 19 | 20 | 21 | def parse_args(): 22 | """Parse args.""" 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--file", 26 | type=str, 27 | default="train.jsonl", 28 | ) 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def get_char_spans(spans, text): 35 | """ 36 | Get character spans instead of default word spans. 37 | 38 | Args: 39 | spans: word spans 40 | text: text 41 | 42 | Returns: character spans 43 | """ 44 | word_i = 0 45 | prev_is_space = True 46 | char2word = {} 47 | word2char = defaultdict(list) 48 | for char_i, c in enumerate(text): 49 | if c.isspace(): 50 | if not prev_is_space: 51 | word_i += 1 52 | prev_is_space = True 53 | else: 54 | prev_is_space = False 55 | char2word[char_i] = word_i 56 | word2char[word_i].append(char_i) 57 | char_spans = [] 58 | for span in spans: 59 | char_l = min(word2char[span[0]]) 60 | char_r = max(word2char[span[1] - 1]) + 1 61 | char_spans.append([char_l, char_r]) 62 | return char_spans 63 | 64 | 65 | def convert_char_spans(num_processes, file): 66 | """Add char spans to jsonl file.""" 67 | pool = multiprocessing.Pool(processes=num_processes) 68 | num_lines = sum([1 for _ in open(file)]) 69 | temp_file = Path(tempfile.gettempdir()) / "_convert_char_spans.jsonl" 70 | with open(file) as in_f, open(temp_file, "wb") as out_f: 71 | for res in tqdm( 72 | pool.imap_unordered(convert_char_spans_helper, in_f, chunksize=100), 73 | total=num_lines, 74 | desc="Adding char spans", 75 | ): 76 | out_f.write(bytes(res, encoding="utf-8")) 77 | out_f.seek(0) 78 | pool.close() 79 | pool.join() 80 | shutil.copy(temp_file, file) 81 | os.remove(temp_file) 82 | return 83 | 84 | 85 | def convert_char_spans_helper(line): 86 | """Get char spans helper. 87 | 88 | Parses line, adds char spans, and dumps it back again 89 | """ 90 | line = ujson.loads(line) 91 | line["char_spans"] = get_char_spans(line["spans"], line["sentence"]) 92 | to_write = ujson.dumps(line) + "\n" 93 | return to_write 94 | 95 | 96 | def main(): 97 | """Run.""" 98 | args = parse_args() 99 | print(ujson.dumps(vars(args), indent=4)) 100 | num_processes = int(0.8 * multiprocessing.cpu_count()) 101 | print(f"Getting slice counts from {args.file}") 102 | convert_char_spans(num_processes, args.file) 103 | 104 | 105 | if __name__ == "__main__": 106 | main() 107 | -------------------------------------------------------------------------------- /bootleg/utils/preprocessing/get_train_qid_counts.py: -------------------------------------------------------------------------------- 1 | """ 2 | Compute QID counts. 3 | 4 | Helper function that computes a dictionary of QID -> count in training data. 5 | 6 | If a QID is not in this dictionary, it has a count of zero. 7 | """ 8 | 9 | import argparse 10 | import multiprocessing 11 | from collections import defaultdict 12 | 13 | import ujson 14 | from tqdm.auto import tqdm 15 | 16 | from bootleg.utils import utils 17 | 18 | 19 | def parse_args(): 20 | """Parse args.""" 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--train_file", 24 | type=str, 25 | default="/dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0114/train.jsonl", 26 | ) 27 | parser.add_argument( 28 | "--out_file", 29 | type=str, 30 | default="/dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0114/train_qidcnt.json", 31 | help="Regularization of each qid", 32 | ) 33 | 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def get_counts(num_processes, file): 39 | """Get true anchor slice counts.""" 40 | pool = multiprocessing.Pool(processes=num_processes) 41 | num_lines = sum(1 for _ in open(file)) 42 | qid_cnts = defaultdict(int) 43 | for res in tqdm( 44 | pool.imap_unordered(get_counts_hlp, open(file), chunksize=1000), 45 | total=num_lines, 46 | desc="Gathering counts", 47 | ): 48 | for qid in res: 49 | qid_cnts[qid] += res[qid] 50 | pool.close() 51 | pool.join() 52 | return qid_cnts 53 | 54 | 55 | def get_counts_hlp(line): 56 | """Get count helper.""" 57 | res = defaultdict(int) # qid -> cnt 58 | line = ujson.loads(line) 59 | for qid in line["qids"]: 60 | res[qid] += 1 61 | return res 62 | 63 | 64 | def main(): 65 | """Run.""" 66 | args = parse_args() 67 | print(ujson.dumps(vars(args), indent=4)) 68 | num_processes = int(0.8 * multiprocessing.cpu_count()) 69 | print(f"Getting slice counts from {args.train_file}") 70 | qid_cnts = get_counts(num_processes, args.train_file) 71 | utils.dump_json_file(args.out_file, qid_cnts) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /cand_gen/task_config.py: -------------------------------------------------------------------------------- 1 | """Emmental task constants.""" 2 | 3 | CANDGEN_TASK = "CANDGEN" 4 | BATCH_CANDS_LABEL = "gold_unq_eid_idx" 5 | -------------------------------------------------------------------------------- /cand_gen/tasks/candgen_task.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from emmental.scorer import Scorer 4 | from emmental.task import Action, EmmentalTask 5 | from torch import nn 6 | from transformers import AutoModel 7 | 8 | from bootleg.layers.bert_encoder import Encoder 9 | from bootleg.scorer import BootlegSlicedScorer 10 | from cand_gen.task_config import CANDGEN_TASK 11 | 12 | 13 | class DisambigLoss: 14 | def __init__(self, normalize, temperature): 15 | self.normalize = normalize 16 | self.temperature = temperature 17 | 18 | def batch_cands_disambig_output(self, intermediate_output_dict): 19 | """Function to return the probs for a task in Emmental. 20 | Args: 21 | intermediate_output_dict: output dict from Emmental task flow 22 | Returns: NED probabilities for candidates (B x M x K) 23 | """ 24 | out = intermediate_output_dict["context_encoder"][0] 25 | ent_out = intermediate_output_dict["entity_encoder"][0] 26 | if self.normalize: 27 | out = F.normalize(out, p=2, dim=-1) 28 | ent_out = F.normalize(ent_out, p=2, dim=-1) 29 | score = torch.mm(out, ent_out.t()) / self.temperature 30 | return F.softmax(score, dim=-1) 31 | 32 | def batch_cands_disambig_loss(self, intermediate_output_dict, Y): 33 | """Returns the entity disambiguation loss on prediction heads. 34 | Args: 35 | intermediate_output_dict: output dict from the Emmental task flor 36 | Y: gold labels 37 | Returns: loss 38 | """ 39 | # Grab the first value of training (when doing distributed training, we will have one per process) 40 | training = intermediate_output_dict["context_encoder"][1].item() 41 | assert type(training) is bool 42 | out = intermediate_output_dict["context_encoder"][0] 43 | ent_out = intermediate_output_dict["entity_encoder"][0] 44 | if self.normalize: 45 | out = F.normalize(out, p=2, dim=-1) 46 | ent_out = F.normalize(ent_out, p=2, dim=-1) 47 | score = torch.mm(out, ent_out.t()) / self.temperature 48 | 49 | labels = Y 50 | masked_labels = labels.reshape(out.shape[0]) 51 | if not training: 52 | label_mask = labels == -2 53 | masked_labels = torch.where( 54 | ~label_mask, labels, torch.ones_like(labels) * -1 55 | ) 56 | masked_labels = masked_labels.reshape(out.shape[0]) 57 | temp = nn.CrossEntropyLoss(ignore_index=-1)(score, masked_labels.long()) 58 | return temp 59 | 60 | 61 | def create_task(args, len_context_tok, slice_datasets=None): 62 | """Returns an EmmentalTask for named entity disambiguation (NED). 63 | 64 | Args: 65 | args: args 66 | entity_symbols: entity symbols (default None) 67 | slice_datasets: slice datasets used in scorer (default None) 68 | 69 | Returns: EmmentalTask for NED 70 | """ 71 | disamig_loss = DisambigLoss( 72 | args.model_config.normalize, args.model_config.temperature 73 | ) 74 | loss_func = disamig_loss.batch_cands_disambig_loss 75 | output_func = disamig_loss.batch_cands_disambig_output 76 | 77 | # Create sentence encoder 78 | context_model = AutoModel.from_pretrained( 79 | args.data_config.word_embedding.bert_model 80 | ) 81 | context_model.encoder.layer = context_model.encoder.layer[ 82 | : args.data_config.word_embedding.context_layers 83 | ] 84 | context_model.resize_token_embeddings(len_context_tok) 85 | context_model = Encoder(context_model, args.model_config.hidden_size) 86 | 87 | entity_model = AutoModel.from_pretrained(args.data_config.word_embedding.bert_model) 88 | entity_model.encoder.layer = entity_model.encoder.layer[ 89 | : args.data_config.word_embedding.entity_layers 90 | ] 91 | entity_model.resize_token_embeddings(len_context_tok) 92 | entity_model = Encoder(entity_model, args.model_config.hidden_size) 93 | 94 | sliced_scorer = BootlegSlicedScorer( 95 | args.data_config.train_in_candidates, slice_datasets 96 | ) 97 | 98 | # Create module pool and combine with embedding module pool 99 | module_pool = nn.ModuleDict( 100 | { 101 | "context_encoder": context_model, 102 | "entity_encoder": entity_model, 103 | } 104 | ) 105 | 106 | # Create task flow 107 | task_flow = [ 108 | Action( 109 | name="entity_encoder", 110 | module="entity_encoder", 111 | inputs=[ 112 | ("_input_", "entity_cand_input_ids"), 113 | ("_input_", "entity_cand_attention_mask"), 114 | ("_input_", "entity_cand_token_type_ids"), 115 | ], 116 | ), 117 | Action( 118 | name="context_encoder", 119 | module="context_encoder", 120 | inputs=[ 121 | ("_input_", "input_ids"), 122 | ("_input_", "token_type_ids"), 123 | ("_input_", "attention_mask"), 124 | ], 125 | ), 126 | ] 127 | 128 | return EmmentalTask( 129 | name=CANDGEN_TASK, 130 | module_pool=module_pool, 131 | task_flow=task_flow, 132 | loss_func=loss_func, 133 | output_func=output_func, 134 | require_prob_for_eval=False, 135 | require_pred_for_eval=True, 136 | # action_outputs are used to stitch together sentence fragments 137 | action_outputs=[ 138 | ("_input_", "sent_idx"), 139 | ("_input_", "subsent_idx"), 140 | ("_input_", "alias_orig_list_pos"), 141 | ("_input_", "for_dump_gold_cand_K_idx_train"), 142 | ("entity_encoder", 0), # entity embeddings 143 | ], 144 | scorer=Scorer( 145 | customize_metric_funcs={ 146 | f"{CANDGEN_TASK}_scorer": sliced_scorer.bootleg_score 147 | } 148 | ), 149 | ) 150 | -------------------------------------------------------------------------------- /cand_gen/tasks/context_gen_task.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from emmental.scorer import Scorer 3 | from emmental.task import Action, EmmentalTask 4 | from torch import nn 5 | from transformers import AutoModel 6 | 7 | from bootleg.layers.bert_encoder import Encoder 8 | from cand_gen.task_config import CANDGEN_TASK 9 | 10 | 11 | class ContextGenOutput: 12 | """Context gen for output.""" 13 | 14 | def __init__(self, normalize): 15 | """Context gen for output initializer.""" 16 | self.normalize = normalize 17 | 18 | def entity_output_func(self, intermediate_output_dict): 19 | """Context output func.""" 20 | ctx_out = intermediate_output_dict["context_encoder"][0] 21 | if self.normalize: 22 | ctx_out = F.normalize(ctx_out, p=2, dim=-1) 23 | return ctx_out 24 | 25 | 26 | def create_task(args, len_context_tok): 27 | """Returns an EmmentalTask for a forward pass through the entity encoder only. 28 | 29 | Args: 30 | args: args 31 | len_context_tok: number of tokens in the tokenizer 32 | 33 | Returns: EmmentalTask for entity embedding extraction 34 | """ 35 | 36 | # Create sentence encoder 37 | context_model = AutoModel.from_pretrained( 38 | args.data_config.word_embedding.bert_model 39 | ) 40 | context_model.encoder.layer = context_model.encoder.layer[ 41 | : args.data_config.word_embedding.context_layers 42 | ] 43 | context_model.resize_token_embeddings(len_context_tok) 44 | context_model = Encoder(context_model, args.model_config.hidden_size) 45 | 46 | # Create module pool and combine with embedding module pool 47 | module_pool = nn.ModuleDict( 48 | { 49 | "context_encoder": context_model, 50 | } 51 | ) 52 | 53 | # Create task flow 54 | task_flow = [ 55 | Action( 56 | name="context_encoder", 57 | module="context_encoder", 58 | inputs=[ 59 | ("_input_", "input_ids"), 60 | ("_input_", "token_type_ids"), 61 | ("_input_", "attention_mask"), 62 | ], 63 | ), 64 | ] 65 | 66 | return EmmentalTask( 67 | name=CANDGEN_TASK, 68 | module_pool=module_pool, 69 | task_flow=task_flow, 70 | loss_func=None, 71 | output_func=ContextGenOutput(args.model_config.normalize).entity_output_func, 72 | require_prob_for_eval=False, 73 | require_pred_for_eval=True, 74 | scorer=Scorer(), 75 | ) 76 | -------------------------------------------------------------------------------- /cand_gen/tasks/entity_gen_task.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from emmental.scorer import Scorer 3 | from emmental.task import Action, EmmentalTask 4 | from torch import nn 5 | from transformers import AutoModel 6 | 7 | from bootleg.layers.bert_encoder import Encoder 8 | from cand_gen.task_config import CANDGEN_TASK 9 | 10 | 11 | class EntityGenOutput: 12 | """Entity gen for output.""" 13 | 14 | def __init__(self, normalize): 15 | """Entity gen for output initializer.""" 16 | self.normalize = normalize 17 | 18 | def entity_output_func(self, intermediate_output_dict): 19 | """Entity output func.""" 20 | ent_out = intermediate_output_dict["entity_encoder"][0] 21 | if self.normalize: 22 | ent_out = F.normalize(ent_out, p=2, dim=-1) 23 | return ent_out 24 | 25 | 26 | def create_task(args, len_context_tok): 27 | """Returns an EmmentalTask for a forward pass through the entity encoder only. 28 | 29 | Args: 30 | args: args 31 | len_context_tok: number of tokens in the tokenizer 32 | 33 | Returns: EmmentalTask for entity embedding extraction 34 | """ 35 | 36 | entity_model = AutoModel.from_pretrained(args.data_config.word_embedding.bert_model) 37 | entity_model.encoder.layer = entity_model.encoder.layer[ 38 | : args.data_config.word_embedding.entity_layers 39 | ] 40 | entity_model.resize_token_embeddings(len_context_tok) 41 | entity_model = Encoder(entity_model, args.model_config.hidden_size) 42 | 43 | # Create module pool and combine with embedding module pool 44 | module_pool = nn.ModuleDict( 45 | { 46 | "entity_encoder": entity_model, 47 | } 48 | ) 49 | 50 | # Create task flow 51 | task_flow = [ 52 | Action( 53 | name="entity_encoder", 54 | module="entity_encoder", 55 | inputs=[ 56 | ("_input_", "entity_input_ids"), 57 | ("_input_", "entity_attention_mask"), 58 | ("_input_", "entity_token_type_ids"), 59 | ], 60 | ), 61 | ] 62 | 63 | return EmmentalTask( 64 | name=CANDGEN_TASK, 65 | module_pool=module_pool, 66 | task_flow=task_flow, 67 | loss_func=None, 68 | output_func=EntityGenOutput(args.model_config.normalize).entity_output_func, 69 | require_prob_for_eval=False, 70 | require_pred_for_eval=True, 71 | scorer=Scorer(), 72 | ) 73 | -------------------------------------------------------------------------------- /cand_gen/utils/parser/candgen_args.py: -------------------------------------------------------------------------------- 1 | """Bootleg default configuration parameters. 2 | 3 | In the json file, everything is a string or number. In this python file, 4 | if the default is a boolean, it will be parsed as such. If the default 5 | is a dictionary, True and False strings will become booleans. Otherwise 6 | they will stay string. 7 | """ 8 | import multiprocessing 9 | 10 | config_args = { 11 | "run_config": { 12 | "spawn_method": ( 13 | "forkserver", 14 | "multiprocessing spawn method. forkserver will save memory but have slower startup costs.", 15 | ), 16 | "eval_batch_size": (128, "batch size for eval"), 17 | "dump_preds_accumulation_steps": ( 18 | 10000, 19 | "number of eval steps to accumulate the output tensors for before saving results to file", 20 | ), 21 | "dataloader_threads": (16, "data loader threads to feed gpus"), 22 | "log_level": ("info", "logging level"), 23 | "dataset_threads": ( 24 | int(multiprocessing.cpu_count() * 0.9), 25 | "data set threads for prepping data", 26 | ), 27 | }, 28 | # Parameters for hyperparameter tuning 29 | "train_config": { 30 | "batch_size": (32, "batch size"), 31 | }, 32 | "model_config": { 33 | "hidden_size": (300, "hidden dimension for the embeddings before scoring"), 34 | "normalize": (False, "normalize embeddings before dot product"), 35 | "temperature": (1.0, "temperature for softmax in loss"), 36 | }, 37 | "data_config": { 38 | "eval_slices": ([], "slices for evaluation"), 39 | "train_in_candidates": ( 40 | True, 41 | "Train in candidates (if False, this means we include NIL entity)", 42 | ), 43 | "data_dir": ("data", "where training, testing, and dev data is stored"), 44 | "data_prep_dir": ( 45 | "prep", 46 | "directory where data prep files are saved inside data_dir", 47 | ), 48 | "entity_dir": ( 49 | "entity_data", 50 | "where entity profile information and prepped embedding data is stored", 51 | ), 52 | "entity_prep_dir": ( 53 | "prep", 54 | "directory where prepped embedding data is saved inside entity_dir", 55 | ), 56 | "entity_map_dir": ( 57 | "entity_mappings", 58 | "directory where entity json mappings are saved inside entity_dir", 59 | ), 60 | "alias_cand_map": ( 61 | "alias2qids", 62 | "name of alias candidate map file, should be saved in entity_dir/entity_map_dir", 63 | ), 64 | "alias_idx_map": ( 65 | "alias2id", 66 | "name of alias index map file, should be saved in entity_dir/entity_map_dir", 67 | ), 68 | "qid_cnt_map": ( 69 | "qid2cnt.json", 70 | "name of alias index map file, should be saved in data_dir", 71 | ), 72 | "max_seq_len": (128, "max token length sentences"), 73 | "max_seq_window_len": (64, "max window around an entity"), 74 | "max_ent_len": (128, "max token length for entire encoded entity"), 75 | "max_ent_aka_len": (20, "max token length for alternate names"), 76 | "overwrite_preprocessed_data": (False, "overwrite preprocessed data"), 77 | "print_examples_prep": (True, "whether to print examples during prep or not"), 78 | "use_entity_desc": (True, "whether to use entity descriptions or not"), 79 | "use_entity_akas": ( 80 | True, 81 | "whether to use entity alternates names from the candidates or not", 82 | ), 83 | "train_dataset": { 84 | "file": ("train.jsonl", ""), 85 | "use_weak_label": (True, "Use weakly labeled mentions"), 86 | }, 87 | "dev_dataset": { 88 | "file": ("dev.jsonl", ""), 89 | "use_weak_label": (True, "Use weakly labeled mentions"), 90 | }, 91 | "test_dataset": { 92 | "file": ("test.jsonl", ""), 93 | "use_weak_label": (True, "Use weakly labeled mentions"), 94 | }, 95 | "word_embedding": { 96 | "bert_model": ("bert-base-uncased", ""), 97 | "context_layers": (1, ""), 98 | "entity_layers": (1, ""), 99 | "cache_dir": ( 100 | "pretrained_bert_models", 101 | "Directory where word embeddings are cached", 102 | ), 103 | }, 104 | }, 105 | } 106 | -------------------------------------------------------------------------------- /cand_gen/utils/parser/parser_utils.py: -------------------------------------------------------------------------------- 1 | """Parses a Booleg input config into a DottedDict of config values (with 2 | defaults filled in) for running a model.""" 3 | 4 | import argparse 5 | import os 6 | 7 | from bootleg.utils.classes.dotted_dict import create_bool_dotted_dict 8 | from bootleg.utils.parser.emm_parse_args import ( 9 | parse_args as emm_parse_args, 10 | parse_args_to_config as emm_parse_args_to_config, 11 | ) 12 | from bootleg.utils.parser.parser_utils import ( 13 | add_nested_flags_from_config, 14 | flatten_nested_args_for_parser, 15 | load_commented_json_file, 16 | merge_configs, 17 | reconstructed_nested_args, 18 | recursive_keys, 19 | ) 20 | from bootleg.utils.utils import load_yaml_file 21 | from cand_gen.utils.parser.candgen_args import config_args 22 | 23 | 24 | def get_boot_config(config, parser_hierarchy=None, parser=None, unknown=None): 25 | """ 26 | Returns a parsed Bootleg config from config. Config can be a path to a config file or an already loaded dictionary. 27 | The high level work flow 28 | 1. Reads Bootleg default config (config_args) and addes params to a arg parser, 29 | flattening all hierarchical values into "." values 30 | E.g., data_config -> word_embeddings -> layers becomes --data_config.word_embedding.layers 31 | 2. Flattens the given config values into the "." format 32 | 3. Adds any unknown values from the first arg parser that parses the config script. 33 | Allows the user to add --data_config.word_embedding.layers to command line that overwrite values in file 34 | 4. Parses the flattened args w.r.t the arg parser 35 | 5. Reconstruct the args back into their hierarchical form 36 | Args: 37 | config: model specific config 38 | parser_hierarchy: Dict of hierarchy of config (or None) 39 | parser: arg parser (or None) 40 | unknown: unknown arg values passed from command line to be added to config and overwrite values in file 41 | 42 | Returns: parsed config 43 | 44 | """ 45 | if unknown is None: 46 | unknown = [] 47 | if parser_hierarchy is None: 48 | parser_hierarchy = {} 49 | if parser is None: 50 | parser = argparse.ArgumentParser() 51 | 52 | add_nested_flags_from_config(parser, config_args, parser_hierarchy, prefix="") 53 | if type(config) is str: 54 | assert os.path.splitext(config)[1] in [ 55 | ".json", 56 | ".yaml", 57 | ], "We only accept json or yaml ending for configs" 58 | if os.path.splitext(config)[1] == ".json": 59 | params = load_commented_json_file(config) 60 | else: 61 | params = load_yaml_file(config) 62 | else: 63 | assert ( 64 | type(config) is dict 65 | ), "We only support loading configs that are paths to json/yaml files or preloaded configs." 66 | params = config 67 | all_keys = list(recursive_keys(parser_hierarchy)) 68 | new_params = flatten_nested_args_for_parser(params, [], groups=all_keys, prefix="") 69 | # update with new args 70 | # unknown must have ["--arg1", "value1", "--arg2", "value2"] as we don't have any action_true args 71 | assert len(unknown) % 2 == 0 72 | assert all( 73 | unknown[idx].startswith(("-", "--")) for idx in range(0, len(unknown), 2) 74 | ) 75 | for idx in range(1, len(unknown), 2): 76 | # allow passing -1 for emmental.device argument 77 | assert (not unknown[idx].startswith(("-", "--"))) or ( 78 | unknown[idx - 1] == "--emmental.device" and unknown[idx] == "-1" 79 | ) 80 | for idx in range(0, len(unknown), 2): 81 | arg = unknown[idx] 82 | # If override one you already have in json 83 | if arg in new_params: 84 | idx2 = new_params.index(arg) 85 | new_params[idx2 : idx2 + 2] = unknown[idx : idx + 2] 86 | # If override one that is in bootleg_args.py by not in json 87 | else: 88 | new_params.extend(unknown[idx : idx + 2]) 89 | args = parser.parse_args(new_params) 90 | top_names = {} 91 | reconstructed_nested_args(args, top_names, parser_hierarchy, prefix="") 92 | # final_args = argparse.Namespace(**top_names) 93 | final_args = create_bool_dotted_dict(top_names) 94 | # turn_to_dotdicts(final_args) 95 | return final_args 96 | 97 | 98 | def parse_boot_and_emm_args(config_script, unknown=None): 99 | """ 100 | Merges the Emmental config with the Bootleg config. 101 | As we have an emmental: ... level in our config for emmental commands, 102 | we need to parse those with the Emmental parser and then merge the Bootleg only config values 103 | with the Emmental ones. 104 | Args: 105 | config_script: config script for Bootleg and Emmental args 106 | unknown: unknown arg values passed from command line to overwrite file values 107 | 108 | Returns: parsed merged Bootleg and Emmental config 109 | 110 | """ 111 | if unknown is None: 112 | unknown = [] 113 | config_parser = argparse.ArgumentParser( 114 | description="Bootleg Config", 115 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 116 | ) 117 | # Modified parse_args to have 'emmental.group' prefixes. This represents a hierarchy in our parser 118 | config_parser, parser_hierarchy = emm_parse_args(parser=config_parser) 119 | # Add Bootleg args and parse 120 | all_args = get_boot_config(config_script, parser_hierarchy, config_parser, unknown) 121 | # These have emmental -> config group -> arg structure for emmental. 122 | # Must remove that hierarchy to converte to internal Emmental hierarchy 123 | emm_args = {} 124 | for k, v in all_args["emmental"].items(): 125 | emm_args[k] = v 126 | del all_args["emmental"] 127 | # create and add Emmental hierarchy 128 | config = emm_parse_args_to_config(create_bool_dotted_dict(emm_args)) 129 | # Merge configs back (merge workds on dicts so must convert to dict first) 130 | config = create_bool_dotted_dict(merge_configs(all_args, config)) 131 | return config 132 | -------------------------------------------------------------------------------- /configs/cand_gen/bi_train.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_epochs: 1 4 | evaluation_freq: 1 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: /dfs/scratch0/lorr1/projects/bootleg-data/logs/cand_gen_korealiases 8 | checkpointing: true 9 | checkpoint_all: true 10 | checkpoint_freq: 1 11 | online_eval: false 12 | clear_intermediate_checkpoints: false 13 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 14 | l2: 0.01 15 | grad_clip: 1.0 16 | gradient_accumulation_steps: 4 17 | fp16: true 18 | run_config: 19 | log_level: DEBUG 20 | eval_batch_size: 32 21 | dataloader_threads: 6 22 | dataset_threads: 40 23 | spawn_method: fork 24 | train_config: 25 | batch_size: 8 26 | model_config: 27 | hidden_size: 100 28 | normalize: true 29 | temperature: 0.01 30 | data_config: 31 | data_dir: /dfs/scratch0/lorr1/projects/bootleg-data/data/korealiases_0329 32 | train_in_candidates: false 33 | data_prep_dir: prep 34 | use_entity_desc: true 35 | entity_dir: /dfs/scratch0/lorr1/projects/bootleg-data/data/korealiases_0329/entity_db 36 | max_seq_len: 128 37 | max_seq_window_len: 64 38 | max_ent_len: 128 39 | overwrite_preprocessed_data: false 40 | dev_dataset: 41 | file: dev.jsonl 42 | use_weak_label: true 43 | test_dataset: 44 | file: test_tt.jsonl 45 | use_weak_label: true 46 | train_dataset: 47 | file: train.jsonl 48 | use_weak_label: true 49 | eval_slices: 50 | - unif_all 51 | - unif_NS_all 52 | - unif_HD 53 | - unif_TO 54 | - unif_TL 55 | - unif_TS 56 | word_embedding: 57 | cache_dir: /dfs/scratch0/lorr1/projects/bootleg-data/embs/pretrained_bert_models 58 | bert_model: bert-base-uncased 59 | context_layers: 4 60 | entity_layers: 4 61 | -------------------------------------------------------------------------------- /configs/gcp/bootleg_cand_gen_test.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_epochs: 1 4 | evaluation_freq: 1 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: /home/logs/bootleg_cand_gen_1015 8 | use_exact_log_path: true 9 | checkpointing: true 10 | checkpoint_all: true 11 | checkpoint_freq: 1 12 | counter_unit: batch 13 | online_eval: false 14 | clear_intermediate_checkpoints: false 15 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 16 | l2: 0.01 17 | grad_clip: 1.0 18 | gradient_accumulation_steps: 2 19 | fp16: true 20 | writer: wandb 21 | wandb_project_name: platelet 22 | wandb_run_name: gcp_cand_gen_1015 23 | write_loss_per_step: true 24 | run_config: 25 | log_level: DEBUG 26 | eval_batch_size: 16 27 | dataloader_threads: 6 28 | dataset_threads: 20 29 | spawn_method: fork 30 | train_config: 31 | batch_size: 32 32 | model_config: 33 | hidden_size: 200 34 | normalize: true 35 | temperature: 0.01 36 | data_config: 37 | data_dir: /home/data/bootleg-data/korealiases_0329 38 | data_prep_dir: prep 39 | entity_dir: /home/data/bootleg-data/korealiases_0329/entity_db 40 | max_seq_len: 128 41 | max_seq_window_len: 64 42 | max_ent_len: 128 43 | overwrite_preprocessed_data: false 44 | dev_dataset: 45 | file: dev_test.jsonl 46 | use_weak_label: true 47 | test_dataset: 48 | file: test.jsonl 49 | use_weak_label: true 50 | train_dataset: 51 | file: train.jsonl 52 | use_weak_label: true 53 | train_in_candidates: true 54 | eval_slices: 55 | - unif_all 56 | - unif_NS_all 57 | - unif_HD 58 | - unif_TO 59 | - unif_TL 60 | - unif_TS 61 | word_embedding: 62 | cache_dir: /home/data/bootleg-archive/pretrained_bert_models 63 | bert_model: bert-base-uncased 64 | context_layers: 6 65 | entity_layers: 6 66 | -------------------------------------------------------------------------------- /configs/gcp/bootleg_test.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_steps: 212439 4 | evaluation_freq: 35407 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: /home/logs/bootleg_uncased_10tokmask 8 | use_exact_log_path: true 9 | checkpointing: true 10 | checkpoint_all: true 11 | checkpoint_freq: 1 12 | counter_unit: batch 13 | online_eval: false 14 | clear_intermediate_checkpoints: false 15 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 16 | l2: 0.01 17 | grad_clip: 1.0 18 | gradient_accumulation_steps: 2 19 | fp16: true 20 | writer: wandb 21 | wandb_project_name: platelet 22 | wandb_run_name: gcp_korealiases_00tokmask 23 | write_loss_per_step: true 24 | run_config: 25 | log_level: DEBUG 26 | eval_batch_size: 16 27 | dataloader_threads: 3 28 | dataset_threads: 20 29 | spawn_method: fork 30 | train_config: 31 | batch_size: 8 32 | model_config: 33 | hidden_size: 300 34 | normalize: true 35 | temperature: 0.01 36 | data_config: 37 | data_dir: /home/data/bootleg-data/korealiases_0329 38 | data_prep_dir: prep 39 | use_entity_desc: true 40 | context_mask_perc: 0.1 41 | entity_type_data: 42 | use_entity_types: true 43 | type_symbols_dir: type_mappings/wiki 44 | entity_kg_data: 45 | use_entity_kg: true 46 | kg_symbols_dir: kg_mappings 47 | entity_dir: /home/data/bootleg-data/korealiases_0329/entity_db 48 | max_seq_len: 128 49 | max_seq_window_len: 64 50 | max_ent_len: 128 51 | overwrite_preprocessed_data: false 52 | dev_dataset: 53 | file: dev.jsonl 54 | use_weak_label: true 55 | test_dataset: 56 | file: test.jsonl 57 | use_weak_label: true 58 | train_dataset: 59 | file: train.jsonl 60 | use_weak_label: true 61 | train_in_candidates: true 62 | eval_slices: 63 | - unif_all 64 | - unif_NS_all 65 | - unif_HD 66 | - unif_TO 67 | - unif_TL 68 | - unif_TS 69 | word_embedding: 70 | cache_dir: /home/data/bootleg-archive/pretrained_bert_models 71 | bert_model: bert-base-uncased 72 | context_layers: 6 73 | entity_layers: 6 74 | -------------------------------------------------------------------------------- /configs/gcp/bootleg_wiki.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_steps: 428648 # Total dataset size 109733685 - this gives 2 epochs with batch size 32*16 4 | evaluation_freq: 21432 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: /home/logs/bootleg_wiki 8 | use_exact_log_path: true 9 | checkpointing: true 10 | checkpoint_all: true 11 | checkpoint_freq: 1 12 | counter_unit: batch 13 | online_eval: false 14 | clear_intermediate_checkpoints: false 15 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 16 | l2: 0.01 17 | grad_clip: 1.0 18 | gradient_accumulation_steps: 1 19 | fp16: true 20 | writer: wandb 21 | wandb_project_name: platelet 22 | wandb_run_name: gcp_bootleg_wiki 23 | write_loss_per_step: true 24 | run_config: 25 | log_level: DEBUG 26 | eval_batch_size: 16 27 | dataloader_threads: 4 28 | dataset_threads: 30 29 | spawn_method: forkserver 30 | train_config: 31 | batch_size: 32 32 | model_config: 33 | hidden_size: 200 34 | normalize: true 35 | temperature: 0.01 36 | data_config: 37 | data_dir: /home/data/bootleg-data/wiki_title_0122 38 | data_prep_dir: prep 39 | use_entity_desc: true 40 | context_mask_perc: 0.0 41 | entity_type_data: 42 | use_entity_types: true 43 | type_symbols_dir: type_mappings/wiki 44 | entity_kg_data: 45 | use_entity_kg: true 46 | kg_symbols_dir: kg_mappings 47 | entity_dir: /home/data/bootleg-data/wiki_title_0122/entity_db 48 | max_seq_len: 128 49 | max_seq_window_len: 64 50 | max_ent_len: 128 51 | overwrite_preprocessed_data: false 52 | dev_dataset: 53 | file: merged_sample.jsonl 54 | use_weak_label: true 55 | test_dataset: 56 | file: merged_sample.jsonl 57 | use_weak_label: true 58 | train_dataset: 59 | file: train.jsonl 60 | use_weak_label: true 61 | train_in_candidates: true 62 | eval_slices: 63 | - unif_all 64 | - unif_NS_all 65 | - unif_HD 66 | - unif_TO 67 | - unif_TL 68 | - unif_TS 69 | word_embedding: 70 | cache_dir: /home/data/bootleg-data/pretrained_bert_models 71 | bert_model: bert-base-uncased 72 | context_layers: 6 73 | entity_layers: 6 74 | -------------------------------------------------------------------------------- /configs/gcp/launch_gcp.py: -------------------------------------------------------------------------------- 1 | import re 2 | import tempfile 3 | from pathlib import Path 4 | from subprocess import call 5 | 6 | import argh 7 | from rich.console import Console 8 | 9 | from bootleg.utils.utils import load_yaml_file 10 | 11 | console = Console(soft_wrap=True) 12 | bert_dir = tempfile.TemporaryDirectory().name 13 | 14 | checkpoint_regex = re.compile(r"checkpoint_(\d+\.{0,1}\d*).model.pth") 15 | 16 | 17 | def find_latest_checkpoint(path): 18 | path = Path(path) 19 | possible_checkpoints = [] 20 | for fld in path.iterdir(): 21 | res = checkpoint_regex.match(fld.name) 22 | if res: 23 | possible_checkpoints.append([res.group(1), fld]) 24 | if len(possible_checkpoints) <= 0: 25 | return possible_checkpoints 26 | newest_sort = sorted(possible_checkpoints, key=lambda x: float(x[0]), reverse=True) 27 | return newest_sort[0][-1] 28 | 29 | 30 | @argh.arg("--config", help="Path for config") 31 | @argh.arg("--num_gpus", help="Num gpus") 32 | @argh.arg("--batch", help="Batch size") 33 | @argh.arg("--grad_accum", help="Grad accum") 34 | @argh.arg("--cand_gen_run", help="Launch cand get") 35 | def main( 36 | config="configs/gcp/bootleg_test.yaml", 37 | num_gpus=4, 38 | batch=None, 39 | grad_accum=None, 40 | cand_gen_run=False, 41 | ): 42 | config = Path(config) 43 | config_d = load_yaml_file(config) 44 | save_path = Path(config_d["emmental"]["log_path"]) 45 | seed = config_d["emmental"].get("seed", 1234) 46 | call_file = "bootleg/run.py" if not cand_gen_run else "cand_gen/train.py" 47 | to_call = [ 48 | "python3", 49 | "-m", 50 | "torch.distributed.run", 51 | f"--nproc_per_node={num_gpus}", 52 | call_file, 53 | "--config", 54 | str(config), 55 | ] 56 | # if this is a second+ run, log path will be {log_path}_{num_steps_trained} 57 | possible_save_paths = save_path.parent.glob(f"{save_path.name}*") 58 | latest_save_path = sorted( 59 | possible_save_paths, 60 | key=lambda x: int(x.name.split("_")[-1]) 61 | if x.name.split("_")[-1].isnumeric() 62 | else 0, 63 | reverse=True, 64 | ) 65 | save_path = latest_save_path[0] if len(latest_save_path) > 0 else None 66 | 67 | if save_path is not None and save_path.exists() and save_path.is_dir(): 68 | last_checkpoint = find_latest_checkpoint(save_path) 69 | if last_checkpoint is not None: 70 | to_call.append("--emmental.model_path") 71 | to_call.append(str(save_path / last_checkpoint.name)) 72 | num_steps_trained = int( 73 | checkpoint_regex.match(last_checkpoint.name).group(1) 74 | ) 75 | assert num_steps_trained == int( 76 | float(checkpoint_regex.match(last_checkpoint.name).group(1)) 77 | ) 78 | optimizer_path = str( 79 | save_path / last_checkpoint.name.replace("model", "optimizer") 80 | ) 81 | scheduler_path = str( 82 | save_path / last_checkpoint.name.replace("model", "scheduler") 83 | ) 84 | to_call.append("--emmental.optimizer_path") 85 | to_call.append(optimizer_path) 86 | to_call.append("--emmental.scheduler_path") 87 | to_call.append(scheduler_path) 88 | to_call.append("--emmental.steps_learned") 89 | to_call.append(str(num_steps_trained)) 90 | # In case didn't get through epoch, change seed so that data is reshuffled 91 | to_call.append("--emmental.seed") 92 | to_call.append(str(seed + num_steps_trained)) 93 | to_call.append("--emmental.log_path") 94 | to_call.append( 95 | str(save_path.parent / f"{save_path.name}_{num_steps_trained}") 96 | ) 97 | if batch is not None: 98 | to_call.append("--train_config.batch_size") 99 | to_call.append(str(batch)) 100 | if grad_accum is not None: 101 | to_call.append("--emmental.gradient_accumulation_steps") 102 | to_call.append(str(grad_accum)) 103 | print(f"CALLING...{' '.join(to_call)}") 104 | call(to_call) 105 | 106 | 107 | if __name__ == "__main__": 108 | argh.dispatch_command(main) 109 | -------------------------------------------------------------------------------- /configs/standard/train.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_steps: 5 4 | evaluation_freq: 6 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: bootleg_logs 8 | checkpointing: true 9 | checkpoint_all: true 10 | checkpoint_freq: 1 11 | counter_unit: batch 12 | online_eval: false 13 | clear_intermediate_checkpoints: false 14 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 15 | l2: 0.01 16 | grad_clip: 1.0 17 | gradient_accumulation_steps: 4 18 | fp16: true 19 | run_config: 20 | log_level: DEBUG 21 | eval_batch_size: 32 22 | dataloader_threads: 4 23 | dataset_threads: 30 24 | spawn_method: fork 25 | train_config: 26 | batch_size: 32 27 | model_config: 28 | hidden_size: 200 29 | normalize: true 30 | temperature: 0.01 31 | data_config: 32 | data_dir: /lfs/raiders8/0/lorr1/output 33 | train_in_candidates: true 34 | data_prep_dir: prep 35 | use_entity_desc: true 36 | context_mask_perc: 0.0 37 | entity_type_data: 38 | use_entity_types: true 39 | type_symbols_dir: type_mappings/wiki 40 | entity_kg_data: 41 | use_entity_kg: true 42 | kg_symbols_dir: kg_mappings 43 | entity_dir: /lfs/raiders8/0/lorr1/output/entity_db 44 | max_seq_len: 128 45 | max_seq_window_len: 64 46 | max_ent_len: 128 47 | overwrite_preprocessed_data: false 48 | dev_dataset: 49 | file: dev.jsonl 50 | use_weak_label: true 51 | test_dataset: 52 | file: test.jsonl 53 | use_weak_label: true 54 | train_dataset: 55 | file: train.jsonl 56 | use_weak_label: true 57 | word_embedding: 58 | cache_dir: logs/pretrained_bert_models 59 | bert_model: bert-base-chinese 60 | context_layers: 6 61 | entity_layers: 6 62 | -------------------------------------------------------------------------------- /configs/tutorial/bootleg_wiki.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_steps: 428648 # Total dataset size 109733685 - this gives 2 epochs with batch size 8*16 4 | evaluation_freq: 42864 5 | warmup_percentage: 0.1 6 | lr_scheduler: linear 7 | log_path: /dfs/scratch0/lorr1/projects/bootleg-data/logs/logs_temp 8 | checkpointing: true 9 | checkpoint_all: true 10 | checkpoint_freq: 1 11 | counter_unit: batch 12 | online_eval: false 13 | clear_intermediate_checkpoints: false 14 | checkpoint_metric: NED/Bootleg/dev/final_loss/acc_boot:max 15 | l2: 0.01 16 | grad_clip: 1.0 17 | gradient_accumulation_steps: 1 18 | fp16: true 19 | run_config: 20 | log_level: DEBUG 21 | eval_batch_size: 16 22 | dataloader_threads: 2 23 | dataset_threads: 30 24 | spawn_method: forkserver 25 | train_config: 26 | batch_size: 32 27 | model_config: 28 | hidden_size: 200 29 | normalize: true 30 | temperature: 0.01 31 | data_config: 32 | data_dir: /dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0122 33 | data_prep_dir: prep 34 | use_entity_desc: true 35 | context_mask_perc: 0.0 36 | entity_type_data: 37 | use_entity_types: true 38 | type_symbols_dir: type_mappings/wiki 39 | entity_kg_data: 40 | use_entity_kg: true 41 | kg_symbols_dir: kg_mappings 42 | entity_dir: /dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0122/entity_db 43 | max_seq_len: 128 44 | max_seq_window_len: 64 45 | max_ent_len: 128 46 | overwrite_preprocessed_data: false 47 | eval_slices: 48 | - unif_all 49 | - unif_NS_all 50 | - unif_HD 51 | - unif_TO 52 | - unif_TL 53 | - unif_TS 54 | dev_dataset: 55 | file: merged_sample.jsonl 56 | use_weak_label: true 57 | test_dataset: 58 | file: test.jsonl 59 | use_weak_label: true 60 | train_dataset: 61 | file: train.jsonl 62 | use_weak_label: true 63 | train_in_candidates: true 64 | word_embedding: 65 | cache_dir: /dfs/scratch0/lorr1/projects/bootleg-data/embs/pretrained_bert_models 66 | bert_model: bert-base-uncased 67 | -------------------------------------------------------------------------------- /configs/tutorial/sample_config.yaml: -------------------------------------------------------------------------------- 1 | emmental: 2 | lr: 2e-5 3 | n_steps: 50 4 | steps_learned: 20 5 | evaluation_freq: 10 6 | checkpoint_freq: 1 7 | counter_unit: batch 8 | checkpoint_all: true 9 | clear_intermediate_checkpoints: false 10 | lr_scheduler: linear 11 | grad_clip: 1.0 12 | l2: 0.01 13 | log_path: logs/turtorial 14 | device: 0 15 | model_path: logs/turtorial/2021_09_03/14_09_34/aecbc69a/checkpoint_20.model.pth 16 | skip_learned_data: True 17 | run_config: 18 | eval_batch_size: 16 19 | dataloader_threads: 0 20 | dataset_threads: 2 21 | log_level: info 22 | train_config: 23 | batch_size: 2 24 | model_config: 25 | hidden_size: 64 26 | data_config: 27 | data_dir: data/sample_text_data 28 | data_prep_dir: prep 29 | entity_dir: data/sample_entity_data 30 | overwrite_preprocessed_data: false 31 | use_entity_desc: false 32 | entity_type_data: 33 | use_entity_types: true 34 | type_symbols_dir: type_mappings/wiki 35 | entity_kg_data: 36 | use_entity_kg: true 37 | kg_symbols_dir: kg_mappings 38 | max_seq_len: 64 39 | max_seq_window_len: 32 40 | max_ent_len: 64 41 | dev_dataset: 42 | file: dev.jsonl 43 | use_weak_label: true 44 | test_dataset: 45 | file: dev.jsonl 46 | use_weak_label: true 47 | train_dataset: 48 | file: train.jsonl 49 | use_weak_label: true 50 | train_in_candidates: true 51 | word_embedding: 52 | cache_dir: data/pretrained_bert_models 53 | context_layers: 1 54 | entity_layers: 1 55 | -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2id/config.json: -------------------------------------------------------------------------------- 1 | {"max_id":13449} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2id/itoexti.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/alias2id/itoexti.npy -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2id/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/alias2id/vocabulary_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2qids/max_value.json: -------------------------------------------------------------------------------- 1 | 30 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2qids/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/alias2qids/record_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/alias2qids/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/alias2qids/vocabulary_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/config.json: -------------------------------------------------------------------------------- 1 | {"max_candidates":30,"datetime":"2021-10-29 20:55:56.694750"} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/qid2eid/config.json: -------------------------------------------------------------------------------- 1 | {"max_id":1523} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/qid2eid/itoexti.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/qid2eid/itoexti.npy -------------------------------------------------------------------------------- /data/sample_entity_db/entity_mappings/qid2eid/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/entity_mappings/qid2eid/vocabulary_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/kg_mappings/config.json: -------------------------------------------------------------------------------- 1 | {"max_connections":100,"datetime":"2021-10-29 20:55:56.899194"} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/kg_mappings/relation_vocab.json: -------------------------------------------------------------------------------- 1 | {"P127":"P127","P131":"P131","P1336":"P1336","P1365":"P1365","P1366":"P1366","P1376":"P1376","P138":"P138","P144":"P144","P150":"P150","P1532":"P1532","P155":"P155","P156":"P156","P17":"P17","P171":"P171","P186":"P186","P1889":"P1889","P2746":"P2746","P279":"P279","P31":"P31","P36":"P36","P361":"P361","P366":"P366","P3842":"P3842","P460":"P460","P461":"P461","P47":"P47","P495":"P495","P501":"P501","P527":"P527","P530":"P530","P807":"P807"} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/hyena/config.json: -------------------------------------------------------------------------------- 1 | {"max_types":10} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/hyena/qid2typenames/max_value.json: -------------------------------------------------------------------------------- 1 | 10 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/hyena/qid2typenames/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/hyena/qid2typenames/record_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/hyena/qid2typenames/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/hyena/qid2typenames/vocabulary_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/relations/config.json: -------------------------------------------------------------------------------- 1 | {"max_types":10} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/relations/qid2typenames/max_value.json: -------------------------------------------------------------------------------- 1 | 10 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/relations/qid2typenames/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/relations/qid2typenames/record_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/relations/qid2typenames/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/relations/qid2typenames/vocabulary_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/wiki/config.json: -------------------------------------------------------------------------------- 1 | {"max_types":10} 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/wiki/qid2typenames/max_value.json: -------------------------------------------------------------------------------- 1 | 10 2 | -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/wiki/qid2typenames/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/wiki/qid2typenames/record_trie.marisa -------------------------------------------------------------------------------- /data/sample_entity_db/type_mappings/wiki/qid2typenames/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/data/sample_entity_db/type_mappings/wiki/qid2typenames/vocabulary_trie.marisa -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/advanced/distributed_training.rst: -------------------------------------------------------------------------------- 1 | 2 | Distributed Training 3 | ==================== 4 | 5 | We discuss how to use distributed training to train a Bootleg model on the full Wikipedia save. This tutorial assumes you have already completed the `Basic Training Tutorial`_. 6 | 7 | As Wikipedia has over 5 million entities and over 50 million sentences, training on the full Wikipedia save is computationally expensive. We recommend using a `p4d.24xlarge `_ instance on AWS to train on Wikipedia. 8 | 9 | We provide a config for training Wikipedia `here `_. Note this config is the config used to train the pretrained model provided in the `End-to-End Tutorial `_. 10 | 11 | 1. Downloading the Data 12 | ----------------------- 13 | 14 | We provide scripts to download: 15 | 16 | 17 | #. Prepped Wikipedia data (training and dev datasets) 18 | #. Wikipedia entity data and embedding data 19 | 20 | To download the Wikipedia data, run the command below with the directory to download the data to. Note that the prepped Wikipedia data will require ~200GB of disk space and will take some time to download and decompress the prepped Wikipedia data (16GB compressed, ~150GB uncompressed). 21 | 22 | .. code-block:: 23 | 24 | bash download_wiki.sh 25 | 26 | To download (2) above, run the command 27 | 28 | .. code-block:: 29 | 30 | bash download_data.sh 31 | 32 | At the end, the directory structure should be 33 | 34 | .. code-block:: 35 | 36 | 37 | wiki_data/ 38 | prep/ 39 | entity_db/ 40 | entity_mappings/ 41 | type_mappings/ 42 | kg_mappings/ 43 | prep/ 44 | 45 | 2. Setting up Distributed Training 46 | ---------------------------------- 47 | 48 | `Emmental `_, the training framework of Bootleg, supports distributed training using PyTorch's `Data Parallel `_ or `Distributed Data Parallel `_ framework. We recommend DDP for training. 49 | 50 | There is nothing that needs to change to get distributed training to work. We do, however, recommend setting the following params 51 | 52 | .. code-block:: 53 | 54 | emmental: 55 | ... 56 | distributed_backend: nccl 57 | fp16: true 58 | 59 | This allows for fp16 and making sure the ``nccl`` backend is used. Note that when training with DDP, the ``batch_size`` is **per gpu**. With standard data parallel, the ``batch_size`` is across all GPUs. 60 | 61 | From the `Basic Training Tutorial`_, recall that the directory paths should be set to where we want to save our models and read the data, including: 62 | 63 | * ``cache_dir`` in ``data_config.word_embedding`` 64 | * ``data_dir`` and ``entity_dir`` in ``data_config`` 65 | 66 | We have already set these directories in the provided Wikipedia config, but you will need to update ``data_dir`` and ``entity_dir`` to where you downloaded the data in step 1 and may want to update ``log_dir`` to where you want to save the model checkpoints and logs. 67 | 68 | 3. Training the Model 69 | --------------------- 70 | 71 | As we provide the Wikipedia data already prepped, we can jump immediately to training. To train the model with 8 gpus using DDP, we simply run: 72 | 73 | .. code-block:: 74 | 75 | python3 -m torch.distributed.run --nproc_per_node=8 bootleg/run.py --config_script configs/tutorial/wiki_uncased_ft.yaml 76 | 77 | To train using DP, simply run 78 | 79 | .. code-block:: 80 | 81 | python3 bootleg/run.py --config_script configs/tutorial/wiki_uncased_ft.yaml 82 | 83 | and Emmental will automatically using distributed training (you can turn this off by ``dataparallel: false`` in the ``emmental`` config block. 84 | 85 | Once the training begins, we should see all GPUs being utilized. 86 | 87 | If we want to change the config (e.g. change the maximum number of aliases or the maximum word token len), we would need to re-prep the data and would run the command below. Note it takes several hours to perform Wikipedia pre-processing on a 56-core machine: 88 | 89 | 4. Evaluating with Slices 90 | ------------------------- 91 | 92 | We use evaluation slices to understand the performance of Bootleg on important subsets of the dataset. To use evaluation slices, alias-entity pairs are labelled as belonging to specific slices in the ``slices`` key of the dataset. 93 | 94 | In the Wikipedia data in this tutorial, we provide three "slices" of the dev dataset in addition to the "final_loss" (all examples) slice. For each of these three slices, the alias being scored must have more than one candidate. This filters trivial examples all models get correct. 95 | 96 | 97 | * ``unif_NS_TS``: The gold entity does not occur in the training dataset (toes). 98 | * ``unif_NS_TL``: The gold entity occurs globally 10 or fewer times in the training dataset (tail). 99 | * ``unif_NS_TO``: The gold entity occurs globally between 11-1000 times in the training dataset (torso). 100 | * ``unif_NS_HD``: The gold entity occurs globally greater than 1000 times in the training dataset (head). 101 | * ``unif_NS_all``: All gold entities. 102 | 103 | To use the slices for evaluation, they must also be specified in the ``eval_slices`` section of the ``run_config`` (see the `Wikipedia config`_ as an example). 104 | 105 | When the dev evaluation occurs during training, we should see the performance on each of the slices that are specified in ``eval_slices``. These slices help us understand how well Bootleg performs on more challenging subsets. The frequency of dev evaluation can be specified by the ``evaluation_freq`` parameter in the ``emmental`` block. 106 | 107 | 108 | .. _Basic Training Tutorial: ../gettingstarted/training.html> 109 | .. _Wikipedia config: https://github.com/HazyResearch/bootleg/tree/master/configs/tutorial/wiki_uncased_ft.yaml 110 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.end2end.rst: -------------------------------------------------------------------------------- 1 | bootleg.end2end package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.end2end.annotator\_utils module 8 | --------------------------------------- 9 | 10 | .. automodule:: bootleg.end2end.annotator_utils 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.end2end.bootleg\_annotator module 16 | ----------------------------------------- 17 | 18 | .. automodule:: bootleg.end2end.bootleg_annotator 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bootleg.end2end.extract\_mentions module 24 | ---------------------------------------- 25 | 26 | .. automodule:: bootleg.end2end.extract_mentions 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: bootleg.end2end 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.layers.rst: -------------------------------------------------------------------------------- 1 | bootleg.layers package 2 | ====================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.layers.alias\_to\_ent\_encoder module 8 | --------------------------------------------- 9 | 10 | .. automodule:: bootleg.layers.alias_to_ent_encoder 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.layers.bert\_encoder module 16 | ----------------------------------- 17 | 18 | .. automodule:: bootleg.layers.bert_encoder 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: bootleg.layers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.rst: -------------------------------------------------------------------------------- 1 | bootleg package 2 | =============== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | bootleg.end2end 11 | bootleg.layers 12 | bootleg.slicing 13 | bootleg.symbols 14 | bootleg.tasks 15 | bootleg.utils 16 | 17 | Submodules 18 | ---------- 19 | 20 | bootleg.data module 21 | ------------------- 22 | 23 | .. automodule:: bootleg.data 24 | :members: 25 | :undoc-members: 26 | :show-inheritance: 27 | 28 | bootleg.dataset module 29 | ---------------------- 30 | 31 | .. automodule:: bootleg.dataset 32 | :members: 33 | :undoc-members: 34 | :show-inheritance: 35 | 36 | bootleg.extract\_all\_entities module 37 | ------------------------------------- 38 | 39 | .. automodule:: bootleg.extract_all_entities 40 | :members: 41 | :undoc-members: 42 | :show-inheritance: 43 | 44 | bootleg.run module 45 | ------------------ 46 | 47 | .. automodule:: bootleg.run 48 | :members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | bootleg.scorer module 53 | --------------------- 54 | 55 | .. automodule:: bootleg.scorer 56 | :members: 57 | :undoc-members: 58 | :show-inheritance: 59 | 60 | bootleg.task\_config module 61 | --------------------------- 62 | 63 | .. automodule:: bootleg.task_config 64 | :members: 65 | :undoc-members: 66 | :show-inheritance: 67 | 68 | Module contents 69 | --------------- 70 | 71 | .. automodule:: bootleg 72 | :members: 73 | :undoc-members: 74 | :show-inheritance: 75 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.slicing.rst: -------------------------------------------------------------------------------- 1 | bootleg.slicing package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.slicing.slice\_dataset module 8 | ------------------------------------- 9 | 10 | .. automodule:: bootleg.slicing.slice_dataset 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: bootleg.slicing 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.symbols.rst: -------------------------------------------------------------------------------- 1 | bootleg.symbols package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.symbols.constants module 8 | -------------------------------- 9 | 10 | .. automodule:: bootleg.symbols.constants 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.symbols.entity\_profile module 16 | -------------------------------------- 17 | 18 | .. automodule:: bootleg.symbols.entity_profile 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bootleg.symbols.entity\_symbols module 24 | -------------------------------------- 25 | 26 | .. automodule:: bootleg.symbols.entity_symbols 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | bootleg.symbols.kg\_symbols module 32 | ---------------------------------- 33 | 34 | .. automodule:: bootleg.symbols.kg_symbols 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | bootleg.symbols.type\_symbols module 40 | ------------------------------------ 41 | 42 | .. automodule:: bootleg.symbols.type_symbols 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | Module contents 48 | --------------- 49 | 50 | .. automodule:: bootleg.symbols 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.tasks.rst: -------------------------------------------------------------------------------- 1 | bootleg.tasks package 2 | ===================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.tasks.entity\_gen\_task module 8 | -------------------------------------- 9 | 10 | .. automodule:: bootleg.tasks.entity_gen_task 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.tasks.ned\_task module 16 | ------------------------------ 17 | 18 | .. automodule:: bootleg.tasks.ned_task 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: bootleg.tasks 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.utils.classes.rst: -------------------------------------------------------------------------------- 1 | bootleg.utils.classes package 2 | ============================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.utils.classes.aliasmention\_trie module 8 | ----------------------------------------------- 9 | 10 | .. automodule:: bootleg.utils.classes.aliasmention_trie 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.utils.classes.comment\_json module 16 | ------------------------------------------ 17 | 18 | .. automodule:: bootleg.utils.classes.comment_json 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bootleg.utils.classes.dotted\_dict module 24 | ----------------------------------------- 25 | 26 | .. automodule:: bootleg.utils.classes.dotted_dict 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: bootleg.utils.classes 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.utils.parser.rst: -------------------------------------------------------------------------------- 1 | bootleg.utils.parser package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.utils.parser.bootleg\_args module 8 | ----------------------------------------- 9 | 10 | .. automodule:: bootleg.utils.parser.bootleg_args 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.utils.parser.emm\_parse\_args module 16 | -------------------------------------------- 17 | 18 | .. automodule:: bootleg.utils.parser.emm_parse_args 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bootleg.utils.parser.parser\_utils module 24 | ----------------------------------------- 25 | 26 | .. automodule:: bootleg.utils.parser.parser_utils 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: bootleg.utils.parser 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.utils.preprocessing.rst: -------------------------------------------------------------------------------- 1 | bootleg.utils.preprocessing package 2 | =================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | bootleg.utils.preprocessing.compute\_statistics module 8 | ------------------------------------------------------ 9 | 10 | .. automodule:: bootleg.utils.preprocessing.compute_statistics 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | bootleg.utils.preprocessing.count\_body\_part\_size module 16 | ---------------------------------------------------------- 17 | 18 | .. automodule:: bootleg.utils.preprocessing.count_body_part_size 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | bootleg.utils.preprocessing.gen\_alias\_cand\_map module 24 | -------------------------------------------------------- 25 | 26 | .. automodule:: bootleg.utils.preprocessing.gen_alias_cand_map 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | bootleg.utils.preprocessing.gen\_entity\_mappings module 32 | -------------------------------------------------------- 33 | 34 | .. automodule:: bootleg.utils.preprocessing.gen_entity_mappings 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | 39 | bootleg.utils.preprocessing.get\_train\_qid\_counts module 40 | ---------------------------------------------------------- 41 | 42 | .. automodule:: bootleg.utils.preprocessing.get_train_qid_counts 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | bootleg.utils.preprocessing.sample\_eval\_data module 48 | ----------------------------------------------------- 49 | 50 | .. automodule:: bootleg.utils.preprocessing.sample_eval_data 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | Module contents 56 | --------------- 57 | 58 | .. automodule:: bootleg.utils.preprocessing 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | -------------------------------------------------------------------------------- /docs/source/apidocs/bootleg.utils.rst: -------------------------------------------------------------------------------- 1 | bootleg.utils package 2 | ===================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | bootleg.utils.classes 11 | bootleg.utils.parser 12 | bootleg.utils.preprocessing 13 | 14 | Submodules 15 | ---------- 16 | 17 | bootleg.utils.data\_utils module 18 | -------------------------------- 19 | 20 | .. automodule:: bootleg.utils.data_utils 21 | :members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | bootleg.utils.eval\_utils module 26 | -------------------------------- 27 | 28 | .. automodule:: bootleg.utils.eval_utils 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | 33 | bootleg.utils.model\_utils module 34 | --------------------------------- 35 | 36 | .. automodule:: bootleg.utils.model_utils 37 | :members: 38 | :undoc-members: 39 | :show-inheritance: 40 | 41 | bootleg.utils.utils module 42 | -------------------------- 43 | 44 | .. automodule:: bootleg.utils.utils 45 | :members: 46 | :undoc-members: 47 | :show-inheritance: 48 | 49 | Module contents 50 | --------------- 51 | 52 | .. automodule:: bootleg.utils 53 | :members: 54 | :undoc-members: 55 | :show-inheritance: 56 | -------------------------------------------------------------------------------- /docs/source/apidocs/modules.rst: -------------------------------------------------------------------------------- 1 | bootleg 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | bootleg 8 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | # sys.path.insert(0, os.path.abspath("")) 17 | sys.path.insert(0, os.path.abspath(".")) 18 | # sys.path.insert(0, os.path.abspath("..")) 19 | # sys.path.insert(0, os.path.abspath("../..")) 20 | # sys.setrecursionlimit(1500) 21 | 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = "Bootleg" 26 | copyright = "2021, Laurel Orr" 27 | author = "Laurel Orr" 28 | 29 | # The full version, including alpha/beta/rc tags 30 | release = "v1.1.0dev1" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # Add any Sphinx extension module names here, as strings. They can be 36 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 37 | # ones. 38 | extensions = [ 39 | "sphinx.ext.autodoc", 40 | "sphinx.ext.napoleon", 41 | "sphinx.ext.viewcode", 42 | "sphinx_rtd_theme", 43 | "nbsphinx", 44 | "recommonmark", 45 | ] 46 | 47 | # Add any paths that contain templates here, relative to this directory. 48 | templates_path = ["_templates"] 49 | 50 | # List of patterns, relative to source directory, that match files and 51 | # directories to ignore when looking for source files. 52 | # This pattern also affects html_static_path and html_extra_path. 53 | exclude_patterns = [] 54 | 55 | 56 | # -- Options for HTML output ------------------------------------------------- 57 | 58 | # The theme to use for HTML and HTML Help pages. See the documentation for 59 | # a list of builtin themes. 60 | # 61 | html_theme = "sphinx_rtd_theme" 62 | html_theme_options = {"navigation_depth": 2} 63 | 64 | # Add any paths that contain custom static files (such as style sheets) here, 65 | # relative to this directory. They are copied after the builtin static files, 66 | # so a file named "default.css" will overwrite the builtin "default.css". 67 | html_static_path = ["_static"] 68 | -------------------------------------------------------------------------------- /docs/source/dev/changelog.rst: -------------------------------------------------------------------------------- 1 | Changelog 2 | ========= 3 | 4 | All notable changes to this project will be documented in this file. 5 | 6 | The format is based on `Keep a Changelog`_ and this project adheres to `Semantic Versioning 2.0.0`_ conventions. The maintainers will create a git tag for each release and increment the version number found in `bootleg/\_version.py`_ accordingly. We release tagged versions to `PyPI`_ automatically using `GitHub Actions`_. 7 | 8 | .. note:: 9 | Bootleg is still under active development and APIs may still change rapidly. Until we release v1.0.0, changes in MINOR version indicate backward incompatible changes. 10 | 11 | .. include:: ../../../CHANGELOG.rst 12 | 13 | .. _Keep a Changelog: https://keepachangelog.com/en/1.0.0/ 14 | .. _PyPI: https://pypi.org/project/bootleg/ 15 | .. _Semantic Versioning 2.0.0: https://semver.org/ 16 | .. _GitHub Actions: https://github.com/HazyResearch/bootleg/actions 17 | .. _bootleg/\_version.py: https://github.com/HazyResearch/bootleg/blob/master/bootleg/_version.py 18 | -------------------------------------------------------------------------------- /docs/source/dev/codestyle.rst: -------------------------------------------------------------------------------- 1 | Code Style 2 | ========== 3 | 4 | For code consistency, we have a `pre-commit`_ configuration file so that you can easily install pre-commit hooks to run style checks before you commit your files. You can setup our pre-commit hooks by running:: 5 | 6 | $ pip install .[dev] 7 | $ pre-commit install 8 | 9 | Or, just run:: 10 | 11 | $ make dev 12 | 13 | Now, each time you commit, checks will be run using the packages explained below. 14 | 15 | We use `black`_ as our Python code formatter with its default settings. Black helps minimize the line diffs and allows you to not worry about formatting during your own development. Just run black on each of your files before committing them. 16 | 17 | .. tip:: 18 | 19 | Whatever editor you use, we recommend checking out `black editor integrations`_ to help make the code formatting process just a few keystrokes. 20 | 21 | 22 | For sorting imports, we reply on `isort`_. Our repository already includes a ``.isort.cfg`` that is compatible with black. You can run a code style check on your local machine by running our checks:: 23 | 24 | $ make check 25 | 26 | .. _pre-commit: https://pre-commit.com/ 27 | .. _isort: https://github.com/timothycrosley/isort 28 | .. _black editor integrations: https://github.com/ambv/black#editor-integration 29 | .. _black: https://github.com/ambv/black 30 | -------------------------------------------------------------------------------- /docs/source/dev/install.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | To test changes in the package, you install it in `editable mode`_ locally in your virtualenv by running:: 5 | 6 | $ make dev 7 | 8 | This will also install our pre-commit hooks and local packages needed for style checks. 9 | 10 | .. tip:: 11 | 12 | If you need to install a locally edited version of bootleg in a separate location, such as an application, you can directly install your locally modified version:: 13 | 14 | $ pip install -e path/to/bootleg/ 15 | 16 | in the virtualenv of your application. 17 | 18 | Note, you can test the `pip` downloadable version using `TestPyPI `_. To handle dependencies, run 19 | 20 | .. code-block:: 21 | 22 | pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple bootleg 23 | 24 | .. _editable mode: https://packaging.python.org/tutorials/distributing-packages/#working-in-development-mode 25 | -------------------------------------------------------------------------------- /docs/source/dev/tests.rst: -------------------------------------------------------------------------------- 1 | Testing 2 | ======= 3 | 4 | We use pytest_ to run our tests. Our tests are all located in the ``test`` directory in the repo, and are meant to be run *after* installing_ Bootleg locally. 5 | 6 | To run our tests, just run:: 7 | 8 | $ make test 9 | 10 | .. _pytest: https://docs.pytest.org/en/latest/ 11 | .. _installing: install.html 12 | -------------------------------------------------------------------------------- /docs/source/gettingstarted/config.rst: -------------------------------------------------------------------------------- 1 | Configuring Bootleg 2 | ==================== 3 | 4 | By default, Bootleg loads the default config from `bootleg/utils/parser/bootleg_args.py <../apidocs/bootleg.utils.parser.html#module-bootleg.utils.parser.bootleg_args>`_. When running a Bootleg model, the user may pass in a custom JSON or YAML config via:: 5 | 6 | python3 bootleg/run.py --config_script 7 | 8 | This will override all default values. Further, if a user wishes to overwrite a param from the command line, they can pass in the value, using the dotted path of the argument. For example, to overwrite the data directory (the param ``data_config.data_dir``, the user can enter:: 9 | 10 | python3 bootleg/run.py --config_script --data_config.data_dir 11 | 12 | Bootleg will save the run config (as well as a fully parsed verison with all defaults) in the log directory. 13 | 14 | Finally, when evaluating Bootleg using the annotator, Bootleg processes possible mentions in text with three environment flags: ``BOOTLEG_STRIP``, ``BOOTLEG_LOWER``, ``BOOTLEG_LANG_CODE``. The first sets the language to use for Spacy. The second is if the user wants to strip punctuation on mentions (set to False by default). The third is if the user wants to call ``.lower()`` (set to True by default). 15 | 16 | Emmental Config 17 | ________________ 18 | 19 | As Bootleg uses Emmental_, the training parameters (e.g., learning rate) are set and handled by Emmental. We provide all Emmental params, as well as our defaults, at `bootleg/utils/parser/emm_parse_args.py <../apidocs/bootleg.utils.parser.html#module-bootleg.utils.parser.emm_parse_args>`_. All Emmental params are under the ``emmental`` configuration group. For example, to change the learning rate and number of epochs in a config, add 20 | 21 | .. code-block:: 22 | 23 | emmental: 24 | lr: 1e-4 25 | n_epochs: 10 26 | run_config: 27 | ... 28 | 29 | You can also change Emmental params by the command line with ``--emmental. ``. 30 | 31 | Example Training Config 32 | ________________________ 33 | An example training config is shown below 34 | 35 | .. code-block:: 36 | 37 | emmental: 38 | lr: 2e-5 39 | n_epochs: 3 40 | evaluation_freq: 0.2 41 | warmup_percentage: 0.1 42 | lr_scheduler: linear 43 | log_path: logs/wiki 44 | l2: 0.01 45 | grad_clip: 1.0 46 | fp16: true 47 | run_config: 48 | eval_batch_size: 32 49 | dataloader_threads: 4 50 | dataset_threads: 50 51 | train_config: 52 | batch_size: 32 53 | model_config: 54 | hidden_size: 200 55 | data_config: 56 | data_dir: bootleg-data/data/wiki_title_0122 57 | data_prep_dir: prep 58 | use_entity_desc: true 59 | entity_type_data: 60 | use_entity_types: true 61 | type_symbols_dir: type_mappings/wiki 62 | entity_kg_data: 63 | use_entity_kg: true 64 | kg_symbols_dir: kg_mappings 65 | entity_dir: bootleg-data/data/wiki_title_0122/entity_db 66 | max_seq_len: 128 67 | max_seq_window_len: 64 68 | max_ent_len: 128 69 | overwrite_preprocessed_data: false 70 | dev_dataset: 71 | file: dev.jsonl 72 | use_weak_label: true 73 | test_dataset: 74 | file: test.jsonl 75 | use_weak_label: true 76 | train_dataset: 77 | file: train.jsonl 78 | use_weak_label: true 79 | train_in_candidates: true 80 | word_embedding: 81 | cache_dir: bootleg-data/embs/pretrained_bert_models 82 | bert_model: bert-base-uncased 83 | 84 | Default Config 85 | _______________ 86 | The default Bootleg config is shown below 87 | 88 | .. literalinclude:: ../../../bootleg/utils/parser/bootleg_args.py 89 | 90 | 91 | .. _Emmental: https://github.com/SenWu/Emmental 92 | -------------------------------------------------------------------------------- /docs/source/gettingstarted/emmental.rst: -------------------------------------------------------------------------------- 1 | Emmental 2 | =============== 3 | We use the Emmental_ framework. Emmental is a framework for building multimodal multi-task learning systems. A key feature of Emmental is its task flow design where models are defined by the data flow through modules. By reusing modules in different tasks, you easy can extend your model to a multi-task setting. 4 | 5 | We high encourage you to check out the `Emmental docs`_ and `Emmental tutorials`_ to understand the framework. 6 | 7 | 8 | .. _Emmental: https://github.com/SenWu/Emmental 9 | .. _Emmental docs: https://emmental.readthedocs.io/en/latest/ 10 | .. _Emmental tutorials: https://github.com/SenWu/emmental-tutorials -------------------------------------------------------------------------------- /docs/source/gettingstarted/entity_profile.rst: -------------------------------------------------------------------------------- 1 | Entity Profiles 2 | ================= 3 | Bootleg uses Wikipedia and Wikidata to collect and generate a entity database of metadata associated with an entity. We support both non-structural data (e.g., the title of an entity) and structural data (e.g., the type or relationship of an entity). We now describe how to generate entity profile data from scratch to be used for training and the structure of the profile data we already provide. 4 | 5 | Generating Profiles 6 | -------------------- 7 | The database of entity data starts with a simple ``jsonl`` file of data associated with an entity. Specifically, each line is a JSON object 8 | 9 | .. code-block:: JSON 10 | 11 | { 12 | "entity_id": "Q16240866", 13 | "mentions": [["benin national u20 football team",1],["benin national under20 football team",1]], 14 | "title": "Forbidden fruit", 15 | "description": "A fruit that once was considered not to be eaten", 16 | "types": {"hyena": [""], 17 | "wiki": ["national association football team"], 18 | "relations":["country for sport","sport"]}, 19 | "relations": [ 20 | {"relation":"P1532","object":"Q962"}, 21 | ], 22 | } 23 | 24 | The ``entity_id`` gives a unique string identifier of the entity. It does *not* have to start with a ``Q``. As we normalize to Wikidata, our entities are referred to as QIDs. The ``mentions`` provides a list of known aliases to the entity and a prior score associated with that mention indicating the strength of association. The score is used to order the candidates. The ``types`` provides the different types and entity is and supports different type systems. In the example above, the two type systems are ``hyena`` and ``wiki``. We also have a ``relations`` type system which treats the relationships an entity participates in as types. The ``relations`` JSON field provides the actual KG relationship triples where ``entity_id`` is the head. 25 | 26 | .. note:: 27 | 28 | By default, Bootleg assigns the score for each mentions as being the global entity count in Wikipedia. We empirically found this was a better scoring method for incorporating Wikidata "also known as" aliases that did not appear in Wikipedia. This means the scores for the mentions for a single entity will be the same. 29 | 30 | We provide a more complete `sample of raw profile data `_ to look at. 31 | 32 | Once the data is ready, we provide an `EntityProfile <../apidocs/bootleg.symbols.html#module-bootleg.symbols.entity_profile>`_ API to build and interact with the profile data. To create an entity profile for the model from the raw ``jsonl`` data, run 33 | 34 | .. code-block:: python 35 | 36 | from bootleg.symbols.entity_profile import EntityProfile 37 | path_to_file = "data/sample_raw_entity_data/raw_profile.jsonl" 38 | # edit_mode means you are allowed to modify the profile 39 | ep = EntityProfile.load_from_jsonl(path_to_file, edit_mode=True) 40 | 41 | .. note:: 42 | 43 | By default, we assume that each alias can have a maximum of 30 candidates, 10 types, and 100 connections. You can change these by adding ``max_candidates``, ``max_types``, and ``max_connections`` as keyword arguments to ``load_from_jsonl``. Note that increasing the number of maximum candidates increases the memory required for training and inference. 44 | 45 | Profile API 46 | -------------------- 47 | Now that the profile is loaded, you can interact with the metadata and change it. For example, to get the title and add a type mapping, you'd run 48 | 49 | .. code-block:: python 50 | 51 | ep.get_title("Q16240866") 52 | # This is adding the type "country" to the "wiki" type system 53 | ep.add_type("Q16240866", "sports team", "wiki") 54 | 55 | Once ready to train or run a model with the profile data, simply save it 56 | 57 | .. code-block:: python 58 | 59 | ep.save("data/sample_entity_db") 60 | 61 | We have already provided the saved dump at ``data/sample_entity_data``. 62 | 63 | See our `entity profile tutorial `_ for a more complete walkthrough notebook of the API. 65 | 66 | Training with a Profile 67 | ------------------------ 68 | Inside the saved folder for the profile, all the mappings needed to run a Bootleg model are provided. There are three subfolders as described below. Note that we use the word ``alias`` and ``mention`` interchangeably. 69 | 70 | * ``entity_mappings``: This folder contains non-structural entity data. 71 | * ``qid2eid``: This is a folder containing a Trie mapping from entity id (we refer to this as QID) to an entity index used internally to extract embeddings. Note that these entity ids start at 1 (0 index is reserved for a "not in candidate list" entity). We use Wikidata QIDs in our tutorials and documentation but any string identifier will work. 72 | * ``qid2title.json``: This is a mapping from entity QID to entity Wikipedia title. 73 | * ``qid2desc.json``: This is a mapping from entity QID to entity Wikipedia description. 74 | * ``alias2qids``: This is a folder containing a RecordTrie mapping from possible mentions (or aliases) to a list possible candidates. We restrict our candidate lists to be a predefined max length, typically 30. Each item in the list is a pair of [QID, QID score] values. The QID score is used for sorting candidates before filtering to the top 30. The scores are otherwise not used in Bootleg. This mapping is mined from both Wikipedia and Wikidata (reach out with a github issue if you want to know more). 75 | * ``alias2id``: This is a folder containing a Trie mapping from alias to alias index used internally by the model. 76 | * ``config.json``: This gives metadata associated with the entity data. Specifically, the maximum number of candidates. 77 | * ``type_mappings``: This folder contains type entity data for each type system subfolder. Inside each subfolder are the following files. 78 | * ``qid2typenames``: Folder containing a RecordTrie mapping from entity QID to a list of type names. 79 | * ``config.json``: Contains metadata of the maximum number of types allowed for an entity. 80 | * ``kg_mappings``: This folder contains relationship entity data. 81 | * ``qid2relations``: Folder containing a RecordTrie mapping from entity QID to relations to list of tail QIDs associated with the entity QID. 82 | * ``config.json``: Contains metadata of the maximum number of tail connections allowed for a particular head entity and relation. 83 | 84 | .. note:: 85 | 86 | In Bootleg, we add types from a selected type system and add KG relationship triples to our entity encoder. 87 | 88 | .. note:: 89 | 90 | In our public ``entity_db`` provided to run Bootleg models, we also provide ``alias2qids_unfiltered.json`` which provides our unfiltered, raw candidate mappings. We filter noisy aliases before running mention extraction. 91 | 92 | Given this metadata, you simply need to specify the types, relation mappings and correct folder structures in a Bootleg training `config `_. Specifically, these are the config parameters that need to be set to be associated with an entity profile. 93 | 94 | .. code-block:: 95 | 96 | data_config: 97 | entity_dir: data/sample_entity_data 98 | use_entity_desc: true 99 | entity_type_data: 100 | use_entity_types: true 101 | type_symbols_dir: type_mappings/wiki 102 | entity_kg_data: 103 | use_entity_kg: true 104 | kg_symbols_dir: kg_mappings 105 | 106 | See our `example config `_ 107 | for a full reference, and see our `entity profile tutorial `_ for some methods to help modify 109 | configs to map to the entity profile correctly. 110 | -------------------------------------------------------------------------------- /docs/source/gettingstarted/input_data.rst: -------------------------------------------------------------------------------- 1 | Inputs 2 | ============== 3 | Given an input sentence, Bootleg outputs the entities that participate in the text. For example, given the sentence 4 | 5 | ``Where is Lincoln in Logan County`` 6 | 7 | Bootleg should output that Lincoln refers to Lincoln IL and Logan County to Logan County IL. 8 | 9 | This disambiguation occurs in two parts. The first, described here, is mention extraction and candidate generation, where phrases in the input text are extracted to be disambiguation. For example, in the sentence above, the phrases "Lincoln" and "Logan County" should be extracted. Each phrase to be disambiguated is called a mention (or alias). Instead of disambiguating against all entities in Wikipedia, Bootleg uses predefined candidate maps that provide a small subset of possible entity candidates for each mention. The second step, described in `Bootleg Model`_, is the disambiguation using Bootleg's neural model. 10 | 11 | To understand how we do mention extraction and candidate generation, we first need to describe the profile data we have associated with an entity. Then we will describe how we perform mention extraction. Finally, we will provide details on the input data provided to Bootleg. Take a look at our `tutorials `_ to see it in action. 12 | 13 | Entity Data 14 | -------------------- 15 | Bootleg uses Wikipedia and Wikidata to collect and generate a entity database of metadata associated with an entity. This is all located in ``entity_db`` and contains mappings from entities to structural data and possible mention. We describe the entity profiles in more details and how to generate them on our `entity profile `_ page. For reference, we have an `EntityProfile <../apidocs/bootleg.symbols.html#module-bootleg.symbols.entity_profile>`_ class that loads and manages this metadata. 16 | 17 | As our profile data does give us mentions that are associated with each entity, we now need to describe how we generate mentions. 18 | 19 | Mention Extraction 20 | ------------------ 21 | Our mention extraction is a simple n-gram search over the input sentence (see `bootleg/end2end/extract_mentions.py <../apidocs/bootleg.end2end.html#module-bootleg.end2end.extract_mentions>`_). Starting from the largest possible n-grams and working towards single word mentions, we iterate over the sentence and see if any n-gram is a hit in our ``alias2qid`` mapping. If it is, we extract that mention. This enusre that each mention has a set of candidates. 22 | 23 | To prevent extracting noisy mentions, like the word "the", we filter our alias maps to only have words that appear approximately more that 1.5% of the time as mentions in our training data. 24 | 25 | The input format is in ``jsonl`` format where each line is a json object of the form 26 | 27 | * ``sentence``: input sentence. 28 | 29 | We output a jsonl with 30 | 31 | * ``sentence``: input sentence. 32 | * ``aliases``: list of extracted mentions. 33 | * ``spans``: list of word offsets [inclusive, exclusive) for each alias. 34 | 35 | Textual Input 36 | ------------------ 37 | Once we have mentions and candidates, we are ready to run our Bootleg model. The raw input format is in ``jsonl`` format where each line is a json object. We have one json per sentence in our training data with the following files 38 | 39 | * ``sentence``: input sentence. 40 | * ``sent_idx_unq``: unique sentence index. 41 | * ``aliases``: list of extracted mentions. 42 | * ``qids``: list of gold entity id (if known). We use canonical Wikidata QIDs in our tutorials and documentation, but any id used in the entity metadata will work. The id can be ``Q-1`` if unknown, but you _must_ provide gold QIDs for training data. 43 | * ``spans``: list of word offsets [inclusive, exclusive) for each alias. 44 | * ``gold``: list of booleans if the alias is a gold anchor link from Wikipedia or a weakly labeled link. 45 | * ``slices``: list of json slices for evaluation. See `advanced training <../advanced/distributed_training.html>`_ for details. 46 | 47 | For example, the input for the sentence above is 48 | 49 | .. code-block:: JSON 50 | 51 | { 52 | "sentence": "Where is Lincoln in Logan County", 53 | "sent_idx_unq": 0, 54 | "aliases": ["lincoln", "logan county"], 55 | "qids": ["Q121", "Q???"], 56 | "spans": [[2,3], [4,6]], 57 | "gold": [True, True], 58 | "slices": {} 59 | } 60 | 61 | For more details on training, see our `training tutorial `_. 62 | 63 | .. _Bootleg Model: model.html 64 | .. _tutorials: tutorials.html 65 | .. _Emmental: https://github.com/SenWu/Emmental -------------------------------------------------------------------------------- /docs/source/gettingstarted/install.rst: -------------------------------------------------------------------------------- 1 | Install 2 | ======= 3 | `Bootleg `_ requires Python 3.6 or later:: 4 | 5 | git clone git@github.com:HazyResearch/bootleg bootleg 6 | cd bootleg 7 | python3 setup.py install 8 | 9 | 10 | .. note:: 11 | 12 | You will need at least 40 GB of disk space, 12 GB of GPU memory, and 35 GB of CPU memory to run our model. 13 | -------------------------------------------------------------------------------- /docs/source/gettingstarted/model.rst: -------------------------------------------------------------------------------- 1 | Model Overview 2 | ============== 3 | Given an input sentence, list of mentions to be disambiguated, and list of possible candidates for each mention (described in `Input Data`_), Bootleg outputs the most likely candidate for each mention. Bootleg's model is a biencoder architecture and consists of two components: the entity encoder and context encoder. For each entity candidate, the entity encoder generates an embedding representing this entity from a textual input containing entity information such as the title, description, and types. The context encoder embeds the mention and its surrounded context. The selected candidate is the one with the highest dot product. 4 | 5 | We now describe each step in detail and explain how to add/remove different parts of the entity encoder in our `Bootleg Config`_. 6 | 7 | Entity Encoder 8 | -------------------------- 9 | The entity encoder is a BERT Transformer that takes a textual input for an entity and feeds it through BERT. During training, we take the ``[CLS]`` token as the entity embedding. There are four pieces of information we add to the textual input for an entity: 10 | 11 | * ``title``: Entity title. Comes from ``qid2title.json``. This is always used. 12 | * ``description``: Entity description. Comes from ``qid2desc.json``. This is toggled on/off. 13 | * ``type``: Entity type from one of the type systems specified in the config. If the entity has multiple types, we add them to the input as `` ; ; ...`` 14 | * ``KG``: Entity KG relations specified in the config. We add KG relations to the input as `` ; ; ...`` where the head of each triple is the entity in question. 15 | 16 | The final entity input is `` [SEP] <types> [SEP] <relations> [SEP] <description>``. 17 | 18 | You control what inputs are added by the following part in the input config. All the relevant entity encoder code is in `bootleg/dataset.py <../apidocs/bootleg.datasets.html>`_. 19 | 20 | .. code-block:: 21 | 22 | data_config: 23 | ... 24 | use_entity_desc: true 25 | entity_type_data: 26 | use_entity_types: true 27 | type_symbols_dir: type_mappings/wiki 28 | entity_kg_data: 29 | use_entity_kg: true 30 | kg_symbols_dir: kg_mappings 31 | max_seq_len: 128 32 | max_seq_window_len: 64 33 | max_ent_len: 128 34 | 35 | 36 | Context Encoder 37 | ------------------ 38 | Like the entity encoder, our context encode takes the context of a mention and feeds it through a BERT Transformer. The ``[CLS]`` token is used as th e relevant mention embedding. To allow BERT to understand where the mention is, we separate it by ``[ENT_START]`` and ``[ENT_END]`` clauses. As shown above, you can specify the maximum sequence length for the context encoder and the maximum window length. All the relevant context encoder code is in `bootleg/dataset.py <../apidocs/bootleg.datasets.html>`_. 39 | 40 | .. _Input Data: input_data.html 41 | .. _Bootleg Config: config.html 42 | -------------------------------------------------------------------------------- /docs/source/gettingstarted/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart 2 | ============= 3 | 4 | Getting started is easy. Run the following. This will download our default model. 5 | 6 | .. note:: 7 | 8 | You will need at least 40 GB of disk space, 12 GB of GPU memory, and 35 GB of CPU memory to run our model. When running for the first time, it will take 10 plus minutes for everything to download and load correctly, depending on network speeds. 9 | 10 | .. code-block:: 11 | 12 | from bootleg.end2end.bootleg_annotator import BootlegAnnotator 13 | ann = BootlegAnnotator() 14 | ann.label_mentions("How many people are in Lincoln")["titles"] 15 | 16 | You can also pass in multiple sentences:: 17 | 18 | ann.label_mentions(["I am in Lincoln", "I am Lincoln", "I am driving a Lincoln"])["titles"] 19 | 20 | Or, you can decide to use a different model (the choices are bootleg_cased, bootleg_uncased, bootleg_cased_mini, and bootleg_uncased_mini - default is bootleg_uncased):: 21 | 22 | ann = BootlegAnnotator(model_name="bootleg_uncased") 23 | ann.label_mentions("How many people are in Lincoln")["titles"] 24 | 25 | Other initialization parameters are at `bootleg/end2end/bootleg_annotator.py <../apidocs/bootleg.end2end.html#module-bootleg.end2end.bootleg_annotator>`_. 26 | 27 | Check out our `tutorials <https://github.com/HazyResearch/bootleg/tree/master/tutorials>`_ for more help getting started. 28 | 29 | Faster Inference 30 | -------------------- 31 | For improved speed, you can pass in a static matrix of all entity embeddings downloaded from `here <https://bootleg-ned-data.s3-us-west-1.amazonaws.com/models/latest/bootleg_uncased_entity_embs.npy.tar.gz>`_. 32 | 33 | Then, our annotator can be run as:: 34 | 35 | ann = BootlegAnnotator(entity_embs_path=<PATH TO UNTARRED EMBEDDING FILE>) 36 | ann.label_mentions("How many people are in Lincoln")["titles"] 37 | 38 | 39 | .. tip:: 40 | 41 | If you have a larger amount of data to disambiguate, checkout out our `end-to-end tutorial <https://github.com/HazyResearch/bootleg/tree/master/tutorials/end2end_ned_tutorial.ipynb>`_ showing a more optimized end-to-end pipeline. 42 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Bootleg 2 | ==================================== 3 | 4 | Bootleg_ is a named entity disambiguation (NED) system that links mentions in text to entities and produces contextual entity embeddings. 5 | 6 | Bootleg is still *actively under development*, so feedback and 7 | contributions are welcome. Submit bugs in the Issues_ section or feel free to 8 | submit your contributions as a pull request. 9 | 10 | .. _Issues: https://github.com/HazyResearch/bootleg/issues/ 11 | .. _Bootleg: https://github.com/HazyResearch/bootleg 12 | 13 | .. toctree:: 14 | :maxdepth: 2 15 | :caption: Getting Started 16 | 17 | gettingstarted/install 18 | gettingstarted/quickstart 19 | gettingstarted/emmental 20 | gettingstarted/entity_profile 21 | gettingstarted/input_data 22 | gettingstarted/model 23 | gettingstarted/training 24 | gettingstarted/config 25 | 26 | .. toctree:: 27 | :maxdepth: 2 28 | :caption: Advanced 29 | 30 | advanced/distributed_training 31 | 32 | .. toctree:: 33 | :maxdepth: 2 34 | :caption: Developer Documentation 35 | 36 | dev/changelog 37 | dev/install 38 | dev/tests 39 | dev/codestyle 40 | 41 | .. toctree:: 42 | :maxdepth: 2 43 | :caption: API Docs 44 | 45 | apidocs/modules 46 | 47 | .. Indices and tables 48 | ================== 49 | 50 | * :ref:`genindex` 51 | * :ref:`modindex` 52 | * :ref:`search` 53 | -------------------------------------------------------------------------------- /scripts/train.zsh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/zsh 2 | 3 | export BOOTLEG_STRIP="true" 4 | export BOOTLEG_LOWER="true" 5 | export BOOTLEG_LANG_CODE="en" 6 | export CONFIG_PATH="configs/standard/train.yaml" 7 | export NUM_GPUS=1 8 | 9 | # Train Bootleg 10 | if [ $NUM_GPUS -lt 2 ]; then 11 | python3 -m bootleg.run --config_script $CONFIG_PATH 12 | else 13 | python3 -m torch.distributed.run --nproc_per_node $NUM_GPUS --config_script $CONFIG_PATH 14 | fi 15 | 16 | echo "To load Bootleg model, run..." 17 | echo "from bootleg.end2end.bootleg_annotator import BootlegAnnotator" 18 | echo "ann = BootlegAnnotator(config=$CONFIG_PATH)" 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.util import convert_path 2 | 3 | from setuptools import find_packages, setup 4 | 5 | main_ns = {} 6 | ver_path = convert_path("bootleg/_version.py") 7 | with open(ver_path) as ver_file: 8 | exec(ver_file.read(), main_ns) 9 | 10 | NAME = "bootleg" 11 | DESCRIPTION = "Bootleg NED System" 12 | URL = "https://github.com/HazyResearch/bootleg" 13 | EMAIL = "lorr1@cs.stanford.edu" 14 | AUTHOR = "Laurel Orr" 15 | VERSION = main_ns["__version__"] 16 | 17 | REQUIRED = [ 18 | "argh>=0.26.2, <1.0.0", 19 | "emmental==0.1.0", 20 | "faiss-cpu>=1.6.8, <1.7.1", 21 | "jsonlines>=2.0.0, <2.4.0", 22 | "marisa_trie>=0.7.7, <0.8", 23 | "mock>=4.0.3, <4.5.0", 24 | "nltk>=3.6.4, <4.0.0", 25 | "notebook>=6.4.1, <7.0.0", 26 | "numba>=0.50.0, <0.55.0", 27 | "numpy>=1.19.0, <=1.20.0", 28 | "pandas>=1.2.3, <1.5.0", 29 | "progressbar>=2.5.0, <2.8.0", 30 | "pydantic>=1.7.1, <1.8.0", 31 | "pyyaml>=5.1, <6.0", 32 | "rich>=10.0.0, <10.20.0", 33 | "scikit_learn>=0.24.0, <0.27.0", 34 | "scipy>=1.6.1, <1.9.0", 35 | "spacy>=3.2.0", 36 | "tagme>=0.1.3, <0.2.0", 37 | "torch>=1.7.0, <1.10.5", 38 | "tqdm>=4.27", 39 | "transformers>=4.0.0, <5.0.0", 40 | "ujson>=4.1.0, <4.2.0", 41 | "wandb>=0.10.0, <0.13.0", 42 | ] 43 | 44 | EXTRAS = { 45 | "dev": [ 46 | "black>=22.3.0", 47 | "docformatter==1.4", 48 | "flake8>=3.9.2", 49 | "isort>=5.9.3", 50 | "nbsphinx==0.8.1", 51 | "pep8_naming==0.12.1", 52 | "pre-commit>=2.14.0", 53 | "pytest==6.2.2", 54 | "python-dotenv==0.15.0", 55 | "recommonmark==0.7.1", 56 | "sphinx-rtd-theme==0.5.1", 57 | ], 58 | "embs-gpu": [ 59 | "faiss-gpu>=1.7.0, <1.7.2", 60 | ], 61 | } 62 | 63 | setup( 64 | name=NAME, 65 | version=VERSION, 66 | description=DESCRIPTION, 67 | packages=find_packages(), 68 | url=URL, 69 | install_requires=REQUIRED, 70 | extras_require=EXTRAS, 71 | ) 72 | -------------------------------------------------------------------------------- /tests/data/data_loader/end2end_dev.jsonl: -------------------------------------------------------------------------------- 1 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":0,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 2 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":1,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 3 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":2,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 4 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":3,"sentence":"alias1 and multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[11,28]]} 5 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":4,"sentence":"alias1 and multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[11,28]]} 6 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":5,"sentence":"alias1 and multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[11,28]]} 7 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":6,"sentence":"alias1 and multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[11,28]]} 8 | {"aliases":[],"qids":[],"sent_idx_unq":7,"sentence":"alias1 or multi word alias2","spans":[],"gold":[],"char_spans":[]} 9 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":8,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 10 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":9,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 11 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":10,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 12 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q2","Q2"],"sent_idx_unq":11,"sentence":"alias1 and multi word alias2 and multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[11,28],[33,50]]} 13 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q2","Q2"],"sent_idx_unq":12,"sentence":"alias1 and multi word alias2 and multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[11,28],[33,50]]} 14 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4"],"sent_idx_unq":13,"sentence":"alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[10,27],[31,48]]} 15 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4"],"sent_idx_unq":14,"sentence":"alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[10,27],[31,48]]} 16 | {"aliases":["not in our list","multi word alias2","multi word alias2","alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4","Q1","Q4","Q4"],"sent_idx_unq":15,"sentence":"alias1 or multi word alias2 or multi word alias2 alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9],[9,10],[11,14],[15,18]],"gold":[true,true,true,true,true,true],"char_spans":[[0,6],[10,27],[31,48],[49,55],[59,76],[80,97]]} 17 | {"sentence":"The high-speed observation deck elevators accelerate to a world-record certified speed of 1,010 metres per minute (61 km\/h) in 16 seconds, and then it slows down for arrival with subtle air pressure sensations. The door opens after 37 seconds from the 5th floor. Special features include aerodynamic car and counterweights, and cabin pressure control to help passengers adapt smoothly to pressure changes. The downwards journey is completed at a reduced speed of 600 meters per minute, with the doors opening at the 52nd second.","id":"572f0fb8c246551400ce488a","aliases":["highspeed","observation","deck","elevators","worldrecord","speed","air pressure","sensations","the door","counterweights","cabin","downwards","reduced speed"],"spans":[[1,2],[2,3],[3,4],[4,5],[8,9],[10,11],[30,32],[32,33],[33,35],[49,50],[51,52],[63,64],[69,71]],"qids":["Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3"],"gold":[true,true,true,true,true,true,true,true,true,true,true,true,true],"sent_idx_unq":16,"char_spans":[[4,14],[15,26],[27,31],[32,41],[58,70],[81,86],[186,198],[199,210],[211,219],[308,323],[328,333],[410,419],[446,459]]} 18 | {"aliases":[],"qids":[],"sent_idx_unq":17,"sentence":"alias1 or multi word alias2","spans":[],"gold":[],"char_spans":[]} 19 | {"aliases":["not in our list"],"qids":["Q4"],"sent_idx_unq":18,"sentence":"alias1 or multi word alias2","spans":[[0,1]],"gold":[true],"char_spans":[[0,6]]} 20 | -------------------------------------------------------------------------------- /tests/data/data_loader/end2end_train.jsonl: -------------------------------------------------------------------------------- 1 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":0,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 2 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":1,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 3 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":2,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 4 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":3,"sentence":"alias1 and and multi word alias2","spans":[[0,1],[3,6]],"gold":[true,true],"char_spans":[[0,6],[15,32]]} 5 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":4,"sentence":"alias1 and and multi word alias2","spans":[[0,1],[3,6]],"gold":[true,true],"char_spans":[[0,6],[15,32]]} 6 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":5,"sentence":"alias1 and and multi word alias2","spans":[[0,1],[3,6]],"gold":[true,true],"char_spans":[[0,6],[15,32]]} 7 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q2"],"sent_idx_unq":6,"sentence":"alias1 and and multi word alias2","spans":[[0,1],[3,6]],"gold":[true,true],"char_spans":[[0,6],[15,32]]} 8 | {"aliases":[],"qids":[],"sent_idx_unq":7,"sentence":"alias1 or multi word alias2","spans":[],"gold":[],"char_spans":[]} 9 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":8,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 10 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":9,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 11 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":10,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 12 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q2","Q2"],"sent_idx_unq":11,"sentence":"alias1 and multi word alias2 and multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[11,28],[33,50]]} 13 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q2","Q2"],"sent_idx_unq":12,"sentence":"alias1 and multi word alias2 and multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[11,28],[33,50]]} 14 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4"],"sent_idx_unq":13,"sentence":"alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[10,27],[31,48]]} 15 | {"aliases":["alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4"],"sent_idx_unq":14,"sentence":"alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9]],"gold":[true,true,true],"char_spans":[[0,6],[10,27],[31,48]]} 16 | {"aliases":["alias1","multi word alias2","multi word alias2","alias1","multi word alias2","multi word alias2"],"qids":["Q1","Q4","Q4","Q1","Q4","Q4"],"sent_idx_unq":15,"sentence":"alias1 or multi word alias2 or multi word alias2 alias1 or multi word alias2 or multi word alias2","spans":[[0,1],[2,5],[6,9],[9,10],[11,14],[15,18]],"gold":[true,true,true,true,true,true],"char_spans":[[0,6],[10,27],[31,48],[49,55],[59,76],[80,97]]} 17 | {"sentence":"The high-speed observation deck elevators accelerate to a world-record certified speed of 1,010 metres per minute (61 km\/h) in 16 seconds, and then it slows down for arrival with subtle air pressure sensations. The door opens after 37 seconds from the 5th floor. Special features include aerodynamic car and counterweights, and cabin pressure control to help passengers adapt smoothly to pressure changes. The downwards journey is completed at a reduced speed of 600 meters per minute, with the doors opening at the 52nd second.","id":"572f0fb8c246551400ce488a","aliases":["highspeed","observation","deck","elevators","worldrecord","speed","air pressure","sensations","the door","counterweights","cabin","downwards","reduced speed"],"spans":[[1,2],[2,3],[3,4],[4,5],[8,9],[10,11],[30,32],[32,33],[33,35],[49,50],[51,52],[63,64],[69,71]],"qids":["Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3","Q3"],"gold":[true,true,true,true,true,true,true,true,true,true,true,true,true],"sent_idx_unq":16,"char_spans":[[4,14],[15,26],[27,31],[32,41],[58,70],[81,86],[186,198],[199,210],[211,219],[308,323],[328,333],[410,419],[446,459]]} 18 | {"aliases":[],"qids":[],"sent_idx_unq":17,"sentence":"alias1 or multi word alias2","spans":[],"gold":[],"char_spans":[]} 19 | -------------------------------------------------------------------------------- /tests/data/data_loader/end2end_train_not_in_cand.jsonl: -------------------------------------------------------------------------------- 1 | {"aliases":["alias1","multi word alias2"],"qids":["Q2","Q3"],"sent_idx_unq":0,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 2 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":1,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 3 | {"aliases":["alias1","multi word alias2"],"qids":["Q1","Q4"],"sent_idx_unq":2,"sentence":"alias1 or multi word alias2","spans":[[0,1],[2,5]],"gold":[true,true],"char_spans":[[0,6],[10,27]]} 4 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2id/config.json: -------------------------------------------------------------------------------- 1 | {"max_id":16} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2id/itoexti.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/alias2id/itoexti.npy -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2id/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/alias2id/vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2qids/max_value.json: -------------------------------------------------------------------------------- 1 | 3 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2qids/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/alias2qids/record_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/alias2qids/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/alias2qids/vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/config.json: -------------------------------------------------------------------------------- 1 | {"max_candidates":3,"datetime":"2021-11-12 01:17:00.266434"} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/qid2eid/config.json: -------------------------------------------------------------------------------- 1 | {"max_id":4} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/qid2eid/itoexti.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/qid2eid/itoexti.npy -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/qid2eid/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/entity_mappings/qid2eid/vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/entity_mappings/qid2title.json: -------------------------------------------------------------------------------- 1 | {"Q1":"alias1","Q2":"multi alias2","Q3":"word alias3","Q4":"nonalias4"} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/kg_mappings/config.json: -------------------------------------------------------------------------------- 1 | {"max_connections":2,"datetime":"2021-11-12 01:17:00.271087"} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/kg_mappings/qid2relations/key_vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/kg_mappings/qid2relations/key_vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/kg_mappings/qid2relations/max_value.json: -------------------------------------------------------------------------------- 1 | 4 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/kg_mappings/qid2relations/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/kg_mappings/qid2relations/record_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/kg_mappings/qid2relations/value_vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/kg_mappings/qid2relations/value_vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/type_mappings/wiki/config.json: -------------------------------------------------------------------------------- 1 | {"max_types":3} 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/type_mappings/wiki/qid2typenames/max_value.json: -------------------------------------------------------------------------------- 1 | 3 2 | -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/type_mappings/wiki/qid2typenames/record_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/type_mappings/wiki/qid2typenames/record_trie.marisa -------------------------------------------------------------------------------- /tests/data/entity_loader/entity_data/type_mappings/wiki/qid2typenames/vocabulary_trie.marisa: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/tests/data/entity_loader/entity_data/type_mappings/wiki/qid2typenames/vocabulary_trie.marisa -------------------------------------------------------------------------------- /tests/run_args/test_candgen.json: -------------------------------------------------------------------------------- 1 | { 2 | "emmental": { 3 | "n_steps": 5, 4 | "seed": 1234, 5 | "lr_scheduler": "linear", 6 | "lr": 1e-4 7 | }, 8 | "run_config": { 9 | "dataloader_threads": 0, 10 | "dump_preds_accumulation_steps": 1, 11 | "eval_batch_size": 1, 12 | "dataset_threads": 1, 13 | "spawn_method": "fork", 14 | "log_level": "debug" 15 | }, 16 | "train_config": { 17 | "batch_size": 2, 18 | }, 19 | "model_config": { 20 | "hidden_size": 32, 21 | }, 22 | "data_config": { 23 | "entity_dir": "tests/data/entity_loader/entity_data", 24 | "train_in_candidates": true, 25 | "max_seq_len": 32, 26 | "data_dir": "tests/data/data_loader", 27 | "overwrite_preprocessed_data": true, 28 | "print_examples_prep": false, 29 | "use_entity_desc": false, 30 | "word_embedding": 31 | { 32 | "context_layers": 1, 33 | "entity_layers": 1, 34 | "cache_dir": "tests/data/emb_data/pretrained_bert_models" 35 | }, 36 | "train_dataset": { 37 | "file": "end2end_train.jsonl" 38 | }, 39 | "dev_dataset": { 40 | "file": "end2end_dev.jsonl" 41 | }, 42 | "test_dataset": { 43 | "file": "end2end_dev.jsonl" 44 | } 45 | } 46 | } 47 | -------------------------------------------------------------------------------- /tests/run_args/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_config": { 3 | "dataloader_threads": 0, 4 | "dataset_threads": 1, 5 | "spawn_method": "fork" 6 | }, 7 | "train_config": { 8 | "batch_size": 1, 9 | }, 10 | "data_config": { 11 | "entity_dir": "tests/data/entity_loader/entity_data", 12 | "train_in_candidates": true, 13 | "max_seq_len": 5, 14 | "data_dir": "tests/data/data_loader", 15 | "overwrite_preprocessed_data": true, 16 | "entity_type_data": { 17 | "use_entity_types": true, 18 | "type_symbols_dir": "type_mappings/wiki", 19 | "max_ent_type_len": 5, 20 | }, 21 | "word_embedding": 22 | { 23 | "context_layers": 1, 24 | "entity_layers": 1, 25 | "cache_dir": "tests/data/emb_data/pretrained_bert_models" 26 | }, 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /tests/run_args/test_end2end.json: -------------------------------------------------------------------------------- 1 | { 2 | "emmental": { 3 | "n_steps": 5, 4 | "seed": 1234, 5 | "lr_scheduler": "linear", 6 | "lr": 1e-4 7 | }, 8 | "run_config": { 9 | "dataloader_threads": 0, 10 | "dump_preds_accumulation_steps": 1, 11 | "eval_batch_size": 1, 12 | "dataset_threads": 1, 13 | "spawn_method": "fork", 14 | "log_level": "debug" 15 | }, 16 | "train_config": { 17 | "batch_size": 2, 18 | }, 19 | "model_config": { 20 | "hidden_size": 32, 21 | }, 22 | "data_config": { 23 | "entity_dir": "tests/data/entity_loader/entity_data", 24 | "train_in_candidates": true, 25 | "max_seq_len": 32, 26 | "data_dir": "tests/data/data_loader", 27 | "overwrite_preprocessed_data": true, 28 | "print_examples_prep": false, 29 | "use_entity_desc": false, 30 | "entity_type_data": { 31 | "use_entity_types": false, 32 | "max_ent_type_len": 10, 33 | "type_symbols_dir": "type_mappings/wiki", 34 | }, 35 | "entity_kg_data": { 36 | "use_entity_kg": false, 37 | "max_ent_kg_len": 10, 38 | "kg_symbols_dir": "kg_mappings" 39 | }, 40 | "word_embedding": 41 | { 42 | "context_layers": 1, 43 | "entity_layers": 1, 44 | "cache_dir": "tests/data/emb_data/pretrained_bert_models" 45 | }, 46 | "train_dataset": { 47 | "file": "end2end_train.jsonl" 48 | }, 49 | "dev_dataset": { 50 | "file": "end2end_dev.jsonl" 51 | }, 52 | "test_dataset": { 53 | "file": "end2end_dev.jsonl" 54 | } 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /tests/run_args/test_entity_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_config": { 3 | "dataloader_threads": 0, 4 | "dataset_threads": 1, 5 | }, 6 | "train_config": { 7 | "batch_size": 1, 8 | }, 9 | "data_config": { 10 | "entity_dir": "tests/data/entity_loader/entity_data", 11 | "train_in_candidates": true, 12 | "max_seq_len": 5, 13 | "data_dir": "tests/data/data_loader", 14 | "overwrite_preprocessed_data": true, 15 | "use_entity_desc": false, 16 | "entity_type_data": { 17 | "use_entity_types": true, 18 | "type_symbols_dir": "type_mappings/wiki", 19 | "max_ent_type_len": 5, 20 | }, 21 | "entity_kg_data": { 22 | "use_entity_kg": false, 23 | "max_ent_kg_len": 10, 24 | "kg_symbols_dir": "kg_mappings" 25 | }, 26 | "word_embedding": 27 | { 28 | "context_layers": 1, 29 | "entity_layers": 1, 30 | "cache_dir": "tests/data/emb_data/pretrained_bert_models" 31 | }, 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /tests/test_cand_gen/test_eval.py: -------------------------------------------------------------------------------- 1 | """Test entity embedding generation.""" 2 | import os 3 | import shutil 4 | import unittest 5 | 6 | import emmental 7 | import torch 8 | 9 | import cand_gen.eval as eval 10 | import cand_gen.train as train 11 | from bootleg.utils import utils 12 | from cand_gen.utils.parser import parser_utils 13 | 14 | 15 | class TestGenEntities(unittest.TestCase): 16 | """Test entity generation.""" 17 | 18 | def setUp(self) -> None: 19 | """Set up.""" 20 | self.args = parser_utils.parse_boot_and_emm_args( 21 | "tests/run_args/test_candgen.json" 22 | ) 23 | # This _MUST_ get passed the args so it gets a random seed set 24 | emmental.init(log_dir="tests/temp_log", config=self.args) 25 | if not os.path.exists(emmental.Meta.log_path): 26 | os.makedirs(emmental.Meta.log_path) 27 | 28 | def tearDown(self) -> None: 29 | """Tear down.""" 30 | dir = os.path.join( 31 | self.args.data_config.data_dir, self.args.data_config.data_prep_dir 32 | ) 33 | if utils.exists_dir(dir): 34 | shutil.rmtree(dir, ignore_errors=True) 35 | dir = os.path.join( 36 | self.args.data_config.entity_dir, self.args.data_config.entity_prep_dir 37 | ) 38 | if utils.exists_dir(dir): 39 | shutil.rmtree(dir, ignore_errors=True) 40 | dir = os.path.join("tests/temp_log") 41 | if os.path.exists(dir): 42 | shutil.rmtree(dir, ignore_errors=True) 43 | 44 | def test_end2end(self): 45 | """Test end2end entity generation.""" 46 | # For the collate and dataloaders to play nicely, the spawn must be fork (this is set in run.py) 47 | torch.multiprocessing.set_start_method("fork", force=True) 48 | 49 | # Train and save model 50 | train.run_model(config=self.args) 51 | 52 | self.args["model_config"][ 53 | "model_path" 54 | ] = f"{emmental.Meta.log_path}/last_model.pth" 55 | emmental.Meta.config["model_config"][ 56 | "model_path" 57 | ] = f"{emmental.Meta.log_path}/last_model.pth" 58 | 59 | candidates_file, metrics_file = eval.run_model(config=self.args) 60 | assert os.path.exists(candidates_file) 61 | assert os.path.exists(candidates_file) 62 | num_sents = len([_ for _ in open(candidates_file)]) 63 | assert num_sents == 17 64 | 65 | 66 | if __name__ == "__main__": 67 | unittest.main() 68 | -------------------------------------------------------------------------------- /tests/test_end_to_end/test_end_to_end.py: -------------------------------------------------------------------------------- 1 | """End2end test.""" 2 | import os 3 | import shutil 4 | import unittest 5 | 6 | import emmental 7 | import ujson 8 | 9 | from bootleg.run import run_model 10 | from bootleg.utils import utils 11 | from bootleg.utils.parser import parser_utils 12 | 13 | 14 | class TestEnd2End(unittest.TestCase): 15 | """Test end to end.""" 16 | 17 | def setUp(self) -> None: 18 | """Set up.""" 19 | self.args = parser_utils.parse_boot_and_emm_args( 20 | "tests/run_args/test_end2end.json" 21 | ) 22 | # This _MUST_ get passed the args so it gets a random seed set 23 | emmental.init(log_dir="tests/temp_log", config=self.args) 24 | if not os.path.exists(emmental.Meta.log_path): 25 | os.makedirs(emmental.Meta.log_path) 26 | 27 | def tearDown(self) -> None: 28 | """Tear down.""" 29 | dir = os.path.join( 30 | self.args.data_config.data_dir, self.args.data_config.data_prep_dir 31 | ) 32 | if utils.exists_dir(dir): 33 | shutil.rmtree(dir, ignore_errors=True) 34 | dir = os.path.join( 35 | self.args.data_config.entity_dir, self.args.data_config.entity_prep_dir 36 | ) 37 | if utils.exists_dir(dir): 38 | shutil.rmtree(dir, ignore_errors=True) 39 | dir = os.path.join("tests/temp_log") 40 | if os.path.exists(dir): 41 | shutil.rmtree(dir, ignore_errors=True) 42 | 43 | def test_end2end(self): 44 | """End2end base test.""" 45 | # Just setting this for testing pipelines 46 | scores = run_model(mode="train", config=self.args) 47 | assert type(scores) is dict 48 | assert len(scores) > 0 49 | assert scores["model/all/dev/loss"] < 1.1 50 | 51 | self.args["model_config"][ 52 | "model_path" 53 | ] = f"{emmental.Meta.log_path}/last_model.pth" 54 | emmental.Meta.config["model_config"][ 55 | "model_path" 56 | ] = f"{emmental.Meta.log_path}/last_model.pth" 57 | 58 | result_file = run_model(mode="dump_preds", config=self.args) 59 | assert os.path.exists(result_file) 60 | results = [ujson.loads(li) for li in open(result_file)] 61 | assert 19 == len(results) # 18 total sentences 62 | assert len([f for li in results for f in li["entity_ids"]]) == 52 63 | 64 | # Doubling up a test here to also test accumulation steps 65 | def test_end2end_accstep(self): 66 | """Test end2end with accumulation steps.""" 67 | # Just setting this for testing pipelines 68 | self.args.data_config.dump_preds_accumulation_steps = 2 69 | self.args.run_config.dataset_threads = 2 70 | scores = run_model(mode="train", config=self.args) 71 | assert type(scores) is dict 72 | assert len(scores) > 0 73 | assert scores["model/all/dev/loss"] < 1.1 74 | 75 | self.args["model_config"][ 76 | "model_path" 77 | ] = f"{emmental.Meta.log_path}/last_model.pth" 78 | emmental.Meta.config["model_config"][ 79 | "model_path" 80 | ] = f"{emmental.Meta.log_path}/last_model.pth" 81 | 82 | result_file = run_model(mode="dump_preds", config=self.args) 83 | assert os.path.exists(result_file) 84 | results = [ujson.loads(li) for li in open(result_file)] 85 | assert 19 == len(results) # 18 total sentences 86 | assert len([f for li in results for f in li["entity_ids"]]) == 52 87 | 88 | # Doubling up a test here to also test greater than 1 eval batch size 89 | def test_end2end_evalbatch(self): 90 | """Test end2end with eval batch size.""" 91 | self.args.data_config.dump_preds_accumulation_steps = 2 92 | self.args.run_config.dataset_threads = 2 93 | self.args.run_config.eval_batch_size = 2 94 | 95 | scores = run_model(mode="train", config=self.args) 96 | assert type(scores) is dict 97 | assert len(scores) > 0 98 | assert scores["model/all/dev/loss"] < 1.1 99 | 100 | self.args["model_config"][ 101 | "model_path" 102 | ] = f"{emmental.Meta.log_path}/last_model.pth" 103 | emmental.Meta.config["model_config"][ 104 | "model_path" 105 | ] = f"{emmental.Meta.log_path}/last_model.pth" 106 | 107 | result_file = run_model(mode="dump_preds", config=self.args) 108 | assert os.path.exists(result_file) 109 | results = [ujson.loads(li) for li in open(result_file)] 110 | assert 19 == len(results) # 18 total sentences 111 | assert len([f for li in results for f in li["entity_ids"]]) == 52 112 | 113 | shutil.rmtree("tests/temp", ignore_errors=True) 114 | 115 | def test_end2end_bert_long_context(self): 116 | """Test end2end with longer sentence context.""" 117 | self.args.data_config.max_seq_len = 256 118 | self.args.run_config.dump_preds_num_data_splits = 4 119 | scores = run_model(mode="train", config=self.args) 120 | assert type(scores) is dict 121 | assert len(scores) > 0 122 | assert scores["model/all/dev/loss"] < 1.1 123 | 124 | self.args["model_config"][ 125 | "model_path" 126 | ] = f"{emmental.Meta.log_path}/last_model.pth" 127 | emmental.Meta.config["model_config"][ 128 | "model_path" 129 | ] = f"{emmental.Meta.log_path}/last_model.pth" 130 | 131 | result_file = run_model(mode="dump_preds", config=self.args) 132 | assert os.path.exists(result_file) 133 | results = [ujson.loads(li) for li in open(result_file)] 134 | assert 19 == len(results) # 18 total sentences 135 | assert len([f for li in results for f in li["entity_ids"]]) == 52 136 | 137 | shutil.rmtree("tests/temp", ignore_errors=True) 138 | 139 | def test_end2end_train_in_cands_false(self): 140 | """End2end base test.""" 141 | # Just setting this for testing pipelines 142 | self.args.data_config.train_in_candidates = False 143 | self.args.data_config.train_dataset.file = "end2end_train_not_in_cand.jsonl" 144 | scores = run_model(mode="train", config=self.args) 145 | assert type(scores) is dict 146 | assert len(scores) > 0 147 | assert scores["model/all/dev/loss"] < 1.5 148 | 149 | self.args["model_config"][ 150 | "model_path" 151 | ] = f"{emmental.Meta.log_path}/last_model.pth" 152 | emmental.Meta.config["model_config"][ 153 | "model_path" 154 | ] = f"{emmental.Meta.log_path}/last_model.pth" 155 | 156 | result_file = run_model(mode="dump_preds", config=self.args) 157 | assert os.path.exists(result_file) 158 | results = [ujson.loads(li) for li in open(result_file)] 159 | assert 19 == len(results) # 18 total sentences 160 | assert len([f for li in results for f in li["entity_ids"]]) == 52 161 | 162 | 163 | if __name__ == "__main__": 164 | unittest.main() 165 | -------------------------------------------------------------------------------- /tests/test_end_to_end/test_gen_entities.py: -------------------------------------------------------------------------------- 1 | """Test generate entities.""" 2 | import os 3 | import shutil 4 | import unittest 5 | 6 | import emmental 7 | import numpy as np 8 | import torch 9 | import ujson 10 | 11 | import bootleg.extract_all_entities as extract_all_entities 12 | import bootleg.run as run 13 | from bootleg.utils import utils 14 | from bootleg.utils.parser import parser_utils 15 | 16 | 17 | class TestGenEntities(unittest.TestCase): 18 | """Test generate entites.""" 19 | 20 | def setUp(self) -> None: 21 | """Set up.""" 22 | self.args = parser_utils.parse_boot_and_emm_args( 23 | "tests/run_args/test_end2end.json" 24 | ) 25 | # This _MUST_ get passed the args so it gets a random seed set 26 | emmental.init(log_dir="tests/temp_log", config=self.args) 27 | if not os.path.exists(emmental.Meta.log_path): 28 | os.makedirs(emmental.Meta.log_path) 29 | 30 | def tearDown(self) -> None: 31 | """Tear down.""" 32 | dir = os.path.join( 33 | self.args.data_config.data_dir, self.args.data_config.data_prep_dir 34 | ) 35 | if utils.exists_dir(dir): 36 | shutil.rmtree(dir, ignore_errors=True) 37 | dir = os.path.join( 38 | self.args.data_config.entity_dir, self.args.data_config.entity_prep_dir 39 | ) 40 | if utils.exists_dir(dir): 41 | shutil.rmtree(dir, ignore_errors=True) 42 | dir = os.path.join("tests/temp_log") 43 | if os.path.exists(dir): 44 | shutil.rmtree(dir, ignore_errors=True) 45 | 46 | def test_end2end(self): 47 | """Test end to end.""" 48 | # For the collate and dataloaders to play nicely, the spawn must be fork (this is set in run.py) 49 | torch.multiprocessing.set_start_method("fork", force=True) 50 | 51 | # Train and save model 52 | run.run_model(mode="train", config=self.args) 53 | 54 | self.args["model_config"][ 55 | "model_path" 56 | ] = f"{emmental.Meta.log_path}/last_model.pth" 57 | emmental.Meta.config["model_config"][ 58 | "model_path" 59 | ] = f"{emmental.Meta.log_path}/last_model.pth" 60 | 61 | out_emb_file = extract_all_entities.run_model(config=self.args) 62 | assert os.path.exists(out_emb_file) 63 | embs = np.load(out_emb_file) 64 | assert list(embs.shape) == [6, 32] 65 | 66 | final_result_file = run.run_model( 67 | mode="dump_preds", config=self.args, entity_emb_file=out_emb_file 68 | ) 69 | 70 | lines = [ujson.loads(ln) for ln in open(final_result_file)] 71 | 72 | final_result_file = run.run_model( 73 | mode="dump_preds", config=self.args, entity_emb_file=None 74 | ) 75 | lines_no_emb_file = [ujson.loads(ln) for ln in open(final_result_file)] 76 | assert len(lines) == len(lines_no_emb_file) 77 | 78 | 79 | if __name__ == "__main__": 80 | unittest.main() 81 | -------------------------------------------------------------------------------- /tests/test_end_to_end/test_mention_extraction.py: -------------------------------------------------------------------------------- 1 | """Test mention extraction.""" 2 | import os 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | import ujson 8 | 9 | from bootleg.symbols.entity_symbols import EntitySymbols 10 | 11 | 12 | class MentionExtractionTest(unittest.TestCase): 13 | """Mention extraction test.""" 14 | 15 | def setUp(self) -> None: 16 | """Set up.""" 17 | self.test_dir = tempfile.TemporaryDirectory() 18 | 19 | def tearDown(self) -> None: 20 | """Tear down.""" 21 | self.test_dir.cleanup() 22 | 23 | def write_data(self, file, data): 24 | """Write data.""" 25 | Path(file).parent.mkdir(parents=True, exist_ok=True) 26 | with open(file, "w") as out_f: 27 | for line in data: 28 | out_f.write(ujson.dumps(line) + "\n") 29 | 30 | def test_mention_extraction(self): 31 | """Test that mention extraction runs without crashing.""" 32 | in_file = Path(self.test_dir.name) / "train.jsonl" 33 | out_file = Path(self.test_dir.name) / "train_out.jsonl" 34 | entity_db = Path(self.test_dir.name) / "entity_db" / "entity_mappings" 35 | 36 | alias2qids = { 37 | "happy": [["Q1", 1.0], ["Q2", 1.0], ["Q3", 1.0]], 38 | "cow": [["Q4", 1.0], ["Q5", 1.0], ["Q6", 1.0]], 39 | "batman": [["Q7", 1.0], ["Q8", 1.0]], 40 | } 41 | 42 | qid2title = { 43 | "Q1": "aack", 44 | "Q2": "back", 45 | "Q3": "cack", 46 | "Q4": "dack", 47 | "Q5": "eack", 48 | "Q6": "fack", 49 | "Q7": "gack", 50 | "Q8": "hack", 51 | } 52 | 53 | mock_entity_db = EntitySymbols(alias2qids, qid2title) 54 | 55 | mock_entity_db.save(entity_db) 56 | 57 | data = [ 58 | { 59 | "sentence": "happy cow batman", 60 | } 61 | ] * 100 62 | 63 | self.write_data(in_file, data) 64 | os.system( 65 | f"python3 bootleg/end2end/extract_mentions.py " 66 | f"--in_file {str(in_file)} " 67 | f"--out_file {str(out_file)} " 68 | f"--entity_db {str(entity_db)} " 69 | f"--num_workers 1 " 70 | f"--num_chunks 10" 71 | ) 72 | 73 | assert out_file.exists() 74 | out_data = [ln for ln in open(out_file)] 75 | assert len(out_data) == 100 76 | 77 | os.system( 78 | f"python3 bootleg/end2end/extract_mentions.py " 79 | f"--in_file {str(in_file)} " 80 | f"--out_file {str(out_file)} " 81 | f"--entity_db {str(entity_db)} " 82 | f"--num_workers 2 " 83 | f"--num_chunks 10" 84 | ) 85 | 86 | assert out_file.exists() 87 | out_data = [ln for ln in open(out_file)] 88 | assert len(out_data) == 100 89 | -------------------------------------------------------------------------------- /tests/test_scorer/test_scorer.py: -------------------------------------------------------------------------------- 1 | """Test scorer.""" 2 | import unittest 3 | 4 | import numpy as np 5 | 6 | from bootleg.scorer import BootlegSlicedScorer 7 | 8 | 9 | class BootlegMockScorer(BootlegSlicedScorer): 10 | """Bootleg mock scorer class.""" 11 | 12 | def __init__(self, train_in_candidates): 13 | """Mock initializer.""" 14 | self.mock_slices = { 15 | 0: {"all": [1], "slice_1": [0]}, 16 | 1: {"all": [1], "slice_1": [1]}, 17 | 2: {"all": [1], "slice_1": [0]}, 18 | 3: {"all": [1], "slice_1": [0]}, 19 | 4: {"all": [1], "slice_1": [1]}, 20 | 5: {"all": [1], "slice_1": [0]}, 21 | } 22 | self.train_in_candidates = train_in_candidates 23 | 24 | def get_slices(self, uid): 25 | """Get slices.""" 26 | return self.mock_slices[uid] 27 | 28 | 29 | class TestScorer(unittest.TestCase): 30 | """Scorer test.""" 31 | 32 | def test_bootleg_scorer(self): 33 | """Test scorer.""" 34 | # batch = 6 35 | scorer = BootlegMockScorer(train_in_candidates=True) 36 | 37 | golds = np.array([0, -2, 1, -1, 0, 3]) 38 | 39 | probs = np.array([]) 40 | 41 | preds = np.array([1, 2, 0, 1, 0, 3]) 42 | 43 | uids = np.array([0, 1, 2, 3, 4, 5]) 44 | 45 | res = scorer.bootleg_score(golds, probs, preds, uids) 46 | 47 | gold_res = {} 48 | slice_name = "all" 49 | gold_res[f"{slice_name}/total_men"] = 5 50 | gold_res[f"{slice_name}/total_notNC_men"] = 4 51 | gold_res[f"{slice_name}/acc_boot"] = 2 / 5 52 | gold_res[f"{slice_name}/acc_notNC_boot"] = 2 / 4 53 | gold_res[f"{slice_name}/acc_pop"] = 2 / 5 54 | gold_res[f"{slice_name}/acc_notNC_pop"] = 2 / 4 55 | 56 | slice_name = "slice_1" 57 | gold_res[f"{slice_name}/total_men"] = 2 58 | gold_res[f"{slice_name}/total_notNC_men"] = 1 59 | gold_res[f"{slice_name}/acc_boot"] = 1 / 2 60 | gold_res[f"{slice_name}/acc_notNC_boot"] = 1 / 1 61 | gold_res[f"{slice_name}/acc_pop"] = 1 / 2 62 | gold_res[f"{slice_name}/acc_notNC_pop"] = 1 / 1 63 | self.assertDictEqual(res, gold_res) 64 | 65 | def test_bootleg_scorer_notincand(self): 66 | """Test scorer non in candidate.""" 67 | # batch = 6 68 | scorer = BootlegMockScorer(train_in_candidates=False) 69 | 70 | golds = np.array([0, 3, 2, -1, 1, 4]) 71 | 72 | probs = np.array([]) 73 | 74 | preds = np.array([0, 3, 0, 1, 2, 4]) 75 | 76 | uids = np.array([0, 1, 2, 3, 4, 5]) 77 | 78 | res = scorer.bootleg_score(golds, probs, preds, uids) 79 | 80 | gold_res = {} 81 | slice_name = "all" 82 | gold_res[f"{slice_name}/total_men"] = 5 83 | gold_res[f"{slice_name}/total_notNC_men"] = 4 84 | gold_res[f"{slice_name}/acc_boot"] = 3 / 5 85 | gold_res[f"{slice_name}/acc_notNC_boot"] = 2 / 4 86 | gold_res[f"{slice_name}/acc_pop"] = 1 / 5 87 | gold_res[f"{slice_name}/acc_notNC_pop"] = 1 / 4 88 | 89 | slice_name = "slice_1" 90 | gold_res[f"{slice_name}/total_men"] = 2 91 | gold_res[f"{slice_name}/total_notNC_men"] = 2 92 | gold_res[f"{slice_name}/acc_boot"] = 1 / 2 93 | gold_res[f"{slice_name}/acc_notNC_boot"] = 1 / 2 94 | gold_res[f"{slice_name}/acc_pop"] = 1 / 2 95 | gold_res[f"{slice_name}/acc_notNC_pop"] = 1 / 2 96 | self.assertDictEqual(res, gold_res) 97 | 98 | 99 | if __name__ == "__main__": 100 | unittest.main() 101 | -------------------------------------------------------------------------------- /tests/test_utils/test_preprocessing.py: -------------------------------------------------------------------------------- 1 | """Test preprocessing utils.""" 2 | import os 3 | import tempfile 4 | import unittest 5 | from pathlib import Path 6 | 7 | import ujson 8 | 9 | from bootleg.symbols.entity_symbols import EntitySymbols 10 | 11 | 12 | class PreprocessingUtils(unittest.TestCase): 13 | """Preprocessing utils test.""" 14 | 15 | def setUp(self) -> None: 16 | """Set up.""" 17 | self.test_dir = tempfile.TemporaryDirectory() 18 | 19 | def tearDown(self) -> None: 20 | """Tear down.""" 21 | self.test_dir.cleanup() 22 | 23 | def write_data(self, file, data): 24 | """Write data.""" 25 | Path(file).parent.mkdir(parents=True, exist_ok=True) 26 | with open(file, "w") as out_f: 27 | for line in data: 28 | out_f.write(ujson.dumps(line) + "\n") 29 | 30 | def test_get_train_qid_counts(self): 31 | """Test get train qid counts.""" 32 | in_file = Path(self.test_dir.name) / "train.jsonl" 33 | out_file = Path(self.test_dir.name) / "train_counts_out.json" 34 | 35 | data = [{"qids": [f"Q{i}" for i in range(5)]}] * 100 36 | 37 | self.write_data(in_file, data) 38 | 39 | os.system( 40 | f"python3 bootleg/utils/preprocessing/get_train_qid_counts.py " 41 | f"--train_file {in_file} " 42 | f"--out_file {out_file}" 43 | ) 44 | 45 | res = ujson.load(open(out_file, "r")) 46 | 47 | assert len(res) == 5 48 | for k in res: 49 | assert res[k] == 100 50 | 51 | def test_compute_statistics(self): 52 | """Test compute statistics.""" 53 | in_file = Path(self.test_dir.name) / "train.jsonl" 54 | entity_db = Path(self.test_dir.name) / "entity_db" / "entity_mappings" 55 | 56 | alias2qids = { 57 | "happy": [["Q1", 1.0], ["Q2", 1.0], ["Q3", 1.0]], 58 | "cow": [["Q4", 1.0], ["Q5", 1.0], ["Q6", 1.0]], 59 | "batman": [["Q7", 1.0], ["Q8", 1.0]], 60 | } 61 | 62 | qid2title = { 63 | "Q1": "aack", 64 | "Q2": "back", 65 | "Q3": "cack", 66 | "Q4": "dack", 67 | "Q5": "eack", 68 | "Q6": "fack", 69 | "Q7": "gack", 70 | "Q8": "hack", 71 | } 72 | 73 | mock_entity_db = EntitySymbols(alias2qids, qid2title) 74 | 75 | mock_entity_db.save(entity_db) 76 | 77 | data = [ 78 | { 79 | "qids": ["Q1", "Q4", "Q7"], 80 | "unswap_aliases": ["happy", "cow", "batman"], 81 | "sentence": "happy cow batman", 82 | } 83 | ] * 100 84 | 85 | self.write_data(in_file, data) 86 | os.system( 87 | f"python3 bootleg/utils/preprocessing/compute_statistics.py " 88 | f"--data_dir {self.test_dir.name} " 89 | f"--save_dir {self.test_dir.name}" 90 | ) 91 | 92 | out_dir = Path(self.test_dir.name) / "stats" 93 | assert out_dir.exists() 94 | alias_cnts = ujson.load(open(out_dir / "alias_counts.json")) 95 | assert len(alias_cnts) == 3 96 | assert all(v == 100 for v in alias_cnts.values()) 97 | 98 | def test_sample_eval_data(self): 99 | """Test sample eval data.""" 100 | in_file = Path(self.test_dir.name) / "train.jsonl" 101 | data = [ 102 | { 103 | "qids": ["Q1", "Q4", "Q7"], 104 | "sent_idx_unq": i, 105 | "aliases": ["happy", "cow", "batman"], 106 | "gold": [True, True, False], 107 | "slices": {"slice_1": {"0": 1.0, "1": 1.0, "2": 1.0}}, 108 | "sentence": "happy cow batman", 109 | } 110 | for i in range(100) 111 | ] 112 | self.write_data(in_file, data) 113 | 114 | os.system( 115 | f"python3 bootleg/utils/preprocessing/sample_eval_data.py " 116 | f"--data_dir {self.test_dir.name} " 117 | f"--slice slice_1 --file train.jsonl --out_file_name train_out.jsonl --min_sample_size 10" 118 | ) 119 | 120 | out_file = Path(self.test_dir.name) / "train_out.jsonl" 121 | assert out_file.exists() 122 | alias_out = [ln for ln in open(out_file)] 123 | assert len(alias_out) == 10 124 | -------------------------------------------------------------------------------- /tests/test_utils/test_util_classes.py: -------------------------------------------------------------------------------- 1 | """Test class utils.""" 2 | import tempfile 3 | import unittest 4 | 5 | from bootleg.end2end.annotator_utils import DownloadProgressBar 6 | from bootleg.utils.classes.nested_vocab_tries import ( 7 | ThreeLayerVocabularyTrie, 8 | TwoLayerVocabularyScoreTrie, 9 | VocabularyTrie, 10 | ) 11 | 12 | 13 | class UtilClasses(unittest.TestCase): 14 | """Class util test.""" 15 | 16 | def test_vocab_trie(self): 17 | """Test vocab trie.""" 18 | input_dict = {"a": 2, "b": 3, "c": -1} 19 | tri = VocabularyTrie(input_dict=input_dict) 20 | 21 | self.assertDictEqual(tri.to_dict(), input_dict) 22 | self.assertEqual(tri["b"], 3) 23 | self.assertEqual(tri["c"], -1) 24 | self.assertEqual(tri.get_key(-1), "c") 25 | self.assertEqual(tri.get_key(2), "a") 26 | self.assertTrue(tri.is_key_in_trie("b")) 27 | self.assertFalse(tri.is_key_in_trie("f")) 28 | self.assertTrue("b" in tri) 29 | self.assertTrue("f" not in tri) 30 | self.assertTrue(tri.is_value_in_trie(-1)) 31 | self.assertFalse(tri.is_value_in_trie(6)) 32 | self.assertEqual(tri.get_max_id(), 3) 33 | self.assertEqual(len(tri), 3) 34 | 35 | save_path = tempfile.TemporaryDirectory() 36 | tri.dump(save_path.name) 37 | tri2 = VocabularyTrie(load_dir=save_path.name) 38 | self.assertDictEqual(tri.to_dict(), input_dict) 39 | self.assertEqual(tri2["b"], 3) 40 | self.assertEqual(tri2["c"], -1) 41 | self.assertEqual(tri2.get_key(-1), "c") 42 | self.assertEqual(tri2.get_key(2), "a") 43 | self.assertTrue(tri2.is_key_in_trie("b")) 44 | self.assertFalse(tri2.is_key_in_trie("f")) 45 | self.assertTrue("b" in tri2) 46 | self.assertTrue("f" not in tri2) 47 | self.assertTrue(tri2.is_value_in_trie(-1)) 48 | self.assertFalse(tri2.is_value_in_trie(6)) 49 | self.assertEqual(tri2.get_max_id(), 3) 50 | self.assertEqual(len(tri2), 3) 51 | 52 | save_path.cleanup() 53 | 54 | def test_paired_vocab_trie(self): 55 | """Test paired vocab trie.""" 56 | for with_scores in [True, False]: 57 | raw_input_dict = {"a": ["1", "4", "5"], "b": ["5", "2"], "c": []} 58 | vocabulary = {"1": 1, "2": 2, "4": 3, "5": 4} 59 | input_dict = {} 60 | score = 1.0 if with_scores else 0.0 61 | for k, lst in list(raw_input_dict.items()): 62 | input_dict[k] = [[it, score] for it in lst] 63 | 64 | if with_scores: 65 | tri = TwoLayerVocabularyScoreTrie( 66 | input_dict=input_dict, vocabulary=vocabulary, max_value=3 67 | ) 68 | else: 69 | tri = TwoLayerVocabularyScoreTrie( 70 | input_dict=raw_input_dict, vocabulary=vocabulary, max_value=3 71 | ) 72 | 73 | self.assertDictEqual(tri.to_dict(keep_score=True), input_dict) 74 | self.assertDictEqual(tri.to_dict(keep_score=False), raw_input_dict) 75 | self.assertEqual(tri.get_value("b"), [["5", score], ["2", score]]) 76 | self.assertTrue(tri.is_key_in_trie("b")) 77 | self.assertFalse(tri.is_key_in_trie("f")) 78 | self.assertSetEqual(set(input_dict.keys()), set(tri.keys())) 79 | self.assertSetEqual(set(vocabulary.keys()), set(tri.vocab_keys())) 80 | 81 | save_path = tempfile.TemporaryDirectory() 82 | tri.dump(save_path.name) 83 | tri2 = TwoLayerVocabularyScoreTrie(load_dir=save_path.name) 84 | 85 | self.assertDictEqual(tri2.to_dict(keep_score=True), input_dict) 86 | self.assertDictEqual(tri2.to_dict(keep_score=False), raw_input_dict) 87 | self.assertEqual(tri2.get_value("b"), [["5", score], ["2", score]]) 88 | self.assertTrue(tri2.is_key_in_trie("b")) 89 | self.assertFalse(tri2.is_key_in_trie("f")) 90 | self.assertSetEqual(set(input_dict.keys()), set(tri2.keys())) 91 | self.assertSetEqual(set(vocabulary.keys()), set(tri2.vocab_keys())) 92 | 93 | save_path.cleanup() 94 | 95 | def test_dict_vocab_trie(self): 96 | """Test paired vocab trie.""" 97 | raw_input_dict = { 98 | "q1": {"a": ["1", "4", "5"], "b": ["3", "5"]}, 99 | "q2": {"b": ["5", "2"]}, 100 | } 101 | key_vocabulary = {"a": 1, "b": 2, "c": 3} 102 | value_vocabulary = {"1": 1, "2": 2, "3": 3, "4": 4, "5": 5} 103 | 104 | tri = ThreeLayerVocabularyTrie( 105 | input_dict=raw_input_dict, 106 | key_vocabulary=key_vocabulary, 107 | value_vocabulary=value_vocabulary, 108 | max_value=6, 109 | ) 110 | 111 | self.assertDictEqual(tri.to_dict(), raw_input_dict) 112 | self.assertDictEqual(tri.get_value("q1"), raw_input_dict["q1"]) 113 | self.assertTrue(tri.is_key_in_trie("q2")) 114 | self.assertFalse(tri.is_key_in_trie("q3")) 115 | self.assertSetEqual(set(raw_input_dict.keys()), set(tri.keys())) 116 | self.assertSetEqual(set(key_vocabulary.keys()), set(tri.key_vocab_keys())) 117 | self.assertSetEqual(set(value_vocabulary.keys()), set(tri.value_vocab_keys())) 118 | 119 | save_path = tempfile.TemporaryDirectory() 120 | tri.dump(save_path.name) 121 | tri2 = ThreeLayerVocabularyTrie(load_dir=save_path.name) 122 | 123 | self.assertDictEqual(tri2.to_dict(), raw_input_dict) 124 | self.assertDictEqual(tri2.get_value("q1"), raw_input_dict["q1"]) 125 | self.assertTrue(tri2.is_key_in_trie("q2")) 126 | self.assertFalse(tri2.is_key_in_trie("q3")) 127 | self.assertSetEqual(set(raw_input_dict.keys()), set(tri2.keys())) 128 | self.assertSetEqual(set(key_vocabulary.keys()), set(tri2.key_vocab_keys())) 129 | self.assertSetEqual(set(value_vocabulary.keys()), set(tri2.value_vocab_keys())) 130 | 131 | save_path.cleanup() 132 | 133 | def test_download_progress_bar(self): 134 | """Test download progress bar.""" 135 | pbar = DownloadProgressBar() 136 | pbar(1, 5, 10) 137 | assert pbar.pbar is not None 138 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorials 2 | We provide several tutorials to help users get familiar with Bootleg. 3 | 4 | ## Introduction 5 | ### End to End 6 | In this [tutorial](end2end_ned_tutorial.ipynb), learn how to use Bootleg for end-to-end inference. We start from text data and show how to detect mentions and then link them to entities. We also show how to use Bootleg for "on-the-fly" disambiguation of individual sentences. 7 | 8 | ### On the Fly 9 | In this [tutorial](annotation-on-the-fly.ipynb), we show how to use Bootleg for "on-the-fly" disambiguation of individual sentences. 10 | 11 | ### Entity Profiles 12 | In this [tutorial](entity_profile_tutorial.ipynb), learn how to modify and change the entity database associated with a Bootleg model. We start from the downloaded entity profile data and show how to add/remove entities and change type and relation mappings. We then show how to fit an existing model to this profile and load it into a new annotator (or use it on your own data!). 13 | 14 | ## Using Bootleg Representations 15 | ### Embeddings Extraction 16 | In this [tutorial](entity_embedding_tutorial.ipynb), we will introduce you to how to take a pretrained Bootleg model and generate entity representations. The next tutorial shows you how to use them in a downstream model. 17 | 18 | ### Bootleg-Enhanced TACRED 19 | In this [tutorial](downstream_tutorial/), we show you how to integrate Bootleg embeddings into a downstream LSTM model and SPAN-BERT model. 20 | 21 | ## Training 22 | ### Basic Training 23 | In this [tutorial](https://bootleg.readthedocs.io/en/latest/gettingstarted/training.html), learn how to train a Bootleg model on a small dataset. This will cover input data formatting, data preprocessing, and training. 24 | 25 | ### Advanced Training 26 | In this [tutorial](https://bootleg.readthedocs.io/en/latest/advanced/distributed_training.html), learn how to use distributed training to train a Bootleg model on the full English Wikipedia save (over 50 million sentences!). You will need access to GPUs to train this model. 27 | -------------------------------------------------------------------------------- /tutorials/download_data.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=data 2 | 3 | mkdir -p $DATA_DIR 4 | 5 | wget https://bootleg-ned-data.s3-us-west-1.amazonaws.com/data/latest/nq.tar.gz -P $DATA_DIR 6 | wget https://bootleg-ned-data.s3-us-west-1.amazonaws.com/data/latest/entity_db.tar.gz -P $DATA_DIR 7 | 8 | tar -xzvf $DATA_DIR/nq.tar.gz -C $DATA_DIR 9 | tar -xzvf $DATA_DIR/entity_db.tar.gz -C $DATA_DIR 10 | -------------------------------------------------------------------------------- /tutorials/download_model.sh: -------------------------------------------------------------------------------- 1 | MODEL_DIR=models 2 | MODEL=${1:-uncased} 3 | 4 | mkdir -p $MODEL_DIR 5 | 6 | wget https://bootleg-ned-data.s3-us-west-1.amazonaws.com/models/latest/bootleg_$MODEL.tar.gz -P $MODEL_DIR 7 | 8 | tar -xzvf $MODEL_DIR/bootleg_$MODEL.tar.gz -C $MODEL_DIR 9 | -------------------------------------------------------------------------------- /tutorials/download_wiki.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR=$1 2 | 3 | if [ $# -eq 0 ] 4 | then 5 | echo "Need to include data directory as 'bash download_wiki.sh <data_dir>'" 6 | exit 1 7 | fi 8 | 9 | mkdir -p $DATA_DIR 10 | 11 | wget https://bootleg-ned-data.s3-us-west-1.amazonaws.com/data/latest/wiki.tar.gz -P $DATA_DIR 12 | 13 | tar -xzvf $DATA_DIR/wiki.tar.gz -C $DATA_DIR 14 | -------------------------------------------------------------------------------- /web/images/bootleg-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/bootleg-logo.png -------------------------------------------------------------------------------- /web/images/bootleg-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/bootleg-performance.png -------------------------------------------------------------------------------- /web/images/bootleg-text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/bootleg-text.png -------------------------------------------------------------------------------- /web/images/bootleg_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/bootleg_architecture.png -------------------------------------------------------------------------------- /web/images/bootleg_dataflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/bootleg_dataflow.png -------------------------------------------------------------------------------- /web/images/full_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/bootleg/20ef63852f30c0f372901772a81b4b88a8aae0e5/web/images/full_logo.png --------------------------------------------------------------------------------