├── .github └── workflows │ ├── build-docs.yml │ ├── python-publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── docs ├── .gitignore ├── generate_docs.sh ├── generate_docs_netlify.sh ├── html │ └── .keep └── source │ ├── api │ └── index.rst │ ├── conf.py │ ├── index.rst │ ├── modules.rst │ ├── quaterion_models.encoders.encoder.rst │ ├── quaterion_models.encoders.extras.fasttext_encoder.rst │ ├── quaterion_models.encoders.extras.rst │ ├── quaterion_models.encoders.rst │ ├── quaterion_models.encoders.switch_encoder.rst │ ├── quaterion_models.heads.empty_head.rst │ ├── quaterion_models.heads.encoder_head.rst │ ├── quaterion_models.heads.gated_head.rst │ ├── quaterion_models.heads.rst │ ├── quaterion_models.heads.sequential_head.rst │ ├── quaterion_models.heads.skip_connection_head.rst │ ├── quaterion_models.heads.softmax_head.rst │ ├── quaterion_models.heads.stacked_projection_head.rst │ ├── quaterion_models.heads.widening_head.rst │ ├── quaterion_models.model.rst │ ├── quaterion_models.modules.rst │ ├── quaterion_models.modules.simple.rst │ ├── quaterion_models.rst │ ├── quaterion_models.types.rst │ ├── quaterion_models.utils.classes.rst │ ├── quaterion_models.utils.rst │ └── quaterion_models.utils.tensors.rst ├── netlify.toml ├── poetry.lock ├── pyproject.toml ├── quaterion_models ├── __init__.py ├── encoders │ ├── __init__.py │ ├── encoder.py │ ├── extras │ │ ├── __init__.py │ │ └── fasttext_encoder.py │ └── switch_encoder.py ├── heads │ ├── __init__.py │ ├── empty_head.py │ ├── encoder_head.py │ ├── gated_head.py │ ├── sequential_head.py │ ├── skip_connection_head.py │ ├── softmax_head.py │ ├── stacked_projection_head.py │ ├── switch_head.py │ └── widening_head.py ├── model.py ├── modules │ ├── __init__.py │ └── simple.py ├── types │ └── __init__.py └── utils │ ├── __init__.py │ ├── classes.py │ ├── meta.py │ └── tensors.py └── tests ├── __init__.py ├── encoders ├── test_fasttext_encoder.py ├── test_switch_encoder.py └── test_switch_head.py ├── heads └── test_head.py ├── test_model.py └── utils └── test_classes.py /.github/workflows/build-docs.yml: -------------------------------------------------------------------------------- 1 | name: build-docs 2 | on: 3 | pull_request_target: 4 | types: 5 | - closed 6 | 7 | jobs: 8 | build: 9 | if: github.event.pull_request.merged == true 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | with: 14 | fetch-depth: 0 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: '3.9.x' 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install poetry 22 | poetry install -E fasttext 23 | poetry run pip install 'setuptools==59.5.0' # temporary fix for https://github.com/pytorch/pytorch/pull/69904 24 | - name: Generate docs 25 | run: | 26 | bash -x docs/generate_docs.sh 27 | git config user.name 'qdrant' 28 | git config user.email 'qdrant@users.noreply.github.com' 29 | git config pull.rebase false 30 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY 31 | git checkout $GITHUB_HEAD_REF 32 | git add ./docs && git commit -m "docs: auto-generate docs with sphinx" && git pull && git push || true 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | workflow_dispatch: 13 | push: 14 | # Pattern matched against refs/tags 15 | tags: 16 | - 'v*' # Push events to every version tag 17 | 18 | 19 | jobs: 20 | deploy: 21 | 22 | runs-on: ubuntu-latest 23 | 24 | steps: 25 | - uses: actions/checkout@v2 26 | - name: Set up Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: '3.9.x' 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install poetry 33 | poetry install 34 | - name: Build package 35 | run: poetry build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | fail-fast: true 10 | matrix: 11 | python-version: [3.8, 3.9, "3.10"] 12 | os: [ubuntu-latest] # [ubuntu-latest, macOS-latest, windows-latest] 13 | 14 | steps: 15 | - uses: actions/checkout@v1 16 | 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip poetry 25 | poetry install -E fasttext 26 | poetry run pip install 'setuptools==59.5.0' # temporary fix for https://github.com/pytorch/pytorch/pull/69904 27 | 28 | - name: Unit tests 29 | run: | 30 | poetry run pytest 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | .idea/ 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | # pytype static type analyzer 137 | .pytype/ 138 | 139 | # Cython debug symbols 140 | cython_debug/ 141 | 142 | # Do not include auto-generated setup.py 143 | /setup.py 144 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_language_version: 4 | python: python3.8 5 | 6 | ci: 7 | autofix_prs: true 8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 9 | autoupdate_schedule: quarterly 10 | # submodules: true 11 | 12 | repos: 13 | - repo: https://github.com/pre-commit/pre-commit-hooks 14 | rev: v4.4.0 15 | hooks: 16 | - id: trailing-whitespace 17 | - id: end-of-file-fixer 18 | - id: check-added-large-files 19 | 20 | - repo: https://github.com/psf/black 21 | rev: 22.12.0 22 | hooks: 23 | - id: black 24 | name: "Black: The uncompromising Python code formatter" 25 | 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 5.11.4 28 | hooks: 29 | - id: isort 30 | name: "Sort Imports" 31 | args: ["--profile", "black"] 32 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | email: team@qdrant.tech. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Quaterion Models 2 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's: 3 | 4 | - Reporting a bug 5 | - Discussing the current state of the code 6 | - Submitting a fix 7 | - Proposing new features 8 | 9 | ## We Develop with GitHub 10 | We use github to host code, to track issues and feature requests, as well as accept pull requests. 11 | 12 | ## We Use [Github Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests 13 | Pull requests are the best way to propose changes to the codebase (we use [Github Flow](https://guides.github.com/introduction/flow/index.html)). 14 | We actively welcome your pull requests: 15 | 16 | 1. Fork the repo and create your branch from `master`. 17 | 2. If you've added code that should be tested, add tests. 18 | 3. Ensure the test suite passes. 19 | 4. Make sure your code lints (ToDo). 20 | 5. Make sure that commits have a reference to related issue (e.g. `Fix model training #num_of_issue`) 21 | 6. Issue that pull request! 22 | 23 | ## Any contributions you make will be under the Apache License 2.0 24 | In short, when you submit code changes, your submissions are understood to be under the same [Apache License 2.0](https://choosealicense.com/licenses/apache-2.0/) that covers the project. Feel free to contact the maintainers if that's a concern. 25 | 26 | ## Report bugs using Github's [issues](https://github.com/qdrant/quaterion-models/issues) 27 | We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/qdrant/quaterion-models/issues/new); it's that easy! 28 | 29 | ## Write bug reports with detail, background, and sample code 30 | 31 | **Great Bug Reports** tend to have: 32 | 33 | - A quick summary and/or background 34 | - Steps to reproduce 35 | - Be specific! 36 | - Give sample code if you can. 37 | - What you expected would happen 38 | - What actually happens 39 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work) 40 | 41 | ## Coding Style 42 | 43 | 1. We use [PEP8](https://www.python.org/dev/peps/pep-0008/) code style 44 | 2. We use [Python Type Annotations](https://docs.python.org/3/library/typing.html) whenever it is necessary 45 | 1. If your IDE cannot infer type of some variable, it is a good sign to add some more type annotations 46 | 3. We document tensor transformations - type of tensors are usually not enough for comfortable understanding of the code 47 | 4. We prefer simplicity and practical approach over kaggle-level state-of-the-art accuracy 48 | 1. If some modules or layers have complicated interface, dependencies, or just very complicated internally - we would prefer to keep them outside Quaterion Models. 49 | 50 | ## License 51 | By contributing, you agree that your contributions will be licensed under its Apache License 2.0. 52 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Quaterion Models 2 | 3 | `quaterion-models` is a part of [`Quaterion`](https://github.com/qdrant/quaterion), similarity learning framework. 4 | It is kept as a separate package to make servable models lightweight and free from training dependencies. 5 | 6 | It contains definition of base classes, used for model inference, as well as the collection of building blocks for building fine-tunable similarity learning models. 7 | The documentation can be found [here](https://quaterion-models.qdrant.tech/). 8 | 9 | If you are looking for the training-related part of Quaterion, please see the [main repository](https://github.com/qdrant/quaterion) instead. 10 | 11 | ## Install 12 | 13 | ```bash 14 | pip install quaterion-models 15 | ``` 16 | 17 | It makes sense to install `quaterion-models` independent of the main framework if you already have trained model 18 | and only need to make inference. 19 | 20 | ## Load and inference 21 | 22 | ```python 23 | from quaterion_models import SimilarityModel 24 | 25 | model = SimilarityModel.load("./path/to/saved/model") 26 | 27 | embeddings = model.encode([ 28 | {"description": "this is an example input"}, 29 | {"description": "you may have a different format"}, 30 | {"description": "the output will be a numpy array"}, 31 | {"description": "of size [batch_size, embedding_size]"}, 32 | ]) 33 | ``` 34 | 35 | ## Content 36 | 37 | * `SimilarityModel` - main class which contains encoder models with the head layer 38 | * Base class for Encoders 39 | * Base class and various implementations of the Head Layers 40 | * Additional helper functions 41 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | .doctrees/ 2 | html/ 3 | -------------------------------------------------------------------------------- /docs/generate_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | set -e 4 | 5 | # Ensure current path is project root 6 | cd "$(dirname "$0")/../" 7 | 8 | poetry run sphinx-apidoc -f -e -o docs/source quaterion_models 9 | poetry run sphinx-build docs/source docs/html 10 | -------------------------------------------------------------------------------- /docs/generate_docs_netlify.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | set -xe 4 | 5 | # Ensure current path is project root 6 | cd "$(dirname "$0")/../" 7 | 8 | # install CPU torch, cause it is smaller 9 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu 10 | 11 | pip install poetry 12 | poetry build -f wheel 13 | pip install dist/$(ls -1 dist | grep .whl) 14 | poetry install --extras "fasttext" 15 | 16 | pip install sphinx>=4.4.0 17 | pip install "git+https://github.com/qdrant/qdrant_sphinx_theme.git@master#egg=qdrant-sphinx-theme" 18 | 19 | sphinx-build docs/source docs/html 20 | -------------------------------------------------------------------------------- /docs/html/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/docs/html/.keep -------------------------------------------------------------------------------- /docs/source/api/index.rst: -------------------------------------------------------------------------------- 1 | API References 2 | ~~~~~~~~~~~~~~ 3 | 4 | Similarity Model 5 | ---------------- 6 | 7 | .. py:currentmodule:: quaterion_models.model 8 | 9 | .. autosummary:: 10 | :nosignatures: 11 | 12 | SimilarityModel 13 | 14 | Encoders 15 | -------- 16 | 17 | .. py:currentmodule:: quaterion_models.encoders 18 | 19 | .. autosummary:: 20 | :nosignatures: 21 | 22 | ~encoder.Encoder 23 | ~switch_encoder.SwitchEncoder 24 | ~extras.fasttext_encoder.FasttextEncoder 25 | 26 | 27 | Head Layers 28 | ----------- 29 | 30 | .. py:currentmodule:: quaterion_models.heads 31 | 32 | .. autosummary:: 33 | :nosignatures: 34 | 35 | ~empty_head.EmptyHead 36 | ~gated_head.GatedHead 37 | ~sequential_head.SequentialHead 38 | ~skip_connection_head.SkipConnectionHead 39 | ~softmax_head.SoftmaxEmbeddingsHead 40 | ~stacked_projection_head.StackedProjectionHead 41 | ~widening_head.WideningHead 42 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | # import os 14 | # import sys 15 | # sys.path.insert(0, os.path.abspath('.')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = "Quaterion Models" 21 | copyright = "2022, Quaterion Models Authors" 22 | author = "Quaterion Models Authors" 23 | 24 | 25 | # -- General configuration --------------------------------------------------- 26 | 27 | # Add any Sphinx extension module names here, as strings. They can be 28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 29 | # ones. 30 | extensions = [ 31 | "sphinx.ext.napoleon", 32 | "sphinx.ext.autodoc", 33 | "sphinx.ext.viewcode", 34 | "sphinx.ext.intersphinx", 35 | "sphinx.ext.autosummary", 36 | ] 37 | 38 | # mapping to other sphinx doc 39 | # tuple: (target, inventory) 40 | # Each target is the base URI of a foreign Sphinx documentation set and can be a local path or an 41 | # HTTP URI. The inventory indicates where the inventory file can be found: it can be None (an 42 | # objects.inv file at the same location as the base URI) or another local file path or a full 43 | # HTTP URI to an inventory file. 44 | intersphinx_mapping = { 45 | "gensim": ("https://radimrehurek.com/gensim/", None), 46 | } 47 | 48 | # prevents sphinx from adding full path to type hints 49 | autodoc_typehints_format = "short" 50 | 51 | # order members by type and not alphabetically, it prevents mixing of class attributes 52 | # and methods 53 | autodoc_member_order = "groupwise" 54 | 55 | # moves ``Return type`` to ``Returns`` 56 | napoleon_use_rtype = False 57 | 58 | # If true, suppress the module name of the python reference if it can be resolved. 59 | # Experimental feature: 60 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-python_use_unqualified_type_names 61 | python_use_unqualified_type_names = True 62 | 63 | # prevents sphinx to add full module path in titles 64 | add_module_names = False 65 | 66 | # Add any paths that contain templates here, relative to this directory. 67 | templates_path = ["_templates"] 68 | 69 | # prevents unfolding type hints 70 | autodoc_type_aliases = { 71 | "TensorInterchange": "TensorInterchange", 72 | "CollateFnType": "CollateFnType", 73 | } 74 | 75 | # List of patterns, relative to source directory, that match files and 76 | # directories to ignore when looking for source files. 77 | # This pattern also affects html_static_path and html_extra_path. 78 | exclude_patterns = [] 79 | # -- Options for HTML output ------------------------------------------------- 80 | 81 | # The theme to use for HTML and HTML Help pages. See the documentation for 82 | # a list of builtin themes. 83 | # 84 | 85 | html_theme = "qdrant_sphinx_theme" 86 | 87 | # Add any paths that contain custom static files (such as style sheets) here, 88 | # relative to this directory. They are copied after the builtin static files, 89 | # so a file named "default.css" will overwrite the builtin "default.css". 90 | # html_static_path = [] 91 | 92 | # Files excluded via exclude_patterns still being generating by sphinx-apidoc 93 | # As they are generated, some documents have links to them. It leads to a warning like: 94 | # `WARNING: toctree contains reference to excluded document '...'`. 95 | # suppress_warnings allows to remove such warnings 96 | suppress_warnings = ["toc.excluded"] 97 | 98 | html_theme_options = { 99 | # google analytics can be added here 100 | "logo_only": False, 101 | "display_version": True, 102 | "prev_next_buttons_location": "bottom", 103 | "style_external_links": False, 104 | # Toc options 105 | "collapse_navigation": True, 106 | "sticky_navigation": True, 107 | "titles_only": False, 108 | "qdrant_project": "quaterion-models", 109 | "qdrant_logo": "/_static/images/quaterion_logo_horisontal_opt.svg", 110 | } 111 | 112 | # default is false 113 | _FAST_DOCS_DEV = False 114 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to quaterion-models's documentation! 2 | ============================================ 3 | 4 | 5 | ``quaterion-models`` is a part of 6 | `Quaterion `_, similarity 7 | learning framework. It is kept as a separate package to make servable 8 | models lightweight and free from training dependencies. 9 | 10 | It contains definition of base classes, used for model inference, as 11 | well as the collection of building blocks for building fine-tunable 12 | similarity learning models. 13 | 14 | If you are looking for the training-related part of Quaterion, please 15 | see the `main repository `_ 16 | instead. 17 | 18 | Install 19 | ------- 20 | 21 | .. code:: bash 22 | 23 | pip install quaterion-models 24 | 25 | It makes sense to install ``quaterion-models`` independent of the main 26 | framework if you already have trained model and only need to make 27 | inference. 28 | 29 | Load and inference 30 | ------------------ 31 | 32 | .. code:: python 33 | 34 | from quaterion_models import SimilarityModel 35 | 36 | model = SimilarityModel.load("./path/to/saved/model") 37 | 38 | embeddings = model.encode([ 39 | {"description": "this is an example input"}, 40 | {"description": "you may have a different format"}, 41 | {"description": "the output will be a numpy array"}, 42 | {"description": "of size [batch_size, embedding_size]"}, 43 | ]) 44 | 45 | Content 46 | ------- 47 | 48 | .. toctree:: 49 | :maxdepth: 1 50 | 51 | api/index 52 | 53 | 54 | Indices and tables 55 | ================== 56 | 57 | * :ref:`genindex` 58 | * :ref:`modindex` 59 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | quaterion_models 2 | ================ 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | quaterion_models 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.encoders.encoder.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.encoders.encoder module 2 | ========================================= 3 | 4 | .. automodule:: quaterion_models.encoders.encoder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.encoders.extras.fasttext_encoder.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.encoders.extras.fasttext\_encoder module 2 | ========================================================== 3 | 4 | .. automodule:: quaterion_models.encoders.extras.fasttext_encoder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.encoders.extras.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.encoders.extras package 2 | ========================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.encoders.extras.fasttext_encoder 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: quaterion_models.encoders.extras 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.encoders.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.encoders package 2 | ================================== 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.encoders.extras 11 | 12 | Submodules 13 | ---------- 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | 18 | quaterion_models.encoders.encoder 19 | quaterion_models.encoders.switch_encoder 20 | 21 | Module contents 22 | --------------- 23 | 24 | .. automodule:: quaterion_models.encoders 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.encoders.switch_encoder.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.encoders.switch\_encoder module 2 | ================================================= 3 | 4 | .. automodule:: quaterion_models.encoders.switch_encoder 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.empty_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.empty\_head module 2 | ========================================== 3 | 4 | .. automodule:: quaterion_models.heads.empty_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.encoder_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.encoder\_head module 2 | ============================================ 3 | 4 | .. automodule:: quaterion_models.heads.encoder_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.gated_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.gated\_head module 2 | ========================================== 3 | 4 | .. automodule:: quaterion_models.heads.gated_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.heads.empty_head 11 | quaterion_models.heads.encoder_head 12 | quaterion_models.heads.gated_head 13 | quaterion_models.heads.sequential_head 14 | quaterion_models.heads.skip_connection_head 15 | quaterion_models.heads.softmax_head 16 | quaterion_models.heads.stacked_projection_head 17 | quaterion_models.heads.widening_head 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: quaterion_models.heads 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.sequential_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.sequential\_head module 2 | =============================================== 3 | 4 | .. automodule:: quaterion_models.heads.sequential_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.skip_connection_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.skip\_connection\_head module 2 | ===================================================== 3 | 4 | .. automodule:: quaterion_models.heads.skip_connection_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.softmax_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.softmax\_head module 2 | ============================================ 3 | 4 | .. automodule:: quaterion_models.heads.softmax_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.stacked_projection_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.stacked\_projection\_head module 2 | ======================================================== 3 | 4 | .. automodule:: quaterion_models.heads.stacked_projection_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.heads.widening_head.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.heads.widening\_head module 2 | ============================================= 3 | 4 | .. automodule:: quaterion_models.heads.widening_head 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.model.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.model module 2 | ============================== 3 | 4 | .. automodule:: quaterion_models.model 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.modules.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.modules package 2 | ================================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.modules.simple 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: quaterion_models.modules 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.modules.simple.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.modules.simple module 2 | ======================================= 3 | 4 | .. automodule:: quaterion_models.modules.simple 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models package 2 | ========================= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.encoders 11 | quaterion_models.heads 12 | quaterion_models.modules 13 | quaterion_models.types 14 | quaterion_models.utils 15 | 16 | Submodules 17 | ---------- 18 | 19 | .. toctree:: 20 | :maxdepth: 4 21 | 22 | quaterion_models.model 23 | 24 | Module contents 25 | --------------- 26 | 27 | .. automodule:: quaterion_models 28 | :members: 29 | :undoc-members: 30 | :show-inheritance: 31 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.types.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.types package 2 | =============================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: quaterion_models.types 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.utils.classes.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.utils.classes module 2 | ====================================== 3 | 4 | .. automodule:: quaterion_models.utils.classes 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.utils.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.utils package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | quaterion_models.utils.classes 11 | quaterion_models.utils.tensors 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: quaterion_models.utils 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /docs/source/quaterion_models.utils.tensors.rst: -------------------------------------------------------------------------------- 1 | quaterion\_models.utils.tensors module 2 | ====================================== 3 | 4 | .. automodule:: quaterion_models.utils.tensors 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /netlify.toml: -------------------------------------------------------------------------------- 1 | [build] 2 | publish = "docs/html" 3 | command = "bash -x docs/generate_docs_netlify.sh" 4 | -------------------------------------------------------------------------------- /poetry.lock: -------------------------------------------------------------------------------- 1 | [[package]] 2 | name = "alabaster" 3 | version = "0.7.12" 4 | description = "A configurable sidebar-enabled Sphinx theme" 5 | category = "dev" 6 | optional = false 7 | python-versions = "*" 8 | 9 | [[package]] 10 | name = "atomicwrites" 11 | version = "1.4.1" 12 | description = "Atomic file writes." 13 | category = "dev" 14 | optional = false 15 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 16 | 17 | [[package]] 18 | name = "attrs" 19 | version = "22.1.0" 20 | description = "Classes Without Boilerplate" 21 | category = "dev" 22 | optional = false 23 | python-versions = ">=3.5" 24 | 25 | [package.extras] 26 | dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"] 27 | docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"] 28 | tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"] 29 | tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"] 30 | 31 | [[package]] 32 | name = "babel" 33 | version = "2.10.3" 34 | description = "Internationalization utilities" 35 | category = "dev" 36 | optional = false 37 | python-versions = ">=3.6" 38 | 39 | [package.dependencies] 40 | pytz = ">=2015.7" 41 | 42 | [[package]] 43 | name = "black" 44 | version = "22.6.0" 45 | description = "The uncompromising code formatter." 46 | category = "dev" 47 | optional = false 48 | python-versions = ">=3.6.2" 49 | 50 | [package.dependencies] 51 | click = ">=8.0.0" 52 | mypy-extensions = ">=0.4.3" 53 | pathspec = ">=0.9.0" 54 | platformdirs = ">=2" 55 | tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""} 56 | typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""} 57 | 58 | [package.extras] 59 | colorama = ["colorama (>=0.4.3)"] 60 | d = ["aiohttp (>=3.7.4)"] 61 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] 62 | uvloop = ["uvloop (>=0.15.2)"] 63 | 64 | [[package]] 65 | name = "certifi" 66 | version = "2022.6.15" 67 | description = "Python package for providing Mozilla's CA Bundle." 68 | category = "dev" 69 | optional = false 70 | python-versions = ">=3.6" 71 | 72 | [[package]] 73 | name = "charset-normalizer" 74 | version = "2.1.1" 75 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." 76 | category = "dev" 77 | optional = false 78 | python-versions = ">=3.6.0" 79 | 80 | [package.extras] 81 | unicode_backport = ["unicodedata2"] 82 | 83 | [[package]] 84 | name = "click" 85 | version = "8.1.3" 86 | description = "Composable command line interface toolkit" 87 | category = "dev" 88 | optional = false 89 | python-versions = ">=3.7" 90 | 91 | [package.dependencies] 92 | colorama = {version = "*", markers = "platform_system == \"Windows\""} 93 | 94 | [[package]] 95 | name = "colorama" 96 | version = "0.4.5" 97 | description = "Cross-platform colored terminal text." 98 | category = "dev" 99 | optional = false 100 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 101 | 102 | [[package]] 103 | name = "docutils" 104 | version = "0.19" 105 | description = "Docutils -- Python Documentation Utilities" 106 | category = "dev" 107 | optional = false 108 | python-versions = ">=3.7" 109 | 110 | [[package]] 111 | name = "gensim" 112 | version = "4.2.0" 113 | description = "Python framework for fast Vector Space Modelling" 114 | category = "main" 115 | optional = true 116 | python-versions = ">=3.6" 117 | 118 | [package.dependencies] 119 | numpy = ">=1.17.0" 120 | scipy = ">=0.18.1" 121 | smart-open = ">=1.8.1" 122 | 123 | [package.extras] 124 | distributed = ["Pyro4 (>=4.27)"] 125 | docs = ["Pyro4 (>=4.27)", "annoy", "cython", "matplotlib", "memory-profiler", "mock", "nltk", "nmslib", "pandas", "pyemd", "pyro4", "pytest", "pytest-cov", "sphinx", "sphinx-gallery", "sphinxcontrib-napoleon", "sphinxcontrib.programoutput", "statsmodels", "testfixtures", "visdom (>0.1.8.7)"] 126 | test = ["cython", "mock", "nmslib", "pyemd", "pytest", "pytest-cov", "testfixtures", "visdom (>0.1.8.7)"] 127 | test-win = ["cython", "mock", "nmslib", "pyemd", "pytest", "pytest-cov", "testfixtures"] 128 | 129 | [[package]] 130 | name = "idna" 131 | version = "3.3" 132 | description = "Internationalized Domain Names in Applications (IDNA)" 133 | category = "dev" 134 | optional = false 135 | python-versions = ">=3.5" 136 | 137 | [[package]] 138 | name = "imagesize" 139 | version = "1.4.1" 140 | description = "Getting image size from png/jpeg/jpeg2000/gif file" 141 | category = "dev" 142 | optional = false 143 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" 144 | 145 | [[package]] 146 | name = "importlib-metadata" 147 | version = "4.12.0" 148 | description = "Read metadata from Python packages" 149 | category = "dev" 150 | optional = false 151 | python-versions = ">=3.7" 152 | 153 | [package.dependencies] 154 | zipp = ">=0.5" 155 | 156 | [package.extras] 157 | docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] 158 | perf = ["ipython"] 159 | testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] 160 | 161 | [[package]] 162 | name = "iniconfig" 163 | version = "1.1.1" 164 | description = "iniconfig: brain-dead simple config-ini parsing" 165 | category = "dev" 166 | optional = false 167 | python-versions = "*" 168 | 169 | [[package]] 170 | name = "jinja2" 171 | version = "3.1.2" 172 | description = "A very fast and expressive template engine." 173 | category = "dev" 174 | optional = false 175 | python-versions = ">=3.7" 176 | 177 | [package.dependencies] 178 | MarkupSafe = ">=2.0" 179 | 180 | [package.extras] 181 | i18n = ["Babel (>=2.7)"] 182 | 183 | [[package]] 184 | name = "markupsafe" 185 | version = "2.1.1" 186 | description = "Safely add untrusted strings to HTML/XML markup." 187 | category = "dev" 188 | optional = false 189 | python-versions = ">=3.7" 190 | 191 | [[package]] 192 | name = "mypy-extensions" 193 | version = "0.4.3" 194 | description = "Experimental type system extensions for programs checked with the mypy typechecker." 195 | category = "dev" 196 | optional = false 197 | python-versions = "*" 198 | 199 | [[package]] 200 | name = "numpy" 201 | version = "1.23.2" 202 | description = "NumPy is the fundamental package for array computing with Python." 203 | category = "main" 204 | optional = false 205 | python-versions = ">=3.8" 206 | 207 | [[package]] 208 | name = "packaging" 209 | version = "21.3" 210 | description = "Core utilities for Python packages" 211 | category = "dev" 212 | optional = false 213 | python-versions = ">=3.6" 214 | 215 | [package.dependencies] 216 | pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" 217 | 218 | [[package]] 219 | name = "pathspec" 220 | version = "0.9.0" 221 | description = "Utility library for gitignore style pattern matching of file paths." 222 | category = "dev" 223 | optional = false 224 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" 225 | 226 | [[package]] 227 | name = "platformdirs" 228 | version = "2.5.2" 229 | description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." 230 | category = "dev" 231 | optional = false 232 | python-versions = ">=3.7" 233 | 234 | [package.extras] 235 | docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"] 236 | test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"] 237 | 238 | [[package]] 239 | name = "pluggy" 240 | version = "1.0.0" 241 | description = "plugin and hook calling mechanisms for python" 242 | category = "dev" 243 | optional = false 244 | python-versions = ">=3.6" 245 | 246 | [package.extras] 247 | dev = ["pre-commit", "tox"] 248 | testing = ["pytest", "pytest-benchmark"] 249 | 250 | [[package]] 251 | name = "py" 252 | version = "1.11.0" 253 | description = "library with cross-python path, ini-parsing, io, code, log facilities" 254 | category = "dev" 255 | optional = false 256 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" 257 | 258 | [[package]] 259 | name = "pygments" 260 | version = "2.13.0" 261 | description = "Pygments is a syntax highlighting package written in Python." 262 | category = "dev" 263 | optional = false 264 | python-versions = ">=3.6" 265 | 266 | [package.extras] 267 | plugins = ["importlib-metadata"] 268 | 269 | [[package]] 270 | name = "pyparsing" 271 | version = "3.0.9" 272 | description = "pyparsing module - Classes and methods to define and execute parsing grammars" 273 | category = "dev" 274 | optional = false 275 | python-versions = ">=3.6.8" 276 | 277 | [package.extras] 278 | diagrams = ["jinja2", "railroad-diagrams"] 279 | 280 | [[package]] 281 | name = "pytest" 282 | version = "6.2.5" 283 | description = "pytest: simple powerful testing with Python" 284 | category = "dev" 285 | optional = false 286 | python-versions = ">=3.6" 287 | 288 | [package.dependencies] 289 | atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} 290 | attrs = ">=19.2.0" 291 | colorama = {version = "*", markers = "sys_platform == \"win32\""} 292 | iniconfig = "*" 293 | packaging = "*" 294 | pluggy = ">=0.12,<2.0" 295 | py = ">=1.8.2" 296 | toml = "*" 297 | 298 | [package.extras] 299 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"] 300 | 301 | [[package]] 302 | name = "pytz" 303 | version = "2022.2.1" 304 | description = "World timezone definitions, modern and historical" 305 | category = "dev" 306 | optional = false 307 | python-versions = "*" 308 | 309 | [[package]] 310 | name = "qdrant-sphinx-theme" 311 | version = "0.0.28" 312 | description = "Qudrant Sphinx Theme" 313 | category = "dev" 314 | optional = false 315 | python-versions = "*" 316 | develop = false 317 | 318 | [package.dependencies] 319 | sphinx = "*" 320 | 321 | [package.source] 322 | type = "git" 323 | url = "https://github.com/qdrant/qdrant_sphinx_theme.git" 324 | reference = "master" 325 | resolved_reference = "a90cdd5925783c2b0ed3b8d39897cd4eaf942e2a" 326 | 327 | [[package]] 328 | name = "requests" 329 | version = "2.28.1" 330 | description = "Python HTTP for Humans." 331 | category = "dev" 332 | optional = false 333 | python-versions = ">=3.7, <4" 334 | 335 | [package.dependencies] 336 | certifi = ">=2017.4.17" 337 | charset-normalizer = ">=2,<3" 338 | idna = ">=2.5,<4" 339 | urllib3 = ">=1.21.1,<1.27" 340 | 341 | [package.extras] 342 | socks = ["PySocks (>=1.5.6,!=1.5.7)"] 343 | use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] 344 | 345 | [[package]] 346 | name = "scipy" 347 | version = "1.9.1" 348 | description = "SciPy: Scientific Library for Python" 349 | category = "main" 350 | optional = true 351 | python-versions = ">=3.8,<3.12" 352 | 353 | [package.dependencies] 354 | numpy = ">=1.18.5,<1.25.0" 355 | 356 | [[package]] 357 | name = "smart-open" 358 | version = "6.1.0" 359 | description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)" 360 | category = "main" 361 | optional = true 362 | python-versions = ">=3.6,<4.0" 363 | 364 | [package.extras] 365 | all = ["azure-common", "azure-core", "azure-storage-blob", "boto3", "google-cloud-storage (>=1.31.0)", "requests"] 366 | azure = ["azure-common", "azure-core", "azure-storage-blob"] 367 | gcs = ["google-cloud-storage (>=1.31.0)"] 368 | http = ["requests"] 369 | s3 = ["boto3"] 370 | test = ["azure-common", "azure-core", "azure-storage-blob", "boto3", "google-cloud-storage (>=1.31.0)", "moto", "paramiko", "pathlib2", "pytest", "pytest-rerunfailures", "requests", "responses"] 371 | webhdfs = ["requests"] 372 | 373 | [[package]] 374 | name = "snowballstemmer" 375 | version = "2.2.0" 376 | description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms." 377 | category = "dev" 378 | optional = false 379 | python-versions = "*" 380 | 381 | [[package]] 382 | name = "sphinx" 383 | version = "5.1.1" 384 | description = "Python documentation generator" 385 | category = "dev" 386 | optional = false 387 | python-versions = ">=3.6" 388 | 389 | [package.dependencies] 390 | alabaster = ">=0.7,<0.8" 391 | babel = ">=1.3" 392 | colorama = {version = ">=0.3.5", markers = "sys_platform == \"win32\""} 393 | docutils = ">=0.14,<0.20" 394 | imagesize = "*" 395 | importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} 396 | Jinja2 = ">=2.3" 397 | packaging = "*" 398 | Pygments = ">=2.0" 399 | requests = ">=2.5.0" 400 | snowballstemmer = ">=1.1" 401 | sphinxcontrib-applehelp = "*" 402 | sphinxcontrib-devhelp = "*" 403 | sphinxcontrib-htmlhelp = ">=2.0.0" 404 | sphinxcontrib-jsmath = "*" 405 | sphinxcontrib-qthelp = "*" 406 | sphinxcontrib-serializinghtml = ">=1.1.5" 407 | 408 | [package.extras] 409 | docs = ["sphinxcontrib-websupport"] 410 | lint = ["docutils-stubs", "flake8 (>=3.5.0)", "flake8-bugbear", "flake8-comprehensions", "isort", "mypy (>=0.971)", "sphinx-lint", "types-requests", "types-typed-ast"] 411 | test = ["cython", "html5lib", "pytest (>=4.6)", "typed-ast"] 412 | 413 | [[package]] 414 | name = "sphinxcontrib-applehelp" 415 | version = "1.0.2" 416 | description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books" 417 | category = "dev" 418 | optional = false 419 | python-versions = ">=3.5" 420 | 421 | [package.extras] 422 | lint = ["docutils-stubs", "flake8", "mypy"] 423 | test = ["pytest"] 424 | 425 | [[package]] 426 | name = "sphinxcontrib-devhelp" 427 | version = "1.0.2" 428 | description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document." 429 | category = "dev" 430 | optional = false 431 | python-versions = ">=3.5" 432 | 433 | [package.extras] 434 | lint = ["docutils-stubs", "flake8", "mypy"] 435 | test = ["pytest"] 436 | 437 | [[package]] 438 | name = "sphinxcontrib-htmlhelp" 439 | version = "2.0.0" 440 | description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files" 441 | category = "dev" 442 | optional = false 443 | python-versions = ">=3.6" 444 | 445 | [package.extras] 446 | lint = ["docutils-stubs", "flake8", "mypy"] 447 | test = ["html5lib", "pytest"] 448 | 449 | [[package]] 450 | name = "sphinxcontrib-jsmath" 451 | version = "1.0.1" 452 | description = "A sphinx extension which renders display math in HTML via JavaScript" 453 | category = "dev" 454 | optional = false 455 | python-versions = ">=3.5" 456 | 457 | [package.extras] 458 | test = ["flake8", "mypy", "pytest"] 459 | 460 | [[package]] 461 | name = "sphinxcontrib-qthelp" 462 | version = "1.0.3" 463 | description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document." 464 | category = "dev" 465 | optional = false 466 | python-versions = ">=3.5" 467 | 468 | [package.extras] 469 | lint = ["docutils-stubs", "flake8", "mypy"] 470 | test = ["pytest"] 471 | 472 | [[package]] 473 | name = "sphinxcontrib-serializinghtml" 474 | version = "1.1.5" 475 | description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)." 476 | category = "dev" 477 | optional = false 478 | python-versions = ">=3.5" 479 | 480 | [package.extras] 481 | lint = ["docutils-stubs", "flake8", "mypy"] 482 | test = ["pytest"] 483 | 484 | [[package]] 485 | name = "toml" 486 | version = "0.10.2" 487 | description = "Python Library for Tom's Obvious, Minimal Language" 488 | category = "dev" 489 | optional = false 490 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" 491 | 492 | [[package]] 493 | name = "tomli" 494 | version = "2.0.1" 495 | description = "A lil' TOML parser" 496 | category = "dev" 497 | optional = false 498 | python-versions = ">=3.7" 499 | 500 | [[package]] 501 | name = "torch" 502 | version = "1.12.1" 503 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" 504 | category = "main" 505 | optional = false 506 | python-versions = ">=3.7.0" 507 | 508 | [package.dependencies] 509 | typing-extensions = "*" 510 | 511 | [[package]] 512 | name = "typing-extensions" 513 | version = "4.3.0" 514 | description = "Backported and Experimental Type Hints for Python 3.7+" 515 | category = "main" 516 | optional = false 517 | python-versions = ">=3.7" 518 | 519 | [[package]] 520 | name = "urllib3" 521 | version = "1.26.12" 522 | description = "HTTP library with thread-safe connection pooling, file post, and more." 523 | category = "dev" 524 | optional = false 525 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4" 526 | 527 | [package.extras] 528 | brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] 529 | secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] 530 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] 531 | 532 | [[package]] 533 | name = "zipp" 534 | version = "3.8.1" 535 | description = "Backport of pathlib-compatible object wrapper for zip files" 536 | category = "dev" 537 | optional = false 538 | python-versions = ">=3.7" 539 | 540 | [package.extras] 541 | docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"] 542 | testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] 543 | 544 | [extras] 545 | fasttext = ["gensim"] 546 | 547 | [metadata] 548 | lock-version = "1.1" 549 | python-versions = ">=3.8,<3.11" 550 | content-hash = "75d18a64b5e87ca3592968a5360ed2aef2dbfe9a2abe9af29ac5b9a34c59fcf8" 551 | 552 | [metadata.files] 553 | alabaster = [ 554 | {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"}, 555 | {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"}, 556 | ] 557 | atomicwrites = [ 558 | {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"}, 559 | ] 560 | attrs = [ 561 | {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"}, 562 | {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"}, 563 | ] 564 | babel = [ 565 | {file = "Babel-2.10.3-py3-none-any.whl", hash = "sha256:ff56f4892c1c4bf0d814575ea23471c230d544203c7748e8c68f0089478d48eb"}, 566 | {file = "Babel-2.10.3.tar.gz", hash = "sha256:7614553711ee97490f732126dc077f8d0ae084ebc6a96e23db1482afabdb2c51"}, 567 | ] 568 | black = [ 569 | {file = "black-22.6.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f586c26118bc6e714ec58c09df0157fe2d9ee195c764f630eb0d8e7ccce72e69"}, 570 | {file = "black-22.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b270a168d69edb8b7ed32c193ef10fd27844e5c60852039599f9184460ce0807"}, 571 | {file = "black-22.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6797f58943fceb1c461fb572edbe828d811e719c24e03375fd25170ada53825e"}, 572 | {file = "black-22.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c85928b9d5f83b23cee7d0efcb310172412fbf7cb9d9ce963bd67fd141781def"}, 573 | {file = "black-22.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6fe02afde060bbeef044af7996f335fbe90b039ccf3f5eb8f16df8b20f77666"}, 574 | {file = "black-22.6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cfaf3895a9634e882bf9d2363fed5af8888802d670f58b279b0bece00e9a872d"}, 575 | {file = "black-22.6.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94783f636bca89f11eb5d50437e8e17fbc6a929a628d82304c80fa9cd945f256"}, 576 | {file = "black-22.6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2ea29072e954a4d55a2ff58971b83365eba5d3d357352a07a7a4df0d95f51c78"}, 577 | {file = "black-22.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e439798f819d49ba1c0bd9664427a05aab79bfba777a6db94fd4e56fae0cb849"}, 578 | {file = "black-22.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187d96c5e713f441a5829e77120c269b6514418f4513a390b0499b0987f2ff1c"}, 579 | {file = "black-22.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:074458dc2f6e0d3dab7928d4417bb6957bb834434516f21514138437accdbe90"}, 580 | {file = "black-22.6.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a218d7e5856f91d20f04e931b6f16d15356db1c846ee55f01bac297a705ca24f"}, 581 | {file = "black-22.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:568ac3c465b1c8b34b61cd7a4e349e93f91abf0f9371eda1cf87194663ab684e"}, 582 | {file = "black-22.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c1734ab264b8f7929cef8ae5f900b85d579e6cbfde09d7387da8f04771b51c6"}, 583 | {file = "black-22.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a3ac16efe9ec7d7381ddebcc022119794872abce99475345c5a61aa18c45ad"}, 584 | {file = "black-22.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:b9fd45787ba8aa3f5e0a0a98920c1012c884622c6c920dbe98dbd05bc7c70fbf"}, 585 | {file = "black-22.6.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ba9be198ecca5031cd78745780d65a3f75a34b2ff9be5837045dce55db83d1c"}, 586 | {file = "black-22.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a3db5b6409b96d9bd543323b23ef32a1a2b06416d525d27e0f67e74f1446c8f2"}, 587 | {file = "black-22.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:560558527e52ce8afba936fcce93a7411ab40c7d5fe8c2463e279e843c0328ee"}, 588 | {file = "black-22.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b154e6bbde1e79ea3260c4b40c0b7b3109ffcdf7bc4ebf8859169a6af72cd70b"}, 589 | {file = "black-22.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:4af5bc0e1f96be5ae9bd7aaec219c901a94d6caa2484c21983d043371c733fc4"}, 590 | {file = "black-22.6.0-py3-none-any.whl", hash = "sha256:ac609cf8ef5e7115ddd07d85d988d074ed00e10fbc3445aee393e70164a2219c"}, 591 | {file = "black-22.6.0.tar.gz", hash = "sha256:6c6d39e28aed379aec40da1c65434c77d75e65bb59a1e1c283de545fb4e7c6c9"}, 592 | ] 593 | certifi = [ 594 | {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"}, 595 | {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"}, 596 | ] 597 | charset-normalizer = [ 598 | {file = "charset-normalizer-2.1.1.tar.gz", hash = "sha256:5a3d016c7c547f69d6f81fb0db9449ce888b418b5b9952cc5e6e66843e9dd845"}, 599 | {file = "charset_normalizer-2.1.1-py3-none-any.whl", hash = "sha256:83e9a75d1911279afd89352c68b45348559d1fc0506b054b346651b5e7fee29f"}, 600 | ] 601 | click = [ 602 | {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"}, 603 | {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, 604 | ] 605 | colorama = [ 606 | {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, 607 | {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, 608 | ] 609 | docutils = [] 610 | gensim = [] 611 | idna = [ 612 | {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, 613 | {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, 614 | ] 615 | imagesize = [ 616 | {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"}, 617 | {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, 618 | ] 619 | importlib-metadata = [ 620 | {file = "importlib_metadata-4.12.0-py3-none-any.whl", hash = "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23"}, 621 | {file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"}, 622 | ] 623 | iniconfig = [ 624 | {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, 625 | {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, 626 | ] 627 | jinja2 = [ 628 | {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"}, 629 | {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"}, 630 | ] 631 | markupsafe = [ 632 | {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812"}, 633 | {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a"}, 634 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e"}, 635 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5"}, 636 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4"}, 637 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f"}, 638 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e"}, 639 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933"}, 640 | {file = "MarkupSafe-2.1.1-cp310-cp310-win32.whl", hash = "sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6"}, 641 | {file = "MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417"}, 642 | {file = "MarkupSafe-2.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02"}, 643 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a"}, 644 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37"}, 645 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980"}, 646 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a"}, 647 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3"}, 648 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a"}, 649 | {file = "MarkupSafe-2.1.1-cp37-cp37m-win32.whl", hash = "sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff"}, 650 | {file = "MarkupSafe-2.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a"}, 651 | {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452"}, 652 | {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003"}, 653 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1"}, 654 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601"}, 655 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925"}, 656 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f"}, 657 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88"}, 658 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63"}, 659 | {file = "MarkupSafe-2.1.1-cp38-cp38-win32.whl", hash = "sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1"}, 660 | {file = "MarkupSafe-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7"}, 661 | {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a"}, 662 | {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f"}, 663 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6"}, 664 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77"}, 665 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603"}, 666 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7"}, 667 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135"}, 668 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96"}, 669 | {file = "MarkupSafe-2.1.1-cp39-cp39-win32.whl", hash = "sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c"}, 670 | {file = "MarkupSafe-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247"}, 671 | {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"}, 672 | ] 673 | mypy-extensions = [] 674 | numpy = [ 675 | {file = "numpy-1.23.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e603ca1fb47b913942f3e660a15e55a9ebca906857edfea476ae5f0fe9b457d5"}, 676 | {file = "numpy-1.23.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:633679a472934b1c20a12ed0c9a6c9eb167fbb4cb89031939bfd03dd9dbc62b8"}, 677 | {file = "numpy-1.23.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17e5226674f6ea79e14e3b91bfbc153fdf3ac13f5cc54ee7bc8fdbe820a32da0"}, 678 | {file = "numpy-1.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdc02c0235b261925102b1bd586579b7158e9d0d07ecb61148a1799214a4afd5"}, 679 | {file = "numpy-1.23.2-cp310-cp310-win32.whl", hash = "sha256:df28dda02c9328e122661f399f7655cdcbcf22ea42daa3650a26bce08a187450"}, 680 | {file = "numpy-1.23.2-cp310-cp310-win_amd64.whl", hash = "sha256:8ebf7e194b89bc66b78475bd3624d92980fca4e5bb86dda08d677d786fefc414"}, 681 | {file = "numpy-1.23.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dc76bca1ca98f4b122114435f83f1fcf3c0fe48e4e6f660e07996abf2f53903c"}, 682 | {file = "numpy-1.23.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ecfdd68d334a6b97472ed032b5b37a30d8217c097acfff15e8452c710e775524"}, 683 | {file = "numpy-1.23.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5593f67e66dea4e237f5af998d31a43e447786b2154ba1ad833676c788f37cde"}, 684 | {file = "numpy-1.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac987b35df8c2a2eab495ee206658117e9ce867acf3ccb376a19e83070e69418"}, 685 | {file = "numpy-1.23.2-cp311-cp311-win32.whl", hash = "sha256:d98addfd3c8728ee8b2c49126f3c44c703e2b005d4a95998e2167af176a9e722"}, 686 | {file = "numpy-1.23.2-cp311-cp311-win_amd64.whl", hash = "sha256:8ecb818231afe5f0f568c81f12ce50f2b828ff2b27487520d85eb44c71313b9e"}, 687 | {file = "numpy-1.23.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:909c56c4d4341ec8315291a105169d8aae732cfb4c250fbc375a1efb7a844f8f"}, 688 | {file = "numpy-1.23.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8247f01c4721479e482cc2f9f7d973f3f47810cbc8c65e38fd1bbd3141cc9842"}, 689 | {file = "numpy-1.23.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8b97a8a87cadcd3f94659b4ef6ec056261fa1e1c3317f4193ac231d4df70215"}, 690 | {file = "numpy-1.23.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5b7ccae24e3d8501ee5563e82febc1771e73bd268eef82a1e8d2b4d556ae66"}, 691 | {file = "numpy-1.23.2-cp38-cp38-win32.whl", hash = "sha256:9b83d48e464f393d46e8dd8171687394d39bc5abfe2978896b77dc2604e8635d"}, 692 | {file = "numpy-1.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:dec198619b7dbd6db58603cd256e092bcadef22a796f778bf87f8592b468441d"}, 693 | {file = "numpy-1.23.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4f41f5bf20d9a521f8cab3a34557cd77b6f205ab2116651f12959714494268b0"}, 694 | {file = "numpy-1.23.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:806cc25d5c43e240db709875e947076b2826f47c2c340a5a2f36da5bb10c58d6"}, 695 | {file = "numpy-1.23.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f9d84a24889ebb4c641a9b99e54adb8cab50972f0166a3abc14c3b93163f074"}, 696 | {file = "numpy-1.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c403c81bb8ffb1c993d0165a11493fd4bf1353d258f6997b3ee288b0a48fce77"}, 697 | {file = "numpy-1.23.2-cp39-cp39-win32.whl", hash = "sha256:cf8c6aed12a935abf2e290860af8e77b26a042eb7f2582ff83dc7ed5f963340c"}, 698 | {file = "numpy-1.23.2-cp39-cp39-win_amd64.whl", hash = "sha256:5e28cd64624dc2354a349152599e55308eb6ca95a13ce6a7d5679ebff2962913"}, 699 | {file = "numpy-1.23.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:806970e69106556d1dd200e26647e9bee5e2b3f1814f9da104a943e8d548ca38"}, 700 | {file = "numpy-1.23.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bd879d3ca4b6f39b7770829f73278b7c5e248c91d538aab1e506c628353e47f"}, 701 | {file = "numpy-1.23.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:be6b350dfbc7f708d9d853663772a9310783ea58f6035eec649fb9c4371b5389"}, 702 | {file = "numpy-1.23.2.tar.gz", hash = "sha256:b78d00e48261fbbd04aa0d7427cf78d18401ee0abd89c7559bbf422e5b1c7d01"}, 703 | ] 704 | packaging = [ 705 | {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, 706 | {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, 707 | ] 708 | pathspec = [] 709 | platformdirs = [ 710 | {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"}, 711 | {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"}, 712 | ] 713 | pluggy = [ 714 | {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, 715 | {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, 716 | ] 717 | py = [ 718 | {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, 719 | {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, 720 | ] 721 | pygments = [] 722 | pyparsing = [ 723 | {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"}, 724 | {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"}, 725 | ] 726 | pytest = [ 727 | {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"}, 728 | {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"}, 729 | ] 730 | pytz = [] 731 | qdrant-sphinx-theme = [] 732 | requests = [ 733 | {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"}, 734 | {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"}, 735 | ] 736 | scipy = [] 737 | smart-open = [] 738 | snowballstemmer = [ 739 | {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"}, 740 | {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, 741 | ] 742 | sphinx = [ 743 | {file = "Sphinx-5.1.1-py3-none-any.whl", hash = "sha256:309a8da80cb6da9f4713438e5b55861877d5d7976b69d87e336733637ea12693"}, 744 | {file = "Sphinx-5.1.1.tar.gz", hash = "sha256:ba3224a4e206e1fbdecf98a4fae4992ef9b24b85ebf7b584bb340156eaf08d89"}, 745 | ] 746 | sphinxcontrib-applehelp = [ 747 | {file = "sphinxcontrib-applehelp-1.0.2.tar.gz", hash = "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58"}, 748 | {file = "sphinxcontrib_applehelp-1.0.2-py2.py3-none-any.whl", hash = "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a"}, 749 | ] 750 | sphinxcontrib-devhelp = [ 751 | {file = "sphinxcontrib-devhelp-1.0.2.tar.gz", hash = "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4"}, 752 | {file = "sphinxcontrib_devhelp-1.0.2-py2.py3-none-any.whl", hash = "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e"}, 753 | ] 754 | sphinxcontrib-htmlhelp = [ 755 | {file = "sphinxcontrib-htmlhelp-2.0.0.tar.gz", hash = "sha256:f5f8bb2d0d629f398bf47d0d69c07bc13b65f75a81ad9e2f71a63d4b7a2f6db2"}, 756 | {file = "sphinxcontrib_htmlhelp-2.0.0-py2.py3-none-any.whl", hash = "sha256:d412243dfb797ae3ec2b59eca0e52dac12e75a241bf0e4eb861e450d06c6ed07"}, 757 | ] 758 | sphinxcontrib-jsmath = [ 759 | {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"}, 760 | {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"}, 761 | ] 762 | sphinxcontrib-qthelp = [ 763 | {file = "sphinxcontrib-qthelp-1.0.3.tar.gz", hash = "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72"}, 764 | {file = "sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl", hash = "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6"}, 765 | ] 766 | sphinxcontrib-serializinghtml = [ 767 | {file = "sphinxcontrib-serializinghtml-1.1.5.tar.gz", hash = "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952"}, 768 | {file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"}, 769 | ] 770 | toml = [ 771 | {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"}, 772 | {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, 773 | ] 774 | tomli = [ 775 | {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, 776 | {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, 777 | ] 778 | torch = [ 779 | {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"}, 780 | {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"}, 781 | {file = "torch-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:e9c8f4a311ac29fc7e8e955cfb7733deb5dbe1bdaabf5d4af2765695824b7e0d"}, 782 | {file = "torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:976c3f997cea38ee91a0dd3c3a42322785414748d1761ef926b789dfa97c6134"}, 783 | {file = "torch-1.12.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:68104e4715a55c4bb29a85c6a8d57d820e0757da363be1ba680fa8cc5be17b52"}, 784 | {file = "torch-1.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:743784ccea0dc8f2a3fe6a536bec8c4763bd82c1352f314937cb4008d4805de1"}, 785 | {file = "torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b5dbcca369800ce99ba7ae6dee3466607a66958afca3b740690d88168752abcf"}, 786 | {file = "torch-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a"}, 787 | {file = "torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:8a34a2fbbaa07c921e1b203f59d3d6e00ed379f2b384445773bd14e328a5b6c8"}, 788 | {file = "torch-1.12.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:42f639501928caabb9d1d55ddd17f07cd694de146686c24489ab8c615c2871f2"}, 789 | {file = "torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0b44601ec56f7dd44ad8afc00846051162ef9c26a8579dda0a02194327f2d55e"}, 790 | {file = "torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cd26d8c5640c3a28c526d41ccdca14cf1cbca0d0f2e14e8263a7ac17194ab1d2"}, 791 | {file = "torch-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:42e115dab26f60c29e298559dbec88444175528b729ae994ec4c65d56fe267dd"}, 792 | {file = "torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:a8320ba9ad87e80ca5a6a016e46ada4d1ba0c54626e135d99b2129a4541c509d"}, 793 | {file = "torch-1.12.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:03e31c37711db2cd201e02de5826de875529e45a55631d317aadce2f1ed45aa8"}, 794 | {file = "torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9b356aea223772cd754edb4d9ecf2a025909b8615a7668ac7d5130f86e7ec421"}, 795 | {file = "torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cf6f54b43c0c30335428195589bd00e764a6d27f3b9ba637aaa8c11aaf93073"}, 796 | {file = "torch-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d"}, 797 | {file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"}, 798 | {file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"}, 799 | ] 800 | typing-extensions = [ 801 | {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, 802 | {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, 803 | ] 804 | urllib3 = [] 805 | zipp = [ 806 | {file = "zipp-3.8.1-py3-none-any.whl", hash = "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009"}, 807 | {file = "zipp-3.8.1.tar.gz", hash = "sha256:05b45f1ee8f807d0cc928485ca40a07cb491cf092ff587c0df9cb1fd154848d2"}, 808 | ] 809 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "quaterion-models" 3 | version = "0.1.19" 4 | description = "The collection of building blocks to build fine-tunable similarity learning models" 5 | authors = ["Quaterion Authors "] 6 | packages = [ 7 | {include = "quaterion_models"}, 8 | ] 9 | readme = "README.md" 10 | homepage = "https://github.com/qdrant/quaterion-models" 11 | repository = "https://github.com/qdrant/quaterion-models" 12 | keywords = ["framework", "metric-learning", "similarity", "similarity-learning", "deep-learning", "pytorch"] 13 | 14 | 15 | [tool.poetry.dependencies] 16 | python = ">=3.8,<3.11" 17 | torch = ">=1.8.2" 18 | numpy = "^1.22" 19 | gensim = {version = "^4.1.2", optional = true} 20 | 21 | 22 | [tool.poetry.dev-dependencies] 23 | pytest = "^6.2.5" 24 | sphinx = ">=5.0.1" 25 | qdrant-sphinx-theme = { git = "https://github.com/qdrant/qdrant_sphinx_theme.git", branch = "master" } 26 | black = "^22.3.0" 27 | 28 | [tool.poetry.extras] 29 | fasttext = ["gensim"] 30 | 31 | [build-system] 32 | requires = ["poetry-core>=1.0.0", "setuptools"] 33 | build-backend = "poetry.core.masonry.api" 34 | -------------------------------------------------------------------------------- /quaterion_models/__init__.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.model import SimilarityModel 2 | -------------------------------------------------------------------------------- /quaterion_models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.encoders.encoder import Encoder 2 | from quaterion_models.encoders.switch_encoder import SwitchEncoder 3 | -------------------------------------------------------------------------------- /quaterion_models/encoders/encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Any, List 4 | 5 | from torch import Tensor, nn 6 | from torch.utils.data.dataloader import default_collate 7 | 8 | from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange 9 | 10 | 11 | class Encoder(nn.Module): 12 | """Base class for encoder abstraction""" 13 | 14 | def __init__(self): 15 | super(Encoder, self).__init__() 16 | 17 | def disable_gradients_if_required(self): 18 | """Disables gradients of the model if it is declared as not trainable""" 19 | 20 | if not self.trainable: 21 | for _, weights in self.named_parameters(): 22 | weights.requires_grad = False 23 | 24 | @property 25 | def trainable(self) -> bool: 26 | """Defines if encoder is trainable. 27 | 28 | This flag affects caching and checkpoint saving of the encoder. 29 | """ 30 | raise NotImplementedError() 31 | 32 | @property 33 | def embedding_size(self) -> int: 34 | """Size of resulting embedding""" 35 | raise NotImplementedError() 36 | 37 | def get_collate_fn(self) -> CollateFnType: 38 | """Provides function that converts raw data batch into suitable model input 39 | 40 | Returns: 41 | :const:`~quaterion_models.types.CollateFnType`: model's collate function 42 | """ 43 | return default_collate 44 | 45 | @classmethod 46 | def extract_meta(cls, batch: List[Any]) -> List[dict]: 47 | """Extracts meta information from the batch 48 | 49 | Args: 50 | batch: raw batch of data 51 | 52 | Returns: 53 | meta information 54 | """ 55 | return [{} for _ in batch] 56 | 57 | def get_meta_extractor(self) -> MetaExtractorFnType: 58 | return self.extract_meta 59 | 60 | def forward(self, batch: TensorInterchange) -> Tensor: 61 | """Infer encoder - convert input batch to embeddings 62 | 63 | Args: 64 | batch: processed batch 65 | Returns: 66 | embeddings: shape: (batch_size, embedding_size) 67 | """ 68 | raise NotImplementedError() 69 | 70 | def save(self, output_path: str): 71 | """Persist current state to the provided directory 72 | 73 | Args: 74 | output_path: path to save model 75 | 76 | """ 77 | raise NotImplementedError() 78 | 79 | @classmethod 80 | def load(cls, input_path: str) -> Encoder: 81 | """Instantiate encoder from saved state. 82 | 83 | If no state required - just call `create` instead 84 | 85 | Args: 86 | input_path: path to load from 87 | 88 | Returns: 89 | :class:`~quaterion_models.encoders.encoder.Encoder`: loaded encoder 90 | """ 91 | raise NotImplementedError() 92 | -------------------------------------------------------------------------------- /quaterion_models/encoders/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/quaterion_models/encoders/extras/__init__.py -------------------------------------------------------------------------------- /quaterion_models/encoders/extras/fasttext_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from typing import Any, List, Union 6 | 7 | import gensim 8 | import numpy as np 9 | import torch 10 | from gensim.models import FastText, KeyedVectors 11 | from gensim.models.fasttext import FastTextKeyedVectors 12 | from torch import Tensor 13 | 14 | from quaterion_models.encoders import Encoder 15 | from quaterion_models.types import CollateFnType 16 | 17 | 18 | def load_fasttext_model(path: str) -> Union[FastText, KeyedVectors]: 19 | """Load fasttext model in a universal way 20 | 21 | Try to find possible way of loading FastText model and load it 22 | 23 | Args: 24 | path: path to FastText model or vectors 25 | 26 | Returns: 27 | :class:`~gensim.models.fasttext.FastText` or :class:`~gensim.models.KeyedVectors`: 28 | loaded model 29 | """ 30 | try: 31 | model = FastText.load(path).wv 32 | except Exception: 33 | try: 34 | model = FastText.load_fasttext_format(path).wv 35 | except Exception: 36 | model = gensim.models.KeyedVectors.load(path) 37 | 38 | return model 39 | 40 | 41 | class FasttextEncoder(Encoder): 42 | """Creates a fasttext encoder, which generates vector for a list of tokens based in given 43 | fasttext model 44 | 45 | Args: 46 | model_path: Path to model to load 47 | on_disk: If True - use mmap to keep embeddings out of RAM 48 | aggregations: What types of aggregations to use to combine multiple vectors into one. If 49 | multiple aggregations are specified - concatenation of all of them will be used as a 50 | result. 51 | 52 | """ 53 | 54 | aggregation_options = ["min", "max", "avg"] 55 | 56 | def __init__(self, model_path: str, on_disk: bool, aggregations: List[str] = None): 57 | super(FasttextEncoder, self).__init__() 58 | 59 | # workaround tensor to keep information about required model device 60 | self._device_tensor = torch.nn.Parameter(torch.zeros(1)) 61 | 62 | if aggregations is None: 63 | aggregations = ["avg"] 64 | self.aggregations = aggregations 65 | self.on_disk = on_disk 66 | 67 | # noinspection PyTypeChecker 68 | self.model: FastTextKeyedVectors = gensim.models.KeyedVectors.load( 69 | model_path, mmap="r" if self.on_disk else None 70 | ) 71 | 72 | @property 73 | def trainable(self) -> bool: 74 | return False 75 | 76 | @property 77 | def embedding_size(self) -> int: 78 | return self.model.vector_size * len(self.aggregations) 79 | 80 | @classmethod 81 | def get_tokens(cls, batch: List[Any]) -> List[List[str]]: 82 | raise NotImplementedError() 83 | 84 | def get_collate_fn(self) -> CollateFnType: 85 | return self.__class__.get_tokens 86 | 87 | @classmethod 88 | def aggregate(cls, embeddings: Tensor, operation: str) -> Tensor: 89 | """Apply aggregation operation to embeddings along the first dimension 90 | 91 | Args: 92 | embeddings: embeddings to aggregate 93 | operation: one of :attr:`aggregation_options` 94 | 95 | Returns: 96 | Tensor: aggregated embeddings 97 | """ 98 | if operation == "avg": 99 | return torch.mean(embeddings, dim=0) 100 | if operation == "max": 101 | return torch.max(embeddings, dim=0).values 102 | if operation == "min": 103 | return torch.min(embeddings, dim=0).values 104 | 105 | raise RuntimeError(f"Unknown operation: {operation}") 106 | 107 | def forward(self, batch: List[List[str]]) -> Tensor: 108 | embeddings = [] 109 | for record in batch: 110 | token_vectors = [self.model.get_vector(token) for token in record] 111 | if token_vectors: 112 | record_vectors = np.stack(token_vectors) 113 | else: 114 | record_vectors = np.zeros((1, self.model.vector_size)) 115 | token_tensor = torch.tensor( 116 | record_vectors, device=self._device_tensor.device 117 | ) 118 | record_embedding = torch.cat( 119 | [ 120 | self.aggregate(token_tensor, operation) 121 | for operation in self.aggregations 122 | ] 123 | ) 124 | embeddings.append(record_embedding) 125 | 126 | return torch.stack(embeddings) 127 | 128 | def save(self, output_path: str): 129 | model_path = os.path.join(output_path, "fasttext.model") 130 | self.model.save( 131 | model_path, separately=["vectors_ngrams", "vectors", "vectors_vocab"] 132 | ) 133 | with open(os.path.join(output_path, "config.json"), "w") as f_out: 134 | json.dump( 135 | { 136 | "on_disk": self.on_disk, 137 | "aggregations": self.aggregations, 138 | }, 139 | f_out, 140 | indent=2, 141 | ) 142 | 143 | @classmethod 144 | def load(cls, input_path: str) -> Encoder: 145 | model_path = os.path.join(input_path, "fasttext.model") 146 | with open(os.path.join(input_path, "config.json")) as f_in: 147 | config = json.load(f_in) 148 | 149 | return cls(model_path=model_path, **config) 150 | -------------------------------------------------------------------------------- /quaterion_models/encoders/switch_encoder.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from functools import partial 6 | from typing import Any, Dict, List 7 | 8 | import torch 9 | from torch import Tensor 10 | 11 | from quaterion_models.encoders import Encoder 12 | from quaterion_models.types import CollateFnType, TensorInterchange 13 | from quaterion_models.utils import move_to_device, restore_class, save_class_import 14 | 15 | 16 | def inverse_permutation(perm): 17 | inv = torch.empty_like(perm) 18 | inv[perm] = torch.arange(perm.size(0), device=perm.device) 19 | return inv 20 | 21 | 22 | class SwitchEncoder(Encoder): 23 | """Allows use alternative embeddings based on input data. 24 | 25 | For example, train shared embedding representation for images and texts. 26 | In this case image encoder should be used if input is an image and text encoder in other case. 27 | """ 28 | 29 | @classmethod 30 | def encoder_selection(cls, record: Any) -> str: 31 | """Decide which encoder to use for given record. 32 | 33 | Args: 34 | record: input piece of data 35 | 36 | Returns: 37 | name of the related encoder 38 | """ 39 | raise NotImplementedError() 40 | 41 | def __init__(self, options: Dict[str, Encoder]): 42 | super(SwitchEncoder, self).__init__() 43 | self.options = options 44 | 45 | embedding_sizes = set() 46 | for key, encoder in self.options.items(): 47 | self.add_module(key, encoder) 48 | embedding_sizes.add(encoder.embedding_size) 49 | 50 | if len(embedding_sizes) != 1: 51 | raise RuntimeError( 52 | f"Alternative encoders have inconsistent output size: {embedding_sizes}" 53 | ) 54 | 55 | self._embedding_size = list(embedding_sizes)[0] 56 | 57 | def disable_gradients_if_required(self): 58 | for encoder in self.options.values(): 59 | encoder.disable_gradients_if_required() 60 | 61 | @property 62 | def trainable(self) -> bool: 63 | return any(encoder.trainable for encoder in self.options.values()) 64 | 65 | @property 66 | def embedding_size(self) -> int: 67 | return self._embedding_size 68 | 69 | @classmethod 70 | def switch_collate_fn( 71 | cls, batch: List[Any], encoder_collates: Dict[str, CollateFnType] 72 | ) -> TensorInterchange: 73 | switch_batches = dict((key, []) for key in encoder_collates.keys()) 74 | switch_ordering = dict((key, []) for key in encoder_collates.keys()) 75 | for original_id, record in enumerate(batch): 76 | record_encoder = cls.encoder_selection(record) 77 | switch_batches[record_encoder].append(record) 78 | switch_ordering[record_encoder].append(original_id) 79 | 80 | switch_batches = { 81 | key: encoder_collates[key](batch) 82 | for key, batch in switch_batches.items() 83 | if len(batch) > 0 84 | } 85 | 86 | return {"ordering": switch_ordering, "batches": switch_batches} 87 | 88 | def get_collate_fn(self) -> CollateFnType: 89 | return partial( 90 | self.__class__.switch_collate_fn, 91 | encoder_collates=dict( 92 | (key, encoder.get_collate_fn()) for key, encoder in self.options.items() 93 | ), 94 | ) 95 | 96 | @classmethod 97 | def extract_meta(cls, batch: List[Any]) -> List[dict]: 98 | meta = [] 99 | for record in batch: 100 | meta.append( 101 | { 102 | "encoder": cls.encoder_selection(record), 103 | } 104 | ) 105 | return meta 106 | 107 | def forward(self, batch: TensorInterchange) -> Tensor: 108 | switch_ordering: dict = batch["ordering"] 109 | switch_batches: dict = batch["batches"] 110 | embeddings = [] 111 | ordering = [] 112 | for key, batch in switch_batches.items(): 113 | embeddings.append(self.options[key].forward(batch)) 114 | ordering += switch_ordering[key] 115 | ordering_tensor: Tensor = inverse_permutation(torch.tensor(ordering)) 116 | embeddings_tensor: Tensor = torch.cat(embeddings) 117 | ordering_tensor = move_to_device(ordering_tensor, embeddings_tensor.device) 118 | return embeddings_tensor[ordering_tensor] 119 | 120 | def save(self, output_path: str): 121 | encoders = {} 122 | for key, encoder in self.options.items(): 123 | encoders[key] = save_class_import(encoder) 124 | encoder_path = os.path.join(output_path, key) 125 | os.makedirs(encoder_path, exist_ok=True) 126 | encoder.save(encoder_path) 127 | 128 | with open(os.path.join(output_path, "config.json"), "w") as f_out: 129 | json.dump( 130 | { 131 | "encoders": encoders, 132 | }, 133 | f_out, 134 | indent=2, 135 | ) 136 | 137 | @classmethod 138 | def load(cls, input_path: str) -> "Encoder": 139 | with open(os.path.join(input_path, "config.json")) as f_in: 140 | config = json.load(f_in) 141 | 142 | encoders = {} 143 | encoders_params: dict = config["encoders"] 144 | for key, class_params in encoders_params.items(): 145 | encoder_path = os.path.join(input_path, key) 146 | encoder_class = restore_class(class_params) 147 | encoders[key] = encoder_class.load(encoder_path) 148 | 149 | return cls(options=encoders) 150 | -------------------------------------------------------------------------------- /quaterion_models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.heads.empty_head import EmptyHead 2 | from quaterion_models.heads.encoder_head import EncoderHead 3 | from quaterion_models.heads.gated_head import GatedHead 4 | from quaterion_models.heads.sequential_head import SequentialHead 5 | from quaterion_models.heads.skip_connection_head import SkipConnectionHead 6 | from quaterion_models.heads.softmax_head import SoftmaxEmbeddingsHead 7 | from quaterion_models.heads.stacked_projection_head import StackedProjectionHead 8 | from quaterion_models.heads.switch_head import SwitchHead 9 | from quaterion_models.heads.widening_head import WideningHead 10 | -------------------------------------------------------------------------------- /quaterion_models/heads/empty_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from quaterion_models.heads.encoder_head import EncoderHead 4 | 5 | 6 | class EmptyHead(EncoderHead): 7 | """Returns input embeddings without any modification""" 8 | 9 | def __init__(self, input_embedding_size: int, dropout: float = 0.0): 10 | super(EmptyHead, self).__init__(input_embedding_size, dropout=dropout) 11 | 12 | @property 13 | def output_size(self) -> int: 14 | return self.input_embedding_size 15 | 16 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 17 | return input_vectors 18 | -------------------------------------------------------------------------------- /quaterion_models/heads/encoder_head.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class EncoderHead(nn.Module): 10 | """Base class for the final layer of fine-tuned model. 11 | EncoderHead is the only trainable component in case of frozen encoders. 12 | 13 | Args: 14 | input_embedding_size: 15 | Size of the concatenated embedding, obtained from combination of all configured encoders 16 | dropout: 17 | Probability of Dropout. If `dropout > 0.`, apply dropout layer 18 | on embeddings before applying head layer transformations 19 | **kwargs: 20 | """ 21 | 22 | def __init__(self, input_embedding_size: int, dropout: float = 0.0, **kwargs): 23 | super(EncoderHead, self).__init__() 24 | self.input_embedding_size = input_embedding_size 25 | self._dropout_prob = dropout 26 | self.dropout = ( 27 | torch.nn.Dropout(p=dropout) if dropout > 0.0 else torch.nn.Identity() 28 | ) 29 | 30 | @property 31 | def output_size(self) -> int: 32 | raise NotImplementedError() 33 | 34 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 35 | """Apply head-specific transformations to the embeddings tensor. 36 | Called as part of `forward` function, but with generic wrappings 37 | 38 | Args: 39 | input_vectors: Concatenated embeddings of all encoders. Shape: (batch_size, self.input_embedding_size) 40 | 41 | Returns: 42 | Final embeddings for a batch: (batch_size, self.output_size) 43 | """ 44 | raise NotImplementedError() 45 | 46 | def forward( 47 | self, input_vectors: torch.Tensor, meta: List[Any] = None 48 | ) -> torch.Tensor: 49 | return self.transform(self.dropout(input_vectors)) 50 | 51 | def get_config_dict(self) -> Dict[str, Any]: 52 | """Constructs savable params dict 53 | 54 | Returns: 55 | Serializable parameters for __init__ of the Module 56 | """ 57 | return { 58 | "input_embedding_size": self.input_embedding_size, 59 | "dropout": self._dropout_prob, 60 | } 61 | 62 | def save(self, output_path): 63 | torch.save(self.state_dict(), os.path.join(output_path, "weights.bin")) 64 | 65 | with open(os.path.join(output_path, "config.json"), "w") as f_out: 66 | json.dump(self.get_config_dict(), f_out, indent=2) 67 | 68 | @classmethod 69 | def load(cls, input_path: str) -> "EncoderHead": 70 | with open(os.path.join(input_path, "config.json")) as f_in: 71 | config = json.load(f_in) 72 | model = cls(**config) 73 | model.load_state_dict(torch.load(os.path.join(input_path, "weights.bin"))) 74 | return model 75 | -------------------------------------------------------------------------------- /quaterion_models/heads/gated_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Parameter 3 | 4 | from quaterion_models.heads.encoder_head import EncoderHead 5 | 6 | 7 | class GatedHead(EncoderHead): 8 | """Disables or amplifies some components of input embedding. 9 | 10 | This layer has minimal amount of trainable parameters and is suitable for even small training 11 | sets. 12 | """ 13 | 14 | def __init__(self, input_embedding_size: int, dropout: float = 0.0): 15 | super(GatedHead, self).__init__(input_embedding_size, dropout=dropout) 16 | self.gates = Parameter(torch.Tensor(self.input_embedding_size)) 17 | self.reset_parameters() 18 | 19 | @property 20 | def output_size(self) -> int: 21 | return self.input_embedding_size 22 | 23 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 24 | """ 25 | 26 | Args: 27 | input_vectors: shape: (batch_size, vector_size) 28 | 29 | Returns: 30 | Tensor: (batch_size, vector_size) 31 | """ 32 | return input_vectors * torch.tanh(self.gates) 33 | 34 | def reset_parameters(self) -> None: 35 | torch.nn.init.constant_( 36 | self.gates, 2.0 37 | ) # 2. ensures that all vector components are enabled by default 38 | -------------------------------------------------------------------------------- /quaterion_models/heads/sequential_head.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, Iterator, Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from quaterion_models.heads.encoder_head import EncoderHead 9 | 10 | 11 | class SequentialHead(EncoderHead): 12 | """A `torch.nn.Sequential`-like head layer that you can freely add any layers. 13 | 14 | Unlike `torch.nn.Sequential`, it also expects the output size to be passed 15 | as a required keyword-only argument. It is required because some loss functions 16 | may need this information. 17 | 18 | Args: 19 | args: Any sequence of `torch.nn.Module` instances. See `torch.nn.Sequential` for more info. 20 | output_size: Final output dimension from this head. 21 | 22 | Examples:: 23 | 24 | head = SequentialHead( 25 | nn.Linear(10, 20), 26 | nn.ReLU(), 27 | nn.Linear(20, 30), 28 | output_size=30 29 | ) 30 | """ 31 | 32 | def __init__(self, *args, output_size: int): 33 | super().__init__(None) 34 | self._sequential = nn.Sequential(*args) 35 | self._output_size = output_size 36 | 37 | @property 38 | def output_size(self) -> int: 39 | return self._output_size 40 | 41 | def forward(self, input_vectors: torch.Tensor, meta=None) -> torch.Tensor: 42 | """Forward pass for this head layer. 43 | 44 | Just like `torch.nn.Sequential`, it passes the input to the first module, 45 | and the output of each module is input to the next. The final output of this head layer is 46 | the output from the last module in the sequence. 47 | 48 | Args: 49 | input_vectors: Batch of input vectors. 50 | meta: Optional metadata for this batch. 51 | 52 | Returns: 53 | Output from the last module in the sequence. 54 | """ 55 | return self._sequential.forward(input_vectors) 56 | 57 | def append(self, module: nn.Module) -> "SequentialHead": 58 | self._sequential.append(module) 59 | return self 60 | 61 | def get_config_dict(self) -> Dict[str, Any]: 62 | """Constructs savable params dict 63 | 64 | Returns: 65 | Serializable parameters for __init__ of the Module 66 | """ 67 | return { 68 | "output_size": self._output_size, 69 | } 70 | 71 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 72 | return input_vectors 73 | 74 | def save(self, output_path): 75 | torch.save(self._sequential, os.path.join(output_path, "weights.bin")) 76 | 77 | with open(os.path.join(output_path, "config.json"), "w") as f_out: 78 | json.dump(self.get_config_dict(), f_out, indent=2) 79 | 80 | @classmethod 81 | def load(cls, input_path: str) -> "EncoderHead": 82 | with open(os.path.join(input_path, "config.json")) as f_in: 83 | config = json.load(f_in) 84 | sequential = torch.load( 85 | os.path.join(input_path, "weights.bin"), map_location="cpu" 86 | ) 87 | model = cls(*sequential, **config) 88 | return model 89 | 90 | def __getitem__(self, idx) -> Union[nn.Sequential, nn.Module]: 91 | return self._sequential[idx] 92 | 93 | def __delitem__(self, idx: Union[slice, int]) -> None: 94 | return self._sequential.__delitem(idx) 95 | 96 | def __setitem__(self, idx: int, module: nn.Module) -> None: 97 | return self._sequential.__setitem__(idx, module) 98 | 99 | def __len__(self) -> int: 100 | return self._sequential.__len__() 101 | 102 | def __dir__(self): 103 | return self._sequential.__dir__() 104 | 105 | def __iter__(self) -> Iterator[nn.Module]: 106 | return self._sequential.__iter__() 107 | -------------------------------------------------------------------------------- /quaterion_models/heads/skip_connection_head.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | from torch.nn import Parameter 5 | 6 | from quaterion_models.heads.encoder_head import EncoderHead 7 | 8 | 9 | class SkipConnectionHead(EncoderHead): 10 | """Unites the idea of gated head and residual connections. 11 | 12 | Schema: 13 | .. code-block:: none 14 | 15 | ├──────────┐ 16 | ┌───────┴───────┐ │ 17 | │ Skip-Dropout │ │ 18 | └───────┬───────┘ │ 19 | ┌───────┴───────┐ │ 20 | │ Linear │ │ 21 | └───────┬───────┘ │ 22 | ┌───────┴───────┐ │ 23 | │ Gated │ │ 24 | └───────┬───────┘ │ 25 | + <────────┘ 26 | │ 27 | 28 | Args: 29 | input_embedding_size: 30 | Size of the concatenated embedding, obtained from combination of all configured encoders 31 | dropout: 32 | Probability of Dropout. If `dropout > 0.`, apply dropout layer 33 | on embeddings before applying head layer transformations 34 | skip_dropout: 35 | Additional dropout, applied to the trainable branch only. 36 | Using additional dropout allows to avoid the modification of original embedding. 37 | n_layers: 38 | Number of gated residual blocks stacked on top of each other. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | input_embedding_size: int, 44 | dropout: float = 0.0, 45 | skip_dropout: float = 0.0, 46 | n_layers: int = 1, 47 | ): 48 | super().__init__(input_embedding_size, dropout=dropout) 49 | for i in range(n_layers): 50 | self.register_parameter( 51 | f"gates_{i}", Parameter(torch.Tensor(self.input_embedding_size)) 52 | ) 53 | setattr( 54 | self, 55 | f"fc_{i}", 56 | torch.nn.Linear(input_embedding_size, input_embedding_size), 57 | ) 58 | 59 | self.skip_dropout = ( 60 | torch.nn.Dropout(p=skip_dropout) 61 | if skip_dropout > 0.0 62 | else torch.nn.Identity() 63 | ) 64 | 65 | self._skip_dropout_prob = skip_dropout 66 | self._n_layers = n_layers 67 | self.reset_parameters() 68 | 69 | @property 70 | def output_size(self) -> int: 71 | return self.input_embedding_size 72 | 73 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 74 | """ 75 | Args: 76 | input_vectors: shape: (batch_size, input_embedding_size) 77 | 78 | Returns: 79 | torch.Tensor: shape: (batch_size, input_embedding_size) 80 | """ 81 | for i in range(self._n_layers): 82 | fc = getattr(self, f"fc_{i}") 83 | gates = getattr(self, f"gates_{i}") 84 | input_vectors = ( 85 | fc(self.skip_dropout(input_vectors)) * torch.sigmoid(gates) 86 | + input_vectors 87 | ) 88 | 89 | return input_vectors 90 | 91 | def reset_parameters(self) -> None: 92 | for i in range(self._n_layers): 93 | torch.nn.init.constant_( 94 | getattr(self, f"gates_{i}"), -4.0 95 | ) # -4. ensures that all vector components are disabled by default 96 | 97 | def get_config_dict(self) -> Dict[str, Any]: 98 | config = super().get_config_dict() 99 | config.update( 100 | {"skip_dropout": self._skip_dropout_prob, "n_layers": self._n_layers} 101 | ) 102 | return config 103 | -------------------------------------------------------------------------------- /quaterion_models/heads/softmax_head.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | import torch 4 | from torch.nn import Linear 5 | 6 | from quaterion_models.heads.encoder_head import EncoderHead 7 | 8 | 9 | class SoftmaxEmbeddingsHead(EncoderHead): 10 | """Provides a concatenation of the independent softmax embeddings groups as a head layer 11 | 12 | Useful for deriving embedding confidence. 13 | 14 | Schema: 15 | .. code-block:: none 16 | 17 | ┌──────────────────┐ 18 | │ Encoder │ 19 | └──┬───────────┬───┘ 20 | │ │ 21 | │ │ 22 | ┌───────────┼───────────┼───────────┐ 23 | │ │ │ │ 24 | │ ┌─────────▼──┐ ┌──▼─────────┐ │ 25 | │ │ Linear │ ... │ Linear │ │ 26 | │ └─────┬──────┘ └─────┬──────┘ │ 27 | │ │ │ │ 28 | │ ┌─────┴──────┐ ┌─────┴──────┐ │ 29 | │ │ SoftMax │ ... │ SoftMax │ │ 30 | │ └─────┬──────┘ └─────┬──────┘ │ 31 | │ │ │ │ 32 | │ ┌────┴──────────────────┴─────┐ │ 33 | │ │ Concatenation │ │ 34 | │ └──────────────┬──────────────┘ │ 35 | │ │ │ 36 | └─────────────────┼─────────────────┘ 37 | │ 38 | ▼ 39 | 40 | """ 41 | 42 | def __init__( 43 | self, 44 | output_groups: int, 45 | output_size_per_group: int, 46 | input_embedding_size: int, 47 | dropout: float = 0.0, 48 | **kwargs 49 | ): 50 | super().__init__(input_embedding_size, dropout=dropout, **kwargs) 51 | 52 | self.output_groups = output_groups 53 | self.output_size_per_group = output_size_per_group 54 | self.projectors = [] 55 | 56 | self.projection_layer = Linear( 57 | self.input_embedding_size, self.output_size_per_group * self.output_groups 58 | ) 59 | 60 | @property 61 | def output_size(self) -> int: 62 | return self.output_size_per_group * self.output_groups 63 | 64 | def transform(self, input_vectors: torch.Tensor): 65 | """ 66 | 67 | Args: 68 | input_vectors: shape: (batch_size, ..., input_dim) 69 | 70 | Returns: 71 | shape (batch_size, ..., self.output_size_per_group * self.output_groups) 72 | """ 73 | 74 | # shape: [batch_size, ..., self.output_size_per_group * self.output_groups] 75 | projection = self.projection_layer(input_vectors) 76 | 77 | init_shape = projection.shape 78 | groups_shape = list(init_shape) 79 | groups_shape[-1] = self.output_groups 80 | groups_shape.append(-1) 81 | 82 | # shape: [batch_size, ..., self.output_groups, self.output_size_per_group] 83 | grouped_projection = torch.softmax(projection.view(*groups_shape), dim=-1) 84 | 85 | return grouped_projection.view(init_shape) 86 | 87 | def get_config_dict(self) -> Dict[str, Any]: 88 | config = super().get_config_dict() 89 | config.update( 90 | { 91 | "output_groups": self.output_groups, 92 | "output_size_per_group": self.output_size_per_group, 93 | } 94 | ) 95 | 96 | return config 97 | -------------------------------------------------------------------------------- /quaterion_models/heads/stacked_projection_head.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from quaterion_models.heads.encoder_head import EncoderHead 7 | from quaterion_models.modules import ActivationFromFnName 8 | 9 | 10 | class StackedProjectionHead(EncoderHead): 11 | """Stacks any number of projection layers with specified output sizes. 12 | 13 | Args: 14 | input_embedding_size: Dimensionality of the input to this stack of layers. 15 | output_sizes: List of output sizes for each one of the layers stacked. 16 | activation_fn: Name of the activation function to apply between the layers stacked. 17 | Must be an attribute of `torch.nn.functional` and defaults to `relu`. 18 | dropout: Probability of Dropout. If `dropout > 0.`, apply dropout layer 19 | on embeddings before applying head layer transformations 20 | """ 21 | 22 | def __init__( 23 | self, 24 | input_embedding_size: int, 25 | output_sizes: List[int], 26 | activation_fn: str = "relu", 27 | dropout: float = 0.0, 28 | ): 29 | super(StackedProjectionHead, self).__init__( 30 | input_embedding_size, dropout=dropout 31 | ) 32 | self._output_sizes = output_sizes 33 | self._activation_fn = activation_fn 34 | 35 | modules = [nn.Linear(input_embedding_size, self._output_sizes[0])] 36 | 37 | if len(self._output_sizes) > 1: 38 | for i in range(1, len(self._output_sizes)): 39 | modules.extend( 40 | [ 41 | ActivationFromFnName(self._activation_fn), 42 | nn.Linear(self._output_sizes[i - 1], self._output_sizes[i]), 43 | ] 44 | ) 45 | 46 | self._stack = nn.Sequential(*modules) 47 | 48 | @property 49 | def output_size(self) -> int: 50 | return self._output_sizes[-1] 51 | 52 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 53 | return self._stack(input_vectors) 54 | 55 | def get_config_dict(self) -> Dict[str, Any]: 56 | config = super().get_config_dict() 57 | config.update( 58 | {"output_sizes": self._output_sizes, "activation_fn": self._activation_fn} 59 | ) 60 | 61 | return config 62 | -------------------------------------------------------------------------------- /quaterion_models/heads/switch_head.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, List 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from quaterion_models.encoders.switch_encoder import inverse_permutation 9 | from quaterion_models.heads import EncoderHead 10 | from quaterion_models.utils import restore_class, save_class_import 11 | 12 | 13 | class SwitchHead(EncoderHead): 14 | """Encoder which switches between different heads based on the metadata 15 | Useful in combination with the SwitchEncoder for training multimodal models 16 | 17 | Args: 18 | options: dict of heads. Choice of head is based on the metadata key 19 | """ 20 | 21 | def __init__( 22 | self, options: Dict[str, EncoderHead], input_embedding_size: int, **kwargs 23 | ): 24 | super().__init__(input_embedding_size, dropout=0.0, **kwargs) 25 | self._heads = options 26 | for key, head in self._heads.items(): 27 | self.add_module(key, head) 28 | 29 | @property 30 | def output_size(self) -> int: 31 | return next(iter(self._heads.values())).output_size 32 | 33 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 34 | pass 35 | 36 | def forward( 37 | self, input_vectors: torch.Tensor, meta: List[Any] = None 38 | ) -> torch.Tensor: 39 | # Shape: [batch_size x input_embedding_size] 40 | dropout_input = self.dropout(input_vectors) 41 | 42 | switch_mask = dict((key, []) for key in self._heads.keys()) 43 | switch_ordering = dict((key, []) for key in self._heads.keys()) 44 | for i, m in enumerate(meta): 45 | switch_ordering[m["encoder"]].append(i) 46 | for key, mask in switch_mask.items(): 47 | mask.append(int(key == m["encoder"])) 48 | 49 | head_outputs = [] 50 | ordering = [] 51 | for key, mask in switch_mask.items(): 52 | # Shape: [batch_size] 53 | mask = torch.tensor(mask, dtype=torch.bool, device=input_vectors.device) 54 | head_outputs.append(self._heads[key].transform(dropout_input[mask])) 55 | ordering += switch_ordering[key] 56 | 57 | ordering_tensor: Tensor = inverse_permutation(torch.tensor(ordering)) 58 | # Shape: [batch_size x output_size] 59 | return torch.cat(head_outputs)[ordering_tensor] 60 | 61 | def get_config_dict(self) -> Dict[str, Any]: 62 | """Constructs savable params dict 63 | 64 | Returns: 65 | Serializable parameters for __init__ of the Module 66 | """ 67 | return { 68 | "heads": { 69 | k: {"config": v.get_config_dict(), "class": save_class_import(v)} 70 | for k, v in self._heads.items() 71 | }, 72 | "input_embedding_size": self.input_embedding_size, 73 | } 74 | 75 | @classmethod 76 | def load(cls, input_path: str) -> "EncoderHead": 77 | with open(os.path.join(input_path, "config.json")) as f_in: 78 | config = json.load(f_in) 79 | 80 | heads_config = config.pop("heads") 81 | 82 | heads = dict( 83 | (key, restore_class(head_config["class"])(**head_config["config"])) 84 | for key, head_config in heads_config.items() 85 | ) 86 | 87 | model = cls(options=heads, **config) 88 | model.load_state_dict(torch.load(os.path.join(input_path, "weights.bin"))) 89 | return model 90 | -------------------------------------------------------------------------------- /quaterion_models/heads/widening_head.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from quaterion_models.heads.stacked_projection_head import StackedProjectionHead 4 | 5 | 6 | class WideningHead(StackedProjectionHead): 7 | """Implements narrow-wide-narrow architecture. 8 | 9 | Widen the dimensionality by a factor of `expansion_factor` and narrow it down back to 10 | `input_embedding_size`. 11 | 12 | Args: 13 | input_embedding_size: Dimensionality of the input to this head layer. 14 | expansion_factor: Widen the dimensionality by this factor in the intermediate layer. 15 | activation_fn: Name of the activation function to apply after the intermediate layer. 16 | Must be an attribute of `torch.nn.functional` and defaults to `relu`. 17 | dropout: Probability of Dropout. If `dropout > 0.`, apply dropout layer 18 | on embeddings before applying head layer transformations 19 | """ 20 | 21 | def __init__( 22 | self, 23 | input_embedding_size: int, 24 | expansion_factor: float = 4.0, 25 | activation_fn: str = "relu", 26 | dropout: float = 0.0, 27 | **kwargs 28 | ): 29 | self._expansion_factor = expansion_factor 30 | self._activation_fn = activation_fn 31 | super(WideningHead, self).__init__( 32 | input_embedding_size=input_embedding_size, 33 | output_sizes=[ 34 | int(input_embedding_size * expansion_factor), 35 | input_embedding_size, 36 | ], 37 | activation_fn=activation_fn, 38 | dropout=dropout, 39 | ) 40 | 41 | def get_config_dict(self) -> Dict[str, Any]: 42 | config = super().get_config_dict() 43 | config.update( 44 | { 45 | "expansion_factor": self._expansion_factor, 46 | } 47 | ) 48 | return config 49 | -------------------------------------------------------------------------------- /quaterion_models/model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import json 4 | import os 5 | from functools import partial 6 | from typing import Any, Callable, Dict, List, Type, Union 7 | 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | 12 | from quaterion_models.encoders import Encoder 13 | from quaterion_models.heads.encoder_head import EncoderHead 14 | from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange 15 | from quaterion_models.utils.classes import restore_class, save_class_import 16 | from quaterion_models.utils.meta import merge_meta 17 | from quaterion_models.utils.tensors import move_to_device 18 | 19 | DEFAULT_ENCODER_KEY = "default" 20 | 21 | 22 | class SimilarityModel(nn.Module): 23 | """Main class which contains encoder models with the head layer.""" 24 | 25 | def __init__(self, encoders: Union[Encoder, Dict[str, Encoder]], head: EncoderHead): 26 | super().__init__() 27 | 28 | if not isinstance(encoders, dict): 29 | self.encoders: Dict[str, Encoder] = {DEFAULT_ENCODER_KEY: encoders} 30 | else: 31 | self.encoders: Dict[str, Encoder] = encoders 32 | 33 | for key, encoder in self.encoders.items(): 34 | encoder.disable_gradients_if_required() 35 | self.add_module(key, encoder) 36 | 37 | self.head = head 38 | 39 | @classmethod 40 | def collate_fn( 41 | cls, 42 | batch: List[dict], 43 | encoders_collate_fns: Dict[str, CollateFnType], 44 | meta_extractors: Dict[str, MetaExtractorFnType], 45 | ) -> TensorInterchange: 46 | """Construct batches for all encoders 47 | 48 | Args: 49 | batch: 50 | encoders_collate_fns: Dict (or single) of collate functions associated with encoders 51 | meta_extractors: Dict (or single) of meta extractor functions associated with encoders 52 | 53 | """ 54 | data = dict( 55 | (key, collate_fn(batch)) for key, collate_fn in encoders_collate_fns.items() 56 | ) 57 | meta = dict( 58 | (key, meta_extractor_fn(batch)) 59 | for key, meta_extractor_fn in meta_extractors.items() 60 | ) 61 | return { 62 | "data": data, 63 | "meta": merge_meta(meta), 64 | } 65 | 66 | @classmethod 67 | def get_encoders_output_size(cls, encoders: Union[Encoder, Dict[str, Encoder]]): 68 | """Calculate total output size of given encoders 69 | 70 | Args: 71 | encoders: 72 | 73 | """ 74 | encoders = encoders.values() if isinstance(encoders, dict) else [encoders] 75 | total_size = 0 76 | for encoder in encoders: 77 | total_size += encoder.embedding_size 78 | return total_size 79 | 80 | def train(self, mode: bool = True): 81 | super().train(mode) 82 | 83 | def get_collate_fn(self) -> Callable: 84 | """Construct a function to convert input data into neural network inputs 85 | 86 | Returns: 87 | neural network inputs 88 | """ 89 | return partial( 90 | SimilarityModel.collate_fn, 91 | encoders_collate_fns=dict( 92 | (key, encoder.get_collate_fn()) 93 | for key, encoder in self.encoders.items() 94 | ), 95 | meta_extractors=dict( 96 | (key, encoder.get_meta_extractor()) 97 | for key, encoder in self.encoders.items() 98 | ), 99 | ) 100 | 101 | # ------------------------------------------- 102 | # ---------- Inference methods -------------- 103 | # ------------------------------------------- 104 | 105 | def encode( 106 | self, inputs: Union[List[Any], Any], batch_size=32, to_numpy=True 107 | ) -> Union[torch.Tensor, np.ndarray]: 108 | """Encode data in batches 109 | 110 | Args: 111 | inputs: list of input data to encode 112 | batch_size: 113 | to_numpy: 114 | 115 | Returns: 116 | Numpy array or torch.Tensor of shape (input_size, embedding_size) 117 | """ 118 | self.eval() 119 | device = next(self.parameters(), torch.tensor(0)).device 120 | collate_fn = self.get_collate_fn() 121 | 122 | input_was_list = True 123 | if not isinstance(inputs, list): 124 | input_was_list = False 125 | inputs = [inputs] 126 | 127 | all_embeddings = [] 128 | 129 | for start_index in range(0, len(inputs), batch_size): 130 | input_batch = [ 131 | inputs[i] 132 | for i in range(start_index, min(len(inputs), start_index + batch_size)) 133 | ] 134 | features = collate_fn(input_batch) 135 | features = move_to_device(features, device) 136 | 137 | with torch.no_grad(): 138 | embeddings = self.forward(features) 139 | embeddings = embeddings.detach() 140 | if to_numpy: 141 | embeddings = embeddings.cpu().numpy() 142 | all_embeddings.append(embeddings) 143 | 144 | if to_numpy: 145 | all_embeddings = np.concatenate(all_embeddings, axis=0) 146 | else: 147 | all_embeddings = torch.cat(all_embeddings, dim=0) 148 | 149 | if not input_was_list: 150 | all_embeddings = all_embeddings.squeeze() 151 | 152 | if to_numpy: 153 | all_embeddings = np.atleast_2d(all_embeddings) 154 | else: 155 | all_embeddings = torch.atleast_2d(all_embeddings) 156 | 157 | return all_embeddings 158 | 159 | def forward(self, batch): 160 | embeddings = [ 161 | (key, encoder.forward(batch["data"][key])) 162 | for key, encoder in self.encoders.items() 163 | ] 164 | 165 | meta = batch["meta"] 166 | 167 | # Order embeddings by key name, to ensure reproduction 168 | embeddings = sorted(embeddings, key=lambda x: x[0]) 169 | 170 | # Only embedding tensors of shape [batch_size x encoder_output_size] 171 | embedding_tensors = [embedding[1] for embedding in embeddings] 172 | 173 | # Shape: [batch_size x sum( encoders_emb_sizes )] 174 | joined_embeddings = torch.cat(embedding_tensors, dim=1) 175 | 176 | # Shape: [batch_size x output_emb_size] 177 | result_embedding = self.head(joined_embeddings, meta=meta) 178 | 179 | return result_embedding 180 | 181 | # ------------------------------------------- 182 | # ---------- Persistence methods ------------ 183 | # ------------------------------------------- 184 | 185 | @classmethod 186 | def _get_head_path(cls, directory: str): 187 | return os.path.join(directory, "head") 188 | 189 | @classmethod 190 | def _get_encoders_path(cls, directory: str): 191 | return os.path.join(directory, "encoders") 192 | 193 | def save(self, output_path: str): 194 | head_path = self._get_head_path(output_path) 195 | os.makedirs(head_path, exist_ok=True) 196 | self.head.save(head_path) 197 | 198 | head_config = save_class_import(self.head) 199 | 200 | encoders_path = self._get_encoders_path(output_path) 201 | os.makedirs(encoders_path, exist_ok=True) 202 | 203 | encoders_config = [] 204 | 205 | for encoder_key, encoder in self.encoders.items(): 206 | encoder_path = os.path.join(encoders_path, encoder_key) 207 | os.mkdir(encoder_path) 208 | encoder.save(encoder_path) 209 | encoders_config.append({"key": encoder_key, **save_class_import(encoder)}) 210 | 211 | with open(os.path.join(output_path, "config.json"), "w") as f_out: 212 | json.dump( 213 | {"encoders": encoders_config, "head": head_config}, f_out, indent=2 214 | ) 215 | 216 | @classmethod 217 | def load(cls, input_path: str) -> SimilarityModel: 218 | with open(os.path.join(input_path, "config.json")) as f_in: 219 | config = json.load(f_in) 220 | 221 | head_config = config["head"] 222 | head_class: Type[EncoderHead] = restore_class(head_config) 223 | head_path = cls._get_head_path(input_path) 224 | head = head_class.load(head_path) 225 | 226 | encoders: Union[Encoder, Dict[str, Encoder]] = {} 227 | encoders_path = cls._get_encoders_path(input_path) 228 | encoders_config = config["encoders"] 229 | 230 | for encoder_params in encoders_config: 231 | encoder_key = encoder_params["key"] 232 | encoder_class = restore_class(encoder_params) 233 | encoders[encoder_key] = encoder_class.load( 234 | os.path.join(encoders_path, encoder_key) 235 | ) 236 | 237 | return cls(head=head, encoders=encoders) 238 | 239 | 240 | # In this framework, the terms Metric Learning and Similarity Learning are considered synonymous. 241 | # However, the word "Metric" overlaps with other concepts in model training. 242 | # In addition, the semantics of the word "Similarity" are simpler. 243 | # It better reflects the basic idea of this training approach. 244 | # That's why we prefer to use Similarity over Metric. 245 | MetricModel = SimilarityModel 246 | -------------------------------------------------------------------------------- /quaterion_models/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.modules.simple import ActivationFromFnName 2 | -------------------------------------------------------------------------------- /quaterion_models/modules/simple.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ActivationFromFnName(nn.Module): 7 | """Simple module constructed from function name to be used in `nn.Sequential` 8 | 9 | Construct a `nn.Module` that applies the specified activation function to inputs 10 | 11 | Args: 12 | activation_fn: Name of the activation function to apply to input. 13 | Must be an attribute of `torch.nn.functional`. 14 | """ 15 | 16 | def __init__(self, activation_fn: str): 17 | super().__init__() 18 | self._activation_fn = activation_fn 19 | 20 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 21 | return vars(F)[self._activation_fn](inputs) 22 | -------------------------------------------------------------------------------- /quaterion_models/types/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: Split them into separate files once we have more of them. 2 | from typing import Any, Callable, Dict, List, Tuple, Union 3 | 4 | from torch import Tensor 5 | 6 | #: 7 | TensorInterchange = Union[ 8 | Tensor, 9 | Tuple[Tensor], 10 | List[Tensor], 11 | Dict[str, Tensor], 12 | Dict[str, dict], 13 | Any, 14 | ] 15 | #: 16 | CollateFnType = Callable[[List[Any]], TensorInterchange] 17 | MetaExtractorFnType = Callable[[List[Any]], List[dict]] 18 | -------------------------------------------------------------------------------- /quaterion_models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.utils.classes import restore_class, save_class_import 2 | from quaterion_models.utils.tensors import move_to_device 3 | -------------------------------------------------------------------------------- /quaterion_models/utils/classes.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | from typing import Any 4 | 5 | 6 | def save_class_import(obj: Any) -> dict: 7 | """ 8 | Serializes information about object class 9 | :param obj: 10 | :return: serializable class info 11 | """ 12 | if obj.__module__ == "__main__": 13 | logging.warning( 14 | f"Class {obj.__class__.__qualname__} is defined in a same file as training loop." 15 | f" It won't be possible to load it properly later." 16 | ) 17 | 18 | return {"module": obj.__module__, "class": obj.__class__.__qualname__} 19 | 20 | 21 | def restore_class(data: dict) -> Any: 22 | """ 23 | :param data: name of module and class 24 | :return: Class 25 | """ 26 | module = data["module"] 27 | class_name = data["class"] 28 | module = importlib.import_module(module) 29 | return getattr(module, class_name) 30 | -------------------------------------------------------------------------------- /quaterion_models/utils/meta.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | 4 | def merge_meta(meta: Dict[str, list]) -> List[dict]: 5 | """Merge meta information from multiple encoders into one 6 | 7 | Combine meta from all encoders 8 | Example: Encoder 1 meta: `[{"a": 1}, {"a": 2}]`, Encoder 2 meta: `[{"b": 3}, {"b": 4}]` 9 | Result: `[{"a": 1, "b": 3}, {"a": 2, "b": 4}]` 10 | 11 | Args: 12 | meta: meta information to merge 13 | """ 14 | aggregated = None 15 | for key, encoder_meta in meta.items(): 16 | if aggregated is None: 17 | aggregated = encoder_meta 18 | else: 19 | for i in range(len(meta)): 20 | aggregated[i].update(encoder_meta[i]) 21 | return aggregated 22 | -------------------------------------------------------------------------------- /quaterion_models/utils/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def move_to_device(obj, device: torch.device): 5 | """ 6 | Given a structure (possibly) containing Tensors on the CPU, 7 | move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU). 8 | """ 9 | 10 | if device == torch.device("cpu"): 11 | return obj 12 | elif isinstance(obj, torch.Tensor): 13 | return obj.cuda(device) 14 | elif isinstance(obj, dict): 15 | return {key: move_to_device(value, device) for key, value in obj.items()} 16 | elif isinstance(obj, list): 17 | return [move_to_device(item, device) for item in obj] 18 | elif isinstance(obj, tuple) and hasattr(obj, "_fields"): 19 | # This is the best way to detect a NamedTuple, it turns out. 20 | return obj.__class__(*(move_to_device(item, device) for item in obj)) 21 | elif isinstance(obj, tuple): 22 | return tuple(move_to_device(item, device) for item in obj) 23 | else: 24 | return obj 25 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/tests/__init__.py -------------------------------------------------------------------------------- /tests/encoders/test_fasttext_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from quaterion_models.encoders.extras.fasttext_encoder import FasttextEncoder 5 | 6 | 7 | def test_fasttext_encoder(): 8 | from gensim.models import FastText 9 | 10 | demo_texts = [ 11 | ["aaa", "bbb", "ccc", "ddd", "123"], 12 | ["aaa", "bbb", "ccc", "aaa", "123"], 13 | ["aaa", "bbb", "ccc", "bbb", "123"], 14 | ["aaa", "bbb", "ccc", "ccc", "123"], 15 | ["aaa", "bbb", "ccc", "123", "123"], 16 | ] 17 | 18 | model = FastText( 19 | vector_size=10, window=1, min_count=0, min_n=0, max_n=0, sg=1, bucket=1_000 20 | ) 21 | epochs = 10 22 | model.build_vocab(demo_texts) 23 | model.train(demo_texts, epochs=epochs, total_examples=len(demo_texts)) 24 | tempdir = tempfile.TemporaryDirectory() 25 | 26 | model_path = os.path.join(tempdir.name, "fasttext.model") 27 | model.wv.save(model_path, separately=["vectors_ngrams", "vectors", "vectors_vocab"]) 28 | 29 | encoder = FasttextEncoder( 30 | model_path=model_path, on_disk=False, aggregations=["avg", "max"] 31 | ) 32 | 33 | assert encoder.embedding_size == 20 34 | 35 | embeddings = encoder.forward([["aaa", "123"], ["aaa", "ccc"]]) 36 | 37 | assert embeddings.shape[0] == 2 38 | assert embeddings.shape[1] == 20 39 | 40 | encoder.named_parameters() 41 | 42 | empty_embeddings = encoder.forward([[], []]) 43 | 44 | assert empty_embeddings.shape[0] == 2 45 | assert empty_embeddings.shape[1] == 20 46 | -------------------------------------------------------------------------------- /tests/encoders/test_switch_encoder.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from abc import ABC 3 | from typing import Any, List 4 | 5 | import numpy as np 6 | import torch 7 | from torch import Tensor 8 | 9 | from quaterion_models.encoders import Encoder 10 | from quaterion_models.encoders.switch_encoder import SwitchEncoder 11 | from quaterion_models.heads.empty_head import EmptyHead 12 | from quaterion_models.model import SimilarityModel 13 | from quaterion_models.types import CollateFnType, TensorInterchange 14 | 15 | 16 | class CustomEncoder(Encoder, ABC): 17 | @property 18 | def trainable(self) -> bool: 19 | return False 20 | 21 | @property 22 | def embedding_size(self) -> int: 23 | return 3 24 | 25 | def save(self, output_path: str): 26 | pass 27 | 28 | @classmethod 29 | def load(cls, input_path: str) -> "Encoder": 30 | return cls() 31 | 32 | @classmethod 33 | def collate_fn(cls, batch: List[Any]) -> TensorInterchange: 34 | return [torch.zeros(1) for _ in batch] 35 | 36 | def get_collate_fn(self) -> CollateFnType: 37 | return self.__class__.collate_fn 38 | 39 | 40 | class EncoderA(CustomEncoder): 41 | def forward(self, batch: TensorInterchange) -> Tensor: 42 | return torch.zeros(len(batch), self.embedding_size) 43 | 44 | 45 | class EncoderB(CustomEncoder): 46 | def forward(self, batch: TensorInterchange) -> Tensor: 47 | return torch.ones(len(batch), self.embedding_size) 48 | 49 | 50 | class CustomSwitchEncoder(SwitchEncoder): 51 | @classmethod 52 | def encoder_selection(cls, record: Any) -> str: 53 | if record == "zeros": 54 | return "a" 55 | if record == "ones": 56 | return "b" 57 | 58 | 59 | def test_forward(): 60 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()}) 61 | 62 | model = SimilarityModel(encoders=encoder, head=EmptyHead(encoder.embedding_size)) 63 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"] 64 | 65 | res = model.encode(batch) 66 | 67 | assert res.shape[0] == len(batch) 68 | assert all(res[:, 0] == np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0])) 69 | 70 | 71 | def test_meta(): 72 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()}) 73 | meta_extractor = encoder.get_meta_extractor() 74 | 75 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"] 76 | 77 | meta = meta_extractor(batch) 78 | print("") 79 | print(meta) 80 | 81 | assert meta[0]["encoder"] == "a" 82 | assert meta[1]["encoder"] == "a" 83 | assert meta[2]["encoder"] == "b" 84 | assert meta[3]["encoder"] == "b" 85 | assert meta[4]["encoder"] == "a" 86 | assert meta[5]["encoder"] == "a" 87 | assert meta[6]["encoder"] == "b" 88 | 89 | 90 | def test_save_and_load(): 91 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()}) 92 | 93 | tempdir = tempfile.TemporaryDirectory() 94 | model = SimilarityModel(encoders=encoder, head=EmptyHead(encoder.embedding_size)) 95 | model.save(tempdir.name) 96 | model = model.load(tempdir.name) 97 | 98 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"] 99 | res = model.encode(batch) 100 | 101 | assert res.shape[0] == len(batch) 102 | assert all(res[:, 0] == np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0])) 103 | -------------------------------------------------------------------------------- /tests/encoders/test_switch_head.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | import torch 4 | 5 | from quaterion_models.heads import EncoderHead 6 | from quaterion_models.heads.switch_head import SwitchHead 7 | 8 | 9 | class FakeHeadA(EncoderHead): 10 | @property 11 | def output_size(self) -> int: 12 | return 3 13 | 14 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 15 | return input_vectors + torch.tensor( 16 | [[0, 0, 1] for _ in range(input_vectors.shape[0])] 17 | ) 18 | 19 | 20 | class FakeHeadB(EncoderHead): 21 | @property 22 | def output_size(self) -> int: 23 | return 3 24 | 25 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 26 | return input_vectors + torch.tensor( 27 | [[0, 2, 0] for _ in range(input_vectors.shape[0])] 28 | ) 29 | 30 | 31 | def test_save_and_load(): 32 | head = SwitchHead( 33 | options={"a": FakeHeadA(3), "b": FakeHeadB(3)}, 34 | input_embedding_size=3, 35 | ) 36 | 37 | temp_dir = tempfile.TemporaryDirectory() 38 | head.save(temp_dir.name) 39 | 40 | loaded_head = SwitchHead.load(temp_dir.name) 41 | 42 | print(loaded_head) 43 | 44 | 45 | def test_forward(): 46 | head = SwitchHead( 47 | options={"a": FakeHeadA(3), "b": FakeHeadB(3)}, 48 | input_embedding_size=3, 49 | ) 50 | 51 | batch = torch.tensor( 52 | [ 53 | [1, 0, 0], 54 | [2, 0, 0], 55 | [3, 0, 0], 56 | [4, 0, 0], 57 | [5, 0, 0], 58 | ] 59 | ) 60 | 61 | meta = [ 62 | {"encoder": "b"}, 63 | {"encoder": "a"}, 64 | {"encoder": "b"}, 65 | {"encoder": "b"}, 66 | {"encoder": "a"}, 67 | ] 68 | 69 | res = head.forward(batch, meta) 70 | 71 | assert res.shape[0] == batch.shape[0] 72 | assert res.shape[1] == 3 73 | 74 | assert res[0][0] == 1 75 | assert res[1][0] == 2 76 | assert res[2][0] == 3 77 | assert res[3][0] == 4 78 | assert res[4][0] == 5 79 | 80 | assert res[0][1] == 2 81 | assert res[1][1] == 0 82 | assert res[2][1] == 2 83 | assert res[3][1] == 2 84 | assert res[4][1] == 0 85 | 86 | assert res[0][2] == 0 87 | assert res[1][2] == 1 88 | assert res[2][2] == 0 89 | assert res[3][2] == 0 90 | assert res[4][2] == 1 91 | -------------------------------------------------------------------------------- /tests/heads/test_head.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Any, List 4 | 5 | import pytest 6 | import torch 7 | import torch.nn as nn 8 | from torch import Tensor 9 | 10 | from quaterion_models import SimilarityModel 11 | from quaterion_models.encoders import Encoder 12 | from quaterion_models.heads import ( 13 | EmptyHead, 14 | GatedHead, 15 | SequentialHead, 16 | SkipConnectionHead, 17 | SoftmaxEmbeddingsHead, 18 | WideningHead, 19 | ) 20 | from quaterion_models.types import CollateFnType, TensorInterchange 21 | 22 | BATCH_SIZE = 3 23 | INPUT_EMBEDDING_SIZE = 5 24 | HIDDEN_EMBEDDING_SIZE = 7 25 | OUTPUT_EMBEDDING_SIZE = 10 26 | 27 | _HEADS = ( 28 | SequentialHead( 29 | nn.Linear(INPUT_EMBEDDING_SIZE, HIDDEN_EMBEDDING_SIZE), 30 | nn.ReLU(), 31 | nn.Linear(HIDDEN_EMBEDDING_SIZE, OUTPUT_EMBEDDING_SIZE), 32 | output_size=OUTPUT_EMBEDDING_SIZE, 33 | ), 34 | GatedHead(INPUT_EMBEDDING_SIZE), 35 | WideningHead(INPUT_EMBEDDING_SIZE), 36 | SkipConnectionHead(INPUT_EMBEDDING_SIZE), 37 | EmptyHead(INPUT_EMBEDDING_SIZE), 38 | SoftmaxEmbeddingsHead( 39 | output_groups=2, 40 | output_size_per_group=OUTPUT_EMBEDDING_SIZE, 41 | input_embedding_size=INPUT_EMBEDDING_SIZE, 42 | ), 43 | ) 44 | HEADS = {head_.__class__.__name__: head_ for head_ in _HEADS} 45 | 46 | 47 | class CustomEncoder(Encoder): 48 | def save(self, output_path: str): 49 | pass 50 | 51 | @classmethod 52 | def load(cls, input_path: str) -> "Encoder": 53 | return cls() 54 | 55 | @property 56 | def trainable(self) -> bool: 57 | return False 58 | 59 | @property 60 | def embedding_size(self) -> int: 61 | return INPUT_EMBEDDING_SIZE 62 | 63 | @classmethod 64 | def collate_fn(cls, batch: List[Any]): 65 | return torch.stack(batch) 66 | 67 | def get_collate_fn(self) -> CollateFnType: 68 | return self.__class__.collate_fn 69 | 70 | def forward(self, batch: TensorInterchange) -> Tensor: 71 | return batch 72 | 73 | 74 | @pytest.mark.parametrize("head", HEADS.values(), ids=HEADS.keys()) 75 | def test_save_and_load(head): 76 | encoder = CustomEncoder() 77 | 78 | model = SimilarityModel(encoders=encoder, head=head) 79 | tempdir = tempfile.TemporaryDirectory() 80 | 81 | model.save(tempdir.name) 82 | 83 | config_path = os.path.join(tempdir.name, "config.json") 84 | 85 | assert os.path.exists(config_path) 86 | 87 | batch = torch.rand(BATCH_SIZE, INPUT_EMBEDDING_SIZE) 88 | origin_output = model.encode(batch, to_numpy=False) 89 | 90 | loaded_model = SimilarityModel.load(tempdir.name) 91 | 92 | assert model.encoders.keys() == loaded_model.encoders.keys() 93 | assert [type(encoder) for encoder in model.encoders.values()] == [ 94 | type(encoder) for encoder in loaded_model.encoders.values() 95 | ] 96 | 97 | assert type(model.head) == type(loaded_model.head) 98 | assert torch.allclose(origin_output, loaded_model.encode(batch, to_numpy=False)) 99 | 100 | 101 | @pytest.mark.parametrize("head", HEADS.values(), ids=HEADS.keys()) 102 | def test_forward_shape(head): 103 | batch = torch.rand(BATCH_SIZE, INPUT_EMBEDDING_SIZE) 104 | res = head.forward(batch) 105 | assert res.shape == (BATCH_SIZE, head.output_size) 106 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from multiprocessing import Pool 4 | from typing import Any, List 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from quaterion_models import SimilarityModel 10 | from quaterion_models.encoders import Encoder 11 | from quaterion_models.heads import EmptyHead, EncoderHead 12 | from quaterion_models.types import CollateFnType, TensorInterchange 13 | 14 | TEST_EMB_SIZE = 5 15 | 16 | 17 | class LambdaHead(EncoderHead): 18 | def __init__(self): 19 | super(LambdaHead, self).__init__(TEST_EMB_SIZE) 20 | self.my_lambda = lambda x: "hello" 21 | 22 | @property 23 | def output_size(self) -> int: 24 | return 0 25 | 26 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor: 27 | return input_vectors 28 | 29 | 30 | class CustomEncoder(Encoder): 31 | def save(self, output_path: str): 32 | pass 33 | 34 | @classmethod 35 | def load(cls, input_path: str) -> "Encoder": 36 | return cls() 37 | 38 | def __init__(self): 39 | super().__init__() 40 | self.unpickable = lambda x: x + 1 41 | 42 | @property 43 | def trainable(self) -> bool: 44 | return False 45 | 46 | @property 47 | def embedding_size(self) -> int: 48 | return TEST_EMB_SIZE 49 | 50 | @classmethod 51 | def collate_fn(cls, batch: List[Any]): 52 | return torch.rand(len(batch), TEST_EMB_SIZE) 53 | 54 | def get_collate_fn(self) -> CollateFnType: 55 | return self.__class__.collate_fn 56 | 57 | def forward(self, batch: TensorInterchange) -> Tensor: 58 | return batch 59 | 60 | 61 | class Tst: 62 | def __init__(self, foo): 63 | self.foo = foo 64 | 65 | def bar(self, x): 66 | return self.foo(x) 67 | 68 | 69 | def test_get_collate_fn(): 70 | model = SimilarityModel(encoders={"test": CustomEncoder()}, head=LambdaHead()) 71 | 72 | tester = Tst(foo=model.get_collate_fn()) 73 | 74 | with Pool(2) as pool: 75 | res = pool.map(tester.bar, [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]]) 76 | 77 | assert len(res) == 4 78 | 79 | first_batch = res[0] 80 | 81 | assert "test" in first_batch["data"] 82 | 83 | tensor = first_batch["data"]["test"] 84 | 85 | assert tensor.shape == (3, TEST_EMB_SIZE) 86 | 87 | 88 | def test_model_save_and_load(): 89 | tempdir = tempfile.TemporaryDirectory() 90 | model = SimilarityModel(encoders={"test": CustomEncoder()}, head=EmptyHead(100)) 91 | 92 | model.save(tempdir.name) 93 | 94 | config_path = os.path.join(tempdir.name, "config.json") 95 | 96 | assert os.path.exists(config_path) 97 | 98 | loaded_model = SimilarityModel.load(tempdir.name) 99 | 100 | assert model.encoders.keys() == loaded_model.encoders.keys() 101 | assert [type(encoder) for encoder in model.encoders.values()] == [ 102 | type(encoder) for encoder in loaded_model.encoders.values() 103 | ] 104 | 105 | assert type(model.head) == type(loaded_model.head) 106 | -------------------------------------------------------------------------------- /tests/utils/test_classes.py: -------------------------------------------------------------------------------- 1 | from quaterion_models.utils import restore_class, save_class_import 2 | 3 | 4 | def test_restore_class(): 5 | model_class = restore_class({"module": "torch.nn", "class": "Linear"}) 6 | 7 | from torch.nn import Linear 8 | 9 | model: Linear = model_class(10, 10) 10 | 11 | assert model.out_features == 10 12 | 13 | 14 | def test_save_class_import(): 15 | from collections import Counter 16 | 17 | class_serialized = save_class_import(Counter()) 18 | 19 | assert class_serialized["class"] == "Counter" 20 | assert class_serialized["module"] == "collections" 21 | --------------------------------------------------------------------------------