├── .github
└── workflows
│ ├── build-docs.yml
│ ├── python-publish.yml
│ └── test.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── docs
├── .gitignore
├── generate_docs.sh
├── generate_docs_netlify.sh
├── html
│ └── .keep
└── source
│ ├── api
│ └── index.rst
│ ├── conf.py
│ ├── index.rst
│ ├── modules.rst
│ ├── quaterion_models.encoders.encoder.rst
│ ├── quaterion_models.encoders.extras.fasttext_encoder.rst
│ ├── quaterion_models.encoders.extras.rst
│ ├── quaterion_models.encoders.rst
│ ├── quaterion_models.encoders.switch_encoder.rst
│ ├── quaterion_models.heads.empty_head.rst
│ ├── quaterion_models.heads.encoder_head.rst
│ ├── quaterion_models.heads.gated_head.rst
│ ├── quaterion_models.heads.rst
│ ├── quaterion_models.heads.sequential_head.rst
│ ├── quaterion_models.heads.skip_connection_head.rst
│ ├── quaterion_models.heads.softmax_head.rst
│ ├── quaterion_models.heads.stacked_projection_head.rst
│ ├── quaterion_models.heads.widening_head.rst
│ ├── quaterion_models.model.rst
│ ├── quaterion_models.modules.rst
│ ├── quaterion_models.modules.simple.rst
│ ├── quaterion_models.rst
│ ├── quaterion_models.types.rst
│ ├── quaterion_models.utils.classes.rst
│ ├── quaterion_models.utils.rst
│ └── quaterion_models.utils.tensors.rst
├── netlify.toml
├── poetry.lock
├── pyproject.toml
├── quaterion_models
├── __init__.py
├── encoders
│ ├── __init__.py
│ ├── encoder.py
│ ├── extras
│ │ ├── __init__.py
│ │ └── fasttext_encoder.py
│ └── switch_encoder.py
├── heads
│ ├── __init__.py
│ ├── empty_head.py
│ ├── encoder_head.py
│ ├── gated_head.py
│ ├── sequential_head.py
│ ├── skip_connection_head.py
│ ├── softmax_head.py
│ ├── stacked_projection_head.py
│ ├── switch_head.py
│ └── widening_head.py
├── model.py
├── modules
│ ├── __init__.py
│ └── simple.py
├── types
│ └── __init__.py
└── utils
│ ├── __init__.py
│ ├── classes.py
│ ├── meta.py
│ └── tensors.py
└── tests
├── __init__.py
├── encoders
├── test_fasttext_encoder.py
├── test_switch_encoder.py
└── test_switch_head.py
├── heads
└── test_head.py
├── test_model.py
└── utils
└── test_classes.py
/.github/workflows/build-docs.yml:
--------------------------------------------------------------------------------
1 | name: build-docs
2 | on:
3 | pull_request_target:
4 | types:
5 | - closed
6 |
7 | jobs:
8 | build:
9 | if: github.event.pull_request.merged == true
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | with:
14 | fetch-depth: 0
15 | - name: Set up Python
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: '3.9.x'
19 | - name: Install dependencies
20 | run: |
21 | python -m pip install poetry
22 | poetry install -E fasttext
23 | poetry run pip install 'setuptools==59.5.0' # temporary fix for https://github.com/pytorch/pytorch/pull/69904
24 | - name: Generate docs
25 | run: |
26 | bash -x docs/generate_docs.sh
27 | git config user.name 'qdrant'
28 | git config user.email 'qdrant@users.noreply.github.com'
29 | git config pull.rebase false
30 | git remote set-url origin https://x-access-token:${{ secrets.GITHUB_TOKEN }}@github.com/$GITHUB_REPOSITORY
31 | git checkout $GITHUB_HEAD_REF
32 | git add ./docs && git commit -m "docs: auto-generate docs with sphinx" && git pull && git push || true
33 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | workflow_dispatch:
13 | push:
14 | # Pattern matched against refs/tags
15 | tags:
16 | - 'v*' # Push events to every version tag
17 |
18 |
19 | jobs:
20 | deploy:
21 |
22 | runs-on: ubuntu-latest
23 |
24 | steps:
25 | - uses: actions/checkout@v2
26 | - name: Set up Python
27 | uses: actions/setup-python@v2
28 | with:
29 | python-version: '3.9.x'
30 | - name: Install dependencies
31 | run: |
32 | python -m pip install poetry
33 | poetry install
34 | - name: Build package
35 | run: poetry build
36 | - name: Publish package
37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
38 | with:
39 | user: __token__
40 | password: ${{ secrets.PYPI_API_TOKEN }}
41 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Test
2 |
3 | on: [push]
4 |
5 | jobs:
6 | build:
7 | runs-on: ${{ matrix.os }}
8 | strategy:
9 | fail-fast: true
10 | matrix:
11 | python-version: [3.8, 3.9, "3.10"]
12 | os: [ubuntu-latest] # [ubuntu-latest, macOS-latest, windows-latest]
13 |
14 | steps:
15 | - uses: actions/checkout@v1
16 |
17 | - name: Set up Python ${{ matrix.python-version }}
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: ${{ matrix.python-version }}
21 |
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip poetry
25 | poetry install -E fasttext
26 | poetry run pip install 'setuptools==59.5.0' # temporary fix for https://github.com/pytorch/pytorch/pull/69904
27 |
28 | - name: Unit tests
29 | run: |
30 | poetry run pytest
31 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | .idea/
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100 | __pypackages__/
101 |
102 | # Celery stuff
103 | celerybeat-schedule
104 | celerybeat.pid
105 |
106 | # SageMath parsed files
107 | *.sage.py
108 |
109 | # Environments
110 | .env
111 | .venv
112 | env/
113 | venv/
114 | ENV/
115 | env.bak/
116 | venv.bak/
117 |
118 | # Spyder project settings
119 | .spyderproject
120 | .spyproject
121 |
122 | # Rope project settings
123 | .ropeproject
124 |
125 | # mkdocs documentation
126 | /site
127 |
128 | # mypy
129 | .mypy_cache/
130 | .dmypy.json
131 | dmypy.json
132 |
133 | # Pyre type checker
134 | .pyre/
135 |
136 | # pytype static type analyzer
137 | .pytype/
138 |
139 | # Cython debug symbols
140 | cython_debug/
141 |
142 | # Do not include auto-generated setup.py
143 | /setup.py
144 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | default_language_version:
4 | python: python3.8
5 |
6 | ci:
7 | autofix_prs: true
8 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
9 | autoupdate_schedule: quarterly
10 | # submodules: true
11 |
12 | repos:
13 | - repo: https://github.com/pre-commit/pre-commit-hooks
14 | rev: v4.4.0
15 | hooks:
16 | - id: trailing-whitespace
17 | - id: end-of-file-fixer
18 | - id: check-added-large-files
19 |
20 | - repo: https://github.com/psf/black
21 | rev: 22.12.0
22 | hooks:
23 | - id: black
24 | name: "Black: The uncompromising Python code formatter"
25 |
26 | - repo: https://github.com/PyCQA/isort
27 | rev: 5.11.4
28 | hooks:
29 | - id: isort
30 | name: "Sort Imports"
31 | args: ["--profile", "black"]
32 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | email: team@qdrant.tech.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Quaterion Models
2 | We love your input! We want to make contributing to this project as easy and transparent as possible, whether it's:
3 |
4 | - Reporting a bug
5 | - Discussing the current state of the code
6 | - Submitting a fix
7 | - Proposing new features
8 |
9 | ## We Develop with GitHub
10 | We use github to host code, to track issues and feature requests, as well as accept pull requests.
11 |
12 | ## We Use [Github Flow](https://guides.github.com/introduction/flow/index.html), So All Code Changes Happen Through Pull Requests
13 | Pull requests are the best way to propose changes to the codebase (we use [Github Flow](https://guides.github.com/introduction/flow/index.html)).
14 | We actively welcome your pull requests:
15 |
16 | 1. Fork the repo and create your branch from `master`.
17 | 2. If you've added code that should be tested, add tests.
18 | 3. Ensure the test suite passes.
19 | 4. Make sure your code lints (ToDo).
20 | 5. Make sure that commits have a reference to related issue (e.g. `Fix model training #num_of_issue`)
21 | 6. Issue that pull request!
22 |
23 | ## Any contributions you make will be under the Apache License 2.0
24 | In short, when you submit code changes, your submissions are understood to be under the same [Apache License 2.0](https://choosealicense.com/licenses/apache-2.0/) that covers the project. Feel free to contact the maintainers if that's a concern.
25 |
26 | ## Report bugs using Github's [issues](https://github.com/qdrant/quaterion-models/issues)
27 | We use GitHub issues to track public bugs. Report a bug by [opening a new issue](https://github.com/qdrant/quaterion-models/issues/new); it's that easy!
28 |
29 | ## Write bug reports with detail, background, and sample code
30 |
31 | **Great Bug Reports** tend to have:
32 |
33 | - A quick summary and/or background
34 | - Steps to reproduce
35 | - Be specific!
36 | - Give sample code if you can.
37 | - What you expected would happen
38 | - What actually happens
39 | - Notes (possibly including why you think this might be happening, or stuff you tried that didn't work)
40 |
41 | ## Coding Style
42 |
43 | 1. We use [PEP8](https://www.python.org/dev/peps/pep-0008/) code style
44 | 2. We use [Python Type Annotations](https://docs.python.org/3/library/typing.html) whenever it is necessary
45 | 1. If your IDE cannot infer type of some variable, it is a good sign to add some more type annotations
46 | 3. We document tensor transformations - type of tensors are usually not enough for comfortable understanding of the code
47 | 4. We prefer simplicity and practical approach over kaggle-level state-of-the-art accuracy
48 | 1. If some modules or layers have complicated interface, dependencies, or just very complicated internally - we would prefer to keep them outside Quaterion Models.
49 |
50 | ## License
51 | By contributing, you agree that your contributions will be licensed under its Apache License 2.0.
52 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Quaterion Models
2 |
3 | `quaterion-models` is a part of [`Quaterion`](https://github.com/qdrant/quaterion), similarity learning framework.
4 | It is kept as a separate package to make servable models lightweight and free from training dependencies.
5 |
6 | It contains definition of base classes, used for model inference, as well as the collection of building blocks for building fine-tunable similarity learning models.
7 | The documentation can be found [here](https://quaterion-models.qdrant.tech/).
8 |
9 | If you are looking for the training-related part of Quaterion, please see the [main repository](https://github.com/qdrant/quaterion) instead.
10 |
11 | ## Install
12 |
13 | ```bash
14 | pip install quaterion-models
15 | ```
16 |
17 | It makes sense to install `quaterion-models` independent of the main framework if you already have trained model
18 | and only need to make inference.
19 |
20 | ## Load and inference
21 |
22 | ```python
23 | from quaterion_models import SimilarityModel
24 |
25 | model = SimilarityModel.load("./path/to/saved/model")
26 |
27 | embeddings = model.encode([
28 | {"description": "this is an example input"},
29 | {"description": "you may have a different format"},
30 | {"description": "the output will be a numpy array"},
31 | {"description": "of size [batch_size, embedding_size]"},
32 | ])
33 | ```
34 |
35 | ## Content
36 |
37 | * `SimilarityModel` - main class which contains encoder models with the head layer
38 | * Base class for Encoders
39 | * Base class and various implementations of the Head Layers
40 | * Additional helper functions
41 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | .doctrees/
2 | html/
3 |
--------------------------------------------------------------------------------
/docs/generate_docs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | set -e
4 |
5 | # Ensure current path is project root
6 | cd "$(dirname "$0")/../"
7 |
8 | poetry run sphinx-apidoc -f -e -o docs/source quaterion_models
9 | poetry run sphinx-build docs/source docs/html
10 |
--------------------------------------------------------------------------------
/docs/generate_docs_netlify.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | set -xe
4 |
5 | # Ensure current path is project root
6 | cd "$(dirname "$0")/../"
7 |
8 | # install CPU torch, cause it is smaller
9 | pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
10 |
11 | pip install poetry
12 | poetry build -f wheel
13 | pip install dist/$(ls -1 dist | grep .whl)
14 | poetry install --extras "fasttext"
15 |
16 | pip install sphinx>=4.4.0
17 | pip install "git+https://github.com/qdrant/qdrant_sphinx_theme.git@master#egg=qdrant-sphinx-theme"
18 |
19 | sphinx-build docs/source docs/html
20 |
--------------------------------------------------------------------------------
/docs/html/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/docs/html/.keep
--------------------------------------------------------------------------------
/docs/source/api/index.rst:
--------------------------------------------------------------------------------
1 | API References
2 | ~~~~~~~~~~~~~~
3 |
4 | Similarity Model
5 | ----------------
6 |
7 | .. py:currentmodule:: quaterion_models.model
8 |
9 | .. autosummary::
10 | :nosignatures:
11 |
12 | SimilarityModel
13 |
14 | Encoders
15 | --------
16 |
17 | .. py:currentmodule:: quaterion_models.encoders
18 |
19 | .. autosummary::
20 | :nosignatures:
21 |
22 | ~encoder.Encoder
23 | ~switch_encoder.SwitchEncoder
24 | ~extras.fasttext_encoder.FasttextEncoder
25 |
26 |
27 | Head Layers
28 | -----------
29 |
30 | .. py:currentmodule:: quaterion_models.heads
31 |
32 | .. autosummary::
33 | :nosignatures:
34 |
35 | ~empty_head.EmptyHead
36 | ~gated_head.GatedHead
37 | ~sequential_head.SequentialHead
38 | ~skip_connection_head.SkipConnectionHead
39 | ~softmax_head.SoftmaxEmbeddingsHead
40 | ~stacked_projection_head.StackedProjectionHead
41 | ~widening_head.WideningHead
42 |
--------------------------------------------------------------------------------
/docs/source/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | # import os
14 | # import sys
15 | # sys.path.insert(0, os.path.abspath('.'))
16 |
17 |
18 | # -- Project information -----------------------------------------------------
19 |
20 | project = "Quaterion Models"
21 | copyright = "2022, Quaterion Models Authors"
22 | author = "Quaterion Models Authors"
23 |
24 |
25 | # -- General configuration ---------------------------------------------------
26 |
27 | # Add any Sphinx extension module names here, as strings. They can be
28 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
29 | # ones.
30 | extensions = [
31 | "sphinx.ext.napoleon",
32 | "sphinx.ext.autodoc",
33 | "sphinx.ext.viewcode",
34 | "sphinx.ext.intersphinx",
35 | "sphinx.ext.autosummary",
36 | ]
37 |
38 | # mapping to other sphinx doc
39 | # tuple: (target, inventory)
40 | # Each target is the base URI of a foreign Sphinx documentation set and can be a local path or an
41 | # HTTP URI. The inventory indicates where the inventory file can be found: it can be None (an
42 | # objects.inv file at the same location as the base URI) or another local file path or a full
43 | # HTTP URI to an inventory file.
44 | intersphinx_mapping = {
45 | "gensim": ("https://radimrehurek.com/gensim/", None),
46 | }
47 |
48 | # prevents sphinx from adding full path to type hints
49 | autodoc_typehints_format = "short"
50 |
51 | # order members by type and not alphabetically, it prevents mixing of class attributes
52 | # and methods
53 | autodoc_member_order = "groupwise"
54 |
55 | # moves ``Return type`` to ``Returns``
56 | napoleon_use_rtype = False
57 |
58 | # If true, suppress the module name of the python reference if it can be resolved.
59 | # Experimental feature:
60 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#confval-python_use_unqualified_type_names
61 | python_use_unqualified_type_names = True
62 |
63 | # prevents sphinx to add full module path in titles
64 | add_module_names = False
65 |
66 | # Add any paths that contain templates here, relative to this directory.
67 | templates_path = ["_templates"]
68 |
69 | # prevents unfolding type hints
70 | autodoc_type_aliases = {
71 | "TensorInterchange": "TensorInterchange",
72 | "CollateFnType": "CollateFnType",
73 | }
74 |
75 | # List of patterns, relative to source directory, that match files and
76 | # directories to ignore when looking for source files.
77 | # This pattern also affects html_static_path and html_extra_path.
78 | exclude_patterns = []
79 | # -- Options for HTML output -------------------------------------------------
80 |
81 | # The theme to use for HTML and HTML Help pages. See the documentation for
82 | # a list of builtin themes.
83 | #
84 |
85 | html_theme = "qdrant_sphinx_theme"
86 |
87 | # Add any paths that contain custom static files (such as style sheets) here,
88 | # relative to this directory. They are copied after the builtin static files,
89 | # so a file named "default.css" will overwrite the builtin "default.css".
90 | # html_static_path = []
91 |
92 | # Files excluded via exclude_patterns still being generating by sphinx-apidoc
93 | # As they are generated, some documents have links to them. It leads to a warning like:
94 | # `WARNING: toctree contains reference to excluded document '...'`.
95 | # suppress_warnings allows to remove such warnings
96 | suppress_warnings = ["toc.excluded"]
97 |
98 | html_theme_options = {
99 | # google analytics can be added here
100 | "logo_only": False,
101 | "display_version": True,
102 | "prev_next_buttons_location": "bottom",
103 | "style_external_links": False,
104 | # Toc options
105 | "collapse_navigation": True,
106 | "sticky_navigation": True,
107 | "titles_only": False,
108 | "qdrant_project": "quaterion-models",
109 | "qdrant_logo": "/_static/images/quaterion_logo_horisontal_opt.svg",
110 | }
111 |
112 | # default is false
113 | _FAST_DOCS_DEV = False
114 |
--------------------------------------------------------------------------------
/docs/source/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to quaterion-models's documentation!
2 | ============================================
3 |
4 |
5 | ``quaterion-models`` is a part of
6 | `Quaterion `_, similarity
7 | learning framework. It is kept as a separate package to make servable
8 | models lightweight and free from training dependencies.
9 |
10 | It contains definition of base classes, used for model inference, as
11 | well as the collection of building blocks for building fine-tunable
12 | similarity learning models.
13 |
14 | If you are looking for the training-related part of Quaterion, please
15 | see the `main repository `_
16 | instead.
17 |
18 | Install
19 | -------
20 |
21 | .. code:: bash
22 |
23 | pip install quaterion-models
24 |
25 | It makes sense to install ``quaterion-models`` independent of the main
26 | framework if you already have trained model and only need to make
27 | inference.
28 |
29 | Load and inference
30 | ------------------
31 |
32 | .. code:: python
33 |
34 | from quaterion_models import SimilarityModel
35 |
36 | model = SimilarityModel.load("./path/to/saved/model")
37 |
38 | embeddings = model.encode([
39 | {"description": "this is an example input"},
40 | {"description": "you may have a different format"},
41 | {"description": "the output will be a numpy array"},
42 | {"description": "of size [batch_size, embedding_size]"},
43 | ])
44 |
45 | Content
46 | -------
47 |
48 | .. toctree::
49 | :maxdepth: 1
50 |
51 | api/index
52 |
53 |
54 | Indices and tables
55 | ==================
56 |
57 | * :ref:`genindex`
58 | * :ref:`modindex`
59 |
--------------------------------------------------------------------------------
/docs/source/modules.rst:
--------------------------------------------------------------------------------
1 | quaterion_models
2 | ================
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | quaterion_models
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.encoders.encoder.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.encoders.encoder module
2 | =========================================
3 |
4 | .. automodule:: quaterion_models.encoders.encoder
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.encoders.extras.fasttext_encoder.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.encoders.extras.fasttext\_encoder module
2 | ==========================================================
3 |
4 | .. automodule:: quaterion_models.encoders.extras.fasttext_encoder
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.encoders.extras.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.encoders.extras package
2 | =========================================
3 |
4 | Submodules
5 | ----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.encoders.extras.fasttext_encoder
11 |
12 | Module contents
13 | ---------------
14 |
15 | .. automodule:: quaterion_models.encoders.extras
16 | :members:
17 | :undoc-members:
18 | :show-inheritance:
19 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.encoders.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.encoders package
2 | ==================================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.encoders.extras
11 |
12 | Submodules
13 | ----------
14 |
15 | .. toctree::
16 | :maxdepth: 4
17 |
18 | quaterion_models.encoders.encoder
19 | quaterion_models.encoders.switch_encoder
20 |
21 | Module contents
22 | ---------------
23 |
24 | .. automodule:: quaterion_models.encoders
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.encoders.switch_encoder.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.encoders.switch\_encoder module
2 | =================================================
3 |
4 | .. automodule:: quaterion_models.encoders.switch_encoder
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.empty_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.empty\_head module
2 | ==========================================
3 |
4 | .. automodule:: quaterion_models.heads.empty_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.encoder_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.encoder\_head module
2 | ============================================
3 |
4 | .. automodule:: quaterion_models.heads.encoder_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.gated_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.gated\_head module
2 | ==========================================
3 |
4 | .. automodule:: quaterion_models.heads.gated_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads package
2 | ===============================
3 |
4 | Submodules
5 | ----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.heads.empty_head
11 | quaterion_models.heads.encoder_head
12 | quaterion_models.heads.gated_head
13 | quaterion_models.heads.sequential_head
14 | quaterion_models.heads.skip_connection_head
15 | quaterion_models.heads.softmax_head
16 | quaterion_models.heads.stacked_projection_head
17 | quaterion_models.heads.widening_head
18 |
19 | Module contents
20 | ---------------
21 |
22 | .. automodule:: quaterion_models.heads
23 | :members:
24 | :undoc-members:
25 | :show-inheritance:
26 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.sequential_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.sequential\_head module
2 | ===============================================
3 |
4 | .. automodule:: quaterion_models.heads.sequential_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.skip_connection_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.skip\_connection\_head module
2 | =====================================================
3 |
4 | .. automodule:: quaterion_models.heads.skip_connection_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.softmax_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.softmax\_head module
2 | ============================================
3 |
4 | .. automodule:: quaterion_models.heads.softmax_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.stacked_projection_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.stacked\_projection\_head module
2 | ========================================================
3 |
4 | .. automodule:: quaterion_models.heads.stacked_projection_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.heads.widening_head.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.heads.widening\_head module
2 | =============================================
3 |
4 | .. automodule:: quaterion_models.heads.widening_head
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.model.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.model module
2 | ==============================
3 |
4 | .. automodule:: quaterion_models.model
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.modules.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.modules package
2 | =================================
3 |
4 | Submodules
5 | ----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.modules.simple
11 |
12 | Module contents
13 | ---------------
14 |
15 | .. automodule:: quaterion_models.modules
16 | :members:
17 | :undoc-members:
18 | :show-inheritance:
19 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.modules.simple.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.modules.simple module
2 | =======================================
3 |
4 | .. automodule:: quaterion_models.modules.simple
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models package
2 | =========================
3 |
4 | Subpackages
5 | -----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.encoders
11 | quaterion_models.heads
12 | quaterion_models.modules
13 | quaterion_models.types
14 | quaterion_models.utils
15 |
16 | Submodules
17 | ----------
18 |
19 | .. toctree::
20 | :maxdepth: 4
21 |
22 | quaterion_models.model
23 |
24 | Module contents
25 | ---------------
26 |
27 | .. automodule:: quaterion_models
28 | :members:
29 | :undoc-members:
30 | :show-inheritance:
31 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.types.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.types package
2 | ===============================
3 |
4 | Module contents
5 | ---------------
6 |
7 | .. automodule:: quaterion_models.types
8 | :members:
9 | :undoc-members:
10 | :show-inheritance:
11 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.utils.classes.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.utils.classes module
2 | ======================================
3 |
4 | .. automodule:: quaterion_models.utils.classes
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.utils.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.utils package
2 | ===============================
3 |
4 | Submodules
5 | ----------
6 |
7 | .. toctree::
8 | :maxdepth: 4
9 |
10 | quaterion_models.utils.classes
11 | quaterion_models.utils.tensors
12 |
13 | Module contents
14 | ---------------
15 |
16 | .. automodule:: quaterion_models.utils
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
--------------------------------------------------------------------------------
/docs/source/quaterion_models.utils.tensors.rst:
--------------------------------------------------------------------------------
1 | quaterion\_models.utils.tensors module
2 | ======================================
3 |
4 | .. automodule:: quaterion_models.utils.tensors
5 | :members:
6 | :undoc-members:
7 | :show-inheritance:
8 |
--------------------------------------------------------------------------------
/netlify.toml:
--------------------------------------------------------------------------------
1 | [build]
2 | publish = "docs/html"
3 | command = "bash -x docs/generate_docs_netlify.sh"
4 |
--------------------------------------------------------------------------------
/poetry.lock:
--------------------------------------------------------------------------------
1 | [[package]]
2 | name = "alabaster"
3 | version = "0.7.12"
4 | description = "A configurable sidebar-enabled Sphinx theme"
5 | category = "dev"
6 | optional = false
7 | python-versions = "*"
8 |
9 | [[package]]
10 | name = "atomicwrites"
11 | version = "1.4.1"
12 | description = "Atomic file writes."
13 | category = "dev"
14 | optional = false
15 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
16 |
17 | [[package]]
18 | name = "attrs"
19 | version = "22.1.0"
20 | description = "Classes Without Boilerplate"
21 | category = "dev"
22 | optional = false
23 | python-versions = ">=3.5"
24 |
25 | [package.extras]
26 | dev = ["cloudpickle", "coverage[toml] (>=5.0.2)", "furo", "hypothesis", "mypy (>=0.900,!=0.940)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "sphinx", "sphinx-notfound-page", "zope.interface"]
27 | docs = ["furo", "sphinx", "sphinx-notfound-page", "zope.interface"]
28 | tests = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "zope.interface"]
29 | tests_no_zope = ["cloudpickle", "coverage[toml] (>=5.0.2)", "hypothesis", "mypy (>=0.900,!=0.940)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins"]
30 |
31 | [[package]]
32 | name = "babel"
33 | version = "2.10.3"
34 | description = "Internationalization utilities"
35 | category = "dev"
36 | optional = false
37 | python-versions = ">=3.6"
38 |
39 | [package.dependencies]
40 | pytz = ">=2015.7"
41 |
42 | [[package]]
43 | name = "black"
44 | version = "22.6.0"
45 | description = "The uncompromising code formatter."
46 | category = "dev"
47 | optional = false
48 | python-versions = ">=3.6.2"
49 |
50 | [package.dependencies]
51 | click = ">=8.0.0"
52 | mypy-extensions = ">=0.4.3"
53 | pathspec = ">=0.9.0"
54 | platformdirs = ">=2"
55 | tomli = {version = ">=1.1.0", markers = "python_full_version < \"3.11.0a7\""}
56 | typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
57 |
58 | [package.extras]
59 | colorama = ["colorama (>=0.4.3)"]
60 | d = ["aiohttp (>=3.7.4)"]
61 | jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
62 | uvloop = ["uvloop (>=0.15.2)"]
63 |
64 | [[package]]
65 | name = "certifi"
66 | version = "2022.6.15"
67 | description = "Python package for providing Mozilla's CA Bundle."
68 | category = "dev"
69 | optional = false
70 | python-versions = ">=3.6"
71 |
72 | [[package]]
73 | name = "charset-normalizer"
74 | version = "2.1.1"
75 | description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
76 | category = "dev"
77 | optional = false
78 | python-versions = ">=3.6.0"
79 |
80 | [package.extras]
81 | unicode_backport = ["unicodedata2"]
82 |
83 | [[package]]
84 | name = "click"
85 | version = "8.1.3"
86 | description = "Composable command line interface toolkit"
87 | category = "dev"
88 | optional = false
89 | python-versions = ">=3.7"
90 |
91 | [package.dependencies]
92 | colorama = {version = "*", markers = "platform_system == \"Windows\""}
93 |
94 | [[package]]
95 | name = "colorama"
96 | version = "0.4.5"
97 | description = "Cross-platform colored terminal text."
98 | category = "dev"
99 | optional = false
100 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
101 |
102 | [[package]]
103 | name = "docutils"
104 | version = "0.19"
105 | description = "Docutils -- Python Documentation Utilities"
106 | category = "dev"
107 | optional = false
108 | python-versions = ">=3.7"
109 |
110 | [[package]]
111 | name = "gensim"
112 | version = "4.2.0"
113 | description = "Python framework for fast Vector Space Modelling"
114 | category = "main"
115 | optional = true
116 | python-versions = ">=3.6"
117 |
118 | [package.dependencies]
119 | numpy = ">=1.17.0"
120 | scipy = ">=0.18.1"
121 | smart-open = ">=1.8.1"
122 |
123 | [package.extras]
124 | distributed = ["Pyro4 (>=4.27)"]
125 | docs = ["Pyro4 (>=4.27)", "annoy", "cython", "matplotlib", "memory-profiler", "mock", "nltk", "nmslib", "pandas", "pyemd", "pyro4", "pytest", "pytest-cov", "sphinx", "sphinx-gallery", "sphinxcontrib-napoleon", "sphinxcontrib.programoutput", "statsmodels", "testfixtures", "visdom (>0.1.8.7)"]
126 | test = ["cython", "mock", "nmslib", "pyemd", "pytest", "pytest-cov", "testfixtures", "visdom (>0.1.8.7)"]
127 | test-win = ["cython", "mock", "nmslib", "pyemd", "pytest", "pytest-cov", "testfixtures"]
128 |
129 | [[package]]
130 | name = "idna"
131 | version = "3.3"
132 | description = "Internationalized Domain Names in Applications (IDNA)"
133 | category = "dev"
134 | optional = false
135 | python-versions = ">=3.5"
136 |
137 | [[package]]
138 | name = "imagesize"
139 | version = "1.4.1"
140 | description = "Getting image size from png/jpeg/jpeg2000/gif file"
141 | category = "dev"
142 | optional = false
143 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
144 |
145 | [[package]]
146 | name = "importlib-metadata"
147 | version = "4.12.0"
148 | description = "Read metadata from Python packages"
149 | category = "dev"
150 | optional = false
151 | python-versions = ">=3.7"
152 |
153 | [package.dependencies]
154 | zipp = ">=0.5"
155 |
156 | [package.extras]
157 | docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"]
158 | perf = ["ipython"]
159 | testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"]
160 |
161 | [[package]]
162 | name = "iniconfig"
163 | version = "1.1.1"
164 | description = "iniconfig: brain-dead simple config-ini parsing"
165 | category = "dev"
166 | optional = false
167 | python-versions = "*"
168 |
169 | [[package]]
170 | name = "jinja2"
171 | version = "3.1.2"
172 | description = "A very fast and expressive template engine."
173 | category = "dev"
174 | optional = false
175 | python-versions = ">=3.7"
176 |
177 | [package.dependencies]
178 | MarkupSafe = ">=2.0"
179 |
180 | [package.extras]
181 | i18n = ["Babel (>=2.7)"]
182 |
183 | [[package]]
184 | name = "markupsafe"
185 | version = "2.1.1"
186 | description = "Safely add untrusted strings to HTML/XML markup."
187 | category = "dev"
188 | optional = false
189 | python-versions = ">=3.7"
190 |
191 | [[package]]
192 | name = "mypy-extensions"
193 | version = "0.4.3"
194 | description = "Experimental type system extensions for programs checked with the mypy typechecker."
195 | category = "dev"
196 | optional = false
197 | python-versions = "*"
198 |
199 | [[package]]
200 | name = "numpy"
201 | version = "1.23.2"
202 | description = "NumPy is the fundamental package for array computing with Python."
203 | category = "main"
204 | optional = false
205 | python-versions = ">=3.8"
206 |
207 | [[package]]
208 | name = "packaging"
209 | version = "21.3"
210 | description = "Core utilities for Python packages"
211 | category = "dev"
212 | optional = false
213 | python-versions = ">=3.6"
214 |
215 | [package.dependencies]
216 | pyparsing = ">=2.0.2,<3.0.5 || >3.0.5"
217 |
218 | [[package]]
219 | name = "pathspec"
220 | version = "0.9.0"
221 | description = "Utility library for gitignore style pattern matching of file paths."
222 | category = "dev"
223 | optional = false
224 | python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
225 |
226 | [[package]]
227 | name = "platformdirs"
228 | version = "2.5.2"
229 | description = "A small Python module for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
230 | category = "dev"
231 | optional = false
232 | python-versions = ">=3.7"
233 |
234 | [package.extras]
235 | docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx (>=4)", "sphinx-autodoc-typehints (>=1.12)"]
236 | test = ["appdirs (==1.4.4)", "pytest (>=6)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)"]
237 |
238 | [[package]]
239 | name = "pluggy"
240 | version = "1.0.0"
241 | description = "plugin and hook calling mechanisms for python"
242 | category = "dev"
243 | optional = false
244 | python-versions = ">=3.6"
245 |
246 | [package.extras]
247 | dev = ["pre-commit", "tox"]
248 | testing = ["pytest", "pytest-benchmark"]
249 |
250 | [[package]]
251 | name = "py"
252 | version = "1.11.0"
253 | description = "library with cross-python path, ini-parsing, io, code, log facilities"
254 | category = "dev"
255 | optional = false
256 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
257 |
258 | [[package]]
259 | name = "pygments"
260 | version = "2.13.0"
261 | description = "Pygments is a syntax highlighting package written in Python."
262 | category = "dev"
263 | optional = false
264 | python-versions = ">=3.6"
265 |
266 | [package.extras]
267 | plugins = ["importlib-metadata"]
268 |
269 | [[package]]
270 | name = "pyparsing"
271 | version = "3.0.9"
272 | description = "pyparsing module - Classes and methods to define and execute parsing grammars"
273 | category = "dev"
274 | optional = false
275 | python-versions = ">=3.6.8"
276 |
277 | [package.extras]
278 | diagrams = ["jinja2", "railroad-diagrams"]
279 |
280 | [[package]]
281 | name = "pytest"
282 | version = "6.2.5"
283 | description = "pytest: simple powerful testing with Python"
284 | category = "dev"
285 | optional = false
286 | python-versions = ">=3.6"
287 |
288 | [package.dependencies]
289 | atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""}
290 | attrs = ">=19.2.0"
291 | colorama = {version = "*", markers = "sys_platform == \"win32\""}
292 | iniconfig = "*"
293 | packaging = "*"
294 | pluggy = ">=0.12,<2.0"
295 | py = ">=1.8.2"
296 | toml = "*"
297 |
298 | [package.extras]
299 | testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "requests", "xmlschema"]
300 |
301 | [[package]]
302 | name = "pytz"
303 | version = "2022.2.1"
304 | description = "World timezone definitions, modern and historical"
305 | category = "dev"
306 | optional = false
307 | python-versions = "*"
308 |
309 | [[package]]
310 | name = "qdrant-sphinx-theme"
311 | version = "0.0.28"
312 | description = "Qudrant Sphinx Theme"
313 | category = "dev"
314 | optional = false
315 | python-versions = "*"
316 | develop = false
317 |
318 | [package.dependencies]
319 | sphinx = "*"
320 |
321 | [package.source]
322 | type = "git"
323 | url = "https://github.com/qdrant/qdrant_sphinx_theme.git"
324 | reference = "master"
325 | resolved_reference = "a90cdd5925783c2b0ed3b8d39897cd4eaf942e2a"
326 |
327 | [[package]]
328 | name = "requests"
329 | version = "2.28.1"
330 | description = "Python HTTP for Humans."
331 | category = "dev"
332 | optional = false
333 | python-versions = ">=3.7, <4"
334 |
335 | [package.dependencies]
336 | certifi = ">=2017.4.17"
337 | charset-normalizer = ">=2,<3"
338 | idna = ">=2.5,<4"
339 | urllib3 = ">=1.21.1,<1.27"
340 |
341 | [package.extras]
342 | socks = ["PySocks (>=1.5.6,!=1.5.7)"]
343 | use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"]
344 |
345 | [[package]]
346 | name = "scipy"
347 | version = "1.9.1"
348 | description = "SciPy: Scientific Library for Python"
349 | category = "main"
350 | optional = true
351 | python-versions = ">=3.8,<3.12"
352 |
353 | [package.dependencies]
354 | numpy = ">=1.18.5,<1.25.0"
355 |
356 | [[package]]
357 | name = "smart-open"
358 | version = "6.1.0"
359 | description = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)"
360 | category = "main"
361 | optional = true
362 | python-versions = ">=3.6,<4.0"
363 |
364 | [package.extras]
365 | all = ["azure-common", "azure-core", "azure-storage-blob", "boto3", "google-cloud-storage (>=1.31.0)", "requests"]
366 | azure = ["azure-common", "azure-core", "azure-storage-blob"]
367 | gcs = ["google-cloud-storage (>=1.31.0)"]
368 | http = ["requests"]
369 | s3 = ["boto3"]
370 | test = ["azure-common", "azure-core", "azure-storage-blob", "boto3", "google-cloud-storage (>=1.31.0)", "moto", "paramiko", "pathlib2", "pytest", "pytest-rerunfailures", "requests", "responses"]
371 | webhdfs = ["requests"]
372 |
373 | [[package]]
374 | name = "snowballstemmer"
375 | version = "2.2.0"
376 | description = "This package provides 29 stemmers for 28 languages generated from Snowball algorithms."
377 | category = "dev"
378 | optional = false
379 | python-versions = "*"
380 |
381 | [[package]]
382 | name = "sphinx"
383 | version = "5.1.1"
384 | description = "Python documentation generator"
385 | category = "dev"
386 | optional = false
387 | python-versions = ">=3.6"
388 |
389 | [package.dependencies]
390 | alabaster = ">=0.7,<0.8"
391 | babel = ">=1.3"
392 | colorama = {version = ">=0.3.5", markers = "sys_platform == \"win32\""}
393 | docutils = ">=0.14,<0.20"
394 | imagesize = "*"
395 | importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""}
396 | Jinja2 = ">=2.3"
397 | packaging = "*"
398 | Pygments = ">=2.0"
399 | requests = ">=2.5.0"
400 | snowballstemmer = ">=1.1"
401 | sphinxcontrib-applehelp = "*"
402 | sphinxcontrib-devhelp = "*"
403 | sphinxcontrib-htmlhelp = ">=2.0.0"
404 | sphinxcontrib-jsmath = "*"
405 | sphinxcontrib-qthelp = "*"
406 | sphinxcontrib-serializinghtml = ">=1.1.5"
407 |
408 | [package.extras]
409 | docs = ["sphinxcontrib-websupport"]
410 | lint = ["docutils-stubs", "flake8 (>=3.5.0)", "flake8-bugbear", "flake8-comprehensions", "isort", "mypy (>=0.971)", "sphinx-lint", "types-requests", "types-typed-ast"]
411 | test = ["cython", "html5lib", "pytest (>=4.6)", "typed-ast"]
412 |
413 | [[package]]
414 | name = "sphinxcontrib-applehelp"
415 | version = "1.0.2"
416 | description = "sphinxcontrib-applehelp is a sphinx extension which outputs Apple help books"
417 | category = "dev"
418 | optional = false
419 | python-versions = ">=3.5"
420 |
421 | [package.extras]
422 | lint = ["docutils-stubs", "flake8", "mypy"]
423 | test = ["pytest"]
424 |
425 | [[package]]
426 | name = "sphinxcontrib-devhelp"
427 | version = "1.0.2"
428 | description = "sphinxcontrib-devhelp is a sphinx extension which outputs Devhelp document."
429 | category = "dev"
430 | optional = false
431 | python-versions = ">=3.5"
432 |
433 | [package.extras]
434 | lint = ["docutils-stubs", "flake8", "mypy"]
435 | test = ["pytest"]
436 |
437 | [[package]]
438 | name = "sphinxcontrib-htmlhelp"
439 | version = "2.0.0"
440 | description = "sphinxcontrib-htmlhelp is a sphinx extension which renders HTML help files"
441 | category = "dev"
442 | optional = false
443 | python-versions = ">=3.6"
444 |
445 | [package.extras]
446 | lint = ["docutils-stubs", "flake8", "mypy"]
447 | test = ["html5lib", "pytest"]
448 |
449 | [[package]]
450 | name = "sphinxcontrib-jsmath"
451 | version = "1.0.1"
452 | description = "A sphinx extension which renders display math in HTML via JavaScript"
453 | category = "dev"
454 | optional = false
455 | python-versions = ">=3.5"
456 |
457 | [package.extras]
458 | test = ["flake8", "mypy", "pytest"]
459 |
460 | [[package]]
461 | name = "sphinxcontrib-qthelp"
462 | version = "1.0.3"
463 | description = "sphinxcontrib-qthelp is a sphinx extension which outputs QtHelp document."
464 | category = "dev"
465 | optional = false
466 | python-versions = ">=3.5"
467 |
468 | [package.extras]
469 | lint = ["docutils-stubs", "flake8", "mypy"]
470 | test = ["pytest"]
471 |
472 | [[package]]
473 | name = "sphinxcontrib-serializinghtml"
474 | version = "1.1.5"
475 | description = "sphinxcontrib-serializinghtml is a sphinx extension which outputs \"serialized\" HTML files (json and pickle)."
476 | category = "dev"
477 | optional = false
478 | python-versions = ">=3.5"
479 |
480 | [package.extras]
481 | lint = ["docutils-stubs", "flake8", "mypy"]
482 | test = ["pytest"]
483 |
484 | [[package]]
485 | name = "toml"
486 | version = "0.10.2"
487 | description = "Python Library for Tom's Obvious, Minimal Language"
488 | category = "dev"
489 | optional = false
490 | python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
491 |
492 | [[package]]
493 | name = "tomli"
494 | version = "2.0.1"
495 | description = "A lil' TOML parser"
496 | category = "dev"
497 | optional = false
498 | python-versions = ">=3.7"
499 |
500 | [[package]]
501 | name = "torch"
502 | version = "1.12.1"
503 | description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
504 | category = "main"
505 | optional = false
506 | python-versions = ">=3.7.0"
507 |
508 | [package.dependencies]
509 | typing-extensions = "*"
510 |
511 | [[package]]
512 | name = "typing-extensions"
513 | version = "4.3.0"
514 | description = "Backported and Experimental Type Hints for Python 3.7+"
515 | category = "main"
516 | optional = false
517 | python-versions = ">=3.7"
518 |
519 | [[package]]
520 | name = "urllib3"
521 | version = "1.26.12"
522 | description = "HTTP library with thread-safe connection pooling, file post, and more."
523 | category = "dev"
524 | optional = false
525 | python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, <4"
526 |
527 | [package.extras]
528 | brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"]
529 | secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"]
530 | socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
531 |
532 | [[package]]
533 | name = "zipp"
534 | version = "3.8.1"
535 | description = "Backport of pathlib-compatible object wrapper for zip files"
536 | category = "dev"
537 | optional = false
538 | python-versions = ">=3.7"
539 |
540 | [package.extras]
541 | docs = ["jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx"]
542 | testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"]
543 |
544 | [extras]
545 | fasttext = ["gensim"]
546 |
547 | [metadata]
548 | lock-version = "1.1"
549 | python-versions = ">=3.8,<3.11"
550 | content-hash = "75d18a64b5e87ca3592968a5360ed2aef2dbfe9a2abe9af29ac5b9a34c59fcf8"
551 |
552 | [metadata.files]
553 | alabaster = [
554 | {file = "alabaster-0.7.12-py2.py3-none-any.whl", hash = "sha256:446438bdcca0e05bd45ea2de1668c1d9b032e1a9154c2c259092d77031ddd359"},
555 | {file = "alabaster-0.7.12.tar.gz", hash = "sha256:a661d72d58e6ea8a57f7a86e37d86716863ee5e92788398526d58b26a4e4dc02"},
556 | ]
557 | atomicwrites = [
558 | {file = "atomicwrites-1.4.1.tar.gz", hash = "sha256:81b2c9071a49367a7f770170e5eec8cb66567cfbbc8c73d20ce5ca4a8d71cf11"},
559 | ]
560 | attrs = [
561 | {file = "attrs-22.1.0-py2.py3-none-any.whl", hash = "sha256:86efa402f67bf2df34f51a335487cf46b1ec130d02b8d39fd248abfd30da551c"},
562 | {file = "attrs-22.1.0.tar.gz", hash = "sha256:29adc2665447e5191d0e7c568fde78b21f9672d344281d0c6e1ab085429b22b6"},
563 | ]
564 | babel = [
565 | {file = "Babel-2.10.3-py3-none-any.whl", hash = "sha256:ff56f4892c1c4bf0d814575ea23471c230d544203c7748e8c68f0089478d48eb"},
566 | {file = "Babel-2.10.3.tar.gz", hash = "sha256:7614553711ee97490f732126dc077f8d0ae084ebc6a96e23db1482afabdb2c51"},
567 | ]
568 | black = [
569 | {file = "black-22.6.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f586c26118bc6e714ec58c09df0157fe2d9ee195c764f630eb0d8e7ccce72e69"},
570 | {file = "black-22.6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b270a168d69edb8b7ed32c193ef10fd27844e5c60852039599f9184460ce0807"},
571 | {file = "black-22.6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6797f58943fceb1c461fb572edbe828d811e719c24e03375fd25170ada53825e"},
572 | {file = "black-22.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c85928b9d5f83b23cee7d0efcb310172412fbf7cb9d9ce963bd67fd141781def"},
573 | {file = "black-22.6.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6fe02afde060bbeef044af7996f335fbe90b039ccf3f5eb8f16df8b20f77666"},
574 | {file = "black-22.6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:cfaf3895a9634e882bf9d2363fed5af8888802d670f58b279b0bece00e9a872d"},
575 | {file = "black-22.6.0-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94783f636bca89f11eb5d50437e8e17fbc6a929a628d82304c80fa9cd945f256"},
576 | {file = "black-22.6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:2ea29072e954a4d55a2ff58971b83365eba5d3d357352a07a7a4df0d95f51c78"},
577 | {file = "black-22.6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e439798f819d49ba1c0bd9664427a05aab79bfba777a6db94fd4e56fae0cb849"},
578 | {file = "black-22.6.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:187d96c5e713f441a5829e77120c269b6514418f4513a390b0499b0987f2ff1c"},
579 | {file = "black-22.6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:074458dc2f6e0d3dab7928d4417bb6957bb834434516f21514138437accdbe90"},
580 | {file = "black-22.6.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:a218d7e5856f91d20f04e931b6f16d15356db1c846ee55f01bac297a705ca24f"},
581 | {file = "black-22.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:568ac3c465b1c8b34b61cd7a4e349e93f91abf0f9371eda1cf87194663ab684e"},
582 | {file = "black-22.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6c1734ab264b8f7929cef8ae5f900b85d579e6cbfde09d7387da8f04771b51c6"},
583 | {file = "black-22.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9a3ac16efe9ec7d7381ddebcc022119794872abce99475345c5a61aa18c45ad"},
584 | {file = "black-22.6.0-cp38-cp38-win_amd64.whl", hash = "sha256:b9fd45787ba8aa3f5e0a0a98920c1012c884622c6c920dbe98dbd05bc7c70fbf"},
585 | {file = "black-22.6.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ba9be198ecca5031cd78745780d65a3f75a34b2ff9be5837045dce55db83d1c"},
586 | {file = "black-22.6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a3db5b6409b96d9bd543323b23ef32a1a2b06416d525d27e0f67e74f1446c8f2"},
587 | {file = "black-22.6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:560558527e52ce8afba936fcce93a7411ab40c7d5fe8c2463e279e843c0328ee"},
588 | {file = "black-22.6.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b154e6bbde1e79ea3260c4b40c0b7b3109ffcdf7bc4ebf8859169a6af72cd70b"},
589 | {file = "black-22.6.0-cp39-cp39-win_amd64.whl", hash = "sha256:4af5bc0e1f96be5ae9bd7aaec219c901a94d6caa2484c21983d043371c733fc4"},
590 | {file = "black-22.6.0-py3-none-any.whl", hash = "sha256:ac609cf8ef5e7115ddd07d85d988d074ed00e10fbc3445aee393e70164a2219c"},
591 | {file = "black-22.6.0.tar.gz", hash = "sha256:6c6d39e28aed379aec40da1c65434c77d75e65bb59a1e1c283de545fb4e7c6c9"},
592 | ]
593 | certifi = [
594 | {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"},
595 | {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"},
596 | ]
597 | charset-normalizer = [
598 | {file = "charset-normalizer-2.1.1.tar.gz", hash = "sha256:5a3d016c7c547f69d6f81fb0db9449ce888b418b5b9952cc5e6e66843e9dd845"},
599 | {file = "charset_normalizer-2.1.1-py3-none-any.whl", hash = "sha256:83e9a75d1911279afd89352c68b45348559d1fc0506b054b346651b5e7fee29f"},
600 | ]
601 | click = [
602 | {file = "click-8.1.3-py3-none-any.whl", hash = "sha256:bb4d8133cb15a609f44e8213d9b391b0809795062913b383c62be0ee95b1db48"},
603 | {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"},
604 | ]
605 | colorama = [
606 | {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"},
607 | {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"},
608 | ]
609 | docutils = []
610 | gensim = []
611 | idna = [
612 | {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
613 | {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
614 | ]
615 | imagesize = [
616 | {file = "imagesize-1.4.1-py2.py3-none-any.whl", hash = "sha256:0d8d18d08f840c19d0ee7ca1fd82490fdc3729b7ac93f49870406ddde8ef8d8b"},
617 | {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"},
618 | ]
619 | importlib-metadata = [
620 | {file = "importlib_metadata-4.12.0-py3-none-any.whl", hash = "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23"},
621 | {file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"},
622 | ]
623 | iniconfig = [
624 | {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
625 | {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
626 | ]
627 | jinja2 = [
628 | {file = "Jinja2-3.1.2-py3-none-any.whl", hash = "sha256:6088930bfe239f0e6710546ab9c19c9ef35e29792895fed6e6e31a023a182a61"},
629 | {file = "Jinja2-3.1.2.tar.gz", hash = "sha256:31351a702a408a9e7595a8fc6150fc3f43bb6bf7e319770cbc0db9df9437e852"},
630 | ]
631 | markupsafe = [
632 | {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812"},
633 | {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a"},
634 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e"},
635 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5"},
636 | {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4"},
637 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f"},
638 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e"},
639 | {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933"},
640 | {file = "MarkupSafe-2.1.1-cp310-cp310-win32.whl", hash = "sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6"},
641 | {file = "MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417"},
642 | {file = "MarkupSafe-2.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02"},
643 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a"},
644 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37"},
645 | {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980"},
646 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a"},
647 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3"},
648 | {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a"},
649 | {file = "MarkupSafe-2.1.1-cp37-cp37m-win32.whl", hash = "sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff"},
650 | {file = "MarkupSafe-2.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a"},
651 | {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452"},
652 | {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003"},
653 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1"},
654 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601"},
655 | {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925"},
656 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f"},
657 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88"},
658 | {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63"},
659 | {file = "MarkupSafe-2.1.1-cp38-cp38-win32.whl", hash = "sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1"},
660 | {file = "MarkupSafe-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7"},
661 | {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a"},
662 | {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f"},
663 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6"},
664 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77"},
665 | {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603"},
666 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7"},
667 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135"},
668 | {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96"},
669 | {file = "MarkupSafe-2.1.1-cp39-cp39-win32.whl", hash = "sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c"},
670 | {file = "MarkupSafe-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247"},
671 | {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"},
672 | ]
673 | mypy-extensions = []
674 | numpy = [
675 | {file = "numpy-1.23.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e603ca1fb47b913942f3e660a15e55a9ebca906857edfea476ae5f0fe9b457d5"},
676 | {file = "numpy-1.23.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:633679a472934b1c20a12ed0c9a6c9eb167fbb4cb89031939bfd03dd9dbc62b8"},
677 | {file = "numpy-1.23.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17e5226674f6ea79e14e3b91bfbc153fdf3ac13f5cc54ee7bc8fdbe820a32da0"},
678 | {file = "numpy-1.23.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdc02c0235b261925102b1bd586579b7158e9d0d07ecb61148a1799214a4afd5"},
679 | {file = "numpy-1.23.2-cp310-cp310-win32.whl", hash = "sha256:df28dda02c9328e122661f399f7655cdcbcf22ea42daa3650a26bce08a187450"},
680 | {file = "numpy-1.23.2-cp310-cp310-win_amd64.whl", hash = "sha256:8ebf7e194b89bc66b78475bd3624d92980fca4e5bb86dda08d677d786fefc414"},
681 | {file = "numpy-1.23.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:dc76bca1ca98f4b122114435f83f1fcf3c0fe48e4e6f660e07996abf2f53903c"},
682 | {file = "numpy-1.23.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ecfdd68d334a6b97472ed032b5b37a30d8217c097acfff15e8452c710e775524"},
683 | {file = "numpy-1.23.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5593f67e66dea4e237f5af998d31a43e447786b2154ba1ad833676c788f37cde"},
684 | {file = "numpy-1.23.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac987b35df8c2a2eab495ee206658117e9ce867acf3ccb376a19e83070e69418"},
685 | {file = "numpy-1.23.2-cp311-cp311-win32.whl", hash = "sha256:d98addfd3c8728ee8b2c49126f3c44c703e2b005d4a95998e2167af176a9e722"},
686 | {file = "numpy-1.23.2-cp311-cp311-win_amd64.whl", hash = "sha256:8ecb818231afe5f0f568c81f12ce50f2b828ff2b27487520d85eb44c71313b9e"},
687 | {file = "numpy-1.23.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:909c56c4d4341ec8315291a105169d8aae732cfb4c250fbc375a1efb7a844f8f"},
688 | {file = "numpy-1.23.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8247f01c4721479e482cc2f9f7d973f3f47810cbc8c65e38fd1bbd3141cc9842"},
689 | {file = "numpy-1.23.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8b97a8a87cadcd3f94659b4ef6ec056261fa1e1c3317f4193ac231d4df70215"},
690 | {file = "numpy-1.23.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bd5b7ccae24e3d8501ee5563e82febc1771e73bd268eef82a1e8d2b4d556ae66"},
691 | {file = "numpy-1.23.2-cp38-cp38-win32.whl", hash = "sha256:9b83d48e464f393d46e8dd8171687394d39bc5abfe2978896b77dc2604e8635d"},
692 | {file = "numpy-1.23.2-cp38-cp38-win_amd64.whl", hash = "sha256:dec198619b7dbd6db58603cd256e092bcadef22a796f778bf87f8592b468441d"},
693 | {file = "numpy-1.23.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4f41f5bf20d9a521f8cab3a34557cd77b6f205ab2116651f12959714494268b0"},
694 | {file = "numpy-1.23.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:806cc25d5c43e240db709875e947076b2826f47c2c340a5a2f36da5bb10c58d6"},
695 | {file = "numpy-1.23.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f9d84a24889ebb4c641a9b99e54adb8cab50972f0166a3abc14c3b93163f074"},
696 | {file = "numpy-1.23.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c403c81bb8ffb1c993d0165a11493fd4bf1353d258f6997b3ee288b0a48fce77"},
697 | {file = "numpy-1.23.2-cp39-cp39-win32.whl", hash = "sha256:cf8c6aed12a935abf2e290860af8e77b26a042eb7f2582ff83dc7ed5f963340c"},
698 | {file = "numpy-1.23.2-cp39-cp39-win_amd64.whl", hash = "sha256:5e28cd64624dc2354a349152599e55308eb6ca95a13ce6a7d5679ebff2962913"},
699 | {file = "numpy-1.23.2-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:806970e69106556d1dd200e26647e9bee5e2b3f1814f9da104a943e8d548ca38"},
700 | {file = "numpy-1.23.2-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bd879d3ca4b6f39b7770829f73278b7c5e248c91d538aab1e506c628353e47f"},
701 | {file = "numpy-1.23.2-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:be6b350dfbc7f708d9d853663772a9310783ea58f6035eec649fb9c4371b5389"},
702 | {file = "numpy-1.23.2.tar.gz", hash = "sha256:b78d00e48261fbbd04aa0d7427cf78d18401ee0abd89c7559bbf422e5b1c7d01"},
703 | ]
704 | packaging = [
705 | {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"},
706 | {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"},
707 | ]
708 | pathspec = []
709 | platformdirs = [
710 | {file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"},
711 | {file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"},
712 | ]
713 | pluggy = [
714 | {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"},
715 | {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"},
716 | ]
717 | py = [
718 | {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
719 | {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
720 | ]
721 | pygments = []
722 | pyparsing = [
723 | {file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
724 | {file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
725 | ]
726 | pytest = [
727 | {file = "pytest-6.2.5-py3-none-any.whl", hash = "sha256:7310f8d27bc79ced999e760ca304d69f6ba6c6649c0b60fb0e04a4a77cacc134"},
728 | {file = "pytest-6.2.5.tar.gz", hash = "sha256:131b36680866a76e6781d13f101efb86cf674ebb9762eb70d3082b6f29889e89"},
729 | ]
730 | pytz = []
731 | qdrant-sphinx-theme = []
732 | requests = [
733 | {file = "requests-2.28.1-py3-none-any.whl", hash = "sha256:8fefa2a1a1365bf5520aac41836fbee479da67864514bdb821f31ce07ce65349"},
734 | {file = "requests-2.28.1.tar.gz", hash = "sha256:7c5599b102feddaa661c826c56ab4fee28bfd17f5abca1ebbe3e7f19d7c97983"},
735 | ]
736 | scipy = []
737 | smart-open = []
738 | snowballstemmer = [
739 | {file = "snowballstemmer-2.2.0-py2.py3-none-any.whl", hash = "sha256:c8e1716e83cc398ae16824e5572ae04e0d9fc2c6b985fb0f900f5f0c96ecba1a"},
740 | {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"},
741 | ]
742 | sphinx = [
743 | {file = "Sphinx-5.1.1-py3-none-any.whl", hash = "sha256:309a8da80cb6da9f4713438e5b55861877d5d7976b69d87e336733637ea12693"},
744 | {file = "Sphinx-5.1.1.tar.gz", hash = "sha256:ba3224a4e206e1fbdecf98a4fae4992ef9b24b85ebf7b584bb340156eaf08d89"},
745 | ]
746 | sphinxcontrib-applehelp = [
747 | {file = "sphinxcontrib-applehelp-1.0.2.tar.gz", hash = "sha256:a072735ec80e7675e3f432fcae8610ecf509c5f1869d17e2eecff44389cdbc58"},
748 | {file = "sphinxcontrib_applehelp-1.0.2-py2.py3-none-any.whl", hash = "sha256:806111e5e962be97c29ec4c1e7fe277bfd19e9652fb1a4392105b43e01af885a"},
749 | ]
750 | sphinxcontrib-devhelp = [
751 | {file = "sphinxcontrib-devhelp-1.0.2.tar.gz", hash = "sha256:ff7f1afa7b9642e7060379360a67e9c41e8f3121f2ce9164266f61b9f4b338e4"},
752 | {file = "sphinxcontrib_devhelp-1.0.2-py2.py3-none-any.whl", hash = "sha256:8165223f9a335cc1af7ffe1ed31d2871f325254c0423bc0c4c7cd1c1e4734a2e"},
753 | ]
754 | sphinxcontrib-htmlhelp = [
755 | {file = "sphinxcontrib-htmlhelp-2.0.0.tar.gz", hash = "sha256:f5f8bb2d0d629f398bf47d0d69c07bc13b65f75a81ad9e2f71a63d4b7a2f6db2"},
756 | {file = "sphinxcontrib_htmlhelp-2.0.0-py2.py3-none-any.whl", hash = "sha256:d412243dfb797ae3ec2b59eca0e52dac12e75a241bf0e4eb861e450d06c6ed07"},
757 | ]
758 | sphinxcontrib-jsmath = [
759 | {file = "sphinxcontrib-jsmath-1.0.1.tar.gz", hash = "sha256:a9925e4a4587247ed2191a22df5f6970656cb8ca2bd6284309578f2153e0c4b8"},
760 | {file = "sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl", hash = "sha256:2ec2eaebfb78f3f2078e73666b1415417a116cc848b72e5172e596c871103178"},
761 | ]
762 | sphinxcontrib-qthelp = [
763 | {file = "sphinxcontrib-qthelp-1.0.3.tar.gz", hash = "sha256:4c33767ee058b70dba89a6fc5c1892c0d57a54be67ddd3e7875a18d14cba5a72"},
764 | {file = "sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl", hash = "sha256:bd9fc24bcb748a8d51fd4ecaade681350aa63009a347a8c14e637895444dfab6"},
765 | ]
766 | sphinxcontrib-serializinghtml = [
767 | {file = "sphinxcontrib-serializinghtml-1.1.5.tar.gz", hash = "sha256:aa5f6de5dfdf809ef505c4895e51ef5c9eac17d0f287933eb49ec495280b6952"},
768 | {file = "sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl", hash = "sha256:352a9a00ae864471d3a7ead8d7d79f5fc0b57e8b3f95e9867eb9eb28999b92fd"},
769 | ]
770 | toml = [
771 | {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
772 | {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
773 | ]
774 | tomli = [
775 | {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"},
776 | {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"},
777 | ]
778 | torch = [
779 | {file = "torch-1.12.1-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:9c038662db894a23e49e385df13d47b2a777ffd56d9bcd5b832593fab0a7e286"},
780 | {file = "torch-1.12.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:4e1b9c14cf13fd2ab8d769529050629a0e68a6fc5cb8e84b4a3cc1dd8c4fe541"},
781 | {file = "torch-1.12.1-cp310-cp310-win_amd64.whl", hash = "sha256:e9c8f4a311ac29fc7e8e955cfb7733deb5dbe1bdaabf5d4af2765695824b7e0d"},
782 | {file = "torch-1.12.1-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:976c3f997cea38ee91a0dd3c3a42322785414748d1761ef926b789dfa97c6134"},
783 | {file = "torch-1.12.1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:68104e4715a55c4bb29a85c6a8d57d820e0757da363be1ba680fa8cc5be17b52"},
784 | {file = "torch-1.12.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:743784ccea0dc8f2a3fe6a536bec8c4763bd82c1352f314937cb4008d4805de1"},
785 | {file = "torch-1.12.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b5dbcca369800ce99ba7ae6dee3466607a66958afca3b740690d88168752abcf"},
786 | {file = "torch-1.12.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f3b52a634e62821e747e872084ab32fbcb01b7fa7dbb7471b6218279f02a178a"},
787 | {file = "torch-1.12.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:8a34a2fbbaa07c921e1b203f59d3d6e00ed379f2b384445773bd14e328a5b6c8"},
788 | {file = "torch-1.12.1-cp37-none-macosx_11_0_arm64.whl", hash = "sha256:42f639501928caabb9d1d55ddd17f07cd694de146686c24489ab8c615c2871f2"},
789 | {file = "torch-1.12.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0b44601ec56f7dd44ad8afc00846051162ef9c26a8579dda0a02194327f2d55e"},
790 | {file = "torch-1.12.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cd26d8c5640c3a28c526d41ccdca14cf1cbca0d0f2e14e8263a7ac17194ab1d2"},
791 | {file = "torch-1.12.1-cp38-cp38-win_amd64.whl", hash = "sha256:42e115dab26f60c29e298559dbec88444175528b729ae994ec4c65d56fe267dd"},
792 | {file = "torch-1.12.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:a8320ba9ad87e80ca5a6a016e46ada4d1ba0c54626e135d99b2129a4541c509d"},
793 | {file = "torch-1.12.1-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:03e31c37711db2cd201e02de5826de875529e45a55631d317aadce2f1ed45aa8"},
794 | {file = "torch-1.12.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:9b356aea223772cd754edb4d9ecf2a025909b8615a7668ac7d5130f86e7ec421"},
795 | {file = "torch-1.12.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:6cf6f54b43c0c30335428195589bd00e764a6d27f3b9ba637aaa8c11aaf93073"},
796 | {file = "torch-1.12.1-cp39-cp39-win_amd64.whl", hash = "sha256:f00c721f489089dc6364a01fd84906348fe02243d0af737f944fddb36003400d"},
797 | {file = "torch-1.12.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:bfec2843daa654f04fda23ba823af03e7b6f7650a873cdb726752d0e3718dada"},
798 | {file = "torch-1.12.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:69fe2cae7c39ccadd65a123793d30e0db881f1c1927945519c5c17323131437e"},
799 | ]
800 | typing-extensions = [
801 | {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"},
802 | {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"},
803 | ]
804 | urllib3 = []
805 | zipp = [
806 | {file = "zipp-3.8.1-py3-none-any.whl", hash = "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009"},
807 | {file = "zipp-3.8.1.tar.gz", hash = "sha256:05b45f1ee8f807d0cc928485ca40a07cb491cf092ff587c0df9cb1fd154848d2"},
808 | ]
809 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "quaterion-models"
3 | version = "0.1.19"
4 | description = "The collection of building blocks to build fine-tunable similarity learning models"
5 | authors = ["Quaterion Authors "]
6 | packages = [
7 | {include = "quaterion_models"},
8 | ]
9 | readme = "README.md"
10 | homepage = "https://github.com/qdrant/quaterion-models"
11 | repository = "https://github.com/qdrant/quaterion-models"
12 | keywords = ["framework", "metric-learning", "similarity", "similarity-learning", "deep-learning", "pytorch"]
13 |
14 |
15 | [tool.poetry.dependencies]
16 | python = ">=3.8,<3.11"
17 | torch = ">=1.8.2"
18 | numpy = "^1.22"
19 | gensim = {version = "^4.1.2", optional = true}
20 |
21 |
22 | [tool.poetry.dev-dependencies]
23 | pytest = "^6.2.5"
24 | sphinx = ">=5.0.1"
25 | qdrant-sphinx-theme = { git = "https://github.com/qdrant/qdrant_sphinx_theme.git", branch = "master" }
26 | black = "^22.3.0"
27 |
28 | [tool.poetry.extras]
29 | fasttext = ["gensim"]
30 |
31 | [build-system]
32 | requires = ["poetry-core>=1.0.0", "setuptools"]
33 | build-backend = "poetry.core.masonry.api"
34 |
--------------------------------------------------------------------------------
/quaterion_models/__init__.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.model import SimilarityModel
2 |
--------------------------------------------------------------------------------
/quaterion_models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.encoders.encoder import Encoder
2 | from quaterion_models.encoders.switch_encoder import SwitchEncoder
3 |
--------------------------------------------------------------------------------
/quaterion_models/encoders/encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Any, List
4 |
5 | from torch import Tensor, nn
6 | from torch.utils.data.dataloader import default_collate
7 |
8 | from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
9 |
10 |
11 | class Encoder(nn.Module):
12 | """Base class for encoder abstraction"""
13 |
14 | def __init__(self):
15 | super(Encoder, self).__init__()
16 |
17 | def disable_gradients_if_required(self):
18 | """Disables gradients of the model if it is declared as not trainable"""
19 |
20 | if not self.trainable:
21 | for _, weights in self.named_parameters():
22 | weights.requires_grad = False
23 |
24 | @property
25 | def trainable(self) -> bool:
26 | """Defines if encoder is trainable.
27 |
28 | This flag affects caching and checkpoint saving of the encoder.
29 | """
30 | raise NotImplementedError()
31 |
32 | @property
33 | def embedding_size(self) -> int:
34 | """Size of resulting embedding"""
35 | raise NotImplementedError()
36 |
37 | def get_collate_fn(self) -> CollateFnType:
38 | """Provides function that converts raw data batch into suitable model input
39 |
40 | Returns:
41 | :const:`~quaterion_models.types.CollateFnType`: model's collate function
42 | """
43 | return default_collate
44 |
45 | @classmethod
46 | def extract_meta(cls, batch: List[Any]) -> List[dict]:
47 | """Extracts meta information from the batch
48 |
49 | Args:
50 | batch: raw batch of data
51 |
52 | Returns:
53 | meta information
54 | """
55 | return [{} for _ in batch]
56 |
57 | def get_meta_extractor(self) -> MetaExtractorFnType:
58 | return self.extract_meta
59 |
60 | def forward(self, batch: TensorInterchange) -> Tensor:
61 | """Infer encoder - convert input batch to embeddings
62 |
63 | Args:
64 | batch: processed batch
65 | Returns:
66 | embeddings: shape: (batch_size, embedding_size)
67 | """
68 | raise NotImplementedError()
69 |
70 | def save(self, output_path: str):
71 | """Persist current state to the provided directory
72 |
73 | Args:
74 | output_path: path to save model
75 |
76 | """
77 | raise NotImplementedError()
78 |
79 | @classmethod
80 | def load(cls, input_path: str) -> Encoder:
81 | """Instantiate encoder from saved state.
82 |
83 | If no state required - just call `create` instead
84 |
85 | Args:
86 | input_path: path to load from
87 |
88 | Returns:
89 | :class:`~quaterion_models.encoders.encoder.Encoder`: loaded encoder
90 | """
91 | raise NotImplementedError()
92 |
--------------------------------------------------------------------------------
/quaterion_models/encoders/extras/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/quaterion_models/encoders/extras/__init__.py
--------------------------------------------------------------------------------
/quaterion_models/encoders/extras/fasttext_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | from typing import Any, List, Union
6 |
7 | import gensim
8 | import numpy as np
9 | import torch
10 | from gensim.models import FastText, KeyedVectors
11 | from gensim.models.fasttext import FastTextKeyedVectors
12 | from torch import Tensor
13 |
14 | from quaterion_models.encoders import Encoder
15 | from quaterion_models.types import CollateFnType
16 |
17 |
18 | def load_fasttext_model(path: str) -> Union[FastText, KeyedVectors]:
19 | """Load fasttext model in a universal way
20 |
21 | Try to find possible way of loading FastText model and load it
22 |
23 | Args:
24 | path: path to FastText model or vectors
25 |
26 | Returns:
27 | :class:`~gensim.models.fasttext.FastText` or :class:`~gensim.models.KeyedVectors`:
28 | loaded model
29 | """
30 | try:
31 | model = FastText.load(path).wv
32 | except Exception:
33 | try:
34 | model = FastText.load_fasttext_format(path).wv
35 | except Exception:
36 | model = gensim.models.KeyedVectors.load(path)
37 |
38 | return model
39 |
40 |
41 | class FasttextEncoder(Encoder):
42 | """Creates a fasttext encoder, which generates vector for a list of tokens based in given
43 | fasttext model
44 |
45 | Args:
46 | model_path: Path to model to load
47 | on_disk: If True - use mmap to keep embeddings out of RAM
48 | aggregations: What types of aggregations to use to combine multiple vectors into one. If
49 | multiple aggregations are specified - concatenation of all of them will be used as a
50 | result.
51 |
52 | """
53 |
54 | aggregation_options = ["min", "max", "avg"]
55 |
56 | def __init__(self, model_path: str, on_disk: bool, aggregations: List[str] = None):
57 | super(FasttextEncoder, self).__init__()
58 |
59 | # workaround tensor to keep information about required model device
60 | self._device_tensor = torch.nn.Parameter(torch.zeros(1))
61 |
62 | if aggregations is None:
63 | aggregations = ["avg"]
64 | self.aggregations = aggregations
65 | self.on_disk = on_disk
66 |
67 | # noinspection PyTypeChecker
68 | self.model: FastTextKeyedVectors = gensim.models.KeyedVectors.load(
69 | model_path, mmap="r" if self.on_disk else None
70 | )
71 |
72 | @property
73 | def trainable(self) -> bool:
74 | return False
75 |
76 | @property
77 | def embedding_size(self) -> int:
78 | return self.model.vector_size * len(self.aggregations)
79 |
80 | @classmethod
81 | def get_tokens(cls, batch: List[Any]) -> List[List[str]]:
82 | raise NotImplementedError()
83 |
84 | def get_collate_fn(self) -> CollateFnType:
85 | return self.__class__.get_tokens
86 |
87 | @classmethod
88 | def aggregate(cls, embeddings: Tensor, operation: str) -> Tensor:
89 | """Apply aggregation operation to embeddings along the first dimension
90 |
91 | Args:
92 | embeddings: embeddings to aggregate
93 | operation: one of :attr:`aggregation_options`
94 |
95 | Returns:
96 | Tensor: aggregated embeddings
97 | """
98 | if operation == "avg":
99 | return torch.mean(embeddings, dim=0)
100 | if operation == "max":
101 | return torch.max(embeddings, dim=0).values
102 | if operation == "min":
103 | return torch.min(embeddings, dim=0).values
104 |
105 | raise RuntimeError(f"Unknown operation: {operation}")
106 |
107 | def forward(self, batch: List[List[str]]) -> Tensor:
108 | embeddings = []
109 | for record in batch:
110 | token_vectors = [self.model.get_vector(token) for token in record]
111 | if token_vectors:
112 | record_vectors = np.stack(token_vectors)
113 | else:
114 | record_vectors = np.zeros((1, self.model.vector_size))
115 | token_tensor = torch.tensor(
116 | record_vectors, device=self._device_tensor.device
117 | )
118 | record_embedding = torch.cat(
119 | [
120 | self.aggregate(token_tensor, operation)
121 | for operation in self.aggregations
122 | ]
123 | )
124 | embeddings.append(record_embedding)
125 |
126 | return torch.stack(embeddings)
127 |
128 | def save(self, output_path: str):
129 | model_path = os.path.join(output_path, "fasttext.model")
130 | self.model.save(
131 | model_path, separately=["vectors_ngrams", "vectors", "vectors_vocab"]
132 | )
133 | with open(os.path.join(output_path, "config.json"), "w") as f_out:
134 | json.dump(
135 | {
136 | "on_disk": self.on_disk,
137 | "aggregations": self.aggregations,
138 | },
139 | f_out,
140 | indent=2,
141 | )
142 |
143 | @classmethod
144 | def load(cls, input_path: str) -> Encoder:
145 | model_path = os.path.join(input_path, "fasttext.model")
146 | with open(os.path.join(input_path, "config.json")) as f_in:
147 | config = json.load(f_in)
148 |
149 | return cls(model_path=model_path, **config)
150 |
--------------------------------------------------------------------------------
/quaterion_models/encoders/switch_encoder.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | from functools import partial
6 | from typing import Any, Dict, List
7 |
8 | import torch
9 | from torch import Tensor
10 |
11 | from quaterion_models.encoders import Encoder
12 | from quaterion_models.types import CollateFnType, TensorInterchange
13 | from quaterion_models.utils import move_to_device, restore_class, save_class_import
14 |
15 |
16 | def inverse_permutation(perm):
17 | inv = torch.empty_like(perm)
18 | inv[perm] = torch.arange(perm.size(0), device=perm.device)
19 | return inv
20 |
21 |
22 | class SwitchEncoder(Encoder):
23 | """Allows use alternative embeddings based on input data.
24 |
25 | For example, train shared embedding representation for images and texts.
26 | In this case image encoder should be used if input is an image and text encoder in other case.
27 | """
28 |
29 | @classmethod
30 | def encoder_selection(cls, record: Any) -> str:
31 | """Decide which encoder to use for given record.
32 |
33 | Args:
34 | record: input piece of data
35 |
36 | Returns:
37 | name of the related encoder
38 | """
39 | raise NotImplementedError()
40 |
41 | def __init__(self, options: Dict[str, Encoder]):
42 | super(SwitchEncoder, self).__init__()
43 | self.options = options
44 |
45 | embedding_sizes = set()
46 | for key, encoder in self.options.items():
47 | self.add_module(key, encoder)
48 | embedding_sizes.add(encoder.embedding_size)
49 |
50 | if len(embedding_sizes) != 1:
51 | raise RuntimeError(
52 | f"Alternative encoders have inconsistent output size: {embedding_sizes}"
53 | )
54 |
55 | self._embedding_size = list(embedding_sizes)[0]
56 |
57 | def disable_gradients_if_required(self):
58 | for encoder in self.options.values():
59 | encoder.disable_gradients_if_required()
60 |
61 | @property
62 | def trainable(self) -> bool:
63 | return any(encoder.trainable for encoder in self.options.values())
64 |
65 | @property
66 | def embedding_size(self) -> int:
67 | return self._embedding_size
68 |
69 | @classmethod
70 | def switch_collate_fn(
71 | cls, batch: List[Any], encoder_collates: Dict[str, CollateFnType]
72 | ) -> TensorInterchange:
73 | switch_batches = dict((key, []) for key in encoder_collates.keys())
74 | switch_ordering = dict((key, []) for key in encoder_collates.keys())
75 | for original_id, record in enumerate(batch):
76 | record_encoder = cls.encoder_selection(record)
77 | switch_batches[record_encoder].append(record)
78 | switch_ordering[record_encoder].append(original_id)
79 |
80 | switch_batches = {
81 | key: encoder_collates[key](batch)
82 | for key, batch in switch_batches.items()
83 | if len(batch) > 0
84 | }
85 |
86 | return {"ordering": switch_ordering, "batches": switch_batches}
87 |
88 | def get_collate_fn(self) -> CollateFnType:
89 | return partial(
90 | self.__class__.switch_collate_fn,
91 | encoder_collates=dict(
92 | (key, encoder.get_collate_fn()) for key, encoder in self.options.items()
93 | ),
94 | )
95 |
96 | @classmethod
97 | def extract_meta(cls, batch: List[Any]) -> List[dict]:
98 | meta = []
99 | for record in batch:
100 | meta.append(
101 | {
102 | "encoder": cls.encoder_selection(record),
103 | }
104 | )
105 | return meta
106 |
107 | def forward(self, batch: TensorInterchange) -> Tensor:
108 | switch_ordering: dict = batch["ordering"]
109 | switch_batches: dict = batch["batches"]
110 | embeddings = []
111 | ordering = []
112 | for key, batch in switch_batches.items():
113 | embeddings.append(self.options[key].forward(batch))
114 | ordering += switch_ordering[key]
115 | ordering_tensor: Tensor = inverse_permutation(torch.tensor(ordering))
116 | embeddings_tensor: Tensor = torch.cat(embeddings)
117 | ordering_tensor = move_to_device(ordering_tensor, embeddings_tensor.device)
118 | return embeddings_tensor[ordering_tensor]
119 |
120 | def save(self, output_path: str):
121 | encoders = {}
122 | for key, encoder in self.options.items():
123 | encoders[key] = save_class_import(encoder)
124 | encoder_path = os.path.join(output_path, key)
125 | os.makedirs(encoder_path, exist_ok=True)
126 | encoder.save(encoder_path)
127 |
128 | with open(os.path.join(output_path, "config.json"), "w") as f_out:
129 | json.dump(
130 | {
131 | "encoders": encoders,
132 | },
133 | f_out,
134 | indent=2,
135 | )
136 |
137 | @classmethod
138 | def load(cls, input_path: str) -> "Encoder":
139 | with open(os.path.join(input_path, "config.json")) as f_in:
140 | config = json.load(f_in)
141 |
142 | encoders = {}
143 | encoders_params: dict = config["encoders"]
144 | for key, class_params in encoders_params.items():
145 | encoder_path = os.path.join(input_path, key)
146 | encoder_class = restore_class(class_params)
147 | encoders[key] = encoder_class.load(encoder_path)
148 |
149 | return cls(options=encoders)
150 |
--------------------------------------------------------------------------------
/quaterion_models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.heads.empty_head import EmptyHead
2 | from quaterion_models.heads.encoder_head import EncoderHead
3 | from quaterion_models.heads.gated_head import GatedHead
4 | from quaterion_models.heads.sequential_head import SequentialHead
5 | from quaterion_models.heads.skip_connection_head import SkipConnectionHead
6 | from quaterion_models.heads.softmax_head import SoftmaxEmbeddingsHead
7 | from quaterion_models.heads.stacked_projection_head import StackedProjectionHead
8 | from quaterion_models.heads.switch_head import SwitchHead
9 | from quaterion_models.heads.widening_head import WideningHead
10 |
--------------------------------------------------------------------------------
/quaterion_models/heads/empty_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from quaterion_models.heads.encoder_head import EncoderHead
4 |
5 |
6 | class EmptyHead(EncoderHead):
7 | """Returns input embeddings without any modification"""
8 |
9 | def __init__(self, input_embedding_size: int, dropout: float = 0.0):
10 | super(EmptyHead, self).__init__(input_embedding_size, dropout=dropout)
11 |
12 | @property
13 | def output_size(self) -> int:
14 | return self.input_embedding_size
15 |
16 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
17 | return input_vectors
18 |
--------------------------------------------------------------------------------
/quaterion_models/heads/encoder_head.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any, Dict, List
4 |
5 | import torch
6 | from torch import nn
7 |
8 |
9 | class EncoderHead(nn.Module):
10 | """Base class for the final layer of fine-tuned model.
11 | EncoderHead is the only trainable component in case of frozen encoders.
12 |
13 | Args:
14 | input_embedding_size:
15 | Size of the concatenated embedding, obtained from combination of all configured encoders
16 | dropout:
17 | Probability of Dropout. If `dropout > 0.`, apply dropout layer
18 | on embeddings before applying head layer transformations
19 | **kwargs:
20 | """
21 |
22 | def __init__(self, input_embedding_size: int, dropout: float = 0.0, **kwargs):
23 | super(EncoderHead, self).__init__()
24 | self.input_embedding_size = input_embedding_size
25 | self._dropout_prob = dropout
26 | self.dropout = (
27 | torch.nn.Dropout(p=dropout) if dropout > 0.0 else torch.nn.Identity()
28 | )
29 |
30 | @property
31 | def output_size(self) -> int:
32 | raise NotImplementedError()
33 |
34 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
35 | """Apply head-specific transformations to the embeddings tensor.
36 | Called as part of `forward` function, but with generic wrappings
37 |
38 | Args:
39 | input_vectors: Concatenated embeddings of all encoders. Shape: (batch_size, self.input_embedding_size)
40 |
41 | Returns:
42 | Final embeddings for a batch: (batch_size, self.output_size)
43 | """
44 | raise NotImplementedError()
45 |
46 | def forward(
47 | self, input_vectors: torch.Tensor, meta: List[Any] = None
48 | ) -> torch.Tensor:
49 | return self.transform(self.dropout(input_vectors))
50 |
51 | def get_config_dict(self) -> Dict[str, Any]:
52 | """Constructs savable params dict
53 |
54 | Returns:
55 | Serializable parameters for __init__ of the Module
56 | """
57 | return {
58 | "input_embedding_size": self.input_embedding_size,
59 | "dropout": self._dropout_prob,
60 | }
61 |
62 | def save(self, output_path):
63 | torch.save(self.state_dict(), os.path.join(output_path, "weights.bin"))
64 |
65 | with open(os.path.join(output_path, "config.json"), "w") as f_out:
66 | json.dump(self.get_config_dict(), f_out, indent=2)
67 |
68 | @classmethod
69 | def load(cls, input_path: str) -> "EncoderHead":
70 | with open(os.path.join(input_path, "config.json")) as f_in:
71 | config = json.load(f_in)
72 | model = cls(**config)
73 | model.load_state_dict(torch.load(os.path.join(input_path, "weights.bin")))
74 | return model
75 |
--------------------------------------------------------------------------------
/quaterion_models/heads/gated_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import Parameter
3 |
4 | from quaterion_models.heads.encoder_head import EncoderHead
5 |
6 |
7 | class GatedHead(EncoderHead):
8 | """Disables or amplifies some components of input embedding.
9 |
10 | This layer has minimal amount of trainable parameters and is suitable for even small training
11 | sets.
12 | """
13 |
14 | def __init__(self, input_embedding_size: int, dropout: float = 0.0):
15 | super(GatedHead, self).__init__(input_embedding_size, dropout=dropout)
16 | self.gates = Parameter(torch.Tensor(self.input_embedding_size))
17 | self.reset_parameters()
18 |
19 | @property
20 | def output_size(self) -> int:
21 | return self.input_embedding_size
22 |
23 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
24 | """
25 |
26 | Args:
27 | input_vectors: shape: (batch_size, vector_size)
28 |
29 | Returns:
30 | Tensor: (batch_size, vector_size)
31 | """
32 | return input_vectors * torch.tanh(self.gates)
33 |
34 | def reset_parameters(self) -> None:
35 | torch.nn.init.constant_(
36 | self.gates, 2.0
37 | ) # 2. ensures that all vector components are enabled by default
38 |
--------------------------------------------------------------------------------
/quaterion_models/heads/sequential_head.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any, Dict, Iterator, Union
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | from quaterion_models.heads.encoder_head import EncoderHead
9 |
10 |
11 | class SequentialHead(EncoderHead):
12 | """A `torch.nn.Sequential`-like head layer that you can freely add any layers.
13 |
14 | Unlike `torch.nn.Sequential`, it also expects the output size to be passed
15 | as a required keyword-only argument. It is required because some loss functions
16 | may need this information.
17 |
18 | Args:
19 | args: Any sequence of `torch.nn.Module` instances. See `torch.nn.Sequential` for more info.
20 | output_size: Final output dimension from this head.
21 |
22 | Examples::
23 |
24 | head = SequentialHead(
25 | nn.Linear(10, 20),
26 | nn.ReLU(),
27 | nn.Linear(20, 30),
28 | output_size=30
29 | )
30 | """
31 |
32 | def __init__(self, *args, output_size: int):
33 | super().__init__(None)
34 | self._sequential = nn.Sequential(*args)
35 | self._output_size = output_size
36 |
37 | @property
38 | def output_size(self) -> int:
39 | return self._output_size
40 |
41 | def forward(self, input_vectors: torch.Tensor, meta=None) -> torch.Tensor:
42 | """Forward pass for this head layer.
43 |
44 | Just like `torch.nn.Sequential`, it passes the input to the first module,
45 | and the output of each module is input to the next. The final output of this head layer is
46 | the output from the last module in the sequence.
47 |
48 | Args:
49 | input_vectors: Batch of input vectors.
50 | meta: Optional metadata for this batch.
51 |
52 | Returns:
53 | Output from the last module in the sequence.
54 | """
55 | return self._sequential.forward(input_vectors)
56 |
57 | def append(self, module: nn.Module) -> "SequentialHead":
58 | self._sequential.append(module)
59 | return self
60 |
61 | def get_config_dict(self) -> Dict[str, Any]:
62 | """Constructs savable params dict
63 |
64 | Returns:
65 | Serializable parameters for __init__ of the Module
66 | """
67 | return {
68 | "output_size": self._output_size,
69 | }
70 |
71 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
72 | return input_vectors
73 |
74 | def save(self, output_path):
75 | torch.save(self._sequential, os.path.join(output_path, "weights.bin"))
76 |
77 | with open(os.path.join(output_path, "config.json"), "w") as f_out:
78 | json.dump(self.get_config_dict(), f_out, indent=2)
79 |
80 | @classmethod
81 | def load(cls, input_path: str) -> "EncoderHead":
82 | with open(os.path.join(input_path, "config.json")) as f_in:
83 | config = json.load(f_in)
84 | sequential = torch.load(
85 | os.path.join(input_path, "weights.bin"), map_location="cpu"
86 | )
87 | model = cls(*sequential, **config)
88 | return model
89 |
90 | def __getitem__(self, idx) -> Union[nn.Sequential, nn.Module]:
91 | return self._sequential[idx]
92 |
93 | def __delitem__(self, idx: Union[slice, int]) -> None:
94 | return self._sequential.__delitem(idx)
95 |
96 | def __setitem__(self, idx: int, module: nn.Module) -> None:
97 | return self._sequential.__setitem__(idx, module)
98 |
99 | def __len__(self) -> int:
100 | return self._sequential.__len__()
101 |
102 | def __dir__(self):
103 | return self._sequential.__dir__()
104 |
105 | def __iter__(self) -> Iterator[nn.Module]:
106 | return self._sequential.__iter__()
107 |
--------------------------------------------------------------------------------
/quaterion_models/heads/skip_connection_head.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import torch
4 | from torch.nn import Parameter
5 |
6 | from quaterion_models.heads.encoder_head import EncoderHead
7 |
8 |
9 | class SkipConnectionHead(EncoderHead):
10 | """Unites the idea of gated head and residual connections.
11 |
12 | Schema:
13 | .. code-block:: none
14 |
15 | ├──────────┐
16 | ┌───────┴───────┐ │
17 | │ Skip-Dropout │ │
18 | └───────┬───────┘ │
19 | ┌───────┴───────┐ │
20 | │ Linear │ │
21 | └───────┬───────┘ │
22 | ┌───────┴───────┐ │
23 | │ Gated │ │
24 | └───────┬───────┘ │
25 | + <────────┘
26 | │
27 |
28 | Args:
29 | input_embedding_size:
30 | Size of the concatenated embedding, obtained from combination of all configured encoders
31 | dropout:
32 | Probability of Dropout. If `dropout > 0.`, apply dropout layer
33 | on embeddings before applying head layer transformations
34 | skip_dropout:
35 | Additional dropout, applied to the trainable branch only.
36 | Using additional dropout allows to avoid the modification of original embedding.
37 | n_layers:
38 | Number of gated residual blocks stacked on top of each other.
39 | """
40 |
41 | def __init__(
42 | self,
43 | input_embedding_size: int,
44 | dropout: float = 0.0,
45 | skip_dropout: float = 0.0,
46 | n_layers: int = 1,
47 | ):
48 | super().__init__(input_embedding_size, dropout=dropout)
49 | for i in range(n_layers):
50 | self.register_parameter(
51 | f"gates_{i}", Parameter(torch.Tensor(self.input_embedding_size))
52 | )
53 | setattr(
54 | self,
55 | f"fc_{i}",
56 | torch.nn.Linear(input_embedding_size, input_embedding_size),
57 | )
58 |
59 | self.skip_dropout = (
60 | torch.nn.Dropout(p=skip_dropout)
61 | if skip_dropout > 0.0
62 | else torch.nn.Identity()
63 | )
64 |
65 | self._skip_dropout_prob = skip_dropout
66 | self._n_layers = n_layers
67 | self.reset_parameters()
68 |
69 | @property
70 | def output_size(self) -> int:
71 | return self.input_embedding_size
72 |
73 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
74 | """
75 | Args:
76 | input_vectors: shape: (batch_size, input_embedding_size)
77 |
78 | Returns:
79 | torch.Tensor: shape: (batch_size, input_embedding_size)
80 | """
81 | for i in range(self._n_layers):
82 | fc = getattr(self, f"fc_{i}")
83 | gates = getattr(self, f"gates_{i}")
84 | input_vectors = (
85 | fc(self.skip_dropout(input_vectors)) * torch.sigmoid(gates)
86 | + input_vectors
87 | )
88 |
89 | return input_vectors
90 |
91 | def reset_parameters(self) -> None:
92 | for i in range(self._n_layers):
93 | torch.nn.init.constant_(
94 | getattr(self, f"gates_{i}"), -4.0
95 | ) # -4. ensures that all vector components are disabled by default
96 |
97 | def get_config_dict(self) -> Dict[str, Any]:
98 | config = super().get_config_dict()
99 | config.update(
100 | {"skip_dropout": self._skip_dropout_prob, "n_layers": self._n_layers}
101 | )
102 | return config
103 |
--------------------------------------------------------------------------------
/quaterion_models/heads/softmax_head.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | import torch
4 | from torch.nn import Linear
5 |
6 | from quaterion_models.heads.encoder_head import EncoderHead
7 |
8 |
9 | class SoftmaxEmbeddingsHead(EncoderHead):
10 | """Provides a concatenation of the independent softmax embeddings groups as a head layer
11 |
12 | Useful for deriving embedding confidence.
13 |
14 | Schema:
15 | .. code-block:: none
16 |
17 | ┌──────────────────┐
18 | │ Encoder │
19 | └──┬───────────┬───┘
20 | │ │
21 | │ │
22 | ┌───────────┼───────────┼───────────┐
23 | │ │ │ │
24 | │ ┌─────────▼──┐ ┌──▼─────────┐ │
25 | │ │ Linear │ ... │ Linear │ │
26 | │ └─────┬──────┘ └─────┬──────┘ │
27 | │ │ │ │
28 | │ ┌─────┴──────┐ ┌─────┴──────┐ │
29 | │ │ SoftMax │ ... │ SoftMax │ │
30 | │ └─────┬──────┘ └─────┬──────┘ │
31 | │ │ │ │
32 | │ ┌────┴──────────────────┴─────┐ │
33 | │ │ Concatenation │ │
34 | │ └──────────────┬──────────────┘ │
35 | │ │ │
36 | └─────────────────┼─────────────────┘
37 | │
38 | ▼
39 |
40 | """
41 |
42 | def __init__(
43 | self,
44 | output_groups: int,
45 | output_size_per_group: int,
46 | input_embedding_size: int,
47 | dropout: float = 0.0,
48 | **kwargs
49 | ):
50 | super().__init__(input_embedding_size, dropout=dropout, **kwargs)
51 |
52 | self.output_groups = output_groups
53 | self.output_size_per_group = output_size_per_group
54 | self.projectors = []
55 |
56 | self.projection_layer = Linear(
57 | self.input_embedding_size, self.output_size_per_group * self.output_groups
58 | )
59 |
60 | @property
61 | def output_size(self) -> int:
62 | return self.output_size_per_group * self.output_groups
63 |
64 | def transform(self, input_vectors: torch.Tensor):
65 | """
66 |
67 | Args:
68 | input_vectors: shape: (batch_size, ..., input_dim)
69 |
70 | Returns:
71 | shape (batch_size, ..., self.output_size_per_group * self.output_groups)
72 | """
73 |
74 | # shape: [batch_size, ..., self.output_size_per_group * self.output_groups]
75 | projection = self.projection_layer(input_vectors)
76 |
77 | init_shape = projection.shape
78 | groups_shape = list(init_shape)
79 | groups_shape[-1] = self.output_groups
80 | groups_shape.append(-1)
81 |
82 | # shape: [batch_size, ..., self.output_groups, self.output_size_per_group]
83 | grouped_projection = torch.softmax(projection.view(*groups_shape), dim=-1)
84 |
85 | return grouped_projection.view(init_shape)
86 |
87 | def get_config_dict(self) -> Dict[str, Any]:
88 | config = super().get_config_dict()
89 | config.update(
90 | {
91 | "output_groups": self.output_groups,
92 | "output_size_per_group": self.output_size_per_group,
93 | }
94 | )
95 |
96 | return config
97 |
--------------------------------------------------------------------------------
/quaterion_models/heads/stacked_projection_head.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from quaterion_models.heads.encoder_head import EncoderHead
7 | from quaterion_models.modules import ActivationFromFnName
8 |
9 |
10 | class StackedProjectionHead(EncoderHead):
11 | """Stacks any number of projection layers with specified output sizes.
12 |
13 | Args:
14 | input_embedding_size: Dimensionality of the input to this stack of layers.
15 | output_sizes: List of output sizes for each one of the layers stacked.
16 | activation_fn: Name of the activation function to apply between the layers stacked.
17 | Must be an attribute of `torch.nn.functional` and defaults to `relu`.
18 | dropout: Probability of Dropout. If `dropout > 0.`, apply dropout layer
19 | on embeddings before applying head layer transformations
20 | """
21 |
22 | def __init__(
23 | self,
24 | input_embedding_size: int,
25 | output_sizes: List[int],
26 | activation_fn: str = "relu",
27 | dropout: float = 0.0,
28 | ):
29 | super(StackedProjectionHead, self).__init__(
30 | input_embedding_size, dropout=dropout
31 | )
32 | self._output_sizes = output_sizes
33 | self._activation_fn = activation_fn
34 |
35 | modules = [nn.Linear(input_embedding_size, self._output_sizes[0])]
36 |
37 | if len(self._output_sizes) > 1:
38 | for i in range(1, len(self._output_sizes)):
39 | modules.extend(
40 | [
41 | ActivationFromFnName(self._activation_fn),
42 | nn.Linear(self._output_sizes[i - 1], self._output_sizes[i]),
43 | ]
44 | )
45 |
46 | self._stack = nn.Sequential(*modules)
47 |
48 | @property
49 | def output_size(self) -> int:
50 | return self._output_sizes[-1]
51 |
52 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
53 | return self._stack(input_vectors)
54 |
55 | def get_config_dict(self) -> Dict[str, Any]:
56 | config = super().get_config_dict()
57 | config.update(
58 | {"output_sizes": self._output_sizes, "activation_fn": self._activation_fn}
59 | )
60 |
61 | return config
62 |
--------------------------------------------------------------------------------
/quaterion_models/heads/switch_head.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any, Dict, List
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | from quaterion_models.encoders.switch_encoder import inverse_permutation
9 | from quaterion_models.heads import EncoderHead
10 | from quaterion_models.utils import restore_class, save_class_import
11 |
12 |
13 | class SwitchHead(EncoderHead):
14 | """Encoder which switches between different heads based on the metadata
15 | Useful in combination with the SwitchEncoder for training multimodal models
16 |
17 | Args:
18 | options: dict of heads. Choice of head is based on the metadata key
19 | """
20 |
21 | def __init__(
22 | self, options: Dict[str, EncoderHead], input_embedding_size: int, **kwargs
23 | ):
24 | super().__init__(input_embedding_size, dropout=0.0, **kwargs)
25 | self._heads = options
26 | for key, head in self._heads.items():
27 | self.add_module(key, head)
28 |
29 | @property
30 | def output_size(self) -> int:
31 | return next(iter(self._heads.values())).output_size
32 |
33 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
34 | pass
35 |
36 | def forward(
37 | self, input_vectors: torch.Tensor, meta: List[Any] = None
38 | ) -> torch.Tensor:
39 | # Shape: [batch_size x input_embedding_size]
40 | dropout_input = self.dropout(input_vectors)
41 |
42 | switch_mask = dict((key, []) for key in self._heads.keys())
43 | switch_ordering = dict((key, []) for key in self._heads.keys())
44 | for i, m in enumerate(meta):
45 | switch_ordering[m["encoder"]].append(i)
46 | for key, mask in switch_mask.items():
47 | mask.append(int(key == m["encoder"]))
48 |
49 | head_outputs = []
50 | ordering = []
51 | for key, mask in switch_mask.items():
52 | # Shape: [batch_size]
53 | mask = torch.tensor(mask, dtype=torch.bool, device=input_vectors.device)
54 | head_outputs.append(self._heads[key].transform(dropout_input[mask]))
55 | ordering += switch_ordering[key]
56 |
57 | ordering_tensor: Tensor = inverse_permutation(torch.tensor(ordering))
58 | # Shape: [batch_size x output_size]
59 | return torch.cat(head_outputs)[ordering_tensor]
60 |
61 | def get_config_dict(self) -> Dict[str, Any]:
62 | """Constructs savable params dict
63 |
64 | Returns:
65 | Serializable parameters for __init__ of the Module
66 | """
67 | return {
68 | "heads": {
69 | k: {"config": v.get_config_dict(), "class": save_class_import(v)}
70 | for k, v in self._heads.items()
71 | },
72 | "input_embedding_size": self.input_embedding_size,
73 | }
74 |
75 | @classmethod
76 | def load(cls, input_path: str) -> "EncoderHead":
77 | with open(os.path.join(input_path, "config.json")) as f_in:
78 | config = json.load(f_in)
79 |
80 | heads_config = config.pop("heads")
81 |
82 | heads = dict(
83 | (key, restore_class(head_config["class"])(**head_config["config"]))
84 | for key, head_config in heads_config.items()
85 | )
86 |
87 | model = cls(options=heads, **config)
88 | model.load_state_dict(torch.load(os.path.join(input_path, "weights.bin")))
89 | return model
90 |
--------------------------------------------------------------------------------
/quaterion_models/heads/widening_head.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from quaterion_models.heads.stacked_projection_head import StackedProjectionHead
4 |
5 |
6 | class WideningHead(StackedProjectionHead):
7 | """Implements narrow-wide-narrow architecture.
8 |
9 | Widen the dimensionality by a factor of `expansion_factor` and narrow it down back to
10 | `input_embedding_size`.
11 |
12 | Args:
13 | input_embedding_size: Dimensionality of the input to this head layer.
14 | expansion_factor: Widen the dimensionality by this factor in the intermediate layer.
15 | activation_fn: Name of the activation function to apply after the intermediate layer.
16 | Must be an attribute of `torch.nn.functional` and defaults to `relu`.
17 | dropout: Probability of Dropout. If `dropout > 0.`, apply dropout layer
18 | on embeddings before applying head layer transformations
19 | """
20 |
21 | def __init__(
22 | self,
23 | input_embedding_size: int,
24 | expansion_factor: float = 4.0,
25 | activation_fn: str = "relu",
26 | dropout: float = 0.0,
27 | **kwargs
28 | ):
29 | self._expansion_factor = expansion_factor
30 | self._activation_fn = activation_fn
31 | super(WideningHead, self).__init__(
32 | input_embedding_size=input_embedding_size,
33 | output_sizes=[
34 | int(input_embedding_size * expansion_factor),
35 | input_embedding_size,
36 | ],
37 | activation_fn=activation_fn,
38 | dropout=dropout,
39 | )
40 |
41 | def get_config_dict(self) -> Dict[str, Any]:
42 | config = super().get_config_dict()
43 | config.update(
44 | {
45 | "expansion_factor": self._expansion_factor,
46 | }
47 | )
48 | return config
49 |
--------------------------------------------------------------------------------
/quaterion_models/model.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import json
4 | import os
5 | from functools import partial
6 | from typing import Any, Callable, Dict, List, Type, Union
7 |
8 | import numpy as np
9 | import torch
10 | from torch import nn
11 |
12 | from quaterion_models.encoders import Encoder
13 | from quaterion_models.heads.encoder_head import EncoderHead
14 | from quaterion_models.types import CollateFnType, MetaExtractorFnType, TensorInterchange
15 | from quaterion_models.utils.classes import restore_class, save_class_import
16 | from quaterion_models.utils.meta import merge_meta
17 | from quaterion_models.utils.tensors import move_to_device
18 |
19 | DEFAULT_ENCODER_KEY = "default"
20 |
21 |
22 | class SimilarityModel(nn.Module):
23 | """Main class which contains encoder models with the head layer."""
24 |
25 | def __init__(self, encoders: Union[Encoder, Dict[str, Encoder]], head: EncoderHead):
26 | super().__init__()
27 |
28 | if not isinstance(encoders, dict):
29 | self.encoders: Dict[str, Encoder] = {DEFAULT_ENCODER_KEY: encoders}
30 | else:
31 | self.encoders: Dict[str, Encoder] = encoders
32 |
33 | for key, encoder in self.encoders.items():
34 | encoder.disable_gradients_if_required()
35 | self.add_module(key, encoder)
36 |
37 | self.head = head
38 |
39 | @classmethod
40 | def collate_fn(
41 | cls,
42 | batch: List[dict],
43 | encoders_collate_fns: Dict[str, CollateFnType],
44 | meta_extractors: Dict[str, MetaExtractorFnType],
45 | ) -> TensorInterchange:
46 | """Construct batches for all encoders
47 |
48 | Args:
49 | batch:
50 | encoders_collate_fns: Dict (or single) of collate functions associated with encoders
51 | meta_extractors: Dict (or single) of meta extractor functions associated with encoders
52 |
53 | """
54 | data = dict(
55 | (key, collate_fn(batch)) for key, collate_fn in encoders_collate_fns.items()
56 | )
57 | meta = dict(
58 | (key, meta_extractor_fn(batch))
59 | for key, meta_extractor_fn in meta_extractors.items()
60 | )
61 | return {
62 | "data": data,
63 | "meta": merge_meta(meta),
64 | }
65 |
66 | @classmethod
67 | def get_encoders_output_size(cls, encoders: Union[Encoder, Dict[str, Encoder]]):
68 | """Calculate total output size of given encoders
69 |
70 | Args:
71 | encoders:
72 |
73 | """
74 | encoders = encoders.values() if isinstance(encoders, dict) else [encoders]
75 | total_size = 0
76 | for encoder in encoders:
77 | total_size += encoder.embedding_size
78 | return total_size
79 |
80 | def train(self, mode: bool = True):
81 | super().train(mode)
82 |
83 | def get_collate_fn(self) -> Callable:
84 | """Construct a function to convert input data into neural network inputs
85 |
86 | Returns:
87 | neural network inputs
88 | """
89 | return partial(
90 | SimilarityModel.collate_fn,
91 | encoders_collate_fns=dict(
92 | (key, encoder.get_collate_fn())
93 | for key, encoder in self.encoders.items()
94 | ),
95 | meta_extractors=dict(
96 | (key, encoder.get_meta_extractor())
97 | for key, encoder in self.encoders.items()
98 | ),
99 | )
100 |
101 | # -------------------------------------------
102 | # ---------- Inference methods --------------
103 | # -------------------------------------------
104 |
105 | def encode(
106 | self, inputs: Union[List[Any], Any], batch_size=32, to_numpy=True
107 | ) -> Union[torch.Tensor, np.ndarray]:
108 | """Encode data in batches
109 |
110 | Args:
111 | inputs: list of input data to encode
112 | batch_size:
113 | to_numpy:
114 |
115 | Returns:
116 | Numpy array or torch.Tensor of shape (input_size, embedding_size)
117 | """
118 | self.eval()
119 | device = next(self.parameters(), torch.tensor(0)).device
120 | collate_fn = self.get_collate_fn()
121 |
122 | input_was_list = True
123 | if not isinstance(inputs, list):
124 | input_was_list = False
125 | inputs = [inputs]
126 |
127 | all_embeddings = []
128 |
129 | for start_index in range(0, len(inputs), batch_size):
130 | input_batch = [
131 | inputs[i]
132 | for i in range(start_index, min(len(inputs), start_index + batch_size))
133 | ]
134 | features = collate_fn(input_batch)
135 | features = move_to_device(features, device)
136 |
137 | with torch.no_grad():
138 | embeddings = self.forward(features)
139 | embeddings = embeddings.detach()
140 | if to_numpy:
141 | embeddings = embeddings.cpu().numpy()
142 | all_embeddings.append(embeddings)
143 |
144 | if to_numpy:
145 | all_embeddings = np.concatenate(all_embeddings, axis=0)
146 | else:
147 | all_embeddings = torch.cat(all_embeddings, dim=0)
148 |
149 | if not input_was_list:
150 | all_embeddings = all_embeddings.squeeze()
151 |
152 | if to_numpy:
153 | all_embeddings = np.atleast_2d(all_embeddings)
154 | else:
155 | all_embeddings = torch.atleast_2d(all_embeddings)
156 |
157 | return all_embeddings
158 |
159 | def forward(self, batch):
160 | embeddings = [
161 | (key, encoder.forward(batch["data"][key]))
162 | for key, encoder in self.encoders.items()
163 | ]
164 |
165 | meta = batch["meta"]
166 |
167 | # Order embeddings by key name, to ensure reproduction
168 | embeddings = sorted(embeddings, key=lambda x: x[0])
169 |
170 | # Only embedding tensors of shape [batch_size x encoder_output_size]
171 | embedding_tensors = [embedding[1] for embedding in embeddings]
172 |
173 | # Shape: [batch_size x sum( encoders_emb_sizes )]
174 | joined_embeddings = torch.cat(embedding_tensors, dim=1)
175 |
176 | # Shape: [batch_size x output_emb_size]
177 | result_embedding = self.head(joined_embeddings, meta=meta)
178 |
179 | return result_embedding
180 |
181 | # -------------------------------------------
182 | # ---------- Persistence methods ------------
183 | # -------------------------------------------
184 |
185 | @classmethod
186 | def _get_head_path(cls, directory: str):
187 | return os.path.join(directory, "head")
188 |
189 | @classmethod
190 | def _get_encoders_path(cls, directory: str):
191 | return os.path.join(directory, "encoders")
192 |
193 | def save(self, output_path: str):
194 | head_path = self._get_head_path(output_path)
195 | os.makedirs(head_path, exist_ok=True)
196 | self.head.save(head_path)
197 |
198 | head_config = save_class_import(self.head)
199 |
200 | encoders_path = self._get_encoders_path(output_path)
201 | os.makedirs(encoders_path, exist_ok=True)
202 |
203 | encoders_config = []
204 |
205 | for encoder_key, encoder in self.encoders.items():
206 | encoder_path = os.path.join(encoders_path, encoder_key)
207 | os.mkdir(encoder_path)
208 | encoder.save(encoder_path)
209 | encoders_config.append({"key": encoder_key, **save_class_import(encoder)})
210 |
211 | with open(os.path.join(output_path, "config.json"), "w") as f_out:
212 | json.dump(
213 | {"encoders": encoders_config, "head": head_config}, f_out, indent=2
214 | )
215 |
216 | @classmethod
217 | def load(cls, input_path: str) -> SimilarityModel:
218 | with open(os.path.join(input_path, "config.json")) as f_in:
219 | config = json.load(f_in)
220 |
221 | head_config = config["head"]
222 | head_class: Type[EncoderHead] = restore_class(head_config)
223 | head_path = cls._get_head_path(input_path)
224 | head = head_class.load(head_path)
225 |
226 | encoders: Union[Encoder, Dict[str, Encoder]] = {}
227 | encoders_path = cls._get_encoders_path(input_path)
228 | encoders_config = config["encoders"]
229 |
230 | for encoder_params in encoders_config:
231 | encoder_key = encoder_params["key"]
232 | encoder_class = restore_class(encoder_params)
233 | encoders[encoder_key] = encoder_class.load(
234 | os.path.join(encoders_path, encoder_key)
235 | )
236 |
237 | return cls(head=head, encoders=encoders)
238 |
239 |
240 | # In this framework, the terms Metric Learning and Similarity Learning are considered synonymous.
241 | # However, the word "Metric" overlaps with other concepts in model training.
242 | # In addition, the semantics of the word "Similarity" are simpler.
243 | # It better reflects the basic idea of this training approach.
244 | # That's why we prefer to use Similarity over Metric.
245 | MetricModel = SimilarityModel
246 |
--------------------------------------------------------------------------------
/quaterion_models/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.modules.simple import ActivationFromFnName
2 |
--------------------------------------------------------------------------------
/quaterion_models/modules/simple.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ActivationFromFnName(nn.Module):
7 | """Simple module constructed from function name to be used in `nn.Sequential`
8 |
9 | Construct a `nn.Module` that applies the specified activation function to inputs
10 |
11 | Args:
12 | activation_fn: Name of the activation function to apply to input.
13 | Must be an attribute of `torch.nn.functional`.
14 | """
15 |
16 | def __init__(self, activation_fn: str):
17 | super().__init__()
18 | self._activation_fn = activation_fn
19 |
20 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
21 | return vars(F)[self._activation_fn](inputs)
22 |
--------------------------------------------------------------------------------
/quaterion_models/types/__init__.py:
--------------------------------------------------------------------------------
1 | # TODO: Split them into separate files once we have more of them.
2 | from typing import Any, Callable, Dict, List, Tuple, Union
3 |
4 | from torch import Tensor
5 |
6 | #:
7 | TensorInterchange = Union[
8 | Tensor,
9 | Tuple[Tensor],
10 | List[Tensor],
11 | Dict[str, Tensor],
12 | Dict[str, dict],
13 | Any,
14 | ]
15 | #:
16 | CollateFnType = Callable[[List[Any]], TensorInterchange]
17 | MetaExtractorFnType = Callable[[List[Any]], List[dict]]
18 |
--------------------------------------------------------------------------------
/quaterion_models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.utils.classes import restore_class, save_class_import
2 | from quaterion_models.utils.tensors import move_to_device
3 |
--------------------------------------------------------------------------------
/quaterion_models/utils/classes.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | from typing import Any
4 |
5 |
6 | def save_class_import(obj: Any) -> dict:
7 | """
8 | Serializes information about object class
9 | :param obj:
10 | :return: serializable class info
11 | """
12 | if obj.__module__ == "__main__":
13 | logging.warning(
14 | f"Class {obj.__class__.__qualname__} is defined in a same file as training loop."
15 | f" It won't be possible to load it properly later."
16 | )
17 |
18 | return {"module": obj.__module__, "class": obj.__class__.__qualname__}
19 |
20 |
21 | def restore_class(data: dict) -> Any:
22 | """
23 | :param data: name of module and class
24 | :return: Class
25 | """
26 | module = data["module"]
27 | class_name = data["class"]
28 | module = importlib.import_module(module)
29 | return getattr(module, class_name)
30 |
--------------------------------------------------------------------------------
/quaterion_models/utils/meta.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 |
3 |
4 | def merge_meta(meta: Dict[str, list]) -> List[dict]:
5 | """Merge meta information from multiple encoders into one
6 |
7 | Combine meta from all encoders
8 | Example: Encoder 1 meta: `[{"a": 1}, {"a": 2}]`, Encoder 2 meta: `[{"b": 3}, {"b": 4}]`
9 | Result: `[{"a": 1, "b": 3}, {"a": 2, "b": 4}]`
10 |
11 | Args:
12 | meta: meta information to merge
13 | """
14 | aggregated = None
15 | for key, encoder_meta in meta.items():
16 | if aggregated is None:
17 | aggregated = encoder_meta
18 | else:
19 | for i in range(len(meta)):
20 | aggregated[i].update(encoder_meta[i])
21 | return aggregated
22 |
--------------------------------------------------------------------------------
/quaterion_models/utils/tensors.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def move_to_device(obj, device: torch.device):
5 | """
6 | Given a structure (possibly) containing Tensors on the CPU,
7 | move all the Tensors to the specified GPU (or do nothing, if they should be on the CPU).
8 | """
9 |
10 | if device == torch.device("cpu"):
11 | return obj
12 | elif isinstance(obj, torch.Tensor):
13 | return obj.cuda(device)
14 | elif isinstance(obj, dict):
15 | return {key: move_to_device(value, device) for key, value in obj.items()}
16 | elif isinstance(obj, list):
17 | return [move_to_device(item, device) for item in obj]
18 | elif isinstance(obj, tuple) and hasattr(obj, "_fields"):
19 | # This is the best way to detect a NamedTuple, it turns out.
20 | return obj.__class__(*(move_to_device(item, device) for item in obj))
21 | elif isinstance(obj, tuple):
22 | return tuple(move_to_device(item, device) for item in obj)
23 | else:
24 | return obj
25 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/qdrant/quaterion-models/f5f97a27f8d9b35194fe316592bd70c4852048f7/tests/__init__.py
--------------------------------------------------------------------------------
/tests/encoders/test_fasttext_encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 |
4 | from quaterion_models.encoders.extras.fasttext_encoder import FasttextEncoder
5 |
6 |
7 | def test_fasttext_encoder():
8 | from gensim.models import FastText
9 |
10 | demo_texts = [
11 | ["aaa", "bbb", "ccc", "ddd", "123"],
12 | ["aaa", "bbb", "ccc", "aaa", "123"],
13 | ["aaa", "bbb", "ccc", "bbb", "123"],
14 | ["aaa", "bbb", "ccc", "ccc", "123"],
15 | ["aaa", "bbb", "ccc", "123", "123"],
16 | ]
17 |
18 | model = FastText(
19 | vector_size=10, window=1, min_count=0, min_n=0, max_n=0, sg=1, bucket=1_000
20 | )
21 | epochs = 10
22 | model.build_vocab(demo_texts)
23 | model.train(demo_texts, epochs=epochs, total_examples=len(demo_texts))
24 | tempdir = tempfile.TemporaryDirectory()
25 |
26 | model_path = os.path.join(tempdir.name, "fasttext.model")
27 | model.wv.save(model_path, separately=["vectors_ngrams", "vectors", "vectors_vocab"])
28 |
29 | encoder = FasttextEncoder(
30 | model_path=model_path, on_disk=False, aggregations=["avg", "max"]
31 | )
32 |
33 | assert encoder.embedding_size == 20
34 |
35 | embeddings = encoder.forward([["aaa", "123"], ["aaa", "ccc"]])
36 |
37 | assert embeddings.shape[0] == 2
38 | assert embeddings.shape[1] == 20
39 |
40 | encoder.named_parameters()
41 |
42 | empty_embeddings = encoder.forward([[], []])
43 |
44 | assert empty_embeddings.shape[0] == 2
45 | assert empty_embeddings.shape[1] == 20
46 |
--------------------------------------------------------------------------------
/tests/encoders/test_switch_encoder.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from abc import ABC
3 | from typing import Any, List
4 |
5 | import numpy as np
6 | import torch
7 | from torch import Tensor
8 |
9 | from quaterion_models.encoders import Encoder
10 | from quaterion_models.encoders.switch_encoder import SwitchEncoder
11 | from quaterion_models.heads.empty_head import EmptyHead
12 | from quaterion_models.model import SimilarityModel
13 | from quaterion_models.types import CollateFnType, TensorInterchange
14 |
15 |
16 | class CustomEncoder(Encoder, ABC):
17 | @property
18 | def trainable(self) -> bool:
19 | return False
20 |
21 | @property
22 | def embedding_size(self) -> int:
23 | return 3
24 |
25 | def save(self, output_path: str):
26 | pass
27 |
28 | @classmethod
29 | def load(cls, input_path: str) -> "Encoder":
30 | return cls()
31 |
32 | @classmethod
33 | def collate_fn(cls, batch: List[Any]) -> TensorInterchange:
34 | return [torch.zeros(1) for _ in batch]
35 |
36 | def get_collate_fn(self) -> CollateFnType:
37 | return self.__class__.collate_fn
38 |
39 |
40 | class EncoderA(CustomEncoder):
41 | def forward(self, batch: TensorInterchange) -> Tensor:
42 | return torch.zeros(len(batch), self.embedding_size)
43 |
44 |
45 | class EncoderB(CustomEncoder):
46 | def forward(self, batch: TensorInterchange) -> Tensor:
47 | return torch.ones(len(batch), self.embedding_size)
48 |
49 |
50 | class CustomSwitchEncoder(SwitchEncoder):
51 | @classmethod
52 | def encoder_selection(cls, record: Any) -> str:
53 | if record == "zeros":
54 | return "a"
55 | if record == "ones":
56 | return "b"
57 |
58 |
59 | def test_forward():
60 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()})
61 |
62 | model = SimilarityModel(encoders=encoder, head=EmptyHead(encoder.embedding_size))
63 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"]
64 |
65 | res = model.encode(batch)
66 |
67 | assert res.shape[0] == len(batch)
68 | assert all(res[:, 0] == np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]))
69 |
70 |
71 | def test_meta():
72 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()})
73 | meta_extractor = encoder.get_meta_extractor()
74 |
75 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"]
76 |
77 | meta = meta_extractor(batch)
78 | print("")
79 | print(meta)
80 |
81 | assert meta[0]["encoder"] == "a"
82 | assert meta[1]["encoder"] == "a"
83 | assert meta[2]["encoder"] == "b"
84 | assert meta[3]["encoder"] == "b"
85 | assert meta[4]["encoder"] == "a"
86 | assert meta[5]["encoder"] == "a"
87 | assert meta[6]["encoder"] == "b"
88 |
89 |
90 | def test_save_and_load():
91 | encoder = CustomSwitchEncoder({"a": EncoderA(), "b": EncoderB()})
92 |
93 | tempdir = tempfile.TemporaryDirectory()
94 | model = SimilarityModel(encoders=encoder, head=EmptyHead(encoder.embedding_size))
95 | model.save(tempdir.name)
96 | model = model.load(tempdir.name)
97 |
98 | batch = ["zeros", "zeros", "ones", "ones", "zeros", "zeros", "ones"]
99 | res = model.encode(batch)
100 |
101 | assert res.shape[0] == len(batch)
102 | assert all(res[:, 0] == np.array([0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]))
103 |
--------------------------------------------------------------------------------
/tests/encoders/test_switch_head.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 |
3 | import torch
4 |
5 | from quaterion_models.heads import EncoderHead
6 | from quaterion_models.heads.switch_head import SwitchHead
7 |
8 |
9 | class FakeHeadA(EncoderHead):
10 | @property
11 | def output_size(self) -> int:
12 | return 3
13 |
14 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
15 | return input_vectors + torch.tensor(
16 | [[0, 0, 1] for _ in range(input_vectors.shape[0])]
17 | )
18 |
19 |
20 | class FakeHeadB(EncoderHead):
21 | @property
22 | def output_size(self) -> int:
23 | return 3
24 |
25 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
26 | return input_vectors + torch.tensor(
27 | [[0, 2, 0] for _ in range(input_vectors.shape[0])]
28 | )
29 |
30 |
31 | def test_save_and_load():
32 | head = SwitchHead(
33 | options={"a": FakeHeadA(3), "b": FakeHeadB(3)},
34 | input_embedding_size=3,
35 | )
36 |
37 | temp_dir = tempfile.TemporaryDirectory()
38 | head.save(temp_dir.name)
39 |
40 | loaded_head = SwitchHead.load(temp_dir.name)
41 |
42 | print(loaded_head)
43 |
44 |
45 | def test_forward():
46 | head = SwitchHead(
47 | options={"a": FakeHeadA(3), "b": FakeHeadB(3)},
48 | input_embedding_size=3,
49 | )
50 |
51 | batch = torch.tensor(
52 | [
53 | [1, 0, 0],
54 | [2, 0, 0],
55 | [3, 0, 0],
56 | [4, 0, 0],
57 | [5, 0, 0],
58 | ]
59 | )
60 |
61 | meta = [
62 | {"encoder": "b"},
63 | {"encoder": "a"},
64 | {"encoder": "b"},
65 | {"encoder": "b"},
66 | {"encoder": "a"},
67 | ]
68 |
69 | res = head.forward(batch, meta)
70 |
71 | assert res.shape[0] == batch.shape[0]
72 | assert res.shape[1] == 3
73 |
74 | assert res[0][0] == 1
75 | assert res[1][0] == 2
76 | assert res[2][0] == 3
77 | assert res[3][0] == 4
78 | assert res[4][0] == 5
79 |
80 | assert res[0][1] == 2
81 | assert res[1][1] == 0
82 | assert res[2][1] == 2
83 | assert res[3][1] == 2
84 | assert res[4][1] == 0
85 |
86 | assert res[0][2] == 0
87 | assert res[1][2] == 1
88 | assert res[2][2] == 0
89 | assert res[3][2] == 0
90 | assert res[4][2] == 1
91 |
--------------------------------------------------------------------------------
/tests/heads/test_head.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from typing import Any, List
4 |
5 | import pytest
6 | import torch
7 | import torch.nn as nn
8 | from torch import Tensor
9 |
10 | from quaterion_models import SimilarityModel
11 | from quaterion_models.encoders import Encoder
12 | from quaterion_models.heads import (
13 | EmptyHead,
14 | GatedHead,
15 | SequentialHead,
16 | SkipConnectionHead,
17 | SoftmaxEmbeddingsHead,
18 | WideningHead,
19 | )
20 | from quaterion_models.types import CollateFnType, TensorInterchange
21 |
22 | BATCH_SIZE = 3
23 | INPUT_EMBEDDING_SIZE = 5
24 | HIDDEN_EMBEDDING_SIZE = 7
25 | OUTPUT_EMBEDDING_SIZE = 10
26 |
27 | _HEADS = (
28 | SequentialHead(
29 | nn.Linear(INPUT_EMBEDDING_SIZE, HIDDEN_EMBEDDING_SIZE),
30 | nn.ReLU(),
31 | nn.Linear(HIDDEN_EMBEDDING_SIZE, OUTPUT_EMBEDDING_SIZE),
32 | output_size=OUTPUT_EMBEDDING_SIZE,
33 | ),
34 | GatedHead(INPUT_EMBEDDING_SIZE),
35 | WideningHead(INPUT_EMBEDDING_SIZE),
36 | SkipConnectionHead(INPUT_EMBEDDING_SIZE),
37 | EmptyHead(INPUT_EMBEDDING_SIZE),
38 | SoftmaxEmbeddingsHead(
39 | output_groups=2,
40 | output_size_per_group=OUTPUT_EMBEDDING_SIZE,
41 | input_embedding_size=INPUT_EMBEDDING_SIZE,
42 | ),
43 | )
44 | HEADS = {head_.__class__.__name__: head_ for head_ in _HEADS}
45 |
46 |
47 | class CustomEncoder(Encoder):
48 | def save(self, output_path: str):
49 | pass
50 |
51 | @classmethod
52 | def load(cls, input_path: str) -> "Encoder":
53 | return cls()
54 |
55 | @property
56 | def trainable(self) -> bool:
57 | return False
58 |
59 | @property
60 | def embedding_size(self) -> int:
61 | return INPUT_EMBEDDING_SIZE
62 |
63 | @classmethod
64 | def collate_fn(cls, batch: List[Any]):
65 | return torch.stack(batch)
66 |
67 | def get_collate_fn(self) -> CollateFnType:
68 | return self.__class__.collate_fn
69 |
70 | def forward(self, batch: TensorInterchange) -> Tensor:
71 | return batch
72 |
73 |
74 | @pytest.mark.parametrize("head", HEADS.values(), ids=HEADS.keys())
75 | def test_save_and_load(head):
76 | encoder = CustomEncoder()
77 |
78 | model = SimilarityModel(encoders=encoder, head=head)
79 | tempdir = tempfile.TemporaryDirectory()
80 |
81 | model.save(tempdir.name)
82 |
83 | config_path = os.path.join(tempdir.name, "config.json")
84 |
85 | assert os.path.exists(config_path)
86 |
87 | batch = torch.rand(BATCH_SIZE, INPUT_EMBEDDING_SIZE)
88 | origin_output = model.encode(batch, to_numpy=False)
89 |
90 | loaded_model = SimilarityModel.load(tempdir.name)
91 |
92 | assert model.encoders.keys() == loaded_model.encoders.keys()
93 | assert [type(encoder) for encoder in model.encoders.values()] == [
94 | type(encoder) for encoder in loaded_model.encoders.values()
95 | ]
96 |
97 | assert type(model.head) == type(loaded_model.head)
98 | assert torch.allclose(origin_output, loaded_model.encode(batch, to_numpy=False))
99 |
100 |
101 | @pytest.mark.parametrize("head", HEADS.values(), ids=HEADS.keys())
102 | def test_forward_shape(head):
103 | batch = torch.rand(BATCH_SIZE, INPUT_EMBEDDING_SIZE)
104 | res = head.forward(batch)
105 | assert res.shape == (BATCH_SIZE, head.output_size)
106 |
--------------------------------------------------------------------------------
/tests/test_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from multiprocessing import Pool
4 | from typing import Any, List
5 |
6 | import torch
7 | from torch import Tensor
8 |
9 | from quaterion_models import SimilarityModel
10 | from quaterion_models.encoders import Encoder
11 | from quaterion_models.heads import EmptyHead, EncoderHead
12 | from quaterion_models.types import CollateFnType, TensorInterchange
13 |
14 | TEST_EMB_SIZE = 5
15 |
16 |
17 | class LambdaHead(EncoderHead):
18 | def __init__(self):
19 | super(LambdaHead, self).__init__(TEST_EMB_SIZE)
20 | self.my_lambda = lambda x: "hello"
21 |
22 | @property
23 | def output_size(self) -> int:
24 | return 0
25 |
26 | def transform(self, input_vectors: torch.Tensor) -> torch.Tensor:
27 | return input_vectors
28 |
29 |
30 | class CustomEncoder(Encoder):
31 | def save(self, output_path: str):
32 | pass
33 |
34 | @classmethod
35 | def load(cls, input_path: str) -> "Encoder":
36 | return cls()
37 |
38 | def __init__(self):
39 | super().__init__()
40 | self.unpickable = lambda x: x + 1
41 |
42 | @property
43 | def trainable(self) -> bool:
44 | return False
45 |
46 | @property
47 | def embedding_size(self) -> int:
48 | return TEST_EMB_SIZE
49 |
50 | @classmethod
51 | def collate_fn(cls, batch: List[Any]):
52 | return torch.rand(len(batch), TEST_EMB_SIZE)
53 |
54 | def get_collate_fn(self) -> CollateFnType:
55 | return self.__class__.collate_fn
56 |
57 | def forward(self, batch: TensorInterchange) -> Tensor:
58 | return batch
59 |
60 |
61 | class Tst:
62 | def __init__(self, foo):
63 | self.foo = foo
64 |
65 | def bar(self, x):
66 | return self.foo(x)
67 |
68 |
69 | def test_get_collate_fn():
70 | model = SimilarityModel(encoders={"test": CustomEncoder()}, head=LambdaHead())
71 |
72 | tester = Tst(foo=model.get_collate_fn())
73 |
74 | with Pool(2) as pool:
75 | res = pool.map(tester.bar, [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10]])
76 |
77 | assert len(res) == 4
78 |
79 | first_batch = res[0]
80 |
81 | assert "test" in first_batch["data"]
82 |
83 | tensor = first_batch["data"]["test"]
84 |
85 | assert tensor.shape == (3, TEST_EMB_SIZE)
86 |
87 |
88 | def test_model_save_and_load():
89 | tempdir = tempfile.TemporaryDirectory()
90 | model = SimilarityModel(encoders={"test": CustomEncoder()}, head=EmptyHead(100))
91 |
92 | model.save(tempdir.name)
93 |
94 | config_path = os.path.join(tempdir.name, "config.json")
95 |
96 | assert os.path.exists(config_path)
97 |
98 | loaded_model = SimilarityModel.load(tempdir.name)
99 |
100 | assert model.encoders.keys() == loaded_model.encoders.keys()
101 | assert [type(encoder) for encoder in model.encoders.values()] == [
102 | type(encoder) for encoder in loaded_model.encoders.values()
103 | ]
104 |
105 | assert type(model.head) == type(loaded_model.head)
106 |
--------------------------------------------------------------------------------
/tests/utils/test_classes.py:
--------------------------------------------------------------------------------
1 | from quaterion_models.utils import restore_class, save_class_import
2 |
3 |
4 | def test_restore_class():
5 | model_class = restore_class({"module": "torch.nn", "class": "Linear"})
6 |
7 | from torch.nn import Linear
8 |
9 | model: Linear = model_class(10, 10)
10 |
11 | assert model.out_features == 10
12 |
13 |
14 | def test_save_class_import():
15 | from collections import Counter
16 |
17 | class_serialized = save_class_import(Counter())
18 |
19 | assert class_serialized["class"] == "Counter"
20 | assert class_serialized["module"] == "collections"
21 |
--------------------------------------------------------------------------------