├── .circleci └── config.yml ├── .flake8 ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── stale.yml └── workflows │ └── publish-to-test-pypi.yml ├── .gitignore ├── .gitmodules ├── .pep8speaks.yml ├── .pre-commit-config.yaml ├── .pre-commit-hooks.yaml ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets └── simple_api │ └── task_config_templates.json ├── codecov.yml ├── conftest.py ├── examples ├── README.md └── notebooks │ ├── jiant_Basic_Example.ipynb │ ├── jiant_EdgeProbing_Example.ipynb │ ├── jiant_MNLI_Diagnostic_Example.ipynb │ ├── jiant_Multi_Task_Example.ipynb │ ├── jiant_STILTs_Example.ipynb │ ├── jiant_XNLI_Example.ipynb │ └── simple_api_fine_tuning.ipynb ├── guides ├── README.md ├── benchmarks │ ├── glue.md │ ├── superglue.md │ └── xtreme.md ├── experiments │ ├── large_scale_experiments.md │ └── my_experiment_and_me.md ├── general │ ├── in_depth_intro.md │ ├── pipeline_scripts.png │ ├── pipeline_scripts.svg │ └── pipeline_simplified.png ├── models │ └── adding_models.md ├── projects │ └── xstilts.md ├── tasks │ ├── adding_tasks.md │ ├── supported_tasks.md │ └── task_specific.md └── tutorials │ ├── quick_start_main.md │ └── quick_start_simple.md ├── jiant ├── __init__.py ├── ext │ ├── __init__.py │ ├── allennlp.py │ └── radam.py ├── proj │ ├── __init__.py │ ├── main │ │ ├── __init__.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── container_setup.py │ │ │ ├── evaluate.py │ │ │ ├── outputs.py │ │ │ ├── task_sampler.py │ │ │ └── write_configs.py │ │ ├── export_model.py │ │ ├── metarunner.py │ │ ├── modeling │ │ │ ├── __init__.py │ │ │ ├── heads.py │ │ │ ├── model_setup.py │ │ │ ├── primary.py │ │ │ └── taskmodels.py │ │ ├── preprocessing.py │ │ ├── runner.py │ │ ├── runscript.py │ │ ├── scripts │ │ │ ├── __init__.py │ │ │ └── configurator.py │ │ ├── tokenize_and_cache.py │ │ └── write_task_configs.py │ └── simple │ │ ├── __init__.py │ │ └── runscript.py ├── scripts │ ├── __init__.py │ ├── benchmarks │ │ ├── __init__.py │ │ ├── benchmark_submission_formatter.py │ │ ├── benchmarks.py │ │ └── xtreme │ │ │ ├── subscripts │ │ │ ├── a_download_model.sh │ │ │ ├── b_download_data.sh │ │ │ ├── c_tokenize_and_cache.sh │ │ │ ├── d_write_configs.sh │ │ │ ├── e_run_models.sh │ │ │ └── run_all.sh │ │ │ ├── xtreme_runconfig_writer.py │ │ │ └── xtreme_submission.py │ ├── download_data │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── dl_datasets │ │ │ ├── __init__.py │ │ │ ├── files_tasks.py │ │ │ ├── hf_datasets_tasks.py │ │ │ └── xtreme.py │ │ ├── runscript.py │ │ └── utils.py │ └── preproc │ │ ├── __init__.py │ │ └── export_glue_data.py ├── shared │ ├── __init__.py │ ├── caching.py │ ├── constants.py │ ├── distributed.py │ ├── initialization.py │ ├── metarunner.py │ ├── model_resolution.py │ ├── model_setup.py │ └── runner.py ├── tasks │ ├── __init__.py │ ├── constants.py │ ├── core.py │ ├── evaluate │ │ ├── __init__.py │ │ └── core.py │ ├── lib │ │ ├── __init__.py │ │ ├── abductive_nli.py │ │ ├── acceptability_judgement │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── coord.py │ │ │ ├── definiteness.py │ │ │ ├── eos.py │ │ │ └── whwords.py │ │ ├── adversarial_nli.py │ │ ├── arc_challenge.py │ │ ├── arc_easy.py │ │ ├── arct.py │ │ ├── boolq.py │ │ ├── bucc2018.py │ │ ├── ccg.py │ │ ├── cola.py │ │ ├── commitmentbank.py │ │ ├── commonsenseqa.py │ │ ├── copa.py │ │ ├── cosmosqa.py │ │ ├── edge_probing │ │ │ ├── __init__.py │ │ │ ├── coref.py │ │ │ ├── dep.py │ │ │ ├── dpr.py │ │ │ ├── ner.py │ │ │ ├── nonterminal.py │ │ │ ├── pos.py │ │ │ ├── semeval.py │ │ │ ├── spr1.py │ │ │ ├── spr2.py │ │ │ └── srl.py │ │ ├── fever_nli.py │ │ ├── glue_diagnostics.py │ │ ├── hellaswag.py │ │ ├── mcscript.py │ │ ├── mctaco.py │ │ ├── mctest.py │ │ ├── mlm_premasked.py │ │ ├── mlm_pretokenized.py │ │ ├── mlm_simple.py │ │ ├── mlqa.py │ │ ├── mnli.py │ │ ├── mnli_mismatched.py │ │ ├── mrpc.py │ │ ├── mrqa_natural_questions.py │ │ ├── multirc.py │ │ ├── mutual.py │ │ ├── mutual_plus.py │ │ ├── newsqa.py │ │ ├── panx.py │ │ ├── pawsx.py │ │ ├── piqa.py │ │ ├── qamr.py │ │ ├── qasrl.py │ │ ├── qnli.py │ │ ├── qqp.py │ │ ├── quail.py │ │ ├── quoref.py │ │ ├── race.py │ │ ├── record.py │ │ ├── ropes.py │ │ ├── rte.py │ │ ├── scitail.py │ │ ├── senteval │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── bigram_shift.py │ │ │ ├── coordination_inversion.py │ │ │ ├── obj_number.py │ │ │ ├── odd_man_out.py │ │ │ ├── past_present.py │ │ │ ├── sentence_length.py │ │ │ ├── subj_number.py │ │ │ ├── top_constituents.py │ │ │ ├── tree_depth.py │ │ │ └── word_content.py │ │ ├── snli.py │ │ ├── socialiqa.py │ │ ├── squad.py │ │ ├── sst.py │ │ ├── stsb.py │ │ ├── superglue_axb.py │ │ ├── superglue_axg.py │ │ ├── swag.py │ │ ├── tatoeba.py │ │ ├── templates │ │ │ ├── __init__.py │ │ │ ├── edge_probing_single_span.py │ │ │ ├── edge_probing_two_span.py │ │ │ ├── hacky_tokenization_matching.py │ │ │ ├── mlm.py │ │ │ ├── mlm_premasked.py │ │ │ ├── multiple_choice.py │ │ │ ├── shared.py │ │ │ ├── span_prediction.py │ │ │ └── squad_style │ │ │ │ ├── __init__.py │ │ │ │ ├── core.py │ │ │ │ └── utils.py │ │ ├── tydiqa.py │ │ ├── udpos.py │ │ ├── wic.py │ │ ├── winogrande.py │ │ ├── wnli.py │ │ ├── wsc.py │ │ ├── xnli.py │ │ └── xquad.py │ ├── retrieval.py │ └── utils.py └── utils │ ├── __init__.py │ ├── config_handlers.py │ ├── data_handlers.py │ ├── display.py │ ├── path_parse.py │ ├── python │ ├── __init__.py │ ├── checks.py │ ├── datastructures.py │ ├── filesystem.py │ ├── functional.py │ ├── io.py │ ├── logic.py │ └── strings.py │ ├── retokenize.py │ ├── string_comparing.py │ ├── testing │ ├── __init__.py │ ├── tokenizer.py │ └── utils.py │ ├── tokenization_normalization.py │ ├── tokenization_utils.py │ ├── torch_utils.py │ ├── zconf │ ├── __init__.py │ └── core.py │ └── zlog.py ├── pyproject.toml ├── requirements-dev.txt ├── requirements-no-torch.txt ├── requirements.txt ├── setup.py └── tests ├── README.md ├── __init__.py ├── proj ├── main │ ├── components │ │ └── test_task_sampler.py │ └── test_export_model.py └── simple │ └── test_runscript.py ├── tasks └── lib │ ├── resources │ ├── data │ │ ├── mnli │ │ │ ├── mnli_test.jsonl │ │ │ ├── mnli_train.jsonl │ │ │ └── mnli_val.jsonl │ │ ├── spr1 │ │ │ ├── test.jsonl │ │ │ └── train.jsonl │ │ └── sst │ │ │ ├── test.jsonl │ │ │ ├── train.jsonl │ │ │ └── val.jsonl │ ├── mnli.json │ ├── spr1.json │ └── sst.json │ ├── templates │ └── test_hacky_tokenization_matching.py │ ├── test_mlm_premasked.py │ ├── test_mlm_pretokenized.py │ ├── test_mnli.py │ ├── test_spr1.py │ ├── test_sst.py │ └── test_wic.py ├── test_zconf ├── __init__.py ├── jsons │ ├── empty.json │ ├── simple.json │ ├── store_true.json │ └── store_true_false.json ├── test_conf_jsons.py └── test_confs.py └── utils ├── __init__.py ├── config ├── base_config.json ├── final_config.json ├── first_override_config.json └── second_override_config.json ├── python ├── __init__.py ├── test_checks.py ├── test_datastructures.py ├── test_filesystem.py ├── test_functional.py └── test_logic.py ├── test_config_handlers.py ├── test_data_handlers.py ├── test_path_parse.py ├── test_token_alignment.py ├── test_tokenization_normalization.py └── test_utils.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | orbs: 4 | codecov: codecov/codecov@1.0.5 5 | 6 | jobs: 7 | test: 8 | docker: 9 | - image: python:3.8 10 | steps: 11 | - checkout 12 | - restore_cache: 13 | key: -v3-{{ checksum "requirements-dev.txt" }} 14 | - run: 15 | name: Install dependencies 16 | command: | 17 | python3 -m venv venv 18 | source venv/bin/activate 19 | pip install -r requirements-dev.txt 20 | - save_cache: 21 | key: -v3-{{ checksum "requirements-dev.txt" }} 22 | paths: 23 | - "venv" 24 | - run: 25 | name: black formatting check 26 | command: | 27 | source venv/bin/activate 28 | black --check jiant/ 29 | - run: 30 | name: flake8 31 | command: | 32 | source venv/bin/activate 33 | flake8 --docstring-convention google jiant/ 34 | - run: 35 | name: Unit Tests 36 | command: | 37 | source venv/bin/activate 38 | pytest 39 | - run: 40 | name: Coverage Report 41 | command: | 42 | source venv/bin/activate 43 | pytest --cov-report=xml --cov=jiant tests/ 44 | - codecov/upload: 45 | file: coverage.xml 46 | 47 | workflows: 48 | build: 49 | jobs: 50 | - test 51 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | 4 | ignore = 5 | # "these rules don't play well with black", from AllenNLP 6 | E203 # whitespace before : 7 | W503 # line break before binary operatori 8 | # these are docstring-related ignores: 9 | D100 # Missing docstring in public module 10 | D101 # Missing docstring in public class 11 | D102 # Missing docstring in public method 12 | D103 # Missing docstring in public function 13 | D104 # Missing docstring in public package 14 | D105 # Missing docstring in magic method 15 | D107 # Missing docstring in __init__ 16 | D400 # First line should end with a period 17 | D401 # First line should be in imperative mood; try rephrasing 18 | D415 # First line should end with a period, question mark, or exclamation point 19 | D205 # 1 blank line required between summary line and description 20 | 21 | exclude = 22 | examples/** 23 | tests/** 24 | jiant/ext/allennlp.py # excluded to avoid modifying code copied from AllenNLP. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | 1. Tell use which version of `jiant` you're using 15 | 2. Describe the environment where you're using `jiant`, e.g, "2 P40 GPUs" 16 | 3. Provide the experiment config artifact (e.g., `defaults.conf`) 17 | 18 | **Expected behavior** 19 | A clear and concise description of what you expected to happen. 20 | 21 | **Screenshots** 22 | If applicable, add screenshots to help explain your problem. 23 | 24 | **Additional context** 25 | Add any other context about the problem here. 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false 18 | -------------------------------------------------------------------------------- /.github/workflows/publish-to-test-pypi.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python 🐍 distributions 📦 to PyPI 2 | 3 | on: push 4 | 5 | jobs: 6 | build-n-publish: 7 | name: Build and publish Python 🐍 distributions 📦 to PyPI 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - uses: actions/checkout@master 12 | - name: Set up Python 3.7 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: 3.7 16 | 17 | - name: Install 18 | if: startsWith(github.ref, 'refs/tags') 19 | run: >- 20 | python3 -m 21 | pip install 22 | --user 23 | --upgrade setuptools wheel 24 | 25 | - name: Build a binary wheel and a source tarball 26 | if: startsWith(github.ref, 'refs/tags') 27 | run: >- 28 | python3 setup.py 29 | sdist 30 | bdist_wheel 31 | 32 | - name: Publish distribution 📦 to PyPI 33 | if: startsWith(github.ref, 'refs/tags') 34 | uses: pypa/gh-action-pypi-publish@master 35 | with: 36 | password: ${{ secrets.pypi_password }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | wheels/ 22 | pip-wheel-metadata/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 94 | __pypackages__/ 95 | 96 | # Celery stuff 97 | celerybeat-schedule 98 | celerybeat.pid 99 | 100 | # SageMath parsed files 101 | *.sage.py 102 | 103 | # Environments 104 | .env 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "jiant/cove"] 2 | path = jiant/modules/cove 3 | url = https://github.com/salesforce/cove.git 4 | ignore = untracked 5 | -------------------------------------------------------------------------------- /.pep8speaks.yml: -------------------------------------------------------------------------------- 1 | scanner: 2 | diff_only: False # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned. 3 | 4 | pycodestyle: # Same as scanner.linter value. Other option is flake8 5 | max-line-length: 100 # Default is 79 in PEP 8 6 | ignore: 7 | - E203 # Whitespace before :, not a strict PEP8 requirement and sometimes incompatible with black. 8 | - W503 # Deprecated, incompatible with black. 9 | 10 | no_blank_comment: False # If True, no comment is made on PR without any errors. 11 | 12 | message: # Customize the comment made by the bot 13 | opened: # Messages when a new PR is submitted 14 | # The keyword {name} is converted into the author's username 15 | footer: "You can repair most issues by installing [black](https://github.com/ambv/black) and running: `black -l 100 ./*`. If you contribute often, have a look at the 'Contributing' section of the [README](https://github.com/nyu-mll/jiant) for instructions on doing this automatically." 16 | # The messages can be written as they would over GitHub 17 | updated: # Messages when new commits are added to the PR 18 | footer: "You can repair most issues by installing [black](https://github.com/ambv/black) and running: `black -l 100 ./*`. If you contribute often, have a look at the 'Contributing' section of the [README](https://github.com/nyu-mll/jiant) for instructions on doing this automatically." 19 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/ambv/black 3 | rev: stable 4 | hooks: 5 | - id: black 6 | language_version: python3.7 7 | - repo: https://gitlab.com/pycqa/flake8 8 | rev: 3.7.9 9 | hooks: 10 | - id: flake8 11 | -------------------------------------------------------------------------------- /.pre-commit-hooks.yaml: -------------------------------------------------------------------------------- 1 | - id: black 2 | name: black 3 | description: 'Black: The uncompromising Python code formatter' 4 | entry: black 5 | language: python 6 | language_version: python3 7 | require_serial: true 8 | types: [python] 9 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # These owners will be the default owners for everything in 2 | # the repo. Unless a later match takes precedence, 3 | # @global-owner1 and @global-owner2 will be requested for 4 | # review when someone opens a pull request. 5 | * @zphang @jeswan 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing to `jiant` 2 | 3 | Thanks for considering contributing to `jiant`! :+1: 4 | 5 | #### Guidelines for a successful PR review process: 6 | 1. Choose a descriptive PR title (“Adding SNLI Task” rather than “add task”). 7 | 2. In the PR description field provide a summary explaining the motivation for the changes, and link to any related issues. 8 | 3. PRs should address only one issue (or a few very closely related issues). 9 | 4. While your PR is a work in progress (WIP), use the [Draft PR feature](https://github.blog/2019-02-14-introducing-draft-pull-requests/) to provide visibility without requesting a review. 10 | 5. Once your PR is ready for review, in your Draft PR press “Ready for review”. 11 | 12 | #### Requirements for pull requests (PR) into `jiant`'s master branch: 13 | 1. Requirements applied by the automated build system: 14 | 1. black formatting check 15 | 2. flake8 check for style and documentation 16 | 3. pytest unit tests 17 | 2. Requirements for successful code reviews: 18 | 1. Code changes must be paired with effective tests. 19 | 2. PRs adding or modifying code must make appropriate changes to related documentation (using [google style](https://google.github.io/styleguide/pyguide.html)). 20 | 21 | #### Setting up your local dev environment to run the validation steps applied to PRs by the build system: 22 | ``` 23 | pip install -r requirements-dev.txt 24 | pre-commit install 25 | ``` 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 New York University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/simple_api/task_config_templates.json: -------------------------------------------------------------------------------- 1 | { 2 | "boolq": { 3 | "task": "boolq", 4 | "rel_paths": { 5 | "train": "./train.jsonl", 6 | "val": "./val.jsonl", 7 | "test": "./test.jsonl" 8 | }, 9 | "name": "boolq" 10 | }, 11 | "ccg": { 12 | "task": "ccg", 13 | "rel_paths": { 14 | "train": "./ccg.train", 15 | "val": "./ccg.dev", 16 | "test": "./ccg.test", 17 | "tags_to_id": "./tags_to_id.json" 18 | }, 19 | "name": "ccg" 20 | }, 21 | "cola": { 22 | "task": "cola", 23 | "rel_paths": { 24 | "train": "./train.jsonl", 25 | "val": "./val.jsonl", 26 | "test": "./test.jsonl" 27 | }, 28 | "name": "cola" 29 | }, 30 | "cosmosqa": { 31 | "task": "cosmosqa", 32 | "rel_paths": { 33 | "train": "./train.csv", 34 | "val": "./valid.csv", 35 | "test": "./test_no_label.csv" 36 | }, 37 | "name": "cosmosqa" 38 | }, 39 | "hellaswag": { 40 | "task": "hellaswag", 41 | "rel_paths": { 42 | "train": "./hellaswag_train.jsonl", 43 | "val": "./hellaswag_val.jsonl", 44 | "test": "./hellaswag_test.jsonl" 45 | }, 46 | "name": "hellaswag" 47 | }, 48 | "mnli": { 49 | "task": "mnli", 50 | "rel_paths": { 51 | "train": "./train.jsonl", 52 | "val": "./val.jsonl", 53 | "test": "./test.jsonl" 54 | }, 55 | "name": "mnli" 56 | }, 57 | "mrpc": { 58 | "task": "mrpc", 59 | "rel_paths": { 60 | "train": "./train.jsonl", 61 | "val": "./val.jsonl", 62 | "test": "./test.jsonl" 63 | }, 64 | "name": "mrpc" 65 | }, 66 | "qnli": { 67 | "task": "qnli", 68 | "rel_paths": { 69 | "train": "./train.jsonl", 70 | "val": "./val.jsonl", 71 | "test": "./test.jsonl" 72 | }, 73 | "name": "qnli" 74 | }, 75 | "qqp": { 76 | "task": "qqp", 77 | "rel_paths": { 78 | "train": "./train.jsonl", 79 | "val": "./val.jsonl", 80 | "test": "./test.jsonl" 81 | }, 82 | "name": "qqp" 83 | }, 84 | "rte": { 85 | "task": "rte", 86 | "rel_paths": { 87 | "train": "./train.jsonl", 88 | "val": "./val.jsonl", 89 | "test": "./test.jsonl" 90 | }, 91 | "name": "rte" 92 | }, 93 | "squad_v1": { 94 | "task": "squad", 95 | "rel_paths": { 96 | "train": "./train-v1.1.json", 97 | "val": "./dev-v1.1.json" 98 | }, 99 | "name": "squad_v1" 100 | }, 101 | "stsb": { 102 | "task": "stsb", 103 | "rel_paths": { 104 | "train": "./train.jsonl", 105 | "val": "./val.jsonl", 106 | "test": "./test.jsonl" 107 | }, 108 | "name": "stsb" 109 | }, 110 | "wic": { 111 | "task": "wic", 112 | "rel_paths": { 113 | "train": "./train.jsonl", 114 | "val": "./val.jsonl", 115 | "test": "./test.jsonl" 116 | }, 117 | "name": "wic" 118 | }, 119 | "wnli": { 120 | "task": "wnli", 121 | "rel_paths": { 122 | "train": "./train.jsonl", 123 | "val": "./val.jsonl", 124 | "test": "./test.jsonl" 125 | }, 126 | "name": "wnli" 127 | } 128 | } 129 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: false 4 | patch: false 5 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Skipping slow tests according to command line: https://docs.pytest.org/en/latest/example/simple.html 4 | """ 5 | import pytest 6 | 7 | 8 | def pytest_addoption(parser): 9 | parser.addoption("--runslow", action="store_true", default=False, help="run slow tests") 10 | parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests") 11 | parser.addoption( 12 | "--runovernight", action="store_true", default=False, help="run overnight tests" 13 | ) 14 | 15 | 16 | def pytest_configure(config): 17 | config.addinivalue_line("markers", "slow: mark test as slow to run") 18 | config.addinivalue_line("markers", "gpu: mark test as gpu required to run") 19 | config.addinivalue_line("markers", "overnight: mark test as gpu required to run") 20 | 21 | 22 | def pytest_collection_modifyitems(config, items): 23 | if not config.getoption("--runslow"): 24 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 25 | for item in items: 26 | if "slow" in item.keywords: 27 | item.add_marker(skip_slow) 28 | if not config.getoption("--rungpu"): 29 | skip_gpu = pytest.mark.skip(reason="need --rungpu option to run") 30 | for item in items: 31 | if "gpu" in item.keywords: 32 | item.add_marker(skip_gpu) 33 | if not config.getoption("--runovernight"): 34 | skip_overnight = pytest.mark.skip(reason="need --runovernight option to run") 35 | for item in items: 36 | if "overnight" in item.keywords: 37 | item.add_marker(skip_overnight) 38 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Also check out our [Guides](../guides) to learn more about `jiant`. 4 | 5 | ## Example Notebooks 6 | 7 | ### Simple Fine-tuning 8 | 9 | * **Fine-tuning on a single task with the simple API** [[Notebook](./notebooks/simple_api_fine_tuning.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/simple_api_fine_tuning.ipynb)] 10 | * **Fine-tuning on a single task with the main API** [[Notebook](./notebooks/jiant_Basic_Example.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/jiant_Basic_Example.ipynb)] 11 | 12 | ### Intermediate use cases 13 | 14 | * **Multi-task Training** [[Notebook](./notebooks/jiant_Multi_Task_Example.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/jiant_Multi_Task_Example.ipynb)] 15 | * Fine-tuning on multiple tasks simultaneously 16 | * **STILTs Training** [[Notebook](./notebooks/jiant_STILTs_Example.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/jiant_STILTs_Example.ipynb)] 17 | * Fine-tuning on multiple tasks sequentially 18 | * **Zero-shot transfer to XNLI** [[Notebook](./notebooks/jiant_XNLI_Example.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/jiant_XNLI_Example.ipynb)] 19 | * Fine-tuning on MNLI for zero-shot transfer to XNLI 20 | * **MNLI-mismatched/GLUE Diagnostic** [[Notebook](./notebooks/jiant_MNLI_Diagnostic_Example.ipynb)] [[Colab](https://colab.research.google.com/github/nyu-mll/jiant/blob/master/examples/notebooks/jiant_MNLI_Diagnostic_Example.ipynb)] 21 | * Fine-tuning on MNLI for MNLI-mismatched and GLUE Diagnostic set 22 | -------------------------------------------------------------------------------- /guides/README.md: -------------------------------------------------------------------------------- 1 | # Guides 2 | 3 | Also check out our [Examples](../examples) to see `jiant` in action. 4 | 5 | If you don't know what to read, why not read our [In-Depth Introduction to Jiant](general/in_depth_intro.md)? 6 | 7 | Contents: 8 | 9 | * [Tutorials](#tutorials) 10 | * [General](#general) 11 | * [Benchmarks](#benchmarks) 12 | * [Experiments](#experiments) 13 | * [Tasks](#tasks) 14 | * [Papers / Projects](#papers--projects) 15 | 16 | --- 17 | 18 | ## Tutorials 19 | 20 | These are quick tutorials that demonstrate `jiant` usage. 21 | 22 | * [Quick Start Guide — Using the "Simple" CLI](tutorials/quick_start_simple.md): A simple `jiant` training run in bash, using the "Simple" CLI 23 | * [Quick Start Guide — Using the "Main" CLI](tutorials/quick_start_main.md): A simple `jiant` training run in bash, using the "Main" CLI 24 | 25 | The "Simple" API provides a single command-line script for training and evaluating models on tasks, while the "Main" API offers more flexibilty by breaking the workflow down into discrete steps (downloading the model, tokenization & caching, writing a fully specific run-configuration, and finally running the experiment). Both interfaces use the same models and task implementations uner the hood. 26 | 27 | 28 | ## General 29 | 30 | These are general guides to `jiant`'s design and components. Refer to these if you have questions about parts of `jiant`: 31 | 32 | * [In-Depth Introduction to Jiant](general/in_depth_intro.md): Learn about `jiant` in greater detail 33 | * [`jiant`'s models](general/in_depth_intro.md#jiants-models) 34 | * [`jiant`'s tasks](general/in_depth_intro.md#jiants-tasks) 35 | * [`Runner`s and `Metarunner`s](general/in_depth_intro.md#runners-and-metarunners) 36 | * [Step-by-step through `jiant`'s pipeline](general/in_depth_intro.md#step-by-step-through-jiants-pipeline) 37 | 38 | ## Running benchmarks 39 | 40 | These are guides to running common NLP benchmarks using `jiant`: 41 | 42 | * [GLUE Benchmark](benchmarks/glue.md): Generate GLUE Benchmark submissions 43 | * [SuperGLUE Benchmark](benchmarks/superglue.md): Generate SuperGLUE Benchmark submissions 44 | * [XTREME](benchmarks/xtreme.md): End-to-end guide for training and generating submission for the XTREME bernchmark 45 | 46 | ## Tips & Tricks for Running Experiments 47 | 48 | These are more specific guides about running experiments in `jiant`: 49 | 50 | * [My Experiment and Me](experiments/my_experiment_and_me.md): More info about a `jiant` training/eval run 51 | * [Tips for Large-scale Experiments](experiments/large_scale_experiments.md) 52 | 53 | ## Adding a Model 54 | * [Guide for adding a model to `jiant`](models/adding_models.md) 55 | 56 | ## Tasks 57 | 58 | These are notes on the tasks supported in `jiant`: 59 | 60 | * [List of supported tasks in `jiant`](tasks/supported_tasks.md) 61 | * [Task-specific notes](tasks/task_specific.md): Learn about quirks/caveats about specific tasks 62 | * [Adding Tasks](tasks/adding_tasks.md): Guide on adding a task to `jiant` 63 | 64 | ## Papers / Projects 65 | 66 | * [English Intermediate-Task Training Improves Zero-Shot Cross-Lingual Transfer Too (X-STILTs)](projects/xstilts.md) -------------------------------------------------------------------------------- /guides/benchmarks/glue.md: -------------------------------------------------------------------------------- 1 | # GLUE Benchmark 2 | 3 | ## Generating Submissions 4 | 5 | `jiant` supports generating submission files for [GLUE](https://gluebenchmark.com/). To generate test predictions, use the `--write_test_preds` flag in [`runscript.py`](https://github.com/jiant-dev/jiant/blob/master/jiant/proj/main/runscript.py) when running your workflow. This will generate a `test_preds.p` file in the specified output directory. To convert `test_preds.p` to the required GLUE submission format, use the following command: 6 | 7 | ```bash 8 | python benchmark_submission_formatter.py \ 9 | --benchmark GLUE \ 10 | --input_base_path $INPUT_BASE_PATH \ 11 | --output_path $OUTPUT_BASE_PATH 12 | ``` 13 | 14 | where `$INPUT_BASE_PATH` contains the task folder(s) output by [runscript.py](https://github.com/jiant-dev/jiant/blob/master/jiant/proj/main/runscript.py). Alternatively, a subset of tasks can be formatted using: 15 | 16 | ```bash 17 | python benchmark_submission_formatter.py \ 18 | --benchmark GLUE \ 19 | --tasks cola mrpc \ 20 | --input_base_path $INPUT_BASE_PATH \ 21 | --output_path $OUTPUT_BASE_PATH 22 | ``` 23 | -------------------------------------------------------------------------------- /guides/benchmarks/superglue.md: -------------------------------------------------------------------------------- 1 | # SuperGLUE Benchmark Submission Formatter 2 | 3 | `jiant` supports generating submission files for [SuperGLUE](https://super.gluebenchmark.com/). To generate test predictions, use the `--write_test_preds` flag in [`runscript.py`](https://github.com/jiant-dev/jiant/blob/master/jiant/proj/main/runscript.py) when running your workflow. This will generate a `test_preds.p` file in the specified output directory. To convert `test_preds.p` to the required GLUE submission format, use the following command: 4 | 5 | ```bash 6 | python benchmark_submission_formatter.py \ 7 | --benchmark SUPERGLUE \ 8 | --input_base_path $INPUT_BASE_PATH \ 9 | --output_path $OUTPUT_BASE_PATH 10 | ``` 11 | 12 | where `$INPUT_BASE_PATH` contains the task folder(s) output by [runscript.py](https://github.com/nyu-mll/jiant/blob/master/jiant/proj/main/runscript.py). Alternatively, a subset of tasks can be formatted using: 13 | 14 | ```bash 15 | python benchmark_submission_formatter.py \ 16 | --benchmark SUPERGLUE \ 17 | --tasks cola mrpc \ 18 | --input_base_path $INPUT_BASE_PATH \ 19 | --output_path $OUTPUT_BASE_PATH 20 | ``` 21 | -------------------------------------------------------------------------------- /guides/experiments/large_scale_experiments.md: -------------------------------------------------------------------------------- 1 | # Tips for Large-scale Experiments 2 | 3 | `jiant` was designed with large-scale transfer-learning experiments in mind. Here are some tips to manage and collect results from multiple experiments. 4 | 5 | ### Aggregated results using `path_parse` 6 | 7 | One common format for running experiments is to run something like the following on SLURM: 8 | 9 | ```bash 10 | for TASK in mnli rte squad_v1; do 11 | for MODEL in roberta-base bert-base-cased; do 12 | export TASK=${TASK} 13 | export MODEL=${MODEL} 14 | export OUTPUT_PATH=/path/to/experiments/${MODEL}/${TASK} 15 | sbatch my_run_script.sbatch 16 | done 17 | done 18 | ``` 19 | where `my_run_script.sbatch` kicks off an experiment, and where the run is saved to the output path `/path/to/experiments/${MODEL}/${TASK}`. As seen in [my_experiment_and_me.md](./my_experiment_and_me.md), the results are stored in `val_metrics.json`. 20 | 21 | A quick was to pick up the results across the range of experiments is to run code like this: 22 | 23 | ```python 24 | import pandas as pd 25 | import jiant.utils.python.io as io 26 | import jiant.utils.path_parse as path_parse 27 | 28 | matches = path_parse.match_paths("/path/to/experiments/{model}/{task}/val_metrics.json") 29 | for match in matches: 30 | match["score"] = io.read_json(match["path"])[match["task"]]["major"] 31 | del match["path"] 32 | df = pd.DataFrame(matches).set_index(["model", "task"]) 33 | ``` 34 | 35 | This returns a nice table of the results for each run across your range of experiments. 36 | -------------------------------------------------------------------------------- /guides/experiments/my_experiment_and_me.md: -------------------------------------------------------------------------------- 1 | # My Experiment and Me 2 | 3 | ### Run outputs 4 | 5 | After running an experiment, you will see your run folder populated with many files and folders. Here's a quick run-down of what they are: 6 | 7 | * `args.json`: Saved a copy of your run arguments for future reference. 8 | * `last_model.p`: Model weights at the end of training. 9 | * `last_model.metadata.json`: Contains the metadata for the best-model-weights (e.g. last step of training). 10 | * `best_model.p`: The best version of the model weights based on validation-subset, 11 | * `best_model.metadata.json`: Contains the metadata for the best-model-weights (e.g. what step of training they were from, validation scores at that point). 12 | * `checkpoint.p`: A checkpoint for the run that allows you to resume interrupted runs. Contains additional training state, such as the optimizer state, so it's at least 2x as large as model weights. 13 | * `{log-timestamp}/loss_train.zlog`: JSONL log of training loss over training steps 14 | * `{log-timestamp}/early_stopping.zlog`: JSONL log of early-stopping progress (e.g. steps since last best model) 15 | * `{log-timestamp}/train_val.zlog`: JSONL log of validation-subset evaluation over the course of training (i.e. what's used for early stopping) 16 | * `{log-timestamp}/train_val_best.zlog`: JSONL log of validation-subset evaluation, only recording the improving runs 17 | -------------------------------------------------------------------------------- /guides/general/pipeline_scripts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/guides/general/pipeline_scripts.png -------------------------------------------------------------------------------- /guides/general/pipeline_simplified.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/guides/general/pipeline_simplified.png -------------------------------------------------------------------------------- /guides/projects/xstilts.md: -------------------------------------------------------------------------------- 1 | # English Intermediate-Task Training Improves Zero-Shot Cross-Lingual Transfer Too 2 | 3 | This guide describes how to replicate the experiments in [English Intermediate-Task Training Improves Zero-Shot Cross-Lingual Transfer Too](https://arxiv.org/abs/2005.13013). 4 | 5 | ## Overview 6 | 7 | The experiments described in the paper follow a simple transfer learning procedure: 8 | 9 | 1. Fine-tune XLM-R (large) on an intermediate task (e.g. MNLI, SQuAD), or on multiple tasks 10 | 2. Evaluate the checkpoint from (1) on the XTREME benchmark. 11 | 12 | In the second step, this entails tuning the checkpoint from (1) on the English training sets of the 7 XTREME tasks, and zero-shot evaluation on the tuned models on the same tasks in other languages. There are also 2 XTREME tasks where the checkpoint from (1) is evaluated directly (BuCC, Tatoeba). 13 | 14 | ## Intermediate-Task Training 15 | 16 | ### Single/Multiple Intermediate Task 17 | 18 | For tuning on a single or multiple intermediate tasks, you can use the [Quick Start guide](../tutorials/quick_start_simple.md) as a reference. You should follow the same steps for downloading the data. Be sure to use the relevant tasks and XLM-R models. For instance, the training command should look something like: 19 | 20 | ```bash 21 | python jiant/jiant/proj/simple/runscript.py \ 22 | run \ 23 | --run_name mnli_and_squad \ 24 | --exp_dir ./experiments/stilts \ 25 | --data_dir $(pwd)/tasks/data \ 26 | --model_type xlm-roberta-large \ 27 | --train_batch_size 4 \ 28 | --tasks mnli,squad_v1 29 | ``` 30 | 31 | ## XTREME Benchmark Evaluation 32 | 33 | The [XTREME benchmark guide](../benchmarks/xtreme.md) describes how to evaluate XLM-R on the XTREME benchmark, end-to-end. You can generally follow the guide, except two steps: 34 | 35 | 1. You don't need to re-download `xlm-roberta-large`, since you should have a checkpoint from the previous step. 36 | 2. In the [final step](../benchmarks/xtreme.md#trainrun-models) where you train/run the models, replace the following line 37 | ```bash 38 | --model_load_mode from_transformers \ 39 | ``` 40 | with 41 | ```bash 42 | --ZZoverrides model_load_path \ 43 | --model_load_mode partial \ 44 | --model_load_path /path/to/my/model.p \ 45 | ``` 46 | 47 | This ensures that your encoder is loaded from the model tuned on the intermediate task. 48 | 49 | ## Citation 50 | 51 | If you would like to cite our work: 52 | 53 | Jason Phang, Iacer Calixto, Phu Mon Htut, Yada Pruksachatkun, Haokun Liu, Clara Vania, Katharina Kann, and Samuel R. Bowman **"English Intermediate-Task Training Improves Zero-Shot Cross-Lingual Transfer Too."** *Proceedings of AACL, 2020* 54 | 55 | ``` 56 | @inproceedings{phang2020english, 57 | author = {Jason Phang and Iacer Calixto and Phu Mon Htut and Yada Pruksachatkun and Haokun Liu and Clara Vania and Katharina Kann and Samuel R. Bowman}, 58 | title = {English Intermediate-Task Training Improves Zero-Shot Cross-Lingual Transfer Too}, 59 | booktitle = {Proceedings of AACL}, 60 | year = {2020} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /guides/tasks/task_specific.md: -------------------------------------------------------------------------------- 1 | ## Task-specific Notes 2 | 3 | ### Adversarial NLI 4 | 5 | [Adversarial NLI](https://arxiv.org/pdf/1910.14599.pdf) has 3 rounds of adversarial data creation. A1, A2 and A3 are different rounds of data creation. When downloading, you can use the task names `adversarial_nli_r1`, `adversarial_nli_r2`, `adversarial_nli_r3` to point the the different rounds. 6 | 7 | When doing training on the full ANLI dataset, which is SNLI+MNLI+A1+A2+A3, perform training in a multi-task manner with proportional sampling, and be sure to set the `task_to_taskmodel_map` to have all tasks point to the same NLI head. 8 | 9 | 10 | ### Masked Language Modeling (MLM) 11 | 12 | MLM is a generic task, implemented with the `jiant_task_name` "`mlm_simple`". In other words, it is meant to be used with any appropriately formatted file. 13 | 14 | `mlm_simple` expects input data files to be a single text file per phase, where each line corresponds to one example, and empty lines are ignored. This means that if a line corresponds to more than the `max_seq_length` of tokens during tokenization, everything past the first `max_seq_length` tokens per line will be ignored. We plan to add more complex implementations in the future. 15 | 16 | You can structure your MLM task config file as follow: 17 | 18 | ```json 19 | { 20 | "task": "mlm_simple", 21 | "paths": { 22 | "train": "/path/to/train.txt", 23 | "val": "/path/to/val.txt" 24 | }, 25 | "name": "my_mlm_task" 26 | } 27 | ``` 28 | 29 | ### UDPOS (XTREME) 30 | 31 | UDPOS requires a specific version `networkx` to download. You can install it via 32 | 33 | ```bash 34 | pip install networkx==1.11 35 | ``` 36 | 37 | 38 | ### PAN-X (XTREME) 39 | 40 | To preprocess PAN-X, you actually first need to download the file from: https://www.amazon.com/clouddrive/share/d3KGCRCIYwhKJF0H3eWA26hjg2ZCRhjpEQtDL70FSBN. 41 | 42 | The file should be named `AmazonPhotos.zip`, and it should be placed in `${task_data_base_path}/panx_temp/AmazonPhotos.zip` before running the download script. 43 | 44 | 45 | ### Bucc2018, Tatoeba (XTREME) 46 | 47 | The Bucc2018 and Tatoeba tasks are sentence retrieval tasks, and require the `faiss` library to run. `faiss-gpu` is recommended for speed reasons. 48 | 49 | We recommend running: 50 | 51 | ```bash 52 | conda install faiss-gpu cudatoolkit=10.1 -c pytorch 53 | ``` 54 | 55 | (Use the appropriate `cudatoolkit` version, which you can check with `nvcc --version`.) 56 | 57 | Additionally, the task-model corresponding to retrieval tasks outputs an pooled embedding from a given layer of the encoder. As such, both the layer and pooling method need to be specified in taskmodel config. For instance, to replicate the baseline used in the XTREME benchmark, consider using: 58 | 59 | ```python 60 | { 61 | "pooler_type": "mean", 62 | "layer": 14, 63 | } 64 | ``` 65 | 66 | Also note that neither task has training sets, and Tatoeba does not have a separate test set. -------------------------------------------------------------------------------- /jiant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/__init__.py -------------------------------------------------------------------------------- /jiant/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/ext/__init__.py -------------------------------------------------------------------------------- /jiant/proj/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/__init__.py -------------------------------------------------------------------------------- /jiant/proj/main/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/main/__init__.py -------------------------------------------------------------------------------- /jiant/proj/main/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/main/components/__init__.py -------------------------------------------------------------------------------- /jiant/proj/main/components/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import torch 5 | 6 | import jiant.utils.python.io as py_io 7 | import jiant.proj.main.components.task_sampler as jiant_task_sampler 8 | 9 | 10 | def write_val_results(val_results_dict, metrics_aggregator, output_dir, verbose=True): 11 | full_results_to_write = { 12 | "aggregated": jiant_task_sampler.compute_aggregate_major_metrics_from_results_dict( 13 | metrics_aggregator=metrics_aggregator, 14 | results_dict=val_results_dict, 15 | ), 16 | } 17 | for task_name, task_results in val_results_dict.items(): 18 | task_results_to_write = {} 19 | if "loss" in task_results: 20 | task_results_to_write["loss"] = task_results["loss"] 21 | if "metrics" in task_results: 22 | task_results_to_write["metrics"] = task_results["metrics"].to_dict() 23 | full_results_to_write[task_name] = task_results_to_write 24 | 25 | metrics_str = json.dumps(full_results_to_write, indent=2) 26 | if verbose: 27 | print(metrics_str) 28 | 29 | py_io.write_json(data=full_results_to_write, path=os.path.join(output_dir, "val_metrics.json")) 30 | 31 | 32 | def write_preds(eval_results_dict, path): 33 | preds_dict = {} 34 | for task_name, task_results_dict in eval_results_dict.items(): 35 | preds_dict[task_name] = { 36 | "preds": task_results_dict["preds"], 37 | "guids": task_results_dict["accumulator"].get_guids(), 38 | } 39 | torch.save(preds_dict, path) 40 | -------------------------------------------------------------------------------- /jiant/proj/main/components/outputs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict 3 | 4 | import torch 5 | 6 | from jiant.utils.python.datastructures import ExtendedDataClassMixin 7 | 8 | 9 | class BaseModelOutput(ExtendedDataClassMixin): 10 | pass 11 | 12 | 13 | @dataclass 14 | class LogitsOutput(BaseModelOutput): 15 | logits: torch.Tensor 16 | other: Any = None 17 | 18 | 19 | @dataclass 20 | class LogitsAndLossOutput(BaseModelOutput): 21 | logits: torch.Tensor 22 | loss: torch.Tensor 23 | other: Any = None 24 | 25 | 26 | @dataclass 27 | class EmbeddingOutput(BaseModelOutput): 28 | embedding: torch.Tensor 29 | other: Any = None 30 | 31 | 32 | def construct_output_from_dict(struct_dict: Dict): 33 | keys = sorted(list(struct_dict.keys())) 34 | if keys == ["logits", "other"]: 35 | return LogitsOutput.from_dict(struct_dict) 36 | elif keys == ["logits", "loss", "other"]: 37 | return LogitsAndLossOutput.from_dict(struct_dict) 38 | elif keys == ["embedding", "other"]: 39 | return EmbeddingOutput.from_dict(struct_dict) 40 | else: 41 | raise ValueError() 42 | -------------------------------------------------------------------------------- /jiant/proj/main/components/write_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jiant.utils.python.io as py_io 4 | 5 | 6 | def write_configs(config_dict, base_path): 7 | os.makedirs(base_path, exist_ok=True) 8 | config_keys = [ 9 | "task_config_path_dict", 10 | "task_cache_config_dict", 11 | "sampler_config", 12 | "global_train_config", 13 | "task_specific_configs_dict", 14 | "metric_aggregator_config", 15 | ] 16 | for path in config_dict["task_config_path_dict"].values(): 17 | assert os.path.exists(path) 18 | for path_dict in config_dict["task_cache_config_dict"].values(): 19 | for path in path_dict.values(): 20 | assert os.path.exists(path) 21 | for config_key in config_keys: 22 | py_io.write_json( 23 | config_dict[config_key], 24 | os.path.join(base_path, f"{config_key}.json"), 25 | ) 26 | py_io.write_json(config_dict, os.path.join(base_path, "full.json")) 27 | py_io.write_json( 28 | { 29 | f"{config_key}_path": os.path.join(base_path, f"{config_key}.json") 30 | for config_key in config_keys 31 | }, 32 | path=os.path.join(base_path, "zz_full.json"), 33 | ) 34 | -------------------------------------------------------------------------------- /jiant/proj/main/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/main/modeling/__init__.py -------------------------------------------------------------------------------- /jiant/proj/main/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/main/scripts/__init__.py -------------------------------------------------------------------------------- /jiant/proj/main/write_task_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jiant.utils.python.filesystem as py_filesystem 4 | import jiant.utils.python.io as py_io 5 | import jiant.utils.zconf as zconf 6 | 7 | 8 | def get_task_config(task_config_templates, task_name, task_data_dir): 9 | task_config = task_config_templates[task_name].copy() 10 | task_config["paths"] = {} 11 | for key, rel_path in task_config["rel_paths"].items(): 12 | if isinstance(rel_path, dict): 13 | raise RuntimeError("Nested path dicts not currently supported") 14 | task_config["paths"][key] = os.path.join(task_data_dir, rel_path) 15 | assert os.path.exists(task_config["paths"][key]) 16 | del task_config["rel_paths"] 17 | return task_config 18 | 19 | 20 | def create_and_write_task_config(task_name, task_data_dir, task_config_path): 21 | task_config_templates = py_io.read_json( 22 | py_filesystem.get_code_asset_path("assets/simple_api/task_config_templates.json") 23 | ) 24 | task_config = get_task_config( 25 | task_config_templates=task_config_templates, 26 | task_name=task_name, 27 | task_data_dir=task_data_dir, 28 | ) 29 | os.makedirs(os.path.split(task_config_path)[0], exist_ok=True) 30 | py_io.write_json(task_config, task_config_path) 31 | 32 | 33 | @zconf.run_config 34 | class RunConfiguration(zconf.RunConfig): 35 | # === Required parameters === # 36 | task_name = zconf.attr(type=str, required=True) 37 | task_data_dir = zconf.attr(type=str, required=True) 38 | task_config_path = zconf.attr(type=str, required=True) 39 | 40 | 41 | def main(args: RunConfiguration): 42 | create_and_write_task_config( 43 | task_name=args.task_name, 44 | task_data_dir=args.task_data_dir, 45 | task_config_path=args.task_config_path, 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | main(RunConfiguration.default_run_cli()) 51 | -------------------------------------------------------------------------------- /jiant/proj/simple/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/proj/simple/__init__.py -------------------------------------------------------------------------------- /jiant/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/scripts/__init__.py -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/scripts/benchmarks/__init__.py -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/benchmark_submission_formatter.py: -------------------------------------------------------------------------------- 1 | """Translate raw prediction files for benchmark tasks into format expected by 2 | benchmark leaderboards. 3 | """ 4 | import os 5 | import argparse 6 | 7 | from jiant.scripts.benchmarks.benchmarks import GlueBenchmark, SuperglueBenchmark 8 | 9 | 10 | SUPPORTED_BENCHMARKS = {"GLUE": GlueBenchmark, "SUPERGLUE": SuperglueBenchmark} 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser( 15 | description="Generate formatted test prediction files for benchmark submission" 16 | ) 17 | parser.add_argument( 18 | "--input_base_path", 19 | required=True, 20 | help="base path where per-task folders contain raw prediction files", 21 | ) 22 | parser.add_argument("--output_path", required=True, help="output path for formatted files") 23 | parser.add_argument( 24 | "--benchmark", required=True, choices=SUPPORTED_BENCHMARKS, help="name of benchmark" 25 | ) 26 | parser.add_argument( 27 | "--tasks", required=False, nargs="+", help="subset of benchmark tasks to format" 28 | ) 29 | args = parser.parse_args() 30 | 31 | benchmark = SUPPORTED_BENCHMARKS[args.benchmark] 32 | 33 | if args.tasks: 34 | assert set(args.tasks) <= benchmark.TASKS 35 | task_names = args.tasks 36 | else: 37 | task_names = benchmark.TASKS 38 | 39 | for task_name in task_names: 40 | input_filepath = os.path.join(args.input_base_path, task_name, "test_preds.p") 41 | output_filepath = os.path.join( 42 | os.path.abspath(args.output_path), benchmark.BENCHMARK_SUBMISSION_FILENAMES[task_name] 43 | ) 44 | benchmark.write_predictions(task_name, input_filepath, output_filepath) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/benchmarks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | 5 | import jiant.utils.python.io as py_io 6 | from jiant.tasks import retrieval 7 | from jiant.tasks.constants import GLUE_TASKS, SUPERGLUE_TASKS 8 | 9 | 10 | class Benchmark: 11 | TASKS = NotImplemented 12 | BENCHMARK_SUBMISSION_FILENAMES = NotImplemented 13 | 14 | @classmethod 15 | def write_predictions(cls, task_name: str, input_filepath: str, output_filepath: str): 16 | raise NotImplementedError 17 | 18 | 19 | # https://gluebenchmark.com/ 20 | class GlueBenchmark(Benchmark): 21 | TASKS = GLUE_TASKS 22 | BENCHMARK_SUBMISSION_FILENAMES = { 23 | "cola": "CoLA.tsv", 24 | "sst": "SST-2.tsv", 25 | "mrpc": "MRPC.tsv", 26 | "stsb": "STS-B.tsv", 27 | "mnli": "MNLI-m.tsv", 28 | "mnli_mismatched": "MNLI-mm.tsv", 29 | "qnli": "QNLI.tsv", 30 | "qqp": "QQP.tsv", 31 | "rte": "RTE.tsv", 32 | "wnli": "WNLI.tsv", 33 | "glue_diagnostics": "AX.tsv", 34 | } 35 | 36 | @classmethod 37 | def write_predictions(cls, task_name: str, input_filepath: str, output_filepath: str): 38 | task = retrieval.get_task_class(task_name) 39 | task_preds = torch.load(input_filepath)[task_name] 40 | indexes, predictions = task.get_glue_preds(task_preds) 41 | with open(output_filepath, "w") as f: 42 | writer = csv.writer(f, delimiter="\t") 43 | writer.writerow(("index", "prediction")) 44 | writer.writerows(zip(indexes, predictions)) 45 | 46 | 47 | # https://super.gluebenchmark.com/ 48 | class SuperglueBenchmark(Benchmark): 49 | TASKS = SUPERGLUE_TASKS 50 | BENCHMARK_SUBMISSION_FILENAMES = { 51 | "boolq": "BoolQ.jsonl", 52 | "cb": "CB.jsonl", 53 | "copa": "COPA.jsonl", 54 | "multirc": "MultiRC.jsonl", 55 | "record": "ReCoRD.jsonl", 56 | "rte": "RTE.jsonl", 57 | "wic": "WiC.jsonl", 58 | "wsc": "WSC.jsonl", 59 | "superglue_axb": "AX-b.jsonl", 60 | "superglue_axg": "AX-g.jsonl", 61 | } 62 | 63 | @classmethod 64 | def write_predictions(cls, task_name: str, input_filepath: str, output_filepath: str): 65 | task = retrieval.get_task_class(task_name) 66 | task_preds = torch.load(input_filepath)[task_name] 67 | formatted_preds = task.super_glue_format_preds(task_preds) 68 | py_io.write_jsonl( 69 | data=formatted_preds, 70 | path=os.path.join( 71 | SuperglueBenchmark.BENCHMARK_SUBMISSION_FILENAMES[task_name], output_filepath 72 | ), 73 | ) 74 | -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/xtreme/subscripts/a_download_model.sh: -------------------------------------------------------------------------------- 1 | # Requires variables: 2 | # MODEL_TYPE (e.g. xlm-roberta-large) 3 | # BASE_PATH 4 | # 5 | # Description: 6 | # This downloads a model (e.g. xlm-roberta-large) 7 | 8 | python jiant/proj/main/export_model.py \ 9 | --hf_pretrained_model_name_or_path ${MODEL_TYPE} \ 10 | --output_base_path ${BASE_PATH}/models/${MODEL_TYPE} 11 | -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/xtreme/subscripts/b_download_data.sh: -------------------------------------------------------------------------------- 1 | # Requires variables: 2 | # MODEL_TYPE (e.g. xlm-roberta-large) 3 | # 4 | # Description: 5 | # This downloads the XTREME datasets, as well as MNLI and SQuAD for training 6 | 7 | 8 | python jiant/scripts/download_data/runscript.py \ 9 | download \ 10 | --benchmark XTREME \ 11 | --output_path ${BASE_PATH}/tasks/ 12 | python jiant/scripts/download_data/runscript.py \ 13 | download \ 14 | --tasks mnli squad_v1 \ 15 | --output_path ${BASE_PATH}/tasks/ 16 | -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/xtreme/subscripts/e_run_models.sh: -------------------------------------------------------------------------------- 1 | # Requires variables: 2 | # MODEL_TYPE (e.g. xlm-roberta-large) 3 | # BASE_PATH 4 | # 5 | # Description: 6 | # This runs models for both fine-tuned and retrieval XTREME tasks 7 | # Ideally, this should be run in parallel on a cluster. 8 | 9 | for TASK in xnli pawsx udpos panx xquad mlqa tydiqa; do 10 | python jiant/proj/main/runscript.py \ 11 | run_with_continue \ 12 | --ZZsrc ${BASE_PATH}/models/${MODEL_TYPE}/config.json \ 13 | --jiant_task_container_config_path ${BASE_PATH}/runconfigs/${TASK}.json \ 14 | --model_load_mode from_transformers \ 15 | --learning_rate 1e-5 \ 16 | --eval_every_steps 1000 \ 17 | --no_improvements_for_n_evals 30 \ 18 | --do_save \ 19 | --force_overwrite \ 20 | --do_train --do_val \ 21 | --output_dir ${BASE_PATH}/runs/${TASK} 22 | done 23 | 24 | for TASK in bucc2018 tatoeba; do 25 | python jiant/proj/main/runscript.py \ 26 | run_with_continue \ 27 | --ZZsrc ${BASE_PATH}/models/${MODEL_TYPE}/config.json \ 28 | --jiant_task_container_config_path ${BASE_PATH}/runconfigs/${TASK}.json \ 29 | --model_load_mode from_transformers \ 30 | --force_overwrite \ 31 | --do_val \ 32 | --output_dir ${BASE_PATH}/runs/${TASK} 33 | done 34 | -------------------------------------------------------------------------------- /jiant/scripts/benchmarks/xtreme/subscripts/run_all.sh: -------------------------------------------------------------------------------- 1 | # Requires variables: 2 | # MODEL_TYPE (e.g. xlm-roberta-large) 3 | # BASE_PATH 4 | 5 | bash a_download_model.sh 6 | bash b_download_data.sh 7 | bash c_tokenize_and_cache.sh 8 | bash d_write_configs.sh 9 | bash e_run_models.sh 10 | -------------------------------------------------------------------------------- /jiant/scripts/download_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/scripts/download_data/__init__.py -------------------------------------------------------------------------------- /jiant/scripts/download_data/constants.py: -------------------------------------------------------------------------------- 1 | # Directly download tasks when not available in HF Datasets, or HF Datasets version 2 | # is not suitable 3 | SQUAD_TASKS = {"squad_v1", "squad_v2"} 4 | DIRECT_SUPERGLUE_TASKS_TO_DATA_URLS = { 5 | "wsc": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip", 6 | "multirc": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/MultiRC.zip", 7 | "record": "https://dl.fbaipublicfiles.com/glue/superglue/data/v2/ReCoRD.zip", 8 | } 9 | 10 | OTHER_DOWNLOAD_TASKS = { 11 | "abductive_nli", 12 | "arct", 13 | "fever_nli", 14 | "swag", 15 | "qamr", 16 | "qasrl", 17 | "newsqa", 18 | "mctaco", 19 | "mctest160", 20 | "mctest500", 21 | "mrqa_natural_questions", 22 | "mutual", 23 | "mutual_plus", 24 | "piqa", 25 | "winogrande", 26 | "ropes", 27 | "acceptability_definiteness", 28 | "acceptability_coord", 29 | "acceptability_eos", 30 | "acceptability_whwords", 31 | "senteval_bigram_shift", 32 | "senteval_coordination_inversion", 33 | "senteval_obj_number", 34 | "senteval_odd_man_out", 35 | "senteval_past_present", 36 | "senteval_sentence_length", 37 | "senteval_subj_number", 38 | "senteval_top_constituents", 39 | "senteval_tree_depth", 40 | "senteval_word_content", 41 | } 42 | 43 | DIRECT_DOWNLOAD_TASKS = set( 44 | list(SQUAD_TASKS) + list(DIRECT_SUPERGLUE_TASKS_TO_DATA_URLS) + list(OTHER_DOWNLOAD_TASKS) 45 | ) 46 | OTHER_HF_DATASETS_TASKS = { 47 | "snli", 48 | "commonsenseqa", 49 | "hellaswag", 50 | "cosmosqa", 51 | "socialiqa", 52 | "scitail", 53 | "quoref", 54 | "adversarial_nli_r1", 55 | "adversarial_nli_r2", 56 | "adversarial_nli_r3", 57 | "arc_easy", 58 | "arc_challenge", 59 | "race", 60 | "race_middle", 61 | "race_high", 62 | "quail", 63 | } 64 | -------------------------------------------------------------------------------- /jiant/scripts/download_data/dl_datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/scripts/download_data/dl_datasets/__init__.py -------------------------------------------------------------------------------- /jiant/scripts/preproc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/scripts/preproc/__init__.py -------------------------------------------------------------------------------- /jiant/shared/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/shared/__init__.py -------------------------------------------------------------------------------- /jiant/shared/constants.py: -------------------------------------------------------------------------------- 1 | class PHASE: 2 | TRAIN = "train" 3 | VAL = "val" 4 | TEST = "test" 5 | -------------------------------------------------------------------------------- /jiant/shared/distributed.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | 4 | 5 | @contextmanager 6 | def only_first_process(local_rank): 7 | if local_rank not in [-1, 0]: 8 | # noinspection PyUnresolvedReferences 9 | torch.distributed.barrier() 10 | 11 | try: 12 | yield 13 | finally: 14 | if local_rank == 0: 15 | # noinspection PyUnresolvedReferences 16 | torch.distributed.barrier() 17 | -------------------------------------------------------------------------------- /jiant/shared/metarunner.py: -------------------------------------------------------------------------------- 1 | class AbstractMetarunner: 2 | def begin_training(self): 3 | raise NotImplementedError() 4 | 5 | def yield_train_step(self): 6 | raise NotImplementedError() 7 | 8 | def should_save_model(self) -> bool: 9 | raise NotImplementedError() 10 | 11 | def save_model(self): 12 | raise NotImplementedError() 13 | 14 | def should_save_checkpoint(self) -> bool: 15 | raise NotImplementedError() 16 | 17 | def save_checkpoint(self): 18 | raise NotImplementedError() 19 | 20 | def should_eval_model(self) -> bool: 21 | raise NotImplementedError() 22 | 23 | def eval_model(self): 24 | raise NotImplementedError() 25 | 26 | def should_break_training(self) -> bool: 27 | raise NotImplementedError() 28 | 29 | def done_training(self): 30 | raise NotImplementedError() 31 | 32 | def returned_result(self): 33 | raise NotImplementedError() 34 | 35 | def run_train_loop(self): 36 | self.begin_training() 37 | 38 | for _ in self.yield_train_step(): 39 | if self.should_save_model(): 40 | self.save_model() 41 | 42 | if self.should_save_checkpoint(): 43 | self.save_checkpoint() 44 | 45 | if self.should_eval_model(): 46 | self.eval_model() 47 | 48 | if self.should_break_training(): 49 | break 50 | 51 | self.eval_model() 52 | self.done_training() 53 | 54 | return self.returned_result() 55 | -------------------------------------------------------------------------------- /jiant/shared/model_resolution.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from jiant.utils.python.datastructures import BiDict 4 | 5 | import transformers 6 | 7 | 8 | class ModelArchitectures(Enum): 9 | BERT = "bert" 10 | XLM = "xlm" 11 | ROBERTA = "roberta" 12 | ALBERT = "albert" 13 | XLM_ROBERTA = "xlm-roberta" 14 | BART = "bart" 15 | MBART = "mbart" 16 | ELECTRA = "electra" 17 | DEBERTAV2 = "deberta-v2" 18 | 19 | @classmethod 20 | def from_model_type(cls, model_type: str): 21 | return cls(model_type) 22 | 23 | def get_encoder_prefix(self): 24 | if self.value == "xlm-roberta": 25 | return "roberta" 26 | else: 27 | return self.value 28 | 29 | 30 | TOKENIZER_CLASS_DICT = BiDict( 31 | { 32 | ModelArchitectures.BERT: transformers.BertTokenizer, 33 | ModelArchitectures.XLM: transformers.XLMTokenizer, 34 | ModelArchitectures.ROBERTA: transformers.RobertaTokenizer, 35 | ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer, 36 | ModelArchitectures.ALBERT: transformers.AlbertTokenizer, 37 | ModelArchitectures.BART: transformers.BartTokenizer, 38 | ModelArchitectures.MBART: transformers.MBartTokenizer, 39 | ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, 40 | ModelArchitectures.DEBERTAV2: transformers.DebertaV2Tokenizer, 41 | } 42 | ) 43 | 44 | 45 | @dataclass 46 | class ModelClassSpec: 47 | config_class: type 48 | tokenizer_class: type 49 | model_class: type 50 | 51 | 52 | def resolve_tokenizer_class(model_type): 53 | """Get tokenizer class for a given model architecture. 54 | 55 | Args: 56 | model_type (str): model shortcut name. 57 | 58 | Returns: 59 | Tokenizer associated with the given model. 60 | 61 | """ 62 | return TOKENIZER_CLASS_DICT[ModelArchitectures(model_type)] 63 | 64 | 65 | def resolve_model_arch_tokenizer(tokenizer): 66 | """Get the model architecture for a given tokenizer. 67 | 68 | Args: 69 | tokenizer 70 | 71 | Returns: 72 | ModelArchitecture 73 | 74 | """ 75 | assert len(TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__]) == 1 76 | return TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__][0] 77 | 78 | 79 | def resolve_is_lower_case(tokenizer): 80 | if isinstance(tokenizer, transformers.BertTokenizer): 81 | return tokenizer.basic_tokenizer.do_lower_case 82 | if isinstance(tokenizer, transformers.AlbertTokenizer): 83 | return tokenizer.do_lower_case 84 | else: 85 | return False 86 | 87 | 88 | def bart_or_mbart_model_heuristic(model_config: transformers.BartConfig) -> ModelArchitectures: 89 | if model_config.is_valid_mbart(): 90 | return ModelArchitectures.MBART 91 | else: 92 | return ModelArchitectures.BART 93 | -------------------------------------------------------------------------------- /jiant/shared/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | import jiant.shared.caching as caching 8 | import jiant.utils.python.io as py_io 9 | import jiant.utils.torch_utils as torch_utils 10 | 11 | 12 | def complex_backpropagate( 13 | loss, optimizer, model, fp16, n_gpu, gradient_accumulation_steps, max_grad_norm 14 | ): 15 | if n_gpu > 1: 16 | loss = loss.mean() # mean() to average on multi-gpu. 17 | if gradient_accumulation_steps > 1: 18 | loss = loss / gradient_accumulation_steps 19 | if fp16: 20 | # noinspection PyUnresolvedReferences,PyPackageRequirements 21 | from apex import amp 22 | 23 | with amp.scale_loss(loss, optimizer) as scaled_loss: 24 | scaled_loss.backward() 25 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm) 26 | else: 27 | loss.backward() 28 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 29 | return loss 30 | 31 | 32 | def get_train_dataloader_from_cache( 33 | train_cache: caching.ChunkedFilesDataCache, task, train_batch_size: int 34 | ): 35 | # TODO: Expose buffer_size parameter (issue #1183) 36 | dataset = train_cache.get_iterable_dataset(buffer_size=10000, shuffle=True) 37 | train_dataloader = torch_utils.DataLoaderWithLength( 38 | dataset=dataset, 39 | batch_size=train_batch_size, 40 | collate_fn=task.collate_fn, 41 | ) 42 | return train_dataloader 43 | 44 | 45 | def get_eval_dataloader_from_cache( 46 | eval_cache: caching.ChunkedFilesDataCache, 47 | task, 48 | eval_batch_size: int, 49 | subset_num=None, 50 | explicit_subset=None, 51 | ): 52 | dataset = eval_cache.get_iterable_dataset( 53 | buffer_size=10000, 54 | shuffle=False, 55 | subset_num=subset_num, 56 | explicit_subset=explicit_subset, 57 | ) 58 | eval_dataloader = torch_utils.DataLoaderWithLength( 59 | dataset=dataset, 60 | batch_size=eval_batch_size, 61 | collate_fn=task.collate_fn, 62 | ) 63 | return eval_dataloader 64 | 65 | 66 | def save_model_with_metadata( 67 | model_or_state_dict: Union[nn.Module, dict], 68 | output_dir: str, 69 | file_name="model", 70 | metadata: Optional[dict] = None, 71 | ): 72 | if isinstance(model_or_state_dict, dict): 73 | state_dict = model_or_state_dict 74 | else: 75 | state_dict = torch_utils.get_model_for_saving(model_or_state_dict).state_dict() 76 | 77 | torch.save(state_dict, os.path.join(output_dir, f"{file_name}.p")) 78 | if metadata is not None: 79 | py_io.write_json(metadata, os.path.join(output_dir, f"{file_name}.metadata.json")) 80 | 81 | 82 | def compare_steps_max_steps(step, max_steps): 83 | return max_steps is not None and max_steps != -1 and step >= max_steps 84 | -------------------------------------------------------------------------------- /jiant/tasks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/constants.py: -------------------------------------------------------------------------------- 1 | GLUE_TASKS = { 2 | "cola", 3 | "sst", 4 | "mrpc", 5 | "qqp", 6 | "stsb", 7 | "mnli", 8 | "mnli_mismatched", 9 | "qnli", 10 | "rte", 11 | "wnli", 12 | "glue_diagnostics", 13 | } 14 | 15 | SUPERGLUE_TASKS = { 16 | "cb", 17 | "copa", 18 | "multirc", 19 | "wic", 20 | "wsc", 21 | "boolq", 22 | "record", 23 | "rte", 24 | "superglue_broadcoverage_diagnostics", 25 | "superglue_winogender_diagnostics", 26 | } 27 | 28 | XTREME_TASKS = { 29 | "xnli", 30 | "pawsx", 31 | "udpos", 32 | "panx", 33 | "xquad", 34 | "mlqa", 35 | "tydiqa", 36 | "bucc2018", 37 | "tatoeba", 38 | } 39 | 40 | BENCHMARKS = {"GLUE": GLUE_TASKS, "SUPERGLUE": SUPERGLUE_TASKS, "XTREME": XTREME_TASKS} 41 | -------------------------------------------------------------------------------- /jiant/tasks/evaluate/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * # noqa: F401,F403 2 | -------------------------------------------------------------------------------- /jiant/tasks/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/acceptability_judgement/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/acceptability_judgement/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/acceptability_judgement/coord.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | 4 | 5 | @dataclass 6 | class Example(base.Example): 7 | pass 8 | 9 | 10 | @dataclass 11 | class TokenizedExample(base.TokenizedExample): 12 | pass 13 | 14 | 15 | @dataclass 16 | class DataRow(base.DataRow): 17 | pass 18 | 19 | 20 | @dataclass 21 | class Batch(base.Batch): 22 | pass 23 | 24 | 25 | class AcceptabilityCoordTask(base.BaseAcceptabilityTask): 26 | pass 27 | -------------------------------------------------------------------------------- /jiant/tasks/lib/acceptability_judgement/definiteness.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | 4 | 5 | @dataclass 6 | class Example(base.Example): 7 | pass 8 | 9 | 10 | @dataclass 11 | class TokenizedExample(base.TokenizedExample): 12 | pass 13 | 14 | 15 | @dataclass 16 | class DataRow(base.DataRow): 17 | pass 18 | 19 | 20 | @dataclass 21 | class Batch(base.Batch): 22 | pass 23 | 24 | 25 | class AcceptabilityDefinitenessTask(base.BaseAcceptabilityTask): 26 | pass 27 | -------------------------------------------------------------------------------- /jiant/tasks/lib/acceptability_judgement/eos.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | 4 | 5 | @dataclass 6 | class Example(base.Example): 7 | pass 8 | 9 | 10 | @dataclass 11 | class TokenizedExample(base.TokenizedExample): 12 | pass 13 | 14 | 15 | @dataclass 16 | class DataRow(base.DataRow): 17 | pass 18 | 19 | 20 | @dataclass 21 | class Batch(base.Batch): 22 | pass 23 | 24 | 25 | class AcceptabilityEOSTask(base.BaseAcceptabilityTask): 26 | pass 27 | -------------------------------------------------------------------------------- /jiant/tasks/lib/acceptability_judgement/whwords.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | 4 | 5 | @dataclass 6 | class Example(base.Example): 7 | pass 8 | 9 | 10 | @dataclass 11 | class TokenizedExample(base.TokenizedExample): 12 | pass 13 | 14 | 15 | @dataclass 16 | class DataRow(base.DataRow): 17 | pass 18 | 19 | 20 | @dataclass 21 | class Batch(base.Batch): 22 | pass 23 | 24 | 25 | class AcceptabilityWHwordsTask(base.BaseAcceptabilityTask): 26 | pass 27 | -------------------------------------------------------------------------------- /jiant/tasks/lib/arc_challenge.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return ArcChallengeTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class ArcChallengeTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C", "D", "E"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | potential_label_map = { 52 | "1": "A", 53 | "2": "B", 54 | "3": "C", 55 | "4": "D", 56 | "5": "E", 57 | } 58 | NUM_CHOICES = len(potential_label_map) 59 | examples = [] 60 | for i, line in enumerate(lines): 61 | label = line["answerKey"] 62 | if label in potential_label_map: 63 | label = potential_label_map[label] 64 | choice_list = [d for d in line["choices"]["text"]] 65 | filler_choice_list = ["." for i in range(NUM_CHOICES - len(choice_list))] 66 | choice_list = choice_list + filler_choice_list 67 | assert len(choice_list) == NUM_CHOICES 68 | 69 | examples.append( 70 | Example( 71 | guid="%s-%s" % (set_type, i), 72 | prompt=line["question"], 73 | choice_list=choice_list, 74 | label=label, 75 | ) 76 | ) 77 | return examples 78 | -------------------------------------------------------------------------------- /jiant/tasks/lib/arc_easy.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return ArcEasyTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class ArcEasyTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C", "D", "E"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | potential_label_map = { 52 | "1": "A", 53 | "2": "B", 54 | "3": "C", 55 | "4": "D", 56 | "5": "E", 57 | } 58 | NUM_CHOICES = len(potential_label_map) 59 | examples = [] 60 | for i, line in enumerate(lines): 61 | label = line["answerKey"] 62 | if label in potential_label_map: 63 | label = potential_label_map[label] 64 | choice_list = [d for d in line["choices"]["text"]] 65 | filler_choice_list = ["." for i in range(NUM_CHOICES - len(choice_list))] 66 | choice_list = choice_list + filler_choice_list 67 | assert len(choice_list) == NUM_CHOICES 68 | 69 | examples.append( 70 | Example( 71 | guid="%s-%s" % (set_type, i), 72 | prompt=line["question"], 73 | choice_list=choice_list, 74 | label=label, 75 | ) 76 | ) 77 | return examples 78 | -------------------------------------------------------------------------------- /jiant/tasks/lib/arct.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import pandas as pd 4 | 5 | from jiant.tasks.lib.templates.shared import labels_to_bimap 6 | from jiant.tasks.lib.templates import multiple_choice as mc_template 7 | 8 | 9 | @dataclass 10 | class Example(mc_template.Example): 11 | @property 12 | def task(self): 13 | return ArctTask 14 | 15 | 16 | @dataclass 17 | class TokenizedExample(mc_template.TokenizedExample): 18 | pass 19 | 20 | 21 | @dataclass 22 | class DataRow(mc_template.DataRow): 23 | pass 24 | 25 | 26 | @dataclass 27 | class Batch(mc_template.Batch): 28 | pass 29 | 30 | 31 | class ArctTask(mc_template.AbstractMultipleChoiceTask): 32 | Example = Example 33 | TokenizedExample = Example 34 | DataRow = DataRow 35 | Batch = Batch 36 | 37 | CHOICE_KEYS = [0, 1] 38 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 39 | NUM_CHOICES = len(CHOICE_KEYS) 40 | 41 | def get_train_examples(self): 42 | return self._create_examples(self.train_path, set_type="train") 43 | 44 | def get_val_examples(self): 45 | return self._create_examples(self.val_path, set_type="val") 46 | 47 | def get_test_examples(self): 48 | return self._create_examples(self.test_path, set_type="test") 49 | 50 | @classmethod 51 | def _create_examples(cls, path, set_type): 52 | df_names = [ 53 | "#id", 54 | "warrant0", 55 | "warrant1", 56 | "gold_label", 57 | "reason", 58 | "claim", 59 | "debateTitle", 60 | "debateInfo", 61 | ] 62 | 63 | df = pd.read_csv( 64 | path, 65 | sep="\t", 66 | header=0, 67 | names=df_names, 68 | ) 69 | choice_pre = "And since " 70 | examples = [] 71 | 72 | for i, row in enumerate(df.itertuples()): 73 | # Repo explanation from https://github.com/UKPLab/argument-reasoning-comprehension-task 74 | examples.append( 75 | Example( 76 | guid="%s-%s" % (set_type, i), 77 | prompt=row.reason + " ", 78 | choice_list=[ 79 | choice_pre + row.warrant0 + ", " + row.claim, 80 | choice_pre + row.warrant1 + ", " + row.claim, 81 | ], 82 | label=row.gold_label if set_type != "test" else cls.CHOICE_KEYS[-1], 83 | ) 84 | ) 85 | 86 | return examples 87 | -------------------------------------------------------------------------------- /jiant/tasks/lib/cola.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | from jiant.tasks.core import ( 7 | BaseExample, 8 | BaseTokenizedExample, 9 | BaseDataRow, 10 | BatchMixin, 11 | GlueMixin, 12 | Task, 13 | TaskTypes, 14 | ) 15 | from jiant.tasks.lib.templates.shared import single_sentence_featurize, labels_to_bimap 16 | from jiant.utils.python.io import read_jsonl 17 | 18 | 19 | @dataclass 20 | class Example(BaseExample): 21 | guid: str 22 | text: str 23 | label: str 24 | 25 | def tokenize(self, tokenizer): 26 | return TokenizedExample( 27 | guid=self.guid, 28 | text=tokenizer.tokenize(self.text), 29 | label_id=ColaTask.LABEL_TO_ID[self.label], 30 | ) 31 | 32 | 33 | @dataclass 34 | class TokenizedExample(BaseTokenizedExample): 35 | guid: str 36 | text: List 37 | label_id: int 38 | 39 | def featurize(self, tokenizer, feat_spec): 40 | return single_sentence_featurize( 41 | guid=self.guid, 42 | input_tokens=self.text, 43 | label_id=self.label_id, 44 | tokenizer=tokenizer, 45 | feat_spec=feat_spec, 46 | data_row_class=DataRow, 47 | ) 48 | 49 | 50 | @dataclass 51 | class DataRow(BaseDataRow): 52 | guid: str 53 | input_ids: np.ndarray 54 | input_mask: np.ndarray 55 | segment_ids: np.ndarray 56 | label_id: int 57 | tokens: list 58 | 59 | 60 | @dataclass 61 | class Batch(BatchMixin): 62 | input_ids: torch.LongTensor 63 | input_mask: torch.LongTensor 64 | segment_ids: torch.LongTensor 65 | label_id: torch.LongTensor 66 | tokens: list 67 | 68 | 69 | class ColaTask(GlueMixin, Task): 70 | Example = Example 71 | TokenizedExample = Example 72 | DataRow = DataRow 73 | Batch = Batch 74 | 75 | TASK_TYPE = TaskTypes.CLASSIFICATION 76 | LABELS = ["0", "1"] 77 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 78 | 79 | def get_train_examples(self): 80 | return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") 81 | 82 | def get_val_examples(self): 83 | return self._create_examples(lines=read_jsonl(self.val_path), set_type="val") 84 | 85 | def get_test_examples(self): 86 | return self._create_examples(lines=read_jsonl(self.test_path), set_type="test") 87 | 88 | @classmethod 89 | def _create_examples(cls, lines, set_type): 90 | examples = [] 91 | for (i, line) in enumerate(lines): 92 | examples.append( 93 | Example( 94 | # NOTE: get_glue_preds() is dependent on this guid format. 95 | guid="%s-%s" % (set_type, i), 96 | text=line["text"], 97 | label=line["label"] if set_type != "test" else cls.LABELS[-1], 98 | ) 99 | ) 100 | return examples 101 | -------------------------------------------------------------------------------- /jiant/tasks/lib/copa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | from jiant.tasks.core import SuperGlueMixin 7 | 8 | 9 | @dataclass 10 | class Example(mc_template.Example): 11 | @property 12 | def task(self): 13 | return CopaTask 14 | 15 | 16 | @dataclass 17 | class TokenizedExample(mc_template.TokenizedExample): 18 | pass 19 | 20 | 21 | @dataclass 22 | class DataRow(mc_template.DataRow): 23 | pass 24 | 25 | 26 | @dataclass 27 | class Batch(mc_template.Batch): 28 | pass 29 | 30 | 31 | class CopaTask(SuperGlueMixin, mc_template.AbstractMultipleChoiceTask): 32 | Example = Example 33 | TokenizedExample = Example 34 | DataRow = DataRow 35 | Batch = Batch 36 | 37 | CHOICE_KEYS = [0, 1] 38 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 39 | NUM_CHOICES = len(CHOICE_KEYS) 40 | 41 | _QUESTION_DICT = { 42 | "cause": "What was the cause of this?", 43 | "effect": "What happened as a result?", 44 | } 45 | 46 | def get_train_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 48 | 49 | def get_val_examples(self): 50 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 51 | 52 | def get_test_examples(self): 53 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 54 | 55 | @classmethod 56 | def _create_examples(cls, lines, set_type): 57 | examples = [] 58 | for line in lines: 59 | question = cls._QUESTION_DICT[line["question"]] 60 | examples.append( 61 | Example( 62 | # NOTE: CopaTask.super_glue_format_preds() is dependent on this guid format. 63 | guid="%s-%s" % (set_type, line["idx"]), 64 | prompt=line["premise"] + " " + question, 65 | choice_list=[line["choice1"], line["choice2"]], 66 | label=line["label"] if set_type != "test" else cls.CHOICE_KEYS[-1], 67 | ) 68 | ) 69 | return examples 70 | 71 | @classmethod 72 | def super_glue_format_preds(cls, pred_dict): 73 | """Reformat this task's raw predictions to have the structure expected by SuperGLUE.""" 74 | lines = [] 75 | for pred, guid in zip(list(pred_dict["preds"]), list(pred_dict["guids"])): 76 | lines.append({"idx": int(guid.split("-")[1]), "label": cls.CHOICE_KEYS[pred]}) 77 | return lines 78 | -------------------------------------------------------------------------------- /jiant/tasks/lib/cosmosqa.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dataclasses import dataclass 3 | 4 | from jiant.tasks.lib.templates.shared import labels_to_bimap 5 | from jiant.tasks.lib.templates import multiple_choice as mc_template 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return CosmosQATask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class CosmosQATask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1, 2, 3] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(path=self.train_path, set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(path=self.val_path, set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(path=self.test_path, set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, path, set_type): 51 | if path.endswith(".csv"): 52 | df = pd.read_csv(path) 53 | elif path.endswith(".jsonl"): 54 | df = pd.read_json(path, lines=True) 55 | else: 56 | raise RuntimeError("Format not supported") 57 | examples = [] 58 | for i, row in enumerate(df.itertuples()): 59 | examples.append( 60 | Example( 61 | guid="%s-%s" % (set_type, i), 62 | prompt=row.context + " " + row.question, 63 | choice_list=[row.answer0, row.answer1, row.answer2, row.answer3], 64 | label=row.label if set_type != "test" else cls.CHOICE_KEYS[-1], 65 | ) 66 | ) 67 | return examples 68 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/edge_probing/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/coref.py: -------------------------------------------------------------------------------- 1 | """Coreference Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_two_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_two_span.Example): 16 | @property 17 | def task(self): 18 | return CorefTask 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_two_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_two_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_two_span.Batch): 33 | pass 34 | 35 | 36 | class CorefTask(edge_probing_two_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = ["0", "1"] 43 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 44 | 45 | @property 46 | def num_spans(self): 47 | return 2 48 | 49 | def get_train_examples(self): 50 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 51 | 52 | def get_val_examples(self): 53 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 54 | 55 | def get_test_examples(self): 56 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 57 | 58 | @classmethod 59 | def _create_examples(cls, lines, set_type): 60 | examples = [] 61 | for (line_num, line) in enumerate(lines): 62 | for (target_num, target) in enumerate(line["targets"]): 63 | span1 = target["span1"] 64 | span2 = target["span2"] 65 | examples.append( 66 | Example( 67 | guid="%s-%s-%s" % (set_type, line_num, target_num), 68 | text=line["text"], 69 | span1=span1, 70 | span2=span2, 71 | labels=[target["label"]] if set_type != "test" else [cls.LABELS[-1]], 72 | ) 73 | ) 74 | return examples 75 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/dpr.py: -------------------------------------------------------------------------------- 1 | """Definite Pronoun Resolution Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_two_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_two_span.Example): 16 | @property 17 | def task(self): 18 | return DprTask 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_two_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_two_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_two_span.Batch): 33 | pass 34 | 35 | 36 | class DprTask(edge_probing_two_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = ["entailed", "not-entailed"] 43 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 44 | 45 | @property 46 | def num_spans(self): 47 | return 2 48 | 49 | def get_train_examples(self): 50 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 51 | 52 | def get_val_examples(self): 53 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 54 | 55 | def get_test_examples(self): 56 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 57 | 58 | @classmethod 59 | def _create_examples(cls, lines, set_type): 60 | examples = [] 61 | for (line_num, line) in enumerate(lines): 62 | for (target_num, target) in enumerate(line["targets"]): 63 | span1 = target["span1"] 64 | span2 = target["span2"] 65 | examples.append( 66 | Example( 67 | guid="%s-%s-%s" % (set_type, line_num, target_num), 68 | text=line["text"], 69 | span1=span1, 70 | span2=span2, 71 | labels=[target["label"]] if set_type != "test" else [cls.LABELS[-1]], 72 | ) 73 | ) 74 | return examples 75 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/ner.py: -------------------------------------------------------------------------------- 1 | """Named entity labeling Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_single_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_single_span.Example): 16 | @property 17 | def task(self): 18 | return NerTask 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_single_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_single_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_single_span.Batch): 33 | pass 34 | 35 | 36 | class NerTask(edge_probing_single_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = [ 43 | "CARDINAL", 44 | "DATE", 45 | "EVENT", 46 | "FAC", 47 | "GPE", 48 | "LANGUAGE", 49 | "LAW", 50 | "LOC", 51 | "MONEY", 52 | "NORP", 53 | "ORDINAL", 54 | "ORG", 55 | "PERCENT", 56 | "PERSON", 57 | "PRODUCT", 58 | "QUANTITY", 59 | "TIME", 60 | "WORK_OF_ART", 61 | ] 62 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 63 | 64 | @property 65 | def num_spans(self): 66 | return 1 67 | 68 | def get_train_examples(self): 69 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 70 | 71 | def get_val_examples(self): 72 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 73 | 74 | def get_test_examples(self): 75 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 76 | 77 | @classmethod 78 | def _create_examples(cls, lines, set_type): 79 | examples = [] 80 | for (line_num, line) in enumerate(lines): 81 | for (target_num, target) in enumerate(line["targets"]): 82 | span = target["span1"] 83 | examples.append( 84 | Example( 85 | guid="%s-%s-%s" % (set_type, line_num, target_num), 86 | text=line["text"], 87 | span=span, 88 | labels=[target["label"]] if set_type != "test" else [cls.LABELS[-1]], 89 | ) 90 | ) 91 | return examples 92 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/nonterminal.py: -------------------------------------------------------------------------------- 1 | """Constituent labeling (assigning a non-terminal label) Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_single_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_single_span.Example): 16 | @property 17 | def task(self): 18 | return NonterminalTask 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_single_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_single_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_single_span.Batch): 33 | pass 34 | 35 | 36 | class NonterminalTask(edge_probing_single_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = [ 43 | "ADJP", 44 | "ADVP", 45 | "CONJP", 46 | "EMBED", 47 | "FRAG", 48 | "INTJ", 49 | "LST", 50 | "META", 51 | "NAC", 52 | "NML", 53 | "NP", 54 | "NX", 55 | "PP", 56 | "PRN", 57 | "PRT", 58 | "QP", 59 | "RRC", 60 | "S", 61 | "SBAR", 62 | "SBARQ", 63 | "SINV", 64 | "SQ", 65 | "TOP", 66 | "UCP", 67 | "VP", 68 | "WHADJP", 69 | "WHADVP", 70 | "WHNP", 71 | "WHPP", 72 | "X", 73 | ] 74 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 75 | 76 | @property 77 | def num_spans(self): 78 | return 1 79 | 80 | def get_train_examples(self): 81 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 82 | 83 | def get_val_examples(self): 84 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 85 | 86 | def get_test_examples(self): 87 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 88 | 89 | @classmethod 90 | def _create_examples(cls, lines, set_type): 91 | examples = [] 92 | for (line_num, line) in enumerate(lines): 93 | for (target_num, target) in enumerate(line["targets"]): 94 | span = target["span1"] 95 | examples.append( 96 | Example( 97 | guid="%s-%s-%s" % (set_type, line_num, target_num), 98 | text=line["text"], 99 | span=span, 100 | labels=[target["label"]] if set_type != "test" else [cls.LABELS[-1]], 101 | ) 102 | ) 103 | return examples 104 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/semeval.py: -------------------------------------------------------------------------------- 1 | """Relation Classification Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_two_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_two_span.Example): 16 | @property 17 | def task(self): 18 | return SemevalTask 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_two_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_two_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_two_span.Batch): 33 | pass 34 | 35 | 36 | class SemevalTask(edge_probing_two_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = [ 43 | "Cause-Effect(e1,e2)", 44 | "Cause-Effect(e2,e1)", 45 | "Component-Whole(e1,e2)", 46 | "Component-Whole(e2,e1)", 47 | "Content-Container(e1,e2)", 48 | "Content-Container(e2,e1)", 49 | "Entity-Destination(e1,e2)", 50 | "Entity-Destination(e2,e1)", 51 | "Entity-Origin(e1,e2)", 52 | "Entity-Origin(e2,e1)", 53 | "Instrument-Agency(e1,e2)", 54 | "Instrument-Agency(e2,e1)", 55 | "Member-Collection(e1,e2)", 56 | "Member-Collection(e2,e1)", 57 | "Message-Topic(e1,e2)", 58 | "Message-Topic(e2,e1)", 59 | "Other", 60 | "Product-Producer(e1,e2)", 61 | "Product-Producer(e2,e1)", 62 | ] 63 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 64 | 65 | @property 66 | def num_spans(self): 67 | return 2 68 | 69 | def get_train_examples(self): 70 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 71 | 72 | def get_val_examples(self): 73 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 74 | 75 | def get_test_examples(self): 76 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 77 | 78 | @classmethod 79 | def _create_examples(cls, lines, set_type): 80 | examples = [] 81 | for (line_num, line) in enumerate(lines): 82 | for (target_num, target) in enumerate(line["targets"]): 83 | span1 = target["span1"] 84 | span2 = target["span2"] 85 | examples.append( 86 | Example( 87 | guid="%s-%s-%s" % (set_type, line_num, target_num), 88 | text=line["text"], 89 | span1=span1, 90 | span2=span2, 91 | labels=[target["label"]] if set_type != "test" else [cls.LABELS[-1]], 92 | ) 93 | ) 94 | return examples 95 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/spr1.py: -------------------------------------------------------------------------------- 1 | """Semantic proto-role (1) Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_two_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_two_span.Example): 16 | @property 17 | def task(self): 18 | return Spr1Task 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_two_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_two_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_two_span.Batch): 33 | pass 34 | 35 | 36 | class Spr1Task(edge_probing_two_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = [ 43 | "awareness", 44 | "change_of_location", 45 | "change_of_state", 46 | "changes_possession", 47 | "created", 48 | "destroyed", 49 | "existed_after", 50 | "existed_before", 51 | "existed_during", 52 | "exists_as_physical", 53 | "instigation", 54 | "location_of_event", 55 | "makes_physical_contact", 56 | "manipulated_by_another", 57 | "predicate_changed_argument", 58 | "sentient", 59 | "stationary", 60 | "volition", 61 | ] 62 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 63 | 64 | @property 65 | def num_spans(self): 66 | return 2 67 | 68 | def get_train_examples(self): 69 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 70 | 71 | def get_val_examples(self): 72 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 73 | 74 | def get_test_examples(self): 75 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 76 | 77 | @classmethod 78 | def _create_examples(cls, lines, set_type): 79 | examples = [] 80 | for (line_num, line) in enumerate(lines): 81 | # A line in the task's data file can contain multiple targets (span-pair + labels). 82 | # We create an example for every target: 83 | for (target_num, target) in enumerate(line["targets"]): 84 | span1 = target["span1"] 85 | span2 = target["span2"] 86 | examples.append( 87 | Example( 88 | guid="%s-%s-%s" % (set_type, line_num, target_num), 89 | text=line["text"], 90 | span1=span1, 91 | span2=span2, 92 | labels=target["label"] if set_type != "test" else [cls.LABELS[-1]], 93 | ) 94 | ) 95 | return examples 96 | -------------------------------------------------------------------------------- /jiant/tasks/lib/edge_probing/spr2.py: -------------------------------------------------------------------------------- 1 | """Semantic proto-role (2) Edge Probing task. 2 | 3 | Task source paper: https://arxiv.org/pdf/1905.06316.pdf. 4 | Task data prep directions: https://github.com/nyu-mll/jiant/blob/master/probing/data/README.md. 5 | 6 | """ 7 | from dataclasses import dataclass 8 | 9 | from jiant.tasks.lib.templates.shared import labels_to_bimap 10 | from jiant.tasks.lib.templates import edge_probing_two_span 11 | from jiant.utils.python.io import read_json_lines 12 | 13 | 14 | @dataclass 15 | class Example(edge_probing_two_span.Example): 16 | @property 17 | def task(self): 18 | return Spr2Task 19 | 20 | 21 | @dataclass 22 | class TokenizedExample(edge_probing_two_span.TokenizedExample): 23 | pass 24 | 25 | 26 | @dataclass 27 | class DataRow(edge_probing_two_span.DataRow): 28 | pass 29 | 30 | 31 | @dataclass 32 | class Batch(edge_probing_two_span.Batch): 33 | pass 34 | 35 | 36 | class Spr2Task(edge_probing_two_span.AbstractProbingTask): 37 | Example = Example 38 | TokenizedExample = TokenizedExample 39 | DataRow = DataRow 40 | Batch = Batch 41 | 42 | LABELS = [ 43 | "awareness", 44 | "change_of_location", 45 | "change_of_possession", 46 | "change_of_state", 47 | "change_of_state_continuous", 48 | "changes_possession", 49 | "existed_after", 50 | "existed_before", 51 | "existed_during", 52 | "exists_as_physical", 53 | "instigation", 54 | "location_of_event", 55 | "makes_physical_contact", 56 | "partitive", 57 | "predicate_changed_argument", 58 | "sentient", 59 | "stationary", 60 | "volition", 61 | "was_for_benefit", 62 | "was_used", 63 | ] 64 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 65 | 66 | @property 67 | def num_spans(self): 68 | return 2 69 | 70 | def get_train_examples(self): 71 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 72 | 73 | def get_val_examples(self): 74 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 75 | 76 | def get_test_examples(self): 77 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 78 | 79 | @classmethod 80 | def _create_examples(cls, lines, set_type): 81 | examples = [] 82 | for (line_num, line) in enumerate(lines): 83 | for (target_num, target) in enumerate(line["targets"]): 84 | span1 = target["span1"] 85 | span2 = target["span2"] 86 | examples.append( 87 | Example( 88 | guid="%s-%s-%s" % (set_type, line_num, target_num), 89 | text=line["text"], 90 | span1=span1, 91 | span2=span2, 92 | labels=target["label"] if set_type != "test" else [cls.LABELS[-1]], 93 | ) 94 | ) 95 | return examples 96 | -------------------------------------------------------------------------------- /jiant/tasks/lib/fever_nli.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | from jiant.tasks.core import ( 7 | BaseExample, 8 | BaseTokenizedExample, 9 | BaseDataRow, 10 | BatchMixin, 11 | Task, 12 | TaskTypes, 13 | ) 14 | from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap 15 | from jiant.utils.python.io import read_jsonl 16 | 17 | 18 | @dataclass 19 | class Example(BaseExample): 20 | guid: str 21 | premise: str 22 | hypothesis: str 23 | label: str 24 | 25 | def tokenize(self, tokenizer): 26 | return TokenizedExample( 27 | guid=self.guid, 28 | premise=tokenizer.tokenize(self.premise), 29 | hypothesis=tokenizer.tokenize(self.hypothesis), 30 | label_id=FeverNliTask.LABEL_TO_ID[self.label], 31 | ) 32 | 33 | 34 | @dataclass 35 | class TokenizedExample(BaseTokenizedExample): 36 | guid: str 37 | premise: List 38 | hypothesis: List 39 | label_id: int 40 | 41 | def featurize(self, tokenizer, feat_spec): 42 | return double_sentence_featurize( 43 | guid=self.guid, 44 | input_tokens_a=self.premise, 45 | input_tokens_b=self.hypothesis, 46 | label_id=self.label_id, 47 | tokenizer=tokenizer, 48 | feat_spec=feat_spec, 49 | data_row_class=DataRow, 50 | ) 51 | 52 | 53 | @dataclass 54 | class DataRow(BaseDataRow): 55 | guid: str 56 | input_ids: np.ndarray 57 | input_mask: np.ndarray 58 | segment_ids: np.ndarray 59 | label_id: int 60 | tokens: list 61 | 62 | 63 | @dataclass 64 | class Batch(BatchMixin): 65 | input_ids: torch.LongTensor 66 | input_mask: torch.LongTensor 67 | segment_ids: torch.LongTensor 68 | label_id: torch.LongTensor 69 | tokens: list 70 | 71 | 72 | class FeverNliTask(Task): 73 | Example = Example 74 | TokenizedExample = Example 75 | DataRow = DataRow 76 | Batch = Batch 77 | 78 | TASK_TYPE = TaskTypes.CLASSIFICATION 79 | LABELS = ["REFUTES", "SUPPORTS", "NOT ENOUGH INFO"] 80 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 81 | 82 | def get_train_examples(self): 83 | return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") 84 | 85 | def get_val_examples(self): 86 | return self._create_examples(lines=read_jsonl(self.val_path), set_type="val") 87 | 88 | def get_test_examples(self): 89 | return self._create_examples(lines=read_jsonl(self.test_path), set_type="test") 90 | 91 | @classmethod 92 | def _create_examples(cls, lines, set_type): 93 | # noinspection DuplicatedCode 94 | examples = [] 95 | for (i, line) in enumerate(lines): 96 | examples.append( 97 | Example( 98 | guid="%s-%s" % (set_type, i), 99 | premise=line["context"], 100 | hypothesis=line["query"], 101 | label=line["label"] if set_type != "test" else cls.LABELS[-1], 102 | ) 103 | ) 104 | return examples 105 | -------------------------------------------------------------------------------- /jiant/tasks/lib/glue_diagnostics.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from . import mnli 4 | 5 | 6 | @dataclass 7 | class Example(mnli.Example): 8 | pass 9 | 10 | 11 | @dataclass 12 | class TokenizedExample(mnli.TokenizedExample): 13 | pass 14 | 15 | 16 | @dataclass 17 | class DataRow(mnli.DataRow): 18 | pass 19 | 20 | 21 | @dataclass 22 | class Batch(mnli.Batch): 23 | pass 24 | 25 | 26 | class GlueDiagnosticsTask(mnli.MnliTask): 27 | def get_train_examples(self): 28 | raise RuntimeError("This task does not support training examples") 29 | 30 | def get_val_examples(self): 31 | raise RuntimeError("This task does not support validation examples") 32 | -------------------------------------------------------------------------------- /jiant/tasks/lib/hellaswag.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return HellaSwagTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class HellaSwagTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1, 2, 3] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | for i, line in enumerate(lines): 53 | examples.append( 54 | Example( 55 | guid="%s-%s" % (set_type, i), 56 | prompt=line["ctx_a"], 57 | choice_list=[line["ctx_b"] + " " + ending for ending in line["endings"]], 58 | label=line["label"] if set_type != "test" else cls.CHOICE_KEYS[-1], 59 | ) 60 | ) 61 | return examples 62 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mcscript.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return MCScriptTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class MCScriptTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | for line in lines: 53 | passage = line["passage"]["text"] 54 | passage_id = line["idx"] 55 | for question_dict in line["passage"]["questions"]: 56 | question = question_dict["question"] 57 | question_id = question_dict["idx"] 58 | answer_dicts = question_dict["answers"] 59 | examples.append( 60 | Example( 61 | guid="%s-%s-%s" % (set_type, passage_id, question_id), 62 | prompt=passage, 63 | choice_list=[ 64 | question + " " + answer_dict["text"] for answer_dict in answer_dicts 65 | ], 66 | label=answer_dicts[1]["label"] == "True" 67 | if set_type != "test" 68 | else cls.CHOICE_KEYS[-1], 69 | ) 70 | ) 71 | 72 | return examples 73 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mctest.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_file_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return MCTestTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class MCTestTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = TokenizedExample 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C", "D"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples( 42 | lines=read_file_lines(self.train_path, strip_lines=True), 43 | ans_lines=read_file_lines(self.path_dict["train_ans"], strip_lines=True), 44 | set_type="train", 45 | ) 46 | 47 | def get_val_examples(self): 48 | return self._create_examples( 49 | lines=read_file_lines(self.val_path, strip_lines=True), 50 | ans_lines=read_file_lines(self.path_dict["val_ans"], strip_lines=True), 51 | set_type="val", 52 | ) 53 | 54 | def get_test_examples(self): 55 | return self._create_examples( 56 | lines=read_file_lines(self.test_path, strip_lines=True), 57 | ans_lines=None, 58 | set_type="test", 59 | ) 60 | 61 | @classmethod 62 | def _create_examples(cls, lines, ans_lines, set_type): 63 | examples = [] 64 | if ans_lines is None: 65 | ans_lines = ["\t".join([cls.CHOICE_KEYS[-1]] * 4) for line in lines] 66 | for i, (line, ans) in enumerate(zip(lines, ans_lines)): 67 | line = line.split("\t") 68 | ans = ans.split("\t") 69 | for j in range(4): 70 | examples.append( 71 | Example( 72 | guid="%s-%s" % (set_type, i * 4 + j), 73 | prompt=line[2].replace("\\newline", " ") + " " + line[3 + j * 5], 74 | choice_list=line[4 + j * 5 : 8 + j * 5], 75 | label=ans[j], 76 | ) 77 | ) 78 | return examples 79 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mlm_premasked.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | 4 | import jiant.utils.python.io as py_io 5 | from jiant.tasks.core import ( 6 | Task, 7 | TaskTypes, 8 | BaseExample, 9 | ) 10 | from jiant.tasks.utils import ExclusiveSpan 11 | from .templates import mlm_premasked as mlm_premasked_template 12 | 13 | 14 | @dataclass 15 | class Example(BaseExample): 16 | guid: str 17 | text: str 18 | # Spans over char indices 19 | masked_spans: List[ExclusiveSpan] 20 | 21 | def tokenize(self, tokenizer): 22 | # masked_tokens will be regular tokens except with tokenizer.mask_token for masked spans 23 | # label_tokens will be tokenizer.pad_token except with the regular tokens for masked spans 24 | masked_tokens = [] 25 | label_tokens = [] 26 | curr = 0 27 | for start, end in self.masked_spans: 28 | # Handle text before next mask 29 | tokenized_text = tokenizer.tokenize(self.text[curr:start]) 30 | masked_tokens += tokenized_text 31 | label_tokens += [tokenizer.pad_token] * len(tokenized_text) 32 | 33 | # Handle mask 34 | tokenized_masked_text = tokenizer.tokenize(self.text[start:end]) 35 | masked_tokens += [tokenizer.mask_token] * len(tokenized_masked_text) 36 | label_tokens += tokenized_masked_text 37 | curr = end 38 | if curr < len(self.text): 39 | tokenized_text = tokenizer.tokenize(self.text[curr:]) 40 | masked_tokens += tokenized_text 41 | label_tokens += [tokenizer.pad_token] * len(tokenized_text) 42 | 43 | return TokenizedExample( 44 | guid=self.guid, 45 | masked_tokens=masked_tokens, 46 | label_tokens=label_tokens, 47 | ) 48 | 49 | 50 | @dataclass 51 | class TokenizedExample(mlm_premasked_template.TokenizedExample): 52 | pass 53 | 54 | 55 | @dataclass 56 | class DataRow(mlm_premasked_template.BaseDataRow): 57 | pass 58 | 59 | 60 | @dataclass 61 | class Batch(mlm_premasked_template.Batch): 62 | pass 63 | 64 | 65 | class MLMPremaskedTask(Task): 66 | Example = Example 67 | TokenizedExample = TokenizedExample 68 | DataRow = DataRow 69 | Batch = Batch 70 | 71 | TASK_TYPE = TaskTypes.MASKED_LANGUAGE_MODELING 72 | 73 | def __init__(self, name, path_dict): 74 | super().__init__(name=name, path_dict=path_dict) 75 | self.mlm_probability = None 76 | self.do_mask = False 77 | 78 | def get_train_examples(self): 79 | return self._create_examples(path=self.train_path, set_type="train") 80 | 81 | def get_val_examples(self): 82 | return self._create_examples(path=self.val_path, set_type="val") 83 | 84 | def get_test_examples(self): 85 | return self._create_examples(path=self.test_path, set_type="test") 86 | 87 | @classmethod 88 | def _create_examples(cls, path, set_type): 89 | for i, row in enumerate(py_io.read_jsonl(path)): 90 | yield Example( 91 | guid="%s-%s" % (set_type, i), 92 | text=row["text"], 93 | masked_spans=[ExclusiveSpan(start, end) for start, end in row["masked_spans"]], 94 | ) 95 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mlm_simple.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.utils.python.datastructures import ReusableGenerator 4 | from jiant.tasks.lib.templates import mlm as mlm_template 5 | 6 | 7 | @dataclass 8 | class Example(mlm_template.Example): 9 | pass 10 | 11 | 12 | @dataclass 13 | class TokenizedExample(mlm_template.TokenizedExample): 14 | pass 15 | 16 | 17 | @dataclass 18 | class DataRow(mlm_template.DataRow): 19 | pass 20 | 21 | 22 | @dataclass 23 | class Batch(mlm_template.Batch): 24 | pass 25 | 26 | 27 | @dataclass 28 | class MaskedBatch(mlm_template.MaskedBatch): 29 | pass 30 | 31 | 32 | class MLMSimpleTask(mlm_template.MLMTask): 33 | """Simple implementation of MLM task 34 | - Reads from a single file per phase 35 | - One example per line (examples that are too long will be truncated) 36 | - Empty lines are skipped. 37 | """ 38 | 39 | Example = Example 40 | TokenizedExample = Example 41 | DataRow = DataRow 42 | Batch = Batch 43 | 44 | def __init__(self, name, path_dict, mlm_probability=0.15, do_mask=True): 45 | super().__init__(name=name, path_dict=path_dict) 46 | self.mlm_probability = mlm_probability 47 | self.do_mask = do_mask 48 | 49 | def get_train_examples(self): 50 | return self._create_examples(path=self.train_path, set_type="train", return_generator=True) 51 | 52 | def get_val_examples(self): 53 | return self._create_examples(path=self.val_path, set_type="val", return_generator=True) 54 | 55 | def get_test_examples(self): 56 | return self._create_examples(path=self.test_path, set_type="test", return_generator=True) 57 | 58 | @classmethod 59 | def _get_examples_generator(cls, path, set_type): 60 | with open(path, "r") as f: 61 | for (i, line) in enumerate(f): 62 | line = line.strip() 63 | if not line: 64 | continue 65 | yield Example( 66 | guid="%s-%s" % (set_type, i), 67 | text=line, 68 | ) 69 | 70 | @classmethod 71 | def _create_examples(cls, path, set_type, return_generator): 72 | generator = ReusableGenerator(cls._get_examples_generator, path=path, set_type=set_type) 73 | if return_generator: 74 | return generator 75 | else: 76 | return list(generator) 77 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mnli_mismatched.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from . import mnli 4 | 5 | 6 | @dataclass 7 | class Example(mnli.Example): 8 | pass 9 | 10 | 11 | @dataclass 12 | class TokenizedExample(mnli.TokenizedExample): 13 | pass 14 | 15 | 16 | @dataclass 17 | class DataRow(mnli.DataRow): 18 | pass 19 | 20 | 21 | @dataclass 22 | class Batch(mnli.Batch): 23 | pass 24 | 25 | 26 | class MnliMismatchedTask(mnli.MnliTask): 27 | def get_train_examples(self): 28 | raise RuntimeError("This task does not support training examples") 29 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mrpc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | from jiant.tasks.core import ( 7 | BaseExample, 8 | BaseTokenizedExample, 9 | BaseDataRow, 10 | BatchMixin, 11 | GlueMixin, 12 | Task, 13 | TaskTypes, 14 | ) 15 | from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap 16 | from jiant.utils.python.io import read_jsonl 17 | 18 | 19 | @dataclass 20 | class Example(BaseExample): 21 | guid: str 22 | text_a: str 23 | text_b: str 24 | label: str 25 | 26 | def tokenize(self, tokenizer): 27 | return TokenizedExample( 28 | guid=self.guid, 29 | text_a=tokenizer.tokenize(self.text_a), 30 | text_b=tokenizer.tokenize(self.text_b), 31 | label_id=MrpcTask.LABEL_TO_ID[self.label], 32 | ) 33 | 34 | 35 | @dataclass 36 | class TokenizedExample(BaseTokenizedExample): 37 | guid: str 38 | text_a: List 39 | text_b: List 40 | label_id: int 41 | 42 | def featurize(self, tokenizer, feat_spec): 43 | return double_sentence_featurize( 44 | guid=self.guid, 45 | input_tokens_a=self.text_a, 46 | input_tokens_b=self.text_b, 47 | label_id=self.label_id, 48 | tokenizer=tokenizer, 49 | feat_spec=feat_spec, 50 | data_row_class=DataRow, 51 | ) 52 | 53 | 54 | @dataclass 55 | class DataRow(BaseDataRow): 56 | guid: str 57 | input_ids: np.ndarray 58 | input_mask: np.ndarray 59 | segment_ids: np.ndarray 60 | label_id: int 61 | tokens: list 62 | 63 | 64 | @dataclass 65 | class Batch(BatchMixin): 66 | input_ids: torch.LongTensor 67 | input_mask: torch.LongTensor 68 | segment_ids: torch.LongTensor 69 | label_id: torch.LongTensor 70 | tokens: list 71 | 72 | 73 | class MrpcTask(GlueMixin, Task): 74 | Example = Example 75 | TokenizedExample = Example 76 | DataRow = DataRow 77 | Batch = Batch 78 | 79 | TASK_TYPE = TaskTypes.CLASSIFICATION 80 | LABELS = ["0", "1"] 81 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 82 | 83 | def get_train_examples(self): 84 | return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") 85 | 86 | def get_val_examples(self): 87 | return self._create_examples(lines=read_jsonl(self.val_path), set_type="val") 88 | 89 | def get_test_examples(self): 90 | return self._create_examples(lines=read_jsonl(self.test_path), set_type="test") 91 | 92 | @classmethod 93 | def _create_examples(cls, lines, set_type): 94 | examples = [] 95 | for (i, line) in enumerate(lines): 96 | examples.append( 97 | Example( 98 | # NOTE: get_glue_preds() is dependent on this guid format. 99 | guid="%s-%s" % (set_type, i), 100 | text_a=line["text_a"], 101 | text_b=line["text_b"], 102 | label=line["label"] if set_type != "test" else cls.LABELS[-1], 103 | ) 104 | ) 105 | return examples 106 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mutual.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return MutualTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class MutualTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C", "D"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | for i, line in enumerate(lines): 53 | examples.append( 54 | Example( 55 | guid="%s-%s" % (set_type, i), 56 | prompt=line["article"], 57 | choice_list=[d for d in line["options"]], 58 | label=line["answers"], 59 | ) 60 | ) 61 | return examples 62 | -------------------------------------------------------------------------------- /jiant/tasks/lib/mutual_plus.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return MutualPlusTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class MutualPlusTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C", "D"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | for i, line in enumerate(lines): 53 | examples.append( 54 | Example( 55 | guid="%s-%s" % (set_type, i), 56 | prompt=line["article"], 57 | choice_list=[d for d in line["options"]], 58 | label=line["answers"], 59 | ) 60 | ) 61 | return examples 62 | -------------------------------------------------------------------------------- /jiant/tasks/lib/newsqa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.shared.constants import PHASE 4 | from jiant.tasks.lib.templates.squad_style import core as squad_style_template 5 | from jiant.utils.python.io import read_jsonl 6 | 7 | 8 | @dataclass 9 | class Example(squad_style_template.Example): 10 | def tokenize(self, tokenizer): 11 | raise NotImplementedError("SQuaD is weird") 12 | 13 | 14 | @dataclass 15 | class DataRow(squad_style_template.DataRow): 16 | pass 17 | 18 | 19 | @dataclass 20 | class Batch(squad_style_template.Batch): 21 | pass 22 | 23 | 24 | class NewsQATask(squad_style_template.BaseSquadStyleTask): 25 | Example = Example 26 | DataRow = DataRow 27 | Batch = Batch 28 | 29 | def get_train_examples(self): 30 | return self.read_examples(path=self.train_path, set_type=PHASE.TRAIN) 31 | 32 | def get_val_examples(self): 33 | return self.read_examples(path=self.val_path, set_type=PHASE.VAL) 34 | 35 | @classmethod 36 | def read_examples(cls, path: str, set_type: str): 37 | examples = [] 38 | for entry in read_jsonl(path): 39 | for qa in entry["qas"]: 40 | answer_text = entry["text"][qa["answer"]["s"] : qa["answer"]["e"]] 41 | examples.append( 42 | Example( 43 | qas_id=f"{set_type}-{len(examples)}", 44 | question_text=qa["question"], 45 | context_text=entry["text"], 46 | answer_text=answer_text, 47 | start_position_character=qa["answer"]["s"], 48 | title="", 49 | is_impossible=False, 50 | answers=[{"answer_start": qa["answer"]["s"], "text": answer_text}], 51 | ) 52 | ) 53 | return examples 54 | -------------------------------------------------------------------------------- /jiant/tasks/lib/piqa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines, read_file_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return PiqaTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class PiqaTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples( 42 | lines=zip( 43 | read_json_lines(self.train_path), 44 | read_file_lines(self.path_dict["train_labels"], strip_lines=True), 45 | ), 46 | set_type="train", 47 | ) 48 | 49 | def get_val_examples(self): 50 | return self._create_examples( 51 | lines=zip( 52 | read_json_lines(self.val_path), 53 | read_file_lines(self.path_dict["val_labels"], strip_lines=True), 54 | ), 55 | set_type="val", 56 | ) 57 | 58 | def get_test_examples(self): 59 | return self._create_examples( 60 | lines=zip(read_json_lines(self.test_path), read_json_lines(self.test_path)), 61 | set_type="test", 62 | ) 63 | 64 | @classmethod 65 | def _create_examples(cls, lines, set_type): 66 | examples = [] 67 | 68 | for i, (ex, label_string) in enumerate(lines): 69 | examples.append( 70 | Example( 71 | guid="%s-%s" % (set_type, i), 72 | prompt=ex["goal"], 73 | choice_list=[ex["sol1"], ex["sol2"]], 74 | label=int(label_string) if set_type != "test" else cls.CHOICE_KEYS[-1], 75 | ) 76 | ) 77 | 78 | return examples 79 | -------------------------------------------------------------------------------- /jiant/tasks/lib/qamr.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import nltk 3 | 4 | from jiant.tasks.lib.templates import span_prediction as span_pred_template 5 | from jiant.utils.retokenize import TokenAligner 6 | 7 | 8 | class QAMRTask(span_pred_template.AbstractSpanPredictionTask): 9 | def get_train_examples(self): 10 | return self._create_examples(self.train_path, set_type="train") 11 | 12 | def get_val_examples(self): 13 | return self._create_examples(self.val_path, set_type="val") 14 | 15 | def get_test_examples(self): 16 | return self._create_examples(self.test_path, set_type="test") 17 | 18 | def _create_examples(self, qa_file_path, set_type): 19 | wiki_df = pd.read_csv(self.path_dict["wiki_dict"], sep="\t", names=["sent_id", "text"]) 20 | wiki_dict = {row.sent_id: row.text for row in wiki_df.itertuples(index=False)} 21 | 22 | data_df = pd.read_csv( 23 | qa_file_path, 24 | sep="\t", 25 | header=None, 26 | names=[ 27 | "sent_id", 28 | "target_ids", 29 | "worker_id", 30 | "qa_index", 31 | "qa_word", 32 | "question", 33 | "answer", 34 | "response1", 35 | "response2", 36 | ], 37 | ) 38 | data_df["sent"] = data_df["sent_id"].apply(wiki_dict.get) 39 | 40 | examples = [] 41 | ptb_detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer() 42 | for i, row in enumerate(data_df.itertuples(index=False)): 43 | # Answer indices are a space-limited list of numbers. 44 | # We simply take the min/max of the indices 45 | answer_idxs = list(map(int, row.answer.split())) 46 | answer_token_start, answer_token_end = min(answer_idxs), max(answer_idxs) 47 | passage_ptb_tokens = row.sent.split() 48 | passage_space_tokens = ptb_detokenizer.detokenize( 49 | passage_ptb_tokens, convert_parentheses=True 50 | ).split() 51 | passage_space_str = " ".join(passage_space_tokens) 52 | 53 | token_aligner = TokenAligner(source=passage_ptb_tokens, target=passage_space_tokens) 54 | answer_char_span = token_aligner.project_token_to_char_span( 55 | answer_token_start, answer_token_end, inclusive=True 56 | ) 57 | answer_str = passage_space_str[answer_char_span[0] : answer_char_span[1] + 1] 58 | 59 | examples.append( 60 | span_pred_template.Example( 61 | guid="%s-%s" % (set_type, i), 62 | passage=passage_space_str, 63 | question=row.question, 64 | answer=answer_str, 65 | answer_char_span=answer_char_span, 66 | ) 67 | ) 68 | 69 | return examples 70 | -------------------------------------------------------------------------------- /jiant/tasks/lib/qqp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | from jiant.tasks.core import ( 7 | BaseExample, 8 | BaseTokenizedExample, 9 | BaseDataRow, 10 | BatchMixin, 11 | GlueMixin, 12 | Task, 13 | TaskTypes, 14 | ) 15 | from jiant.tasks.lib.templates.shared import double_sentence_featurize, labels_to_bimap 16 | from jiant.utils.python.io import read_jsonl 17 | 18 | 19 | @dataclass 20 | class Example(BaseExample): 21 | guid: str 22 | text_a: str 23 | text_b: str 24 | label: str 25 | 26 | def tokenize(self, tokenizer): 27 | return TokenizedExample( 28 | guid=self.guid, 29 | text_a=tokenizer.tokenize(self.text_a), 30 | text_b=tokenizer.tokenize(self.text_b), 31 | label_id=QqpTask.LABEL_TO_ID[self.label], 32 | ) 33 | 34 | 35 | @dataclass 36 | class TokenizedExample(BaseTokenizedExample): 37 | guid: str 38 | text_a: List 39 | text_b: List 40 | label_id: int 41 | 42 | def featurize(self, tokenizer, feat_spec): 43 | return double_sentence_featurize( 44 | guid=self.guid, 45 | input_tokens_a=self.text_a, 46 | input_tokens_b=self.text_b, 47 | label_id=self.label_id, 48 | tokenizer=tokenizer, 49 | feat_spec=feat_spec, 50 | data_row_class=DataRow, 51 | ) 52 | 53 | 54 | @dataclass 55 | class DataRow(BaseDataRow): 56 | guid: str 57 | input_ids: np.ndarray 58 | input_mask: np.ndarray 59 | segment_ids: np.ndarray 60 | label_id: int 61 | tokens: list 62 | 63 | 64 | @dataclass 65 | class Batch(BatchMixin): 66 | input_ids: torch.LongTensor 67 | input_mask: torch.LongTensor 68 | segment_ids: torch.LongTensor 69 | label_id: torch.LongTensor 70 | tokens: list 71 | 72 | 73 | class QqpTask(GlueMixin, Task): 74 | Example = Example 75 | TokenizedExample = Example 76 | DataRow = DataRow 77 | Batch = Batch 78 | 79 | TASK_TYPE = TaskTypes.CLASSIFICATION 80 | LABELS = ["0", "1"] 81 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 82 | 83 | def get_train_examples(self): 84 | return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") 85 | 86 | def get_val_examples(self): 87 | return self._create_examples(lines=read_jsonl(self.val_path), set_type="val") 88 | 89 | def get_test_examples(self): 90 | return self._create_examples(lines=read_jsonl(self.test_path), set_type="test") 91 | 92 | @classmethod 93 | def _create_examples(cls, lines, set_type): 94 | examples = [] 95 | for (i, line) in enumerate(lines): 96 | examples.append( 97 | Example( 98 | guid="%s-%s" % (set_type, i), 99 | text_a=line["text_a"], 100 | text_b=line["text_b"], 101 | label=line["label"] if set_type != "test" else cls.LABELS[-1], 102 | ) 103 | ) 104 | return examples 105 | -------------------------------------------------------------------------------- /jiant/tasks/lib/quail.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return QuailTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class QuailTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1, 2, 3] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | for i, line in enumerate(lines): 53 | examples.append( 54 | Example( 55 | guid="%s-%s" % (set_type, i), 56 | prompt=line["context"] + " " + line["question"], 57 | choice_list=[d for d in line["answers"]], 58 | label=line["correct_answer_id"], 59 | ) 60 | ) 61 | return examples 62 | -------------------------------------------------------------------------------- /jiant/tasks/lib/quoref.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.shared.constants import PHASE 4 | from jiant.tasks.lib.templates.squad_style import core as squad_style_template 5 | from jiant.utils.python.io import read_jsonl 6 | 7 | 8 | @dataclass 9 | class Example(squad_style_template.Example): 10 | def tokenize(self, tokenizer): 11 | raise NotImplementedError("SQuaD is weird") 12 | 13 | 14 | @dataclass 15 | class DataRow(squad_style_template.DataRow): 16 | pass 17 | 18 | 19 | @dataclass 20 | class Batch(squad_style_template.Batch): 21 | pass 22 | 23 | 24 | class QuorefTask(squad_style_template.BaseSquadStyleTask): 25 | Example = Example 26 | DataRow = DataRow 27 | Batch = Batch 28 | 29 | def get_train_examples(self): 30 | return self.read_examples(path=self.train_path, set_type=PHASE.TRAIN) 31 | 32 | def get_val_examples(self): 33 | return self.read_examples(path=self.val_path, set_type=PHASE.VAL) 34 | 35 | @classmethod 36 | def read_examples(cls, path: str, set_type: str): 37 | examples = [] 38 | for i, line in enumerate(read_jsonl(path)): 39 | if set_type == PHASE.TRAIN: 40 | for j, (answer_start, answer_text) in enumerate( 41 | zip(line["answers"]["answer_start"], line["answers"]["text"]) 42 | ): 43 | examples.append( 44 | Example( 45 | qas_id=f"{set_type}-{i}", 46 | question_text=line["question"], 47 | context_text=line["context"], 48 | answer_text=answer_text, 49 | start_position_character=answer_start, 50 | title=line["title"], 51 | is_impossible=False, 52 | answers=[], 53 | ) 54 | ) 55 | else: 56 | answers = [ 57 | {"answer_start": answer_start, "text": answer_text} 58 | for answer_start, answer_text in zip( 59 | line["answers"]["answer_start"], line["answers"]["text"] 60 | ) 61 | ] 62 | examples.append( 63 | Example( 64 | qas_id=f"{set_type}-{i}", 65 | question_text=line["question"], 66 | context_text=line["context"], 67 | answer_text=None, 68 | start_position_character=None, 69 | title=line["title"], 70 | is_impossible=False, 71 | answers=answers, 72 | ) 73 | ) 74 | return examples 75 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/senteval/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | from dataclasses import dataclass 5 | from typing import List 6 | 7 | from jiant.tasks.core import ( 8 | BaseExample, 9 | BaseTokenizedExample, 10 | BaseDataRow, 11 | BatchMixin, 12 | Task, 13 | TaskTypes, 14 | ) 15 | from jiant.tasks.lib.templates.shared import single_sentence_featurize 16 | 17 | 18 | @dataclass 19 | class Example(BaseExample): 20 | guid: str 21 | text: str 22 | label: str 23 | 24 | @property 25 | def label_to_id(self): 26 | raise NotImplementedError() 27 | 28 | def tokenize(self, tokenizer): 29 | return TokenizedExample( 30 | guid=self.guid, 31 | text=tokenizer.tokenize(self.text), 32 | label_id=self.label_to_id[self.label], 33 | ) 34 | 35 | 36 | @dataclass 37 | class TokenizedExample(BaseTokenizedExample): 38 | guid: str 39 | text: List 40 | label_id: int 41 | 42 | def featurize(self, tokenizer, feat_spec): 43 | return single_sentence_featurize( 44 | guid=self.guid, 45 | input_tokens=self.text, 46 | label_id=self.label_id, 47 | tokenizer=tokenizer, 48 | feat_spec=feat_spec, 49 | data_row_class=DataRow, 50 | ) 51 | 52 | 53 | @dataclass 54 | class DataRow(BaseDataRow): 55 | guid: str 56 | input_ids: np.ndarray 57 | input_mask: np.ndarray 58 | segment_ids: np.ndarray 59 | label_id: int 60 | tokens: list 61 | 62 | 63 | @dataclass 64 | class Batch(BatchMixin): 65 | input_ids: torch.LongTensor 66 | input_mask: torch.LongTensor 67 | segment_ids: torch.LongTensor 68 | label_id: torch.LongTensor 69 | tokens: list 70 | 71 | 72 | class BaseSentEvalTask(Task): 73 | Example = Example 74 | TokenizedExample = TokenizedExample 75 | DataRow = DataRow 76 | Batch = Batch 77 | 78 | TASK_TYPE = TaskTypes.CLASSIFICATION 79 | LABELS = None # Override this 80 | 81 | def get_train_examples(self): 82 | return self._create_examples(set_type="train") 83 | 84 | def get_val_examples(self): 85 | return self._create_examples(set_type="val") 86 | 87 | def get_test_examples(self): 88 | return self._create_examples(set_type="test") 89 | 90 | def _create_examples(self, set_type): 91 | examples = [] 92 | df = pd.read_csv(self.path_dict["data"], sep="\t", names=["phase", "label", "text"]) 93 | phase_key = {"train": "tr", "val": "va", "test": "te"}[set_type] 94 | sub_df = df[df["phase"] == phase_key] 95 | for i, row in sub_df.iterrows(): 96 | # noinspection PyArgumentList 97 | examples.append( 98 | self.Example( 99 | guid="%s-%s" % (set_type, i), 100 | text=row.text, 101 | label=row.label if set_type != "test" else self.LABELS[-1], 102 | ) 103 | ) 104 | return examples 105 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/bigram_shift.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalBigramShiftTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalBigramShiftTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["I", "O"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/coordination_inversion.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalCoordinationInversionTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalCoordinationInversionTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["I", "O"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/obj_number.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalObjNumberTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalObjNumberTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["NN", "NNS"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/odd_man_out.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalOddManOutTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalOddManOutTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["C", "O"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/past_present.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalPastPresentTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalPastPresentTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["PAST", "PRES"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/sentence_length.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalSentenceLengthTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalSentenceLengthTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = [0, 1, 2, 3, 4, 5, 6] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/subj_number.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalSubjNumberTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalSubjNumberTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = ["NN", "NNS"] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/top_constituents.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalTopConstituentsTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalTopConstituentsTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = [ 35 | "ADVP_NP_VP_.", 36 | "CC_ADVP_NP_VP_.", 37 | "CC_NP_VP_.", 38 | "IN_NP_VP_.", 39 | "NP_ADVP_VP_.", 40 | "NP_NP_VP_.", 41 | "NP_PP_.", 42 | "NP_VP_.", 43 | "OTHER", 44 | "PP_NP_VP_.", 45 | "RB_NP_VP_.", 46 | "SBAR_NP_VP_.", 47 | "SBAR_VP_.", 48 | "S_CC_S_.", 49 | "S_NP_VP_.", 50 | "S_VP_.", 51 | "VBD_NP_VP_.", 52 | "VP_.", 53 | "WHADVP_SQ_.", 54 | "WHNP_SQ_.", 55 | ] 56 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 57 | -------------------------------------------------------------------------------- /jiant/tasks/lib/senteval/tree_depth.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from . import base as base 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | 5 | 6 | @dataclass 7 | class Example(base.Example): 8 | @property 9 | def label_to_id(self): 10 | return SentEvalTreeDepthTask.LABEL_TO_ID 11 | 12 | 13 | @dataclass 14 | class TokenizedExample(base.TokenizedExample): 15 | pass 16 | 17 | 18 | @dataclass 19 | class DataRow(base.DataRow): 20 | pass 21 | 22 | 23 | @dataclass 24 | class Batch(base.Batch): 25 | pass 26 | 27 | 28 | class SentEvalTreeDepthTask(base.BaseSentEvalTask): 29 | Example = Example 30 | TokenizedExample = TokenizedExample 31 | DataRow = DataRow 32 | Batch = Batch 33 | 34 | LABELS = [5, 6, 7, 8, 9, 10, 11] 35 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 36 | -------------------------------------------------------------------------------- /jiant/tasks/lib/socialiqa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.shared import labels_to_bimap 4 | from jiant.tasks.lib.templates import multiple_choice as mc_template 5 | from jiant.utils.python.io import read_json_lines, read_file_lines 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return SocialIQATask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class SocialIQATask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = ["A", "B", "C"] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(lines=read_json_lines(self.train_path), set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(lines=read_json_lines(self.val_path), set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(lines=read_json_lines(self.test_path), set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, lines, set_type): 51 | examples = [] 52 | answer_key_ls = ["answerA", "answerB", "answerC"] 53 | hf_datasets_label_map = { 54 | "1\n": "A", 55 | "2\n": "B", 56 | "3\n": "C", 57 | } 58 | for i, line in enumerate(lines): 59 | if "label" in line: 60 | # Loading from HF Datasets data 61 | label = hf_datasets_label_map[line["label"]] 62 | else: 63 | # Loading from original data 64 | label = line["correct"] 65 | examples.append( 66 | Example( 67 | guid="%s-%s" % (set_type, i), 68 | prompt=line["context"] + " " + line["question"], 69 | choice_list=[line[answer_key] for answer_key in answer_key_ls], 70 | label=label, 71 | ) 72 | ) 73 | return examples 74 | 75 | @classmethod 76 | def _read_labels(cls, path): 77 | lines = read_file_lines(path) 78 | return [int(line.strip()) for line in lines] 79 | -------------------------------------------------------------------------------- /jiant/tasks/lib/squad.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.squad_style import core as squad_style_template 4 | 5 | 6 | @dataclass 7 | class Example(squad_style_template.Example): 8 | def tokenize(self, tokenizer): 9 | raise NotImplementedError("SQuaD is weird") 10 | 11 | 12 | @dataclass 13 | class DataRow(squad_style_template.DataRow): 14 | pass 15 | 16 | 17 | @dataclass 18 | class Batch(squad_style_template.Batch): 19 | pass 20 | 21 | 22 | class SquadTask(squad_style_template.BaseSquadStyleTask): 23 | Example = Example 24 | DataRow = DataRow 25 | Batch = Batch 26 | -------------------------------------------------------------------------------- /jiant/tasks/lib/sst.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dataclasses import dataclass 4 | from typing import List 5 | 6 | from jiant.tasks.core import ( 7 | BaseExample, 8 | BaseTokenizedExample, 9 | BaseDataRow, 10 | BatchMixin, 11 | GlueMixin, 12 | Task, 13 | TaskTypes, 14 | ) 15 | from jiant.tasks.lib.templates.shared import single_sentence_featurize, labels_to_bimap 16 | from jiant.utils.python.io import read_jsonl 17 | 18 | 19 | @dataclass 20 | class Example(BaseExample): 21 | guid: str 22 | text: str 23 | label: str 24 | 25 | def tokenize(self, tokenizer): 26 | return TokenizedExample( 27 | guid=self.guid, 28 | text=tokenizer.tokenize(self.text), 29 | label_id=SstTask.LABEL_TO_ID[self.label], 30 | ) 31 | 32 | 33 | @dataclass 34 | class TokenizedExample(BaseTokenizedExample): 35 | guid: str 36 | text: List 37 | label_id: int 38 | 39 | def featurize(self, tokenizer, feat_spec): 40 | return single_sentence_featurize( 41 | guid=self.guid, 42 | input_tokens=self.text, 43 | label_id=self.label_id, 44 | tokenizer=tokenizer, 45 | feat_spec=feat_spec, 46 | data_row_class=DataRow, 47 | ) 48 | 49 | 50 | @dataclass 51 | class DataRow(BaseDataRow): 52 | guid: str 53 | input_ids: np.ndarray 54 | input_mask: np.ndarray 55 | segment_ids: np.ndarray 56 | label_id: int 57 | tokens: list 58 | 59 | 60 | @dataclass 61 | class Batch(BatchMixin): 62 | input_ids: torch.LongTensor 63 | input_mask: torch.LongTensor 64 | segment_ids: torch.LongTensor 65 | label_id: torch.LongTensor 66 | tokens: list 67 | 68 | 69 | class SstTask(GlueMixin, Task): 70 | Example = Example 71 | TokenizedExample = Example 72 | DataRow = DataRow 73 | Batch = Batch 74 | 75 | TASK_TYPE = TaskTypes.CLASSIFICATION 76 | LABELS = ["0", "1"] 77 | LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS) 78 | 79 | def get_train_examples(self): 80 | return self._create_examples(lines=read_jsonl(self.train_path), set_type="train") 81 | 82 | def get_val_examples(self): 83 | return self._create_examples(lines=read_jsonl(self.val_path), set_type="val") 84 | 85 | def get_test_examples(self): 86 | return self._create_examples(lines=read_jsonl(self.test_path), set_type="test") 87 | 88 | @classmethod 89 | def _create_examples(cls, lines, set_type): 90 | examples = [] 91 | for (i, line) in enumerate(lines): 92 | examples.append( 93 | Example( 94 | guid="%s-%s" % (set_type, i), 95 | text=line["text"], 96 | label=line["label"] if set_type != "test" else cls.LABELS[-1], 97 | ) 98 | ) 99 | return examples 100 | -------------------------------------------------------------------------------- /jiant/tasks/lib/superglue_axb.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from . import rte 4 | 5 | 6 | @dataclass 7 | class Example(rte.Example): 8 | pass 9 | 10 | 11 | @dataclass 12 | class TokenizedExample(rte.Example): 13 | pass 14 | 15 | 16 | @dataclass 17 | class DataRow(rte.DataRow): 18 | pass 19 | 20 | 21 | @dataclass 22 | class Batch(rte.Batch): 23 | pass 24 | 25 | 26 | class SuperglueBroadcoverageDiagnosticsTask(rte.RteTask): 27 | def get_train_examples(self): 28 | raise RuntimeError("This task does not support training examples") 29 | 30 | def get_val_examples(self): 31 | raise RuntimeError("This task does not support validation examples") 32 | -------------------------------------------------------------------------------- /jiant/tasks/lib/superglue_axg.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from . import rte 4 | 5 | 6 | @dataclass 7 | class Example(rte.Example): 8 | pass 9 | 10 | 11 | @dataclass 12 | class TokenizedExample(rte.Example): 13 | pass 14 | 15 | 16 | @dataclass 17 | class DataRow(rte.DataRow): 18 | pass 19 | 20 | 21 | @dataclass 22 | class Batch(rte.Batch): 23 | pass 24 | 25 | 26 | class SuperglueWinogenderDiagnosticsTask(rte.RteTask): 27 | def get_train_examples(self): 28 | raise RuntimeError("This task does not support training examples") 29 | 30 | def get_val_examples(self): 31 | raise RuntimeError("This task does not support validation examples") 32 | -------------------------------------------------------------------------------- /jiant/tasks/lib/swag.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from dataclasses import dataclass 3 | 4 | from jiant.tasks.lib.templates.shared import labels_to_bimap 5 | from jiant.tasks.lib.templates import multiple_choice as mc_template 6 | 7 | 8 | @dataclass 9 | class Example(mc_template.Example): 10 | @property 11 | def task(self): 12 | return SWAGTask 13 | 14 | 15 | @dataclass 16 | class TokenizedExample(mc_template.TokenizedExample): 17 | pass 18 | 19 | 20 | @dataclass 21 | class DataRow(mc_template.DataRow): 22 | pass 23 | 24 | 25 | @dataclass 26 | class Batch(mc_template.Batch): 27 | pass 28 | 29 | 30 | class SWAGTask(mc_template.AbstractMultipleChoiceTask): 31 | Example = Example 32 | TokenizedExample = Example 33 | DataRow = DataRow 34 | Batch = Batch 35 | 36 | CHOICE_KEYS = [0, 1, 2, 3] 37 | CHOICE_TO_ID, ID_TO_CHOICE = labels_to_bimap(CHOICE_KEYS) 38 | NUM_CHOICES = len(CHOICE_KEYS) 39 | 40 | def get_train_examples(self): 41 | return self._create_examples(path=self.train_path, set_type="train") 42 | 43 | def get_val_examples(self): 44 | return self._create_examples(path=self.val_path, set_type="val") 45 | 46 | def get_test_examples(self): 47 | return self._create_examples(path=self.test_path, set_type="test") 48 | 49 | @classmethod 50 | def _create_examples(cls, path, set_type): 51 | df = pd.read_csv(path) 52 | examples = [] 53 | for i, row in enumerate(df.itertuples()): 54 | examples.append( 55 | Example( 56 | guid="%s-%s" % (set_type, i), 57 | prompt=row.sent1, 58 | choice_list=[ 59 | row.sent2 + " " + row.ending0, 60 | row.sent2 + " " + row.ending1, 61 | row.sent2 + " " + row.ending2, 62 | row.sent2 + " " + row.ending3, 63 | ], 64 | label=row.label if set_type != "test" else cls.CHOICE_KEYS[-1], 65 | ) 66 | ) 67 | return examples 68 | -------------------------------------------------------------------------------- /jiant/tasks/lib/templates/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/templates/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/templates/hacky_tokenization_matching.py: -------------------------------------------------------------------------------- 1 | """TODO: Remove when Tokenizers gets better (issue #1189)""" 2 | from jiant.tasks.utils import ExclusiveSpan 3 | 4 | 5 | def map_tags_to_token_position(flat_stripped, indices, split_text): 6 | char_index = 0 7 | current_string = flat_stripped 8 | positions = [None] * len(split_text) 9 | for i, token in enumerate(split_text): 10 | found_index = current_string.find(token.lower()) 11 | assert found_index != -1 12 | positions[i] = indices[char_index + found_index] 13 | char_index += found_index + len(token) 14 | current_string = flat_stripped[char_index:] 15 | for elem in positions: 16 | assert elem is not None 17 | return positions 18 | 19 | 20 | def convert_mapped_tags(positions, tag_ids, length): 21 | labels = [None] * length 22 | mask = [0] * length 23 | for pos, tag_id in zip(positions, tag_ids): 24 | labels[pos] = tag_id 25 | mask[pos] = 1 26 | return labels, mask 27 | 28 | 29 | def input_flat_strip(tokens): 30 | return "".join(tokens).lower() 31 | 32 | 33 | def flat_strip(tokens, tokenizer, return_indices=False): 34 | return tokenizer.convert_tokens_to_string(tokens).replace(" ", "").lower() 35 | 36 | 37 | def starts_with(ls, prefix): 38 | return ls[: len(prefix)] == prefix 39 | 40 | 41 | def get_token_span(sentence, span: ExclusiveSpan, tokenizer): 42 | tokenized = tokenizer.tokenize(sentence) 43 | tokenized_start1 = tokenizer.tokenize(sentence[: span.start]) 44 | tokenized_start2 = tokenizer.tokenize(sentence[: span.end]) 45 | assert starts_with(tokenized, tokenized_start1) 46 | # assert starts_with(tokenized, tokenized_start2) # <- fails because of "does" in "doesn't" 47 | word = sentence[span.to_slice()] 48 | assert word.lower().replace(" ", "") in flat_strip( 49 | tokenized_start2[len(tokenized_start1) :], 50 | tokenizer=tokenizer, 51 | ) 52 | token_span = ExclusiveSpan(start=len(tokenized_start1), end=len(tokenized_start2)) 53 | return tokenized, token_span 54 | -------------------------------------------------------------------------------- /jiant/tasks/lib/templates/squad_style/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/tasks/lib/templates/squad_style/__init__.py -------------------------------------------------------------------------------- /jiant/tasks/lib/tydiqa.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.squad_style import core as squad_style_template 4 | 5 | 6 | @dataclass 7 | class Example(squad_style_template.Example): 8 | def tokenize(self, tokenizer): 9 | raise NotImplementedError("SQuaD is weird") 10 | 11 | 12 | @dataclass 13 | class DataRow(squad_style_template.DataRow): 14 | pass 15 | 16 | 17 | @dataclass 18 | class Batch(squad_style_template.Batch): 19 | pass 20 | 21 | 22 | class TyDiQATask(squad_style_template.BaseSquadStyleTask): 23 | Example = Example 24 | DataRow = DataRow 25 | Batch = Batch 26 | 27 | def __init__( 28 | self, 29 | name, 30 | path_dict, 31 | language, 32 | version_2_with_negative=False, 33 | n_best_size=20, 34 | max_answer_length=30, 35 | null_score_diff_threshold=0.0, 36 | ): 37 | super().__init__( 38 | name=name, 39 | path_dict=path_dict, 40 | version_2_with_negative=version_2_with_negative, 41 | n_best_size=n_best_size, 42 | max_answer_length=max_answer_length, 43 | null_score_diff_threshold=null_score_diff_threshold, 44 | ) 45 | self.language = language 46 | 47 | def get_train_examples(self): 48 | if self.language == "en": 49 | return self.read_squad_examples(path=self.train_path, set_type="train") 50 | else: 51 | raise NotImplementedError("TyDiQA does not have training examples except for English") 52 | 53 | @classmethod 54 | def read_squad_examples(cls, path, set_type): 55 | return squad_style_template.generic_read_squad_examples( 56 | path=path, 57 | set_type=set_type, 58 | example_class=cls.Example, 59 | ) 60 | -------------------------------------------------------------------------------- /jiant/tasks/lib/xquad.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from jiant.tasks.lib.templates.squad_style import core as squad_style_template 4 | 5 | 6 | @dataclass 7 | class Example(squad_style_template.Example): 8 | def tokenize(self, tokenizer): 9 | raise NotImplementedError("SQuaD is weird") 10 | 11 | 12 | @dataclass 13 | class DataRow(squad_style_template.DataRow): 14 | pass 15 | 16 | 17 | @dataclass 18 | class Batch(squad_style_template.Batch): 19 | pass 20 | 21 | 22 | class XquadTask(squad_style_template.BaseSquadStyleTask): 23 | Example = Example 24 | DataRow = DataRow 25 | Batch = Batch 26 | 27 | def __init__( 28 | self, 29 | name, 30 | path_dict, 31 | language, 32 | version_2_with_negative=False, 33 | n_best_size=20, 34 | max_answer_length=30, 35 | null_score_diff_threshold=0.0, 36 | ): 37 | super().__init__( 38 | name=name, 39 | path_dict=path_dict, 40 | version_2_with_negative=version_2_with_negative, 41 | n_best_size=n_best_size, 42 | max_answer_length=max_answer_length, 43 | null_score_diff_threshold=null_score_diff_threshold, 44 | ) 45 | self.language = language 46 | 47 | def get_train_examples(self): 48 | raise NotImplementedError("XQuAD does not have training examples") 49 | 50 | @classmethod 51 | def read_squad_examples(cls, path, set_type): 52 | return squad_style_template.generic_read_squad_examples( 53 | path=path, 54 | set_type=set_type, 55 | example_class=cls.Example, 56 | ) 57 | -------------------------------------------------------------------------------- /jiant/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/utils/__init__.py -------------------------------------------------------------------------------- /jiant/utils/config_handlers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """This module contains utilities for manipulating configs.""" 3 | from typing import List 4 | 5 | import json 6 | import _jsonnet # type: ignore 7 | 8 | 9 | def json_merge_patch(target_json: str, patch_json: str) -> str: 10 | """Merge json objects according to JSON merge patch spec: https://tools.ietf.org/html/rfc7396. 11 | 12 | Takes a target json string, and a patch json string and applies the patch json to the target 13 | json according to "JSON Merge Patch" (defined by https://tools.ietf.org/html/rfc7396). 14 | 15 | Args: 16 | target_json: the json to be overwritten by the patch json. 17 | patch_json: the json used to overwrite the target json. 18 | 19 | Returns: 20 | json str after applying the patch json to the target json using "JSON Merge Patch" method. 21 | 22 | """ 23 | merged: str = """local target = {target_json}; 24 | local patch = {patch_json}; 25 | std.mergePatch(target, patch)""".format( 26 | target_json=target_json, patch_json=patch_json 27 | ) 28 | return _jsonnet.evaluate_snippet("snippet", merged) 29 | 30 | 31 | def merge_jsons_in_order(jsons: List[str]) -> str: 32 | """Applies JSON Merge Patch process to a list of json documents in order. 33 | 34 | Takes a list of json document strings and performs "JSON Merge Patch" (see json_merge_patch). 35 | The first element in the list of json docs is treated as the base, subsequent docs (if any) 36 | are applied as patches in order from first to last. 37 | 38 | Args: 39 | jsons: list of json docs to merge into a composite json document. 40 | 41 | Returns: 42 | The composite json document string. 43 | 44 | """ 45 | base_json = jsons.pop(0) 46 | # json.loads is called to check that input strings are valid json. 47 | json.loads(base_json) 48 | composite_json = base_json 49 | for json_str in jsons: 50 | json.loads(json_str) 51 | composite_json = json_merge_patch(composite_json, json_str) 52 | return composite_json 53 | -------------------------------------------------------------------------------- /jiant/utils/data_handlers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """This module contains utils for handling data (e.g., validating data)""" 3 | import hashlib 4 | 5 | 6 | def md5_checksum(filepath: str) -> str: 7 | """Calculate MD5 checksum hash for a given file. 8 | 9 | Code from example: https://stackoverflow.com/a/3431838/8734015. 10 | 11 | Args: 12 | filepath: file to calculate MD5 checksum. 13 | 14 | Returns: 15 | MD5 hash string. 16 | 17 | """ 18 | hash_md5 = hashlib.md5() 19 | with open(filepath, "rb") as f: 20 | for chunk in iter(lambda: f.read(4096), b""): 21 | hash_md5.update(chunk) 22 | return hash_md5.hexdigest() 23 | -------------------------------------------------------------------------------- /jiant/utils/display.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import auto as tqdm_lib 3 | 4 | 5 | def tqdm(iterable=None, desc=None, total=None, initial=0): 6 | return tqdm_lib.tqdm( 7 | iterable=iterable, 8 | desc=desc, 9 | total=total, 10 | initial=initial, 11 | ) 12 | 13 | 14 | def trange(*args, desc=None, total=None): 15 | return tqdm(range(*args), desc=desc, total=total) 16 | 17 | 18 | def maybe_tqdm(iterable=None, desc=None, total=None, initial=0, verbose=True): 19 | if verbose: 20 | return tqdm(iterable=iterable, desc=desc, total=total, initial=initial) 21 | else: 22 | return iterable 23 | 24 | 25 | def maybe_trange(*args, verbose, **kwargs): 26 | return maybe_tqdm(range(*args), verbose=verbose, **kwargs) 27 | 28 | 29 | def show_json(obj, do_print=True): 30 | string = json.dumps(obj, indent=2) 31 | if do_print: 32 | print(string) 33 | else: 34 | return string 35 | 36 | 37 | def is_notebook(): 38 | try: 39 | # noinspection PyUnresolvedReferences 40 | shell = get_ipython().__class__.__name__ 41 | if shell == "ZMQInteractiveShell": 42 | return True # Jupyter notebook or qtconsole 43 | elif shell == "TerminalInteractiveShell": 44 | return False # Terminal running IPython 45 | else: 46 | return False # Other type (?) 47 | except NameError: 48 | return False # Probably standard Python interpreter 49 | -------------------------------------------------------------------------------- /jiant/utils/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/utils/python/__init__.py -------------------------------------------------------------------------------- /jiant/utils/python/checks.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | def dict_equal(dict1: Dict, dict2: Dict) -> bool: 5 | if not len(dict1) == len(dict2): 6 | return False 7 | for (k1, v1), (k2, v2) in zip(dict1.items(), dict2.items()): 8 | if k1 != k2: 9 | return False 10 | if v1 != v2: 11 | return False 12 | return True 13 | -------------------------------------------------------------------------------- /jiant/utils/python/filesystem.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import sys 4 | from contextlib import contextmanager 5 | 6 | 7 | def find_files(base_path, func): 8 | return sorted( 9 | [ 10 | os.path.join(dp, filename) 11 | for dp, dn, filenames in os.walk(base_path) 12 | for filename in filenames 13 | if func(filename) 14 | ] 15 | ) 16 | 17 | 18 | def find_files_with_ext(base_path, ext): 19 | return find_files(base_path=base_path, func=lambda filename: filename.endswith(f".{ext}")) 20 | 21 | 22 | def get_code_base_path(): 23 | """Gets path to root of jiant code base 24 | 25 | Returns: 26 | Path to root of jiant code base 27 | """ 28 | return os.path.abspath( 29 | os.path.join( 30 | __file__, 31 | os.pardir, 32 | os.pardir, 33 | os.pardir, 34 | os.pardir, 35 | ) 36 | ) 37 | 38 | 39 | def get_code_asset_path(*rel_path): 40 | """Get path to file/folder within code base 41 | 42 | Like os.path.join, you can supple either arguments: 43 | "path", "to", "file" 44 | or 45 | "path/to/file" 46 | 47 | Args: 48 | *rel_path: one or more strings representing folder/file name, 49 | similar to os.path.join(*rel_path) 50 | 51 | Returns: 52 | Path to file/folder within code base 53 | """ 54 | return os.path.join(get_code_base_path(), *rel_path) 55 | 56 | 57 | def find_case_insensitive_filename(filename, path): 58 | for f in os.listdir(path): 59 | if f.lower() == filename.lower(): 60 | return f 61 | 62 | 63 | @contextmanager 64 | def temporarily_add_sys_path(path): 65 | sys.path = [path] + sys.path 66 | yield 67 | sys.path = sys.path[1:] 68 | 69 | 70 | def import_from_path(path): 71 | base_path, filename = os.path.split(path) 72 | module_name = filename[:-3] if filename.endswith(".py") else filename 73 | with temporarily_add_sys_path(base_path): 74 | module = importlib.import_module(module_name) 75 | return module 76 | -------------------------------------------------------------------------------- /jiant/utils/python/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | def getter(attr_name: Any): 5 | def f(obj): 6 | return getattr(obj, attr_name) 7 | 8 | return f 9 | 10 | 11 | def indexer(key): 12 | def f(obj): 13 | return obj[key] 14 | 15 | return f 16 | 17 | 18 | def identity(*args): 19 | if len(args) > 1: 20 | return args 21 | else: 22 | return args[0] 23 | 24 | 25 | # noinspection PyUnusedLocal 26 | def always_false(*args, **kwargs): 27 | return False 28 | 29 | 30 | # noinspection PyUnusedLocal 31 | def always_true(*args, **kwargs): 32 | return True 33 | -------------------------------------------------------------------------------- /jiant/utils/python/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import json 4 | 5 | 6 | def read_file(path, mode="r", **kwargs): 7 | with open(path, mode=mode, **kwargs) as f: 8 | return f.read() 9 | 10 | 11 | def write_file(data, path, mode="w", **kwargs): 12 | with open(path, mode=mode, **kwargs) as f: 13 | f.write(data) 14 | 15 | 16 | def read_json(path, mode="r", **kwargs): 17 | return json.loads(read_file(path, mode=mode, **kwargs)) 18 | 19 | 20 | def write_json(data, path): 21 | return write_file(json.dumps(data, indent=2), path) 22 | 23 | 24 | def read_jsonl(path, mode="r", **kwargs): 25 | # Manually open because .splitlines is different from iterating over lines 26 | ls = [] 27 | with open(path, mode, **kwargs) as f: 28 | for line in f: 29 | ls.append(json.loads(line)) 30 | return ls 31 | 32 | 33 | def write_jsonl(data, path): 34 | assert isinstance(data, list) 35 | lines = [to_jsonl(elem) for elem in data] 36 | write_file("\n".join(lines), path) 37 | 38 | 39 | def read_file_lines(path, mode="r", encoding="utf-8", strip_lines=False, **kwargs): 40 | with open(path, mode=mode, encoding=encoding, **kwargs) as f: 41 | lines = f.readlines() 42 | if strip_lines: 43 | return [line.strip() for line in lines] 44 | else: 45 | return lines 46 | 47 | 48 | def read_json_lines(path, mode="r", encoding="utf-8", **kwargs): 49 | with open(path, mode=mode, encoding=encoding, **kwargs) as f: 50 | for line in f.readlines(): 51 | yield json.loads(line) 52 | 53 | 54 | def to_jsonl(data): 55 | return json.dumps(data).replace("\n", "") 56 | 57 | 58 | def create_containing_folder(path): 59 | fol_path = os.path.split(path)[0] 60 | os.makedirs(fol_path, exist_ok=True) 61 | 62 | 63 | def sorted_glob(pathname, *, recursive=False): 64 | return sorted(glob.glob(pathname, recursive=recursive)) 65 | 66 | 67 | def assert_exists(path): 68 | if not os.path.exists(path): 69 | raise FileNotFoundError(path) 70 | 71 | 72 | def assert_not_exists(path): 73 | if os.path.exists(path): 74 | raise FileExistsError(path) 75 | 76 | 77 | def get_num_lines(path): 78 | with open(path, "r") as f: 79 | for i, l in enumerate(f): 80 | pass 81 | return i + 1 82 | 83 | 84 | def create_dir(*args): 85 | """Makes a folder and returns the path 86 | 87 | Args: 88 | *args: args to os.path.join 89 | """ 90 | path = os.path.join(*args) 91 | os.makedirs(path, exist_ok=True) 92 | return path 93 | -------------------------------------------------------------------------------- /jiant/utils/python/logic.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | 4 | def replace_none(elem: Optional[Any], default: Any): 5 | """If elem is None, return default, else return elem 6 | 7 | Args: 8 | elem: element to possibly return 9 | default: default element 10 | 11 | Returns: 12 | elem, or default 13 | 14 | """ 15 | if elem is None: 16 | return default 17 | else: 18 | return elem 19 | -------------------------------------------------------------------------------- /jiant/utils/python/strings.py: -------------------------------------------------------------------------------- 1 | def remove_prefix(s, prefix): 2 | assert s.startswith(prefix) 3 | return s[len(prefix) :] 4 | 5 | 6 | def remove_suffix(s, suffix): 7 | assert s.endswith(suffix) 8 | return s[: -len(suffix)] 9 | 10 | 11 | def replace_prefix(s, prefix, new_prefix): 12 | return new_prefix + remove_prefix(s, prefix) 13 | 14 | 15 | def replace_suffix(s, suffix, new_suffix): 16 | return remove_suffix(s, suffix) + new_suffix 17 | -------------------------------------------------------------------------------- /jiant/utils/string_comparing.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | import collections 4 | 5 | 6 | def normalize_answer(s): 7 | """Lower text and remove punctuation, articles and extra whitespace. 8 | From official ReCoRD eval script 9 | """ 10 | 11 | def remove_articles(text): 12 | return re.sub(r"\b(a|an|the)\b", " ", text) 13 | 14 | def white_space_fix(text): 15 | return " ".join(text.split()) 16 | 17 | def remove_punc(text): 18 | exclude = set(string.punctuation) 19 | return "".join(ch for ch in text if ch not in exclude) 20 | 21 | def lower(text): 22 | return text.lower() 23 | 24 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 25 | 26 | 27 | def string_f1_score(prediction, ground_truth): 28 | """Compute normalized token level F1 29 | From official ReCoRD eval script 30 | """ 31 | prediction_tokens = normalize_answer(prediction).split() 32 | ground_truth_tokens = normalize_answer(ground_truth).split() 33 | common = collections.Counter(prediction_tokens) & collections.Counter(ground_truth_tokens) 34 | num_same = sum(common.values()) 35 | if num_same == 0: 36 | return 0 37 | precision = 1.0 * num_same / len(prediction_tokens) 38 | recall = 1.0 * num_same / len(ground_truth_tokens) 39 | f1 = (2 * precision * recall) / (precision + recall) 40 | return f1 41 | 42 | 43 | def exact_match_score(prediction, ground_truth): 44 | """Compute normalized exact match 45 | From official ReCoRD eval script 46 | """ 47 | return normalize_answer(prediction) == normalize_answer(ground_truth) 48 | -------------------------------------------------------------------------------- /jiant/utils/testing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/jiant/utils/testing/__init__.py -------------------------------------------------------------------------------- /jiant/utils/testing/tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from jiant.utils.python.datastructures import BiMap 4 | from jiant.tasks.core import FeaturizationSpec 5 | 6 | 7 | class SimpleSpaceTokenizer: 8 | 9 | pad_token = "" 10 | cls_token = "" 11 | sep_token = "" 12 | unk_token = "" 13 | SPECIAL_TOKENS = [pad_token, cls_token, sep_token, unk_token] 14 | 15 | def __init__(self, vocabulary: List[str], add_special=True): 16 | if add_special: 17 | vocabulary = self.SPECIAL_TOKENS + vocabulary 18 | self.tokens_to_ids, self.ids_to_tokens = BiMap( 19 | a=vocabulary, b=list(range(len(vocabulary))) 20 | ).get_maps() 21 | 22 | def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: 23 | return [self.tokens_to_ids[token] for token in tokens] 24 | 25 | def tokenize(self, string: str) -> List[str]: 26 | return [ 27 | token if token in self.tokens_to_ids else self.unk_token for token in string.split() 28 | ] 29 | 30 | @classmethod 31 | def get_feat_spec(cls, max_seq_length: int) -> FeaturizationSpec: 32 | return FeaturizationSpec( 33 | max_seq_length=max_seq_length, 34 | cls_token_at_end=False, 35 | pad_on_left=False, 36 | cls_token_segment_id=0, 37 | pad_token_segment_id=0, 38 | pad_token_id=0, 39 | pad_token_mask_id=0, 40 | sequence_a_segment_id=0, 41 | sequence_b_segment_id=1, 42 | sep_token_extra=False, 43 | ) 44 | -------------------------------------------------------------------------------- /jiant/utils/testing/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def is_pytest(): 5 | return "pytest" in sys.modules 6 | -------------------------------------------------------------------------------- /jiant/utils/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from typing import Sequence 4 | 5 | 6 | def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = ""): 7 | """Applies a beginning of word (BoW) marker to every token in the tokens sequence.""" 8 | return [bow_tag + t for t in tokens] 9 | 10 | 11 | def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): 12 | """Applies a end of word (EoW) marker to every token in the tokens sequence.""" 13 | return [t + eow_tag for t in tokens] 14 | 15 | 16 | def process_wordpiece_tokens(tokens: Sequence[str]): 17 | return [process_wordpiece_token_for_alignment(token) for token in tokens] 18 | 19 | 20 | def process_sentencepiece_tokens(tokens: Sequence[str]): 21 | return [process_sentencepiece_token_for_alignment(token) for token in tokens] 22 | 23 | 24 | def process_bytebpe_tokens(tokens: Sequence[str]): 25 | return [process_bytebpe_token_for_alignment(token) for token in tokens] 26 | 27 | 28 | def process_wordpiece_token_for_alignment(t): 29 | """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" 30 | if t.startswith("##"): 31 | return re.sub(r"^##", "", t) 32 | else: 33 | return "" + t 34 | 35 | 36 | def process_sentencepiece_token_for_alignment(t): 37 | """Add word boundary markers, removes token prefix (space meta-symbol).""" 38 | if t.startswith("▁"): 39 | return "" + re.sub(r"^▁", "", t) 40 | else: 41 | return t 42 | 43 | 44 | def process_bytebpe_token_for_alignment(t): 45 | """Add word boundary markers, removes token prefix (space meta-symbol).""" 46 | if t.startswith("Ġ"): 47 | return "" + re.sub(r"^Ġ", "", t) 48 | else: 49 | return t 50 | -------------------------------------------------------------------------------- /jiant/utils/zconf/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import run_config, RunConfig, get_mode_and_cl_args, ModeLookupError 2 | from .core import argparse_attr as attr 3 | 4 | __all__ = ("run_config", "RunConfig", "get_mode_and_cl_args", "ModeLookupError", "attr") 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | 4 | [tool.pytest.ini_options] 5 | filterwarnings = [ 6 | "ignore::UserWarning", 7 | ] 8 | 9 | include = '\.pyi?$' 10 | 11 | exclude = ''' 12 | ( 13 | __pycache__ 14 | | \.git 15 | | \.mypy_cache 16 | | \.pytest_cache 17 | | \.venv 18 | ) 19 | ''' 20 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | -r requirements.txt 2 | sphinx==1.8.1 3 | pytest==3.10.0 4 | pytest-cov==2.8.1 5 | pre-commit==2.3.0 6 | flake8==3.7.9 7 | flake8-docstrings==1.5.0 8 | black==22.3.0 9 | mypy==0.770 10 | -------------------------------------------------------------------------------- /requirements-no-torch.txt: -------------------------------------------------------------------------------- 1 | attrs==19.3.0 2 | bs4==0.0.1 3 | jsonnet==0.15.0 4 | lxml==4.9.1 5 | datasets==1.1.2 6 | nltk>=3.5 7 | numexpr==2.7.1 8 | numpy==1.22.4 9 | pandas==1.0.3 10 | python-Levenshtein==0.12.0 11 | sacremoses==0.0.43 12 | seqeval==0.0.12 13 | scikit-learn==0.22.2.post1 14 | scipy==1.4.1 15 | sentencepiece==0.1.91 16 | tokenizers==0.10.1 17 | tqdm==4.46.0 18 | transformers==4.5.0 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -r requirements-no-torch.txt 2 | torch>=1.8.1 3 | torchvision==0.9.1 4 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ### Tests 2 | 3 | To run tests, run `pytest` from the project root directory. 4 | 5 | #### Notes: 6 | * By default, [tests marked "slow" or "gpu" will be skipped](https://github.com/pyeres/jiant/pull/10#issue-414779551) by the [CI system](https://app.circleci.com/pipelines/github/pyeres/jiant). 7 | * To run "slow" and "gpu" tests (required for some PRs), run tests with `pytest jiant --runslow --rungpu`. 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/tests/__init__.py -------------------------------------------------------------------------------- /tests/proj/main/components/test_task_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import jiant.proj.main.components.task_sampler as task_sampler 5 | 6 | 7 | def test_time_dependent_prob_multitask_sampler_const_p(): 8 | sampler = task_sampler.TimeDependentProbMultiTaskSampler( 9 | task_dict={"rte": None, "mnli": None, "squad_v1": None,}, 10 | rng=0, 11 | task_to_unnormalized_prob_funcs_dict={"rte": "1", "mnli": "1", "squad_v1": "1",}, 12 | ) 13 | gold_p = np.ones(3) / 3 14 | assert np.array_equal(sampler.get_task_p(0), gold_p) 15 | assert np.array_equal(sampler.get_task_p(500), gold_p) 16 | assert np.array_equal(sampler.get_task_p(999), gold_p) 17 | 18 | 19 | def test_time_dependent_prob_multitask_sampler_variable_p(): 20 | sampler = task_sampler.TimeDependentProbMultiTaskSampler( 21 | task_dict={"rte": None, "mnli": None, "squad_v1": None,}, 22 | rng=0, 23 | task_to_unnormalized_prob_funcs_dict={ 24 | "rte": "1", 25 | "mnli": "1 - t/1000", 26 | "squad_v1": "exp(t/1000)", 27 | }, 28 | ) 29 | assert np.array_equal(sampler.get_task_p(0), np.ones(3) / 3) 30 | assert np.allclose(sampler.get_task_p(500), np.array([0.31758924, 0.15879462, 0.52361614])) 31 | assert np.allclose( 32 | sampler.get_task_p(999), np.array([2.69065663e-01, 2.69065663e-04, 7.30665271e-01]) 33 | ) 34 | 35 | 36 | def test_time_dependent_prob_multitask_sampler_handles_max_steps(): 37 | sampler_1 = task_sampler.TimeDependentProbMultiTaskSampler( 38 | task_dict={"rte": None}, rng=0, task_to_unnormalized_prob_funcs_dict={"rte": "1"}, 39 | ) 40 | sampler_2 = task_sampler.TimeDependentProbMultiTaskSampler( 41 | task_dict={"rte": None}, 42 | rng=0, 43 | task_to_unnormalized_prob_funcs_dict={"rte": "1"}, 44 | max_steps=10, 45 | ) 46 | for i in range(10): 47 | sampler_1.pop() 48 | sampler_2.pop() 49 | sampler_1.pop() 50 | with pytest.raises(IndexError): 51 | sampler_2.pop() 52 | sampler_2.reset_counter() 53 | sampler_2.pop() 54 | -------------------------------------------------------------------------------- /tests/proj/main/test_export_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from transformers import BertPreTrainedModel 5 | from transformers import BertTokenizer 6 | from transformers import DebertaV2ForMaskedLM 7 | from transformers import RobertaForMaskedLM 8 | from transformers import RobertaTokenizer 9 | 10 | import jiant.utils.python.io as py_io 11 | 12 | from jiant.proj.main.export_model import export_model 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "model_type, model_class, hf_pretrained_model_name_or_path", 17 | [ 18 | ("bert", BertPreTrainedModel, "bert-base-cased"), 19 | ("roberta", RobertaForMaskedLM, "nyu-mll/roberta-med-small-1M-1",), 20 | ], 21 | ) 22 | def test_export_model(tmp_path, model_type, model_class, hf_pretrained_model_name_or_path): 23 | export_model( 24 | hf_pretrained_model_name_or_path=hf_pretrained_model_name_or_path, 25 | output_base_path=tmp_path, 26 | ) 27 | read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) 28 | assert read_config["hf_pretrained_model_name_or_path"] == hf_pretrained_model_name_or_path 29 | assert read_config["model_path"] == os.path.join(tmp_path, "model", "model.p") 30 | assert read_config["model_config_path"] == os.path.join(tmp_path, "model", "config.json") 31 | 32 | 33 | @pytest.mark.slow 34 | @pytest.mark.parametrize( 35 | "model_type, model_class, hf_pretrained_model_name_or_path", 36 | [("deberta-v2-xlarge", DebertaV2ForMaskedLM, "microsoft/deberta-v2-xlarge",), ], 37 | ) 38 | def test_export_model_large(tmp_path, model_type, model_class, hf_pretrained_model_name_or_path): 39 | export_model( 40 | hf_pretrained_model_name_or_path=hf_pretrained_model_name_or_path, 41 | output_base_path=tmp_path, 42 | ) 43 | read_config = py_io.read_json(os.path.join(tmp_path, f"config.json")) 44 | assert read_config["hf_pretrained_model_name_or_path"] == hf_pretrained_model_name_or_path 45 | assert read_config["model_path"] == os.path.join(tmp_path, "model", "model.p") 46 | assert read_config["model_config_path"] == os.path.join(tmp_path, "model", "config.json") 47 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/mnli/mnli_test.jsonl: -------------------------------------------------------------------------------- 1 | {"premise": "Hierbas, ans seco, ans dulce, and frigola are just a few names worth keeping a look-out for.", "hypothesis": "Hierbas is a name worth looking out for."} 2 | {"premise": "The extent of the behavioral effects would depend in part on the structure of the individual account program and any limits on accessing the funds.", "hypothesis": "Many people would be very unhappy to loose control over their own money."} 3 | {"premise": "Timely access to information is in the best interests of both GAO and the agencies.", "hypothesis": "It is in everyone's best interest to have access to information in a timely manner."} 4 | {"premise": "Based in the Auvergnat spa town of Vichy, the French government often proved more zealous than its masters in suppressing civil liberties and drawing up anti-Jewish legislation.", "hypothesis": "The French government passed anti-Jewish laws aimed at helping the Nazi."} 5 | {"premise": "Built in 1870, its canopy of stained glass and cast iron is the oldest in Dublin; its enthusiastic interior decoration is also typical of the era.", "hypothesis": "It was constructed in 1870 and has the oldest canopy in Dublin."} -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/mnli/mnli_train.jsonl: -------------------------------------------------------------------------------- 1 | {"premise": "Conceptually cream skimming has two basic dimensions - product and geography.", "hypothesis": "Product and geography are what make cream skimming work. ", "label": "neutral"} 2 | {"premise": "you know during the season and i guess at at your level uh you lose them to the next level if if they decide to recall the the parent team the Braves decide to call to recall a guy from triple A then a double A guy goes up to replace him and a single A guy goes up to replace him", "hypothesis": "You lose the things to the following level if the people recall.", "label": "entailment"} 3 | {"premise": "One of our number will carry out your instructions minutely.", "hypothesis": "A member of my team will execute your orders with immense precision.", "label": "entailment"} 4 | {"premise": "How do you know? All this is their information again.", "hypothesis": "This information belongs to them.", "label": "entailment"} 5 | {"premise": "yeah i tell you what though if you go price some of those tennis shoes i can see why now you know they're getting up in the hundred dollar range", "hypothesis": "The tennis shoes have a range of prices.", "label": "neutral"} -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/mnli/mnli_val.jsonl: -------------------------------------------------------------------------------- 1 | {"premise": "The new rights are nice enough", "hypothesis": "Everyone really likes the newest benefits ", "label": "neutral"} 2 | {"premise": "This site includes a list of all award winners and a searchable database of Government Executive articles.", "hypothesis": "The Government Executive articles housed on the website are not able to be searched.", "label": "contradiction"} 3 | {"premise": "uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him", "hypothesis": "I like him for the most part, but would still enjoy seeing someone beat him.", "label": "entailment"} 4 | {"premise": "yeah i i think my favorite restaurant is always been the one closest you know the closest as long as it's it meets the minimum criteria you know of good food", "hypothesis": "My favorite restaurants are always at least a hundred miles away from my house. ", "label": "contradiction"} 5 | {"premise": "i don't know um do you do a lot of camping", "hypothesis": "I know exactly.", "label": "contradiction"} -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/spr1/test.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "The acquisition strengthens BSN 's position in the European pasta market .", "info": {"split": "test", "sent_id": "2254_8"}, "targets": [{"span1": [2, 3], "span2": [1, 2], "label": ["existed_after", "existed_during", "instigation", "manipulated_by_another"]}, {"span1": [2, 3], "span2": [5, 6], "label": ["change_of_state", "existed_after", "existed_before", "existed_during", "location_of_event", "manipulated_by_another", "predicate_changed_argument"]}]} 2 | {"text": "In fact , some of the association 's members -- long-term , buy-and-hold investors -- welcomed the drop in prices .", "info": {"split": "test", "sent_id": "2386_13"}, "targets": [{"span1": [15, 16], "span2": [17, 18], "label": ["existed_during", "manipulated_by_another"]}, {"span1": [15, 16], "span2": [3, 4], "label": ["awareness", "existed_after", "existed_before", "existed_during", "instigation", "volition"]}]} 3 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/spr1/train.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "This conjures up images of a nation full of trim , muscular folks , and suggests couch potatoes are out of season .", "info": {"split": "train", "sent_id": "0409_35"}, "targets": [{"span1": [1, 2], "span2": [0, 1], "label": []}, {"span1": [15, 16], "span2": [0, 1], "label": ["existed_during", "instigation", "manipulated_by_another"]}]} 2 | {"text": "`` I spent so much money that if I look at it , and I 'm not on it , I feel guilty . ''", "info": {"split": "train", "sent_id": "0409_32"}, "targets": [{"span1": [2, 3], "span2": [1, 2], "label": ["awareness", "change_of_state", "existed_after", "existed_before", "existed_during", "exists_as_physical", "instigation", "makes_physical_contact", "predicate_changed_argument", "sentient", "volition"]}, {"span1": [2, 3], "span2": [5, 6], "label": ["changes_possession", "existed_after", "existed_before", "existed_during", "manipulated_by_another"]}]} 3 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/sst/test.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "uneasy mishmash of styles and genres ."} 2 | {"text": "this film 's relationship to actual tension is the same as what christmas-tree flocking in a spray can is to actual snow : a poor -- if durable -- imitation ."} 3 | {"text": "by the end of no such thing the audience , like beatrice , has a watchful affection for the monster ."} 4 | {"text": "director rob marshall went out gunning to make a great one ."} 5 | {"text": "lathan and diggs have considerable personal charm , and their screen rapport makes the old story seem new ."} 6 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/sst/train.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "hide new secretions from the parental units ", "label": "0"} 2 | {"text": "contains no wit , only labored gags ", "label": "0"} 3 | {"text": "that loves its characters and communicates something rather beautiful about human nature ", "label": "1"} 4 | {"text": "remains utterly satisfied to remain the same throughout ", "label": "0"} 5 | {"text": "on the worst revenge-of-the-nerds clich\u00e9s the filmmakers could dredge up ", "label": "0"} 6 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/data/sst/val.jsonl: -------------------------------------------------------------------------------- 1 | {"text": "it 's a charming and often affecting journey . ", "label": "1"} 2 | {"text": "unflinchingly bleak and desperate ", "label": "0"} 3 | {"text": "allows us to hope that nolan is poised to embark a major career as a commercial yet inventive filmmaker . ", "label": "1"} 4 | {"text": "the acting , costumes , music , cinematography and sound are all astounding given the production 's austere locales . ", "label": "1"} 5 | {"text": "it 's slow -- very , very slow . ", "label": "0"} 6 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/mnli.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "mnli", 3 | "paths": { 4 | "train": "data/mnli/mnli_train.jsonl", 5 | "val": "data/mnli/mnli_val.jsonl", 6 | "test": "data/mnli/mnli_test.jsonl" 7 | }, 8 | "name": "mnli" 9 | } -------------------------------------------------------------------------------- /tests/tasks/lib/resources/spr1.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "spr1", 3 | "paths": { 4 | "train": "data/spr1/train.jsonl", 5 | "val": "data/spr1/test.jsonl" 6 | }, 7 | "name": "spr1" 8 | } 9 | -------------------------------------------------------------------------------- /tests/tasks/lib/resources/sst.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "sst", 3 | "paths": { 4 | "train": "data/sst/train.jsonl", 5 | "val": "data/sst/val.jsonl", 6 | "test": "data/sst/test.jsonl" 7 | }, 8 | "name": "sst" 9 | } 10 | -------------------------------------------------------------------------------- /tests/tasks/lib/templates/test_hacky_tokenization_matching.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from transformers import RobertaTokenizer, BertTokenizer 4 | 5 | from jiant.tasks.lib.templates.hacky_tokenization_matching import flat_strip 6 | 7 | from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer 8 | import jiant.shared.model_resolution as model_resolution 9 | 10 | 11 | TEST_STRINGS = ["Hi, my name is Bob Roberts."] 12 | FLAT_STRIP_EXPECTED_STRINGS = ["hi,mynameisbobroberts."] 13 | 14 | 15 | @pytest.mark.parametrize("model_type", ["albert-base-v2", "roberta-base", "bert-base-uncased"]) 16 | def test_delegate_flat_strip(model_type): 17 | tokenizer = model_resolution.resolve_tokenizer_class(model_type.split("-")[0]).from_pretrained( 18 | model_type 19 | ) 20 | for test_string, target_string in zip(TEST_STRINGS, FLAT_STRIP_EXPECTED_STRINGS): 21 | flat_strip_result = flat_strip( 22 | tokenizer.tokenize(test_string), tokenizer, return_indices=False 23 | ) 24 | assert flat_strip_result == target_string 25 | -------------------------------------------------------------------------------- /tests/tasks/lib/test_mlm_premasked.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | import jiant.shared.model_resolution as model_resolution 4 | 5 | from jiant.proj.main.modeling.primary import JiantTransformersModelFactory 6 | from jiant.tasks.retrieval import MLMPremaskedTask 7 | 8 | 9 | def test_tokenization_and_featurization(): 10 | task = MLMPremaskedTask(name="mlm_premasked", path_dict={}) 11 | tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base") 12 | example = task.Example(guid=None, text="Hi, my name is Bob Roberts.", masked_spans=[[15, 18]],) 13 | tokenized_example = example.tokenize(tokenizer=tokenizer) 14 | assert tokenized_example.masked_tokens == [ 15 | "Hi", 16 | ",", 17 | "Ġmy", 18 | "Ġname", 19 | "Ġis", 20 | "Ġ", 21 | "", 22 | "ĠRoberts", 23 | ".", 24 | ] 25 | assert tokenized_example.label_tokens == [ 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "Bob", 33 | "", 34 | "", 35 | ] 36 | 37 | data_row = tokenized_example.featurize( 38 | tokenizer=tokenizer, 39 | feat_spec=JiantTransformersModelFactory.build_featurization_spec( 40 | model_type="roberta", max_seq_length=16, 41 | ), 42 | ) 43 | assert list(data_row.masked_input_ids) == [ 44 | 0, 45 | 30086, 46 | 6, 47 | 127, 48 | 766, 49 | 16, 50 | 1437, 51 | 50264, 52 | 6274, 53 | 4, 54 | 2, 55 | 1, 56 | 1, 57 | 1, 58 | 1, 59 | 1, 60 | ] 61 | assert list(data_row.masked_lm_labels) == [ 62 | -100, 63 | -100, 64 | -100, 65 | -100, 66 | -100, 67 | -100, 68 | -100, 69 | 25158, 70 | -100, 71 | -100, 72 | -100, 73 | -100, 74 | -100, 75 | -100, 76 | -100, 77 | -100, 78 | ] 79 | -------------------------------------------------------------------------------- /tests/tasks/lib/test_mlm_pretokenized.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | import jiant.shared.model_resolution as model_resolution 4 | from jiant.tasks.retrieval import MLMPretokenizedTask 5 | 6 | from jiant.shared.model_resolution import ModelArchitectures 7 | from jiant.proj.main.modeling.primary import JiantTransformersModelFactory 8 | 9 | 10 | def test_tokenization_and_featurization(): 11 | task = MLMPretokenizedTask(name="mlm_pretokenized", path_dict={}) 12 | tokenizer = transformers.RobertaTokenizer.from_pretrained("roberta-base") 13 | example = task.Example( 14 | guid=None, 15 | tokenized_text=["Hi", ",", "Ġmy", "Ġname", "Ġis", "ĠBob", "ĠRoberts", "."], 16 | masked_spans=[[2, 3], [5, 6]], 17 | ) 18 | tokenized_example = example.tokenize(tokenizer=tokenizer) 19 | assert tokenized_example.masked_tokens == [ 20 | "Hi", 21 | ",", 22 | "Ġmy", 23 | "Ġname", 24 | "Ġis", 25 | "ĠBob", 26 | "ĠRoberts", 27 | ".", 28 | ] 29 | assert tokenized_example.label_tokens == [ 30 | "", 31 | "", 32 | "Ġmy", 33 | "", 34 | "", 35 | "ĠBob", 36 | "", 37 | "", 38 | ] 39 | 40 | data_row = tokenized_example.featurize( 41 | tokenizer=tokenizer, 42 | feat_spec=JiantTransformersModelFactory.build_featurization_spec( 43 | model_type=ModelArchitectures.ROBERTA.value, max_seq_length=16, 44 | ), 45 | ) 46 | assert list(data_row.masked_input_ids) == [ 47 | 0, 48 | 30086, 49 | 6, 50 | 127, 51 | 766, 52 | 16, 53 | 3045, 54 | 6274, 55 | 4, 56 | 2, 57 | 1, 58 | 1, 59 | 1, 60 | 1, 61 | 1, 62 | 1, 63 | ] 64 | assert list(data_row.masked_lm_labels) == [ 65 | -100, 66 | -100, 67 | -100, 68 | 127, 69 | -100, 70 | -100, 71 | 3045, 72 | -100, 73 | -100, 74 | -100, 75 | -100, 76 | -100, 77 | -100, 78 | -100, 79 | -100, 80 | -100, 81 | ] 82 | -------------------------------------------------------------------------------- /tests/tasks/lib/test_wic.py: -------------------------------------------------------------------------------- 1 | from collections import Counter 2 | 3 | from jiant.tasks.utils import ExclusiveSpan 4 | from jiant.tasks.lib.wic import Example, TokenizedExample 5 | from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer 6 | 7 | 8 | EXAMPLES = [ 9 | Example( 10 | guid="train-1", 11 | sentence1="Approach a task.", 12 | sentence2="To approach the city.", 13 | word="approach", 14 | span1=ExclusiveSpan(start=0, end=8), 15 | span2=ExclusiveSpan(start=3, end=11), 16 | label=False, 17 | ), 18 | Example( 19 | guid="train-2", 20 | sentence1="In England they call takeout food 'takeaway'.", 21 | sentence2="If you're hungry, there's a takeaway just around the corner.", 22 | word="takeaway", 23 | span1=ExclusiveSpan(start=35, end=43), 24 | span2=ExclusiveSpan(start=28, end=36), 25 | label=True, 26 | ), 27 | ] 28 | 29 | TOKENIZED_EXAMPLES = [ 30 | TokenizedExample( 31 | guid="train-1", 32 | sentence1_tokens=["Approach", "a", "task."], 33 | sentence2_tokens=["To", "approach", "the", "city."], 34 | word=["approach"], 35 | sentence1_span=ExclusiveSpan(start=0, end=1), 36 | sentence2_span=ExclusiveSpan(start=1, end=2), 37 | label_id=0, 38 | ), 39 | TokenizedExample( 40 | guid="train-2", 41 | sentence1_tokens=["In", "England", "they", "call", "takeout", "food", "'takeaway'."], 42 | sentence2_tokens=[ 43 | "If", 44 | "you're", 45 | "hungry,", 46 | "there's", 47 | "a", 48 | "takeaway", 49 | "just", 50 | "around", 51 | "the", 52 | "corner.", 53 | ], 54 | word=["takeaway"], 55 | sentence1_span=ExclusiveSpan(start=6, end=7), 56 | sentence2_span=ExclusiveSpan(start=5, end=6), 57 | label_id=1, 58 | ), 59 | ] 60 | 61 | 62 | def test_task_tokenization(): 63 | token_counter = Counter() 64 | for example in EXAMPLES: 65 | token_counter.update(example.sentence1.split() + example.sentence2.split()) 66 | token_vocab = list(token_counter.keys()) 67 | tokenizer = SimpleSpaceTokenizer(vocabulary=token_vocab) 68 | 69 | for example, tokenized_example in zip(EXAMPLES, TOKENIZED_EXAMPLES): 70 | assert example.tokenize(tokenizer).to_dict() == tokenized_example.to_dict() 71 | -------------------------------------------------------------------------------- /tests/test_zconf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/tests/test_zconf/__init__.py -------------------------------------------------------------------------------- /tests/test_zconf/jsons/empty.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /tests/test_zconf/jsons/simple.json: -------------------------------------------------------------------------------- 1 | { 2 | "str_attr": "hello", 3 | "int_default_attr": 3 4 | } -------------------------------------------------------------------------------- /tests/test_zconf/jsons/store_true.json: -------------------------------------------------------------------------------- 1 | { 2 | "store_true_attr": true 3 | } -------------------------------------------------------------------------------- /tests/test_zconf/jsons/store_true_false.json: -------------------------------------------------------------------------------- 1 | { 2 | "store_true_attr": false 3 | } -------------------------------------------------------------------------------- /tests/test_zconf/test_confs.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import jiant.utils.zconf as zconf 4 | 5 | 6 | @zconf.run_config 7 | class Config(zconf.RunConfig): 8 | attr1 = zconf.attr(default=None) 9 | attr2 = zconf.attr(required=True) 10 | 11 | 12 | def test_args(): 13 | config = Config(attr1=1, attr2=2) 14 | assert config.attr1 == 1 15 | assert config.attr2 == 2 16 | 17 | config = Config(attr2=2) 18 | assert config.attr1 is None 19 | assert config.attr2 == 2 20 | 21 | 22 | def test_args_required(): 23 | with pytest.raises(TypeError): 24 | Config() 25 | 26 | 27 | def test_args_required_command_line(): 28 | with pytest.raises(SystemExit): 29 | Config.run_cli_json_prepend(cl_args=[]) 30 | 31 | 32 | def test_to_dict(): 33 | config = Config(attr1=1, attr2=2) 34 | conf_dict = config.to_dict() 35 | assert len(conf_dict) == 2 36 | assert conf_dict["attr1"] == 1 37 | assert conf_dict["attr2"] == 2 38 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/config/base_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_name": "BERT_BASE", 3 | "model": "BERT", 4 | "tasks": ["mrpc"], 5 | "params": { 6 | "lr": 0.0003, 7 | "val_interval": 100 8 | } 9 | } -------------------------------------------------------------------------------- /tests/utils/config/final_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_name": "BERT_BASE", 3 | "model": "BERT", 4 | "tasks": ["sst", "mrpc"], 5 | "params": { 6 | "val_interval": 100 7 | }, 8 | "add_on_setting": "added_last" 9 | } -------------------------------------------------------------------------------- /tests/utils/config/first_override_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "run_name": "BERT_BASE", 3 | "model": "BERT", 4 | "tasks": ["sst", "mrpc"], 5 | "params": { 6 | "lr": null 7 | } 8 | } -------------------------------------------------------------------------------- /tests/utils/config/second_override_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_on_setting": "added_last" 3 | } -------------------------------------------------------------------------------- /tests/utils/python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nyu-mll/jiant/daa5a258e3af5e7503288de8401429eaf3f58e13/tests/utils/python/__init__.py -------------------------------------------------------------------------------- /tests/utils/python/test_checks.py: -------------------------------------------------------------------------------- 1 | import jiant.utils.python.checks as py_checks 2 | 3 | 4 | def test_dict_equal(): 5 | assert py_checks.dict_equal({1: 2}, {1: 2}) 6 | assert not py_checks.dict_equal({1: 2}, {1: 3}) 7 | assert not py_checks.dict_equal({1: 2}, {2: 2}) 8 | assert not py_checks.dict_equal({1: 2}, {2: 2, 1: 1}) 9 | -------------------------------------------------------------------------------- /tests/utils/python/test_filesystem.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import jiant.utils.python.filesystem as py_filesystem 4 | 5 | 6 | def test_get_code_base_path(): 7 | code_base_path = py_filesystem.get_code_base_path() 8 | assert os.path.exists(code_base_path) 9 | 10 | 11 | def test_get_code_asset_path(): 12 | import jiant 13 | 14 | assert py_filesystem.get_code_asset_path(os.path.join("jiant", "__init__.py")) == jiant.__file__ 15 | -------------------------------------------------------------------------------- /tests/utils/python/test_functional.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import jiant.utils.python.functional as py_functional 4 | 5 | 6 | def test_indexer(): 7 | assert py_functional.indexer(1)({1: 2}) == 2 8 | with pytest.raises(KeyError): 9 | py_functional.indexer("1")({1: 2}) 10 | -------------------------------------------------------------------------------- /tests/utils/python/test_logic.py: -------------------------------------------------------------------------------- 1 | import jiant.utils.python.logic as py_logic 2 | 3 | 4 | def test_replace_none(): 5 | assert py_logic.replace_none(1, default=2) == 1 6 | assert py_logic.replace_none(None, default=2) == 2 7 | -------------------------------------------------------------------------------- /tests/utils/test_config_handlers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import json 3 | import os 4 | 5 | from jiant.utils import config_handlers 6 | 7 | 8 | def test_json_merge_patch(): 9 | """Tests that JSON Merge Patch works as expected (https://tools.ietf.org/html/rfc7396)""" 10 | target = """ 11 | { 12 | "title": "Goodbye!", 13 | "author" : { 14 | "givenName" : "John", 15 | "familyName" : "Doe" 16 | }, 17 | "tags":[ "example", "sample" ], 18 | "content": "This will be unchanged" 19 | } 20 | """ 21 | patch = """ 22 | { 23 | "title": "Hello!", 24 | "phoneNumber": "+01-123-456-7890", 25 | "author": { 26 | "familyName": null 27 | }, 28 | "tags": [ "example" ] 29 | } 30 | """ 31 | merged = config_handlers.json_merge_patch(target, patch) 32 | expected = """ 33 | { 34 | "title": "Hello!", 35 | "author" : { 36 | "givenName" : "John" 37 | }, 38 | "tags": [ "example" ], 39 | "content": "This will be unchanged", 40 | "phoneNumber": "+01-123-456-7890" 41 | } 42 | """ 43 | merged_sorted: str = json.dumps(json.loads(merged), sort_keys=True) 44 | expected_sorted: str = json.dumps(json.loads(expected), sort_keys=True) 45 | assert merged_sorted == expected_sorted 46 | 47 | 48 | def test_merging_multiple_json_configs(): 49 | with open(os.path.join(os.path.dirname(__file__), "config/base_config.json")) as f: 50 | base_config = f.read() 51 | with open(os.path.join(os.path.dirname(__file__), "./config/first_override_config.json")) as f: 52 | override_config_1 = f.read() 53 | with open(os.path.join(os.path.dirname(__file__), "./config/second_override_config.json")) as f: 54 | override_config_2 = f.read() 55 | merged_config = config_handlers.merge_jsons_in_order( 56 | [base_config, override_config_1, override_config_2] 57 | ) 58 | with open(os.path.join(os.path.dirname(__file__), "./config/final_config.json")) as f: 59 | expected_config = f.read() 60 | sorted_merged_config = json.dumps(json.loads(merged_config), sort_keys=True) 61 | sorted_expected_config = json.dumps(json.loads(expected_config), sort_keys=True) 62 | assert sorted_merged_config == sorted_expected_config 63 | -------------------------------------------------------------------------------- /tests/utils/test_data_handlers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from jiant.utils import data_handlers 5 | 6 | 7 | def test_md5_checksum_matches_expected_checksum(): 8 | expected_md5_checksum = "4d5e587120171bc1ba4d49e2aa862a12" # calc'd w/ http://onlinemd5.com/ 9 | filepath = os.path.join(os.path.dirname(__file__), "config/base_config.json") 10 | computed_md5_checksum = data_handlers.md5_checksum(filepath) 11 | assert expected_md5_checksum == computed_md5_checksum 12 | -------------------------------------------------------------------------------- /tests/utils/test_path_parse.py: -------------------------------------------------------------------------------- 1 | import jiant.utils.path_parse as path_parse 2 | 3 | 4 | def test_tags_to_regex(): 5 | assert ( 6 | path_parse.tags_to_regex("/path/to/experiments/{model}/{task}") 7 | == "/path/to/experiments/(?P\\w+)/(?P\\w+)" 8 | ) 9 | 10 | assert ( 11 | path_parse.tags_to_regex("/path/to/experiments/{model}/{task}", default_format="(\\w|_)+") 12 | == "/path/to/experiments/(?P(\\w|_)+)/(?P(\\w|_)+)" 13 | ) 14 | 15 | assert ( 16 | path_parse.tags_to_regex( 17 | "/path/to/experiments/{model}/{task}", format_dict={"task": "(\\w|_)+"} 18 | ) 19 | == "/path/to/experiments/(?P\\w+)/(?P(\\w|_)+)" 20 | ) 21 | -------------------------------------------------------------------------------- /tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | from jiant.tasks.utils import truncate_sequences 2 | 3 | 4 | def test_truncate_empty_sequence(): 5 | seq = [] 6 | trunc_seq = truncate_sequences(seq, 10) 7 | assert not trunc_seq 8 | 9 | 10 | def test_truncate_single_sequence_default_trunc_end(): 11 | seq = [["abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz"]] 12 | trunc_seq = truncate_sequences(seq, 8) 13 | assert trunc_seq == [["abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx"]] 14 | 15 | 16 | def test_truncate_single_sequence_trunc_start(): 17 | seq = [["abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz"]] 18 | trunc_seq = truncate_sequences(seq, 8, False) 19 | assert trunc_seq == [["def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz"]] 20 | 21 | 22 | def test_truncate_two_sequences_default_trunc_end(): 23 | seqs = [["abc", "def", "ghi", "jkl"], ["mno", "pqr", "stu", "vwx", "yz"]] 24 | trunc_seqs = truncate_sequences(seqs, 8) 25 | assert trunc_seqs == [["abc", "def", "ghi", "jkl"], ["mno", "pqr", "stu", "vwx"]] 26 | 27 | 28 | def test_truncate_more_than_two_sequences_trunc_start(): 29 | seqs = [["abc", "def", "ghi"], ["jkl", "mno", "pqr"], ["stu", "vwx", "yz"]] 30 | trunc_seqs = truncate_sequences(seqs, 8) 31 | assert trunc_seqs == [["abc", "def"], ["jkl", "mno", "pqr"], ["stu", "vwx", "yz"]] 32 | 33 | 34 | def test_truncate_two_sequences_default_trunc_start(): 35 | seqs = [["abc", "def", "ghi", "jkl"], ["mno", "pqr", "stu", "vwx", "yz"]] 36 | trunc_seqs = truncate_sequences(seqs, 8, False) 37 | assert trunc_seqs == [["abc", "def", "ghi", "jkl"], ["pqr", "stu", "vwx", "yz"]] 38 | --------------------------------------------------------------------------------