├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── custom.md │ └── feature_request.md └── workflows │ ├── ci.yml │ ├── docker.yml │ └── docs.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── classy ├── __init__.py ├── data │ ├── __init__.py │ ├── data_drivers.py │ ├── data_modules.py │ └── dataset │ │ ├── __init__.py │ │ ├── base.py │ │ └── hf │ │ ├── __init__.py │ │ ├── base.py │ │ ├── classification.py │ │ └── generation.py ├── evaluation │ ├── __init__.py │ ├── base.py │ ├── generation.py │ ├── simple.py │ ├── span.py │ └── squad.py ├── optim │ ├── __init__.py │ ├── factories.py │ └── optimizers │ │ ├── __init__.py │ │ └── radam.py ├── pl_callbacks │ ├── __init__.py │ ├── best_checkpoint.py │ └── prediction.py ├── pl_modules │ ├── __init__.py │ ├── base.py │ ├── hf │ │ ├── __init__.py │ │ ├── classification.py │ │ └── generation.py │ └── mixins │ │ ├── __init__.py │ │ ├── prediction.py │ │ ├── saving.py │ │ ├── task.py │ │ ├── task_serve.py │ │ └── task_ui.py ├── scripts │ ├── __init__.py │ ├── cli │ │ ├── __init__.py │ │ ├── demo.py │ │ ├── describe.py │ │ ├── download.py │ │ ├── evaluate.py │ │ ├── export.py │ │ ├── import_.py │ │ ├── predict.py │ │ ├── serve.py │ │ ├── train.py │ │ ├── upload.py │ │ └── utils.py │ └── model │ │ ├── __init__.py │ │ ├── demo.py │ │ ├── describe.py │ │ ├── download.py │ │ ├── evaluate.py │ │ ├── export.py │ │ ├── import_.py │ │ ├── predict.py │ │ ├── serve.py │ │ ├── train.py │ │ └── upload.py ├── utils │ ├── __init__.py │ ├── commons.py │ ├── config.py │ ├── data.py │ ├── experiment.py │ ├── file.py │ ├── help_cli.py │ ├── hydra.py │ ├── hydra_patch.py │ ├── lightning.py │ ├── log.py │ ├── omegaconf.py │ ├── optional_deps.py │ ├── plotly.py │ ├── rich_config.py │ ├── streamlit.py │ ├── train_coordinates.py │ └── vocabulary.py └── version.py ├── configurations ├── .placeholder ├── __init__.py ├── callbacks │ ├── empty.yaml │ ├── evaluation.yaml │ └── file-dumper.yaml ├── data │ ├── generation.yaml │ ├── qa.yaml │ ├── sentence-pair.yaml │ ├── sequence.yaml │ └── token.yaml ├── evaluation │ ├── generation.yaml │ ├── qa.yaml │ ├── rouge.yaml │ ├── sacrebleu.yaml │ ├── sentence-pair.yaml │ ├── sequence.yaml │ ├── span.yaml │ ├── squad-v1.yaml │ └── token.yaml ├── generation.yaml ├── logging │ └── default.yaml ├── model │ ├── generation.yaml │ ├── qa.yaml │ ├── sentence-pair.yaml │ ├── sequence.yaml │ └── token.yaml ├── prediction-params │ ├── generation-beam.yaml │ └── generation-sample.yaml ├── prediction │ └── default.yaml ├── profiles │ ├── bart-base.yaml │ ├── bart-large.yaml │ ├── bert-base.yaml │ ├── bert-large.yaml │ ├── deberta-base.yaml │ ├── deberta-large.yaml │ ├── distilbert.yaml │ ├── distilroberta.yaml │ ├── gpt2-large.yaml │ ├── gpt2-medium.yaml │ ├── gpt2.yaml │ ├── mbart.yaml │ ├── multilingual-bert.yaml │ ├── roberta-base.yaml │ ├── roberta-large.yaml │ ├── squeezebert.yaml │ ├── xlm-roberta-base.yaml │ └── xlm-roberta-large.yaml ├── qa.yaml ├── sentence-pair.yaml ├── sequence.yaml ├── token.yaml └── training │ └── default.yaml ├── data ├── .placeholder ├── generation │ ├── cnn_dailymail.ipynb │ └── tatoeba.ipynb ├── qa │ └── SQuAD.ipynb ├── sentence-pair │ └── quora_question_pairs.ipynb ├── sequence │ └── sst2.ipynb └── token │ └── conll2003.ipynb ├── docs ├── .gitignore ├── README.md ├── babel.config.js ├── docs │ ├── getting-started │ │ ├── _category_.json │ │ ├── basic │ │ │ ├── _category_.json │ │ │ ├── choosing-profile.md │ │ │ ├── data-formatting.md │ │ │ ├── inference.md │ │ │ ├── intro.md │ │ │ └── train.md │ │ └── customizing-things │ │ │ ├── _category_.json │ │ │ ├── changing-profile.md │ │ │ ├── config.md │ │ │ ├── custom-data-format.md │ │ │ ├── custom-dataset.md │ │ │ ├── custom-metric.md │ │ │ ├── custom-model.md │ │ │ ├── custom-optimizer.md │ │ │ └── template.md │ ├── glossary │ │ ├── _category_.json │ │ └── token-batch-size.md │ ├── installation.md │ ├── intro.md │ └── reference-manual │ │ ├── _category_.json │ │ ├── cli │ │ ├── _category_.json │ │ ├── describe.md │ │ ├── export.md │ │ ├── inference.md │ │ ├── predict.md │ │ ├── train.md │ │ └── up-download.md │ │ ├── mixins.md │ │ ├── profiles.md │ │ ├── structured-configs │ │ ├── _category_.json │ │ ├── changing-config.md │ │ ├── overall-structure.md │ │ └── visualize.md │ │ └── tasks-and-formats.md ├── docusaurus.config.js ├── local-dev.sh ├── package.json ├── pdoc │ ├── pdoc_postprocess.py │ └── templates │ │ ├── config.mako │ │ └── text.mako ├── sidebars.js ├── src │ ├── components │ │ ├── HomepageFeatures.js │ │ ├── HomepageFeatures.module.css │ │ ├── api-link.js │ │ ├── termynal.css │ │ └── termynal.js │ ├── css │ │ ├── api.css │ │ └── custom.css │ └── pages │ │ ├── examples.md │ │ ├── index.js │ │ ├── index.module.css │ │ └── markdown-page.md └── static │ ├── .nojekyll │ └── img │ ├── CLASSY.svg │ ├── classy_logo-short.png │ ├── classy_logo-short_transparent.png │ ├── classy_logo.png │ ├── classy_logotypo.png │ ├── docusaurus.png │ ├── favicon.ico │ ├── hyperlink.svg │ ├── intro │ ├── classy-demo-seq-config.png │ ├── classy-demo-seq-model.png │ ├── classy-demo-tok-config.png │ ├── classy-demo-tok-model.png │ ├── classy-describe-seq-chars.png │ ├── classy-describe-seq-labels.png │ ├── classy-serve-tok.png │ ├── classy-serve.png │ ├── classy-train-print-tok.png │ ├── demo.png │ └── serve-docs.png │ ├── logo.svg │ ├── tutorial │ ├── docsVersionDropdown.png │ └── localeDropdown.png │ ├── undraw_docusaurus_mountain.svg │ ├── undraw_docusaurus_react.svg │ └── undraw_docusaurus_tree.svg ├── experiments └── .placeholder ├── extra-requirements.txt ├── img └── logo.png ├── requirements.txt ├── setup.py ├── setup.sh └── tests ├── test_profiles.py └── test_version.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/* 2 | docs/* 3 | experiments/* 4 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behaviour: 15 | 16 | **Expected behaviour** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Actual behaviour** 20 | A clear and concise description of what happened. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Desktop (please complete the following information):** 26 | - OS: [e.g. Ubuntu 20.04] 27 | - PyTorch Version: 28 | - Classy Version: 29 | 30 | **Additional context** 31 | Add any other context about the problem here. 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/custom.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Custom issue template 3 | about: Describe this issue template's purpose here. 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | schedule: 7 | - cron: '00 08 * * *' # early morning (04:00 UTC) every day 8 | 9 | jobs: 10 | check_date: 11 | runs-on: ubuntu-latest 12 | name: Check latest commit 13 | outputs: 14 | WAS_EDITED: ${{ steps.check_date.outputs.WAS_EDITED }} 15 | steps: 16 | - uses: actions/checkout@v2 17 | with: 18 | ref: develop 19 | 20 | - id: check_date 21 | name: Check if there were commits in the last day 22 | if: ${{ github.event_name == 'schedule' }} 23 | run: echo '::set-output name=WAS_EDITED::'$(test -n "$(git log --format=%H --since='72 hours ago')" && echo 'true' || echo 'false') 24 | 25 | build: 26 | needs: [check_date] 27 | if: ${{ github.event_name == 'release' || needs.check_date.outputs.WAS_EDITED == 'true' }} 28 | 29 | name: Build Package 30 | 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v2 34 | with: 35 | ref: develop 36 | 37 | - name: Setup Python 38 | uses: actions/setup-python@v2 39 | with: 40 | python-version: 3.8 41 | 42 | - uses: actions/cache@v2 43 | with: 44 | path: ${{ env.pythonLocation }} 45 | key: ${{ runner.os }}-pydeps-${{ env.pythonLocation }}-${{ hashFiles('setup.py') }}-${{ hashFiles('requirements.txt') }} 46 | 47 | - name: Install requirements 48 | run: | 49 | pip install -r requirements.txt 50 | 51 | - name: Set version name 52 | if: ${{ github.event_name == 'schedule' }} 53 | run: | 54 | # You can't set env variables to bash commands, we need 55 | # to export them this way instead. 56 | echo "CLASSY_VERSION_SUFFIX=dev$(date -u +%Y%m%d)" >> $GITHUB_ENV 57 | 58 | - name: Build wheel 59 | run: | 60 | echo "Building packages for pypi push" 61 | python setup.py bdist_wheel sdist 62 | 63 | - name: Save package 64 | uses: actions/upload-artifact@v1 65 | with: 66 | name: package 67 | path: dist 68 | 69 | - name: Clean up 70 | if: always() 71 | run: | 72 | pip uninstall -y classy-core 73 | 74 | publish: 75 | name: Publish to PyPI 76 | needs: [build] 77 | if: ${{ (github.repository == 'sunglasses-ai/classy') && (github.event_name == 'release' || github.event_name == 'schedule') }} 78 | runs-on: ubuntu-latest 79 | 80 | steps: 81 | - uses: actions/checkout@v2 82 | with: 83 | ref: develop 84 | 85 | - name: Setup Python 86 | uses: actions/setup-python@v2 87 | with: 88 | python-version: 3.8 89 | 90 | - name: Install requirements 91 | run: | 92 | pip install --upgrade pip setuptools wheel twine 93 | 94 | - name: Download package 95 | uses: actions/download-artifact@v1 96 | with: 97 | name: package 98 | path: dist 99 | 100 | - name: Publish core package 101 | run: | 102 | twine upload -u ${{ secrets.PYPI_USERNAME }} -p ${{ secrets.PYPI_PASSWORD }} dist/* 103 | -------------------------------------------------------------------------------- /.github/workflows/docker.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docker image 2 | on: 3 | push: 4 | tags: 5 | - 'v*' 6 | 7 | jobs: 8 | push_to_registry: 9 | name: Push Docker image to Docker Hub 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out the repo 13 | uses: actions/checkout@v2 14 | 15 | - name: Log in to Docker Hub 16 | uses: docker/login-action@f054a8b539a109f9f41c372932f1ae047eff08c9 17 | with: 18 | username: ${{ secrets.DOCKERHUB_USERNAME }} 19 | password: ${{ secrets.DOCKERHUB_TOKEN }} 20 | 21 | - name: Extract metadata (tags, labels) for Docker 22 | id: meta 23 | uses: docker/metadata-action@98669ae865ea3cffbcbaa878cf57c20bbf1c6c38 24 | with: 25 | images: poccio/classy 26 | tags: | 27 | type=schedule 28 | type=ref,event=branch 29 | type=ref,event=tag 30 | 31 | - name: Build and push Docker image 32 | uses: docker/build-push-action@ad44023a93711e3deb337508980b4b5e9bcdc5dc 33 | with: 34 | context: . 35 | push: true 36 | tags: ${{ steps.meta.outputs.tags }} 37 | labels: ${{ steps.meta.outputs.labels }} 38 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build and Deploy Docs 2 | on: 3 | release: 4 | types: 5 | - published 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build-and-deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout 13 | uses: actions/checkout@v2.3.1 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.8 19 | 20 | - name: Set up Node 21 | uses: actions/setup-node@v2 22 | with: 23 | node-version: "16" 24 | 25 | - name: Generate pdoc output & Build Docusaurus website 26 | run: | 27 | pip install -e .[all] 28 | pdoc -f --template-dir docs/pdoc/templates -o docs/docs classy 29 | mv docs/docs/classy docs/docs/api 30 | python docs/pdoc/pdoc_postprocess.py 31 | cd docs 32 | yarn install 33 | # builds https://gitlab.grnet.gr/terminology/docusaurus-terminology 34 | yarn docusaurus parse 35 | yarn docusaurus glossary 36 | # creates the build/ folder 37 | yarn build 38 | 39 | - name: Deploy to GitHub Pages 40 | if: success() 41 | uses: crazy-max/ghaction-github-pages@v2 42 | with: 43 | target_branch: gh-pages 44 | build_dir: docs/build 45 | env: 46 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ide stuff 132 | .idea 133 | .vscode 134 | 135 | # custom stuff 136 | data/* 137 | experiments/* 138 | !.placeholder 139 | 140 | # docs stuff 141 | docs/node_modules 142 | docs/build 143 | docs/.docusaurus 144 | docs/.cache-loader 145 | docs/.DS_Store 146 | docs/.env.local 147 | docs/.env.development.local 148 | docs/.env.test.local 149 | docs/.env.production.local 150 | docs/npm-debug.log* 151 | docs/yarn-debug.log* 152 | docs/yarn-error.log* 153 | /docs/docs/api/ 154 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.8 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.3.0 12 | hooks: 13 | - id: check-yaml 14 | - id: check-case-conflict 15 | - id: detect-private-key 16 | - id: end-of-file-fixer 17 | - id: trailing-whitespace 18 | # - repo: https://github.com/PyCQA/pydocstyle todo decomment and add api docs 19 | # rev: 6.1.1 20 | # hooks: 21 | # - id: pydocstyle 22 | # name: Check docs 23 | # files: ^classy/ 24 | # language_version: python3.8 25 | # args: ["--convention=google"] 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.10.1 28 | hooks: 29 | - id: isort 30 | name: Format imports 31 | files: ^classy/ 32 | args: ["--profile", "black"] 33 | - repo: https://github.com/psf/black 34 | rev: 22.6.0 35 | hooks: 36 | - id: black 37 | name: Format code 38 | files: ^classy/ 39 | - repo: https://github.com/asottile/blacken-docs 40 | rev: v1.12.1 41 | hooks: 42 | - id: blacken-docs 43 | name: Format code in docs 44 | additional_dependencies: [ black==21.12b0 ] 45 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to contribute to `classy`? 2 | 3 | Everything you do for the community of `classy` is a valuable contribution! This includes: 4 | 5 | - 🐞 Reporting bug fixes 6 | - :zap: Proposing new features 7 | - :scroll: Improving the docs 8 | - :book: Adding [examples](https://github.com/sunglasses-ai/classy-examples) 9 | - :question: Solving open issues 10 | 11 | ### 🐞 Did you find a bug? 12 | 13 | First, please [quickly search](https://github.com/sunglasses-ai/classy/issues) our issues to see whether it was already reported, in which case you could comment on the existing issue. 14 | 15 | Otherwise, open [a new GitHub issue](https://github.com/sunglasses-ai/classy/issues/new). 16 | 17 | Ideally, your issue's title should briefly describe the bug, while the description should contain as much relevant information as possible: 18 | - a minimal example to reproduce the issue 19 | - the observed behaviour 20 | - the expected behaviour 21 | - environment information (classy, pytorch, pytorch-lightning, huggingface & any relevant package versions, NVIDIA driver version) 22 | - [*optional*] a code sample / test case / example demonstrating the expected behaviour 23 | 24 | ### :zap: Do you have a suggestion for an enhancement? 25 | 26 | We track enhancement requests via GitHub issues. Similarly as for bugs, **before you create a new issue**, please [quickly search](https://github.com/sunglasses-ai/classy/issues) our issues to see whether it was suggested already, in which case you would comment on the existing issue (even to suggest improvement / changes or to help us with reviewing the code!). 27 | 28 | When creating your enhancement request, please: 29 | 30 | - Provide a clear title and description. 31 | - Provide a brief explanation of why the enhancement would be useful, possibly with examples. 32 | - If you're not sure of how you would go about it, be more vague and request an open discussion of the feature design / implementation. 33 | 34 | ### :scroll: Improving the docs 35 | 36 | Did you find a mistake or do you want to improve the current documentation in some way? Again, thank you! Documentation and examples are what makes a software really usable, and you are contributing to making `classy` a better framework this way. You can simply open a new issue with the section(s) that you would want to work on / see improved, and we'll get there together. 37 | 38 | ### :book: Do you have an example you want to share? 39 | 40 | Please head over to [sunglasses-ai/classy-examples](https://github.com/sunglasses-ai/classy-examples) and follow the guides there. 41 | If you have any troubles making a contribution there, open [a new issue there](https://github.com/sunglasses-ai/classy-examples/issues/new) and we'll be in touch as soon as possible! 42 | 43 | ### :question: Solving open issues 44 | 45 | Did you stumble onto something (a feature request, a bug, a question) you think you could provide a solution / answer to? Please do! 46 | Anyone is more than welcome to collaborate and contribute in as many ways as possible, and we thank every contributor for their effort in making `classy` a better framework. 47 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 2 | 3 | WORKDIR /root 4 | 5 | # install utilities 6 | 7 | RUN \ 8 | DEBIAN_FRONTEND="noninteractive" apt-get update && \ 9 | DEBIAN_FRONTEND="noninteractive" apt-get install -y rsync byobu tmux vim nano htop wget curl git lm-sensors openssh-server && \ 10 | mkdir .ssh 11 | EXPOSE 22 12 | 13 | # install conda 14 | RUN \ 15 | wget -O miniconda.sh "https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh" && \ 16 | bash miniconda.sh -b -p /root/miniconda3 && \ 17 | rm -f miniconda.sh 18 | ENV PATH=/root/miniconda3/bin:${PATH} 19 | RUN conda update -y conda && conda init 20 | 21 | # setup env 22 | WORKDIR /classy 23 | COPY . . 24 | RUN \ 25 | bash -c "source ~/miniconda3/etc/profile.d/conda.sh && printf 'classy\n3.8\n1.10.2\n11.3\nN\n' | bash setup.sh" 26 | 27 | # standard cmd 28 | CMD [ "/bin/bash" ] 29 | -------------------------------------------------------------------------------- /classy/__init__.py: -------------------------------------------------------------------------------- 1 | from classy.version import VERSION as __version__ # noqa 2 | -------------------------------------------------------------------------------- /classy/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/data/__init__.py -------------------------------------------------------------------------------- /classy/data/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/data/dataset/__init__.py -------------------------------------------------------------------------------- /classy/data/dataset/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/data/dataset/hf/__init__.py -------------------------------------------------------------------------------- /classy/data/dataset/hf/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | from transformers import AutoTokenizer 4 | 5 | from classy.data.dataset.base import BaseDataset 6 | from classy.utils.log import get_project_logger 7 | 8 | logger = get_project_logger(__name__) 9 | 10 | 11 | class HFBaseDataset(BaseDataset): 12 | _shared_state = {} 13 | 14 | def __init__( 15 | self, 16 | transformer_model: str, 17 | additional_special_tokens: Optional[List[str]] = None, 18 | truncation: Union[bool, str] = False, 19 | max_length: int = -1, 20 | **kwargs, 21 | ): 22 | 23 | if "tokenizer" not in self._shared_state: 24 | self._shared_state[ 25 | "tokenizer", tuple(additional_special_tokens or []) 26 | ] = AutoTokenizer.from_pretrained( 27 | transformer_model, 28 | use_fast=True, 29 | add_prefix_space=True, 30 | additional_special_tokens=list(additional_special_tokens) 31 | if additional_special_tokens is not None 32 | else None, 33 | ) 34 | self.tokenizer = self._shared_state[ 35 | "tokenizer", tuple(additional_special_tokens or []) 36 | ] 37 | self.transformer_model = transformer_model 38 | self.truncation = truncation 39 | self.additional_special_tokens = additional_special_tokens 40 | 41 | super().__init__( 42 | batching_fields=kwargs.pop("batching_fields") 43 | if "batching_fields" in kwargs 44 | else ["input_ids"], 45 | max_length=max_length 46 | if max_length != -1 47 | else self.tokenizer.model_max_length, 48 | **kwargs, 49 | ) 50 | -------------------------------------------------------------------------------- /classy/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/evaluation/__init__.py -------------------------------------------------------------------------------- /classy/evaluation/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from classy.data.data_drivers import ClassySample 4 | 5 | 6 | class Evaluation: 7 | def __call__( 8 | self, 9 | path: str, 10 | predicted_samples: List[ClassySample], 11 | ) -> Dict: 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /classy/evaluation/generation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | import nltk 4 | from datasets import load_metric 5 | 6 | from classy.data.data_drivers import GenerationSample 7 | from classy.evaluation.base import Evaluation 8 | from classy.utils.log import get_project_logger 9 | 10 | logger = get_project_logger(__name__) 11 | 12 | 13 | class RougeEvaluation(Evaluation): 14 | def __init__(self): 15 | self.rouge = load_metric("rouge") 16 | 17 | def __call__(self, path: str, predicted_samples: List[GenerationSample]) -> Dict: 18 | assert all( 19 | sample.reference_annotation is not None for sample in predicted_samples 20 | ) 21 | 22 | gold_summaries = [sample.reference_annotation for sample in predicted_samples] 23 | pred_summaries = [sample.predicted_annotation for sample in predicted_samples] 24 | 25 | # process summaries 26 | # todo maybe improve with something like ptb/stanza/some real sentence tokenizer 27 | gold_summaries = [ 28 | "\n".join(nltk.sent_tokenize(gs.replace(". ", "\n").rstrip())) 29 | for gs in gold_summaries 30 | ] 31 | pred_summaries = [ 32 | "\n".join(nltk.sent_tokenize(ps.replace(". ", "\n").rstrip())) 33 | for ps in pred_summaries 34 | ] 35 | 36 | results = self.rouge.compute( 37 | predictions=pred_summaries, references=gold_summaries 38 | ) 39 | scores = {} 40 | 41 | for k, v in results.items(): 42 | scores[k] = v.mid.fmeasure 43 | 44 | return scores 45 | 46 | 47 | class SacreBleuEvaluation(Evaluation): 48 | def __init__(self): 49 | self.bleu = load_metric("sacrebleu") 50 | 51 | def __call__( 52 | self, 53 | path: str, 54 | predicted_samples: List[GenerationSample], 55 | ): 56 | 57 | assert all( 58 | sample.reference_annotation is not None for sample in predicted_samples 59 | ) 60 | 61 | references = [sample.reference_annotation for sample in predicted_samples] 62 | predictions = [sample.predicted_annotation for sample in predicted_samples] 63 | 64 | results = self.bleu.compute( 65 | predictions=predictions, references=[[r] for r in references] 66 | ) 67 | return {"bleu": results["score"]} 68 | -------------------------------------------------------------------------------- /classy/evaluation/simple.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 4 | 5 | from classy.data.data_drivers import ( 6 | GenerationSample, 7 | QASample, 8 | SentencePairSample, 9 | SequenceSample, 10 | TokensSample, 11 | ) 12 | from classy.evaluation.base import Evaluation 13 | from classy.utils.commons import flatten 14 | 15 | 16 | def accuracy(gold, pred) -> float: 17 | return accuracy_score(gold, pred) 18 | 19 | 20 | def p_r_f_support(gold, pred) -> Dict[str, float]: 21 | result = {} 22 | for avg in ["micro", "macro", "weighted"]: 23 | p, r, f1, _ = precision_recall_fscore_support(gold, pred, average=avg) 24 | for k, v in zip(["precision", "recall", "f1"], [p, r, f1]): 25 | result[f"{avg}_{k}"] = v 26 | return result 27 | 28 | 29 | class SequenceSimpleEvaluation(Evaluation): 30 | def __call__(self, path: str, predicted_samples: List[SequenceSample]) -> Dict: 31 | gold = [sample.reference_annotation for sample in predicted_samples] 32 | pred = [sample.predicted_annotation for sample in predicted_samples] 33 | return {"accuracy": accuracy(gold, pred), **p_r_f_support(gold, pred)} 34 | 35 | 36 | class SentencePairSimpleEvaluation(Evaluation): 37 | def __call__(self, path: str, predicted_samples: List[SentencePairSample]) -> Dict: 38 | gold = [sample.reference_annotation for sample in predicted_samples] 39 | pred = [sample.predicted_annotation for sample in predicted_samples] 40 | return {"accuracy": accuracy(gold, pred), **p_r_f_support(gold, pred)} 41 | 42 | 43 | class TokenSimpleEvaluation(Evaluation): 44 | def __call__(self, path: str, predicted_samples: List[TokensSample]) -> Dict: 45 | gold = [sample.reference_annotation for sample in predicted_samples] 46 | pred = [sample.predicted_annotation for sample in predicted_samples] 47 | gold, pred = flatten(gold), flatten(pred) 48 | return {"accuracy": accuracy(gold, pred), **p_r_f_support(gold, pred)} 49 | 50 | 51 | class QASimpleEvaluation(Evaluation): 52 | """ 53 | Computes a simple exact-match accuracy 54 | """ 55 | 56 | def __call__(self, path: str, predicted_samples: List[QASample]) -> Dict: 57 | n, d = 0, 0 58 | for sample in predicted_samples: 59 | d += 1 60 | if sample.reference_annotation == sample.predicted_annotation: 61 | n += 1 62 | return {"exact-match-accuracy": f"{n / d:.2f}"} 63 | 64 | 65 | class GenerationSimpleEvaluation(Evaluation): 66 | """ 67 | Computes a simple full-text accuracy 68 | """ 69 | 70 | def __call__(self, path: str, predicted_samples: List[GenerationSample]) -> Dict: 71 | n, d = 0, 0 72 | for sample in predicted_samples: 73 | d += 1 74 | if sample.reference_annotation == sample.predicted_annotation: 75 | n += 1 76 | return {"full-generation-accuracy": f"{n / d:.2f}"} 77 | -------------------------------------------------------------------------------- /classy/evaluation/span.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from datasets import load_metric 4 | 5 | from classy.data.data_drivers import TokensSample 6 | from classy.evaluation.base import Evaluation 7 | 8 | 9 | class SeqEvalSpanEvaluation(Evaluation): 10 | def __init__(self): 11 | self.backend_metric = load_metric("seqeval") 12 | 13 | def __call__(self, path: str, predicted_samples: List[TokensSample]) -> Dict: 14 | 15 | metric_out = self.backend_metric.compute( 16 | predictions=[sample.predicted_annotation for sample in predicted_samples], 17 | references=[sample.reference_annotation for sample in predicted_samples], 18 | ) 19 | p, r, f1 = ( 20 | metric_out["overall_precision"], 21 | metric_out["overall_recall"], 22 | metric_out["overall_f1"], 23 | ) 24 | 25 | return {"precision": p, "recall": r, "f1": f1} 26 | -------------------------------------------------------------------------------- /classy/evaluation/squad.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Tuple 2 | 3 | from datasets import load_metric 4 | 5 | from classy.data.data_drivers import QASample 6 | from classy.evaluation.base import Evaluation 7 | 8 | 9 | class SQuADV1Evaluation(Evaluation): 10 | def __init__(self): 11 | self.squad = load_metric("squad") 12 | 13 | def __call__( 14 | self, 15 | path: str, 16 | predicted_samples: List[QASample], 17 | ) -> Dict: 18 | 19 | pred = [ 20 | { 21 | "id": sample.squad_id, 22 | "prediction_text": sample.context[ 23 | sample.predicted_annotation[0] : sample.predicted_annotation[1] 24 | ], 25 | } 26 | for sample in predicted_samples 27 | ] 28 | gold = [ 29 | {"id": sample.squad_id, "answers": sample.full_answers} 30 | for sample in predicted_samples 31 | ] 32 | 33 | assert all( 34 | g["id"] is not None and g["answers"] is not None for g in gold 35 | ), f"Expected 'id' and 'answers' in gold, but found None" 36 | 37 | results = self.squad.compute(predictions=pred, references=gold) 38 | exact_match, f1 = results["exact_match"], results["f1"] 39 | 40 | return {"exact_match": exact_match, "f1": f1} 41 | -------------------------------------------------------------------------------- /classy/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/optim/__init__.py -------------------------------------------------------------------------------- /classy/optim/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/optim/optimizers/__init__.py -------------------------------------------------------------------------------- /classy/pl_callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/pl_callbacks/__init__.py -------------------------------------------------------------------------------- /classy/pl_callbacks/best_checkpoint.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | 4 | from pytorch_lightning.callbacks import ModelCheckpoint 5 | 6 | 7 | class ModelCheckpointWithBest(ModelCheckpoint): 8 | """ 9 | A callback that explicitly saves the best checkpoint with best.ckpt. 10 | Note that the best checkpoint is duplicated, rather than linked, in best.ckpt 11 | """ 12 | 13 | CHECKPOINT_NAME_BEST = "best.ckpt" 14 | 15 | def on_validation_end(self, trainer, pl_module): 16 | super().on_validation_end(trainer, pl_module) 17 | if self.best_model_path == "": 18 | return 19 | orig_best = Path(self.best_model_path) 20 | shutil.copyfile(orig_best, orig_best.parent / self.CHECKPOINT_NAME_BEST) 21 | -------------------------------------------------------------------------------- /classy/pl_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/pl_modules/__init__.py -------------------------------------------------------------------------------- /classy/pl_modules/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, NamedTuple, Optional 2 | 3 | import hydra 4 | import omegaconf 5 | import pytorch_lightning as pl 6 | import torch 7 | 8 | from classy.pl_modules.mixins.prediction import PredictionMixin 9 | from classy.pl_modules.mixins.saving import SavingMixin 10 | from classy.utils.vocabulary import Vocabulary 11 | 12 | 13 | class ClassificationOutput(NamedTuple): 14 | logits: torch.Tensor 15 | probabilities: torch.Tensor 16 | predictions: torch.Tensor 17 | loss: Optional[torch.Tensor] = None 18 | 19 | 20 | class ClassyPLModule(SavingMixin, PredictionMixin, pl.LightningModule): 21 | def __init__( 22 | self, vocabulary: Optional[Vocabulary], optim_conf: omegaconf.DictConfig 23 | ): 24 | super().__init__() 25 | self.vocabulary: Vocabulary = vocabulary 26 | self._optim_conf = optim_conf 27 | 28 | def load_prediction_params(self, prediction_params: Dict): 29 | pass 30 | 31 | def forward(self, *args, **kwargs) -> ClassificationOutput: 32 | raise NotImplementedError 33 | 34 | def configure_optimizers(self): 35 | """ """ 36 | return hydra.utils.instantiate(self._optim_conf, _recursive_=False)(module=self) 37 | -------------------------------------------------------------------------------- /classy/pl_modules/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/pl_modules/hf/__init__.py -------------------------------------------------------------------------------- /classy/pl_modules/mixins/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/pl_modules/mixins/__init__.py -------------------------------------------------------------------------------- /classy/pl_modules/mixins/prediction.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Iterator, Union 2 | 3 | import hydra 4 | import torch 5 | from omegaconf import DictConfig 6 | from pytorch_lightning.utilities import move_data_to_device 7 | from torch.cuda.amp import autocast 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from classy.data.data_drivers import ClassySample 12 | from classy.utils.log import get_project_logger 13 | 14 | logger = get_project_logger(__name__) 15 | 16 | 17 | class PredictionMixin: 18 | """ 19 | Simple Mixin to model the prediction behavior of a classy.pl_modules.base.ClassyPLModule. 20 | """ 21 | 22 | def predict( 23 | self, 24 | samples: Iterator[ClassySample], 25 | dataset_conf: Union[Dict, DictConfig], 26 | token_batch_size: int = 1024, 27 | progress_bar: bool = False, 28 | **kwargs 29 | ) -> Iterator[ClassySample]: 30 | """ 31 | Exposed method of each classy.pl_modules.base.ClassyPLModule invoked to annotate a collection of input 32 | samples. 33 | 34 | Args: 35 | samples: iterator over the samples that have to be annotated. 36 | dataset_conf: the dataset configuration used to instantiate the Dataset with hydra. 37 | token_batch_size: the maximum number of tokens in each batch. 38 | progress_bar: whether or not to show a progress bar of the prediction process. 39 | **kwargs: additional parameters. (Future proof atm) 40 | 41 | Returns: 42 | An iterator over the input samples with the predicted annotation updated. 43 | 44 | """ 45 | 46 | # setup infrastructure to re-yield in order 47 | def samples_it(): 48 | for i, sample in enumerate(samples): 49 | assert sample._mixin_prediction_position is None 50 | sample._mixin_prediction_position = i 51 | yield sample 52 | 53 | next_prediction_position = 0 54 | position2predicted_sample = {} 55 | 56 | # instantiate dataset 57 | dataset_conf["tokens_per_batch"] = token_batch_size 58 | dataset = hydra.utils.instantiate( 59 | dataset_conf, samples=samples_it(), vocabulary=self.vocabulary 60 | ) 61 | 62 | # instantiate dataloader 63 | iterator = DataLoader(dataset, batch_size=None, num_workers=0) 64 | if progress_bar: 65 | iterator = tqdm(iterator, desc="Predicting") 66 | 67 | for batch in iterator: 68 | # do batch predict 69 | with torch.inference_mode(): 70 | with autocast(enabled=True): # todo: always enabled? 71 | batch = move_data_to_device(batch, self.device) 72 | batch_out = self.batch_predict(**batch) 73 | # update prediction position position 74 | for sample in batch_out: 75 | position2predicted_sample[sample._mixin_prediction_position] = sample 76 | # yield 77 | while next_prediction_position in position2predicted_sample: 78 | yield position2predicted_sample[next_prediction_position] 79 | del position2predicted_sample[next_prediction_position] 80 | next_prediction_position += 1 81 | 82 | if len(position2predicted_sample) > 0: 83 | logger.warning( 84 | "It seems samples have been discarded in your dataset. This means that you WON'T have a prediction for each input sample. Prediction order will also be partially disrupted" 85 | ) 86 | for k, v in sorted(position2predicted_sample.items(), key=lambda x: x[0]): 87 | yield v 88 | 89 | if progress_bar: 90 | iterator.close() 91 | 92 | def batch_predict(self, *args, **kwargs) -> Iterator[ClassySample]: 93 | """ 94 | General method that must be implemented by each classy.pl_modules.base.ClassyPLModule in order to perform 95 | batch prediction. 96 | 97 | Returns: 98 | An iterator over a collection of samples with the predicted annotation updated with the model outputs. 99 | """ 100 | raise NotImplementedError 101 | -------------------------------------------------------------------------------- /classy/pl_modules/mixins/saving.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | 5 | import pytorch_lightning as pl 6 | from omegaconf import DictConfig 7 | 8 | from classy.data.data_drivers import get_data_driver 9 | from classy.data.data_modules import ClassyDataModule 10 | from classy.utils.hydra import fix_paths 11 | 12 | 13 | class SavingMixin: 14 | def save_resources_and_update_config( 15 | self, 16 | conf: DictConfig, 17 | working_folder: str, 18 | experiment_folder: str, 19 | data_module: ClassyDataModule, 20 | ): 21 | 22 | working_folder = Path(working_folder) 23 | experiment_folder = Path(experiment_folder) 24 | 25 | # save examples 26 | source, examples = data_module.get_examples(n=5) 27 | experiment_folder.joinpath("data").mkdir(exist_ok=True) 28 | get_data_driver(self.task, "jsonl").save( 29 | examples, str(experiment_folder / "data" / f"examples-{source}.jsonl") 30 | ) 31 | 32 | # move every paths into "./resources/" and overwrite the config 33 | Path(experiment_folder / "resources").mkdir() 34 | 35 | # a same resource might be used by multiple components at the same time 36 | # avoid copying them multiple times 37 | colored_paths = set() 38 | 39 | def fix_with_copy_side_effect(path): 40 | input_path = Path(path) 41 | assert input_path.exists() 42 | output_path = ( 43 | experiment_folder / "resources" / input_path.relative_to(working_folder) 44 | ) 45 | output_path.parent.mkdir(parents=True, exist_ok=True) 46 | if input_path not in colored_paths: 47 | if Path(input_path).is_dir(): 48 | shutil.copytree(input_path, output_path) 49 | else: 50 | shutil.copy(input_path, output_path) 51 | colored_paths.add(input_path) 52 | return str(output_path.relative_to(experiment_folder)) 53 | 54 | fix_paths( 55 | conf.model, 56 | check_fn=lambda path: Path(path).exists(), 57 | fix_fn=fix_with_copy_side_effect, 58 | ) 59 | fix_paths( 60 | conf.prediction, 61 | check_fn=lambda path: Path(path).exists(), 62 | fix_fn=fix_with_copy_side_effect, 63 | ) 64 | if "evaluation" in conf: 65 | fix_paths( 66 | conf.evaluation, 67 | check_fn=lambda path: Path(path).exists(), 68 | fix_fn=fix_with_copy_side_effect, 69 | ) 70 | -------------------------------------------------------------------------------- /classy/pl_modules/mixins/task.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from classy.data.data_drivers import ( 4 | GENERATION, 5 | JSONL, 6 | QA, 7 | SENTENCE_PAIR, 8 | SEQUENCE, 9 | TOKEN, 10 | ClassySample, 11 | GenerationSample, 12 | QASample, 13 | SentencePairSample, 14 | SequenceSample, 15 | TokensSample, 16 | get_data_driver, 17 | ) 18 | from classy.pl_modules.mixins.task_serve import ( 19 | GenerationTaskServeMixin, 20 | QATaskServeMixin, 21 | SentencePairTaskServeMixin, 22 | SequenceTaskServeMixin, 23 | TokenTaskServeMixin, 24 | ) 25 | from classy.pl_modules.mixins.task_ui import ( 26 | GenerationTaskUIMixin, 27 | QATaskUIMixin, 28 | SentencePairTaskUIMixin, 29 | SequenceTaskUIMixin, 30 | TokenTaskUIMixin, 31 | ) 32 | 33 | 34 | class TaskMixin: 35 | def read_input_from_bash( 36 | self, 37 | ) -> ClassySample: 38 | raise NotImplementedError 39 | 40 | @property 41 | def task(self) -> str: 42 | raise NotImplementedError 43 | 44 | 45 | class SequenceTask(SequenceTaskServeMixin, SequenceTaskUIMixin, TaskMixin): 46 | 47 | __data_driver = get_data_driver(SEQUENCE, JSONL) 48 | 49 | def read_input_from_bash(self) -> SequenceSample: 50 | sequence = input("Enter sequence text: ").strip() 51 | sample = json.dumps({"sequence": sequence}) 52 | return next(self.__data_driver.read([sample])) 53 | 54 | @property 55 | def task(self) -> str: 56 | return SEQUENCE 57 | 58 | 59 | class TokensTask(TokenTaskServeMixin, TokenTaskUIMixin, TaskMixin): 60 | __data_driver = get_data_driver(TOKEN, JSONL) 61 | 62 | def read_input_from_bash(self) -> TokensSample: 63 | tokens = input("Enter space-separated tokens: ").strip() 64 | sample = json.dumps({"tokens": tokens.split(" ")}) 65 | return next(self.__data_driver.read([sample])) 66 | 67 | @property 68 | def task(self) -> str: 69 | return TOKEN 70 | 71 | 72 | class GenerationTask(GenerationTaskServeMixin, GenerationTaskUIMixin, TaskMixin): 73 | __data_driver = get_data_driver(GENERATION, JSONL) 74 | 75 | def read_input_from_bash(self) -> GenerationSample: 76 | source_sequence = input("Enter source sequence text: ").strip() 77 | source_language = ( 78 | input("Enter source language (leave empty to set it to None): ").strip() 79 | or None 80 | ) 81 | target_language = ( 82 | input("Enter target language (leave empty to set it to None): ").strip() 83 | or None 84 | ) 85 | sample = json.dumps( 86 | dict( 87 | source_sequence=source_sequence, 88 | source_language=source_language, 89 | target_language=target_language, 90 | ) 91 | ) 92 | return next(self.__data_driver.read([sample])) 93 | 94 | @property 95 | def task(self) -> str: 96 | return GENERATION 97 | 98 | 99 | class SentencePairTask(SentencePairTaskServeMixin, SentencePairTaskUIMixin, TaskMixin): 100 | __data_driver = get_data_driver(SENTENCE_PAIR, JSONL) 101 | 102 | def read_input_from_bash(self) -> SentencePairSample: 103 | sentence1 = input("Enter first sentence: ").strip() 104 | sentence2 = input("Enter second sentence: ").strip() 105 | sample = json.dumps({"sentence1": sentence1, "sentence2": sentence2}) 106 | return next(self.__data_driver.read([sample])) 107 | 108 | @property 109 | def task(self) -> str: 110 | return SENTENCE_PAIR 111 | 112 | 113 | class QATask(QATaskServeMixin, QATaskUIMixin, TaskMixin): 114 | __data_driver = get_data_driver(QA, JSONL) 115 | 116 | def read_input_from_bash(self) -> QASample: 117 | question = input("Enter question: ").strip() 118 | context = input("Enter context: ").strip() 119 | sample = json.dumps({"question": question, "context": context}) 120 | return next(self.__data_driver.read([sample])) 121 | 122 | @property 123 | def task(self) -> str: 124 | return QA 125 | -------------------------------------------------------------------------------- /classy/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/scripts/__init__.py -------------------------------------------------------------------------------- /classy/scripts/cli/demo.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from classy.scripts.cli.utils import ( 4 | autocomplete_model_path, 5 | checkpoint_path_from_user_input, 6 | get_device, 7 | ) 8 | from classy.utils.help_cli import HELP_MODEL_PATH, HELP_PREDICTION_PARAMS 9 | from classy.utils.optional_deps import requires 10 | 11 | 12 | def populate_parser(parser: ArgumentParser): 13 | parser.add_argument( 14 | "model_path", 15 | type=checkpoint_path_from_user_input, 16 | help=HELP_MODEL_PATH, 17 | ).completer = autocomplete_model_path 18 | parser.add_argument( 19 | "-p", 20 | "--port", 21 | type=int, 22 | default=8000, 23 | help="The port where the streamlit demo will be exposed.", 24 | ) 25 | parser.add_argument( 26 | "-d", 27 | "--device", 28 | default=None, 29 | help="On which device the model for the demo will be loaded. If not provided, classy will try to infer the desired behavior from the available environment.", 30 | ) 31 | parser.add_argument( 32 | "--prediction-params", type=str, default=None, help=HELP_PREDICTION_PARAMS 33 | ) 34 | 35 | 36 | def get_parser(subparser=None) -> ArgumentParser: 37 | # subparser: Optional[argparse._SubParsersAction] 38 | 39 | parser_kwargs = dict( 40 | name="demo", 41 | description="expose a demo of a classy model with Streamlit", 42 | help="Expose a demo of a classy model with Streamlit.", 43 | ) 44 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 45 | **parser_kwargs 46 | ) 47 | 48 | populate_parser(parser) 49 | 50 | return parser 51 | 52 | 53 | def parse_args(): 54 | return get_parser().parse_args() 55 | 56 | 57 | @requires("streamlit", "demo") 58 | def main(args): 59 | # import here to avoid importing before needed 60 | import sys 61 | 62 | from streamlit.cli import main as st_main 63 | 64 | device = get_device(args.device) 65 | 66 | # script params 67 | script_params = [args.model_path] 68 | if device is not None and device != -1: 69 | # todo ugly workaround for streamlit which interprets -1 as a streamlit param) 70 | script_params += ["cuda_device", str(device)] 71 | if args.prediction_params is not None: 72 | script_params += ["prediction_params", args.prediction_params] 73 | 74 | sys.argv = [ 75 | "streamlit", 76 | "run", 77 | # __file__ points to this file's location, even when pip installed. 78 | # given our code structure (this file is [...]/classy/scripts/cli/demo.py), 79 | # if we replace /cli/ with /model/ we get the actual streamlit python file we need to run. 80 | __file__.replace("/cli/", "/model/"), 81 | *script_params, 82 | "--server.port", 83 | str(args.port), 84 | ] 85 | 86 | sys.exit(st_main()) 87 | 88 | 89 | if __name__ == "__main__": 90 | main(parse_args()) 91 | -------------------------------------------------------------------------------- /classy/scripts/cli/describe.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from classy.data.data_drivers import GENERATION, QA, SENTENCE_PAIR, SEQUENCE, TOKEN 4 | from classy.utils.help_cli import HELP_TASKS 5 | from classy.utils.optional_deps import requires 6 | 7 | 8 | def populate_parser(parser: ArgumentParser): 9 | parser.add_argument( 10 | "task", 11 | choices=[SEQUENCE, SENTENCE_PAIR, TOKEN, QA, GENERATION], 12 | help=HELP_TASKS, 13 | ) 14 | parser.add_argument( 15 | "--dataset", 16 | required=True, 17 | help="The dataset you want to describe (run statistics on).", 18 | ) 19 | parser.add_argument( 20 | "--tokenize", 21 | default=None, 22 | help="Indicates the language of the dataset in order to select " 23 | "the correct tokenizer. Must be a valid language code for " 24 | "the sacremoses tokenizer url: 'https://github.com/alvations/sacremoses'.", 25 | ) 26 | parser.add_argument( 27 | "-p", 28 | "--port", 29 | type=int, 30 | default=8000, 31 | help="The port where the streamlit demo will be exposed.", 32 | ) 33 | 34 | 35 | def get_parser(subparser=None) -> ArgumentParser: 36 | parser_kwargs = dict( 37 | name="describe", 38 | description="run several statistics on the input dataset and expose them on a streamlit page", 39 | help="Run several statistics on the input dataset and expose them on a streamlit page.", 40 | ) 41 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 42 | **parser_kwargs 43 | ) 44 | 45 | populate_parser(parser) 46 | 47 | return parser 48 | 49 | 50 | def parse_args(): 51 | return get_parser().parse_args() 52 | 53 | 54 | @requires("streamlit", "describe") 55 | def main(args): 56 | # import here to avoid importing before needed 57 | import sys 58 | 59 | from streamlit.cli import main as st_main 60 | 61 | # script params 62 | script_params = [args.task, args.dataset] 63 | 64 | if args.tokenize is not None: 65 | script_params += [args.tokenize] 66 | 67 | sys.argv = [ 68 | "streamlit", 69 | "run", 70 | # see classy/scripts/cli/demo.py for an explanation of this line :) 71 | __file__.replace("/cli/", "/model/"), 72 | *script_params, 73 | "--server.port", 74 | str(args.port), 75 | ] 76 | 77 | sys.exit(st_main()) 78 | 79 | 80 | if __name__ == "__main__": 81 | main(parse_args()) 82 | -------------------------------------------------------------------------------- /classy/scripts/cli/download.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def populate_parser(parser: ArgumentParser): 5 | parser.add_argument( 6 | "model_name", 7 | help="The model you want to download (use user@model for a specific model)", 8 | ) 9 | parser.add_argument( 10 | "--force-download", 11 | action="store_true", 12 | help="It will download the model even if you already have it in the " 13 | "cache. Usually required if you interrupted the previous download.", 14 | ) 15 | 16 | 17 | def get_parser(subparser=None) -> ArgumentParser: 18 | parser_kwargs = dict( 19 | name="download", 20 | description="download a pretrained model from sunglasses-ai's (or a user's) HuggingFace Hub", 21 | help="Download a pretrained model from sunglasses-ai's (or a user's) HuggingFace Hub.", 22 | ) 23 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 24 | **parser_kwargs 25 | ) 26 | 27 | populate_parser(parser) 28 | 29 | return parser 30 | 31 | 32 | def parse_args(): 33 | return get_parser().parse_args() 34 | 35 | 36 | def main(args): 37 | from classy.scripts.model.download import download 38 | 39 | download(args.model_name, args.force_download) 40 | 41 | 42 | if __name__ == "__main__": 43 | main(parse_args()) 44 | -------------------------------------------------------------------------------- /classy/scripts/cli/export.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def populate_parser(parser: ArgumentParser): 5 | parser.add_argument("model_name", help="The model you want to export") 6 | parser.add_argument( 7 | "--zip-name", 8 | help="Name of the output file. Defaults to classy-export-{model_name}.zip", 9 | ) 10 | parser.add_argument( 11 | "-ns", 12 | "--no-strip", 13 | action="store_true", 14 | default=False, 15 | help="Whether to strip the checkpoint of optimizer states, schedulers and callbacks. " 16 | "Should only do this if you're not planning on resuming training (i.e., for inference).", 17 | ) 18 | parser.add_argument( 19 | "-a", 20 | "--all-ckpts", 21 | action="store_true", 22 | default=False, 23 | help="Whether to include every checkpoint under the /checkpoints/ folder or just the `best.ckpt`.", 24 | ) 25 | 26 | 27 | def get_parser(subparser=None) -> ArgumentParser: 28 | parser_kwargs = dict( 29 | name="export", 30 | description="export a trained model as a zip file", 31 | help="Export a trained model as a zip file", 32 | ) 33 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 34 | **parser_kwargs 35 | ) 36 | 37 | populate_parser(parser) 38 | 39 | return parser 40 | 41 | 42 | def parse_args(): 43 | return get_parser().parse_args() 44 | 45 | 46 | def main(args): 47 | from classy.scripts.model.export import export 48 | 49 | export(args.model_name, args.no_strip, args.all_ckpts, args.zip_name) 50 | 51 | 52 | if __name__ == "__main__": 53 | main(parse_args()) 54 | -------------------------------------------------------------------------------- /classy/scripts/cli/import_.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def populate_parser(parser: ArgumentParser): 5 | parser.add_argument( 6 | "path", help="Path to the zip file with the model run you want to import" 7 | ) 8 | parser.add_argument( 9 | "--exp-dir", 10 | help="Path to the experiments folder where the exported model should be added. " 11 | "Optional, automatically inferred if running from a classy project root dir.", 12 | ) 13 | 14 | 15 | def get_parser(subparser=None) -> ArgumentParser: 16 | parser_kwargs = dict( 17 | name="import", 18 | description="import a previously exported trained model from a zip file", 19 | help="import a previously exported trained model from a zip file", 20 | ) 21 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 22 | **parser_kwargs 23 | ) 24 | 25 | populate_parser(parser) 26 | 27 | return parser 28 | 29 | 30 | def parse_args(): 31 | return get_parser().parse_args() 32 | 33 | 34 | def main(args): 35 | from classy.scripts.model.import_ import import_zip 36 | 37 | import_zip(args.path, args.exp_dir) 38 | 39 | 40 | if __name__ == "__main__": 41 | main(parse_args()) 42 | -------------------------------------------------------------------------------- /classy/scripts/cli/predict.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from argcomplete import FilesCompleter 4 | 5 | from classy.scripts.cli.utils import ( 6 | autocomplete_model_path, 7 | checkpoint_path_from_user_input, 8 | get_device, 9 | ) 10 | from classy.utils.help_cli import ( 11 | HELP_MODEL_PATH, 12 | HELP_PREDICTION_PARAMS, 13 | HELP_TOKEN_BATCH_SIZE, 14 | ) 15 | 16 | 17 | def populate_parser(parser: ArgumentParser): 18 | # TODO: would be cool to have it work with exp_name and add an optional --checkpoint-name flag (default=best.ckpt) 19 | # the user should not need to know what a checkpoint is :) 20 | 21 | subcmd = parser.add_subparsers( 22 | dest="subcmd", 23 | required=True, 24 | help="Whether you want to use the model interactively or to process a file.", 25 | ) 26 | interactive_parser = subcmd.add_parser("interactive") 27 | interactive_parser.add_argument( 28 | "model_path", type=checkpoint_path_from_user_input, help=HELP_MODEL_PATH 29 | ).completer = autocomplete_model_path 30 | interactive_parser.add_argument( 31 | "-d", 32 | "--device", 33 | default=None, 34 | help="The device where the dataset prediction will be run. If not provided, classy will try to infer the desired behavior from the available environment.", 35 | ) 36 | interactive_parser.add_argument( 37 | "--prediction-params", type=str, default=None, help="Path to prediction params." 38 | ) 39 | 40 | file_parser = subcmd.add_parser("file") 41 | file_parser.add_argument( 42 | "model_path", type=checkpoint_path_from_user_input, help=HELP_MODEL_PATH 43 | ).completer = autocomplete_model_path 44 | file_parser.add_argument( 45 | "file_path", help="The file containing the instances that you want to process." 46 | ).completer = FilesCompleter() 47 | file_parser.add_argument( 48 | "-d", 49 | "--device", 50 | default="gpu", 51 | help="The device you will use for the prediction.", 52 | ) 53 | file_parser.add_argument( 54 | "-o", 55 | "--output-path", 56 | required=True, 57 | help="The file where the predictions will be stored.", 58 | ).completer = FilesCompleter() 59 | file_parser.add_argument( 60 | "--prediction-params", type=str, default=None, help=HELP_PREDICTION_PARAMS 61 | ) 62 | file_parser.add_argument( 63 | "--token-batch-size", type=int, default=1024, help=HELP_TOKEN_BATCH_SIZE 64 | ) 65 | 66 | 67 | def get_parser(subparser=None) -> ArgumentParser: 68 | # subparser: Optional[argparse._SubParsersAction] 69 | 70 | parser_kwargs = dict( 71 | name="predict", 72 | description="predict with a model trained using classy", 73 | help="Predict with a model trained using classy.", 74 | ) 75 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 76 | **parser_kwargs 77 | ) 78 | 79 | populate_parser(parser) 80 | 81 | return parser 82 | 83 | 84 | def parse_args(): 85 | return get_parser().parse_args() 86 | 87 | 88 | def main(args): 89 | # import here to avoid importing torch before it's actually needed 90 | import torch 91 | 92 | from classy.scripts.model.predict import file_main, interactive_main 93 | 94 | subcmd = args.subcmd 95 | 96 | # read device 97 | device = args.device 98 | if device is None and torch.cuda.is_available(): 99 | device = 0 100 | device = get_device(device) 101 | 102 | if subcmd == "file": 103 | file_main( 104 | args.model_path, 105 | args.file_path, 106 | args.output_path, 107 | args.prediction_params, 108 | device, 109 | args.token_batch_size, 110 | ) 111 | elif subcmd == "interactive": 112 | interactive_main(args.model_path, args.prediction_params, device) 113 | else: 114 | raise NotImplementedError 115 | 116 | 117 | if __name__ == "__main__": 118 | main(parse_args()) 119 | -------------------------------------------------------------------------------- /classy/scripts/cli/serve.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from classy.scripts.cli.utils import ( 4 | autocomplete_model_path, 5 | checkpoint_path_from_user_input, 6 | get_device, 7 | ) 8 | from classy.utils.help_cli import ( 9 | HELP_MODEL_PATH, 10 | HELP_PREDICTION_PARAMS, 11 | HELP_TOKEN_BATCH_SIZE, 12 | ) 13 | 14 | 15 | def populate_parser(parser: ArgumentParser): 16 | parser.add_argument( 17 | "model_path", type=checkpoint_path_from_user_input, help=HELP_MODEL_PATH 18 | ).completer = autocomplete_model_path 19 | parser.add_argument( 20 | "-p", 21 | "--port", 22 | type=int, 23 | default=8000, 24 | help="The port where the REST api will be exposed.", 25 | ) 26 | parser.add_argument( 27 | "-d", 28 | "--device", 29 | default=None, 30 | help="On which device the model for the REST api will be loaded. If not provided, classy will try to infer the desired behavior from the available environment.", 31 | ) 32 | parser.add_argument( 33 | "--token-batch-size", type=int, default=1024, help=HELP_TOKEN_BATCH_SIZE 34 | ) 35 | parser.add_argument( 36 | "--prediction-params", type=str, default=None, help=HELP_PREDICTION_PARAMS 37 | ) 38 | 39 | 40 | def get_parser(subparser=None) -> ArgumentParser: 41 | # subparser: Optional[argparse._SubParsersAction] 42 | 43 | parser_kwargs = dict( 44 | name="serve", 45 | description="Expose a model trained with classy on a REST API", 46 | help="Expose a model trained with classy on a REST API.", 47 | ) 48 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 49 | **parser_kwargs 50 | ) 51 | 52 | populate_parser(parser) 53 | 54 | return parser 55 | 56 | 57 | def parse_args(): 58 | return get_parser().parse_args() 59 | 60 | 61 | def main(args): 62 | # import here to avoid importing torch before it's actually needed 63 | import torch 64 | 65 | from classy.scripts.model.serve import serve 66 | 67 | # read device 68 | device = args.device 69 | if device is None and torch.cuda.is_available(): 70 | device = 0 71 | device = get_device(device) 72 | 73 | serve( 74 | args.model_path, 75 | args.port, 76 | device, 77 | args.token_batch_size, 78 | prediction_params=args.prediction_params, 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | main(parse_args()) 84 | -------------------------------------------------------------------------------- /classy/scripts/cli/upload.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def populate_parser(parser: ArgumentParser): 5 | parser.add_argument("model_name", help="The model you want to upload.") 6 | parser.add_argument( 7 | "--organization", 8 | help="The name of the organization where you want to upload the model.", 9 | ) 10 | parser.add_argument( 11 | "--name", 12 | help="Optional name to use when uploading to the HuggingFace repository.", 13 | ) 14 | parser.add_argument( 15 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub." 16 | ) 17 | 18 | 19 | def get_parser(subparser=None) -> ArgumentParser: 20 | parser_kwargs = dict( 21 | name="upload", 22 | description="upload a pretrained model to your (or an organization's) HuggingFace Hub", 23 | help="Upload a pretrained model to your (or an organization's) HuggingFace Hub.", 24 | ) 25 | parser = (subparser.add_parser if subparser is not None else ArgumentParser)( 26 | **parser_kwargs 27 | ) 28 | 29 | populate_parser(parser) 30 | 31 | return parser 32 | 33 | 34 | def parse_args(): 35 | return get_parser().parse_args() 36 | 37 | 38 | def main(args): 39 | from classy.scripts.model.upload import upload 40 | 41 | upload(args.model_name, args.organization, args.name, args.commit) 42 | 43 | 44 | if __name__ == "__main__": 45 | main(parse_args()) 46 | -------------------------------------------------------------------------------- /classy/scripts/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import transformers 4 | 5 | # setting the transformers logging level to error with their own function 6 | transformers.logging.set_verbosity_error() 7 | 8 | # todo: we should analyze this error a little bit. Gigi can be useful here. 9 | # According to : https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 10 | # We should check if we really need FastTokenizers instead of the plain tokenizer. 11 | # turning off the tokenizers parallelism warning 12 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 13 | -------------------------------------------------------------------------------- /classy/scripts/model/export.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import zipfile 3 | from pathlib import Path 4 | from typing import Optional 5 | 6 | from classy.utils.experiment import Experiment, Run 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def strip_checkpoint( 12 | checkpoint_path: Path, 13 | destination: Path, 14 | keys_to_remove=("callbacks", "optimizer_states", "lr_schedulers"), 15 | ): 16 | import torch 17 | 18 | logger.debug(f"loading checkpoint {checkpoint_path}") 19 | ckpt = torch.load(checkpoint_path, map_location="cpu") 20 | 21 | for key in keys_to_remove: 22 | if ckpt.pop(key, None) is None: 23 | logger.debug(f"key {key} did not exist in checkpoint {checkpoint_path}") 24 | 25 | logger.debug(f"saving stripped checkpoint to {destination}") 26 | torch.save(ckpt, destination) 27 | 28 | 29 | def zip_run( 30 | run: Run, 31 | tmpdir: Path, 32 | zip_name: str = "model.zip", 33 | strip_ckpt: bool = True, 34 | is_export: bool = False, 35 | best_only: bool = True, 36 | ) -> Path: 37 | 38 | logger.debug(f"zipping run {run} to {tmpdir}") 39 | # creates a zip version of the provided Run (with a single stripped checkpoint) in a model.zip file under `tmpdir` 40 | run_dir = run.directory 41 | ckpt_path = tmpdir / "best.ckpt" 42 | zip_path = tmpdir / zip_name 43 | 44 | relative_directory = run.experiment.directory.parent if is_export else run_dir 45 | 46 | with zipfile.ZipFile(zip_path, "w") as zip_file: 47 | 48 | # fully zip the run directory maintaining its structure 49 | for file in run_dir.rglob("*.*"): 50 | relative_name = file.relative_to(relative_directory) 51 | 52 | if file.is_dir(): 53 | continue 54 | 55 | # skip checkpoints as we add a single checkpoint later 56 | if "checkpoints/" in str(relative_name): 57 | if best_only: 58 | continue 59 | 60 | if strip_ckpt: 61 | strip_checkpoint(file, ckpt_path) 62 | zip_file.write( 63 | ckpt_path, arcname=file.relative_to(relative_directory) 64 | ) 65 | continue 66 | 67 | zip_file.write(file, arcname=file.relative_to(relative_directory)) 68 | 69 | if best_only: 70 | ckpt_name = ( 71 | run_dir.relative_to(relative_directory) / "checkpoints/best.ckpt" 72 | ) 73 | 74 | if strip_ckpt: 75 | logger.debug("Stripping checkpoint before writing to zip file") 76 | strip_checkpoint(run.best_checkpoint, ckpt_path) 77 | logger.debug("Writing stripped checkpoint to zip file") 78 | zip_file.write(ckpt_path, arcname=ckpt_name) 79 | else: 80 | zip_file.write(run.best_checkpoint, arcname=ckpt_name) 81 | 82 | # remove stripped checkpoint file as it's inside the zip 83 | ckpt_path.unlink() 84 | 85 | return zip_path 86 | 87 | 88 | def export( 89 | model_name: str, 90 | no_strip: bool, 91 | all_ckpts: bool, 92 | zip_name: Optional[str] = None, 93 | ): 94 | exp = Experiment.from_name(model_name) 95 | if exp is None: 96 | print(f"No experiment named {model_name} found. Exiting...") 97 | return 98 | 99 | run = exp.last_valid_run 100 | if run is None: 101 | print(f"No valid run found for experiment {model_name}. Exiting...") 102 | return 103 | 104 | zip_name = zip_name or f"classy-export-{model_name}.zip" 105 | 106 | zip_file = zip_run( 107 | run, 108 | Path.cwd(), 109 | zip_name=zip_name, 110 | strip_ckpt=not no_strip, 111 | is_export=True, 112 | best_only=not all_ckpts, 113 | ) 114 | print(f"Model exported at {zip_file}") 115 | -------------------------------------------------------------------------------- /classy/scripts/model/import_.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional 4 | 5 | from classy.utils.experiment import Experiment 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | 10 | def unzip(zip_path: str, target: Path): 11 | import zipfile 12 | 13 | """ 14 | Unzip the contents of `zip_path` into `target`. 15 | """ 16 | logger.debug(f"Unzipping {zip_path} to {target}") 17 | with zipfile.ZipFile(zip_path) as f: 18 | f.extractall(target) 19 | 20 | 21 | def import_zip(zip_path: str, target_path: Optional[str] = None): 22 | if target_path is None: 23 | target = Experiment.try_get_experiment_dir() 24 | else: 25 | target = Path(target_path) 26 | unzip(zip_path, target) 27 | -------------------------------------------------------------------------------- /classy/scripts/model/predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from omegaconf import OmegaConf 5 | 6 | from classy.data.data_drivers import get_data_driver 7 | from classy.utils.lightning import ( 8 | load_classy_module_from_checkpoint, 9 | load_prediction_dataset_conf_from_checkpoint, 10 | ) 11 | 12 | 13 | def interactive_main( 14 | model_checkpoint_path: str, 15 | prediction_params: str, 16 | cuda_device: int, 17 | ): 18 | 19 | model = load_classy_module_from_checkpoint(model_checkpoint_path) 20 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu")) 21 | model.freeze() 22 | 23 | if prediction_params is not None: 24 | model.load_prediction_params(dict(OmegaConf.load(prediction_params))) 25 | 26 | dataset_conf = load_prediction_dataset_conf_from_checkpoint(model_checkpoint_path) 27 | 28 | # mock call to load resources 29 | next(model.predict(samples=[], dataset_conf=dataset_conf), None) 30 | 31 | while True: 32 | predicted_sample = next( 33 | model.predict( 34 | [model.read_input_from_bash()], 35 | dataset_conf=dataset_conf, 36 | ) 37 | ) 38 | print(f"\t# prediction: \t{predicted_sample.predicted_annotation}") 39 | 40 | 41 | def file_main( 42 | model_checkpoint_path: str, 43 | input_path: str, 44 | output_path: str, 45 | prediction_params: str, # todo: u sure? 46 | cuda_device: int, 47 | token_batch_size: int, 48 | ): 49 | 50 | model = load_classy_module_from_checkpoint(model_checkpoint_path) 51 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu")) 52 | model.freeze() 53 | 54 | if prediction_params is not None: 55 | model.load_prediction_params(dict(OmegaConf.load(prediction_params))) 56 | 57 | dataset_conf = load_prediction_dataset_conf_from_checkpoint(model_checkpoint_path) 58 | input_extension, output_extension = ( 59 | input_path.split(".")[-1], 60 | output_path.split(".")[-1], 61 | ) 62 | assert input_extension == output_extension, ( 63 | f"Having different input and output extensions is not currently a supported use case: " 64 | f"input {input_extension} != output {output_extension}" 65 | ) 66 | data_driver = get_data_driver(model.task, input_extension) 67 | 68 | data_driver.save( 69 | model.predict( 70 | data_driver.read_from_path(input_path), 71 | token_batch_size=token_batch_size, 72 | dataset_conf=dataset_conf, 73 | progress_bar=True, 74 | ), 75 | output_path, 76 | use_predicted_annotation=True, 77 | ) 78 | 79 | 80 | def main(): 81 | args = parse_args() 82 | if args.t: 83 | interactive_main( 84 | args.model_checkpoint, 85 | prediction_params=args.prediction_params, 86 | cuda_device=args.cuda_device, 87 | ) 88 | else: 89 | file_main( 90 | args.model_checkpoint, 91 | args.f, 92 | args.o, 93 | prediction_params=args.prediction_params, 94 | cuda_device=args.cuda_device, 95 | token_batch_size=args.token_batch_size, 96 | ) 97 | 98 | 99 | def parse_args(): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument( 102 | "model_checkpoint", type=str, help="Path to pl_modules checkpoint" 103 | ) 104 | parser.add_argument( 105 | "--prediction-params", type=str, default=None, help="Path to prediction params" 106 | ) 107 | parser.add_argument("--cuda-device", type=int, default=-1, help="Cuda device") 108 | # interactive params 109 | parser.add_argument("-t", action="store_true", help="Interactive mode") 110 | # file params 111 | parser.add_argument("-f", type=str, default=None, help="Input file") 112 | parser.add_argument("-o", type=str, default=None, help="Output file") 113 | parser.add_argument( 114 | "--token-batch-size", type=int, default=128, help="Token batch size" 115 | ) 116 | # return 117 | return parser.parse_args() 118 | 119 | 120 | if __name__ == "__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /classy/scripts/model/serve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import List, Optional 3 | 4 | import torch 5 | from omegaconf import OmegaConf 6 | 7 | from classy.utils.optional_deps import requires 8 | 9 | try: 10 | import uvicorn 11 | from fastapi import FastAPI 12 | except ImportError: 13 | uvicorn = None 14 | FastAPI = None 15 | 16 | from classy.utils.commons import get_local_ip_address 17 | from classy.utils.lightning import ( 18 | load_classy_module_from_checkpoint, 19 | load_prediction_dataset_conf_from_checkpoint, 20 | ) 21 | from classy.utils.log import get_project_logger 22 | 23 | logger = get_project_logger(__name__) 24 | 25 | 26 | @requires("uvicorn", "serve") 27 | @requires("fastapi", "serve") 28 | def serve( 29 | model_checkpoint_path: str, 30 | port: int, 31 | cuda_device: int, 32 | token_batch_size: int, 33 | prediction_params: Optional[str] = None, 34 | ): 35 | 36 | # load model 37 | model = load_classy_module_from_checkpoint(model_checkpoint_path) 38 | model.to(torch.device(cuda_device if cuda_device != -1 else "cpu")) 39 | model.freeze() 40 | 41 | if prediction_params is not None: 42 | model.load_prediction_params(dict(OmegaConf.load(prediction_params))) 43 | 44 | # load dataset conf 45 | dataset_conf = load_prediction_dataset_conf_from_checkpoint(model_checkpoint_path) 46 | 47 | # mock call to load resources 48 | next(model.predict(samples=[], dataset_conf=dataset_conf), None) 49 | 50 | # for better readability on the OpenAPI docs 51 | # why leak the inner confusing class names 52 | class InputSample(model.serve_input_class): 53 | pass 54 | 55 | class OutputSample(model.serve_output_class): 56 | pass 57 | 58 | app = FastAPI(title="Classy Serve") 59 | 60 | @app.post("/", response_model=List[OutputSample], description="Prediction endpoint") 61 | def predict(input_samples: List[InputSample]) -> List[OutputSample]: 62 | 63 | output_samples = [] 64 | 65 | for predicted_sample in model.predict( 66 | model=model, 67 | samples=[input_sample.unmarshal() for input_sample in input_samples], 68 | dataset_conf=dataset_conf, 69 | token_batch_size=token_batch_size, 70 | ): 71 | output_samples.append(OutputSample.marshal(predicted_sample)) 72 | 73 | return output_samples 74 | 75 | @app.get("/healthz") 76 | def healthz(): 77 | return "ok" 78 | 79 | local_ip_address = get_local_ip_address() 80 | print(f"Model exposed at http://{local_ip_address}:{port}") 81 | print(f"Remember you can checkout the API at http://{local_ip_address}:{port}/docs") 82 | uvicorn.run(app, host="0.0.0.0", port=port) 83 | 84 | 85 | def main(): 86 | args = parse_args() 87 | serve( 88 | model_checkpoint_path=args.model_checkpoint, 89 | prediction_params=args.prediction_params, 90 | port=args.p, 91 | cuda_device=args.cuda_device, 92 | token_batch_size=args.token_batch_size, 93 | ) 94 | 95 | 96 | def parse_args(): 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument( 99 | "model_checkpoint", type=str, help="Path to pl_modules checkpoint" 100 | ) 101 | parser.add_argument( 102 | "--prediction-params", type=str, default=None, help="Path to prediction params" 103 | ) 104 | parser.add_argument( 105 | "-p", type=int, default=8000, help="Port on which to expose the model" 106 | ) 107 | parser.add_argument("--cuda-device", type=int, default=-1, help="Cuda device") 108 | parser.add_argument( 109 | "--token-batch-size", type=int, default=128, help="Token batch size" 110 | ) 111 | return parser.parse_args() 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /classy/scripts/model/upload.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import tempfile 5 | from datetime import datetime 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import huggingface_hub 10 | 11 | from classy.scripts.model.download import CLASSY_DATE_FORMAT, get_md5 12 | from classy.scripts.model.export import zip_run 13 | from classy.utils.experiment import Experiment 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def create_info_file(tmpdir: Path): 19 | logger.debug("Computing md5 of model.zip") 20 | md5 = get_md5(tmpdir / "model.zip") 21 | date = datetime.now().strftime(CLASSY_DATE_FORMAT) 22 | 23 | logger.debug("Dumping info.json file") 24 | with (tmpdir / "info.json").open("w") as f: 25 | json.dump(dict(md5=md5, upload_date=date), f, indent=2) 26 | 27 | 28 | def upload( 29 | model_name, 30 | organization: Optional[str] = None, 31 | repo_name: Optional[str] = None, 32 | commit: Optional[str] = None, 33 | ): 34 | token = huggingface_hub.HfFolder.get_token() 35 | if token is None: 36 | print( 37 | "No HuggingFace token found. You need to execute `huggingface-cli login` first!" 38 | ) 39 | return 40 | 41 | exp = Experiment.from_name(model_name) 42 | if exp is None: 43 | print(f"No experiment named {model_name} found. Exiting...") 44 | return 45 | 46 | run = exp.last_valid_run 47 | if run is None: 48 | print(f"No valid run found for experiment {model_name}. Exiting...") 49 | return 50 | 51 | with tempfile.TemporaryDirectory() as tmpdir: 52 | api = huggingface_hub.hf_api.HfApi() 53 | repo_url = api.create_repo( 54 | token=token, 55 | name=repo_name or model_name, 56 | organization=organization, 57 | exist_ok=True, 58 | ) 59 | repo = huggingface_hub.Repository( 60 | str(tmpdir), clone_from=repo_url, use_auth_token=token 61 | ) 62 | 63 | tmp_path = Path(tmpdir) 64 | zip_run(run, tmp_path) 65 | create_info_file(tmp_path) 66 | 67 | # this method automatically puts large files (>10MB) into git lfs 68 | repo.push_to_hub(commit_message=commit or "Automatic push from classy") 69 | 70 | 71 | def parse_args() -> argparse.Namespace: 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("model_name", help="The model you want to upload") 74 | parser.add_argument( 75 | "--organization", 76 | help="[optional] the name of the organization where you want to upload the model", 77 | ) 78 | parser.add_argument( 79 | "--name", 80 | help="Optional name to use when uploading to the HuggingFace repository", 81 | ) 82 | parser.add_argument( 83 | "--commit", help="Commit message to use when pushing to the HuggingFace Hub" 84 | ) 85 | return parser.parse_args() 86 | 87 | 88 | def main(): 89 | args = parse_args() 90 | upload(args.model_name, args.organization, args.name, args.commit) 91 | 92 | 93 | if __name__ == "__main__": 94 | main() 95 | -------------------------------------------------------------------------------- /classy/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/classy/utils/__init__.py -------------------------------------------------------------------------------- /classy/utils/commons.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import socket 3 | import subprocess 4 | from typing import Iterable, Optional, Tuple 5 | 6 | import numpy as np 7 | 8 | from classy.utils.log import get_project_logger 9 | 10 | logger = get_project_logger(__name__) 11 | 12 | 13 | def execute_bash_command(command: str) -> Optional[str]: 14 | command_result = subprocess.run(command, shell=True, capture_output=True) 15 | try: 16 | command_result.check_returncode() 17 | return command_result.stdout.decode("utf-8") 18 | except subprocess.CalledProcessError: 19 | logger.warning(f"failed executing command: {command}") 20 | logger.warning(f"return code was: {command_result.returncode}") 21 | logger.warning(f'stdout was: {command_result.stdout.decode("utf-8")}') 22 | logger.warning(f'stderr code was: {command_result.stderr.decode("utf-8")}') 23 | return None 24 | 25 | 26 | def flatten(lst: Iterable[list]) -> list: 27 | return [_e for sub_l in lst for _e in sub_l] 28 | 29 | 30 | def chunks(lst, n): 31 | """Yield successive n-sized chunks from lst.""" 32 | for i in range(0, len(lst), n): 33 | yield lst[i : i + n] 34 | 35 | 36 | def grouper(iterable, n): 37 | it = iter(iterable) 38 | while True: 39 | chunk = tuple(itertools.islice(it, n)) 40 | if not chunk: 41 | return 42 | yield chunk 43 | 44 | 45 | def split_by_first(text: str, split: str) -> Tuple[str, str]: 46 | split_idx = text.index(split) 47 | return text[:split_idx], text[split_idx + len(split) :] 48 | 49 | 50 | def add_noise_to_value(value: int, noise_param: float): 51 | noise_value = value * noise_param 52 | noise = np.random.uniform(-noise_value, noise_value) 53 | return max(1, value + noise) 54 | 55 | 56 | def get_local_ip_address() -> str: 57 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 58 | s.connect(("8.8.8.8", 80)) 59 | address = s.getsockname()[0] 60 | s.close() 61 | return address 62 | -------------------------------------------------------------------------------- /classy/utils/data.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | from typing import Dict, Optional, Tuple 4 | 5 | import numpy as np 6 | 7 | from classy.data.data_drivers import DataDriver 8 | from classy.utils.log import get_project_logger 9 | 10 | logger = get_project_logger(__name__) 11 | 12 | 13 | def create_data_dir(): 14 | # create data folder 15 | output_folder = Path("data") 16 | output_folder.mkdir(exist_ok=True) 17 | 18 | 19 | def load_dataset( 20 | path2data_driver: Dict[str, DataDriver], 21 | ) -> list: 22 | dataset_iterator = itertools.chain( 23 | *[dd.read_from_path(p) for p, dd in path2data_driver.items()] 24 | ) 25 | return list(dataset_iterator) 26 | 27 | 28 | def shuffle_dataset( 29 | path2data_driver: Dict[str, DataDriver], 30 | ) -> list: 31 | samples = load_dataset(path2data_driver) 32 | np.random.shuffle(samples) 33 | return samples 34 | 35 | 36 | def shuffle_and_store_dataset( 37 | path2data_driver: Dict[str, DataDriver], 38 | main_data_driver: DataDriver, 39 | output_path: str, 40 | ) -> None: 41 | samples = shuffle_dataset(path2data_driver) 42 | main_data_driver.save(samples, output_path) 43 | 44 | 45 | # TODO: we have to modify this script in order to support the split without loading the whole dataset in memory 46 | def split_dataset( 47 | path2data_driver: Dict[str, DataDriver], 48 | main_data_driver: DataDriver, 49 | main_extension: str, 50 | output_folder: str, 51 | validation_split_size: Optional[float] = None, 52 | test_split_size: Optional[float] = None, 53 | data_max_split: Optional[int] = None, 54 | shuffle: bool = True, 55 | ) -> Tuple[ 56 | Dict[str, DataDriver], 57 | Optional[Dict[str, DataDriver]], 58 | Optional[Dict[str, DataDriver]], 59 | ]: 60 | 61 | assert ( 62 | sum([validation_split_size or 0.0, test_split_size or 0.0]) > 0.0 63 | ), "At least one between validation_split_size and test_split_size must be provided with a value > 0" 64 | 65 | # create output folder 66 | create_data_dir() 67 | 68 | # read samples and shuffle 69 | if shuffle: 70 | logger.info("Materializing and shuffling dataset before splitting it") 71 | samples = shuffle_dataset(path2data_driver) 72 | else: 73 | logger.info("Materializing dataset before splitting it") 74 | samples = load_dataset(path2data_driver) 75 | 76 | # splitting 77 | training_samples = samples 78 | train_path, validation_path, test_path = None, None, None 79 | 80 | output_folder = Path(output_folder) 81 | 82 | if validation_split_size is not None: 83 | n_validation_samples = min( 84 | int(len(samples) * validation_split_size), data_max_split or len(samples) 85 | ) 86 | validation_samples, training_samples = ( 87 | training_samples[:n_validation_samples], 88 | training_samples[n_validation_samples:], 89 | ) 90 | validation_path = str(output_folder.joinpath(f"validation.{main_extension}")) 91 | main_data_driver.save(validation_samples, validation_path) 92 | 93 | if test_split_size is not None: 94 | n_test_samples = min( 95 | int(len(samples) * test_split_size), data_max_split or len(samples) 96 | ) 97 | test_samples, training_samples = ( 98 | training_samples[:n_test_samples], 99 | training_samples[n_test_samples:], 100 | ) 101 | test_path = str(output_folder.joinpath(f"test.{main_extension}")) 102 | main_data_driver.save(test_samples, test_path) 103 | 104 | train_path = str(output_folder.joinpath(f"train.{main_extension}")) 105 | main_data_driver.save(training_samples, train_path) 106 | 107 | train_bundle = {train_path: main_data_driver} if train_path is not None else None 108 | validation_bundle = ( 109 | {validation_path: main_data_driver} if validation_path is not None else None 110 | ) 111 | test_bundle = {test_path: main_data_driver} if test_path is not None else None 112 | 113 | return train_bundle, validation_bundle, test_bundle 114 | -------------------------------------------------------------------------------- /classy/utils/file.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | logger = logging.getLogger(__name__) 5 | 6 | 7 | CLASSY_HF_MODEL_URL = ( 8 | "https://huggingface.co/{user_name}/{model_name}/resolve/main/model.zip" 9 | ) 10 | CLASSY_HF_INFO_URL = ( 11 | "https://huggingface.co/{user_name}/{model_name}/raw/main/info.json" 12 | ) 13 | CLASSY_MODELS_CACHE_DIR = ".cache/sunglasses-ai/classy" 14 | CLASSY_MODELS_CACHE_PATH = Path.home() / CLASSY_MODELS_CACHE_DIR 15 | CLASSY_DATE_FORMAT = "%Y-%m-%d %H-%M-%S" 16 | 17 | 18 | def ensure_dir(path) -> Path: 19 | """ 20 | Create dir in case it does not exist. 21 | """ 22 | p = Path(path) 23 | p.mkdir(parents=True, exist_ok=True) 24 | return p 25 | 26 | 27 | def get_md5(path: Path): 28 | """ 29 | Get the MD5 value of a path. 30 | """ 31 | import hashlib 32 | 33 | with path.open("rb") as fin: 34 | data = fin.read() 35 | return hashlib.md5(data).hexdigest() 36 | -------------------------------------------------------------------------------- /classy/utils/help_cli.py: -------------------------------------------------------------------------------- 1 | HELP_MODEL_PATH = """ 2 | The model you want to use for the demo. Can be 3 | 1) the experiment name: "my_experiment" and classy will automatically 4 | look for the most recent run and the best checkpoint for that run under 5 | "experiments/my_experiment". 2) experiment directory path: 6 | "experiments/my_experiment" and classy will automatically look for the most 7 | recent run and the best checkpoint of that run under the provided model directory. 8 | 3) experiment directory comprising of date and hour (i.e. specific run): 9 | "experiments/my_experiments/20-10-2021/15-23-58" and classy will look for the best 10 | checkpoint for that specific run 4) experiment specific checkpoint: 11 | "experiments/my_experiments/20-10-2021/15-23-58/checkpoints/last.ckpt. 12 | """ 13 | 14 | 15 | HELP_TOKEN_BATCH_SIZE = "The maximum amount of tokens in a batch." 16 | 17 | HELP_FILE_PATH = """ 18 | Optional. If specified the evaluation will be performed on this file. Otherwise, classy will try to infer 19 | the file_path from the training configuration. Either by searching under dataset_path/test.data_format where 20 | "dataset_path" is the directory passed at training time; or under the "model_path" directory if you passed only 21 | one file at training time. 22 | """ 23 | 24 | HELP_EVALUATE = """ 25 | Path to evaluation config to use. 26 | """ 27 | 28 | HELP_PREDICTION_PARAMS = """ 29 | Path to prediction params. 30 | """ 31 | 32 | HELP_TASKS = """ 33 | One of the tasks that classy supports [sequence, sentence-pair, token, qa, generation]. 34 | """ 35 | -------------------------------------------------------------------------------- /classy/utils/hydra.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import omegaconf 4 | 5 | 6 | def fix_paths(conf, check_fn: Callable[[str], bool], fix_fn: Callable[[str], str]): 7 | if type(conf) == list or type(conf) == omegaconf.listconfig.ListConfig: 8 | for i in range(len(conf)): 9 | conf[i] = fix_paths(conf[i], check_fn=check_fn, fix_fn=fix_fn) 10 | return conf 11 | elif type(conf) == dict or type(conf) == omegaconf.dictconfig.DictConfig: 12 | for k, v in conf.items(): 13 | conf[k] = fix_paths(v, check_fn=check_fn, fix_fn=fix_fn) 14 | return conf 15 | elif type(conf) == str: 16 | if "/" in conf and check_fn(conf): 17 | return fix_fn(conf) 18 | else: 19 | return conf 20 | elif type(conf) in [float, int, bool]: 21 | return conf 22 | elif conf is None: 23 | return conf 24 | else: 25 | raise ValueError(f"Unexpected type {type(conf)}: {conf}") 26 | -------------------------------------------------------------------------------- /classy/utils/lightning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Callable, Dict 4 | 5 | import hydra 6 | import omegaconf 7 | from omegaconf import DictConfig, OmegaConf 8 | 9 | from classy.pl_modules.base import ClassyPLModule 10 | from classy.utils.hydra import fix_paths 11 | from classy.utils.vocabulary import Vocabulary 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def load_training_conf_from_checkpoint( 17 | checkpoint_path: str, post_trainer_init: bool = False 18 | ) -> DictConfig: 19 | # find hydra config path 20 | experiment_folder = Path(checkpoint_path).parent.parent 21 | # load hydra configs 22 | conf = OmegaConf.load(f"{experiment_folder}/.hydra/config.yaml") 23 | 24 | # fix paths 25 | def check_fn(path): 26 | # check whether path exists relative in the experiment resources folder 27 | # if it does, fix it 28 | return (experiment_folder / "resources").joinpath(path).exists() 29 | 30 | fix_paths( 31 | conf, 32 | check_fn=check_fn, 33 | fix_fn=lambda path: str((experiment_folder / "resources").joinpath(path)), 34 | ) 35 | # return 36 | return conf 37 | 38 | 39 | def load_classy_module_from_checkpoint(checkpoint_path: str) -> ClassyPLModule: 40 | """ 41 | Load a PL module from a checkpoint path only. Infer the model to load from the dumped hydra conf 42 | 43 | Args: 44 | checkpoint_path (str): 45 | 46 | Returns: 47 | pl.LightningModule 48 | 49 | """ 50 | 51 | conf = load_training_conf_from_checkpoint(checkpoint_path) 52 | 53 | # check if the model requires a vocab 54 | train_dataset_class = conf["data"]["datamodule"]["dataset"]["_target_"] 55 | if not train_dataset_class.split(".")[-1][ 56 | 0 57 | ].isupper(): # if it is not upper then it is a class method 58 | train_dataset_class = ".".join(train_dataset_class.split(".")[:-1]) 59 | 60 | requires_vocab = hydra.utils.instantiate( 61 | {"_target_": f"{train_dataset_class}.requires_vocab"} 62 | ) 63 | 64 | # extract and build vocabulary 65 | vocabulary_path = Path(checkpoint_path).parent.parent / "vocabulary" 66 | 67 | assert ( 68 | not requires_vocab 69 | ) or vocabulary_path.exists(), f"No vocabulary found at path {vocabulary_path}" 70 | 71 | vocabulary = None 72 | if vocabulary_path.exists(): 73 | vocabulary = Vocabulary.from_folder(vocabulary_path) 74 | 75 | # prepare instantiate params 76 | instantiate_input = dict( 77 | checkpoint_path=checkpoint_path, _recursive_=False, **conf.model 78 | ) 79 | if vocabulary is not None: 80 | instantiate_input["vocabulary"] = vocabulary 81 | instantiate_input["_target_"] = f'{conf.model["_target_"]}.load_from_checkpoint' 82 | 83 | # instantiate and return 84 | return hydra.utils.instantiate(instantiate_input) 85 | 86 | 87 | def load_prediction_dataset_conf_from_checkpoint(checkpoint_path: str) -> DictConfig: 88 | """ 89 | Load a dataset conf from a checkpoint path only, inferring it from the dumped hydra conf 90 | 91 | Args: 92 | checkpoint_path (str): 93 | 94 | Returns: 95 | Dict 96 | 97 | """ 98 | conf = load_training_conf_from_checkpoint(checkpoint_path) 99 | return conf.prediction.dataset 100 | -------------------------------------------------------------------------------- /classy/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | _loggers = {} 4 | 5 | 6 | def get_project_logger(module_name: str) -> logging.Logger: 7 | return _loggers.setdefault(module_name, logging.getLogger(module_name)) 8 | -------------------------------------------------------------------------------- /classy/utils/omegaconf.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import re 3 | 4 | import hydra 5 | from omegaconf import DictConfig, OmegaConf 6 | 7 | 8 | def is_interpolation(cfg, key): 9 | # OmegaConf.is_interpolation(base_cfg, key) requires key to be a direct children of base_cfg 10 | # this function patch this behavior so as to support any successor 11 | parent = cfg 12 | if "." in key: 13 | parent = OmegaConf.select(cfg, key[: key.rindex(".")]) 14 | key = key[key.rindex(".") + 1 :] 15 | return OmegaConf.is_interpolation(parent, key) 16 | 17 | 18 | def adapt_dataset_from(training_dataset: DictConfig, setting: str): 19 | # duplicate configuration 20 | training_dataset = copy.deepcopy(training_dataset) 21 | # resolve otherwise hydra.utils.instantiate will crash 22 | OmegaConf.resolve(training_dataset) 23 | # identify adaptation code 24 | train_dataset_class = training_dataset["_target_"] 25 | if not train_dataset_class.split(".")[-1][ 26 | 0 27 | ].isupper(): # if it is not upper then it is a class method 28 | train_dataset_class = ".".join(train_dataset_class.split(".")[:-1]) 29 | # invoke it and return 30 | return hydra.utils.instantiate( 31 | {"_target_": f"{train_dataset_class}.adapt_dataset_from"}, 32 | training_dataset=training_dataset, 33 | setting=setting, 34 | _recursive_=False, 35 | ) 36 | 37 | 38 | OmegaConf.register_new_resolver("adapt_dataset_from", adapt_dataset_from) 39 | 40 | 41 | def resolve_hf_generation_base_dataset_on_transformer_model( 42 | transformer_model: str, 43 | ) -> str: 44 | if re.fullmatch("facebook/bart-(base|large)", transformer_model): 45 | return "classy.data.dataset.hf.generation.BartHFGenerationDataset.from_file" 46 | elif re.fullmatch("facebook/mbart-large-(cc25|50)", transformer_model): 47 | return "classy.data.dataset.hf.generation.MBartHFGenerationDataset.from_file" 48 | elif transformer_model.startswith("gpt2"): 49 | return "classy.data.dataset.hf.generation.GPT2HFGenerationCataset.from_file" 50 | elif ( 51 | transformer_model.startswith("t5-") 52 | or transformer_model.startswith("google/t5-") 53 | or transformer_model.startswith("google/mt5-") 54 | ): 55 | return "classy.data.dataset.hf.generation.T5HFGenerationDataset.from_file" 56 | else: 57 | raise ValueError( 58 | f"{transformer_model} not currently supported in automatic resolution. But you can still write your own dataset (write _target_ and its parameters)." 59 | ) 60 | 61 | 62 | OmegaConf.register_new_resolver( 63 | "resolve_hf_generation_base_dataset_on_transformer_model", 64 | resolve_hf_generation_base_dataset_on_transformer_model, 65 | ) 66 | 67 | 68 | def resolve_hf_generation_module_on_transformer_model( 69 | transformer_model: str, 70 | ) -> str: 71 | if re.fullmatch("facebook/bart-(base|large)", transformer_model): 72 | return "classy.pl_modules.hf.generation.BartGenerativeModule" 73 | elif re.fullmatch("facebook/mbart-large-(cc25|50)", transformer_model): 74 | return "classy.pl_modules.hf.generation.MBartGenerativeModule" 75 | elif transformer_model.startswith("gpt2"): 76 | return "classy.pl_modules.hf.generation.GPT2GenerativeModule" 77 | elif ( 78 | transformer_model.startswith("t5-") 79 | or transformer_model.startswith("google/t5-") 80 | or transformer_model.startswith("google/mt5-") 81 | ): 82 | return "classy.pl_modules.hf.generation.T5GenerativeModule" 83 | else: 84 | raise ValueError( 85 | f"{transformer_model} not currently supported in automatic resolution. But you can still write your own dataset (write _target_ and its parameters)." 86 | ) 87 | 88 | 89 | OmegaConf.register_new_resolver( 90 | "resolve_hf_generation_module_on_transformer_model", 91 | resolve_hf_generation_module_on_transformer_model, 92 | ) 93 | -------------------------------------------------------------------------------- /classy/utils/optional_deps.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import Optional 3 | 4 | 5 | def requires(library, extra_required: Optional[str] = None): 6 | def closure(decorated_arg): 7 | def inner(*args, **kwargs): 8 | try: 9 | importlib.import_module(library) 10 | except ModuleNotFoundError as e: 11 | error_message = f"ModuleNotFoundError: {library} not found." 12 | if extra_required is not None: 13 | error_message += f" It seems you haven't installed classy[{extra_required}], try doing `pip install classy[{extra_required}]`" 14 | raise ModuleNotFoundError(error_message) 15 | return decorated_arg(*args, **kwargs) 16 | 17 | return inner 18 | 19 | return closure 20 | -------------------------------------------------------------------------------- /classy/utils/plotly.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from classy.utils.optional_deps import requires 4 | 5 | try: 6 | import plotly.graph_objects as go 7 | except ImportError: 8 | go = None 9 | 10 | 11 | @requires("plotly.graph_objects") 12 | def boxplot(y: np.ndarray, x_name: str, y_name: str, color: str): 13 | fig = go.Figure() 14 | fig.add_trace( 15 | go.Box( 16 | y=y, 17 | name=x_name, 18 | marker_color=color, 19 | boxmean="sd", 20 | ), 21 | ) 22 | fig.update_layout( 23 | yaxis_title=y_name, 24 | margin=dict(l=20, r=20, t=20, b=20), 25 | ) 26 | return fig 27 | -------------------------------------------------------------------------------- /classy/utils/streamlit.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Callable, List 3 | 4 | import numpy as np 5 | 6 | 7 | def get_random_color_generator(colors: List[str]) -> Callable[[], str]: 8 | 9 | random.shuffle(colors) 10 | 11 | colors = iter(colors) 12 | 13 | def f(): 14 | try: 15 | return next(colors) 16 | except StopIteration: 17 | return "#%06x" % random.randint(0x000000, 0xFFFFFF) 18 | 19 | return f 20 | 21 | 22 | def get_md_200_random_color_generator() -> Callable[[], str]: 23 | 24 | # colors taken from https://gist.githubusercontent.com/daniellevass/b0b8cfa773488e138037/raw/d2182c212a4132c0f3bb093fd0010395f927a219/android_material_design_colours.xml 25 | # md_.*_200 26 | colors_md_200 = [ 27 | "#EF9A9A", 28 | "#F48FB1", 29 | "#CE93D8", 30 | "#B39DDB", 31 | "#9FA8DA", 32 | "#90CAF9", 33 | "#81D4fA", 34 | "#80DEEA", 35 | "#80CBC4", 36 | "#A5D6A7", 37 | "#C5E1A5", 38 | "#E6EE9C", 39 | "#FFF590", 40 | "#FFE082", 41 | "#FFCC80", 42 | "#FFAB91", 43 | "#BCAAA4", 44 | "#EEEEEE", 45 | "#B0BBC5", 46 | ] 47 | return get_random_color_generator(colors_md_200) 48 | 49 | 50 | def get_md_400_random_color_generator() -> Callable[[], str]: 51 | 52 | # colors taken from https://gist.githubusercontent.com/daniellevass/b0b8cfa773488e138037/raw/d2182c212a4132c0f3bb093fd0010395f927a219/android_material_design_colours.xml 53 | # md_.*_400 54 | colors_md_400 = [ 55 | "#EF5350", 56 | "#EC407A", 57 | "#AB47BC", 58 | "#7E57C2", 59 | "#5C6BC0", 60 | "#42A5F5", 61 | "#29B6FC", 62 | "#26C6DA", 63 | "#26A69A", 64 | "#66BB6A", 65 | "#9CCC65", 66 | "#D4E157", 67 | "#FFEE58", 68 | "#FFCA28", 69 | "#FFA726", 70 | "#FF7043", 71 | "#8D6E63", 72 | "#BDBDBD", 73 | "#78909C", 74 | ] 75 | return get_random_color_generator(colors_md_400) 76 | -------------------------------------------------------------------------------- /classy/utils/train_coordinates.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import hydra 5 | from omegaconf import DictConfig, ListConfig 6 | 7 | from classy.data.data_drivers import DataDriver, get_data_driver 8 | from classy.utils.log import get_project_logger 9 | 10 | logger = get_project_logger(__name__) 11 | 12 | 13 | def load_bundle( 14 | bundle_conf: Optional[Union[str, Dict[str, str]]], 15 | task: str, 16 | compute_main_extension: bool = False, 17 | ) -> Optional[Union[Dict[str, DataDriver], Tuple[Dict[str, DataDriver], str]]]: 18 | if bundle_conf is None: 19 | return None 20 | 21 | main_extension = None 22 | if type(bundle_conf) == str: 23 | file_extension = bundle_conf.split(".")[-1] 24 | bundle_store = { 25 | hydra.utils.to_absolute_path(bundle_conf): get_data_driver( 26 | task, file_extension 27 | ) 28 | } 29 | if compute_main_extension: 30 | main_extension = file_extension 31 | elif type(bundle_conf) == ListConfig: 32 | file_extensions = [path.split(".")[-1] for path in bundle_conf] 33 | bundle_store = { 34 | hydra.utils.to_absolute_path(path): get_data_driver(task, file_extension) 35 | for path, file_extension in zip(bundle_conf, file_extensions) 36 | } 37 | if compute_main_extension: 38 | main_extension = collections.Counter(file_extensions).most_common(1)[0][0] 39 | elif type(bundle_conf) == DictConfig: 40 | bundle_store = { 41 | hydra.utils.to_absolute_path(path): get_data_driver(task, file_extension) 42 | for path, file_extension in bundle_conf.items() 43 | } 44 | if compute_main_extension: 45 | main_extension = collections.Counter(bundle_conf.values()).most_common(1)[ 46 | 0 47 | ][0] 48 | else: 49 | logger.error( 50 | "The value of the dataset in the coordinates file " 51 | "must be either a string indicating the dataset, a " 52 | "list of string or a dict path -> file_extension" 53 | ) 54 | raise NotImplementedError 55 | 56 | if main_extension is not None: 57 | return bundle_store, main_extension 58 | else: 59 | return bundle_store 60 | -------------------------------------------------------------------------------- /classy/utils/vocabulary.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import logging 3 | from pathlib import Path 4 | from typing import Dict, Iterable 5 | 6 | FIELDS_VOCABULARY_PATH = "fields_vocabulary_path.tsv" 7 | LABELS_VOCABULARY_PATH = "labels_vocabulary_path.tsv" 8 | 9 | 10 | class Vocabulary: 11 | 12 | PAD = "" 13 | UNK = "" 14 | 15 | @classmethod 16 | def from_samples(cls, samples: Iterable[Dict[str, str]], add_pad_unk: bool = True): 17 | if add_pad_unk: 18 | backend_vocab = collections.defaultdict( 19 | lambda: {Vocabulary.PAD: 0, Vocabulary.UNK: 1} 20 | ) 21 | else: 22 | backend_vocab = collections.defaultdict(dict) 23 | for sample in samples: 24 | for k, v in sample.items(): 25 | elem2idx = backend_vocab[k] 26 | if v not in elem2idx: 27 | elem2idx[v] = len(elem2idx) 28 | return cls(backend_vocab) 29 | 30 | @classmethod 31 | def from_folder(cls, path: str): 32 | backend_vocab = {} 33 | folder = Path(path) 34 | for f in folder.iterdir(): 35 | k = f.name[: f.name.rindex(".txt")] 36 | elem2idx = {} 37 | with open(f) as _f: 38 | for line in _f: 39 | _k, _v = line.strip().split("\t") 40 | elem2idx[_k] = int(_v) 41 | backend_vocab[k] = elem2idx 42 | return cls(backend_vocab) 43 | 44 | def __init__(self, backend_vocab: Dict[str, Dict[str, int]]): 45 | self.backend_vocab = backend_vocab 46 | self.reverse_backend_vocab = { 47 | k: {_v: _k for _k, _v in v.items()} for k, v in backend_vocab.items() 48 | } 49 | 50 | def get_size(self, k: str) -> int: 51 | return len(self.backend_vocab[k]) 52 | 53 | def get_idx(self, k: str, elem: str) -> int: 54 | idx = self.backend_vocab[k].get(elem) 55 | if idx is None: 56 | if Vocabulary.UNK in self.backend_vocab[k]: 57 | idx = self.backend_vocab[k][Vocabulary.UNK] 58 | else: 59 | logging.error( 60 | f"Unknown element found {elem} but no {Vocabulary.UNK} in vocabulary" 61 | ) 62 | raise KeyError 63 | return idx 64 | 65 | def get_elem(self, k: str, idx: int) -> str: 66 | return self.reverse_backend_vocab[k][idx] 67 | 68 | def save(self, path: str) -> None: 69 | folder = Path(path) 70 | folder.mkdir() 71 | for k, v in self.backend_vocab.items(): 72 | with open(folder / f"{k}.txt", "w") as f: 73 | for _k, _v in v.items(): 74 | f.write(f"{_k}\t{_v}\n") 75 | -------------------------------------------------------------------------------- /classy/version.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | _MAJOR = "0" 4 | _MINOR = "3" 5 | # On main and in a nightly release the patch should be one ahead of the last 6 | # released build. 7 | _PATCH = "2" 8 | # This is mainly for nightly builds which have the suffix ".dev$DATE". See 9 | # https://semver.org/#is-v123-a-semantic-version for the semantics. 10 | _SUFFIX = os.environ.get("CLASSY_VERSION_SUFFIX", "") 11 | 12 | VERSION_SHORT = "{0}.{1}".format(_MAJOR, _MINOR) 13 | VERSION = "{0}.{1}.{2}{3}".format(_MAJOR, _MINOR, _PATCH, _SUFFIX) 14 | -------------------------------------------------------------------------------- /configurations/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/configurations/.placeholder -------------------------------------------------------------------------------- /configurations/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is only here because hydra needs it when packaging configs inside applications :) 2 | -------------------------------------------------------------------------------- /configurations/callbacks/empty.yaml: -------------------------------------------------------------------------------- 1 | [] 2 | -------------------------------------------------------------------------------- /configurations/callbacks/evaluation.yaml: -------------------------------------------------------------------------------- 1 | - _target_: "classy.pl_callbacks.prediction.PredictionPLCallback" 2 | path: null # leave it to null to set it to validation path 3 | prediction_dataset_conf: ${prediction.dataset} 4 | on_result: 5 | file_dumper: 6 | _target_: "classy.pl_callbacks.prediction.FileDumperPredictionCallback" 7 | evaluation: 8 | _target_: "classy.pl_callbacks.prediction.EvaluationPredictionCallback" 9 | evaluation: ${evaluation} 10 | settings: 11 | - name: "val" 12 | path: null # leave it to null to set it to PredictionPLCallback.path 13 | token_batch_size: 800 14 | prediction_param_conf_path: null 15 | limit: 1000 16 | on_result: 17 | - "file_dumper" 18 | - "evaluation" 19 | -------------------------------------------------------------------------------- /configurations/callbacks/file-dumper.yaml: -------------------------------------------------------------------------------- 1 | - _target_: "classy.pl_callbacks.prediction.PredictionPLCallback" 2 | validation_path: null 3 | prediction_confs: 4 | - name: "validation" 5 | path: null # if you leave it to None, it will be set to the validation path 6 | token_batch_size: 400 7 | prediction_param_conf_path: "configurations/prediction-params/generation-beam.yaml" 8 | limit: 1000 9 | enabled_prediction_callbacks: 10 | - "file_dumper" 11 | prediction_callbacks: 12 | file_dumper: 13 | _target_: "classy.pl_callbacks.prediction.FileDumperPredictionCallback" 14 | prediction_dataset_conf: ${prediction.dataset} 15 | -------------------------------------------------------------------------------- /configurations/data/generation.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'classy.data.data_modules.ClassyDataModule' 3 | task: ${task} 4 | dataset_path: null # set via kwargs 5 | dataset: 6 | _target_: "${resolve_hf_generation_base_dataset_on_transformer_model:${transformer_model}}" 7 | transformer_model: ${transformer_model} 8 | truncation: true 9 | additional_special_tokens: "${oc.select:'model.additional_special_tokens',${oc.decode:'[]'}}" 10 | teacher_forcing: True 11 | min_length: 5 12 | max_length: 512 13 | tokens_per_batch: 1000 14 | max_batch_size: 10 15 | section_size: 10000 16 | prebatch: True 17 | materialize: False 18 | for_inference: False 19 | validation_dataset: "${adapt_dataset_from:${data.datamodule.dataset},validation}" 20 | validation_split_size: 0.1 21 | test_split_size: 0.1 22 | max_nontrain_split_size: 500 23 | shuffle_dataset: True 24 | -------------------------------------------------------------------------------- /configurations/data/qa.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'classy.data.data_modules.ClassyDataModule' 3 | task: ${task} 4 | dataset_path: null # set via kwargs 5 | dataset: 6 | _target_: 'classy.data.dataset.hf.classification.HFQADataset.from_file' 7 | transformer_model: ${transformer_model} 8 | truncation: true 9 | additional_special_tokens: "${oc.select:'model.additional_special_tokens',${oc.decode:'[]'}}" 10 | min_length: 5 11 | max_length: 512 12 | tokens_per_batch: 2000 13 | max_batch_size: 10 14 | section_size: 10000 15 | prebatch: True 16 | materialize: False 17 | for_inference: False 18 | validation_dataset: "${adapt_dataset_from:${data.datamodule.dataset},validation}" 19 | validation_split_size: 0.1 20 | test_split_size: 0.1 21 | max_nontrain_split_size: 10000 22 | shuffle_dataset: True 23 | -------------------------------------------------------------------------------- /configurations/data/sentence-pair.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'classy.data.data_modules.ClassyDataModule' 3 | task: ${task} 4 | dataset_path: null # set via kwargs 5 | dataset: 6 | _target_: 'classy.data.dataset.hf.classification.HFSentencePairDataset.from_file' 7 | transformer_model: ${transformer_model} 8 | truncation: true 9 | additional_special_tokens: "${oc.select:'model.additional_special_tokens',${oc.decode:'[]'}}" 10 | min_length: 5 11 | max_length: 512 12 | tokens_per_batch: 800 13 | max_batch_size: 10 14 | section_size: 10000 15 | prebatch: True 16 | materialize: False 17 | for_inference: False 18 | validation_dataset: "${adapt_dataset_from:${data.datamodule.dataset},validation}" 19 | validation_split_size: 0.1 20 | test_split_size: 0.1 21 | max_nontrain_split_size: 10000 22 | shuffle_dataset: True 23 | -------------------------------------------------------------------------------- /configurations/data/sequence.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'classy.data.data_modules.ClassyDataModule' 3 | task: ${task} 4 | dataset_path: null # set via kwargs 5 | dataset: 6 | _target_: 'classy.data.dataset.hf.classification.HFSequenceDataset.from_file' 7 | transformer_model: ${transformer_model} 8 | truncation: true 9 | additional_special_tokens: "${oc.select:'model.additional_special_tokens',${oc.decode:'[]'}}" 10 | min_length: 5 11 | max_length: 512 12 | tokens_per_batch: 2000 13 | max_batch_size: 10 14 | section_size: 10000 15 | prebatch: True 16 | materialize: False 17 | for_inference: False 18 | validation_dataset: "${adapt_dataset_from:${data.datamodule.dataset},validation}" 19 | validation_split_size: 0.1 20 | test_split_size: 0.1 21 | max_nontrain_split_size: 10000 22 | shuffle_dataset: True 23 | -------------------------------------------------------------------------------- /configurations/data/token.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: 'classy.data.data_modules.ClassyDataModule' 3 | task: ${task} 4 | dataset_path: null # set via kwargs 5 | dataset: 6 | _target_: 'classy.data.dataset.hf.classification.HFTokenDataset.from_file' 7 | transformer_model: ${transformer_model} 8 | truncation: true 9 | additional_special_tokens: "${oc.select:'model.additional_special_tokens',${oc.decode:'[]'}}" 10 | min_length: 5 11 | max_length: 512 12 | tokens_per_batch: 1000 13 | max_batch_size: 16 14 | section_size: 10000 15 | prebatch: True 16 | materialize: False 17 | for_inference: False 18 | validation_dataset: "${adapt_dataset_from:${data.datamodule.dataset},validation}" 19 | validation_split_size: 0.1 20 | test_split_size: 0.1 21 | max_nontrain_split_size: 10000 22 | shuffle_dataset: True 23 | -------------------------------------------------------------------------------- /configurations/evaluation/generation.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.simple.GenerationSimpleEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/qa.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.simple.QASimpleEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/rouge.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.generation.RougeEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/sacrebleu.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.generation.SacreBleuEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/sentence-pair.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.simple.SentencePairSimpleEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/sequence.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.simple.SequenceSimpleEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/span.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.span.SeqEvalSpanEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/squad-v1.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.squad.SQuADV1Evaluation' 2 | -------------------------------------------------------------------------------- /configurations/evaluation/token.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.evaluation.simple.TokenSimpleEvaluation' 2 | -------------------------------------------------------------------------------- /configurations/generation.yaml: -------------------------------------------------------------------------------- 1 | task: generation 2 | project_name: classy 3 | exp_name: null 4 | exp_folder: ./experiments/${exp_name} 5 | 6 | transformer_model: "facebook/bart-base" 7 | 8 | callbacks_monitor: 'val_loss' 9 | callbacks_mode: 'min' 10 | 11 | hydra: 12 | # customize working dir 13 | run: 14 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 19 | root: 20 | level: WARN 21 | 22 | # defaults 23 | defaults: 24 | - callbacks: empty 25 | - data: generation 26 | - evaluation: generation 27 | - model: generation 28 | - prediction: default 29 | - training: default 30 | - logging: default 31 | - _self_ 32 | -------------------------------------------------------------------------------- /configurations/logging/default.yaml: -------------------------------------------------------------------------------- 1 | # WANDB 2 | wandb: 3 | use_wandb: False 4 | project_name: null 5 | experiment_name: null 6 | anonymous: null 7 | run_id: null 8 | -------------------------------------------------------------------------------- /configurations/model/generation.yaml: -------------------------------------------------------------------------------- 1 | _target_: "${resolve_hf_generation_module_on_transformer_model:${transformer_model}}" 2 | transformer_model: ${transformer_model} 3 | additional_special_tokens: [] 4 | decoding_skip_special_tokens: True 5 | decoding_clean_up_tokenization_spaces: False 6 | optim_conf: 7 | _target_: classy.optim.factories.RAdamFactory 8 | lr: 1e-5 9 | weight_decay: 0.01 10 | no_decay_params: 11 | - bias 12 | - LayerNorm.weight 13 | -------------------------------------------------------------------------------- /configurations/model/qa.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.pl_modules.hf.classification.HFQAPLModule' 2 | transformer_model: ${transformer_model} 3 | additional_special_tokens: [] 4 | optim_conf: 5 | _target_: classy.optim.factories.RAdamFactory 6 | lr: 1e-5 7 | weight_decay: 0.01 8 | no_decay_params: 9 | - bias 10 | - LayerNorm.weight 11 | -------------------------------------------------------------------------------- /configurations/model/sentence-pair.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.pl_modules.hf.classification.HFSentencePairPLModule' 2 | transformer_model: ${transformer_model} 3 | additional_special_tokens: [] 4 | optim_conf: 5 | _target_: classy.optim.factories.RAdamFactory 6 | lr: 1e-5 7 | weight_decay: 0.01 8 | no_decay_params: 9 | - bias 10 | - LayerNorm.weight 11 | -------------------------------------------------------------------------------- /configurations/model/sequence.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.pl_modules.hf.classification.HFSequencePLModule' 2 | transformer_model: ${transformer_model} 3 | additional_special_tokens: [] 4 | optim_conf: 5 | _target_: classy.optim.factories.RAdamFactory 6 | lr: 1e-5 7 | weight_decay: 0.01 8 | no_decay_params: 9 | - bias 10 | - LayerNorm.weight 11 | -------------------------------------------------------------------------------- /configurations/model/token.yaml: -------------------------------------------------------------------------------- 1 | _target_: 'classy.pl_modules.hf.classification.HFTokensPLModule' 2 | transformer_model: ${transformer_model} 3 | additional_special_tokens: [] 4 | use_last_n_layers: 1 5 | fine_tune: True 6 | optim_conf: 7 | _target_: classy.optim.factories.RAdamFactory 8 | lr: 1e-5 9 | weight_decay: 0.01 10 | no_decay_params: 11 | - bias 12 | - LayerNorm.weight 13 | -------------------------------------------------------------------------------- /configurations/prediction-params/generation-beam.yaml: -------------------------------------------------------------------------------- 1 | num_return_sequences: 1 2 | num_beams: 5 3 | min_length: 5 4 | max_length: 100 5 | length_penalty: 1.0 6 | repetition_penalty: 1.0 7 | early_stopping: false 8 | -------------------------------------------------------------------------------- /configurations/prediction-params/generation-sample.yaml: -------------------------------------------------------------------------------- 1 | num_return_sequences: 1 2 | num_beams: 1 3 | do_sample: True 4 | top_p: 0.8 5 | temperature: 1.0 6 | min_length: 25 7 | max_length: 200 8 | length_penalty: 1.0 9 | repetition_penalty: 1.0 10 | -------------------------------------------------------------------------------- /configurations/prediction/default.yaml: -------------------------------------------------------------------------------- 1 | dataset: "${adapt_dataset_from:${data.datamodule.dataset},prediction}" 2 | -------------------------------------------------------------------------------- /configurations/profiles/bart-base.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | - generation 7 | 8 | # global params 9 | transformer_model: facebook/bart-base 10 | 11 | # trainer 12 | training: 13 | pl_trainer: 14 | accumulate_grad_batches: 1 15 | val_check_interval: 1.0 16 | max_steps: 100_000 17 | 18 | # MODEL PARAMS 19 | model: 20 | optim_conf: 21 | _target_: classy.optim.factories.RAdamFactory 22 | lr: 1e-5 23 | weight_decay: 0.01 24 | no_decay_params: 25 | - bias 26 | - LayerNorm.weight 27 | -------------------------------------------------------------------------------- /configurations/profiles/bart-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | - generation 7 | 8 | # global params 9 | transformer_model: facebook/bart-large 10 | 11 | # trainer 12 | training: 13 | pl_trainer: 14 | accumulate_grad_batches: 1 15 | val_check_interval: 1.0 16 | max_steps: 100_000 17 | 18 | # MODEL PARAMS 19 | model: 20 | optim_conf: 21 | _target_: classy.optim.factories.RAdamFactory 22 | lr: 1e-5 23 | weight_decay: 0.01 24 | no_decay_params: 25 | - bias 26 | - LayerNorm.weight 27 | -------------------------------------------------------------------------------- /configurations/profiles/bert-base.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: bert-base-cased 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 3e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/bert-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: bert-large-cased 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 3e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/deberta-base.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: microsoft/deberta-base 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | gradient_clip_val: 1.0 15 | val_check_interval: 1.0 16 | max_steps: 100_000 17 | 18 | # MODEL PARAMS 19 | model: 20 | optim_conf: 21 | _target_: classy.optim.factories.RAdamFactory 22 | lr: 3e-5 23 | weight_decay: 0.01 24 | no_decay_params: 25 | - bias 26 | - LayerNorm.weight 27 | -------------------------------------------------------------------------------- /configurations/profiles/deberta-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: microsoft/deberta-large 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | gradient_clip_val: 1.0 15 | val_check_interval: 1.0 16 | max_steps: 100_000 17 | 18 | # MODEL PARAMS 19 | model: 20 | optim_conf: 21 | _target_: classy.optim.factories.RAdamFactory 22 | lr: 5e-6 23 | weight_decay: 0.01 24 | no_decay_params: 25 | - bias 26 | - LayerNorm.weight 27 | -------------------------------------------------------------------------------- /configurations/profiles/distilbert.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: distilbert-base-cased 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 3e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/distilroberta.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: distilroberta 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 2e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/gpt2-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - generation 3 | 4 | # global params 5 | transformer_model: gpt2-large 6 | 7 | # trainer 8 | training: 9 | pl_trainer: 10 | accumulate_grad_batches: 1 11 | val_check_interval: 1.0 12 | max_steps: 100_000 13 | 14 | # MODEL PARAMS 15 | model: 16 | optim_conf: 17 | _target_: "classy.optim.factories.TorchFactory" 18 | optimizer: 19 | _target_: torch.optim.Adam 20 | lr: 1e-5 21 | -------------------------------------------------------------------------------- /configurations/profiles/gpt2-medium.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - generation 3 | 4 | # global params 5 | transformer_model: gpt2-medium 6 | 7 | # trainer 8 | training: 9 | pl_trainer: 10 | accumulate_grad_batches: 1 11 | val_check_interval: 1.0 12 | max_steps: 100_000 13 | 14 | # MODEL PARAMS 15 | model: 16 | optim_conf: 17 | _target_: "classy.optim.factories.TorchFactory" 18 | optimizer: 19 | _target_: torch.optim.Adam 20 | lr: 1e-5 21 | -------------------------------------------------------------------------------- /configurations/profiles/gpt2.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - generation 3 | 4 | # global params 5 | transformer_model: gpt2 6 | 7 | # trainer 8 | training: 9 | pl_trainer: 10 | accumulate_grad_batches: 1 11 | val_check_interval: 1.0 12 | max_steps: 100_000 13 | 14 | # MODEL PARAMS 15 | model: 16 | optim_conf: 17 | _target_: "classy.optim.factories.TorchFactory" 18 | optimizer: 19 | _target_: torch.optim.Adam 20 | lr: 1e-5 21 | -------------------------------------------------------------------------------- /configurations/profiles/mbart.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | - generation 7 | 8 | # global params 9 | transformer_model: facebook/mbart-large-cc25 10 | 11 | # trainer 12 | training: 13 | pl_trainer: 14 | accumulate_grad_batches: 1 15 | val_check_interval: 1.0 16 | max_steps: 100_000 17 | 18 | # MODEL PARAMS 19 | model: 20 | optim_conf: 21 | _target_: classy.optim.factories.AdafactorWithWarmupFactory 22 | lr: 1e-5 23 | warmup_steps: 5000 24 | total_steps: ${training.pl_trainer.max_steps} 25 | weight_decay: 0.01 26 | no_decay_params: 27 | - bias 28 | - LayerNorm.weight 29 | -------------------------------------------------------------------------------- /configurations/profiles/multilingual-bert.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: bert-base-multilingual-cased 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 3e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/roberta-base.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: roberta-base 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 2e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/roberta-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: roberta-large 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 2e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/squeezebert.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: distilbert-base-cased 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 3e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/xlm-roberta-base.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: xlm-roberta-base 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 2e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/profiles/xlm-roberta-large.yaml: -------------------------------------------------------------------------------- 1 | supported_tasks: 2 | - qa 3 | - sentence-pair 4 | - sequence 5 | - token 6 | 7 | # global params 8 | transformer_model: xlm-roberta-large 9 | 10 | # trainer 11 | training: 12 | pl_trainer: 13 | accumulate_grad_batches: 1 14 | val_check_interval: 1.0 15 | max_steps: 100_000 16 | 17 | # MODEL PARAMS 18 | model: 19 | optim_conf: 20 | _target_: classy.optim.factories.AdamWWithWarmupFactory 21 | lr: 2e-5 22 | warmup_steps: 5000 23 | total_steps: ${training.pl_trainer.max_steps} 24 | weight_decay: 0.01 25 | no_decay_params: 26 | - bias 27 | - LayerNorm.weight 28 | -------------------------------------------------------------------------------- /configurations/qa.yaml: -------------------------------------------------------------------------------- 1 | task: qa 2 | project_name: classy 3 | exp_name: null 4 | exp_folder: ./experiments/${exp_name} 5 | 6 | transformer_model: "bert-base-cased" 7 | 8 | callbacks_monitor: 'val_accuracy' 9 | callbacks_mode: 'max' 10 | 11 | hydra: 12 | # customize working dir 13 | run: 14 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 19 | root: 20 | level: WARN 21 | 22 | # defaults 23 | defaults: 24 | - callbacks: empty 25 | - data: qa 26 | - evaluation: qa 27 | - model: qa 28 | - prediction: default 29 | - training: default 30 | - logging: default 31 | - _self_ 32 | -------------------------------------------------------------------------------- /configurations/sentence-pair.yaml: -------------------------------------------------------------------------------- 1 | task: sentence-pair 2 | project_name: classy 3 | exp_name: null 4 | exp_folder: ./experiments/${exp_name} 5 | 6 | transformer_model: "bert-large-cased" 7 | 8 | callbacks_monitor: 'val_micro-f1-score' 9 | callbacks_mode: 'max' 10 | 11 | hydra: 12 | # customize working dir 13 | run: 14 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 19 | root: 20 | level: WARN 21 | 22 | # defaults 23 | defaults: 24 | - callbacks: empty 25 | - data: sentence-pair 26 | - evaluation: sentence-pair 27 | - model: sentence-pair 28 | - prediction: default 29 | - training: default 30 | - logging: default 31 | - _self_ 32 | -------------------------------------------------------------------------------- /configurations/sequence.yaml: -------------------------------------------------------------------------------- 1 | task: sequence 2 | project_name: classy 3 | exp_name: null 4 | exp_folder: ./experiments/${exp_name} 5 | 6 | transformer_model: "bert-large-cased" 7 | 8 | callbacks_monitor: 'val_micro-f1-score' 9 | callbacks_mode: 'max' 10 | 11 | hydra: 12 | # customize working dir 13 | run: 14 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 19 | root: 20 | level: WARN 21 | 22 | # defaults 23 | defaults: 24 | - callbacks: empty 25 | - data: sequence 26 | - evaluation: sequence 27 | - model: sequence 28 | - prediction: default 29 | - training: default 30 | - logging: default 31 | - _self_ 32 | -------------------------------------------------------------------------------- /configurations/token.yaml: -------------------------------------------------------------------------------- 1 | task: token 2 | project_name: classy 3 | exp_name: null 4 | exp_folder: ./experiments/${exp_name} 5 | 6 | transformer_model: "bert-base-cased" 7 | 8 | callbacks_monitor: 'val_micro-f1-score' 9 | callbacks_mode: 'max' 10 | 11 | hydra: 12 | # customize working dir 13 | run: 14 | dir: ./experiments/${exp_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 15 | job_logging: 16 | formatters: 17 | simple: 18 | format: '%(asctime)s - %(levelname)s - %(name)s - %(message)s' 19 | root: 20 | level: WARN 21 | 22 | # defaults 23 | defaults: 24 | - callbacks: empty 25 | - data: token 26 | - evaluation: token 27 | - model: token 28 | - prediction: default 29 | - training: default 30 | - logging: default 31 | - _self_ 32 | -------------------------------------------------------------------------------- /configurations/training/default.yaml: -------------------------------------------------------------------------------- 1 | # reproducibility 2 | seed: 12 3 | 4 | # pl_trainer 5 | pl_trainer: 6 | _target_: pytorch_lightning.Trainer 7 | accumulate_grad_batches: 4 8 | gradient_clip_val: 10.0 9 | val_check_interval: 1.0 10 | max_steps: 1_000_000 11 | 12 | # early stopping callback 13 | # "early_stopping_callback: null" will disable early stopping 14 | early_stopping_callback: 15 | _target_: pytorch_lightning.callbacks.EarlyStopping 16 | monitor: ${callbacks_monitor} 17 | mode: ${callbacks_mode} 18 | patience: 5 19 | 20 | # model_checkpoint_callback 21 | # "model_checkpoint_callback: null" will disable model checkpointing 22 | model_checkpoint_callback: 23 | _target_: classy.pl_callbacks.best_checkpoint.ModelCheckpointWithBest 24 | monitor: ${callbacks_monitor} 25 | mode: ${callbacks_mode} 26 | verbose: True 27 | save_top_k: 3 28 | dirpath: checkpoints 29 | save_last: true 30 | 31 | resume_from: null 32 | -------------------------------------------------------------------------------- /data/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/litus-ai/classy/6ce11571d3ac193c9e1729afcbe74e6342838160/data/.placeholder -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | # Dependencies 2 | /node_modules 3 | 4 | # Production 5 | /build 6 | 7 | # Generated files 8 | .docusaurus 9 | .cache-loader 10 | 11 | # Misc 12 | .DS_Store 13 | .env.local 14 | .env.development.local 15 | .env.test.local 16 | .env.production.local 17 | 18 | npm-debug.log* 19 | yarn-debug.log* 20 | yarn-error.log* 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Website 2 | 3 | This website is built using [Docusaurus 2](https://docusaurus.io/), a modern static website generator. 4 | 5 | ### Local Development 6 | 7 | We suggest using Docker to avoid npm/yarn installations. Simply run: 8 | ``` 9 | $ bash local-dev.sh 10 | ``` 11 | and follow the interactive prompt. 12 | 13 | ### Build 14 | 15 | ``` 16 | $ yarn build 17 | ``` 18 | 19 | This command generates static content into the `build` directory and can be served using any static contents hosting service. 20 | 21 | ### Deployment 22 | 23 | ``` 24 | $ GIT_USER= USE_SSH=true yarn deploy 25 | ``` 26 | 27 | If you are using GitHub pages for hosting, this command is a convenient way to build the website and push to the `gh-pages` branch. 28 | -------------------------------------------------------------------------------- /docs/babel.config.js: -------------------------------------------------------------------------------- 1 | module.exports = { 2 | presets: [require.resolve('@docusaurus/core/lib/babel/preset')], 3 | }; 4 | -------------------------------------------------------------------------------- /docs/docs/getting-started/_category_.json: -------------------------------------------------------------------------------- 1 | { 2 | "label": "Getting Started", 3 | "position": 2 4 | } 5 | -------------------------------------------------------------------------------- /docs/docs/getting-started/basic/_category_.json: -------------------------------------------------------------------------------- 1 | { 2 | "label": "ML without Coding", 3 | "position": 1 4 | } 5 | -------------------------------------------------------------------------------- /docs/docs/getting-started/basic/choosing-profile.md: -------------------------------------------------------------------------------- 1 | --- 2 | sidebar_position: 3 3 | title: Choosing a profile 4 | --- 5 | 6 | :::tip 7 | This step is **not mandatory**, but we highly recommend you to read it as it touches an important component of 8 | `classy`, the profiles, which is needed in case you want to heavily modify your training configuration. 9 | 10 | ::: 11 | 12 | It might be the case that you have constraints of any sort (hardware, performance-wise, etc.), and you might 13 | be interested in knowing how to change the default underlying model / optimizer used to train in order to either 14 | fit in smaller GPUs, be faster, or achieve higher accuracy. 15 | 16 | In `classy`, this is achieved through *Profiles*, which a user can employ as a way of changing the training configuration 17 | of their model to fit different criteria. 18 | 19 | `classy` comes with a predefined set of profiles, which you can find [here](/docs/reference-manual/profiles/). 20 | The list includes the underlying transformer model, optimizer and a few key features that each profile shines for. 21 | 22 | For this tutorial, we'll stick with a fast yet powerful model, *DistilBERT*. 23 | -------------------------------------------------------------------------------- /docs/docs/getting-started/basic/data-formatting.md: -------------------------------------------------------------------------------- 1 | --- 2 | sidebar_position: 2 3 | title: Organizing your data 4 | --- 5 | 6 | ## Data Formatting 7 | 8 | `classy` requires data to be formatted in a specific way according to the task you're tackling (check out [Tasks and Input Formats](/docs/reference-manual/tasks-and-formats) in the documentation). 9 | 10 | In our case of **Named Entity Recognition** (i.e., *Token Classification*), we need the data to be formatted such that each line represents a single sample. 11 | For instance, taking again our running example of *Barack Obama visited Google in California*, we can format it as follows: 12 | 13 | ```text 14 | Barack Obama visited Google in California\tPER PER O ORG O LOC 15 | ``` 16 | 17 | That is, a TSV (tab-separated values) file which has a space-separated sequence of tokens as the first column, and 18 | a space-separated sequence of labels as the second column (both sequences **must have** the same number of elements). 19 | 20 | :::tip 21 | `classy` by default supports `.tsv` and `.jsonl` as input formats (see [the documentation](/docs/reference-manual/tasks-and-formats)), 22 | but you can [add custom formats](/docs/getting-started/customizing-things/custom-data-format/) as well. 23 | ::: 24 | 25 | If your dataset is already formatted like this, great! Otherwise, this is the only bit where coding is required. 26 | You can either convert it yourself (via a python or bash script, whatever you're comfortable with), or you can register 27 | a [custom data reader](/docs/getting-started/customizing-things/custom-data-format/) to support your dataset format. 28 | 29 | 30 | ## Organizing Datasets 31 | In `classy`, as in standard machine learning projects, the most simple way to organize your datasets is to create 32 | a directory containing the train, validation and test datasets. 33 | ``` 34 | data/ner-data 35 | ├── train.tsv 36 | ├── validation.tsv 37 | └── test.tsv 38 | ``` 39 | 40 | In this way, `classy` will automatically infer the splits of your dataset from the directory structure. 41 | 42 | :::tip 43 | If you have multiple training files, or you want to specify the splits using a different directory structure, you can 44 | use a _training coordinates_ file. You can find a complete guide on how to do it in the 45 | [Reference Manual](/docs/reference-manual/cli/train/). 46 | ::: 47 | -------------------------------------------------------------------------------- /docs/docs/getting-started/basic/inference.md: -------------------------------------------------------------------------------- 1 | --- 2 | sidebar_position: 5 3 | title: Using your trained model 4 | --- 5 | 6 | import ReactTermynal from '/src/components/termynal'; 7 | 8 | Now that we have our trained model called `fast-ner`, stored under `experiments/fast-ner//