├── tests ├── models │ ├── __init__.py │ ├── failing.py │ └── dummy.py ├── .allennlp_plugins ├── utils.py ├── test_sweep.py ├── parameter-tying_sweep_v1.0.0.yaml ├── test_single_run.py ├── configs │ └── parameter_tying_v1.0.0.jsonnet ├── test_failing_single_run.py └── data │ └── snli_1.0_test │ ├── snli_1.0_train.jsonl │ ├── snli_1.0_test.jsonl │ └── snli_1.0_dev.jsonl ├── src └── wandb_allennlp │ ├── __main__.py │ ├── commands │ ├── __init__.py │ ├── download_from_wandb.py │ ├── parser_base.py │ └── train_with_wandb.py │ ├── training │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── subcallbacks.py │ │ ├── utils.py │ │ └── log_to_wandb.py │ └── train_and_test.py │ ├── config.py │ ├── __init__.py │ ├── versioned.py │ └── utils.py ├── .allennlp_plugins ├── docs_source ├── .gitignore ├── images │ └── banner.png ├── index.rst ├── templates │ ├── module.rst_t │ ├── versioning.html │ ├── version_banner.html │ ├── versions.html │ └── package.rst_t ├── make.py └── conf.py ├── doc_requirements.txt ├── test_requirements.txt ├── docs └── index.html ├── .github ├── workflows │ ├── labeler.yml │ ├── deployment.yml │ ├── pull_requests.yml │ ├── ci-build-docs.yml │ └── changelog.yml └── labels.yml ├── pyproject.toml ├── setup.cfg ├── LICENSE ├── setup.py ├── tox.ini ├── examples ├── README.md ├── training_configs │ └── pair_classification │ │ └── decomposable_attention.jsonnet └── sweep_configs │ └── pair_classification │ └── decomposable_attention.yaml ├── .gitignore ├── CHANGELOG.md └── README.md /tests/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/wandb_allennlp/__main__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.allennlp_plugins: -------------------------------------------------------------------------------- 1 | wandb_allennlp 2 | -------------------------------------------------------------------------------- /tests/.allennlp_plugins: -------------------------------------------------------------------------------- 1 | wandb_allennlp 2 | -------------------------------------------------------------------------------- /docs_source/.gitignore: -------------------------------------------------------------------------------- 1 | box_embeddings/ 2 | conf/ 3 | test_case/ 4 | -------------------------------------------------------------------------------- /docs_source/images/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhruvdcoder/wandb-allennlp/HEAD/docs_source/images/banner.png -------------------------------------------------------------------------------- /src/wandb_allennlp/commands/__init__.py: -------------------------------------------------------------------------------- 1 | from wandb_allennlp.commands import parser_base, train_with_wandb, download_from_wandb 2 | -------------------------------------------------------------------------------- /src/wandb_allennlp/training/__init__.py: -------------------------------------------------------------------------------- 1 | from wandb_allennlp.training import train_and_test 2 | from wandb_allennlp.training import callbacks 3 | -------------------------------------------------------------------------------- /doc_requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | git+https://github.com/miyakogi/m2r.git 4 | docstr-coverage 5 | sphinx-autoapi 6 | sphinx-multiversion 7 | -------------------------------------------------------------------------------- /src/wandb_allennlp/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ALLENNLP_SERIALIZATION_DIR = os.environ.get( 4 | "ALLENNLP_SERIALIZATION_DIR", ".allennlp_models" 5 | ) 6 | -------------------------------------------------------------------------------- /docs_source/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../README.md 3 | 4 | .. toctree:: 5 | :maxdepth: 3 6 | :caption: Code 7 | 8 | API Reference 9 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture 5 | def base_translator(): 6 | from wandb_allennlp.commandline import Translator 7 | 8 | return Translator() 9 | -------------------------------------------------------------------------------- /src/wandb_allennlp/training/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from wandb_allennlp.training.callbacks import log_to_wandb 2 | from wandb_allennlp.training.callbacks.subcallbacks import LogBestValidationMetrics 3 | -------------------------------------------------------------------------------- /src/wandb_allennlp/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional, Callable 2 | import wandb_allennlp.commands 3 | import wandb_allennlp.training 4 | import os 5 | 6 | __version__ = "0.3.2" 7 | -------------------------------------------------------------------------------- /src/wandb_allennlp/versioned.py: -------------------------------------------------------------------------------- 1 | import allennlp 2 | 3 | if int(allennlp.version._MAJOR) >= 1: 4 | from allennlp.__main__ import run as allennlp_run 5 | else: 6 | from allennlp.run import run as allennlp_run 7 | -------------------------------------------------------------------------------- /test_requirements.txt: -------------------------------------------------------------------------------- 1 | mypy 2 | pytest 3 | pytest-console-scripts==1.1.0 4 | black 5 | pytest-flake8 6 | flake8 7 | flake8-docstrings 8 | flake8-annotations 9 | flake8-black 10 | pre-commit 11 | darglint 12 | coverage 13 | -------------------------------------------------------------------------------- /docs_source/templates/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- [basename] | join(' ') | e | heading }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | {%- for option in automodule_options %} 7 | :{{ option }}: 8 | {%- endfor %} 9 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |

See the latest documentation here.

9 | 10 | 11 | -------------------------------------------------------------------------------- /docs_source/templates/versioning.html: -------------------------------------------------------------------------------- 1 | {% if versions %} 2 |

{{ _('Branches') }}

3 | 8 |

{{ _('Tags') }}

9 | 14 | {% endif %} 15 | -------------------------------------------------------------------------------- /src/wandb_allennlp/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional, Callable 2 | import os 3 | 4 | 5 | def read_from_env( 6 | key: str, type_: Optional[Callable[[str], Any]] = None 7 | ) -> Optional[Any]: 8 | val_str = os.environ.get(key, None) 9 | 10 | if type_ is not None and isinstance(val_str, str): 11 | val = type_(val_str) 12 | else: 13 | val = val_str 14 | 15 | return val 16 | -------------------------------------------------------------------------------- /.github/workflows/labeler.yml: -------------------------------------------------------------------------------- 1 | name: Labeler 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | 7 | jobs: 8 | labeler: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Check out the repository 12 | uses: actions/checkout@v2.3.4 13 | 14 | - name: Run Labeler 15 | if: success() 16 | uses: crazy-max/ghaction-github-labeler@v3.1.1 17 | with: 18 | skip-delete: true 19 | yaml-file: .github/labels.yml 20 | -------------------------------------------------------------------------------- /tests/test_sweep.py: -------------------------------------------------------------------------------- 1 | def test_sweep(script_runner): 2 | 3 | ret = script_runner.run( 4 | "wandb", 5 | "agent", 6 | "--count=1", 7 | "dhruveshpate/wandb-allennlp-wandb_allennlp_tests/fyntzj7v", 8 | ) 9 | 10 | assert ret.success 11 | assert ( 12 | "(success)." in ret.stderr 13 | or "wandb: Program ended successfully." in ret.stderr 14 | ) 15 | assert not ( 16 | "(failed 1)" in ret.stderr 17 | or "wandb: Program failed with code 1" in ret.stderr 18 | ) 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "wheel", 5 | ] 6 | 7 | [tool.autopep8] 8 | max_line_length = 120 9 | ignore = ["W504", "W504", "E402", "E731", "C40", "E741", "F40", "F841"] 10 | 11 | [tool.black] 12 | # https://github.com/psf/black 13 | line-length = 79 14 | target-version = ["py37"] 15 | exclude = "(.eggs|.git|.hg|.mypy_cache|.nox|.tox|.venv*|.svn|_build|buck-out|build|dist|__pycache__)" 16 | 17 | [tool.darglint] 18 | docstring_style = "google" 19 | #ignore checking doc strings on tests 20 | ignore_regex="^test_(.*)" 21 | -------------------------------------------------------------------------------- /.github/workflows/deployment.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - name: Set up Python 13 | uses: actions/setup-python@v1 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine 20 | - name: Build and publish 21 | env: 22 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 23 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 24 | run: | 25 | python setup.py sdist bdist_wheel 26 | twine upload dist/* 27 | -------------------------------------------------------------------------------- /docs_source/templates/version_banner.html: -------------------------------------------------------------------------------- 1 | {% extends "page.html" %} 2 | {% block body %} 3 | {% if current_version and latest_version and current_version != latest_version %} 4 |

5 | 6 | {% if current_version.is_released %} 7 | You're reading an old version of this documentation. 8 | If you want up-to-date information, please have a look at {{latest_version.name}}. 9 | {% else %} 10 | You're reading the documentation for a development version. 11 | For the latest released version, please have a look at {{latest_version.name}}. 12 | {% endif %} 13 | 14 |

15 | {% endif %} 16 | {{ super() }} 17 | {% endblock %}% 18 | 19 | -------------------------------------------------------------------------------- /tests/parameter-tying_sweep_v1.0.0.yaml: -------------------------------------------------------------------------------- 1 | name: v1.0.0 2 | program: allennlp 3 | command: 4 | - ${program} #omit the interpreter as we use allennlp train command directly 5 | - "train-with-wandb" # subcommand 6 | - "configs/parameter_tying_v1.0.0.jsonnet" 7 | - "--include-package=models" # add all packages containing your registered classes here 8 | - ${args} 9 | method: bayes 10 | metric: 11 | name: training_loss 12 | goal: minimize 13 | parameters: 14 | # hyperparameters start with overrides 15 | # Ranges 16 | # Add env. to tell that it is a top level parameter 17 | env.a: 18 | min: 1 19 | max: 10 20 | distribution: uniform 21 | env.bool_value: 22 | values: [true, false] 23 | env.int_value: 24 | values: [-1, 0, 1, 10] 25 | model.d: 26 | value: 1 27 | -------------------------------------------------------------------------------- /tests/test_single_run.py: -------------------------------------------------------------------------------- 1 | def test_run(script_runner): 2 | ret = script_runner.run( 3 | "allennlp", 4 | "train-with-wandb", 5 | "configs/parameter_tying_v1.0.0.jsonnet", 6 | "--wandb-entity=dhruveshpate", 7 | "--wandb-project=wandb-allennlp-wandb_allennlp_tests", 8 | "--wandb-name=plugging_test_run", 9 | "--include-package=models", 10 | "--env.a=1.1", 11 | "--env.bool_value=true", 12 | "--env.int_value=10", 13 | "--model.d=1", # keep this 1 14 | "--env.call_finish_on_end=false", 15 | ) 16 | assert ret.success 17 | assert ( 18 | "(success)." in ret.stderr 19 | or "wandb: Program ended successfully." in ret.stderr 20 | ) 21 | assert not ( 22 | "(failed 1)" in ret.stderr 23 | or "wandb: Program failed with code 1" in ret.stderr 24 | ) 25 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pydocstyle] 2 | convention = google 3 | add-ignore = D104,D107,D202 4 | 5 | [darglint] 6 | docstring_style=google 7 | # ignore checking doc strings on tests 8 | ignore_regex=^test_(.*) 9 | 10 | [mypy] 11 | ignore_missing_imports =True 12 | 13 | [flake8] 14 | # TODO: this should be 88 or 100 according PEP8 15 | max-line-length = 79 16 | exclude = .tox,*.egg,build,temp 17 | select = E,W,F 18 | doctests = True 19 | verbose = 2 20 | # https://pep8.readthedocs.io/en/latest/intro.html#error-codes 21 | format = pylint 22 | ignore = 23 | E731 24 | W504 25 | F401 26 | F841 27 | E203 # E203 - whitespace before ':'. Opposite convention enforced by black 28 | E231 # E231: missing whitespace after ',', ';', or ':'; for black 29 | E501 # E501 - line too long. Handled by black, we have longer lines 30 | W503 # W503 - line break before binary operator, need for black 31 | -------------------------------------------------------------------------------- /docs_source/templates/versions.html: -------------------------------------------------------------------------------- 1 |
2 | 3 | Other Versions 4 | v: {{ current_version.name }} 5 | 6 | 7 |
8 | {%- if versions.tags %} 9 |
10 |
Tags
11 | {%- for item in versions.tags %} 12 |
{{ item.name }}
13 | {%- endfor %} 14 |
15 | {%- endif %} 16 | {%- if versions.branches %} 17 |
18 |
Branches
19 | {%- for item in versions.branches %} 20 |
{{ item.name }}
21 | {%- endfor %} 22 |
23 | {%- endif %} 24 |
25 |
26 | {%- endif %} 27 | 28 | -------------------------------------------------------------------------------- /docs_source/templates/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- [pkgname] | join(" ") | e | heading }} 18 | {% else %} 19 | {{- [pkgname] | join(" ") | e | heading }} 20 | {% endif %} 21 | 22 | {%- if modulefirst and not is_namespace %} 23 | {{ automodule(pkgname, automodule_options) }} 24 | {% endif %} 25 | 26 | {%- if subpackages %} 27 | 28 | {{ toctree(subpackages) }} 29 | {% endif %} 30 | 31 | {%- if submodules %} 32 | 33 | {% if separatemodules %} 34 | {{ toctree(submodules) }} 35 | {%- else %} 36 | {%- for submodule in submodules %} 37 | {% if show_headings %} 38 | {{- [submodule] | join(" ") | e | heading(2) }} 39 | {% endif %} 40 | {{ automodule(submodule, automodule_options) }} 41 | {% endfor %} 42 | {%- endif %} 43 | {% endif %} 44 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Dhruvesh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/configs/parameter_tying_v1.0.0.jsonnet: -------------------------------------------------------------------------------- 1 | local data_path = std.extVar('DATA_PATH'); 2 | local a = std.parseJson(std.extVar('a')); 3 | local bool_value = std.parseJson(std.extVar('bool_value')); 4 | local int_value = std.parseJson(std.extVar('int_value')); 5 | local call_finish_on_end = std.parseJson(std.extVar('call_finish_on_end')); 6 | { 7 | type: 'train_test_log_to_wandb', 8 | evaluate_on_test: true, 9 | dataset_reader: { 10 | type: 'dummy', 11 | }, 12 | train_data_path: data_path + '/snli_1.0_test/snli_1.0_train.jsonl', 13 | validation_data_path: data_path + '/snli_1.0_test/snli_1.0_dev.jsonl', 14 | test_data_path: data_path + '/snli_1.0_test/snli_1.0_test.jsonl', 15 | model: { 16 | type: 'parameter-tying', 17 | a: a, 18 | b: a, 19 | d: 0, 20 | bool_value: bool_value, 21 | bool_value_not: !bool_value, 22 | int_value: int_value, 23 | int_value_10: int_value + 10, 24 | 25 | }, 26 | data_loader: { 27 | batch_size: 2, 28 | }, 29 | trainer: { 30 | optimizer: { 31 | type: 'adam', 32 | lr: 0.001, 33 | weight_decay: 0.0, 34 | }, 35 | cuda_device: -1, 36 | num_epochs: 5, 37 | callbacks: [ 38 | { 39 | type: 'wandb_allennlp', 40 | files_to_save: ['config.json'], 41 | files_to_save_at_end: ['*.tar.gz'], 42 | finish_on_end: call_finish_on_end, 43 | }, 44 | ], 45 | }, 46 | } 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | install_requires = [ 7 | "allennlp>=2.5.0", 8 | "wandb>=0.10.11", 9 | "pyyaml", 10 | "tensorboard", 11 | "overrides", 12 | "shortuuid", 13 | # allennlp 2.9+ needs a newer version - this may break older versions 14 | # "nltk<3.6.6" # remove this once the support for older versions of ALLENNLP is dropped. 15 | ] 16 | 17 | setup( 18 | name="wandb_allennlp", 19 | version="0.3.3", 20 | author="Dhruvesh Patel", 21 | author_email="1793dnp@gmail.com", 22 | description="Utilities to use allennlp with wandb", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/dhruvdcoder/wandb-allennlp", 26 | packages=find_packages( 27 | where="src", 28 | exclude=[ 29 | "*.tests", 30 | "*.tests.*", 31 | "tests.*", 32 | "tests", 33 | "examples", 34 | "wandb", 35 | ], 36 | ), 37 | package_dir={"": "src"}, 38 | package_data={"wandb_allennlp": ["py.typed"]}, 39 | install_requires=install_requires, 40 | classifiers=[ 41 | "Programming Language :: Python :: 3", 42 | "License :: OSI Approved :: MIT License", 43 | "Operating System :: OS Independent", 44 | ], 45 | python_requires=">=3.6", 46 | ) 47 | -------------------------------------------------------------------------------- /src/wandb_allennlp/training/callbacks/subcallbacks.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional 2 | 3 | from wandb_allennlp.training.callbacks.log_to_wandb import ( 4 | AllennlpWandbSubCallback, 5 | AllennlpWandbCallback, 6 | GradientDescentTrainer, 7 | ) 8 | 9 | 10 | @AllennlpWandbSubCallback.register("log_best_validation_metrics") 11 | class LogBestValidationMetrics(AllennlpWandbSubCallback): 12 | def on_epoch_( 13 | self, 14 | super_callback: AllennlpWandbCallback, 15 | trainer: "GradientDescentTrainer", 16 | metrics: Dict[str, Any], 17 | epoch: int, 18 | is_primary: bool = True, 19 | **kwargs: Any, 20 | ) -> None: 21 | """ 22 | Logs the best_validation_* to wandb. 23 | """ 24 | # identify the tracked metric 25 | # this logic fragile and depends on an internal varaible of the trainer. 26 | # Hence it will need to be updated with newer versions of allennlp. 27 | metric_names_to_take = [ 28 | f"best_validation_{name}" 29 | for sign, name in trainer._metric_tracker.tracked_metrics 30 | ] 31 | 32 | super_callback.log_scalars( 33 | { 34 | name.replace("validation_", "", 1): value 35 | for name, value in metrics.items() 36 | for metric_name in metric_names_to_take 37 | if name == metric_name 38 | }, 39 | log_prefix="validation", 40 | epoch=epoch, 41 | ) 42 | -------------------------------------------------------------------------------- /docs_source/make.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import os 3 | import argparse 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--local", action="store_true") 9 | parser.add_argument( 10 | "-b", "--smv_branch_whitelist", help="Name of the local branch" 11 | ) 12 | args = parser.parse_args() 13 | subprocess_args = [ 14 | "sphinx-multiversion", 15 | ".", 16 | "../docs", 17 | "-D", 18 | "autoapi_dirs=${sourcedir}/../wandb_allennlp", 19 | "-D", 20 | "autoapi_root=${sourcedir}", 21 | ] 22 | 23 | if args.local: 24 | if args.smv_branch_whitelist is None: 25 | raise ValueError( 26 | "argument smv_branch_whitelist is required if using --local" 27 | ) 28 | else: 29 | subprocess_args += [ 30 | "-D", 31 | f"smv_branch_whitelist={args.smv_branch_whitelist}", 32 | ] 33 | subprocess.run(subprocess_args) 34 | print("Writing root index.html in the docs/") 35 | redirect_url = ( 36 | "master" if not args.local else args.smv_branch_whitelist 37 | ) + "/index.html" 38 | main_index = f""" 39 | 40 | 41 | 42 | 43 | 44 | 45 |

See the latest documentation here.

46 | 47 | 48 | """ 49 | with open("../docs/index.html", "w") as f: 50 | f.write(main_index) 51 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | script_launch_mode = subprocess 3 | 4 | [tox] 5 | # check: https://tox.readthedocs.io/en/latest/config.html#generating-environments-conditional-settings 6 | envlist = py{37,38,39,py3}-allennlp{2.5.0,2.8.0,2.9.3},lint 7 | 8 | [testenv] 9 | passenv = WANDB_API_KEY 10 | setenv = DATA_PATH = {toxinidir}/tests/data 11 | deps = 12 | {toxinidir} 13 | pytest 14 | pytest-console-scripts 15 | allennlp2.5.0: allennlp==2.5.0 16 | allennlp2.5.0: wandb==0.10.11 17 | allennlp2.8.0: allennlp==2.8.0 18 | allennlp2.8.0: wandb==0.12.10 19 | allennlp2.9.3: allennlp==2.9.3 20 | allennlp2.9.3: wandb==0.12.15 21 | 22 | changedir={toxinidir}/tests 23 | # based the nltk download is based on https://github.com/allenai/allennlp/pull/5540/files 24 | # it can be removed after the support for older versions of allennlp is dropped. 25 | commands = 26 | python -c 'import nltk; [nltk.download(p) for p in ("wordnet", "wordnet_ic", "sentiwordnet", "omw", "omw-1.4")]' 27 | pytest 28 | 29 | [testenv:lint] 30 | basepython = python3.8 31 | changedir = {toxinidir} 32 | deps = 33 | # check-manifest 34 | # readme_renderer[md] 35 | flake8 36 | # flake8-docstrings 37 | # flake8-commas 38 | # pep8-naming 39 | # twine 40 | 41 | commands = 42 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 43 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 44 | # check-manifest --ignore *.ini,tests*,.*.yml,demo* 45 | # twine check .tox/dist/* 46 | # flake8 pytest_console_scripts.py setup.py tests 47 | 48 | [flake8] 49 | exclude = .tox,*.egg,build 50 | select = E,W,F 51 | ignore = W503,W504 52 | -------------------------------------------------------------------------------- /.github/workflows/pull_requests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | --- 4 | name: Tests for pull requests and pushes on master 5 | 6 | on: 7 | push: 8 | branches: [master, develop/master] 9 | pull_request: 10 | branches: [master, develop/master] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.8", "3.9"] 19 | allennlp-version: ["2.10.1"] 20 | env: 21 | TOXENV: allennlp${{ matrix.allennlp-version }} 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python ${{ matrix.python-version }} 25 | uses: actions/setup-python@v1 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | - name: Cache pip 29 | id: cache 30 | uses: actions/cache@v1 31 | with: 32 | path: ~/.cache/pip # This path is specific to Ubuntu 33 | # Look to see if there is a cache hit for the corresponding requirements file 34 | key: ${{ runner.os }}-${{ matrix.allennlp-version }}-pip-${{ hashFiles('setup.py') }} 35 | restore-keys: | 36 | ${{ runner.os }}-${{ matrix.allennlp-version }}-pip- 37 | ${{ runner.os }}- 38 | - name: Install Tox and any other packages 39 | run: pip install tox 40 | - name: Run test using Tox for allennlp==${{ matrix.allennlp-version }} 41 | shell: bash 42 | env: 43 | WANDB_API_KEY: ${{secrets.WANDB_API_KEY}} 44 | run: tox 45 | # run: tox -e py 46 | -------------------------------------------------------------------------------- /tests/models/failing.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, Any 3 | import allennlp 4 | from allennlp.models import Model 5 | from allennlp.data.vocabulary import Vocabulary 6 | 7 | 8 | @Model.register("dummy-failing") 9 | class DummyModel(Model): 10 | def __init__(self, vocab: Vocabulary, a: float, b: float = 0): 11 | super().__init__(vocab=vocab) 12 | self.a = a 13 | self.b = b 14 | self.param = torch.nn.Parameter(torch.tensor(10.0)) 15 | self.x = 0 16 | 17 | def forward(self, *args, **kwargs): 18 | self.x += 1 19 | 20 | return {"loss": self.a * self.x + self.b + self.param} 21 | 22 | 23 | @Model.register("parameter-tying-failing") 24 | class DummyModel(Model): 25 | def __init__( 26 | self, 27 | vocab: Vocabulary, 28 | a: float, 29 | b: float, 30 | d: float, 31 | bool_value: bool, 32 | bool_value_not: bool, 33 | int_value: int, 34 | int_value_10: int, 35 | ): 36 | super().__init__(vocab=vocab) 37 | self.a = a 38 | self.b = b 39 | self.d = d 40 | self.param = torch.nn.Parameter(torch.tensor(10.0)) 41 | assert a == b 42 | assert isinstance(bool_value, bool) 43 | assert isinstance(bool_value_not, bool) 44 | assert bool_value == (not bool_value_not) 45 | assert isinstance(int_value, int) 46 | assert isinstance(int_value_10, int) 47 | assert int_value + 10 == int_value_10 48 | assert d == 1 49 | self.x = 0 50 | self.count = 0 51 | 52 | def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: 53 | self.x += 1 54 | 55 | if self.count > 12: 56 | raise Exception 57 | 58 | self.count += 1 59 | 60 | return {"loss": self.a * self.x + self.b + self.param} 61 | -------------------------------------------------------------------------------- /src/wandb_allennlp/commands/download_from_wandb.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict 2 | from .parser_base import WandbParserBase 3 | from allennlp.commands import Subcommand 4 | import argparse 5 | import wandb 6 | import logging 7 | import tqdm 8 | from pathlib import Path 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def main(args: argparse.Namespace) -> None: 14 | api = wandb.Api() # type: ignore 15 | run = api.run( 16 | f"{args.wandb_entity}/{args.wandb_project}/{args.wandb_run_id}" 17 | ) 18 | args.output_folder.mkdir(parents=True, exist_ok=True) 19 | pbar = tqdm.tqdm(run.files(), desc="Downloading files") 20 | 21 | for file_ in pbar: 22 | pbar.set_description(f"Downloading: {file_.name}") 23 | file_.download(args.output_folder, replace=args.replace) 24 | 25 | logger.info(f"Downloaded all files to {args.output_folder}") 26 | 27 | 28 | @Subcommand.register("wandb_download") 29 | class DownloadFromWandb(WandbParserBase): 30 | description = "Downloads all files for a run from wandb" 31 | help_message = ( 32 | "Use this subcommand to perform download" 33 | " all the files for a particular run from wandb" 34 | ) 35 | require_run_id = True 36 | entry_point = main 37 | 38 | def add_arguments( 39 | self, subparser: argparse.ArgumentParser 40 | ) -> argparse.ArgumentParser: 41 | subparser.add_argument( 42 | "-o", 43 | "--output_folder", 44 | type=Path, 45 | required=True, 46 | help="Path to the output folder.", 47 | ) 48 | subparser.add_argument( 49 | "-r", 50 | "--replace", 51 | action="store_true", 52 | help="Whether to overrite the contents if the files/folder exists", 53 | ) 54 | subparser.set_defaults(func=main) 55 | 56 | return subparser 57 | -------------------------------------------------------------------------------- /src/wandb_allennlp/training/callbacks/utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional 2 | import json 3 | import logging 4 | from copy import deepcopy 5 | from pathlib import Path 6 | from allennlp.models.archival import CONFIG_NAME 7 | 8 | logger = logging.getLogger(__name__) 9 | Number = Union[int, float] 10 | Value = Union[int, float, bool, str] 11 | 12 | 13 | def get_allennlp_major_minor_versions() -> Tuple[int, int]: 14 | import allennlp.version 15 | 16 | return int(allennlp.version._MAJOR), int(allennlp.version._MINOR) 17 | 18 | 19 | def flatten_dict(params: Dict[str, Any], 20 | delimiter: str = ".") -> Dict[str, Value]: 21 | """ 22 | Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a.b': 'c'}``. 23 | Args: 24 | params: Dictionary containing the hyperparameters 25 | delimiter: Delimiter to express the hierarchy. Defaults to ``'.'``. 26 | Returns: 27 | Flattened dict. 28 | """ 29 | output: Dict[str, Union[str, Number]] = {} 30 | 31 | def populate(inp: Union[Dict[str, Any], List, str, Number, bool], 32 | prefix: List[str]) -> None: 33 | 34 | if isinstance(inp, dict): 35 | for k, v in inp.items(): 36 | populate(v, deepcopy(prefix) + [k]) 37 | 38 | elif isinstance(inp, list): 39 | for i, val in enumerate(inp): 40 | populate(val, deepcopy(prefix) + [str(i)]) 41 | elif isinstance(inp, (str, float, int, bool)) or (inp is None): 42 | output[delimiter.join(prefix)] = inp 43 | else: # unsupported type 44 | raise ValueError( 45 | f"Unsuported type {type(inp)} at {delimiter.join(prefix)} for flattening." 46 | ) 47 | 48 | populate(params, []) 49 | 50 | return output 51 | 52 | 53 | def get_config_from_serialization_dir(dir_: str, ) -> Dict[str, Value]: 54 | with open(Path(dir_) / CONFIG_NAME) as f: 55 | config_dict = json.load(f) 56 | config_dict = flatten_dict(config_dict) 57 | 58 | return config_dict 59 | -------------------------------------------------------------------------------- /.github/labels.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Labels names are important as they are used by Release Drafter to decide 3 | # regarding where to record them in changelog or if to skip them. 4 | # 5 | # The repository labels will be automatically configured using this file and 6 | # the GitHub Action https://github.com/marketplace/actions/github-labeler. 7 | - name: breaking 8 | description: Breaking Changes 9 | color: bfd4f2 10 | - name: bug 11 | description: Something isn't working 12 | color: d73a4a 13 | - name: build 14 | description: Build System and Dependencies 15 | color: bfdadc 16 | - name: ci 17 | description: Continuous Integration 18 | color: 4a97d6 19 | - name: dependencies 20 | description: Pull requests that update a dependency file 21 | color: 0366d6 22 | - name: documentation 23 | description: Improvements or additions to documentation 24 | color: 0075ca 25 | - name: duplicate 26 | description: This issue or pull request already exists 27 | color: cfd3d7 28 | - name: new_method 29 | description: New model/method/experiment 30 | color: a2eeef 31 | - name: github_actions 32 | description: Pull requests that update Github_actions code 33 | color: "000000" 34 | - name: good first issue 35 | description: Good for newcomers 36 | color: 7057ff 37 | - name: help wanted 38 | description: Extra attention is needed 39 | color: 008672 40 | - name: invalid 41 | description: This doesn't seem right 42 | color: e4e669 43 | - name: performance 44 | description: Performance 45 | color: "016175" 46 | - name: question 47 | description: Further information is requested 48 | color: d876e3 49 | - name: refactoring 50 | description: Refactoring 51 | color: ef67c4 52 | - name: removal 53 | description: Removals 54 | color: 9ae7ea 55 | - name: deprecation 56 | description: Deprecations 57 | color: 9ae7eb 58 | - name: style 59 | description: Code style, linting, typechecking, etc. 60 | color: c120e5 61 | - name: testing 62 | description: Testing 63 | color: b1fc6f 64 | - name: wontfix 65 | description: This will not be worked on 66 | color: ffffff 67 | -------------------------------------------------------------------------------- /docs_source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | 4 | import sys 5 | import os 6 | 7 | sys.path.insert(0, os.path.abspath("../")) 8 | 9 | extensions = [ 10 | "autoapi.extension", 11 | "sphinx.ext.todo", 12 | "sphinx.ext.napoleon", 13 | "sphinx.ext.graphviz", 14 | "sphinx.ext.inheritance_diagram", 15 | "sphinx_multiversion", 16 | "m2r", 17 | ] 18 | source_suffix = [".rst", ".md"] 19 | master_doc = "index" 20 | project = "wandb-allennlp" 21 | copyright = "Dhruvesh Patel" 22 | exclude_patterns = ["_build", "**/docs", "**/.docs", "**/tests", "tests/**"] 23 | pygments_style = "sphinx" 24 | templates_path = ["templates"] # needed for multiversion 25 | autoclass_content = "class" 26 | html_baseurl = "http://dhruveshp.com/wandb-allennlp/" 27 | html_logo = "images/banner.png" 28 | html_theme_options = { 29 | "github_user": "dhruvdcoder", 30 | "github_repo": "wandb-allennlp", 31 | "github_banner": True, 32 | "github_button": True, 33 | } 34 | 35 | # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autoclass_content 36 | autoclass_content = "both" 37 | # autodoc_default_options = {'undoc-members': True} 38 | 39 | # API Generation 40 | autoapi_dirs = ["../wandb_allennlp"] 41 | autoapi_root = "." 42 | autoapi_options = [ 43 | "members", 44 | "inherited-members", 45 | "undoc-members", 46 | "show-inheritance", 47 | "show-module-summary", 48 | ] 49 | autoapi_add_toctree_entry = False 50 | autoapi_keep_files = True 51 | 52 | # see: https://github.com/data-describe/data-describe/blob/master/docs/source/conf.py 53 | # and https://github.com/data-describe/data-describe/blob/master/docs/make.py 54 | # multiversion 55 | # Multiversioning 56 | smv_tag_whitelist = r"^v\d+\.\d+\.[456789]+b?\d*$" 57 | smv_branch_whitelist = r"^.*master$" 58 | smv_remote_whitelist = r"^.*$" 59 | templates_path = [ 60 | "templates", 61 | ] 62 | html_sidebars = { 63 | "**": [ 64 | "about.html", 65 | "navigation.html", 66 | "relations.html", 67 | "searchbox.html", 68 | "versioning.html", 69 | ] 70 | } 71 | -------------------------------------------------------------------------------- /tests/models/dummy.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional 2 | import torch 3 | import allennlp 4 | from allennlp.models import Model 5 | from allennlp.data.vocabulary import Vocabulary 6 | from allennlp.data.dataset_readers import DatasetReader 7 | from allennlp.data.fields import TensorField 8 | from allennlp.data.instance import Instance 9 | 10 | 11 | @Model.register("dummy") 12 | class DummyModel(Model): 13 | def __init__(self, vocab: Vocabulary, a: float, b: float = 0): 14 | super().__init__(vocab=vocab) 15 | self.a = a 16 | self.b = b 17 | self.param = torch.nn.Parameter(torch.tensor(10.0)) 18 | self.x = 0 19 | 20 | def forward(self, *args, **kwargs): 21 | self.x += 1 22 | 23 | return {"loss": self.a * self.x + self.b + self.param} 24 | 25 | 26 | @Model.register("parameter-tying") 27 | class DummyModel(Model): 28 | def __init__( 29 | self, 30 | vocab: Vocabulary, 31 | a: float, 32 | b: float, 33 | d: float, 34 | bool_value: bool, 35 | bool_value_not: bool, 36 | int_value: int, 37 | int_value_10: int, 38 | ): 39 | super().__init__(vocab=vocab) 40 | self.a = a 41 | self.b = b 42 | self.d = d 43 | self.param = torch.nn.Parameter(torch.tensor(10.0)) 44 | assert a == b 45 | assert isinstance(bool_value, bool) 46 | assert isinstance(bool_value_not, bool) 47 | assert bool_value == (not bool_value_not) 48 | assert isinstance(int_value, int) 49 | assert isinstance(int_value_10, int) 50 | assert int_value + 10 == int_value_10 51 | assert d == 1 52 | self.x = 0 53 | 54 | def forward(self, *args: Any, **kwargs: Any): 55 | self.x += 1 56 | 57 | return {"loss": self.a * self.x + self.b + self.param} 58 | 59 | 60 | @DatasetReader.register("dummy") 61 | class Dummy(DatasetReader): 62 | def _read(self, file_path: str): 63 | for i in range(10): 64 | yield Instance({"x": TensorField(torch.tensor([i]))}) 65 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Running Hyperparameter search for AllenNLP models from allennlp-models repository. 4 | 5 | 0. Install wandb-allennlp 6 | 7 | ``` 8 | pip install wandb-allennlp 9 | ``` 10 | 11 | 1. Install the models repository (use virtual environment or conda) 12 | 13 | ``` 14 | pip install allennlp-models 15 | ``` 16 | 17 | 2. Create a model using config file. See [training_configs/pair_classification/decomposable_attention.jsonnet](examples/training_configs/pair_classification/decomposable_attention.jsonnet) for complete example config. 18 | 19 | **Note: Use the following callback specification format for allennlp v0.9:** 20 | 21 | ``` 22 | ..., 23 | 24 | trainer: { 25 | type: 'callback', 26 | callbacks: [ 27 | ..., 28 | 29 | { 30 | type: 'log_metrics_to_wandb', 31 | }, 32 | 33 | ..., 34 | ], 35 | ..., 36 | } 37 | ... 38 | ... 39 | ``` 40 | 41 | **and the following for allennlp v1.x :** 42 | ``` 43 | ... 44 | 45 | trainer: { 46 | epoch_callbacks: [ 47 | ..., 48 | 49 | { 50 | type: 'log_metrics_to_wandb', 51 | }, 52 | 53 | ..., 54 | ], 55 | ..., 56 | } 57 | ... 58 | ... 59 | ``` 60 | 61 | **and the following for allennlp v2.x :** 62 | ``` 63 | ... 64 | 65 | trainer: { 66 | callbacks: [ 67 | ..., 68 | 69 | { 70 | type: 'log_metrics_to_wandb', 71 | }, 72 | 73 | ..., 74 | ], 75 | ..., 76 | } 77 | ... 78 | ... 79 | ``` 80 | 81 | 3. Create a sweep config file. See [sweep_configs/pair_classification/decomposable_attention.yaml](examples/sweep_configs/pair_classification/decomposable_attention.yaml) for the example config. 82 | 83 | 4. Create the sweep. 84 | 85 | ``` 86 | wandb sweep sweep_configs/pair_classification/decomposable_attention.yaml 87 | ``` 88 | 89 | 5. Start the agent(s). Run the following command on multiple machines for parallelizing the search. 90 | 91 | ``` 92 | wandb agent 93 | ``` 94 | 95 | 6. Look at the results! Results for the presented examples can be found [here](https://app.wandb.ai/dhruveshpate/wandb_allennlp_models_demo/sweeps/vwwu3sa0). 96 | 97 | 98 | -------------------------------------------------------------------------------- /.github/workflows/ci-build-docs.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | name: Build Docs 4 | 5 | on: 6 | push: 7 | branches: [ master, develop/master ] 8 | pull_request: 9 | branches: [ master, develop/master ] 10 | 11 | jobs: 12 | build_docs: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | # we will need the complete history. 17 | # checkout@v2 will fetch only last commit 18 | # So we fetch all 19 | # checkout@v2 will leave git in a detached head state 20 | # Hence, we need to temporarily create a branch. 21 | # NOTE: we cannot push this brach 22 | - run: | 23 | git switch -c "temp-branch-for-docs" 24 | git fetch --prune --unshallow --tags 25 | - name: Set up Python 26 | uses: actions/setup-python@v1 27 | with: 28 | python-version: '3.8' 29 | - name: Cache pip 30 | id: cache 31 | uses: actions/cache@v1 32 | with: 33 | path: ~/.cache/pip # This path is specific to Ubuntu 34 | # Look to see if there is a cache hit for the corresponding requirements file 35 | key: ${{ runner.os }}-pip-${{ hashFiles('*_requirements.txt') }} 36 | restore-keys: | 37 | ${{ runner.os }}-pip- 38 | ${{ runner.os }}- 39 | - name: Install test and doc requirements 40 | run: | 41 | pip install -r test_requirements.txt 42 | pip install -r doc_requirements.txt 43 | - name: Build Docs 44 | shell: bash -l {0} 45 | working-directory: docs_source 46 | run: | 47 | echo "See all fetched branches" 48 | git branch 49 | echo "Building Docs" 50 | python make.py 51 | - name: Upload html doc as Artifact 52 | uses: actions/upload-artifact@v1 53 | with: 54 | name: DocumentationHTML 55 | path: docs/ 56 | - name: Deploy Docs 57 | # deploy only after merge 58 | if: ${{github.event_name == 'push'}} 59 | uses: peaceiris/actions-gh-pages@v3.7.3 60 | with: 61 | github_token: ${{ secrets.GITHUB_TOKEN }} 62 | publish_dir: ./docs 63 | keep_files: false 64 | cname: dhruveshp.com 65 | enable_jekyll: false 66 | force_orphan: true 67 | publish_branch: gh-pages 68 | -------------------------------------------------------------------------------- /tests/test_failing_single_run.py: -------------------------------------------------------------------------------- 1 | def test_run_fail_and_dont_close(script_runner): 2 | """ 3 | The model training will fail after one epoch. 4 | The method `wandb.close()` will not be called in the 5 | `close()` function of the wandb callback. 6 | Hence, the fact that the run has failed should propogate to wandb. 7 | """ 8 | ret = script_runner.run( 9 | "allennlp", 10 | "train-with-wandb", 11 | "configs/parameter_tying_v1.0.0.jsonnet", 12 | "--wandb-entity=dhruveshpate", 13 | "--wandb-project=wandb-allennlp-wandb_allennlp_tests", 14 | "--wandb-name=plugging_test__test_run_fail_and_dont_close", 15 | "--include-package=models", 16 | "--env.a=1.1", 17 | "--env.bool_value=true", 18 | "--env.int_value=10", 19 | "--model.type=parameter-tying-failing", 20 | "--model.d=1", # keep this 1 21 | "--env.call_finish_on_end=false", 22 | ) 23 | assert not ret.success 24 | assert not ( 25 | "(success)." in ret.stderr 26 | or "wandb: Program ended successfully." in ret.stderr 27 | ) 28 | assert ( 29 | "(failed 1)" in ret.stderr 30 | or "wandb: Program failed with code 1" in ret.stderr 31 | ) 32 | 33 | 34 | def test_run_fail_and_close(script_runner): 35 | """ 36 | The model training will fail after one epoch. 37 | The method `wandb.close()` will be called in the 38 | `close()` function of the wandb callback. 39 | Hence, the fact that the run has failed should not propogate to wandb. 40 | """ 41 | ret = script_runner.run( 42 | "allennlp", 43 | "train-with-wandb", 44 | "configs/parameter_tying_v1.0.0.jsonnet", 45 | "--wandb-entity=dhruveshpate", 46 | "--wandb-project=wandb-allennlp-wandb_allennlp_tests", 47 | "--wandb-name=plugging_test__test_run_fail_and_close", 48 | "--include-package=models", 49 | "--env.a=1.1", 50 | "--env.bool_value=true", 51 | "--env.int_value=10", 52 | "--model.type=parameter-tying-failing", 53 | "--model.d=1", # keep this 1 54 | "--env.call_finish_on_end=true", 55 | ) 56 | assert not ret.success 57 | assert ( 58 | "(success)." in ret.stderr 59 | or "wandb: Program ended successfully." in ret.stderr 60 | ) 61 | assert not ( 62 | "(failed 1)" in ret.stderr 63 | or "wandb: Program failed with code 1" in ret.stderr 64 | ) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | wandb_allennlp/tests/wandb/ 3 | docs 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /examples/training_configs/pair_classification/decomposable_attention.jsonnet: -------------------------------------------------------------------------------- 1 | // Configuraiton for a textual entailment model based on: 2 | // Parikh, Ankur P. et al. “A Decomposable Attention Model for Natural Language Inference.” EMNLP (2016). 3 | // As presented in the allennlp-models : https://github.com/allenai/allennlp-models/blob/master/training_config/pair_classification/decomposable_attention.jsonnet 4 | { 5 | dataset_reader: { 6 | type: 'snli', 7 | token_indexers: { 8 | tokens: { 9 | type: 'single_id', 10 | lowercase_tokens: true, 11 | }, 12 | }, 13 | tokenizer: { 14 | end_tokens: ['@@NULL@@'], 15 | }, 16 | }, 17 | train_data_path: 'https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_train.jsonl', 18 | validation_data_path: 'https://allennlp.s3.amazonaws.com/datasets/snli/snli_1.0_dev.jsonl', 19 | model: { 20 | type: 'decomposable_attention', 21 | text_field_embedder: { 22 | token_embedders: { 23 | tokens: { 24 | type: 'embedding', 25 | projection_dim: 200, 26 | pretrained_file: 'https://allennlp.s3.amazonaws.com/datasets/glove/glove.6B.300d.txt.gz', 27 | embedding_dim: 300, 28 | trainable: false, 29 | }, 30 | }, 31 | }, 32 | attend_feedforward: { 33 | input_dim: 200, 34 | num_layers: 2, 35 | hidden_dims: 200, 36 | activations: 'relu', 37 | dropout: 0.2, 38 | }, 39 | matrix_attention: { type: 'dot_product' }, 40 | compare_feedforward: { 41 | input_dim: 400, 42 | num_layers: 2, 43 | hidden_dims: 200, 44 | activations: 'relu', 45 | dropout: 0.2, 46 | }, 47 | aggregate_feedforward: { 48 | input_dim: 400, 49 | num_layers: 2, 50 | hidden_dims: [200, 3], 51 | activations: ['relu', 'linear'], 52 | dropout: [0.2, 0.0], 53 | }, 54 | initializer: { 55 | regexes: [ 56 | ['.*linear_layers.*weight', { type: 'xavier_normal' }], 57 | ['.*token_embedder_tokens\\._projection.*weight', { type: 'xavier_normal' }], 58 | ], 59 | }, 60 | }, 61 | data_loader: { 62 | batch_sampler: { 63 | type: 'bucket', 64 | batch_size: 64, 65 | }, 66 | }, 67 | trainer: { 68 | num_epochs: 140, 69 | patience: 20, 70 | cuda_device: -1, 71 | grad_clipping: 5.0, 72 | validation_metric: '+accuracy', 73 | optimizer: { 74 | type: 'adagrad', 75 | }, 76 | callbacks: [{ type: 'log_metrics_to_wandb' }], // The only extra line in the config!!! 77 | }, 78 | } 79 | -------------------------------------------------------------------------------- /examples/sweep_configs/pair_classification/decomposable_attention.yaml: -------------------------------------------------------------------------------- 1 | # See: https://docs.wandb.com/sweeps/configuration 2 | # for details of all the variables present in this file. 3 | # this is the name of the sweep. It can be anything. 4 | name: decomposable_attention_full_search 5 | # The training script/program 6 | # We do not give "allennlp train" directly by use wandb_allennlp to translate the arguments. 7 | program: wandb_allennlp 8 | # "command" gives the template for how the wandb client agent should start the training scrip 9 | command: 10 | - ${program} #omit the interpreter as wandb_allennlp is registered as a console script. 11 | - "--subcommand=train" # allennlp subcommnd. For now only train can be used. 12 | - "--include-package=allennlp_models" # Module/packaged with your registered classes. 13 | # You can have as many packages as you need 14 | # Just include them one after the other like so: 15 | # --include-package=pack1 16 | # --include-package=pack2 17 | - "--config_file=training_configs/pair_classification/decomposable_attention.jsonnet" # Path 18 | - ${args} # this is where the hyperparams generated by wandb server will go. Do not modify. 19 | method: bayes # can be random, grid or bayes 20 | early_terminate: 21 | type: hyperband # Using hyperband with right bands significally reduces serach time 22 | min_iter: 10 # See: https://docs.wandb.com/sweeps/configuration#stopping-criteria 23 | metric: 24 | name: best_validation_accuracy # name of your allennlp metric to optimize for 25 | goal: maximize 26 | parameters: 27 | ## Search Ranges 28 | # boolean 29 | model.text_field_embedder.token_embedders.tokens.trainable: 30 | values: [true, false] # any valid yaml bool like True,true,yes will work. 0/1 won't. 31 | # string categorical 32 | model.matrix_attention.type: 33 | values: ["dot_product", "cosine"] # can be an arbitrary list of strings 34 | # uniform distribution 35 | model.compare_feedforward.dropout: 36 | min: 0.1 37 | max: 0.4 38 | distribution: uniform 39 | # lists 40 | model.aggregate_feedforward.activations.0: 41 | values: ['relu', 'elu'] 42 | model.aggregate_feedforward.dropout: 43 | min: 0.1 44 | max: 0.4 45 | distribution: uniform 46 | # can also do log_uniform or quantized log_uniform 47 | trainer.optimizer.lr: 48 | max: -5.23 # exp(-5.23) ~= 0.005 49 | min: -11.51 # exp(-11.51) ~= 1e-5 50 | distribution: log_uniform 51 | -------------------------------------------------------------------------------- /.github/workflows/changelog.yml: -------------------------------------------------------------------------------- 1 | name: Changelog Generator 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | 7 | jobs: 8 | generate_changelog_file: 9 | runs-on: ubuntu-latest 10 | 11 | steps: 12 | - uses: actions/checkout@v2 13 | with: 14 | token: ${{ secrets.GH_TOKEN }} # Need to use token to commit the changelog as admin. 15 | 16 | - name: Generate changelog 17 | uses: heinrichreimer/github-changelog-generator-action@v2.2 18 | with: 19 | token: ${{ secrets.GH_TOKEN }} 20 | author: true 21 | unreleased: true 22 | unreleasedLabel: "🚧Unreleased🚧" 23 | issues: true 24 | pullRequests: true 25 | prWoLabels: true 26 | compareLink: true 27 | output: CHANGELOG.md 28 | breakingLabel: "### 💥 Breaking Changes:" 29 | breakingLabels: "breaking" 30 | deprecatedLabel: "### 👋 Depricated" 31 | deprecatedLabels: "deprecation" 32 | enhancementLabel: "### ✨ Features and Enhancements:" 33 | enhancementLabels: "enhancement" 34 | bugsLabel: "### 🐛 Bug Fixes:" 35 | bugLabels: "bug,bug-fix,fix,fixes" 36 | removedLabel: '### 🗑️ Removals:' 37 | removedLabels: 'removal' 38 | addSections: >- 39 | { 40 | "documentation": { 41 | "prefix":"### 📖 Documentation updates", 42 | "labels":["documentation"] 43 | }, 44 | "style": { 45 | "prefix":"### 💄 Style", 46 | "labels":["style"] 47 | }, 48 | "dependencies": { 49 | "prefix":"### 📦 Dependencies", 50 | "labels":["dependencies"] 51 | }, 52 | "refactoring": { 53 | "prefix":"### 🔨 Refactoring", 54 | "labels":["refactoring"] 55 | }, 56 | "ci":{ 57 | "prefix":"### 👷 Testing and CI", 58 | "labels":["ci", "test", "testing", "tests", "build"] 59 | } 60 | } 61 | # See: emojipedia.org for emoji codes 62 | - name: Commit CHANGELOG.md 63 | uses: EndBug/add-and-commit@v7 # You can change this to use a specific version 64 | if: success() 65 | with: 66 | add: 'CHANGELOG.md' 67 | -------------------------------------------------------------------------------- /src/wandb_allennlp/training/train_and_test.py: -------------------------------------------------------------------------------- 1 | """This is essentially the same command as `allennlp.train` but it allows to run evaluation on test set after training is completed and logs the results to wandb""" 2 | from typing import Dict, Any 3 | from allennlp.commands.train import TrainModel 4 | from allennlp.commands import Subcommand 5 | from allennlp.common import util as common_util 6 | from allennlp.training import util as training_util 7 | from wandb_allennlp.utils import read_from_env 8 | import os 9 | import logging 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @TrainModel.register( 15 | "train_test_log_to_wandb", constructor="from_partial_objects" 16 | ) # same construction pipeline as parent 17 | class TrainTestAndLogToWandb(TrainModel): 18 | """Does the same thing as `allennlp.commands.train.TrainModel` but 19 | logs final metrics to `wandb` summary. 20 | 21 | To use this class, add the following to the top-level config. 22 | 23 | .. code-block:: JSON 24 | 25 | { 26 | type: 'train_test_log_to_wandb', 27 | evaluate_on_test: true, 28 | dataset_reader: { ...}, 29 | model: {...}, 30 | ... 31 | } 32 | 33 | """ 34 | 35 | def finish(self, metrics: Dict[str, Any]) -> None: 36 | # import wandb here to be sure that it was initialized 37 | # before this line was executed 38 | import wandb # noqa 39 | 40 | if self.evaluation_data_loader is not None and self.evaluate_on_test: 41 | logger.info( 42 | "The model will be evaluated using the best epoch weights." 43 | ) 44 | test_metrics = training_util.evaluate( 45 | self.model, 46 | self.evaluation_data_loader, # type:ignore 47 | cuda_device=self.trainer.cuda_device, # type: ignore 48 | batch_weight_key=self.batch_weight_key, 49 | ) 50 | 51 | for key, value in test_metrics.items(): 52 | metrics["test_" + key] = value 53 | elif self.evaluation_data_loader is not None: 54 | logger.info( 55 | "To evaluate on the test set after training, pass the " 56 | "'evaluate_on_test' flag, or use the 'allennlp evaluate' command." 57 | ) 58 | common_util.dump_metrics( 59 | os.path.join(self.serialization_dir, "metrics.json"), 60 | metrics, 61 | log=True, 62 | ) 63 | # update the summary with all metrics 64 | 65 | if wandb.run is None: 66 | logger.info("wandb run was closed. Resuming to update summary.") 67 | run = wandb.init( 68 | id=read_from_env("WANDB_RUN_ID"), 69 | project=read_from_env("WANDB_PROJECT"), 70 | entity=read_from_env("WANDB_ENTITY"), 71 | resume="must", 72 | ) 73 | else: 74 | logger.info( 75 | "There is an active wandb run. Using that to update summary." 76 | ) 77 | run = wandb.run 78 | 79 | if run is not None: 80 | logger.info("Updating summary on wandb.") 81 | run.summary.update(metrics) 82 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | ## [🚧Unreleased🚧](https://github.com/dhruvdcoder/wandb-allennlp/tree/HEAD) 4 | 5 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.3.2...HEAD) 6 | 7 | ### ✨ Features and Enhancements: 8 | 9 | - Remove flattening of config [\#40](https://github.com/dhruvdcoder/wandb-allennlp/issues/40) 10 | 11 | ### 🐛 Bug Fixes: 12 | 13 | - The subcallbacks are not automatically loaded [\#36](https://github.com/dhruvdcoder/wandb-allennlp/issues/36) 14 | 15 | ## [v0.3.2](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.3.2) (2022-06-07) 16 | 17 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.3.1...v0.3.2) 18 | 19 | ## [v0.3.1](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.3.1) (2022-02-02) 20 | 21 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.3.0...v0.3.1) 22 | 23 | ### ✨ Features and Enhancements: 24 | 25 | - option 'finish\_on\_end' added [\#32](https://github.com/dhruvdcoder/wandb-allennlp/pull/32) ([dhruvdcoder](https://github.com/dhruvdcoder)) 26 | 27 | ## [v0.3.0](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.3.0) (2021-09-27) 28 | 29 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.2.4.beta...v0.3.0) 30 | 31 | ### 💥 Breaking Changes: 32 | 33 | - Breaking changes! [\#26](https://github.com/dhruvdcoder/wandb-allennlp/pull/26) ([dhruvdcoder](https://github.com/dhruvdcoder)) 34 | 35 | ### 🐛 Bug Fixes: 36 | 37 | - Handle SIGTERM [\#25](https://github.com/dhruvdcoder/wandb-allennlp/issues/25) 38 | 39 | ### 👷 Testing and CI 40 | 41 | - Fix tests [\#27](https://github.com/dhruvdcoder/wandb-allennlp/pull/27) ([dhruvdcoder](https://github.com/dhruvdcoder)) 42 | 43 | ## [v0.2.4.beta](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.2.4.beta) (2021-08-27) 44 | 45 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/seal...v0.2.4.beta) 46 | 47 | ## [seal](https://github.com/dhruvdcoder/wandb-allennlp/tree/seal) (2021-08-25) 48 | 49 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.2.3...seal) 50 | 51 | ### 🐛 Bug Fixes: 52 | 53 | - Multi-GPU training errors [\#11](https://github.com/dhruvdcoder/wandb-allennlp/issues/11) 54 | 55 | ## [v0.2.3](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.2.3) (2021-02-09) 56 | 57 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.2.2...v0.2.3) 58 | 59 | ### ✨ Features and Enhancements: 60 | 61 | - add support for AllenNLP 2.0 \(\#19\) [\#20](https://github.com/dhruvdcoder/wandb-allennlp/pull/20) ([dhruvdcoder](https://github.com/dhruvdcoder)) 62 | 63 | ### 🐛 Bug Fixes: 64 | 65 | - Support flattening of configs with null values [\#18](https://github.com/dhruvdcoder/wandb-allennlp/pull/18) ([dhruvdcoder](https://github.com/dhruvdcoder)) 66 | 67 | ## [v0.2.2](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.2.2) (2020-11-04) 68 | 69 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.2.1...v0.2.2) 70 | 71 | ### ✨ Features and Enhancements: 72 | 73 | - Parameter tying through external variables in jsonnet. [\#15](https://github.com/dhruvdcoder/wandb-allennlp/issues/15) 74 | - Add model watch option \(\#13\) [\#14](https://github.com/dhruvdcoder/wandb-allennlp/pull/14) ([dhruvdcoder](https://github.com/dhruvdcoder)) 75 | 76 | ## [v0.2.1](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.2.1) (2020-10-15) 77 | 78 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.2.0...v0.2.1) 79 | 80 | ### 🐛 Bug Fixes: 81 | 82 | - Check if master before sending config to wandb [\#12](https://github.com/dhruvdcoder/wandb-allennlp/pull/12) ([dhruvdcoder](https://github.com/dhruvdcoder)) 83 | 84 | ## [v0.2.0](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.2.0) (2020-06-22) 85 | 86 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.1.1...v0.2.0) 87 | 88 | ### 🐛 Bug Fixes: 89 | 90 | - Fix passing of ints, floats and bool [\#7](https://github.com/dhruvdcoder/wandb-allennlp/pull/7) ([dhruvdcoder](https://github.com/dhruvdcoder)) 91 | 92 | ## [v0.1.1](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.1.1) (2020-05-12) 93 | 94 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.1...v0.1.1) 95 | 96 | ## [v0.1](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.1) (2020-05-12) 97 | 98 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/v0.0.4...v0.1) 99 | 100 | ## [v0.0.4](https://github.com/dhruvdcoder/wandb-allennlp/tree/v0.0.4) (2020-05-11) 101 | 102 | [Full Changelog](https://github.com/dhruvdcoder/wandb-allennlp/compare/2f080f0ea27060f33de4f083a4e086e56e50416d...v0.0.4) 103 | 104 | 105 | 106 | \* *This Changelog was automatically generated by [github_changelog_generator](https://github.com/github-changelog-generator/github-changelog-generator)* 107 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wandb-allennlp 2 | 3 | 4 | ![Tests](https://github.com/dhruvdcoder/wandb-allennlp/workflows/Tests/badge.svg) 5 | 6 | 7 | Utilities and boilerplate code which allows using [Weights & Biases](https://www.wandb.com/) to tune the hypereparameters for any AllenNLP model **without a single line of extra code!** 8 | 9 | # What does it do? 10 | 11 | 1. Log a single run or a hyperparameter search sweep without any extra code, just using configuration files. 12 | 13 | 2. Use [Weights & Biases'](https://www.wandb.com/) bayesian hyperparameter search engine + hyperband in any AllenNLP project. 14 | 15 | 16 | 17 | # Quick start 18 | 19 | ## Installation 20 | 21 | ``` 22 | $ pip install wandb-allennlp 23 | $ echo wandb_allennlp >> .allennlp_plugins 24 | ``` 25 | 26 | 27 | 28 | ## Log a single run 29 | 30 | 1. Create your model using AllenNLP along with a *training configuration* file as you would normally do. 31 | 32 | 2. Add a trainer callback in your config file. Use one of the following based on your AllenNLP version: 33 | 34 | 35 | ``` 36 | ..., 37 | 38 | trainer: { 39 | type: 'callback', 40 | callbacks: [ 41 | ..., 42 | { 43 | type: 'wandb_allennlp', 44 | files_to_save: ['config.json'], 45 | files_to_save_at_end: ['*.tar.gz'], 46 | }, 47 | ..., 48 | ], 49 | ..., 50 | } 51 | ... 52 | ... 53 | ``` 54 | 55 | 2. Execute the `allennlp train-with-wandb` command instead of `allennlp train`. It supports all the arguments present in `allennlp train`. However, the `--overrides` have to be specified in the `--kw value` or `--kw=value` form, where `kw` is the parameter to override and `value` is its value. Use the dot notation for nested parameters. For instance, `{'model': {'embedder': {'type': xyz}}}` can be provided as `--model.embedder.type xyz`. 56 | 57 | ``` 58 | allennlp train-with-wandb model_configs/my_config.jsonnet --include-package=package_with_my_registered_classes --include-package=another_package --wandb-run-name=my_first_run --wandb-tags=any,set,of,non-unique,tags,that,identify,the,run,without,spaces 59 | 60 | ``` 61 | 62 | 63 | ## Hyperparameter Search 64 | 65 | 1. Create your model using AllenNLP along with a *training configuration* file as you would normally do. For example: 66 | 67 | ``` 68 | local data_path = std.extVar('DATA_PATH'); 69 | local a = std.parseJson(std.extVar('a')); 70 | local bool_value = std.parseJson(std.extVar('bool_value')); 71 | local int_value = std.parseJson(std.extVar('int_value')); 72 | 73 | { 74 | type: 'train_test_log_to_wandb', 75 | evaluate_on_test: true, 76 | dataset_reader: { 77 | type: 'snli', 78 | token_indexers: { 79 | tokens: { 80 | type: 'single_id', 81 | lowercase_tokens: true, 82 | }, 83 | }, 84 | }, 85 | train_data_path: data_path + '/snli_1.0_test/snli_1.0_train.jsonl', 86 | validation_data_path: data_path + '/snli_1.0_test/snli_1.0_dev.jsonl', 87 | test_data_path: data_path + '/snli_1.0_test/snli_1.0_test.jsonl', 88 | model: { 89 | type: 'parameter-tying', 90 | a: a, 91 | b: a, 92 | d: 0, 93 | bool_value: bool_value, 94 | bool_value_not: !bool_value, 95 | int_value: int_value, 96 | int_value_10: int_value + 10, 97 | 98 | }, 99 | data_loader: { 100 | batch_sampler: { 101 | type: 'bucket', 102 | batch_size: 64, 103 | }, 104 | }, 105 | trainer: { 106 | optimizer: { 107 | type: 'adam', 108 | lr: 0.001, 109 | weight_decay: 0.0, 110 | }, 111 | cuda_device: -1, 112 | num_epochs: 2, 113 | callbacks: [ 114 | { 115 | type: 'wandb_allennlp', 116 | files_to_save: ['config.json'], 117 | files_to_save_at_end: ['*.tar.gz'], 118 | }, 119 | ], 120 | }, 121 | } 122 | ``` 123 | 124 | 2. Create a *sweep configuration* file and generate a sweep on the wandb server. Note that the tied parameters that are accepted through environment variables are specified using the prefix `env.` in the sweep config. For example: 125 | 126 | ``` 127 | name: parameter_tying_test_console_script_v0.2.4 128 | program: allennlp 129 | command: 130 | - ${program} #omit the interpreter as we use allennlp train command directly 131 | - "train-with-wandb" # subcommand 132 | - "configs/parameter_tying_v0.2.4.jsonnet" 133 | - "--include-package=models" # add all packages containing your registered classes here 134 | - "--include-package=allennlp_models" 135 | - ${args} 136 | method: bayes 137 | metric: 138 | name: training_loss 139 | goal: minimize 140 | parameters: 141 | # hyperparameters start with overrides 142 | # Ranges 143 | # Add env. to tell that it is a top level parameter 144 | env.a: 145 | min: 1 146 | max: 10 147 | distribution: uniform 148 | env.bool_value: 149 | values: [true, false] 150 | env.int_value: 151 | values: [-1, 0, 1, 10] 152 | model.d: 153 | value: 1 154 | ``` 155 | 3. Create the sweep on wandb. 156 | 157 | ``` 158 | $ wandb sweep path_to_sweep.yaml 159 | ``` 160 | 161 | 4. Set the other environment variables required by your jsonnet. 162 | 163 | ``` 164 | export DATA_DIR=./data 165 | ``` 166 | 167 | 5. Start the search agents. 168 | 169 | ``` 170 | wandb agent 171 | ``` 172 | -------------------------------------------------------------------------------- /src/wandb_allennlp/commands/parser_base.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Optional, List, Dict, Any 2 | from allennlp.commands.subcommand import Subcommand 3 | from pathlib import Path 4 | from wandb_allennlp.utils import read_from_env 5 | import logging 6 | import argparse 7 | import wandb 8 | import os 9 | 10 | wandb_major, wandb_minor, wandb_patch = wandb.__version__.split('.') 11 | 12 | if int(wandb_minor) >=12 or int(wandb_major) > 0: 13 | wandb_get_dir= wandb.sdk.wandb_settings._get_wandb_dir 14 | else: 15 | wandb_get_dir= wandb.sdk.wandb_settings.get_wandb_dir 16 | 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class SetWandbEnvVar(argparse.Action): 22 | """Used as an action callback in argparse argument to set env vars that are read by wandb. 23 | 24 | Can be used like so: :: 25 | 26 | parser.add_argument('--wandb_entity', type=str, action=SetWandbEnvVar) 27 | """ 28 | 29 | def __call__( # type: ignore 30 | self, 31 | parser: argparse.ArgumentParser, 32 | namespace: argparse.Namespace, 33 | values: str, 34 | option_string: str = None, 35 | ) -> None: 36 | setattr(namespace, self.dest, values) 37 | assert isinstance(values, str) 38 | os.environ.update({self.dest.replace("-", "_").upper(): values}) 39 | 40 | 41 | class WandbParserBase(Subcommand): 42 | """This subcommand cannot be used directly. It is indented to 43 | be a common base class for all commands that use Weights & Biases. 44 | It as few common args for wandb. 45 | 46 | 47 | The way to use this in a child subcommand would be: :: 48 | 49 | Subcommand.register("some_subcommand") 50 | class SomeSubcommand(WandbParserBase): 51 | description = "Does some stuff" 52 | help_message = "Help message to do some stuff." 53 | entry_point=SomeCallable 54 | def add_arguments(self, subparser): 55 | subparser.add_argument("--some_important_argument", ...) 56 | ... 57 | ... 58 | return subparser 59 | 60 | 61 | """ 62 | 63 | description: str = "A wandb base command" 64 | help_message: str = ( 65 | "Most arguments starting with 'wandb_' are in one-to-one correspondance with wandb.init()." 66 | " See https://docs.wandb.ai/ref/run/init for reference." 67 | ) 68 | require_run_id = False 69 | 70 | def add_arguments( 71 | self, subparser: argparse.ArgumentParser 72 | ) -> argparse.ArgumentParser: 73 | logger.warning( 74 | f"add_arguments() for {self.__class__} not overriden." 75 | " Did you forget to add arguments?" 76 | ) 77 | 78 | return subparser 79 | 80 | @classmethod 81 | def get_wandb_run_args(cls, args: argparse.Namespace) -> Dict[str, Any]: 82 | return { 83 | "_".join(arg.split("_")[1:]): value 84 | for arg, value in vars(args).items() 85 | if arg.startswith("wandb_") and (value is not None) 86 | } 87 | 88 | @classmethod 89 | def init_wandb_run( 90 | cls, args: argparse.Namespace 91 | ) -> wandb.sdk.wandb_run.Run: 92 | run = wandb.init(**cls.get_wandb_run_args(args)) 93 | # just use the log files and do not dynamically patch tensorboard as it messes up the 94 | # the global_step and breaks the normal use of wandb.log() 95 | wandb.tensorboard.patch(save=True, tensorboardX=False) 96 | 97 | return run # type: ignore 98 | 99 | def add_subparser( 100 | self, parser: argparse._SubParsersAction 101 | ) -> argparse.ArgumentParser: 102 | subparser = parser.add_parser( 103 | self.name, 104 | description=self.description, 105 | help=self.help_message, 106 | conflict_handler="resolve", 107 | ) 108 | subparser.add_argument( 109 | "--wandb-run-id", 110 | type=str, 111 | required=self.require_run_id, 112 | default=read_from_env("WANDB_RUN_ID"), 113 | ) 114 | subparser.add_argument( 115 | "--wandb-entity", 116 | type=str, 117 | action=SetWandbEnvVar, 118 | default=read_from_env("WANDB_ENTITY"), 119 | ) 120 | subparser.add_argument( 121 | "--wandb-project", 122 | type=str, 123 | action=SetWandbEnvVar, 124 | default=read_from_env("WANDB_PROJECT"), 125 | ) 126 | subparser.add_argument( 127 | "--wandb-tags", 128 | type=str, 129 | action=SetWandbEnvVar, 130 | help="Comma seperated list of tags.", 131 | ) 132 | subparser.add_argument("--wandb-name", action=SetWandbEnvVar, type=str) 133 | subparser.add_argument( 134 | "--wandb-group", action=SetWandbEnvVar, type=str 135 | ) 136 | subparser.add_argument( 137 | "--wandb-job-type", action=SetWandbEnvVar, type=str 138 | ) 139 | subparser.add_argument( 140 | "--wandb-notes", action=SetWandbEnvVar, type=str 141 | ) 142 | subparser.add_argument( 143 | "--wandb-dir", 144 | type=str, 145 | action=SetWandbEnvVar, 146 | default=wandb_get_dir( 147 | read_from_env("WANDB_DIR") or "" 148 | ), 149 | ) 150 | # subparser.add_argument("--wandb_sync_tensorboard", action="store_true") 151 | subparser.add_argument( 152 | "--wandb-config-exclude-keys", 153 | type=str, 154 | action=SetWandbEnvVar, 155 | help="Comma seperated list.", 156 | ) 157 | subparser.add_argument( 158 | "--wandb-config-include-keys", 159 | type=str, 160 | action=SetWandbEnvVar, 161 | help="Comma seperated list.", 162 | ) 163 | subparser.add_argument( 164 | "--wandb-mode", 165 | action=SetWandbEnvVar, 166 | choices=["online", "offline", "disabled"], 167 | default="online", 168 | ) 169 | 170 | subparser = self.add_arguments(subparser) 171 | 172 | return subparser 173 | -------------------------------------------------------------------------------- /tests/data/snli_1.0_test/snli_1.0_train.jsonl: -------------------------------------------------------------------------------- 1 | {"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} 2 | {"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} 3 | {"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} 4 | {"annotator_labels": ["neutral"], "captionID": "2267923837.jpg#2", "gold_label": "neutral", "pairID": "2267923837.jpg#2r1n", "sentence1": "Children smiling and waving at camera", "sentence1_binary_parse": "( Children ( ( ( smiling and ) waving ) ( at camera ) ) )", "sentence1_parse": "(ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera)))))))", "sentence2": "They are smiling at their parents", "sentence2_binary_parse": "( They ( are ( smiling ( at ( their parents ) ) ) ) )", "sentence2_parse": "(ROOT (S (NP (PRP They)) (VP (VBP are) (VP (VBG smiling) (PP (IN at) (NP (PRP$ their) (NNS parents)))))))"} 5 | {"annotator_labels": ["entailment"], "captionID": "2267923837.jpg#2", "gold_label": "entailment", "pairID": "2267923837.jpg#2r1e", "sentence1": "Children smiling and waving at camera", "sentence1_binary_parse": "( Children ( ( ( smiling and ) waving ) ( at camera ) ) )", "sentence1_parse": "(ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera)))))))", "sentence2": "There are children present", "sentence2_binary_parse": "( There ( ( are children ) present ) )", "sentence2_parse": "(ROOT (S (NP (EX There)) (VP (VBP are) (NP (NNS children)) (ADVP (RB present)))))"} 6 | {"annotator_labels": ["contradiction"], "captionID": "2267923837.jpg#2", "gold_label": "contradiction", "pairID": "2267923837.jpg#2r1c", "sentence1": "Children smiling and waving at camera", "sentence1_binary_parse": "( Children ( ( ( smiling and ) waving ) ( at camera ) ) )", "sentence1_parse": "(ROOT (NP (S (NP (NNP Children)) (VP (VBG smiling) (CC and) (VBG waving) (PP (IN at) (NP (NN camera)))))))", "sentence2": "The kids are frowning", "sentence2_binary_parse": "( ( The kids ) ( are frowning ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NNS kids)) (VP (VBP are) (VP (VBG frowning)))))"} 7 | {"annotator_labels": ["contradiction"], "captionID": "3691670743.jpg#0", "gold_label": "contradiction", "pairID": "3691670743.jpg#0r1c", "sentence1": "A boy is jumping on skateboard in the middle of a red bridge.", "sentence1_binary_parse": "( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .)))", "sentence2": "The boy skates down the sidewalk.", "sentence2_binary_parse": "( ( The boy ) ( ( ( skates down ) ( the sidewalk ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN boy)) (VP (VBZ skates) (PRT (RP down)) (NP (DT the) (NN sidewalk))) (. .)))"} 8 | {"annotator_labels": ["entailment"], "captionID": "3691670743.jpg#0", "gold_label": "entailment", "pairID": "3691670743.jpg#0r1e", "sentence1": "A boy is jumping on skateboard in the middle of a red bridge.", "sentence1_binary_parse": "( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .)))", "sentence2": "The boy does a skateboarding trick.", "sentence2_binary_parse": "( ( The boy ) ( ( does ( a ( skateboarding trick ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN boy)) (VP (VBZ does) (NP (DT a) (NNP skateboarding) (NN trick))) (. .)))"} 9 | {"annotator_labels": ["neutral"], "captionID": "3691670743.jpg#0", "gold_label": "neutral", "pairID": "3691670743.jpg#0r1n", "sentence1": "A boy is jumping on skateboard in the middle of a red bridge.", "sentence1_binary_parse": "( ( A boy ) ( ( is ( ( jumping ( on skateboard ) ) ( in ( ( the middle ) ( of ( a ( red bridge ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN boy)) (VP (VBZ is) (VP (VBG jumping) (PP (IN on) (NP (NN skateboard))) (PP (IN in) (NP (NP (DT the) (NN middle)) (PP (IN of) (NP (DT a) (JJ red) (NN bridge))))))) (. .)))", "sentence2": "The boy is wearing safety equipment.", "sentence2_binary_parse": "( ( The boy ) ( ( is ( wearing ( safety equipment ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN boy)) (VP (VBZ is) (VP (VBG wearing) (NP (NN safety) (NN equipment)))) (. .)))"} 10 | {"annotator_labels": ["neutral"], "captionID": "4804607632.jpg#0", "gold_label": "neutral", "pairID": "4804607632.jpg#0r1n", "sentence1": "An older man sits with his orange juice at a small table in a coffee shop while employees in bright colored shirts smile in the background.", "sentence1_binary_parse": "( ( An ( older man ) ) ( ( ( sits ( with ( ( his ( orange juice ) ) ( at ( ( a ( small table ) ) ( in ( a ( coffee shop ) ) ) ) ) ) ) ) ( while ( ( employees ( in ( bright ( colored shirts ) ) ) ) ( smile ( in ( the background ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ sits) (PP (IN with) (NP (NP (PRP$ his) (JJ orange) (NN juice)) (PP (IN at) (NP (NP (DT a) (JJ small) (NN table)) (PP (IN in) (NP (DT a) (NN coffee) (NN shop))))))) (SBAR (IN while) (S (NP (NP (NNS employees)) (PP (IN in) (NP (JJ bright) (JJ colored) (NNS shirts)))) (VP (VBP smile) (PP (IN in) (NP (DT the) (NN background))))))) (. .)))", "sentence2": "An older man drinks his juice as he waits for his daughter to get off work.", "sentence2_binary_parse": "( ( An ( older man ) ) ( ( ( drinks ( his juice ) ) ( as ( he ( waits ( for ( his ( daughter ( to ( ( get off ) work ) ) ) ) ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT An) (JJR older) (NN man)) (VP (VBZ drinks) (NP (PRP$ his) (NN juice)) (SBAR (IN as) (S (NP (PRP he)) (VP (VBZ waits) (PP (IN for) (NP (PRP$ his) (NN daughter) (S (VP (TO to) (VP (VB get) (PRT (RP off)) (NP (NN work))))))))))) (. .)))"} -------------------------------------------------------------------------------- /tests/data/snli_1.0_test/snli_1.0_test.jsonl: -------------------------------------------------------------------------------- 1 | {"annotator_labels": ["neutral", "contradiction", "contradiction", "neutral", "neutral"], "captionID": "2677109430.jpg#1", "gold_label": "neutral", "pairID": "2677109430.jpg#1r1n", "sentence1": "This church choir sings to the masses as they sing joyous songs from the book at a church.", "sentence1_binary_parse": "( ( This ( church choir ) ) ( ( ( sings ( to ( the masses ) ) ) ( as ( they ( ( sing ( joyous songs ) ) ( from ( ( the book ) ( at ( a church ) ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT This) (NN church) (NN choir)) (VP (VBZ sings) (PP (TO to) (NP (DT the) (NNS masses))) (SBAR (IN as) (S (NP (PRP they)) (VP (VBP sing) (NP (JJ joyous) (NNS songs)) (PP (IN from) (NP (NP (DT the) (NN book)) (PP (IN at) (NP (DT a) (NN church))))))))) (. .)))", "sentence2": "The church has cracks in the ceiling.", "sentence2_binary_parse": "( ( The church ) ( ( has ( cracks ( in ( the ceiling ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN church)) (VP (VBZ has) (NP (NP (NNS cracks)) (PP (IN in) (NP (DT the) (NN ceiling))))) (. .)))"} 2 | {"annotator_labels": ["entailment", "entailment", "entailment", "neutral", "entailment"], "captionID": "2677109430.jpg#1", "gold_label": "entailment", "pairID": "2677109430.jpg#1r1e", "sentence1": "This church choir sings to the masses as they sing joyous songs from the book at a church.", "sentence1_binary_parse": "( ( This ( church choir ) ) ( ( ( sings ( to ( the masses ) ) ) ( as ( they ( ( sing ( joyous songs ) ) ( from ( ( the book ) ( at ( a church ) ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT This) (NN church) (NN choir)) (VP (VBZ sings) (PP (TO to) (NP (DT the) (NNS masses))) (SBAR (IN as) (S (NP (PRP they)) (VP (VBP sing) (NP (JJ joyous) (NNS songs)) (PP (IN from) (NP (NP (DT the) (NN book)) (PP (IN at) (NP (DT a) (NN church))))))))) (. .)))", "sentence2": "The church is filled with song.", "sentence2_binary_parse": "( ( The church ) ( ( is ( filled ( with song ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN church)) (VP (VBZ is) (VP (VBN filled) (PP (IN with) (NP (NN song))))) (. .)))"} 3 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "2677109430.jpg#1", "gold_label": "contradiction", "pairID": "2677109430.jpg#1r1c", "sentence1": "This church choir sings to the masses as they sing joyous songs from the book at a church.", "sentence1_binary_parse": "( ( This ( church choir ) ) ( ( ( sings ( to ( the masses ) ) ) ( as ( they ( ( sing ( joyous songs ) ) ( from ( ( the book ) ( at ( a church ) ) ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (DT This) (NN church) (NN choir)) (VP (VBZ sings) (PP (TO to) (NP (DT the) (NNS masses))) (SBAR (IN as) (S (NP (PRP they)) (VP (VBP sing) (NP (JJ joyous) (NNS songs)) (PP (IN from) (NP (NP (DT the) (NN book)) (PP (IN at) (NP (DT a) (NN church))))))))) (. .)))", "sentence2": "A choir singing at a baseball game.", "sentence2_binary_parse": "( ( ( A choir ) ( singing ( at ( a ( baseball game ) ) ) ) ) . )", "sentence2_parse": "(ROOT (NP (NP (DT A) (NN choir)) (VP (VBG singing) (PP (IN at) (NP (DT a) (NN baseball) (NN game)))) (. .)))"} 4 | {"annotator_labels": ["neutral", "neutral", "neutral", "neutral", "neutral"], "captionID": "6160193920.jpg#4", "gold_label": "neutral", "pairID": "6160193920.jpg#4r1n", "sentence1": "A woman with a green headscarf, blue shirt and a very big grin.", "sentence1_binary_parse": "( ( ( A woman ) ( with ( ( ( ( ( a ( green headscarf ) ) , ) ( blue shirt ) ) and ) ( a ( ( very big ) grin ) ) ) ) ) . )", "sentence1_parse": "(ROOT (NP (NP (DT A) (NN woman)) (PP (IN with) (NP (NP (DT a) (JJ green) (NN headscarf)) (, ,) (NP (JJ blue) (NN shirt)) (CC and) (NP (DT a) (ADJP (RB very) (JJ big)) (NN grin)))) (. .)))", "sentence2": "The woman is young.", "sentence2_binary_parse": "( ( The woman ) ( ( is young ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN woman)) (VP (VBZ is) (ADJP (JJ young))) (. .)))"} 5 | {"annotator_labels": ["entailment", "entailment", "contradiction", "entailment", "neutral"], "captionID": "6160193920.jpg#4", "gold_label": "entailment", "pairID": "6160193920.jpg#4r1e", "sentence1": "A woman with a green headscarf, blue shirt and a very big grin.", "sentence1_binary_parse": "( ( ( A woman ) ( with ( ( ( ( ( a ( green headscarf ) ) , ) ( blue shirt ) ) and ) ( a ( ( very big ) grin ) ) ) ) ) . )", "sentence1_parse": "(ROOT (NP (NP (DT A) (NN woman)) (PP (IN with) (NP (NP (DT a) (JJ green) (NN headscarf)) (, ,) (NP (JJ blue) (NN shirt)) (CC and) (NP (DT a) (ADJP (RB very) (JJ big)) (NN grin)))) (. .)))", "sentence2": "The woman is very happy.", "sentence2_binary_parse": "( ( The woman ) ( ( is ( very happy ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN woman)) (VP (VBZ is) (ADJP (RB very) (JJ happy))) (. .)))"} 6 | {"annotator_labels": ["contradiction", "contradiction", "neutral", "contradiction", "contradiction"], "captionID": "6160193920.jpg#4", "gold_label": "contradiction", "pairID": "6160193920.jpg#4r1c", "sentence1": "A woman with a green headscarf, blue shirt and a very big grin.", "sentence1_binary_parse": "( ( ( A woman ) ( with ( ( ( ( ( a ( green headscarf ) ) , ) ( blue shirt ) ) and ) ( a ( ( very big ) grin ) ) ) ) ) . )", "sentence1_parse": "(ROOT (NP (NP (DT A) (NN woman)) (PP (IN with) (NP (NP (DT a) (JJ green) (NN headscarf)) (, ,) (NP (JJ blue) (NN shirt)) (CC and) (NP (DT a) (ADJP (RB very) (JJ big)) (NN grin)))) (. .)))", "sentence2": "The woman has been shot.", "sentence2_binary_parse": "( ( The woman ) ( ( has ( been shot ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NN woman)) (VP (VBZ has) (VP (VBN been) (VP (VBN shot)))) (. .)))"} 7 | {"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "4791890474.jpg#3", "gold_label": "entailment", "pairID": "4791890474.jpg#3r1e", "sentence1": "An old man with a package poses in front of an advertisement.", "sentence1_binary_parse": "( ( ( An ( old man ) ) ( with ( a package ) ) ) ( ( poses ( in ( front ( of ( an advertisement ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT An) (JJ old) (NN man)) (PP (IN with) (NP (DT a) (NN package)))) (VP (VBZ poses) (PP (IN in) (NP (NP (NN front)) (PP (IN of) (NP (DT an) (NN advertisement)))))) (. .)))", "sentence2": "A man poses in front of an ad.", "sentence2_binary_parse": "( ( A man ) ( ( poses ( in ( front ( of ( an ad ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man)) (VP (VBZ poses) (PP (IN in) (NP (NP (NN front)) (PP (IN of) (NP (DT an) (NN ad)))))) (. .)))"} 8 | {"annotator_labels": ["neutral", "neutral", "entailment", "neutral", "neutral"], "captionID": "4791890474.jpg#3", "gold_label": "neutral", "pairID": "4791890474.jpg#3r1n", "sentence1": "An old man with a package poses in front of an advertisement.", "sentence1_binary_parse": "( ( ( An ( old man ) ) ( with ( a package ) ) ) ( ( poses ( in ( front ( of ( an advertisement ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT An) (JJ old) (NN man)) (PP (IN with) (NP (DT a) (NN package)))) (VP (VBZ poses) (PP (IN in) (NP (NP (NN front)) (PP (IN of) (NP (DT an) (NN advertisement)))))) (. .)))", "sentence2": "A man poses in front of an ad for beer.", "sentence2_binary_parse": "( ( A man ) ( ( poses ( in ( front ( of ( ( an ad ) ( for beer ) ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man)) (VP (VBZ poses) (PP (IN in) (NP (NP (NN front)) (PP (IN of) (NP (NP (DT an) (NN ad)) (PP (IN for) (NP (NN beer)))))))) (. .)))"} 9 | {"annotator_labels": ["contradiction", "neutral", "contradiction", "contradiction", "contradiction"], "captionID": "4791890474.jpg#3", "gold_label": "contradiction", "pairID": "4791890474.jpg#3r1c", "sentence1": "An old man with a package poses in front of an advertisement.", "sentence1_binary_parse": "( ( ( An ( old man ) ) ( with ( a package ) ) ) ( ( poses ( in ( front ( of ( an advertisement ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT An) (JJ old) (NN man)) (PP (IN with) (NP (DT a) (NN package)))) (VP (VBZ poses) (PP (IN in) (NP (NP (NN front)) (PP (IN of) (NP (DT an) (NN advertisement)))))) (. .)))", "sentence2": "A man walks by an ad.", "sentence2_binary_parse": "( ( A man ) ( ( walks ( by ( an ad ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man)) (VP (VBZ walks) (PP (IN by) (NP (DT an) (NN ad)))) (. .)))"} 10 | {"annotator_labels": ["neutral", "neutral", "contradiction", "contradiction", "neutral"], "captionID": "6526219567.jpg#4", "gold_label": "neutral", "pairID": "6526219567.jpg#4r1n", "sentence1": "A statue at a museum that no seems to be looking at.", "sentence1_binary_parse": "( ( ( A statue ) ( at ( a museum ) ) ) ( ( that no ) ( ( seems ( to ( be ( looking at ) ) ) ) . ) ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN statue)) (PP (IN at) (NP (DT a) (NN museum)))) (ADVP (RB that) (RB no)) (VP (VBZ seems) (S (VP (TO to) (VP (VB be) (VP (VBG looking) (ADVP (IN at))))))) (. .)))", "sentence2": "The statue is offensive and people are mad that it is on display.", "sentence2_binary_parse": "( ( ( ( ( The statue ) ( is offensive ) ) and ) ( people ( are ( mad ( that ( it ( is ( on display ) ) ) ) ) ) ) ) . )", "sentence2_parse": "(ROOT (S (S (NP (DT The) (NN statue)) (VP (VBZ is) (ADJP (JJ offensive)))) (CC and) (S (NP (NNS people)) (VP (VBP are) (ADJP (JJ mad) (SBAR (IN that) (S (NP (PRP it)) (VP (VBZ is) (PP (IN on) (NP (NN display))))))))) (. .)))"} -------------------------------------------------------------------------------- /tests/data/snli_1.0_test/snli_1.0_dev.jsonl: -------------------------------------------------------------------------------- 1 | {"annotator_labels": ["neutral", "entailment", "neutral", "neutral", "neutral"], "captionID": "4705552913.jpg#2", "gold_label": "neutral", "pairID": "4705552913.jpg#2r1n", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "The sisters are hugging goodbye while holding to go packages after just eating lunch.", "sentence2_binary_parse": "( ( The sisters ) ( ( are ( ( hugging goodbye ) ( while ( holding ( to ( ( go packages ) ( after ( just ( eating lunch ) ) ) ) ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP are) (VP (VBG hugging) (NP (UH goodbye)) (PP (IN while) (S (VP (VBG holding) (S (VP (TO to) (VP (VB go) (NP (NNS packages)) (PP (IN after) (S (ADVP (RB just)) (VP (VBG eating) (NP (NN lunch))))))))))))) (. .)))"} 2 | {"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "4705552913.jpg#2", "gold_label": "entailment", "pairID": "4705552913.jpg#2r1e", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "Two woman are holding packages.", "sentence2_binary_parse": "( ( Two woman ) ( ( are ( holding packages ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (CD Two) (NN woman)) (VP (VBP are) (VP (VBG holding) (NP (NNS packages)))) (. .)))"} 3 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "4705552913.jpg#2", "gold_label": "contradiction", "pairID": "4705552913.jpg#2r1c", "sentence1": "Two women are embracing while holding to go packages.", "sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", "sentence2": "The men are fighting outside a deli.", "sentence2_binary_parse": "( ( The men ) ( ( are ( fighting ( outside ( a deli ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT The) (NNS men)) (VP (VBP are) (VP (VBG fighting) (PP (IN outside) (NP (DT a) (NNS deli))))) (. .)))"} 4 | {"annotator_labels": ["entailment", "entailment", "entailment", "entailment", "entailment"], "captionID": "2407214681.jpg#0", "gold_label": "entailment", "pairID": "2407214681.jpg#0r1e", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids in numbered jerseys wash their hands.", "sentence2_binary_parse": "( ( ( Two kids ) ( in ( numbered jerseys ) ) ) ( ( wash ( their hands ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN in) (NP (JJ numbered) (NNS jerseys)))) (VP (VBP wash) (NP (PRP$ their) (NNS hands))) (. .)))"} 5 | {"annotator_labels": ["neutral", "neutral", "neutral", "entailment", "entailment"], "captionID": "2407214681.jpg#0", "gold_label": "neutral", "pairID": "2407214681.jpg#0r1n", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids at a ballgame wash their hands.", "sentence2_binary_parse": "( ( ( Two kids ) ( at ( a ballgame ) ) ) ( ( wash ( their hands ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN at) (NP (DT a) (NN ballgame)))) (VP (VBP wash) (NP (PRP$ their) (NNS hands))) (. .)))"} 6 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "2407214681.jpg#0", "gold_label": "contradiction", "pairID": "2407214681.jpg#0r1c", "sentence1": "Two young children in blue jerseys, one with the number 9 and one with the number 2 are standing on wooden steps in a bathroom and washing their hands in a sink.", "sentence1_binary_parse": "( ( ( Two ( young children ) ) ( in ( ( ( ( ( blue jerseys ) , ) ( one ( with ( the ( number 9 ) ) ) ) ) and ) ( one ( with ( the ( number 2 ) ) ) ) ) ) ) ( ( are ( ( ( standing ( on ( ( wooden steps ) ( in ( a bathroom ) ) ) ) ) and ) ( ( washing ( their hands ) ) ( in ( a sink ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS children)) (PP (IN in) (NP (NP (JJ blue) (NNS jerseys)) (, ,) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 9)))) (CC and) (NP (NP (CD one)) (PP (IN with) (NP (DT the) (NN number) (CD 2))))))) (VP (VBP are) (VP (VP (VBG standing) (PP (IN on) (NP (NP (JJ wooden) (NNS steps)) (PP (IN in) (NP (DT a) (NN bathroom)))))) (CC and) (VP (VBG washing) (NP (PRP$ their) (NNS hands)) (PP (IN in) (NP (DT a) (NN sink)))))) (. .)))", "sentence2": "Two kids in jackets walk to school.", "sentence2_binary_parse": "( ( ( Two kids ) ( in jackets ) ) ( ( walk ( to school ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (NP (CD Two) (NNS kids)) (PP (IN in) (NP (NNS jackets)))) (VP (VBP walk) (PP (TO to) (NP (NN school)))) (. .)))"} 7 | {"annotator_labels": ["contradiction", "contradiction", "contradiction", "contradiction", "contradiction"], "captionID": "4718146904.jpg#2", "gold_label": "contradiction", "pairID": "4718146904.jpg#2r1c", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A woman drinks her coffee in a small cafe.", "sentence2_binary_parse": "( ( A woman ) ( ( ( drinks ( her coffee ) ) ( in ( a ( small cafe ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN woman)) (VP (VBZ drinks) (NP (PRP$ her) (NN coffee)) (PP (IN in) (NP (DT a) (JJ small) (NN cafe)))) (. .)))"} 8 | {"annotator_labels": ["neutral", "entailment", "entailment", "neutral", "neutral"], "captionID": "4718146904.jpg#2", "gold_label": "neutral", "pairID": "4718146904.jpg#2r1n", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A man selling donuts to a customer during a world exhibition event while people wait in line behind him.", "sentence2_binary_parse": "( ( A ( man selling ) ) ( ( ( donuts ( to ( ( a customer ) ( during ( a ( world ( exhibition event ) ) ) ) ) ) ) ( while ( people ( ( wait ( in line ) ) ( behind him ) ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (NP (DT a) (NN customer)) (PP (IN during) (NP (DT a) (NN world) (NN exhibition) (NN event))))) (SBAR (IN while) (S (NP (NNS people)) (VP (VBP wait) (PP (IN in) (NP (NN line))) (PP (IN behind) (NP (PRP him))))))) (. .)))"} 9 | {"annotator_labels": ["entailment", "neutral", "entailment", "entailment", "entailment"], "captionID": "4718146904.jpg#2", "gold_label": "entailment", "pairID": "4718146904.jpg#2r1e", "sentence1": "A man selling donuts to a customer during a world exhibition event held in the city of Angeles", "sentence1_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) ( during ( ( a ( world ( exhibition event ) ) ) ( held ( in ( ( the city ) ( of Angeles ) ) ) ) ) ) ) )", "sentence1_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer))) (PP (IN during) (NP (NP (DT a) (NN world) (NN exhibition) (NN event)) (VP (VBN held) (PP (IN in) (NP (NP (DT the) (NN city)) (PP (IN of) (NP (NNP Angeles)))))))))))", "sentence2": "A man selling donuts to a customer.", "sentence2_binary_parse": "( ( A ( man selling ) ) ( ( donuts ( to ( a customer ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN man) (NN selling)) (VP (VBZ donuts) (PP (TO to) (NP (DT a) (NN customer)))) (. .)))"} 10 | {"annotator_labels": ["entailment", "neutral", "entailment", "entailment", "neutral"], "captionID": "3980085662.jpg#0", "gold_label": "entailment", "pairID": "3980085662.jpg#0r1e", "sentence1": "Two young boys of opposing teams play football, while wearing full protection uniforms and helmets.", "sentence1_binary_parse": "( ( ( Two ( young boys ) ) ( of ( opposing teams ) ) ) ( ( ( ( play football ) , ) ( while ( wearing ( full ( protection ( ( uniforms and ) helmets ) ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (CD Two) (JJ young) (NNS boys)) (PP (IN of) (NP (VBG opposing) (NNS teams)))) (VP (VBP play) (NP (NN football)) (, ,) (PP (IN while) (S (VP (VBG wearing) (NP (JJ full) (NN protection) (NNS uniforms) (CC and) (NNS helmets)))))) (. .)))", "sentence2": "boys play football", "sentence2_binary_parse": "( boys ( play football ) )", "sentence2_parse": "(ROOT (S (NP (NNS boys)) (VP (VBP play) (NP (NN football)))))"} -------------------------------------------------------------------------------- /src/wandb_allennlp/training/callbacks/log_to_wandb.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict, Any, Optional, Callable 2 | import logging 3 | from allennlp.common.registrable import Registrable 4 | from allennlp.training.callbacks import ( 5 | WandBCallback, 6 | TrainerCallback, 7 | ) 8 | from allennlp.training.callbacks.log_writer import LogWriterCallback 9 | from allennlp.training import GradientDescentTrainer 10 | from allennlp.data import TensorDict 11 | 12 | from allennlp.models.archival import archive_model, verify_include_in_archive 13 | from wandb_allennlp.utils import read_from_env 14 | from overrides import overrides 15 | import os 16 | import torch 17 | from .utils import flatten_dict 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class AllennlpWandbSubCallback(Registrable): 23 | """ 24 | This is the abstract class that describes a sub-callback to be used with 25 | AllennlpWandbCallback. 26 | 27 | There will be only one isinstance of AllennlpWandbCallback per trainer. 28 | To add custom functionallity to this isinstance will require inheritance 29 | and code duplication. This class is intented to aid extensibility using 30 | composition. 31 | """ 32 | 33 | def __init__(self, priority: int, **kwargs: Any): 34 | self.priority = priority 35 | 36 | def on_start_( 37 | self, 38 | super_callback: "AllennlpWandbCallback", 39 | trainer: "GradientDescentTrainer", 40 | is_primary: bool = True, 41 | **kwargs: Any, 42 | ) -> None: 43 | """ 44 | This callback hook is called before the training is started. 45 | """ 46 | self.trainer = trainer 47 | 48 | def on_batch_( 49 | self, 50 | super_callback: "AllennlpWandbCallback", 51 | trainer: "GradientDescentTrainer", 52 | batch_inputs: List[TensorDict], 53 | batch_outputs: List[Dict[str, Any]], 54 | batch_metrics: Dict[str, Any], 55 | epoch: int, 56 | batch_number: int, 57 | is_training: bool, 58 | is_primary: bool = True, 59 | batch_grad_norm: Optional[float] = None, 60 | **kwargs: Any, 61 | ) -> None: 62 | """ 63 | This callback hook is called after the end of each batch. 64 | """ 65 | pass 66 | 67 | def on_epoch_( 68 | self, 69 | super_callback: "AllennlpWandbCallback", 70 | trainer: "GradientDescentTrainer", 71 | metrics: Dict[str, Any], 72 | epoch: int, 73 | is_primary: bool = True, 74 | **kwargs: Any, 75 | ) -> None: 76 | """ 77 | This callback hook is called after the end of each epoch. 78 | """ 79 | pass 80 | 81 | def on_end_( 82 | self, 83 | super_callback: "AllennlpWandbCallback", 84 | trainer: "GradientDescentTrainer", 85 | metrics: Dict[str, Any] = None, 86 | epoch: int = None, 87 | is_primary: bool = True, 88 | **kwargs: Any, 89 | ) -> None: 90 | """ 91 | This callback hook is called after the final training epoch. 92 | """ 93 | pass 94 | 95 | 96 | @TrainerCallback.register("wandb_allennlp") 97 | class AllennlpWandbCallback(WandBCallback): 98 | """ 99 | This callback should only be used with `train_with_wandb` command. 100 | 101 | Note: 102 | If used with `allennlp train` command, this might have unexpected 103 | behaviour because we read some arguments from environment variables. 104 | """ 105 | 106 | def __init__( 107 | self, 108 | serialization_dir: str, 109 | summary_interval: int = 100, 110 | distribution_interval: Optional[int] = None, 111 | batch_size_interval: Optional[int] = None, 112 | should_log_parameter_statistics: bool = True, 113 | should_log_learning_rate: bool = False, 114 | project: Optional[str] = None, 115 | entity: Optional[str] = None, 116 | group: Optional[str] = None, 117 | name: Optional[str] = None, 118 | notes: Optional[str] = None, 119 | tags: Optional[Union[str, List[str]]] = None, 120 | watch_model: bool = True, 121 | files_to_save: List[str] = ["config.json", "out.log"], 122 | files_to_save_at_end: Optional[List[str]] = None, 123 | include_in_archive: List[str] = None, 124 | save_model_archive: bool = True, 125 | wandb_kwargs: Optional[Dict[str, Any]] = None, 126 | finish_on_end: bool = False, 127 | sub_callbacks: Optional[List[AllennlpWandbSubCallback]] = None, 128 | ) -> None: 129 | logger.debug("Wandb related varaibles") 130 | logger.debug( 131 | "%s | %s | %s", 132 | "variable".ljust(15), 133 | "value from env".ljust(50), 134 | "value in constructor".ljust(50), 135 | ) 136 | 137 | for e, a in [("PROJECT", project), ("ENTITY", entity)]: 138 | logger.debug( 139 | "%s | %s | %s", 140 | str(e).lower()[:15].ljust(15), 141 | str(read_from_env("WANDB_" + e))[:50].ljust(50), 142 | str(a)[:50].ljust(50), 143 | ) 144 | logger.debug("All wandb related envirnment varaibles") 145 | logger.debug("%s | %s ", "ENV VAR.".ljust(15), "VALUE".ljust(50)) 146 | 147 | for k, v in os.environ.items(): 148 | if "WANDB" in k or "ALLENNLP" in k: 149 | logger.debug( 150 | "%s | %s ", 151 | str(k)[:15].ljust(15), 152 | str(v)[:50].ljust(50), 153 | ) 154 | t = read_from_env("WANDB_TAGS") or tags 155 | 156 | if isinstance(t, str): 157 | tags = t.split(",") 158 | else: 159 | tags = t 160 | super().__init__( 161 | serialization_dir, 162 | summary_interval=summary_interval, 163 | distribution_interval=distribution_interval, 164 | batch_size_interval=batch_size_interval, 165 | should_log_parameter_statistics=should_log_parameter_statistics, 166 | should_log_learning_rate=should_log_learning_rate, 167 | # prefer env variables because 168 | project=read_from_env("WANDB_PROJECT") or project, 169 | entity=read_from_env("WANDB_ENTITY") or entity, 170 | group=read_from_env("WANDB_GROUP") or group, 171 | name=read_from_env("WANDB_NAME") or name, 172 | notes=read_from_env("WANDB_NOTES") or notes, 173 | tags=tags, 174 | watch_model=watch_model, 175 | files_to_save=tuple(files_to_save), 176 | wandb_kwargs=wandb_kwargs, 177 | ) 178 | self.finish_on_end = finish_on_end 179 | self._files_to_save_at_end = files_to_save_at_end or [] 180 | self.include_in_archive = include_in_archive 181 | verify_include_in_archive(include_in_archive) 182 | self.save_model_archive = save_model_archive 183 | self.priority = 100 184 | self.sub_callbacks = sorted( 185 | sub_callbacks or [], key=lambda x: x.priority, reverse=True 186 | ) 187 | 188 | if save_model_archive: 189 | self._files_to_save_at_end.append("model.tar.gz") 190 | # do not set wandb dir to be inside the serialization directory. 191 | 192 | if "dir" in self._wandb_kwargs: 193 | self._wandb_kwargs["dir"] = None 194 | 195 | if "config" in self._wandb_kwargs: 196 | self._wandb_kwargs["config"] = flatten_dict( 197 | self._wandb_kwargs["config"] 198 | ) 199 | 200 | def on_start( 201 | self, 202 | trainer: "GradientDescentTrainer", 203 | is_primary: bool = True, 204 | **kwargs: Any, 205 | ) -> None: 206 | super().on_start(trainer, is_primary=is_primary, **kwargs) 207 | 208 | for subcallback in self.sub_callbacks: 209 | subcallback.on_start_(self, trainer, is_primary=is_primary) 210 | 211 | def on_batch( 212 | self, 213 | trainer: "GradientDescentTrainer", 214 | batch_inputs: List[TensorDict], 215 | batch_outputs: List[Dict[str, Any]], 216 | batch_metrics: Dict[str, Any], 217 | epoch: int, 218 | batch_number: int, 219 | is_training: bool, 220 | is_primary: bool = True, 221 | batch_grad_norm: Optional[float] = None, 222 | **kwargs: Any, 223 | ) -> None: 224 | """ 225 | This callback hook is called after the end of each batch. 226 | """ 227 | super().on_batch( 228 | trainer, 229 | batch_inputs, 230 | batch_outputs, 231 | batch_metrics, 232 | epoch, 233 | batch_number, 234 | is_training, 235 | is_primary=is_primary, 236 | batch_grad_norm=batch_grad_norm, 237 | ) 238 | 239 | for sub_callback in self.sub_callbacks: 240 | sub_callback.on_batch_( 241 | self, 242 | trainer, 243 | batch_inputs, 244 | batch_outputs, 245 | batch_metrics, 246 | epoch, 247 | batch_number, 248 | is_training, 249 | is_primary=is_primary, 250 | batch_grad_norm=batch_grad_norm, 251 | ) 252 | 253 | def on_epoch( 254 | self, 255 | trainer: "GradientDescentTrainer", 256 | metrics: Dict[str, Any], 257 | epoch: int, 258 | is_primary: bool = True, 259 | **kwargs: Any, 260 | ) -> None: 261 | super().on_epoch( 262 | trainer, metrics, epoch, is_primary=is_primary, **kwargs 263 | ) 264 | 265 | for sub_callback in self.sub_callbacks: 266 | sub_callback.on_epoch_( 267 | self, trainer, metrics, epoch, is_primary=is_primary, **kwargs 268 | ) 269 | 270 | def on_end( 271 | self, 272 | trainer: "GradientDescentTrainer", 273 | metrics: Dict[str, Any] = None, 274 | epoch: int = None, 275 | is_primary: bool = True, 276 | **kwargs: Any, 277 | ) -> None: 278 | for sub_callback in self.sub_callbacks: 279 | sub_callback.on_end_( 280 | self, 281 | trainer, 282 | metrics=metrics, 283 | epoch=epoch, 284 | is_primary=is_primary, 285 | ) 286 | super().on_end( 287 | trainer, metrics=metrics, epoch=epoch, is_primary=is_primary 288 | ) 289 | 290 | @overrides 291 | def close(self) -> None: 292 | import wandb 293 | 294 | assert wandb.run is not None 295 | # set this here for resuming 296 | os.environ.update({"WANDB_RUN_ID": str(wandb.run.id)}) 297 | 298 | if self.save_model_archive: 299 | # we will have to create archive prematurely here. 300 | # the `train_model()` in `allennlp train` will 301 | # recreate the same model archive later. However, 302 | # this duplication cannot be avioded at this stage. 303 | logger.info("Archiving model before closing wandb.") 304 | archive_model( 305 | self.serialization_dir, 306 | include_in_archive=self.include_in_archive, 307 | ) 308 | 309 | if self._files_to_save_at_end: 310 | for fpath in self._files_to_save_at_end: 311 | self.wandb.save( # type: ignore 312 | os.path.join(self.serialization_dir, fpath), 313 | base_path=self.serialization_dir, 314 | policy="end", 315 | ) 316 | 317 | LogWriterCallback.close(self) 318 | 319 | if self.finish_on_end: 320 | wandb.finish() 321 | -------------------------------------------------------------------------------- /src/wandb_allennlp/commands/train_with_wandb.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Dict, Any, Optional 2 | from .parser_base import WandbParserBase, read_from_env 3 | from allennlp.commands.train import train_model_from_args 4 | from allennlp.commands import Subcommand 5 | import argparse 6 | import logging 7 | import re 8 | import json 9 | import yaml 10 | import os 11 | import sys 12 | from datetime import datetime 13 | from pathlib import Path 14 | from wandb_allennlp.config import ALLENNLP_SERIALIZATION_DIR 15 | import shortuuid 16 | import signal 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def generate_serialization_dir(wandb_run_id: Optional[str] = None) -> Path: 22 | # ref: https://github.com/wandb/client/blob/c4548d3871c4cbdd8c253e46c912c95205bbc7f6/wandb/sdk/wandb_settings.py#L740 23 | root_dir = Path(ALLENNLP_SERIALIZATION_DIR) 24 | root_dir.mkdir(parents=True, exist_ok=True) 25 | datetime_now: datetime = datetime.now() 26 | 27 | if wandb_run_id is None: 28 | # ref: wandb/sdk/lib/runid.py 29 | run_gen = shortuuid.ShortUUID( 30 | alphabet=list("0123456789abcdefghijklmnopqrstuvwxyz") 31 | ) 32 | wandb_run_id = run_gen.random(8) # type: ignore[no-untyped-call] 33 | s = f'run-{datetime.strftime(datetime_now, "%Y%m%d_%H%M%S")}-{wandb_run_id}' 34 | 35 | return root_dir / s 36 | 37 | 38 | def create_dict_for_numbers(k: str, value: Any) -> Dict: 39 | for ki in k.split("."): 40 | if ki.isdigit: 41 | logger.warning( 42 | "Using something like " 43 | "--key.0.something.1=value to index into a list is not" 44 | " supported while overriding hyperparams." 45 | "If you are trying this, you should consider using env" 46 | "variable to achive this." 47 | ) 48 | 49 | return None 50 | 51 | 52 | def translate( 53 | hyperparams: List[str], 54 | ) -> Tuple[List[str], Dict[str, Any], Dict[str, Any]]: 55 | hparams = {} #: temporary variable 56 | env = {} #: params that start with env. 57 | all_args: List[str] = [] #: raw strings of all the unknown arguments 58 | # patter for starting -- or - in --key=value 59 | pattern = re.compile(r"-{1,2}") 60 | 61 | for possible_kwarg in hyperparams: 62 | kw_val = possible_kwarg.split("=") 63 | 64 | if len(kw_val) > 2: 65 | raise ValueError(f"{possible_kwarg} not in valid form.") 66 | 67 | elif len(kw_val) == 2: 68 | k, v = kw_val 69 | all_args.append(k) 70 | # pass through yaml.load to handle 71 | # booleans, ints and floats correctly 72 | # yaml.load with output correct python types 73 | loader = yaml.SafeLoader 74 | loader.add_implicit_resolver( # type: ignore 75 | "tag:yaml.org,2002:float", 76 | re.compile( 77 | """^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)|\\.[0-9_]+(?:[eE][-+][0-9]+)?|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*|[-+]?\\.(?:inf|Inf|INF)|\\.(?:nan|NaN|NAN))$""", 78 | re.X, 79 | ), 80 | list("-+0123456789."), 81 | ) 82 | v = yaml.load(v, Loader=loader) 83 | 84 | if k.startswith("--env.") or k.startswith("-env."): 85 | # split on . and remove the "--env." in the begining 86 | # use json dumps to convert the python type (int, float, bool) 87 | # to string which can be understood by a json reader 88 | # the environment variables have to be stored as string 89 | env[".".join(k.split(".")[1:])] = json.dumps(v) 90 | else: 91 | hparams[pattern.sub("", k)] = v 92 | 93 | elif len(kw_val) == 1: # flag or positional argument 94 | kw_val_ = kw_val[0] 95 | flag = re.match("^-{1,2}(.+)", kw_val_) 96 | 97 | if flag: 98 | all_args.append(flag.group(1)) 99 | else: # positional arg 100 | # for positional arguments like param_path, we don't have to do anything 101 | pass 102 | 103 | else: 104 | logger.warning( 105 | f"{kw_val} not a know argument for allennlp train, " 106 | "or in --hyperparam=value form required for hyperparam overrides" 107 | "or a non-kwarg, " 108 | "or a boolean --flag" 109 | ". Will be ignored by train_with_wandb command." 110 | ) 111 | 112 | # set the env 113 | # os.environ.update(env) 114 | 115 | return all_args, hparams, env 116 | 117 | 118 | @Subcommand.register("train-with-wandb") 119 | class TrainWithWandb(WandbParserBase): 120 | description = "Train with logging to wandb" 121 | help_message = ( 122 | "Use `allennlp train_with_wandb` subcommand instead of " 123 | "`allennp train` to log training to wandb. " 124 | "It supports all the arguments present in `allennlp train`. " 125 | "However, the --overrides have to be specified in the `--kw value` or `--kw=value` form, " 126 | "where 'kw' is the parameter to override and 'value' is its value. " 127 | "Use the dot notation for nested parameters. " 128 | "For instance, {'model': {'embedder': {'type': xyz}}} can be provided as --model.embedder.type xyz" 129 | ) 130 | require_run_id = False 131 | wandb_common_args = ["entity", "project", "notes", "group", "tags"] 132 | 133 | @classmethod 134 | def init_wandb_run( 135 | cls, args: argparse.Namespace 136 | ) -> Optional["wandb.sdk.wandb_run.Run"]: # type: ignore 137 | 138 | import wandb 139 | 140 | wandb_args_dict = cls.get_wandb_run_args(args) 141 | 142 | logger.info( 143 | f"Early init is ON. Initializing wandb with the following args." 144 | ) 145 | 146 | for k, v in wandb_args_dict.items(): 147 | logger.info("%s | %s", k[:15].ljust(15), v[:15].ljust(15)) 148 | 149 | run = wandb.init(**wandb_args_dict) 150 | # just use the log files and do not dynamically patch tensorboard as it messes up the 151 | # the global_step and breaks the normal use of wandb.log() 152 | # after wandb version 0.10.20 153 | # any call to patch either through sync_tensorboard or .patch() 154 | # messes up the log. So we drop it completely. Wandb somehow still 155 | # syncs the tensorboard log folder along with the other folders. 156 | # wandb.tensorboard.patch(save=True, tensorboardX=False) 157 | 158 | for fpath in args.wandb_files_to_save: 159 | self.wandb.save( # type: ignore 160 | os.path.join(args.serialization_dir, fpath), 161 | base_path=args.serialization_dir, 162 | policy="live", 163 | ) 164 | 165 | return run 166 | 167 | def add_arguments( 168 | self, subparser: argparse.ArgumentParser 169 | ) -> argparse.ArgumentParser: 170 | # we use the same args as the allennlp train command 171 | # except the --overrides 172 | # and param_path because 173 | # overrides is something we will create 174 | # and param_path is not a kwarg and hence is always required 175 | # We cannot have a compulsory arg here because if we do and 176 | # we are not trying to call train_with_wandb but some other command 177 | # The feeler call to parse_know_args() will throw an error. 178 | 179 | ######## Begin: arguments for `allennlp train`########## 180 | subparser.add_argument( 181 | "-s", 182 | "--serialization-dir", 183 | required=False, 184 | type=str, 185 | help="directory in which to save the model and its logs", 186 | ) 187 | 188 | subparser.add_argument( 189 | "-r", 190 | "--recover", 191 | action="store_true", 192 | default=False, 193 | help="recover training from the state in serialization_dir", 194 | ) 195 | 196 | subparser.add_argument( 197 | "-f", 198 | "--force", 199 | action="store_true", 200 | required=False, 201 | help="overwrite the output directory if it exists", 202 | ) 203 | 204 | subparser.add_argument( 205 | "--node-rank", 206 | type=int, 207 | default=0, 208 | help="rank of this node in the distributed setup", 209 | ) 210 | 211 | subparser.add_argument( 212 | "--dry-run", 213 | action="store_true", 214 | help=( 215 | "do not train a model, but create a vocabulary, show dataset statistics and " 216 | "other training information" 217 | ), 218 | ) 219 | subparser.add_argument( 220 | "--file-friendly-logging", 221 | action="store_true", 222 | default=False, 223 | help="outputs tqdm status on separate lines and slows tqdm refresh rate", 224 | ) 225 | 226 | subparser.add_argument( 227 | "--include-package", 228 | type=str, 229 | action="append", 230 | default=[], 231 | help="additional packages to include", 232 | ) 233 | ######## End: arguments for `allennlp train`########## 234 | 235 | ######## Begin: Specific keyword arguments for `allennlp train_with_wandb`########## 236 | subparser.add_argument( 237 | "--early-init", 238 | action="store_true", 239 | default=False, 240 | help=( 241 | "Initialize wandb in the command processing itself." 242 | " The default (False) is to initialize wandb in the `on_start` method of the logging callback." 243 | "!!WARNING!! Early initialization of wandb can create problems when using " 244 | "multi-process dataloader or distributed training." 245 | " The only use-case for early initialization is the early population of console log in wandb UI." 246 | ), 247 | ) 248 | subparser.add_argument( 249 | "--wandb-allennlp-files-to-save", 250 | type=str, 251 | action="append", 252 | default=[], 253 | help=( 254 | "Globs describing files to save from the allennlp serialization directory." 255 | "Default: ['config.json', 'out.log']" 256 | ), 257 | ) 258 | ######## End: Specific keyword arguments for `allennlp train_with_wandb`########## 259 | 260 | # we will not do anything if the subcommand is not train_with_wandb 261 | # because otherwise parse_known_args() can throw error or show train_with_wandb's help 262 | # even if we are asking for --help for some other command 263 | 264 | if sys.argv[1] != "train-with-wandb": 265 | subparser.set_defaults(func=main) 266 | 267 | return subparser 268 | # Add dynamic args for overrides and env variables 269 | known_args, hyperparams = subparser.parse_known_args(sys.argv[2:]) 270 | all_args, hparams_for_overrides, env_vars = translate(hyperparams) 271 | overrides_json = f"--overrides={json.dumps(hparams_for_overrides)}" 272 | 273 | # update sys.argv with the json from 274 | sys.argv.append(overrides_json) 275 | # add all hyperparams in both froms--json as well as dot notation 276 | # we do this so that parser_args() in the allennlp code does not throw error 277 | 278 | for arg in all_args: 279 | subparser.add_argument(f"{arg}") 280 | 281 | # Add the rest of the arguments of `allennlp train` that we held out due to the feeler call to parse_known_args() 282 | subparser.add_argument( 283 | "-o", 284 | "--overrides", 285 | type=str, 286 | default="", 287 | help=( 288 | "a json(net) structure used to override the experiment configuration, e.g., " 289 | "'{\"iterator.batch_size\": 16}'. Nested parameters can be specified either" 290 | " with nested dictionaries or with dot syntax." 291 | ), 292 | ) 293 | subparser.add_argument( 294 | "param_path", 295 | type=str, 296 | help="path to parameter file describing the model to be trained", 297 | ) 298 | 299 | # set env vars 300 | os.environ.update(env_vars) 301 | 302 | subparser.set_defaults(func=main) 303 | 304 | return subparser 305 | 306 | 307 | def main(args: argparse.Namespace) -> None: 308 | # We keep serialization_dir and the wandb run directory serperate now. 309 | # We will generate a suitable seriaization-dir if not specified: 310 | # 1. If run_id is given either as cli argument (single run) or 311 | # as environment variable (run of a sweep), we will use the same 312 | # format as wandb to create a serialization-dir in ALLENNLP_SERIALIZATION_DIR 313 | # as the root dir. 314 | # 2. If run_id cannot be obtained, we will generate a random id and treat 315 | # it as run_id to generate a serialization-dir in ALLENNLP_SERIALIZATION_DIR 316 | 317 | 318 | if args.serialization_dir is None: 319 | logging.info(f"Set set serialization_dir as {args.serialization_dir}") 320 | args.serialization_dir = generate_serialization_dir(args.wandb_run_id) 321 | 322 | if args.early_init: 323 | wandb_run = TrainWithWandb.init_wandb_run(args) 324 | train_model_from_args(args) 325 | --------------------------------------------------------------------------------